The Great Transformer
Understand the design and application of the Transformer-styled models
This year, ChatGPT changed the game of the AI market, and the well-known network architecture, named Transformer, is the fundamental base stone of its artificial brain. You might also hear about BERT, which is a Self-Supervised Learning technique for Natural Language Processing (NLP) problems, and it’s totally based on Transformer. As a result, to study the coolest architecture of AI, the Transformer is the only thing you should never miss.
This article will go through some important academic works around Transformer. It will start from the basic idea of the Transformer, then go to its variants in Computer Vision and multimodal, and finally end up with its efficient versions. If you’re working with AI, this might be the minimal knowledge that you must know about the Transformer and its variants.
Transformer
The Transformer is an encoder-decoder architecture. The encoder-decoder architecture indicates that it is a Sequence-to-Sequence model that can transfer a sequence of data to another sequence of data. We the sequence data is a sentence or article, it can perform tasks like translation, conversion, and text summarization. The overlook of the Transformer is shown below.
Basically, Transformer is a bunch of stacked encoders and decoders, and the encoder and decoder are all based on Self-Attention (SA) layers.
Self-Attention (SA)
Let’s begin with Transformer. Transformer is a neural network architecture based on the Self-Attention mechanism. In previous works, the neural networks usually rely on the Convolution operation to reduce the computation cost. In contrast, the Transformer doesn’t include any single convolution operation in its architecture.
The key concept of SA is “weight average”. It believes that high-level knowledge can be extracted from the lower-level context. The following GIF animates the attention weights across tokens for different heads. To test it by yourself, you can try the Tensor2Tensor Intro.
Each word, namely the token, is represented as a high-dimensioned vector that can store information. The tokens on the left-hand side are the input low-level representations, and the output high-level token is on the right-hand side. The colorful link indicates math multiplication and summation, just like we have seen in the typical neural network illustration. Furthermore, the summation of all the links related to a single output token is limited to 1. That’s why I said SA is basically a kind of “weight average”.
Mathematically, it’s still straightforward if we understand it as a weight average method. Here’s the formula of the SA layer.
The attention block has 3 inputs: query(Q), key(K), and value(V). Usually, they got the same dimensions, NxC, where N is the number of tokens, and C is the number of channels. The only purpose of Q and K is to generate a square matrix, and each row of this matrix presents the weight for each token in V.
If we visualize the matrix after softmax, is clear that it is a correlation matrix that is equivalent to the links in the SA visualization above.
Additionally, the input [Q, K, V] is generated by a linear projection on the input tokens. So we call it Self-Attention: the inputs of the attention operation all come from the token itself.
Multihead Self-Attention (MHSA)
Practically, we apply parallel SA blocks in a single layer and concatenate all the tokens together, instead of a single SA layer. That is Multihead Self-Attention (MHSA), and each head stands for a SA block with a feed-forward layer. MHSA enables the network to extract richer token information with a controllable increment of computation effort.
The linear layer is equivalent to the Position-wise Feed-Forward Layer in the paper, which is a stack of two 1D-convolutions.
Positional Encoding (PE)
Positional encoding is a technique to help the network realize the position of each token not only according to the “neuron location” but also directly from the token values. Positional encoding can be either learned or predefined, the predefined formula is called sinusoidal function:
Where pos is the position in the length of tokens, i is the index of channels, and d_model is the channel dimension of the model. The output of PE is called positional embedding. It has the same dimensions as the sequential token and will be directly added to the input token, instead of concatenating.
This GitHub provides good example code for sinusoidal positional encoding in both Tensorflow and PyTorch. There are various types of positional encoding methods, like Leaned Positional Encoding, Relative Positional Encoding, and Rotary Positional Encoding.
Transformer Architecture
Here is the detailed architecture of the Transformer. Actually, it is much more than just a bunch of self-attentions. There are some details here:
- The Self-Attention layers are all implemented with MHSA.
- The input is not words or characters. Instead, the words will be tokenized into tokens, which are high-dimensioned vectors.
- Use the residual link in the encoders and decoders. That eases the risk of vanishing gradient and improves the learning stability.
- The decoder gets inputs from 2 sources: the encoded tokens, and the previous output of the decoder.
- The “output vocabulary” is limited, since it’s generated with a Softmax activation which selects the most possible word as the output from a limited set of predefined vocabularies.
- The length of input and output tokens is fixed and predefined.
During the runtime, the decoder works as an autoregressive model that outputs words one by one based on the previous output words. The following GIF animates its workflow and is very easy to understand.
Transformer on Computer Vision
As the Transformer was getting popular, researchers also wanted to leverage it to boost computer vision tasks. As a sequence-to-sequence model, the Transformer needs some modification for the 2D image input.
This section will discuss about 3 most popular tasks in computer vision: image encoder, video encoder, and object detection.
Image Encoder: Vision Transformer
Vision Transformer (ViT) proposed the idea of representing any image as 16x16 words with the following steps:
- Split the images into small image patches.
- Use a Linear Projection to generate image tokens from the image patch.
- Adding the positional embedding to the image tokens
- Add [CLS] token in the first position of the sequence, which is a learnable token.
- Apply Transformer Encoder for extract feature.
In this paper, authors also experimented with the 1D and 2D positional encodings, but the results show no significant difference between them. This meaningful result shows that even though spatial correlation is important for image processing, the Transformer encoder can work under the simplest positional encoding settings.
There are a bunch of Transformer-based image encoders that have been proposed in recent years. The following GitHub project provides pyTorch implementation of various famous Transformer-based image encoders like ViT, Swin Transformer, CrossFormer, Tokens-to-Token ViT, etc.
Video Encoder: Video Vision Transformer
Video Vision Transformer (ViViT) is one of the video version of ViT models. Since the video is a 3-dimension (T × H × W) spatial-temporal data, the main contribution of ViViT experimented with the speed-accuracy trade-off of 4 different possible designs that feed spatial-temporal data into a Transformer that has only 1D sequence input.
First, the ViViT uses either 2D convolution across frames or 3D convolution on image sequence to extract embedding. For 2D convolution, it is called the Uniform frame sampling. For 3D convolution on image sequence, it is called the Tubelet embedding.
For the encoder, ViViT experimented 4 different possible designs:
- Spatial-Temporal Attention: a Vallina Transformer (input length: T × H × W) with full embeddings.
- Factorised encoder: a late fusion design that a temporal Transformer (input length: T) followed by a spatial Transformer (input length: H × W) across frames.
- Factorised self-attention: a late fusion design inside the MHA blocks, which is a temporal SA block (input length: T) followed by a spatial SA block (input length: H × W) across frames.
- Factorised dot-product attention: the parallel branch design inside the HMA blocks, in which the temporal SA block (input length: T) and spatial SA block (input length: H × W) share a common query feature Q (length: T × H × W).
Although spatial-temporal attention gets the best accuracy, the other design also shows comparable accuracy with significantly lower runtime. Moreover, the factorised dot-product attention didn’t increase the parameter size than the spatial-temporal attention.
Object Detection and Segmentation: DETR
DETR (DEtection TRansformer) might be one of the most important research in Transformer. It follows the encoder-decoder structure from BERT. However, on the decoder side, DETR shows that the only thing we need to predict multiple objects is learnable tokens, named object query.
The object queries are learnable embeddings with shape [N, C], where N is the number of queries, and C is the channel. The output object count is limited by N, which is equal to the anchor size in RCNN-based detection.
DETR removed lots of old designs that we designed for object detection tasks before such as RPN (region proposal network) and anchor, which all need handcrafted hyperparameters set by humans.
According to the analysis session in the paper, DETR outperforms Faster-RCNN for the following reasons:
- The object queries can learn the object centers corresponding to its object size. On the other hand, Faster-RCNN places the anchor across the whole image with the same anchor settings.
- The strong ability of self-attention makes DETR able to detect unseen scenarios by modeling the global in local information at once. However, for CNN-based detectors, this paper in 2018 investigated the Faster-RCNN and found that it’s easy to be fooled by the context around the target object.
Following the encoder-decoder architectures, DETR is also able to perform semantic segmentation tasks by replacing the object decoder with the FPN-styled CNN image decoder.
For more fantastic Transformer-based segmentation architecture, Segmenter is another outstanding research in a totally Transformer-style.
Cross-Attention: Multiple Input Sequence
According to the original self-attention design, it only works for single input data. But sometimes we need to encode features from multiple data sources when the tasks need multiple inputs. For example, the visual question answering (VQA) task in GPT-4.
To do so, cross-attention (CA) should be the most important puzzle piece. CA is an algorithm that blends information from one sequential input to another sequential input. It is useful when you have 2 kinds of input modalities and you want a good multimodal presentation.
The only difference between SA and CA is that: for SA, the [Q, K, V] are different projections from a single sequential data, but for CA, one of them is from another data series.
The original Transformer work used the CA block when parsing the encoded token into the decoder. However, it can be redesigned as other novel modules.
- The Perceiver IO uses 2 cross-attention modules to blend data. The first one is for input and latent array, and the second one is for output and latent array.
- The SelfDoc fuses information between textual features and visual features with CA by replacing K or V respectively with that from another modality.
- The CrossViT uses 2 embeddings with different scales from the same image. It overcame the weakness of ViT, in which the image embeddings are generated by a single convolution layer with a fixed kernel size.
- The StableDiffusion uses cross-attention to guide the diffusion model feature with the additional guidance input, e.g. text. Check The Magic of Diffusion Models to learn more about the diffusion models.
Most of the related works applied cross-attention by replacing Q or V, instead of K. The reason is that CA is a series of matrix multiplication so changing the dimension of Q or V only requires matching the column size or row size of K. On the other hand, replacing K is only valid if and only if another K has the same matrix dimension as SelfDoc did.
Today, cross-attention is anywhere. You can find it in almost every Transformer-styled paper.
Efficient Transformer
Although the Transformer is powerful, the fantastic design also introduces the high computation and memory cost, so that the length of the input sequence is limited by the TFLOPs (Tera floating-point operations per second) and memory of the machine.
Transformer-XL
To feed longer input for the Transformer, Transformer-XL solves this problem by splitting a large sequence input into small segments of sequences and feeding them into the Transformer in an RNN-like manner.
Here are some important key points in the Transformer-XL design:
- The transformer is single direction instead of fully correlated, which means the correlation matrices for tokens are triangle matrices.
- During the training, the Transformer is trained by each segment.
- Each segment uses the hidden states of the previous segment as the input.
- Combining the relative positional encoding and learned positional encoding to capture the absolute position and related position between tokens.
- All the Transformer models across segments share the same weights.
Longformer
The most computation cost that constrains the length of the Transformer is the matrix multiplication of K and Q in the SA block. To speed up, Longformer assumes that the correlation matrix in the SA block should be sparse since each output token should only correlate to parts of input tokens.
By visualizing the density of the correlation matrices by color, Longformer introduces a sparse correlation matrix that combines sliding windows and global connections.
The full n² attention is the standard self-attention manner. For sliding window attention and dilated sliding window attention, they’re equidistant to convolution operation since the output token only takes the neighbor input tokens into account. Even though the sliding window attention, namely convolution, has a lower computation cost, however, it only generates the local feature since it’s limited by its perceptive field.
Longformer takes both global attention and local attention into account by adding them together. Practically, the authors suggest only applying global attention to certain types of tokens, for example, the learnable [CLS] token in classification tasks.
Be aware that the size of the sliding window in Longformer is 512, which is far larger than the normal convolution. This is because 512 is also the length of the input sequence for BERT, a LLM model in Transformer design. Setting the size of the sliding window to 512 can promise the basic performance in NLP tasks with Transformer. During the evaluation, Longformer sets the length of input to 32256, which is impossible to run on today’s hardware with the Vallina design.
LinFormer
LinFormer formula the correlation matrix in self-attention with singular value decomposition (SVD), which proves that we can origin matrix the multiplication with 2 low-rank matrices.
By learning another projection layer that projects the V and K into the lower dimension, the computational complexity is reduced to O(n × k) from O(n × n), where k is the rank setting, n is the length of token, and d is the dimension of tokens. In practice, the k will be far lower than n.
FlashAttention
We can also solve the performance issue by IO-aware design. Most of the methods are focused on solving computation efficiency, however, FlashAttention maximizes the memory access efficiency during the self-attention QKV computation. Usually, the computer has 3 types of memory:
- CPU DRAM (dynamic random access memory): That’s the most used memory, having the largest memory size and the smallest IO bandwidth.
- GPU HBM (high bandwidth memory): A 3D-stacked DRAM for GPU, next by the main GPU chip. It provides a much larger bandwidth than normal CPU DRAM but a lower memory size.
- GPU SRAM (static random access memory): A GPU memory that is used for most frequently accessed data. It has the largest bandwidth and smallest memory size.
FlashAttention proposed an algorithm that tiles the QKV to multiple submatrices that are small enough to access the GPU SRAM, and then output back to HBM for the final output.
For a [N, d] sized sequence data, the computation complexity is still O(N²d), however, the HBM memory access is reduced from Θ(Nd + N²) to Θ(N²d²/M), where M is the SRAM memory size. For the normal settings, M(=0.1MB) is far larger than d(=512).
FlashAttention is very attractive for GPU makers since it practically speeds up the Transformer to ~3x fast without changing the model weights and structure. Even though they manually use Triton to implement the HBM/SRAM memory transition in the GitHub code, however, it will be supported automatically by the Runtime Library in the future, and the AI developers will not need to be aware of them.
There are many outstanding Transformer researches and applications that are not included in this article. The truth is, they will NEVER be included in a single article since it’s too much. Transformer has become a new paradigm for solving problems, transferring your raw data to sequential data with a feasible positional encoding almost works for anything.
Although the Transformer is powerful, using it without understanding is dangerous since you may waste your time and lose your direction in the oceans of research. In that case, even ChatGPT can do better than you do.