Scaled
dot-product attention hides a brutal cost. To compute
\operatorname{softmax}(QK^{\top}/\sqrt{d_k})\,V it first builds the
full n \times n score matrix S = QK^{\top},
writes it to the GPU's large-but-slow memory (HBM), reads it back to softmax, writes again, reads
again to multiply by V. That n \times n
matrix is O(n^2) memory, and shuffling it in and out of slow memory —
not the arithmetic — is what dominates the runtime.
FlashAttention computes exactly the same answer without ever building
that matrix. The trick: tile the inputs, do the math in fast on-chip memory, and stitch the tiles
together with an online softmax that never needs to see all the scores at once.
Tiling attention, line by line
The obstacle is the softmax, which seems to need a whole row of scores before it can normalize.
We dismantle that, one step at a time.
Step 1 — name the bottleneck. A GPU has a little very-fast on-chip memory
(SRAM) and a lot of slow off-chip memory (HBM). Standard attention materializes
S = QK^{\top} \in \mathbb{R}^{n\times n} in HBM:
\text{memory} = O(n^2), \qquad \text{cost} = \text{dominated by HBM reads/writes (IO-bound).}
Step 2 — tile Q, K, V into blocks. Cut the queries
into row-blocks Q_1, Q_2, \dots and the keys/values into
(K_1, V_1), (K_2, V_2), \dots. For one query block, we will sweep over
the key/value blocks, holding only a tile in SRAM at a time:
S_{ij} = Q_i K_j^{\top} \in \mathbb{R}^{B \times B} \quad(\text{a small tile, never the full } n\times n).
Step 3 — recall the safe softmax. To avoid overflow, softmax subtracts the row
max m = \max_j s_j:
\operatorname{softmax}(s)_j = \frac{e^{\,s_j - m}}{\sum_\ell e^{\,s_\ell - m}} = \frac{e^{\,s_j - m}}{\ell}, \qquad \ell = \sum_\ell e^{\,s_\ell - m}.
Both the max m and the normalizer \ell
appear to need the whole row. The online-softmax trick says: they don't — keep them
running.
Step 4 — the online softmax update. Carry a running max
m, a running normalizer \ell, and a running
weighted output o. When a new block of scores arrives with block-max
\tilde m, update the running max and rescale what you had so
far to the new max:
m_{\text{new}} = \max(m, \tilde m), \qquad \ell \leftarrow e^{\,m - m_{\text{new}}}\,\ell + \sum_{j \in \text{block}} e^{\,s_j - m_{\text{new}}},
o \leftarrow e^{\,m - m_{\text{new}}}\,o + \sum_{j \in \text{block}} e^{\,s_j - m_{\text{new}}}\, v_j.
The factor e^{\,m - m_{\text{new}}} retro-corrects the earlier blocks
for the newly discovered larger max — so after the last block, o/\ell
is identical to the softmax computed over the entire row at once. Nothing is
approximated.
Step 5 — sweep, then divide. Loop over all key/value blocks for each query
block, updating (m, \ell, o) in SRAM; at the end, the attention output
for that query block is
\text{out}_i = \frac{o}{\ell}, \qquad \text{stored once to HBM.}
The full n \times n score matrix is never written to HBM — only the
O(n) running statistics and the O(n) output.
Memory drops from O(n^2) to O(n), and the
slow-memory traffic plummets. Same exact result, a fraction of the IO.
Exact attention can be computed block-by-block:
-
No n\times n materialization. The full score
matrix QK^{\top} is never written to slow HBM — only the output
and a few running statistics are.
-
Tiling + online softmax. Tile Q,K,V into blocks
kept in fast SRAM; maintain a running max m, normalizer
\ell, and output o, rescaling by
e^{\,m-m_{\text{new}}} as each block arrives.
-
Exact, with an IO-aware speedup. The result equals standard attention to
the bit, in O(n) memory instead of
O(n^2), with far fewer HBM reads/writes — so the kernel runs
several times faster on long sequences.
A GPU kernel is compute-bound if its arithmetic units are the limit, and
memory-bound if it spends most of its time waiting on data from HBM. Modern
GPUs can do hundreds of floating-point operations in the time it takes to fetch one number from
HBM, so a kernel that does little math per byte moved is memory-bound — the expensive ALUs sit
idle.
Standard attention is exactly this case: reading and writing the
n \times n score matrix moves enormous amounts of data for
relatively little arithmetic, so it is firmly memory-bound. FlashAttention does the
same arithmetic but keeps the scores in SRAM, slashing HBM traffic — which is why it is
faster even though it is not doing fewer floating-point operations. It is an IO-aware
algorithm: it optimizes the thing that actually costs, the memory movement, not the FLOPs.