From MHA to MQA to GQA, line by line
Fix a model with h attention heads, per-head dimension
d_k, sequence length n, and
L layers. Track one number throughout: the size of the KV cache.
Step 1 — the KV cache in standard MHA. Each head stores, for every token, one key
and one value vector of width d_k. So per layer the cache holds
h keys and h values per token; across the whole
model
\text{cache} \;\propto\; 2 \cdot h \cdot n \cdot d_k \cdot L.
The damning factor is h: every head multiplies the cache, and at long
n this is what fills the GPU's memory.
Step 2 — multi-query attention (MQA): share one k,v. Keep all
h query heads, but give them a single shared key
and value. Now only one k and one v are stored per token per layer:
\text{cache}_{\text{MQA}} \;\propto\; 2 \cdot 1 \cdot n \cdot d_k \cdot L,
a factor of h smaller. Tiny cache, much faster decoding — but funnelling
every query through one k,v can cost a little quality, and can make training less stable.
Step 3 — grouped-query attention (GQA): the middle ground. Split the
h query heads into g groups;
each group shares one key–value pair. There are now g distinct k,v sets
rather than h (MHA) or 1 (MQA):
\text{cache}_{\text{GQA}} \;\propto\; 2 \cdot g \cdot n \cdot d_k \cdot L.
For example h = 32 query heads with g = 8 KV
heads: four query heads per group.
Step 4 — the reduction factor. Compare GQA against full MHA:
\frac{\text{cache}_{\text{GQA}}}{\text{cache}_{\text{MHA}}} = \frac{g}{h} \quad\Longrightarrow\quad \text{the cache shrinks by a factor } \frac{h}{g}.
With h = 32, g = 8 the KV cache is 4\times
smaller. The two endpoints fall straight out: g = h recovers MHA
(reduction 1\times), and g = 1 is MQA (reduction
h\times). GQA dials smoothly between them.
Step 5 — why the quality barely drops. The expensive, expressive part of attention is
the query projections — the many points of view — and GQA keeps all
h of them. Only the keys and values are shared, and a handful of distinct
k,v sets turns out to be plenty to serve many query heads. So GQA captures most of MQA's memory win
while staying within a whisker of MHA's quality — the empirical sweet spot, which is why Llama-2 70B,
Llama-3, and most modern large models ship with it.
GQA interpolates between full multi-head and single-KV attention to shrink the KV cache:
-
MHA. Each of the h heads has its own keys and values,
so the KV cache scales with h.
-
MQA. All h query heads share a single key–value pair,
cutting the cache by h\times at some cost to quality/stability.
-
GQA. The query heads form g groups, each sharing one
k,v — so there are g KV heads, the cache shrinks by
h/g, and g = h is MHA while
g = 1 is MQA. Quality stays near MHA because all
h query heads are kept.
The KV cache, not the model weights, is what caps how long a context a server can hold. Weights are
loaded once and shared by every request, but the cache grows with every token of every active
sequence — at long context and high batch it dwarfs the weights and is what forces a request to
be evicted or refused. Cutting it by 4\times or
8\times is therefore not a minor optimisation: it directly multiplies how
many long conversations a GPU can keep alive at once, and how cheaply.
That is why GQA is nearly universal in models built to serve long contexts. It pairs naturally with
the rest of this chapter — the smaller cache is exactly what lets a model with
RoPE's
long-range positions actually be served at length. The throughput and latency consequences land
straight in the
serving
trade-offs we close the chapter on.