Paper Summary #8  FlashAttention: Fast and MemoryEfficient Exact Attention with IOAwareness
Paper: FlashAttention: Fast and MemoryEfficient Exact Attention with IOAwareness
Link: https://arxiv.org/abs/2205.14135
Authors: Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
Code: https://github.com/HazyResearch/flashattention
I have also released an annotated version of the paper. You can find it here.
[Update]  I implemented a simplified version of FlashAttention (without the CUDA and SRAM memory optimizations) in PyTorch. Check it out on Github.
I finished reading the FlashAttention paper recently and thought that it would be good to have a technical writeup of the paper, so that it can help me understand the concept well. I decided to make it public and hopefully it can help anyone reading this.
Overview
Attention as we know, in its standard implementation is an \(O(N^2)\) operation, where N is the sequence length. There are many approximate attention methods out there like Reformer, SMYRF, Performer and others (you can find more details on a few of these in my previous blog) which aim to reduce the compute requirements to linear or nearlinear in sequence length, but many of them do not display wallclock speedup against standard attention. They focus on FLOP reduction (which doesn’t always correlate with wallclock speed) and tend to ignore overheads from memory access (IO). FlashAttention aims to incorporate IOawareness i.e. dividing operations between faster and slower levels of GPU memory to make the whole computation faster. The algorithm uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU onchip SRAM. FlashAttention can also be extended to blockspare attention and this results in the fastest approximate (or not) attention algorithm out there.
All this helps to improve the training time of Transformer models  a 15% endtoend wallclock speedup on BERTlarge (seq. length 512) compared to the MLPerf 1.1 training speed record, 3× speedup on GPT2 (seq. length 1K). This memoryefficient approach also helps to incorporate a longer context (up to 16k/64k tokens) which also results in better models (0.7 better perplexity on GPT2).
I’ll describe more details in the future sections.
Background  Hardware Performance
Since FlashAttention computes exact attention, and the major crux of their work is the efficient hardware usage, it is important to know a bit about GPU memory and the performance characteristics of various kinds of operations on it.
GPU Memory Hierarchy
For a A100 GPU with 40GB of High Memory Bandwidth (HBM), a rough diagram of the memory hierarchy is shown above. The SRAM memory us spread across 108 streaming multiprocessors (SMs), 192KB for each. As one can see, the onchip SRAM is much faster the HBM but is much smaller than size. In terms of compute, the theoretical peak throughput for BFloat16 using Tensor Core is 312 TFLOPS. With time, compute has gotten much faster relative to memory speed, hence processes (operations) are increasingly bottlenecked by memory (HBM) access. Thus, the goal of the FlashAttention paper was to use the SRAM as well as efficiently as possible to speed up the computation.
Execution Model
The typical way in which GPUs operate are that they use a large number of threads to perform an operation, which is called a kernel. The input is loaded from the HBM to the registers and SRAM, and written back to the HBM after computation.
Performance Characteristics
There is a term called arithmetic intensity which is given by the number of arithmetic operations per byte of memory access. It helps to understand the bottleneck of an operation. An operation can be characterized as computebound (also called mathbound) or memorybound.

Computebound  When the bottleneck is the compute i.e., the time taken by the operation is determined by how many arithmetic operations there are since the time taken due to HBM accesses is relative lower. E.g. of such operations are matrix multiplication with large inner dimension, and convolution with large number of channels.

