One run of
scaled
dot-product attention gives the model a single point of view: every token
looks at every other token through one set of query, key and value projections, and blends
accordingly. That is powerful, but it is one relationship. A sentence has many at
once — a verb agreeing with its subject, a pronoun bound to its antecedent, a word leaning on the
one beside it. Forcing all of that through a single attention pattern is like reading with one eye.
Multi-head attention runs several attention operations in
parallel, each with its own small projections, so the model can attend to different kinds
of relationship simultaneously — and then it stitches the views back together. The surprise is that
all of this costs about the same as a single full-width attention.
From one head to h heads, line by line
Start with a sequence of n token vectors stacked into a matrix
X \in \mathbb{R}^{n \times d}, where d is the
model width. A single attention head would project X to queries, keys and
values and run one attention. We are going to do h of them at once.
Step 1 — pick the per-head width. Split the model dimension evenly across the
h heads, so each head works in a smaller space of dimension
d_k = \frac{d}{h}.
With d = 512 and h = 8, each head lives in
d_k = 64 dimensions. This single choice is what keeps the cost flat — hold
onto it.
Step 2 — give each head its own small projections. Head
i gets three learned matrices that map the full
d-wide tokens down into its own
d_k-dimensional view:
Q_i = X\,W_Q^{(i)}, \quad K_i = X\,W_K^{(i)}, \quad V_i = X\,W_V^{(i)}, \qquad W_Q^{(i)}, W_K^{(i)}, W_V^{(i)} \in \mathbb{R}^{d \times d_k}.
Because each projection lands in d_k = d/h dimensions rather than the full
d, each head is a narrow attention, not a full one. Different
W^{(i)} means each head can specialise in a different relationship.
Step 3 — each head attends independently. Run ordinary scaled dot-product
attention inside head i, using its own
Q_i, K_i, V_i:
\text{head}_i = \operatorname{Attention}(Q_i, K_i, V_i) = \operatorname{softmax}\!\left(\frac{Q_i K_i^{\top}}{\sqrt{d_k}}\right) V_i \;\in\; \mathbb{R}^{n \times d_k}.
The h heads share no weights and do not talk to each other here — they
run side by side, each producing its own n \times d_k result.
Step 4 — concatenate the heads back to full width. Lay the
h head outputs side by side. Each contributes
d_k columns, and h \cdot d_k = h \cdot (d/h) = d,
so the concatenation is exactly d wide again:
\text{Concat}(\text{head}_1, \dots, \text{head}_h) = \big[\,\text{head}_1 \;\oplus\; \text{head}_2 \;\oplus\; \cdots \;\oplus\; \text{head}_h\,\big] \;\in\; \mathbb{R}^{n \times d}.
Here \oplus is concatenation along the feature axis — the views are
stacked, not yet combined.
Step 5 — mix the heads with an output projection. The concatenation is just the
heads parked next to each other; a final learned matrix
W_O \in \mathbb{R}^{d \times d} lets the model blend what the
heads found into a single representation:
\operatorname{MultiHead}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)\,W_O \;\in\; \mathbb{R}^{n \times d}.
The output is n \times d — same shape as the input — so multi-head
attention is a drop-in block you can stack.
Step 6 — count the cost. All h heads together project
into h \cdot d_k = d dimensions for each of Q, K, V
— the very same total width a single full attention would use. So the number of parameters and the
amount of compute are ≈ the same as one full-width head; splitting into heads buys
several points of view essentially for free.
Multi-head attention runs several narrow attentions in parallel and mixes their outputs:
-
Parallel heads at width d/h. With model width
d and h heads, each head uses its own
projections W_Q^{(i)}, W_K^{(i)}, W_V^{(i)} \in \mathbb{R}^{d \times d_k}
with d_k = d/h, and computes
\text{head}_i = \operatorname{Attention}(Q_i, K_i, V_i) independently.
-
Concatenate. The h outputs (each
n \times d_k) are concatenated to
n \times d, since h \cdot d_k = d.
-
Mix with W_O. A final
W_O \in \mathbb{R}^{d \times d} blends the heads:
\operatorname{MultiHead}(X) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)\,W_O.
-
Cost ≈ a single head. Because d_k = d/h, the heads
together use the same total projection width as one full attention, so the parameter and compute
cost is essentially unchanged.
The hope that different heads specialise is not just wishful — probing trained models shows it
happening. Some heads turn out to be positional, attending almost entirely to the
previous (or next) token, acting like a local smoothing window. Others track
syntax, linking a verb to its subject or a determiner to its noun. Others handle
coreference, pointing a pronoun back at the entity it refers to.
The most celebrated are induction heads: a head (often a pair working together)
that spots a repeated pattern [A][B]\dots[A] and predicts
[B] next — “the last time I saw A, it was
followed by B, so do that again.” This simple copy-the-pattern
circuit is widely believed to underpin a model's ability to learn from examples given in its
prompt. None of this is hand-designed; it emerges because h independent
heads are free to divide the labour.