We built
self-attention
one token at a time. Now we write it for the whole sequence at once, as a handful of
matrix
multiplications, and add the one small constant that makes it train. The result is
the single most quoted formula in modern deep learning:
\operatorname{Attention}(Q, K, V) = \operatorname{softmax}\!\left(\frac{Q K^{\top}}{\sqrt{d_k}}\right) V.
Every transformer you have heard of runs this line, billions of times a second. Let us earn it,
piece by piece — especially that mysterious \sqrt{d_k}.
The formula, line by line
Stack the per-token vectors into matrices: Q \in \mathbb{R}^{n \times d_k}
(one query per row), K \in \mathbb{R}^{n \times d_k} (one key per row),
and V \in \mathbb{R}^{n \times d_v} (one value per row), for a sequence
of n tokens.
Step 1 — all pairwise scores at once. One matrix product computes every query·key
dot product simultaneously. Entry (i, j) of
Q K^{\top} is exactly q_i \cdot k_j:
S = Q K^{\top} \in \mathbb{R}^{n \times n},\qquad S_{ij} = q_i \cdot k_j.
The transpose K^{\top} is what lines each key up against each query; out
falls the whole n \times n table of relevances in one shot.
Step 2 — why divide by \sqrt{d_k}? This is the
"scaled" in the name, and it is not cosmetic. Suppose the components of
q_i and k_j are independent, with mean
0 and
variance
1. Then each score is a sum of d_k
independent product terms,
q_i \cdot k_j = \sum_{m=1}^{d_k} q_{im}\, k_{jm}.
Each term q_{im} k_{jm} has mean 0 and
variance 1 (the variance of a product of independent unit-variance,
mean-zero variables). Variances of independent things add, so
\operatorname{Var}(q_i \cdot k_j) = \sum_{m=1}^{d_k} \operatorname{Var}(q_{im} k_{jm}) = d_k,\qquad \text{so the standard deviation is } \sqrt{d_k}.
With d_k = 64 the scores already swing by \pm 8
or more. Feed swings that large into a softmax and it saturates: the biggest score
runs away, the weights collapse onto a near one-hot spike, and a saturated softmax has gradient
\approx 0 everywhere else — learning stalls. Dividing by
\sqrt{d_k} rescales the scores back to standard deviation
1, keeping them order-1 and the softmax (and
its gradient) healthy:
\tilde{S} = \frac{S}{\sqrt{d_k}},\qquad \operatorname{Var}\!\big(\tilde{S}_{ij}\big) = \frac{d_k}{d_k} = 1.
Step 3 — softmax each row. Apply softmax along each row of the scaled scores, so
every row becomes a probability distribution over the n keys:
A = \operatorname{softmax}\!\left(\frac{Q K^{\top}}{\sqrt{d_k}}\right),\qquad A_{ij} \ge 0,\quad \sum_{j=1}^{n} A_{ij} = 1.
A is the n \times n attention matrix: row
i says how token i divides its attention
across the sequence.
Step 4 — multiply by the values. One more matrix product mixes the values under
those weights. Row i of AV is exactly token
i's output z_i = \sum_j A_{ij}\, v_j:
\operatorname{Attention}(Q, K, V) = A\,V = \operatorname{softmax}\!\left(\frac{Q K^{\top}}{\sqrt{d_k}}\right) V \in \mathbb{R}^{n \times d_v}.
Two matrix products and a row-wise softmax, and every token in the sequence has been rewritten as a
weighted average of every value — all in parallel, no loop in sight.
For queries, keys, and values stacked as
Q, K \in \mathbb{R}^{n \times d_k},
V \in \mathbb{R}^{n \times d_v}:
-
The formula.
\operatorname{Attention}(Q, K, V) = \operatorname{softmax}\!\big(Q K^{\top}/\sqrt{d_k}\big) V.
-
All pairwise scores. Q K^{\top} is the
n \times n matrix of every dot product
q_i \cdot k_j, computed in one product.
-
The \sqrt{d_k} scaling. With unit-variance,
mean-zero components, \operatorname{Var}(q_i \cdot k_j) = d_k, so
dividing by \sqrt{d_k} keeps scores order-1
and stops the softmax saturating into a near one-hot, zero-gradient regime.
-
Row-softmax × V. Softmax each row (row-sums to
1), then right-multiply by V to get each
output as a weighted average of value vectors.
Look at the shapes. Q K^{\top} is
n \times n, and forming it costs about
n^2 d_k operations; storing it costs
n^2 memory. So scaled dot-product attention is
O(n^2 \cdot d)
— quadratic in the sequence length n. Double the
context and you quadruple the work and memory of the attention matrix. This single fact is the
gravity well the whole field fights: it is why early transformers capped context at a few
hundred tokens, and why a parade of tricks —
FlashAttention,
sparse and linear attention, and long-context schemes — exist to tame that
n^2 without giving up the formula's power.