LSTM networks are a type of recurrent neural network (RNN) designed to overcome the traditional RNN's shortcoming of vanishing gradients, enabling it to learn and remember over long sequences. Unlike standard RNNs, LSTMs have a more complex cell structure that includes mechanisms called gates. These gates control the flow of information, deciding what to retain in memory and what to discard, making LSTMs particularly suitable for tasks requiring the understanding of temporal dynamics.
Model Specifications
- Input Dimension: The input to our model comprises two channels of EOG data: one representing horizontal and the other vertical eye movements. Thus, the input dimension (input_dim) is set to 2. This configuration allows the LSTM to process both dimensions of eye movements simultaneously, capturing the nuanced patterns of gaze direction shifts.
- Hidden Layer Size: The hidden layer size (hidden_dim) dictates the capacity of the LSTM to encode the temporal information. While the exact size is adjustable based on the complexity of the task and dataset, a larger hidden size generally increases the model's ability to learn detailed features at the cost of computational efficiency and potential overfitting.
- Output Structure: Our model yields three outputs: two for the regression task predicting the x and y coordinates of the pointer's position, and one for the classification task identifying blinks. This dual-task capability is achieved through separate output layers: a regressor for pointer coordinates and a classifier for blinks, with the classifier's output passing through a sigmoid activation function to model a probability.
To leverage the computational capabilities of the NVIDIA Jetson Orin Nano, we ensure our LSTM model utilizes CUDA, NVIDIA's parallel computing platform and programming model. This approach allows us to take full advantage of the Orin Nano's GPU, significantly enhancing the efficiency of model training and inference processes.
To ensure the LSTM model we've developed effectively learns from EOG data, capturing both eye movements and blink detection, we utilize k-fold validation. This method allows for a comprehensive evaluation across different subsets of the data, ensuring the model's generalizability and robustness. Here's an outline of our training process and how we save the best-performing model:
- K-Fold Validation Setup: We employ a 5-fold cross-validation approach, dividing the training and validation dataset into five subsets. This method ensures every data segment serves as a validation set at some point, allowing for a thorough evaluation of the model's capabilities.
- Training Loop: Within each fold, the data is further divided into mini-batches. We train the model across a specified number of epochs (50 in this case), using the Adam optimizer with a learning rate of 0.001. This optimizer is chosen for its effectiveness in handling sparse gradients and adapting its learning rate based on the training process.
- Loss Functions: For regression tasks (predicting pointer positions), we use the Mean Squared Error Loss (nn.MSELoss()), and for the classification task (blink detection), we use Binary Cross-Entropy Loss (nn.BCELoss()). These loss functions are appropriate for the respective output types and contribute to the model's learning efficacy.
- Model Evaluation and Selection: Throughout the training process, we monitor the loss on the validation set. The model with the lowest validation loss is considered the best performer, and its parameters are saved. This approach ensures that we capture the model version most capable of generalizing from the training data. Best Model results are:
Training Results: Fold 0, Epoch44
Loss: 0.05, Accuracy: 0.99, Precision: 0.98, Recall: 0.95, F1: 0.97, ROC-AUC: 1.00
MSE: 0.02, MAE: 0.08, RMSE: 0.14, R-2: 0.87
......................................................................................
Validation Results:
Validation Loss: 0.02, Accuracy: 1.00, Precision: 1.00, Recall: 1.00, F1 Score: 1.00, ROC-AUC: 1.00
MSE: 0.01, MAE: 0.06, RMSE: 0.11, R-2: 0.90
- Final Testing: After identifying and saving the best model through cross-validation, we evaluate its performance on the test set. This final step assesses how well the model can predict eye movements and detect blinks on data it has never seen, providing a measure of its real-world applicability. The results for test dataset are:
Test Accuracy: 1.0000
Test Precision: 1.0000
Test Recall: 1.0000
Test F1 Score: 1.0000
Test ROC-AUC: 1.0000
Test MSE: 0.0080
Test MAE: 0.0536
Test RMSE: 0.0895
Test R-2: 0.9280
Comments