Layer Normalization

Batch normalization has a hidden dependency that bites in practice: it normalises each feature across the mini-batch, so the statistics it uses for one example are tangled up with whatever other examples happen to share the batch. Shrink the batch and the estimates get noisy; reach a batch of size 1 and they are undefined.

Layer normalization cures this by turning the normalisation sideways. Instead of averaging one feature down a column of examples, it averages all the features within a single example. The example becomes self-sufficient — its normalisation depends on nothing but itself. This is the normalisation the transformer is built on.

One example, normalised across its features, line by line

Take a single token's feature vector x = (x_1, \dots, x_d) \in \mathbb{R}^d. Every statistic below is computed over the d features of this one vector — no other example is involved.

Step 1 — the mean over features. Average the d coordinates of this example:

\mu = \frac{1}{d}\sum_{j=1}^{d} x_j.

Step 2 — the variance over features. Their mean squared deviation from that per-example mean:

\sigma^2 = \frac{1}{d}\sum_{j=1}^{d} (x_j - \mu)^2.

Step 3 — normalise the vector. Subtract the mean and divide by the standard deviation, coordinate by coordinate (\varepsilon again guarding the division):

\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \varepsilon}}.

Now \hat{x} has mean 0 and variance 1 across its own features — and it got there without consulting a single other example.

Step 4 — learnable per-feature scale and shift. As with batch norm, two learnable vectors \gamma, \beta \in \mathbb{R}^d restore expressivity, applied elementwise (\odot is the Hadamard product):

y = \gamma \odot \hat{x} + \beta.

Here \gamma and \beta hold one value per feature, so the network can re-weight each coordinate after standardising.

The axis is the whole point

Picture the activations as a grid: rows are examples, columns are features. The contrast is entirely about which way you average:

Because a row is self-contained, layer norm is batch-independent: it behaves identically for a batch of 10{,}000 or a batch of 1, and identically in training and at inference. There are no running averages to maintain — the same formula runs everywhere.

For a single example's feature vector x \in \mathbb{R}^d, layer normalization applies:

Transformers process sequences whose length varies token to token and whose batches, at generation time, are often a single example fed one token at a time. Batch norm is a poor fit here: its statistics would mix unrelated tokens and collapse at batch size 1. Layer norm sidesteps all of it — each token normalises against its own features, so the same code path serves a training batch of thousands and an inference batch of one. That batch-independence, more than any raw accuracy gain, is why it became the default normaliser of the entire transformer family.

Modern models often go one step further with RMSNorm, which drops the mean subtraction entirely and divides only by the root-mean-square of the features, \hat{x} = \frac{x}{\sqrt{\tfrac{1}{d}\sum_j x_j^2 + \varepsilon}}. It is cheaper (no mean, no centring) and works about as well, suggesting the re-centring was never the load-bearing part — the re-scaling was. We'll meet it properly later.

Flip the normalisation axis

Each cell is one activation in a grid of examples (rows) by features (columns). Use the control to switch between the two normalisers and watch which cells get pooled together. Layer norm lights up a single row — one example, all its features. Batch norm lights up a single column — one feature, all the examples. Same grid, perpendicular axes.