The primary goal of my project is to develop a robust and scalable PDF parser that leverages machine learning techniques to extract structured data, specifically aiming to enhance Retrieval-Augmented Generation (RAG) systems using Graph Neural Networks (GNNs) and Reinforcement Learning (RL). The main objectives include:
- Data Preprocessing: Converting PDF pages into images, detecting layout elements, performing OCR to extract text, and creating a graph representation of the document.
- Graph Construction: Representing each layout element (text block, image, table) as a node with features such as text embeddings, bounding box coordinates, and element type. Initializing edges based on spatial proximity.
- GNN Model: Designing a GNN to process the graph, updating node features through message passing and aggregation. Integrating a neural network to predict the existence and weights of edges dynamically.
- Policy Network for RL: Developing a policy network (actor) to output action probabilities for modifying the graph structure and a value network (critic) to estimate the expected reward of the current state.
- RL Training: Defining the environment for the PDF parsing process, collecting trajectories (state, action, reward, next state), computing rewards based on RAG system performance, and updating the policy and value networks using Proximal Policy Optimization (PPO).
- End-to-End Workflow: Iteratively training the model on new PDFs, optimizing the document graph structure and relationships to enhance data extraction accuracy and relevance for downstream RAG tasks.
1.Data Preprocessing
- PDF Conversion: convert PDF pages into images.
- Layout Detection: detect layout elements using YOLOv10 (https://github.com/moured/YOLOv10-Document-Layout-Analysis).
- OCR: perform OCR on detected elements to extract text.
- Graph Creation: create a graph where each node represents a layout element with features like text embeddings, bounding box coordinates, and element type.
2.Graph Construction and GNN Model:
- Construct graphs using layout elements as nodes and spatial proximity as heuristic edges.
- Design a GNN to process the graph, update node features, and dynamically predict edges.
3.Policy Network for RL:
- Develop a policy network to output probabilities for adding or removing edges.
- Create a critic network to estimate rewards based on the current state.
4.RL Training:
- Define an environment encapsulating the PDF parsing process.
- Collect trajectories by interacting with the environment.
- Compute rewards based on RAG system performance on test questions.
- Update the policy and value networks using PPO.
5.End-to-End Workflow:
- Use caching for repeated data processing steps to enhance efficiency.
- Integrate the LLaMA model (https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) for answer generation, emphasizing the provided context.
- Evaluate the generated answers against provided answers using cosine similarity to score the quality.
- Combine all metrics into a single average score for evaluation.
Visualization:
- Create visualizations of the graph by drawing bounding boxes and edges on the PDF images.
- Add node indices and ensure edges are distinguishable.
- Caching with Joblib: caching for repeated data processing steps.
- Model Integration: use of the LLaMA model with a prompt emphasizing the context for generating answers.
- Evaluation: Compute cosine similarity between generated and provided answers, returning an average score to measure performance.
- Visualization: Enhance visualization functions to draw node indices and edges on the PDF images.
This comprehensive approach ensures efficient, accurate, and scalable extraction of structured data from PDFs, enhancing the performance of RAG systems.
PPO Models OverviewThe Proximal Policy Optimization (PPO) models are essential for enhancing the PDF parsing process. Below are the key neural network models used:
1. Add Network:
- Predicts the probability of adding an edge between two nodes.
- Input: Concatenated embeddings of two nodes.
- Layers: Two fully connected layers with ReLU activation followed by a sigmoid output.
2. Remove Network:
- Predicts the probability of removing an edge between two nodes.
- Input: Concatenated embeddings of two nodes.
- Layers: Two fully connected layers with ReLU activation followed by a sigmoid output.
3. Policy Network:
- Outputs action probabilities for adding, removing edges, or stopping the process.
- GCN Layers: Two Graph Convolutional Network (GCN) layers with ReLU activation.
- Output: Fully connected layer that outputs action probabilities, followed by softmax activation.
4. Critic Network:
- Estimates the expected reward of the current state.
- GCN Layers: Two GCN layers with ReLU activation.
- Node Projection: Linear layer to project node embeddings.
- Output: Fully connected layers that take GCN output, projected node embeddings, action one-hot vector, and action probability to predict the value.
- Collect Trajectory: Interact with the environment to collect state-action-reward sequences.
- Compute Advantages: Calculate advantages and returns using collected trajectories.
- Update Networks: Optimize policy and value networks using the PPO algorithm with collected data.
The training process and validation for the PPO model involve several key steps to ensure efficient learning and accurate performance evaluation:
Device Configuration: The code checks for available hardware (CUDA, MPS, or CPU) and sets the device accordingly to leverage optimal processing power.
Cache Management: A caching mechanism is implemented using Joblib to store intermediate results and avoid redundant computations. This enhances the efficiency of the training and validation process.
PDF Processing:
- Conversion: Convert PDF pages into images.
- Layout Detection: Detect layout elements in the images.
- OCR: Perform OCR on the detected layout elements.
- Graph Creation: Create graphs from OCR results.
- Graph Merging: Update coordinates and merge graphs for a consolidated representation.
Training Function:
- Episode Training: For each PDF entry, an episode is run where the PPO model interacts with the graph to collect trajectories, compute advantages, and update the networks.
- Caching: Use caching to store processed PDF results, reducing redundant processing.
Validation Function:
- Inference: The model's performance is evaluated by running inference on a held-out PDF entry. This involves generating answers to test questions and computing scores based on cosine similarity.
- Visualization: The graph's state before and after training is visualized to observe the impact of the PPO model.
Training Loop:
- Epoch Iteration: The training loop runs for a defined number of epochs, processing each PDF entry in the training set.
- Score Comparison: After each epoch, the mean score of the generated answers is compared before and after training to evaluate the model's improvement.
In this demonstration, I used three academic papers from arXiv to train and validate the PDF parsing and RAG system. The goal is to showcase the system's ability to extract and utilize structured data to answer specific questions accurately. Below are the details of the documents used and the training process.
Documents and Datasets
Training Documents:
- "Attention Is All You Need": This paper introduces the Transformer model, which has significantly influenced the field of natural language processing.
- "Layer Normalization": This paper presents the technique of layer normalization, which improves the training of deep neural networks.
Validation Document:
- "Neural Machine Translation by Jointly Learning to Align and Translate": This paper discusses an approach to neural machine translation that jointly learns to align and translate, improving translation accuracy.
For each of these documents, I prepared a set of 10 questions and corresponding answers to compare the performance of the RAG system.
Results and Visualization
Throughout the training process, the model's performance is evaluated by comparing the mean scores of generated answers before and after training:
- Before Training: The system's initial performance is measured without any training adjustments to the graph structure.
- After Training: Post-training performance is measured to assess the improvements made by the PPO model.
After training for 4 epochs, the system showed notable improvement in answering questions accurately. The validation results indicated a higher mean score for the generated answers post-training, demonstrating the efficacy of the training process. The visualizations of the document graphs highlighted the structural changes and improvements made by the model.
Mean scores no training: 0.8213 vs with trainig: 0.8402
While there are many improvements that can be made, the project demonstrates the viability of using reinforcement learning to enhance the PDF parsing and splitting part of a RAG system, as well as the effectiveness of using AMD GPUs for this task. By employing a small subset of PDFs and a defined set of questions and answers, the algorithm effectively learns to optimize the splitting and structuring of PDFs for better data extraction and relevance. This approach can be particularly useful for creating Knowledge Bases (KBs) from large collections of documents, enabling more accurate and efficient RAG performance.
Further Improvements- Positional Encoding: Currently, the positional information of elements within the PDF is not encoded in the node embeddings. Incorporating positional encoding, similar to the Sinusoidal Positional Embedding used in Stable Diffusion, could enhance the model's understanding of the document layout and improve the accuracy of node relationships.
- Model Size and Performance: The smallest LLama 3.1 model is prone to hallucinations and repeating answers. While larger models could mitigate this issue, they are computationally intensive. Using the smaller model for quick evaluations during training is effective, but exploring the use of larger models for final evaluations might improve answer quality and reliability.
- Caching Computations: Many computations, especially those involving the RAG process for the same set of questions and context, could be cached to save time and resources. Implementing a more comprehensive caching strategy would reduce redundant calculations and speed up the training and validation processes.
These improvements would further enhance the system's performance, making it more robust and efficient for practical applications in various domains requiring structured data extraction from PDFs.
Comments