Grouped-Query Attention

In multi-head attention, every one of the h heads carries its own keys and values. At inference those keys and values are stored — that is the KV cache, and it is the memory bottleneck of long-context serving. Grouped-query attention shrinks it with a simple observation: the query heads can stay numerous and diverse, but they don't each need a private set of keys and values. Let several query heads share one set, and the cache collapses — at almost no cost to quality.

From MHA to MQA to GQA, line by line

Fix a model with h attention heads, per-head dimension d_k, sequence length n, and L layers. Track one number throughout: the size of the KV cache.

Step 1 — the KV cache in standard MHA. Each head stores, for every token, one key and one value vector of width d_k. So per layer the cache holds h keys and h values per token; across the whole model

\text{cache} \;\propto\; 2 \cdot h \cdot n \cdot d_k \cdot L.

The damning factor is h: every head multiplies the cache, and at long n this is what fills the GPU's memory.

Step 2 — multi-query attention (MQA): share one k,v. Keep all h query heads, but give them a single shared key and value. Now only one k and one v are stored per token per layer:

\text{cache}_{\text{MQA}} \;\propto\; 2 \cdot 1 \cdot n \cdot d_k \cdot L,

a factor of h smaller. Tiny cache, much faster decoding — but funnelling every query through one k,v can cost a little quality, and can make training less stable.

Step 3 — grouped-query attention (GQA): the middle ground. Split the h query heads into g groups; each group shares one key–value pair. There are now g distinct k,v sets rather than h (MHA) or 1 (MQA):

\text{cache}_{\text{GQA}} \;\propto\; 2 \cdot g \cdot n \cdot d_k \cdot L.

For example h = 32 query heads with g = 8 KV heads: four query heads per group.

Step 4 — the reduction factor. Compare GQA against full MHA:

\frac{\text{cache}_{\text{GQA}}}{\text{cache}_{\text{MHA}}} = \frac{g}{h} \quad\Longrightarrow\quad \text{the cache shrinks by a factor } \frac{h}{g}.

With h = 32, g = 8 the KV cache is 4\times smaller. The two endpoints fall straight out: g = h recovers MHA (reduction 1\times), and g = 1 is MQA (reduction h\times). GQA dials smoothly between them.

Step 5 — why the quality barely drops. The expensive, expressive part of attention is the query projections — the many points of view — and GQA keeps all h of them. Only the keys and values are shared, and a handful of distinct k,v sets turns out to be plenty to serve many query heads. So GQA captures most of MQA's memory win while staying within a whisker of MHA's quality — the empirical sweet spot, which is why Llama-2 70B, Llama-3, and most modern large models ship with it.

GQA interpolates between full multi-head and single-KV attention to shrink the KV cache:

The KV cache, not the model weights, is what caps how long a context a server can hold. Weights are loaded once and shared by every request, but the cache grows with every token of every active sequence — at long context and high batch it dwarfs the weights and is what forces a request to be evicted or refused. Cutting it by 4\times or 8\times is therefore not a minor optimisation: it directly multiplies how many long conversations a GPU can keep alive at once, and how cheaply.

That is why GQA is nearly universal in models built to serve long contexts. It pairs naturally with the rest of this chapter — the smaller cache is exactly what lets a model with RoPE's long-range positions actually be served at length. The throughput and latency consequences land straight in the serving trade-offs we close the chapter on.

Map the query heads onto shared KV heads

The top row is the h = 8 query heads; the bottom row is the g shared KV heads. Each query head connects to the KV head of its group. Slide g from 8 down to 1 and watch the KV row collapse — at g = 8 it's full multi-head attention, at g = 1 every query funnels into one KV head (MQA). The readout shows the live cache reduction factor h/g.