FlashAttention

Scaled dot-product attention has a hidden cost that has nothing to do with arithmetic. For a sequence of N tokens it forms an N \times N matrix of attention scores — and for a long context, writing that matrix to memory and reading it back is far slower than the maths itself. FlashAttention is an IO-aware algorithm that computes exactly the same attention while never storing the big matrix at all, making attention dramatically faster and freeing it to handle much longer sequences.

The bottleneck is memory, not multiplication

A GPU has a small pool of very fast on-chip memory (SRAM) and a large pool of much slower memory (HBM). Ordinary attention computes the full N \times N score matrix, writes it to slow HBM, reads it back to apply \text{softmax}, writes that out, reads it again to multiply by the values… The chip spends most of its time waiting on memory traffic, and the N^2 matrix also blows up memory usage as the context grows.

FlashAttention's insight: fuse all those steps and work in tiles. Load a block of queries and a block of keys/values into fast SRAM, compute their part of the attention, accumulate the result, and move to the next block — never materialising the whole matrix in slow memory. The trick that makes this possible is an online softmax: a way to keep a running softmax total as each new block arrives, rescaling what you've accumulated so far, so the final answer is exactly correct even though you never saw all the scores at once.

The jump in context lengths — from a couple of thousand tokens to hundreds of thousands — is in large part a FlashAttention story. By turning attention's memory from N^2 into something linear, and cutting the memory traffic that dominated the runtime, it made long contexts both affordable and fast. Nearly every modern large-model training and inference stack now uses a FlashAttention-style kernel under the hood; it is a rare example of a pure systems optimisation reshaping what models can do.