Memorybound  When the bottleneck is the memory i.e., the time taken by the operation is determined by the number of memory accesses there are since the time spent in computation is relative lower. E.g. of such processes are most other operation like elementwise operations  activation, dropout and reduction operations  sum, softmax, batch normalization, layer normalization.
To understand this better, let’s analyze it mathematically. Let \(N_{op}\) be the number of arithmetic/floating point operations, \(N_{byte}\) be the number of memory accesses, \({BW}_{compute}\) and \({BW}_{memory}\) be the compute and memory bandwidth respectively, the time taken for compute operations and memory accesses can be determined as 
\[\\ \begin{align} t_{compute} = \frac{N_{op}}{BW_{compute}} \\ t_{memory} = \frac{N_{byte}}{BW_{memory}} \end{align} \\\]The operation is computebound if \(t_{compute}\) is greater than \(t_{memory}\) and viceversa for memory bound. Which mathematically becomes 
For computebound
\(\\ \begin{align} \frac{N_{op}}{N_{byte}} \gt \frac{BW_{compute}}{BW_{memory}} \end{align} \\\)
For memorybound
\(\\ \begin{align} \frac{N_{op}}{N_{byte}} \lt \frac{BW_{compute}}{BW_{memory}} \end{align} \\\)
As mentioned above as well, matrix multiplication for large inner dimensions is compute bound but below that it is memory bound. If using FP32 and plugging in numbers for A100 40GB, then for \(N \lt 74\), the \(N \times N\) multiplication is memory bound, but compute bound when \(N\) is greater than that. A great and detailed resource to understand this theory is this blog post by Lei Mao.
Kernel Fusion
Kernel Fusion is often down by compilers to fuse together multiple elementwise operations. It is used to accelerate memorybound operations. The basic ideas is that instead of loading the input from the HBM, performing the operation and writing back to the HBM and repeating that for each operation applied to the same input, the operation can be fused so that all of the operations are performed at once when the input is loaded from the HBM.
However, one must note that when performing model training, the effectiveness of kernel fusion is reduced as the intermediate values still have to be written to the HBM to save for the backward pass.
Background  Standard Attention
For anyone familiar with transformers, this equation is wellknown 
\[Attention(Q, K, V) = softmax(\frac{QK^\mathsf{T}}{\sqrt{d_k}})V\]Here, the sequences \(Q, K, V \in \mathbb{R}^{N \times d}\) where \(N\) is the sequence length and \(d\) is the head dimension. The attention output, above, can be denoted by \(O \in \mathbb{R}^{N \times d}\). The equation can be broken down as 
\[\mathbf{S} = \mathbf{QK^\mathsf{T}} \in \mathbb{R}^{N \times N},\quad \mathbf{P} = softmax(\mathbf{S}) \in \mathbb{R}^{N \times N},\quad \mathbf{O} = \mathbf{PV} \in \mathbb{R}^{N \times d}\]In standard attention implementations, the \(\mathbf{S}\) and \(\mathbf{P}\) matrices are materialized in the HBM, which takes \(O(N^2)\) memory. Also, most operations are memorybound/elementwise operations, e.g. softmax applied on \(\mathbf{P}\), masking applied to \(\mathbf{S}\), dropout applied to \(\mathbf{P}\). This leads to slow wallclock time.
FlashAttention  Algorithm details
As one may understand, the materialization of the \(N \times N\) attention matrix on the HBM and its repeated reading and writing is a major bottleneck. To solve this, two main things need to be done 
 Computing the softmax reduction without access to the whole input
 Not storing the large intermediate attention matrix for the backward pass
Two established techniques, namely tiling and recomputation are used to solve this.
 Tiling  The attention computation is restructured to split the input into blocks and performing the softmax operation incrementally by making several passes over the input blocks.
 Recomputation  The softmax normalization factor from the forward pass is stored to quickly recompute attention onchip in the backward pass, which is faster than the standard attention approach of reading the intermediate matrix from HBM.
