Large Language Models (LLMs) are revolutionizing the way applications leverage artificial intelligence. In certain scenarios, fine-tuning the LLMs on your data may be necessary to achieve superior results compared to prompting or training smaller, more efficient models. Although fine-tuning is a simple process, applying it to LLM is not simple. Fine-tuning LLMs can be computationally expensive and resource-intensive, making it a challenge for many users. Running and fine tuning the largest LLMs rapidly needs high-performing infrastructure. Most people run, train, fine tune LLMs on x64 platforms, but running them on ARM64/aarch64 platforms is less common. Today, I am going to show how to fine tune LLMs on the the NVIDIA Jetson AGX Orin Developer Kit. This is just to show the capability of fine-tuning on the low-powered devices like the NVIDIA Jetson AGX Orin Developer Kit, which features a unified memory architecture between the Arm-based CPU cores and the NVIDIA Ampere architecture-based GPU. The system has 64 GB of shared memory, which is shared between the CPU and GPU.
Let's get started! π€
PrerequisitesFirst of all, if you are going to use Llama, Mistral, or any other model, you need to log in to your Hugging Face account to use your token for accessing the gated repository. The Hugging Face is a popular platform for sharing and discovering machine learning models, making your model accessible to others.
We can do this by running the following command:
huggingface-cli login --token YOUR_TOKEN
Then, it is essential to install the necessary libraries to avoid any potential errors.
I recommend to use BitsandBytes container from Dustin Franklin's Jetson Containers project:
jetson-containers run -v /path/on/host:/path/in/container $(autotag bitsandbytes)
Install the necessary Python packages and libraries required for fine-tuning Mistral:
pip install peft
pip install trl
pip install wandb
To do finetuning with HuggingFace, you need to install both the BitsandBytes libraryand the PEFT library.The BitsandBytes library takes care of the 4-bit quantization. The PEFT library will be used for the LoRA finetuning part. trl library from huggingface for training the Large Language Models.
I will use Weights & Biases(wandb) to track our training metrics. Weights and Biases (W&B) is a popular tool for tracking machine learning trainings.
Dataset PreparationFine-tuning is the most effective method for training a model on task-specific aspects. The process involves using custom datasets to enhance the performance of a pre-trained model on specific tasks. For this purpose, I will utilize a subset of the Open Assistant dataset, which contains the highest-rated paths in the conversation tree, comprising a total of 9, 846 samples. It is available on Hugging Face platform.
Weβll begin by exploring the dataset and looking at a few training data samples. It is essential to have the dataset in the correct format.
The datasets library allows us to easily import ready-to-use datasets from the Hugging Face platform. Letβs import the dataset and examine its format.
from datasets import load_dataset
import pandas as pd
dataset = load_dataset("timdettmers/openassistant-guanaco")
df = pd.DataFrame(dataset['train'])
df.head()
Once the dataset is loaded, we can take a look at it to understand what it contains:
0### Human: ΠΠ°ΠΏΠΈΡΠΈ ΡΡΠ½ΠΊΡΠΈΡ Π½Π° ΡΠ·ΡΠΊΠ΅ swift, ΠΊΠΎΡΠΎ...
1### Human: Inventa un monstruo altamente compl...
2### Human: Escribe un codigo para ESP32 que in...
3### Human: What do you think about ChatGPT?###...
4### Human: Can you please provide me the names...
The training set consists of 9, 846 rows and 1 column, while the test set consists of 518 rows and 1 column. We can look at the sample training data consisting of an instruction human text and its corresponding assistant answer text.
This provides a quick understanding of the dataset and serves as a sanity check to ensure the data has been uploaded correctly. A full list of their datasets can be found here. Feel free to try this experiment with any custom dataset.
Fine-tuning the Mistral 7B modelNow that we have an idea about the dataset, itβs time to determine the LLM model, which will serve as our base model for fine-tuning. Mistral models under the Apache 2.0 license, allowing developers and researchers to use it without restrictions. Mistral 7B is designed to be fine-tuned for various tasks.
In general, there are two foundational models that Mistral released: Mistral 7B and Mistral 7B Instruct. The Mistral 7B is the base foundation model, and the Mistral 7B Instruct is a Mistral 7B model that has been fine-tuned for conversation and question answering. Basically, instruct model is fine tuned to follow instructions so it can do tasks and answer questions in a natural way. The base model doesnβt do that. The Mistral 7B Instruct model is a quick demonstration that the base model can be easily fine-tuned to achieve great performance.
If we want to fine-tune the Mistral 7B Instruct for conversation and question answering, we need to follow the chat template format provided by Mistral, shown in the code block below.
<s>[INST] Instruction [/INST] Model answer</s>[INST] Follow-up instruction [/INST]
However, in this case, we will be using the base model for the fine-tuning process.
Most large language models are too big to be fine-tuned on consumer hardware. Even if your dataset for fine-tuning is small, the backpropagation step needs to compute gradients for billions of parameters. For instance, to fine-tune a 65 billion parameter model we need more than 780 GB of GPU memory. This is equivalent to ten Nvidia A100 80 GB GPUs. Full fine-tuning involves updating all model parameters, but this can be expensive. This is where LoRA and QLoRA come into the picture. Low Rank Adaptation (LoRA) is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. LoRA is a parameter efficient fine-tuning (PEFT) method. The key idea is that LoRA does not retrain all model parameters, but adds a relatively small number of trainable parameters while keeping the original parameters fixed. This makes training with LoRA much faster and more memory-efficient, and produces smaller model weights. QLoRA is a combination of Quantization and LoRA. QLoRA works by introducing 3 new concepts that help to reduce memory while keeping the same quality performance. These are 4-bit Normal Float, Double Quantization, and Paged Optimizers.
You can read the paper to understand it better.
One of the most common technique used for fine-tuning is the Supervised Fine-Tuning (SFT). The most common way for doing SFT is to load the model in 4-bit and apply the config to the model for Lora training. Then we use TRLβs SFTTrainer to fine-tune models.
It is important to first begin with simple and short runs to make sure all the pieces work together! Remember that tuning parameters can be a time-consuming process, and the best configuration might vary depending on your specific dataset and use case.
The below code defines the training arguments for fine-tuning Mistral-7B Large Language Model. Power mode of NVIDIA Jetson AGX Orin has been set to MAXN.
import os
import torch
from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training
from huggingface_hub import login
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
AutoTokenizer,
TrainingArguments,
)
from trl import SFTTrainer
import wandb
import gc
login(
token="ADD YOUR TOKEN HERE FROM HUGGING FACE", # ADD YOUR TOKEN HERE
add_to_git_credential=True
)
output_dir="./fine-tuned_mistral"
model_name = "mistralai/Mistral-7B-v0.1"
#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.padding_side = 'left'
compute_dtype = getattr(torch, "bfloat16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_name, quantization_config=bnb_config, device_map={"": 0}
)
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.0,
r=16,
bias="none",
task_type="CAUSAL_LM",
target_modules= ["gate_proj", "up_proj", "down_proj"]
)
dataset = load_dataset("timdettmers/openassistant-guanaco")
# Monitering the LLM
wandb.login(key = "ADD YOUR TOKEN HERE FROM WANDB")
run = wandb.init(project='Fine tuning of Mistral 7B', job_type="training", anonymous="allow")
training_arguments = TrainingArguments(
output_dir=output_dir,
evaluation_strategy="steps",
do_eval=True,
per_device_train_batch_size=4,
gradient_accumulation_steps=6,
per_device_eval_batch_size=4,
log_level="debug",
save_steps=100,
logging_steps=25,
learning_rate=2e-4,
eval_steps=10,
optim='adamw_8bit',
bf16=True, #change to fp16 if not using an Ampere GPU
weight_decay=0.1,
max_steps=100,
warmup_ratio=0.01,
lr_scheduler_type="linear",
push_to_hub=True,
report_to="wandb",
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset['train'],
eval_dataset=dataset['test'],
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=512,
tokenizer=tokenizer,
args=training_arguments,
)
trainer.train()
gc.collect()
torch.cuda.empty_cache()
trainer.save_model(output_dir)
output_dir = os.path.join(output_dir, "final_checkpoint")
trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
per_device_train_batch_size and gradient_accumulation_steps: I set them to 4 and 6, which will produce a total batch size of 24 (4*6). If batch_size is too large, it might not fit in the memory, leading to an out-of-memory error. And if batch_size is too small, the gradient updates in training or fine-tuning may be very noisy, hurting the model performance. Thus, we just need to find the best batch_size that can fit into your Nvidia Jetson's GPU memory without sacrificing model performance.
learning_rate I set to 2e-4, because it seemed to work well. lr_scheduler_type is set to linear.
adamw_8bit performs well while consuming much less memory than the original AdamW implementation. If you have enough GPU memory, like my Nvidia Jetson has, you can replace it with adamw_torch or adamw_32bit.
Unfortunately, paged optimizers - paged_adamw_8bit and paged_adamw_32bit failed to work on the NVIDIA Jetson AGX Orin Developer Kit.
I will use a bfloat16 point precision for faster and more efficient training. It results in significant memory savings compared to standard FP32. bfloat16 requires Ampere, Ada, or Hopper GPUs.
You can check whether your GPU supports bfloat16 via the following code:
import torch
torch.cuda.is_bf16_supported()
True
I set max_steps to 100 to confirm that the model is learning, we can set the max_steps to be high initially, and examine at what step your model's performance starts to degrade. There is where you'll find a sweet spot for how many steps to perform.
Once youβve kicked off the training process itβll give you useful feedback on your training, validation loss, tokens, and iterations per second.
***** Running training *****
Num examples = 9,846
Num Epochs = 1
Instantaneous batch size per device = 4
Total train batch size (w. parallel, distributed & accumulation) = 24
Gradient Accumulation steps = 6
Total optimization steps = 100
Number of trainable parameters = 28,311,552
0/100 [00:00<?, ?it/s]
The time it takes to fine-tune the model will vary depending on the compute resources, number of trainable parameters and hyperparameters we set. PEFT approaches aim to minimize the number of trainable parameters. The size of an LLM is typically measured in billions or trillions of parameters.
While this is running, you should sign in to W&B and check metrics there.
We can also implement early stopping to prevent overfitting. Basically, we should monitor the modelβs performance on a validation set during training and stop when the validation loss plateaus or starts increasing.
tokenizer config file saved in ./fine-tuned_mistral/final_checkpoint/tokenizer_config.json
Special tokens file saved in ./fine-tuned_mistral/final_checkpoint/special_tokens_map.json
wandb: \ 0.030 MB of 0.030 MB uploaded
wandb: Run history:
wandb: eval/loss ββ
ββββββββ
wandb: eval/runtime ββ
ββββ
ββββ
wandb: eval/samples_per_second ββββββββββ
wandb: eval/steps_per_second ββββββββββ
wandb: train/epoch ββββββββ
βββββββ
wandb: train/global_step ββββββββ
βββββββ
wandb: train/grad_norm ββββ
wandb: train/learning_rate ββββ
wandb: train/loss ββββ
wandb:
wandb: Run summary:
wandb: eval/loss 1.14086
wandb: eval/runtime 383.651
wandb: eval/samples_per_second 1.35
wandb: eval/steps_per_second 0.339
wandb: total_flos 4.973538789487411e+16
wandb: train/epoch 0.2437
wandb: train/global_step 100
wandb: train/grad_norm 0.30859
wandb: train/learning_rate 0.0
wandb: train/loss 1.1271
wandb: train_loss 1.14026
wandb: train_runtime 9322.1962
wandb: train_samples_per_second 0.257
wandb: train_steps_per_second 0.011
wandb:
wandb: π View run light-wind-2 at: https://wandb.ai/shakhizat/Fine%20tuning%20of%20Mistral%207B/runs/m7w7u4zc
wandb: βοΈ View project at: https://wandb.ai/shakhizat/Fine%20tuning%20of%20Mistral%207B
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240601_141306-m7w7u4zc/logs
Our model is now fine-tuned. It took nearly two hours and half to fine-tune for 100 steps using the Nvidia Jetson AGX Orin with 64GB of shared memory.
We should find a fine-tuned_mistral directory with checkpoint folder. The checkpoint folder with the highest number will contain all of our configurations. We will need the files in this folder to deploy our model and process inference requests.
You can check the project on Weights & Biases web interface.
Here are some interesting metrics to analyze. The graph below shows the loss over time for the validation data.
According to the training and validation losses, fine-tuning is going well:
The above model was trained for 100 steps using the Adam optimizer - adamw_8bit, learning_rate=2e-4 and per_device_train_batch_size=4. Depending on the task, you can increase the max_seq_length, I set it to 512. This is a low value, but it helps consume less memory. It seems to be good enough for our custom dataset. If batch_size is too small, the gradient updates in training or fine-tuning may be very noisy, hurting the model performance.
After the initial training, I modified the following parameters and trained the model again:
max_seq_length=2048
per_device_train_batch_size=1,
gradient_accumulation_steps=2,
per_device_eval_batch_size=2,
During training, monitor the evaluation loss to check if your model is overfitting. Overfitting is likely to occur if it increases, and you should consider stopping the training run.
Here are the results of the second training run:
The fine-tuning process took about 1 hours 43 minutes.
The graph above displays the loss over time for the validation data.
Fine-tuning is often an iterative process. Based on the validation, we may need to make further adjustments to the modelβs architecture, hyper parameters, or training data to improve its performance. Note that our fine-tuning pipeline can still be improved in different ways.
Let's Try Our Model: Inference with Fine-Tuned Mistral-7B ModelNow that our model has been fine-tuned, we can test it by doing inference.
To test our fine-tuned model, we will use transformers text generation pipeline and ask simple questions like "What are the key differences between Python and C++?".
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
# Define paths and model name
output_dir = "./fine-tuned_mistral/checkpoint-100"
model_name = "mistralai/Mistral-7B-v0.1"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(output_dir, use_fast=True)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.padding_side = 'left'
# Bits and Bytes Configuration
compute_dtype = getattr(torch, "bfloat16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
# Load the model
model = AutoModelForCausalLM.from_pretrained(
output_dir, quantization_config=bnb_config, device_map={"": 0}
)
# Define the text you want to infer
input_text = "What are the key differences between Python and C++?"
# Tokenize the input text
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
# Generate text
with torch.no_grad():
outputs = model.generate(**inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95)
# Decode the generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Human:", generated_text)
You can play with your finetuned model using the this script by trying different prompts and turning the model temperature.
Here is the answer:
Human: What are the key differences between Python and C++?
### Assistant: Python is a high-level programming language with syntax similar to the English language. It is popular in machine learning, data analysis, and scientific computing. C++ is a low-level, system programming language used in areas such as game development, graphics, and server applications.
Here are some of the key differences between Python and C++:
1. Syntax: Python has a simple and easy-to-read syntax compared to C++'s more verbose syntax. This makes Python more approachable for beginners, but C++ provides more control over low-level operations.
2. Performance: C++ is often faster than Python due to its optimized performance. However, Python's dynamic typing and interpretation can provide more flexibility for prototyping and exploratory data analysis.
3. Portability: Python's code is often more portable and cross-platform
LLMs prediction behavior is not only defined by the model weights, but also largely controlled by the prompt and inference parameters such as max_token_length, top-k, top-p, and temperature.
How does our model compare to the raw, untuned version of Mistral-7B? I tested the untuned Mistral-7B model with the same prompt as above.
What are the key differences between Python and C++?
Python is a high-level
As we can see in the above results, there is a significant improvement in the PEFT model as compared to the original model. We now have a fine-tuned version of one of the most powerful open-source LLMs ever released!
Merge LoRA adapter in to the original modelWhen using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the merge_and_unload method and then save the model with the save_pretrained method.
import torch
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM
# Load the fine-tuned model
output_dir = "./fine-tuned_mistral/checkpoint-100" # Path where your fine-tuned model is saved
device_map = "auto" # Adjust this according to your device setup
model = AutoPeftModelForCausalLM.from_pretrained(
output_dir,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.bfloat16,
device_map=device_map,
)
# Merge LoRA and base model
merged_model = model.merge_and_unload()
# Save the merged model
merged_model.save_pretrained("merged_model", safe_serialization=True)
# Load the tokenizer and save it
tokenizer = AutoTokenizer.from_pretrained(output_dir)
tokenizer.save_pretrained("merged_model")
# push the merged model to the Hugging Face Hub
hf_model_repo = "fine-tuned_mistral"
merged_model.push_to_hub(hf_model_repo)
tokenizer.push_to_hub(hf_model_repo)
If you are ready to share your trained model with the broader community, you can easily export it to the Hugging Face Hub with just one command as shown above.
Fine tuning using Flash attentionIt is well known that the memory requirements of Attention mechanisms scale quadratically in the number of input tokens. FlashAttention-2is a faster and more efficient implementation of the standard attention mechanism, helping reduce both model training time and inference latency.
This installation based on the docker image developed by Dustin Franklin from Nvidia.
Inside of BitsandBytes container, git clone the FlashAttention repository using the following command:
git clone --depth=1 --branch=v2.5.7 https://github.com/Dao-AILab/flash-attention /opt/flash-attention
Next, navigate to the repository directory:
cd /opt/flash-attention
Apply the patch:
git apply /app/patch.diff
Build the package:
export FLASH_ATTENTION_FORCE_BUILD=1
export FLASH_ATTENTION_FORCE_CXX11_ABI=0
export FLASH_ATTENTION_SKIP_CUDA_BUILD=0
export MAX_JOBS=$(nproc)
python3 setup.py --verbose bdist_wheel --dist-dir /opt
Install the built wheel:
pip3 install --no-cache-dir --verbose /opt/flash_attn*.whl
Check that it was installed:
pip3 show flash-attn && python3 -c 'import flash_attn'
This should display information about the installed package, confirming successful installation.
Name: flash_attn
Version: 2.5.7
Summary: Flash Attention: Fast and Memory-Efficient Exact Attention
Home-page: https://github.com/Dao-AILab/flash-attention
Author: Tri Dao
Author-email: trid@cs.stanford.edu
License:
Location: /usr/local/lib/python3.10/dist-packages
Requires: einops, ninja, packaging, torch
Required-by:
Modify the training code by adding use_flash_attention_2=True:
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map='auto',
use_cache=False,
trust_remote_code=True,
use_flash_attention_2=True
)
Below is the training for the hypeparamaters with the 512 sequence length.
I did not observe significant performance differences after enabling Flash Attention. The training time is almost the same. According to Flash Attention, it is more memory efficient, meaning you can train on much larger sequence lengths without encountering out-of-memory issues, potentially reducing memory usage up to 20x for larger sequence lengths.
Shorter sequences are faster to fine-tune, while a higher maximum sequence length increases GPU memory consumption.
Here are the results of fine-tuning with a maximum sequence length of 2048 and other hyperparameters used previously.
You can check the project on Weights & Biases.
Training time remains the same as when training without Flash Attention. Here is the validation loss graph:
System memory utilization during training can be monitored using Weights & Biases.
It seems that Flash Attention did not improve memory utilization either. It needs further investigation. Anyway, the purpose of this project was to check the feasibility of fine-tuning on the NVIDIA Jetson AGX Orin Developer Kit. I have successfully accomplished this task. We have successfully fine-tuned a state-of-the-art language model, leveraging the power of Mistral 7B alongside Hugging Faceβs libraries. Here is the link to my fine-tuned model.
IMHO, the Nvidia Jetson AGX Orin performs best with smaller batch sizes due to its limited memory and computational power. Training times decrease as the batch size decreases, highlighting the importance of optimizing batch size for the specific hardware capabilities of the device. Potentially, larger data transfers between CPU and GPU, which can introduce latency and slow down the training process. These larger transfers can become a bottleneck on devices with limited bandwidth.
That's it, you can experiment with your hyperparameters to achieve better results.
Thank you for joining this tutorial on fine-tuning Mistral on the NVIDIA Jetson AGX Orin Developer Kit. If you have any questions, please do not hesitate to contact me here.
References:- Fine-tune a Mistral-7b model with Direct Preference Optimization
- Train your own South Park Fanatic AI with Mistral-7B
- How to Fine-Tune LLMs in 2024 with Hugging Face
- Mistral-7B Instruct Fine-Tuning using Transformers LoRa
- Scripts for fine-tuning Llama2 via SFT and DPO.
- Fine-tuning an LLM on your texts: part 4 β QLoRA
- Whatβs batch size in LLM training or fine-tuning?
Comments