RMSNorm

Layer normalization does two things to a token's feature vector: it re-centres it (subtract the mean) and re-scales it (divide by the standard deviation). RMSNorm asks a pointed question: is the centring actually doing any work? It turns out the answer is mostly no — so RMSNorm drops the mean-centring (and the bias) and keeps only the scaling. The result is simpler, a touch faster, and works about as well. Modern models (Llama, T5, and friends) use it almost everywhere.

From LayerNorm to RMSNorm, line by line

Start with a token's feature vector x = (x_1, \dots, x_d) \in \mathbb{R}^d and recall what LayerNorm computes over those d features.

Step 1 — recall LayerNorm. It subtracts the per-example mean \mu = \frac1d\sum_j x_j and divides by the standard deviation, then applies a learnable scale and shift:

\operatorname{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \varepsilon}} + \beta, \qquad \sigma^2 = \frac1d\sum_{j=1}^d (x_j - \mu)^2.

Step 2 — drop the mean. Set \mu = 0 and don't subtract anything. With no centring the "variance" of the un-centred vector is just the mean of the squares — the mean square:

\operatorname{MS}(x) = \frac1d\sum_{j=1}^d x_j^2.

Its square root is the root-mean-square, the overall magnitude of the vector:

\operatorname{RMS}(x) = \sqrt{\frac1d\sum_{j=1}^d x_j^2 + \varepsilon}.

Step 3 — divide by the RMS. Re-scale the raw vector by its own magnitude — no subtraction first:

\hat{x} = \frac{x}{\operatorname{RMS}(x)}.

Now \hat{x} has unit root-mean-square but, unlike LayerNorm's output, it is not forced to mean zero — whatever average the features had, they keep.

Step 4 — scale, but drop the bias. Apply a learnable per-feature gain \gamma and stop there — no \beta shift:

\operatorname{RMSNorm}(x) = \frac{x}{\sqrt{\frac1d\sum_{j=1}^d x_j^2 + \varepsilon}} \odot \gamma.

Step 5 — count what we saved. LayerNorm needs two passes' worth of work over the features — one to find \mu, one to find \sigma^2 about that mean — plus a subtraction per coordinate, plus the \beta add. RMSNorm needs only a single sum of squares and one division. No mean, no centring, no bias. A small saving per call — but normalization runs twice in every one of dozens of layers for every token, so it compounds into a real, universal speedup.

Step 6 — and the quality? Empirically, dead even. Dropping the re-centring barely moves the loss, which is the surprising part: the load-bearing operation in LayerNorm was the re-scaling all along. Keep the part that matters, pay less for it.

RMSNorm is LayerNorm with the mean-centring and bias removed:

LayerNorm was originally justified as controlling both the mean and the variance of each activation. RMSNorm's ablation tells a cleaner story: the stabilising effect comes almost entirely from fixing the scale of the vector, which is what keeps gradients from exploding or vanishing as signals pass through a deep stack. Re-centring contributes little because the downstream linear layer already has a bias and can absorb any constant shift — so subtracting the mean was, in effect, redundant work the rest of the network was happy to do itself.

There is a tidy invariance, too. RMSNorm is scale-invariant: multiply the whole input vector by any positive constant c and the output is unchanged, because cx / \operatorname{RMS}(cx) = cx / (c\operatorname{RMS}(x)) = x/\operatorname{RMS}(x). That re-scaling robustness is precisely the property that makes deep networks trainable — and it survives perfectly without the mean ever being touched.

Centre-then-scale, or just scale?

The bars are one token's raw features. Flip the control between the two normalisers and watch what each one does. LayerNorm first slides every bar so the dashed mean line drops to zero, then rescales — the features end up centred. RMSNorm skips the slide entirely: it only rescales by the root-mean-square, so the mean line stays wherever it was. Same spread control, one fewer operation.