FlashAttention

Scaled dot-product attention hides a brutal cost. To compute \operatorname{softmax}(QK^{\top}/\sqrt{d_k})\,V it first builds the full n \times n score matrix S = QK^{\top}, writes it to the GPU's large-but-slow memory (HBM), reads it back to softmax, writes again, reads again to multiply by V. That n \times n matrix is O(n^2) memory, and shuffling it in and out of slow memory — not the arithmetic — is what dominates the runtime.

FlashAttention computes exactly the same answer without ever building that matrix. The trick: tile the inputs, do the math in fast on-chip memory, and stitch the tiles together with an online softmax that never needs to see all the scores at once.

Tiling attention, line by line

The obstacle is the softmax, which seems to need a whole row of scores before it can normalize. We dismantle that, one step at a time.

Step 1 — name the bottleneck. A GPU has a little very-fast on-chip memory (SRAM) and a lot of slow off-chip memory (HBM). Standard attention materializes S = QK^{\top} \in \mathbb{R}^{n\times n} in HBM:

\text{memory} = O(n^2), \qquad \text{cost} = \text{dominated by HBM reads/writes (IO-bound).}

Step 2 — tile Q, K, V into blocks. Cut the queries into row-blocks Q_1, Q_2, \dots and the keys/values into (K_1, V_1), (K_2, V_2), \dots. For one query block, we will sweep over the key/value blocks, holding only a tile in SRAM at a time:

S_{ij} = Q_i K_j^{\top} \in \mathbb{R}^{B \times B} \quad(\text{a small tile, never the full } n\times n).

Step 3 — recall the safe softmax. To avoid overflow, softmax subtracts the row max m = \max_j s_j:

\operatorname{softmax}(s)_j = \frac{e^{\,s_j - m}}{\sum_\ell e^{\,s_\ell - m}} = \frac{e^{\,s_j - m}}{\ell}, \qquad \ell = \sum_\ell e^{\,s_\ell - m}.

Both the max m and the normalizer \ell appear to need the whole row. The online-softmax trick says: they don't — keep them running.

Step 4 — the online softmax update. Carry a running max m, a running normalizer \ell, and a running weighted output o. When a new block of scores arrives with block-max \tilde m, update the running max and rescale what you had so far to the new max:

m_{\text{new}} = \max(m, \tilde m), \qquad \ell \leftarrow e^{\,m - m_{\text{new}}}\,\ell + \sum_{j \in \text{block}} e^{\,s_j - m_{\text{new}}}, o \leftarrow e^{\,m - m_{\text{new}}}\,o + \sum_{j \in \text{block}} e^{\,s_j - m_{\text{new}}}\, v_j.

The factor e^{\,m - m_{\text{new}}} retro-corrects the earlier blocks for the newly discovered larger max — so after the last block, o/\ell is identical to the softmax computed over the entire row at once. Nothing is approximated.

Step 5 — sweep, then divide. Loop over all key/value blocks for each query block, updating (m, \ell, o) in SRAM; at the end, the attention output for that query block is

\text{out}_i = \frac{o}{\ell}, \qquad \text{stored once to HBM.}

The full n \times n score matrix is never written to HBM — only the O(n) running statistics and the O(n) output. Memory drops from O(n^2) to O(n), and the slow-memory traffic plummets. Same exact result, a fraction of the IO.

Exact attention can be computed block-by-block:

A GPU kernel is compute-bound if its arithmetic units are the limit, and memory-bound if it spends most of its time waiting on data from HBM. Modern GPUs can do hundreds of floating-point operations in the time it takes to fetch one number from HBM, so a kernel that does little math per byte moved is memory-bound — the expensive ALUs sit idle.

Standard attention is exactly this case: reading and writing the n \times n score matrix moves enormous amounts of data for relatively little arithmetic, so it is firmly memory-bound. FlashAttention does the same arithmetic but keeps the scores in SRAM, slashing HBM traffic — which is why it is faster even though it is not doing fewer floating-point operations. It is an IO-aware algorithm: it optimizes the thing that actually costs, the memory movement, not the FLOPs.

Watch the tiles sweep — and the memory stay flat

The grid is the n \times n score matrix, drawn as Q-blocks (rows) by K-blocks (columns). Standard attention would store the whole grid at once. FlashAttention lights one tile at a time as it sweeps — only the lit tile and the running statistics live in fast memory. Slide to advance the sweep; the readout contrasts the O(n^2) memory of the full matrix with the flat O(n) footprint FlashAttention actually uses.