This does lead to increased FLOPs due to recomputation, however FlashAttention runs both faster (up to 7.6x on GPT2) and uses less memory — linear in sequence length, due to the massively reduced amount of HBM access.
Understanding the algorithm
The main idea behind the algorithm is to split the inputs \(\mathbf{Q, K, V}\) into blocks, loading them from slow HBM to fast SRAM and then computing the attention output w.r.t those blocks. The output of each block is scaled by the right normalization factor before adding them up, which gives the correct result.
\[\mathbf{S} = \mathbf{\tau QK^\mathsf{T}} \in \mathbb{R}^{N \times N},\quad \mathbf{S}^\mathrm{masked} = \mathrm{MASK}(S) \in \mathbb{R}^{N \times N},\quad \mathbf{P} = softmax(\mathbf{S^\mathrm{masked}}) \in \mathbb{R}^{N \times N},\] \[\mathbf{P}^\mathrm{dropped} = \mathrm{dropout}(\mathbf{P}, p_\mathrm{drop}), \quad \mathbf{O} = \mathbf{P^\mathrm{dropped}V} \in \mathbb{R}^{N \times d},\]where \(\tau \in \mathbb{R}\) is some softmax scaling factor (typically \(\frac{1}{\sqrt{d}}\)), \(\mathrm{MASK}\) is some masking function that sets some entries of the input to \(\infty\) and keep other entries the same, and \(\mathrm{dropout}(x, p)\) applies dropout to 𝑥 elementwise (i.e., output \(\frac{x}{1p}\) with probability \(1 − p\) and output \(0\) with probability \(p\) for each element \(x\))
Tiling
The key part in understanding the blockwise computation of attention in the algorithm above is the blockwise computation of the softmax. The paper explains it well though. The softmax of a vector \(x \in \mathbb{R}^B\) can be computed as 
And for vectors \(x^\mathrm{(1)}, x^\mathrm{(2)} \in \mathbb{R}^B\), the softmax of the concatenated \(x = [x^\mathrm{(1)}, x^\mathrm{(2)}] \in \mathbb{R}^{2B}\) is given by 
Let’s understand this better. In the above equations, \(m(x)\) holds the maximum between \(m(x^\mathrm{(1)})\) and \(m(x^\mathrm{(2)})\). Now, \(m(x^\mathrm{(1)})\) is the maximum element of \(x^\mathrm{(1)}\) and \(m(x^\mathrm{(2)})\) is the maximum element of \(x^\mathrm{(2)}\) which means that \(m(x)\) is basically the maximum of the whole concatenated vector. The beauty is that this was done blockwise.
So, if statistics \((m(x), l(x))\) are tracked then softmax can be computed one block at a time. In line 12 of the algorithm, \(\tilde{m_{ij}}\) has the maximum element of each row of \(S_{ij}^\mathrm{masked}\), and next in line 13, \(m_i^\mathrm{new}\) holds the rowwise maximum of the \(m_i\) till now and the new one i.e., \(\tilde{m_{ij}}\). Hence \(m_i\) is updated every column from the outer loop and eventually stores the rowwise max of the matrix \(\mathbf{S}\). The same logic goes for \(l_i\) and the matrix \(\mathbf{P}\). The results are combined to get the output attention matrix in line 15.
Recomputation
The backward pass of FlashAttention requires the \(\mathbf{S}\) and \(\mathbf{P}\) matrices to compute the gradients w.r.t \(\mathbf{Q}\), \(\mathbf{K}\), \(\mathbf{V}\). However, they are \(N \times N\) matrices and as it can be seen in the algorithm above, they aren’t stored explicitly. The trick is to use the output \(\mathbf{O}\) and the softmax normalization statistics \((m, l)\), we can recompute the attention matrix \(\mathbf{S}\) and \(\mathbf{P}\) easily in the backward pass from blocks of \(\mathbf{Q}\), \(\mathbf{K}\), \(\mathbf{V}\) in SRAM. even with more FLOPs, the recomputation step speeds up the backward pass due to reduced HBM accesses. The backward pass is very interesting too but slightly more complicated hence I’ll probably cover it in a separate post. One can cover the Appendix B of the paper to learn more.
Kernel Fusion is also used to implement the algorithm in one CUDA kernel, loading input from HBM, performing all the computation steps (matrix multiply, softmax, optionally masking and dropout, matrix multiply), then writing the result back to HBM. This avoids repeatedly reading and writing of inputs and outputs from and to HBM.
Important Information  The FlashAttention algorithm computed \(\mathbf{O} = softmax(QK^\mathsf{T})V\) with \(O(N^2d)\) FLOPs and requires \(O(N)\) additional memory beyond inputs and output (for the \((l, m)\) statistics).
The proof for the FLOPs calculation is given in Appendix C of the paper, which should be checked out by the curious reader.
Important Information  Let \(N\) be the sequence length, \(d\) be the head dimension, and \(M\) be the size of SRAM with \(d \leq M \leq Nd\). Standard attention requires \(\Theta(Nd + N^2)\) HBM accesses while FlashAttention requires \(\Theta(N^2d^2M^{1})\) HBM accesses.
For typical values of \(d\) (64128) and \(M\) (around 100KB), \(d^2\) is many times smaller than \(M\), and thus FlashAttention requires many times fewer HBM accesses than standard implementation. This leads to both faster execution and a lower memory footprint.
The authors also go on to show that the number of HBM accesses by FlashAttention is a lowerbound. There can be no implementation which can asymptotically improve on the number of HBM accesses for all values of \(M\) when doing exact attention calculation.
As the block size increases, the number of HBM accesses decreases as there are less passes over the input, and the runtime also decreases. However, beyond 256, the runtime starts getting bottlenecked by factors like arithmetic operations. And there is also a limit on how large we can choose the block size to be, as we want it to be able to fit in the SRAM.
BlockSparse FlashAttention
As mentioned in the overview, FlashAttention can be used to make a approximate attention algorithm as well. The authors call it BlockSparse FlashAttention and it is the fastest approximate attention algorithm. The memory complexity is smaller than FlashAttention by a factor proportional to the sparsity.
For inputs \(\mathbf{Q, K, V} \in \mathbb{R}^{N \times d}\) and a mask \(\tilde{\mathbf{M}} \in \{ 0,1 \}^{N \times N}\), we want to calculate 
\[\mathbf{S} = \mathbf{QK^\mathsf{T}} \in \mathbb{R}^{N \times N},\quad \mathbf{P} = softmax(\mathbf{S} \odot \mathbb{1}_{\tilde{\mathbf{\mathrm{M}}}}) \in \mathbb{R}^{N \times N},\quad \mathbf{O} = \mathbf{PV} \in \mathbb{R}^{N \times d}\]Given a predefined block sparsity mask \(\mathbf{M} \in \{ 0,1 \}^{N/B_r \times N/B_c}\), Algorithm 2 above can be adapted to only compute the nonzero blocks of the attention matrix. We can just skip the zero blocks. The Algorithm shown below describes the forward pass of Blocksparse FlashAttention.
Important Information  Let \(N\) be the sequence length, \(d\) be the head dimension, and \(M\) be the size of SRAM with \(d \leq M \leq Nd\). Blocksparse FlashAttention requires \(\Theta(Nd + N^2d^2M^{1}s)\) HBM accesses where \(s\) is the fraction of nonzero blocks in the blocksparsity mask.
For large sequence lengths, \(s\) is set to \(N^{1/2}\) or \(N^{1} \log N\) resulting in \(\Theta(N \sqrt{N})\) or \(\Theta(N \log N)\) IO complexity. As the sparsity increases, the runtime of blocksparse FlashAttention improves proportionally.
Experiments
There are tons of results in the paper. But the TL;DR is that FlashAttention beats all other exact attention algorithms in both training speed and quality of the models/down stream models especially when pushed to the limits of sequence length. I’ll add the plots and graphs for their various results here. Additional results are present in the paper.
Training Speed
BERT
GPT2
Longrange Arena
Blocksparse FlashAttention is faster than all of the approximate attention methods that were tested.
Model Quality
Language Modeling with Long Context
Long Document Classification
Since FlashAttention allows training on longer sequences, it improves performance on such datasets. MIMICIII contains intensive care unit patient discharge summaries, each annotated with multiple labels. ECtHR contains legal cases from the European Court of Human Rights, each of which is mapped to articles of the Convention of Human Rights that were allegedly violated. Both of these datasets contain very long text documents. The average number of tokens in MIMICIII is 2395 tokens and the longest document contains 14562 tokens.
PathX and Path256
These are challenging tasks from the long range arena benchmark where the task is to classify whether two points in a black and white 128×128 (or 256×256) image have a path connecting them, and the images are fed to the transformer one pixel at a time. No transformer model in the past has been able to model these tasks effectively. They have either ran out of memory or achieved random performance. FlashAttention yields the first Transformer that can achieve betterthanrandom performance on the challenging PathX task (sequence length 16K), and blocksparse FlashAttention yields the first sequence model that can achieve betterthanrandom performance on Path256 (sequence length 64K).
Benchmarking Attention
Runtime
FlashAttention beats all exact attention baselines and is about 3× faster than the PyTorch implementation. The runtimes of many approximate/sparse attention mechanisms grow linearly with sequence length, but FlashAttention still runs faster than approximate and sparse attention for short sequences due to fewer memory accesses. The approximate attention runtimes begin to cross over with FlashAttention at sequences between 512 and 1024. On the other hand, blocksparse FlashAttention is faster than all implementations of exact, sparse, and approximate attention that are available, across all sequence lengths.
Memory Footprint
FlashAttention and blocksparse FlashAttention have the same memory footprint, which grows linearly with sequence length. FlashAttention is up to 20× more memory efficient than exact attention baselines, and is more memoryefficient than the approximate attention baselines. All other algorithms except for Linformer run out of memory on an A100 GPU before 64K, and FlashAttention is still 2× more efficient than Linformer.
A great paper overall, tremendous impact and personally, I had loads to learn from it!