The bottleneck is memory, not multiplication
A GPU has a small pool of very fast on-chip memory (SRAM) and a large pool of
much slower memory (HBM). Ordinary attention computes the full
N \times N score matrix, writes it to slow HBM, reads it back to apply
\text{softmax}, writes that out, reads it again to multiply by
the values… The chip spends most of its time waiting on memory traffic, and the
N^2 matrix also blows up memory usage as the context grows.
FlashAttention's insight: fuse all those steps and work in tiles. Load a block of
queries and a block of keys/values into fast SRAM, compute their part of the attention, accumulate
the result, and move to the next block — never materialising the whole matrix in slow
memory. The trick that makes this possible is an online softmax: a way to
keep a running softmax total as each new block arrives, rescaling what you've accumulated so far, so
the final answer is exactly correct even though you never saw all the scores at once.
- Computes exact attention — not an approximation.
- Processes attention in tiles in fast on-chip memory, never writing the
N \times N matrix to slow memory.
- Uses an online (streaming) softmax to combine blocks correctly; the win is in
memory IO, giving big speed-ups and linear memory in N.
The jump in context lengths — from a couple of thousand tokens to hundreds of thousands — is in
large part a FlashAttention story. By turning attention's memory from
N^2 into something linear, and cutting the memory traffic that
dominated the runtime, it made long contexts both affordable and fast. Nearly every modern
large-model training and inference stack now uses a FlashAttention-style kernel under the hood; it
is a rare example of a pure systems optimisation reshaping what models can do.
-
FlashAttention changes how attention is computed, not what —
the output is numerically the same (bar tiny floating-point differences). It is not one
of the "approximate/linear attention" methods that trade accuracy for speed.
-
The speed-up comes from the memory hierarchy (SRAM vs HBM), so it's a
hardware-aware kernel, not an algebraic change to the attention formula.