- The main goal of YOLOv5 is designing a fast operating speed of an object detector in production systems and optimization for parallel computations, rather than the low computation volume theoretical indicator (BLOP).
- YOLOv4 implementation in Darknet while that of YOLOv5 is in PyTorch, hence v5 may be easier to bring to production while v4 is where top-accuracy research may continue to progress.
- It is a natural extension of the YOLOv3 PyTorch repository.
- After fully replicating the model architecture and training procedure of v3, ultralytics began to make research improvements alongside repository design changes.
[Update] We have released code and sources of this hackster.io project at here: https://github.com/LogicTronix/Vitis-AI-Reference-Tutorials/tree/main/Quantizing-Compiling-Yolov5-Hackster-TutorialUpdates in Yolov5:
- PANet updates: new heads, reduced parameters faster inference, and improved mAP.
- FP16: as a new default which leads to smaller checkpoints and faster inference
- CSP updates
It uses the same architecture as that of Yolov4.
- Involves creating features from input images. These features are then fed through a prediction system to raw boxes around objects and predict their classes.
- The YOLO network consists of three main pieces.
Backbone: A convolutional neural network that aggregates and forms image features at different granularities.
Neck: A series of layers to mix and combine image features to pass them forward to prediction. Head: Consumes features from the neck and takes box and class prediction steps.
Main training procedures:
- Data Augmentation: transformation to base training data to expose the model to a wider range of semantic variation than the training set in isolation. Eg: Scaling, Color space adjustment, and mosaic augmentation.
- Loss calculation: GIoU, object, and class loss.
Similar to Yolov3, the v5 network predicts the bounding boxes as deviations from a list of anchor box dimensions.
- Conversion from 32-bit precision to 16-bit precision which helps to speed up the inference time of models.
Both YOLOv4 and YOLOv5 implement the CSP Bottleneck to formulate image features. Research credit for this architecture is directed to WongKinYiuand their recent paper on Cross Stage Partial Networksfor the convolutional neural network backbone.
The CSP addresses duplicate gradient problems in other larger ConvNet backbones resulting in fewer parameters and fewer FLOPS for comparable importance. This is extremely important to the YOLO family, where inference speed and small model size are paramount.
The CSP models are based on DenseNet. DenseNet was designed to connect layers in convolutional neural networks with the following motivations:
- to alleviate the vanishing gradient problem (it is hard to backprop loss signals through a very deep network);
- to bolster feature propagation;
- to encourage the network to reuse features and;
- to reduce the number of network parameters.
It uses residual and dense blocks to overcome the vanishing gradient problem. However, the problem of redundant gradients occurs which is tackled by the CSPNet by truncating the gradient flow.
PA-Net Neck
- Both Yolov4 and v5 implement the PA-Net neck for feature aggregation.
- Each one of the P_i represents the feature layer in the CSP backbone.
- Improves the information flow and helps in the proper localization of pixels in the task of mask prediction.
- In v5, the network has been modified by applying the CSPNet strategy.
Spatial Pyramid Pooling
SPP block performs an aggregation of the information that is received from the inputs and returns a fixed-length output. Thus it has the advantage of significantly increasing the receptive field and segregating the most relevant context features without lowering the speed of the network. This block has been used in previous versions of YOLO (yolov3 and yolov4) to separate the most important features from the backbone, however in YOLOv5(6.0/6.1) SPPF has been used, which is just another variant of the SPP block, to improve the speed of the network
Different configuration files
v5 formulates model configuration in.yaml, as opposed to the.cfg files used in Darknet.
The main difference between these two formats is that the.yaml file is condensed to specify the network's different layers and then multiply those by the number of layers in the block.
Activation function used
Used SiLU and Sigmoid activation function.
Major Improvement
- The Focus Layer: replaced the three first layers of the network. It helped reduce the number of parameters, the number of FLOPS, and the CUDA memory while improving the speed of the forward and backward passes with minor effects on the mAP (mean Average Precision).
- Eliminating Grid Sensitivity: It was hard for the previous versions of YOLO to detect bounding boxes on image corners mainly due to the equations used to predict the bounding boxes, but the new equations presented above helped solve this problem by expanding the range of the center point offset from (0-1) to (-0.5, 1.5) therefore the offset can be easily 1 or 0 (coordinates can be in the image's edge) as shown in the image in the left. Also, the height and width scaling ratios were unbounded in the previous equations which may lead to training instabilities but now this problem has been reduced as shown in the figure on the right.
- The running environment: The previous versions of YOLO were implemented on the Darknet framework that is written in C, however, YOLOv5 is implemented in Pytorch giving more flexibility to control the encoded operations.
Note: we used Vitis ai 3.0 for the below steps.
Steps:
- Download the pre-trained model from the Ultralytics repo (https://github.com/ultralytics/yolov5).
- Clone the code from the given GitHub repo.
- Arrange the dataset and create a.yaml for the custom dataset.
The above yaml is customized for the BDD dataset. Here, the path to the dataset root directory is given as a path while the images for train, test, and validation are given as train, val, and test. Also, the number of classes in the custom dataset should be specified as nc while the class name should be specified as names in the above-shown format.
- Change the activation function from SiLU to LeakyRelu (use negative slope as 26/256).
Since the Vitis AI does not support the SiLU activation function (The list of supported operators in PyTorch by Vitis AI is given HERE). Out of all the supported activations, Leaky ReLU with a negative slope of 26/256 gives the better result.
In models/common.py (line 66 and 147) and experimental.py(line 55):
self.act = nn.LeakyReLU(26/256, inplace=True) # in place of nn.SiLU
- Train the model using the given code.
python train.py
- Since the permute and view are not supported in the Vitis-AI 3.0, remove the last layer in the detection head and reimplement that last layer in the post-processing section.
Inside yolo.py:
class Detect(nn.Module):
# YOLOv5 Detect head for detection models
stride = None # strides computed during build
dynamic = False # force grid reconstruction
export = False # export mode
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
#original code
def forward(self, x):
z = [] # inference output
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
return x
In detect.py:
##
predi = model(im, augment=augment, visualize = visualize)
pred = postprocessing(predi)
def postprocessing(x):
grid = [torch.empty(0) for _ in range(3)]
z = []
anchor_grid = [torch.empty(0) for _ in range(3)]
stride = torch.tensor([ 8., 16., 32.], device='cuda:0')
# anchors = torch.tensor([[10.,13., 16.,30., 33.,23.],[30.,61., 62.,45., 59.,119.],[116.,90., 156.,198., 373.,326.]] , device='cuda:0')
anchors = torch.tensor([[1.25000, 1.62500, 2.00000, 3.75000,4.12500, 2.87500],
[1.87500, 3.81250, 3.87500, 2.81250, 3.68750, 7.43750],
[ 3.62500, 2.81250, 4.87500, 6.18750, 11.65625, 10.18750]], device='cuda:0')
anchors = torch.tensor(anchors).float().view(3,-1,2)
# anchors[0] = anchors[0] / stride[0]
# anchors[1] = anchors[1] / stride[1]
# anchors[2] = anchors[2] / stride[2]
for i in range(3):
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, 3, 18, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if grid[i].shape[2:4] != x[i].shape[2:4]:
grid[i], anchor_grid[i] = make_grid(nx,ny,i,anchors,stride)
xy, wh, conf = x[i].sigmoid().split((2, 2, 13 + 1), 4)
xy = (xy * 2 + grid[i]) * stride[i] # xy
wh = (wh * 2) ** 2 * anchor_grid[i] # wh
y = torch.cat((xy, wh, conf), 4)
z.append(y.view(bs, 3 * nx * ny, 18))
return (torch.cat(z, 1), x)
def make_grid(nx=20, ny=20, i=0,anchors = None, stride = None, torch_1_10=check_version(torch.__version__, '1.10.0')):
d = anchors[i].device
t = anchors[i].dtype
shape = 1, 3, ny, nx, 2 # grid shape
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
yv, xv = torch.meshgrid(y, x, indexing='ij') if torch_1_10 else torch.meshgrid(y, x) # torch>=0.7 compatibility
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
anchor_grid = (anchors[i] * stride[i]).view((1, 3, 1, 1, 2)).expand(shape)
return grid, anchor_grid
The postprocessing function is implemented based on the original Detect class present in the ultralytics inside models/yolo.py.
Quantizing Yolov5 Pytorch with Vitis AI 3.0Quantization is a technique to reduce the computational and memory costs of running inference by representing the weights and activations with low-precision data types like an 8-bit integer (int8) instead of the usual 32-bit floating point (float32).
Reducing the number of bits means the resulting model requires less memory storage, consumes less energy (in theory), and operations like matrix multiplication can be performed much faster with integer arithmetic. It also allows to run models on embedded devices, which sometimes only support integer data types.
We will be using Vitis AI 3.0 (GPU) for quantization, we can also perform the quantization on the CPU docker of Vitis AI. GPU docker needs to be built locally while CPU docker can be pulled and used. Performing quantization on GPU Docker is faster than CPU Docker.
The following input should be given for the quantization process. The build directory gives the path to the build folder i.e. to save the quantized xmodel. The quant_mode allows us to specify calib or test mode, both of which are required. Also, the weights help us to specify which pt model is to be quantized and the dataset allows us to enter the root to the small dataset to be used during the forward pass.
Steps for quantization:
- Load the model
- Quantization → quant model
- Forward pass
- Handle the quant model
To quantize and compile the model using Vitis AI, at first model(pt or pth model). For YOLOv5, this can be achieved with the following code snippet.
The quantization is performed using the torch_quantizer function from the pytorch_nndct library which is provided by Vitis AI.
The code snippet above creates an instance of a quantizer object. The quantizer object is created by passing the parameters to the torch_quantizer function of the pytorch_nndct library.
The torch_quantizer function takes in:
● quant_mode which can be either “calib” or “test”
● model which is the float model that we loaded.
● rand_in which is the dummy input required for the float model with batch_size, images, each with 3 color channels (RGB), and each image has a resolution of 640 pixels in height and 640 pixels in width.
- output_dir is the location where the quantized xmodel is to be saved.
A forward pass is the process of propagating input data through the network's layers to compute an output. During this forward pass, the input data flows through the layers of the model to produce a prediction or output.
As shown in the code snippet above, a dry run is performed on the dataset (that was passed as argument) along with the non-maximum suppression to get the prediction.
Handle the quant modelThe quantization result is handled based on the two different modes of quantization: calibration (quant_mode == 'calib') and testing (quant_mode == 'test').
- Calibration Mode (quant_mode == 'calib'): In this mode, the script configures the quantizer to perform model calibration and export the quantization configuration. Model calibration is a process where you collect statistics about the model's inputs (e.g., min and max values of activations) to later quantize the model effectively.
- quantizer.export_quant_config(): This function exports the quantization configuration obtained during calibration. This configuration is crucial for quantifying the model correctly during deployment.
- Testing Mode (quant_mode == 'test'): In this mode, the script performs quantization testing and exports the quantized model in different formats. This mode can also be used to validate the quantized model's performance before deploying it in a production environment.
- quantizer.export_xmodel(deploy_check=True, dynamic_batch=True): This function exports the quantized model in "xmodel" format. The parameters deploy_check suggest that this export may include checks for deployment readiness. The quantized model is necessary to generate a compiled model for a given target.
Here is the Quant.py (updated Quantization script with the above mentioned changes).
Compiling a quantized modelCompiling a quantized model refers to preparing a quantized neural network model for deployment on a target platform or hardware. The quantized model is compiled for the target hardware or accelerator. This involves adapting the model to take full advantage of the hardware's capabilities, such as vectorized instructions, hardware accelerators, or specialized memory layouts. The compiled model is integrated into the inference pipeline of the target application.
Once compiled and integrated, the quantized model is ready for deployment on the target platform or device. It can be used to make predictions or perform computations efficiently, taking advantage of the reduced precision and hardware optimizations.
For PyTorch, the quantizer NNDCT outputs the quantized model in the XIR format directly i.e. compiled xmodel formatted model.
Use vai_c_xir to compile the quantized model:
vai_c_xir --xmodel /PATH/TO/quantized.xmodel --arch /PATH/TO/arch.json --output_dir /OUTPUTPATH --net_name netname
Example:
vai_c_xir --xmodel quantized_model/YOLOv5_quantized.xmodel --arch /opt/vitis_ai/compiler/arch/DPUCZDX8G/KV260/arch.json --net_name yolov5_kv260 --output_dir ./KV260
Final layer view in Netron, below is the view of the compiled model for the KV260 board.
Similar type of approach is also discussed at this Vitis AI Forum link :quantizing-ultralytics-yolov5-vitis-ai-v35-modifying-forward-function
Reference:
[1]. YOLO v5 model architecture [Explained]: https://iq.opengenus.org/yolov5/
[2]. Object Detection Algorithm — YOLO v5 Architecture: https://medium.com/analytics-vidhya/object-detection-algorithm-yolo-v5-architecture-89e0a35472ef
[3]. What is YOLOv5? A Guide for Beginners.: https://blog.roboflow.com/yolov5-improvements-and-evaluation/
[4]. https://github.com/ultralytics/yolov5/issues/
[5]. Vitis AI - Yolov5 Tutorial: https://xilinx.eetrend.com/blog/2022/100565582.html
Kudos,
Kudos to Anupam@LogicTronix.com for writing detail and insightful article/tutorial on "Yolov5 Quantization and Compilation". On next tutorial we will go for "deploying the compiled model in KV260 FPGA Board". Kudos to Dikesh@Logictronix.com for planning this tutorial!
You can find Quant.py and compiled model at below at attachment!
And you can also check the github repo of this tutorial: https://github.com/LogicTronix/Vitis-AI-Reference-Tutorials/tree/main/Quantizing-Compiling-Yolov5-Hackster-Tutorial
For queries, you can write to above email or at info@logictronix.com!
Comments
Please log in or sign up to comment.