In this project, we have implemented a BERT large model with FPGA kernel with Apache TVM software. The BERT model has contributed to the recent significant improvements in natural language processing (NLP) applications. The BERT model provides rich contextual embeddings that allows transfer learning possible in many NLP downstream applications. It consists of many layers of multi-headed self-attention, which requires large size matrix multiplications. We have implemented one of the most frequent and computationally costly operators, matrix multiplication, in BERT with FPGA kernel on the Alveo U50 card. For designing the kernel, we have used Vitis HLS (high level synthesis) tool to increase design and debugging efficiency. We have utilized HBM banks of the Alveo U50 card to stream weights into the matrix multiplication kernel.
To run the FPGA kernel as part of the pre-trained Huggingface BERT large model in inference, we use Apache TVM deep learning compiler and graph executer framework. As a result, we have executed the FPGA matmul kernel from Huggingface BERT large model without modifying the pre-trained model.
BERT model analysisAs part of the BERT model analysis, we have implemented the Transformer layer, which is the important component of the BERT model, in C and Python in simplest form for educational purpose (Github). From this exercise, we have identified that the main computational component of the BERT is a large size matrix multiplication. Each layer of the BERT model consists of five matrix multiplications for computing multi-headed self-attention (more detailed descriptions for the model will be in the following section).
We use the Huggingface BERT large model as our experiment. We reviewed the model optimization such as Common Subexpression Elimination (CSE), as described in this link with using Apache TVM, which allows us to analyze and optimize the input model.
Since the model is based on FP32 (floating point32) and we need integer math for FPGA kernel design, we have applied dynamic quantization transformation from PyTorch [link] to the BERT large model for experiment.
Software framework for inference flowFor deep learning inference flow, we need a proper software that can ingest the trained model first, and optimize the model for the deployment target hardware. In other words, we need to choose a deep learning compiler and graph executor for the runtime environment. For this purpose, we selected Apache TVM software to parse Huggingface BERT large model and then optimized and execute the model to FPGA kernel, further description can be found from the following section.
Design approachIn order to design FPGA kernels, there are two different approaches: MPE (matrix of processing elements) and SDF (synchronous data flow) as described in the Xilinx link. MPE approach is to build smaller h/w kernels for multiple purposes, such as Xilinx DPUs and TVM’s VTA. SDF approach is to build more specific kernels for dedicated operations, and examples of them are Xilinx’s FINN kernels. For our project, we have researched and perform initial tests for both approaches in the exploration phase, such as implementing TVM’s VTA on Alveo U50 accelerator (Github). After the initial assessment, we have decided on the SDF approach, to build a kernel for large matrix multiplication instead of MPE approach. Alveo hardware provides a large number of DSP and HBM resources that will allow us to prepare a large matrix multiplier, which is a major operator in the given model. MPE approach is more common in the resource constrained environment.
Report organizationThe organization of the report is as follows: we first provide an overview of the BERT, Transformer, and self-attention layers. Then we describe Apache TVM S/W as a deep learning compiler and graph executor framework. We then describe FPGA kernel design.
Experimental results will be described. Finally, we conclude the report with summary and the future work.
Code repositories for the project are included at the end of the report.
Transformer and BertIn this section, we briefly summarize Transformer as a base layer of BERT as well as a self-attention mechanism, which is the building block of both Transformer and BERT.
Transformer model uses a self-attention mechanism to provide contextual embedding for NLP applications. Bert uses stacks of transformer layers and provides a powerful NLP contextual embedding encoder and is used as a backbone model for NLP applications.
Prior to Transformer and Bert, in NLP processing, RNN (recurrent neural network) model is commonly used because it could process sequence inputs. However, RNN (recurrent neural network) is slow due to its recurrent connection. To address this slowness, faster 1D convolution could be used; but keeping the relationship longer than the kernel size is a limitation, and this is a disadvantage in NLP, because NLP application requires to process a long sequence of inputs.
Transformer network is trying to keep the best of the both approaches (RNN and 1D convolutions). It can model the dependencies in the whole input range. Transformer is a deeply stacked multi-head self-attention network together with a feedforward network. It has two non-linearity, softmax and ReLU after the feed forward, and the rest of the model are linear transformations and dot products.
In this section, we first describe the self-attention mechanism as a core module of the Transformer, then Transformer model, and finally Bert in sequence.
Self-attention mechanismSelf-attention is a sequence operation, for example a sequence of x_1, x_2, …, x_t goes in and a sequence of y_1, y_2, y_3, …, y_t is produced.
To compute output vector y_i, self-attention uses weighted-average over all its input vectors:
where j is indexed over all sequences, w_ij is derived from a function over x_i and x_j.
Since dot product value gives the output ranges between positive and negative infinity, we use softmax to keep them in [0, 1] range and normalizes it into a probability distribution.
The dot product represents how much two input vectors are related. The output vectors are weighted sum over the whole input sequence and the weights determined by the dot product.
TransformerEquipped with the self-attention mechanism, a modern transformer network adds three tricks as follows:
- Queries, keys and values: a way to control parameters, essentially they use the same input, but transformed by respective weights
- Scaling the dot product: a way to stop softmax function from growing too large,
- Multi-head attention (shown in figure below, multi-headed attention): effectively ensembles of multiple heads, so that it allows words can mean different things to different neighbors
Queries, keys and values
For self-attention, instead of using input sequence directly, the transformer uses k by k weight matrices W_q, W_k, W_v to transform each input, x_i for three different parts for the self attention. With these matrices, we have a way to control parameters, so we can modify input vectors to meet the three roles (query, key, value) that they could contribute. The following equations show the previous descriptions for queries, keys, and values.
Transformer can be defined as an architecture with self-attention. Here is the definition of Transformer model borrowed from this link:
- Any architecture designed to process a connected set of units where the only interaction between units is through self-attention
There are some variants, but most of them are roughly in the form shown in the figure below (transformer block). The order and compositions could vary, but the important idea is to combine the followings as shown in the figure below. Here are the building blocks of transformer models:
- multi-head self-attention
- feedforward network
- normalization layer
- residual connections
The major components of the transformer is matrix dot product (or matrix multiplication) which is in the multi-head self-attention and the feedforward network. (This is the reason we decided to design FPGA kernel for matrix multiplication.)
Layer normalization and the residual connections are used to help train deep neural networks faster. In terms of performance, it is comparable to RNN, but it can execute much faster due to parallel computation on GPU.
BERT (Bi-directional Encoder Representation for Transformer) is a deeply bi-directional transformer, which consists of many layers of stacked transformers encoders. BERT is pre-trained on a large corpus consisting of 800M words. During the training, it is asked to predict a random selection of masked out words from the sentences. After the training, BERT understands how the sentences can be put together so that they are likely to show up together.
BERT is contextualized word embeddings, which is a new type of embeddings compared to word or token embeddings. The downside of token based embeddings, such as word2vec and Glove, is that the representation of word is only one way wherever it appears in the sentence. However, a word can mean different things in different contexts. This is the reason why BERT’s contextual embedding could extract rich information from the input corpus.
BERT embedding is useful for the followings:
- Useful for keyword/search expansion, semantic search and information retrieval
- output vectors from Bert can be used as high-quality feature inputs to downstream models. BERT offers an advantage over embeddings like word2Vec, since it is providing information on words around them (context)
BERT uses two ways to do bidirectional direction training:
- Masked Language Model (MLM)
- Two-sentence tasks: given two sentences (A and B), is B likely to be the sentence that follows A, or not?
Pre-trained BERT can be used for various downstream NLP tasks by simply adding final feedforward layer and fine tune it, few examples are as follows:
- question answering
- sentiment classification
- classifying whether two sentences naturally follow one another
BERT has different types of models in size as shown in the below table, and we choose to use BERT large uncase models for our experiment. There are many different types of Bert-like models in different model sizes and with different characteristics [link], but for the purpose of the project, we use the Huggingface model of the original Bert Large.
Finally, as the AlexNet started the CNN based computer vision breakthrough in the 2010s, and provided pre-trained models for numerous applications (transfer learning). Bert started a NLP breakthrough by providing a rich contextual model and pre-trained model that allows transfer learning, hence it is used for many different applications. Therefore, improving and optimizing Bert inference will be crucial for various applications.
Deep learning compiler and graph executors.For preparing the deployment deep learning environment, we need deep learning compilers and graph executors in addition to the popular deep learning software framework. In this section, we summarize how to implement a deep learning pipeline in the deployment environment.
To perform the inference of the trained deep learning model in the production environment, we need a graph runner that executes the neural network on the deployment target. Typically deep learning models are trained with existing deep learning framework (TensorFlow, PyTorch, MxNet), from this process model structure and weight parameters are obtained through training.
Next, the model and the parameters will be converted to the deployment environment by the graph executors’ front end tool.
Graph executors are the tools that can optimize the ingested models by fuse or optimize the original graph for the executing target environments. Examples of these tools are Nvidia’s TensoRT, Intel DNNL, and Apache TVM project. During the optimization process, these tools can convert the input model to their own IR (intermediate representation) then optimize the model and quantize the model if necessary.
Apache TVMIn this project, we selected Apache TVM as our tensor compiler and graph executor to run the Bert model with FPGA kernel because of the following reasons. It provides a flexible front end model converter, so that we can use models that are trained from different deep learning frameworks. It provides clean and structured IR (intermediate representations) and tools to optimize the model graph. In addition, it provides many run-time environment supports for various target systems. The following figure shows the overview of the TVM project.
In this section, we describe the design of matrix multiplication (matmul) for an Alveo accelerator card using vitis_hls. (Github code link)
Transformer blocks (at the core of each Bert layer) have a large number of constant weights. The dimensions of matrices in all layers are identical. Streaming those weights from HBM straight to a large array of multiplication units therefore makes sense.
The kernel written for this has a throughput of 1024 MACs per clock. For int16 that is a 16,384 wide bus for both weights and input tensor values.
Matrix multiplication has been implemented in FPGA many times over.
In the project we set out to only use vitis_hls for coding of kernels.
HLS seems to struggle with such a large number of compute elements. By breaking up into modules using DATAFLOW approach the compiler has decent runtime and outputs RTL that synthesises to a reasonable number of FPGA resources.
Functional entities within the kernel are as follows:
- AXI_MASTER to local memory interface for the input tensor
- AXI_MASTER to streaming interface per HBM channel (128 operands)
- Streaming MAC kernel with 128 DSPs.
- Streaming final summation and scaling from the MAC kernels.
- AXI_MASTER that writes back the outputs from the final summation back to HBM.
1024 DSPs running at 300 MHz make for ~0.6 TOPS in this one kernel.
For comparison: Xilinx DPU IP uses 2 distinct clocks, a 2x clock to run the DSPs. Implementing what they call “Double-pumped DSPs”. That is a feature only RTL designers can accomplish today.
In a similar vein, when using RTL design the DSP48 can be made to do 2 8-bit operations per clock cycle [5].
Getting decent output from HLS with reasonable compile time was a challenge that in the end left us with only the matmul done. There certainly is space in the device still to add a softmax() kernel and make the entire BERT layer run on-chip.
The memory available through HBM is very large and can easily hold the 304M parameters for the BERT large model. That can reduce per-inference host to accelerator communication to just the input tensor.
Integrating FPGA kernel with TVMApache TVM provides flexible ways to use runtime h/w and the operations. It could partition the graph and process different optimizations. TVM defines commonly used operations (TOPI, TVM Operator Inventory) for those existing operators and h/w. For these operators, they define their tensor computation (fcompute) and scheduling templates and priorities, which could be used for automated scheduling such as AutoTVM.
For the new h/w and operators, we need to add the operator definitions and scheduling templates. This is the bare minimum process. However, in order to fully utilize the underlying h/w and the host run time we need more support from the framework. TVM currently provides several ways, and the most noticeable way is called BYOC (Bring your own compiler). This will allow the h/w vendors to integrate their own h/w compilers or intrinsic libraries to the TVM’s heterogeneous run-time.
Alternatively there is also a simpler way, in which a single TVM operator can be defined for a specific operator and this operator can be called as an external functions function. This is called an external tensor function. We decided to use this approach due to time constraint and simplicity. But we aware that this method may not provide the most efficient run-time since it only knows a single operator. For calling the FPGA kernel, we decided to use PynQ call instead of creating a library with OpenCL due to the time constraints of the project.
First, we added a rule to x86 runtime operator strategy when it is for specified data type and shapes. In those cases it will use the FPGA kernel we designed, topi.xilinx_fpga.dense_nopack
, instead of the original operator, topi.x86.dense_nopack
. With this following strategy, 96 matmul are scheduled to FPGA kernel out of 146 matmul.
Below code can be found from this Github repo.
@dense_strategy.register("cpu")
def dense_strategy_cpu(attrs, inputs, out_type, target):
"""dense x86 strategy"""
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_nopack),
wrap_topi_schedule(topi.x86.schedule_dense_nopack),
name="dense_nopack.x86",
plevel=10,
)
if use_pynq and dtype == "int16" and inputs[1].dtype == "int16" and out_type.dtype == "int32" and k <= 3072 and n <= 1024 and m == 14:
strategy.add_implementation(
wrap_compute_dense(topi.xilinx_fpga.dense_nopack),
wrap_topi_schedule(topi.generic.schedule_extern),
name="dense_nopack.xiiln_fpga",
plevel=11,
)
Next we define our external tensor function that can handle the packed function argument processing and call the PynQ function call. The below function will finally call PynQ calls as in tvm.contrib.xilinx_matmul_pynq
, which has the similar calls described in the following matmul kernel test section. The following code can be found form Github.
@autotvm.register_topi_compute("dense_nopack.xilinx_fpga")
def dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense without packing"""
if out_dtype is None:
out_dtype = data.dtype
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
out = te.placeholder((M,N,), name="out", dtype=out_dtype)
CC = te.extern(
(M, N),
[data, weight],
lambda ins, outs: tvm.tir.call_packed("tvm.contrib.xilinx_matmul_pynq", ins[0], ins[1], outs[0]),
dtype=out_dtype,
name="matmul_pynq",
)
if bias is not None:
C = te.compute((M, N), lambda i, j: CC[i, j] + bias[j].astype(out_dtype))
return C
return CC
So in this section, we have described the TVM’s available methods to integrate the new H/W kernels and the method that we chose for the project, which is external tensor function.
Experimental SummaryMatmul kernel testTo test the FPGA matmul kernel, we use the PynQ interface with Jupyter Notebook. As shown below, as a code excerpt, we allocate the input and weight matrices and run the kernel. The following code is in Github.
// allocate buffer
source_v = pynq.allocate(shape=(Nvec,Tsize), dtype=np.int16, target=ol.HBM14)
source_w = [
pynq.allocate(shape=(Nmat*Tsize,Tsize//Nbanks),
dtype=np.int16, target=ol.HBM0),
pynq.allocate(shape=(Nmat*Tsize,Tsize//Nbanks),
dtype=np.int16, target=ol.HBM4),
pynq.allocate(shape=(Nmat*Tsize,Tsize//Nbanks),
dtype=np.int16, target=ol.HBM8),
pynq.allocate(shape=(Nmat*Tsize,Tsize//Nbanks),
dtype=np.int16, target=ol.HBM12),
pynq.allocate(shape=(Nmat*Tsize,Tsize//Nbanks),
dtype=np.int16, target=ol.HBM16),
pynq.allocate(shape=(Nmat*Tsize,Tsize//Nbanks),
dtype=np.int16, target=ol.HBM20),
pynq.allocate(shape=(Nmat*Tsize,Tsize//Nbanks),
dtype=np.int16, target=ol.HBM24),
pynq.allocate(shape=(Nmat*Tsize,Tsize//Nbanks),
dtype=np.int16, target=ol.HBM26)]
outbuf = pynq.allocate((Tsize*Nmat,Nvec), dtype=np.int32, target=ol.HBM14)
// sync to device
for i in range(NUM_HBM_CHAN):
source_w[i].sync_to_device()
source_v.sync_to_device()
// call the kernel
ol.feeder_1.call(
source_v,
source_w[0],
source_w[1],
source_w[2],
source_w[3],
source_w[4],
source_w[5],
source_w[6],
source_w[7],
outbuf,
Nmat,
Nvec, # seq length
0)
// receive the output
outbuf.sync_from_device()
With the same Jupyter notebook environment, the following table shows the computation time for various size of matrix multiplication operation, the measurement is averaged over 1000 runs.
In this section we describe the Bert large model experiment with the FPGA kernel with TVM using the code excerpts, full code could be found form the listed code repository.
First we load the Huggingface BERT large model.
model_org = BertForSequenceClassification.from_pretrained(‘bert-large-uncased’)
Then we convert the model using the quantized model using the PyTorch dynamic quantization.
quantized_model = torch.quantization.quantize_dynamic(
model_org, {torch.nn.Linear}, dtype=torch.qint8
)
We then created a traced model using Torch script so that we ingest the model to TVM.
traced_quantized_model = torch.jit.trace(quantized_model, [tokens_tensor, segments_tensors])
q_mod_bert, q_params_bert = tvm.relay.frontend.pytorch.from_pytorch(traced_quantized_model,
shape_list, default_dtype="int8")
Then we use the TVM to generate internal graph representation for its run-time and also calls FPGA Kernel.
with tvm.transform.PassContext(opt_level=3):
q_graph, q_lib, q_params = tvm.relay.build(q_mod_bert,
target=target,
target_host=target_host,
params=q_params_bert)
q_module_tvm = tvm.contrib.graph_runtime.create(q_graph, q_lib, ctx)
q_module_tvm.set_input("input_ids", tt_a)
q_module_tvm.set_input("attention_mask", st_a)
q_module_tvm.set_input(**q_params_bert)
q_module_tvm.run()
o0 = q_module_tvm.get_output(0)
Experimental resultWe use 14 sequence input for the experiment and the Huggingface Bert large model and quantized using Pytorch. The FPGA kernel computes 14x1024 and 1024x1024 matrices. The following table shows the time latency measurement for the experiment. Experimental code is in this Github.
The first row measures the baseline a single inference time of sequence of 14 input of the TVM run on the CPU only run-time. The second (#2) row measures the latency of FPGA kernel calls with the TVM model.
The following table shows the breakdown items of measured and unmeasured latencies of the FPGA kernel run. The first row obtained by measuring individual kernel run and input and output matrix move to and from FPGA HBM. The second row shows the overhead of the TVM external tensor function.
The third row shows the measured overhead, TVM external function and Kernel call and data move, which is 163 msec. Final row is the remaining portion of the time difference between the TVM run on CPU and TVM run on FPGA, and we subtract the measured difference. We conjecture this time, 65 msec, is due the overhead within the external tensor function, such as transferring between TVM data array to numpy data array in each PynQ operation.
Since the weight matrices don’t change over the repeated inference flow, if we pre-created all weight matrices and moved to Alveo’s HBM. Then this will further reduce the FPGA inference latency. Additionally further performance profile could reveal the times taken.
SummaryWe believe that improving Bert inference time is crucial to improve many NLP applications. With this goal, we have analyzed the Bert model and made an end-to-end development flow work. We implemented one of the main computations, which is a large matrix multiplication for self-attention, as a Vitis HLS kernel. In addition to the FPGA kernel, software framework for model optimization and runtime support is critical in deep learning inference flow. To achieve these goal, we used Apache TVM as our graph optimizer and graph executor. We use TVM’s external tensor function to call the FPGA kernel from TVM run time.
With the method that we used in the project, due to the overhead from TVM external tensor functions to PynQ call, we may not obtain the reduced latency from our model compared to the baseline latency. However, we believe that with the better integration of FPGA kernel with Apache TVM, such as BYOC framework to optimize the graph and fuse the model to reduce the overhead, and use more ‘native’ calls (OpenCL) than PynQ will achieve reduced latency for efficient Bert inference.
Future workHere are some of the future directions of the project. With the tighter newly added FPGA kernel operator and schedule integration with Apache TVM will allow the tool to optimize the model instead of simpler external tensor function used in the project. With HBM at our disposal, there are a number of opportunities to do more work on the FPGA before returning results to the host. This will greatly reduce the overhead currently incurred. Now that the system works end-to-end, low hanging fruit to get much better is to use the native OpenCL C++ code instead of calling PynQ python calls. Also by using Apache TVM, we could analyse the dataflow and schedule a data movement and kernel execution sequence to hide the memory movement time.
Another idea is to add features to Apache TVM to produce multiple HLS code templates based on the input model, similar to the currently available run time code generation for CPU and GPU. Together with TVM’s auto tuning tools (AutoTVM and newer variants), potential HLS code could be evaluated and the most optimal HLS run time code could be selected for the inference flow.
Reference- Attention Is All You Need, Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, https://arxiv.org/abs/1706.03762
- BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova https://arxiv.org/abs/1810.04805
- TRANSFORMERS FROM SCRATCH, http://peterbloem.nl/blog/transformers
- RNN (recurrent neural network), https://en.wikipedia.org/wiki/Recurrent_neural_network
- https://www.semanticscholar.org/paper/The-DSP-architecture-in-UltraScale-TM-and-%2B-TM-and-Wu/3ac64259d37ad76c640333bf8cfccd36bb9bc4f0?p2df
- BERT model analysis with TVM, https://lernapparat.de/transformers-pytorch-tvm/
- Huggingface transformer model, https://huggingface.co/
- PyTorch BERT quantization, https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html
- Xilinx and TVM, https://sampl.cs.washington.edu/tvmconf/slides/2019/Elliott-Delaye-Xilinx.pdf
- Different types of Bert model, https://medium.com/huggingface/distilbert-8cf3380435b5
- TVM external tensor functions, https://tvm.apache.org/docs/tutorials/language/extern_op.html
- TVM BYOC, https://tvm.apache.org/2020/07/15/how-to-bring-your-own-codegen-to-tvm
- HLS Matrix multiplication for BERT large model
This is a repository implementing a large matrix multiplication kernel in HLS that takes advantage of HBM of the Alveo U50 card.
https://github.com/gitbisector/vitisbertl
- Apache TVM fork for the project
This is Apache TVM fork repo, which includes the changes for project
https://github.com/insop/incubator-tvm
- Vta test code on Alveo:
This repository builds simple VTA on the Alveo and OpenCL
https://github.com/insop/transformer_adaptive_computing
- Simple Implementation for Transformer with C and Python
This repository implement Transformer with C and python for educational purpose
https://github.com/insop/transformer_simple
AcknowledgementsWe would like to express our sincere appreciation to Xilinx for allowing us to use Nimbix cloud Alveo accelerator resource for the project.
Comments