Scaled Dot-Product Attention

We built self-attention one token at a time. Now we write it for the whole sequence at once, as a handful of matrix multiplications, and add the one small constant that makes it train. The result is the single most quoted formula in modern deep learning:

\operatorname{Attention}(Q, K, V) = \operatorname{softmax}\!\left(\frac{Q K^{\top}}{\sqrt{d_k}}\right) V.

Every transformer you have heard of runs this line, billions of times a second. Let us earn it, piece by piece — especially that mysterious \sqrt{d_k}.

The formula, line by line

Stack the per-token vectors into matrices: Q \in \mathbb{R}^{n \times d_k} (one query per row), K \in \mathbb{R}^{n \times d_k} (one key per row), and V \in \mathbb{R}^{n \times d_v} (one value per row), for a sequence of n tokens.

Step 1 — all pairwise scores at once. One matrix product computes every query·key dot product simultaneously. Entry (i, j) of Q K^{\top} is exactly q_i \cdot k_j:

S = Q K^{\top} \in \mathbb{R}^{n \times n},\qquad S_{ij} = q_i \cdot k_j.

The transpose K^{\top} is what lines each key up against each query; out falls the whole n \times n table of relevances in one shot.

Step 2 — why divide by \sqrt{d_k}? This is the "scaled" in the name, and it is not cosmetic. Suppose the components of q_i and k_j are independent, with mean 0 and variance 1. Then each score is a sum of d_k independent product terms,

q_i \cdot k_j = \sum_{m=1}^{d_k} q_{im}\, k_{jm}.

Each term q_{im} k_{jm} has mean 0 and variance 1 (the variance of a product of independent unit-variance, mean-zero variables). Variances of independent things add, so

\operatorname{Var}(q_i \cdot k_j) = \sum_{m=1}^{d_k} \operatorname{Var}(q_{im} k_{jm}) = d_k,\qquad \text{so the standard deviation is } \sqrt{d_k}.

With d_k = 64 the scores already swing by \pm 8 or more. Feed swings that large into a softmax and it saturates: the biggest score runs away, the weights collapse onto a near one-hot spike, and a saturated softmax has gradient \approx 0 everywhere else — learning stalls. Dividing by \sqrt{d_k} rescales the scores back to standard deviation 1, keeping them order-1 and the softmax (and its gradient) healthy:

\tilde{S} = \frac{S}{\sqrt{d_k}},\qquad \operatorname{Var}\!\big(\tilde{S}_{ij}\big) = \frac{d_k}{d_k} = 1.

Step 3 — softmax each row. Apply softmax along each row of the scaled scores, so every row becomes a probability distribution over the n keys:

A = \operatorname{softmax}\!\left(\frac{Q K^{\top}}{\sqrt{d_k}}\right),\qquad A_{ij} \ge 0,\quad \sum_{j=1}^{n} A_{ij} = 1.

A is the n \times n attention matrix: row i says how token i divides its attention across the sequence.

Step 4 — multiply by the values. One more matrix product mixes the values under those weights. Row i of AV is exactly token i's output z_i = \sum_j A_{ij}\, v_j:

\operatorname{Attention}(Q, K, V) = A\,V = \operatorname{softmax}\!\left(\frac{Q K^{\top}}{\sqrt{d_k}}\right) V \in \mathbb{R}^{n \times d_v}.

Two matrix products and a row-wise softmax, and every token in the sequence has been rewritten as a weighted average of every value — all in parallel, no loop in sight.

For queries, keys, and values stacked as Q, K \in \mathbb{R}^{n \times d_k}, V \in \mathbb{R}^{n \times d_v}:

Look at the shapes. Q K^{\top} is n \times n, and forming it costs about n^2 d_k operations; storing it costs n^2 memory. So scaled dot-product attention is O(n^2 \cdot d)quadratic in the sequence length n. Double the context and you quadruple the work and memory of the attention matrix. This single fact is the gravity well the whole field fights: it is why early transformers capped context at a few hundred tokens, and why a parade of tricks — FlashAttention, sparse and linear attention, and long-context schemes — exist to tame that n^2 without giving up the formula's power.

Watch the scaling un-saturate the softmax

One row of scores, shown as a strip of cells whose darkness is the softmax weight. At full raw scale the scores are large, so the softmax is spiky — almost all the mass on one cell, a near one-hot pick with a dead gradient elsewhere. Slide toward the 1/\sqrt{d_k} scaling and the same row smooths out: the weights spread, the distribution breathes, and gradient flows back to every token. That is the whole job of the \sqrt{d_k}.