Batch
normalization has a hidden dependency that bites in practice: it normalises each
feature across the mini-batch, so the statistics it uses for one example are tangled
up with whatever other examples happen to share the batch. Shrink the batch and the estimates
get noisy; reach a batch of size 1 and they are undefined.
Layer normalization cures this by turning the normalisation
sideways. Instead of averaging one feature down a column of examples, it
averages all the features within a single example. The example becomes
self-sufficient — its normalisation depends on nothing but itself. This is the normalisation
the transformer
is built on.
One example, normalised across its features, line by line
Take a single token's feature vector
x = (x_1, \dots, x_d) \in \mathbb{R}^d. Every statistic below is
computed over the d features of this one vector — no other
example is involved.
Step 1 — the mean over features. Average the
d coordinates of this example:
\mu = \frac{1}{d}\sum_{j=1}^{d} x_j.
Step 2 — the variance over features. Their mean squared deviation from that
per-example mean:
\sigma^2 = \frac{1}{d}\sum_{j=1}^{d} (x_j - \mu)^2.
Step 3 — normalise the vector. Subtract the mean and divide by the standard
deviation, coordinate by coordinate (\varepsilon again guarding the
division):
\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \varepsilon}}.
Now \hat{x} has mean 0 and variance
1 across its own features — and it got there without
consulting a single other example.
Step 4 — learnable per-feature scale and shift. As with batch norm, two
learnable vectors \gamma, \beta \in \mathbb{R}^d restore
expressivity, applied elementwise (\odot is the Hadamard product):
y = \gamma \odot \hat{x} + \beta.
Here \gamma and \beta hold one value
per feature, so the network can re-weight each coordinate after standardising.
The axis is the whole point
Picture the activations as a grid: rows are examples, columns are
features. The contrast is entirely about which way you average:
-
Batch norm averages down a column — one feature across all
examples in the batch.
-
Layer norm averages along a row — all features within one
example.
Because a row is self-contained, layer norm is batch-independent: it behaves
identically for a batch of 10{,}000 or a batch of
1, and identically in training and at inference. There are no
running averages to maintain — the same formula runs everywhere.
For a single example's feature vector
x \in \mathbb{R}^d, layer normalization applies:
-
Per-example statistics:
\mu = \tfrac{1}{d}\sum_j x_j and
\sigma^2 = \tfrac{1}{d}\sum_j (x_j - \mu)^2, then
\hat{x} = (x - \mu)/\sqrt{\sigma^2 + \varepsilon} and
y = \gamma \odot \hat{x} + \beta (with per-feature
\gamma, \beta).
-
Axis: it normalises across the features of one example, whereas
batch norm normalises across the batch for one feature.
-
Batch independence: no batch statistics and no running averages — the
transform is identical for any batch size (including 1) and
identical in training and inference, which is exactly why it suits
variable-length sequences and small or streaming batches.
Transformers process sequences whose length varies token to token and whose batches, at
generation time, are often a single example fed one token at a time. Batch norm is a poor
fit here: its statistics would mix unrelated tokens and collapse at batch size
1. Layer norm sidesteps all of it — each token normalises against
its own features, so the same code path serves a training batch of thousands and an
inference batch of one. That batch-independence, more than any raw accuracy gain, is why it
became the default normaliser of the entire transformer family.
Modern models often go one step further with RMSNorm, which drops the mean
subtraction entirely and divides only by the root-mean-square of the features,
\hat{x} = \frac{x}{\sqrt{\tfrac{1}{d}\sum_j x_j^2 + \varepsilon}}.
It is cheaper (no mean, no centring) and works about as well, suggesting the re-centring
was never the load-bearing part — the re-scaling was. We'll meet it properly later.