Mixed-Precision Training

The training recipe stored every weight, gradient, and activation in 32-bit floating point (fp32). But a number kept in 16 bits takes half the memory and moves through the hardware roughly twice as fast. Mixed-precision training does the heavy compute in 16-bit while keeping a few critical pieces in fp32 — halving memory and nearly doubling throughput for almost free.

"Almost" is the whole story. Sixteen bits buys speed at the cost of either range or precision, and getting it wrong silently destroys your gradients. The fixes are two small, exact tricks: a master copy of the weights in fp32, and (for one of the two 16-bit formats) loss scaling.

Deriving the fixes

A floating-point number splits its bits into a sign, an exponent (which sets the range — how large or small a value it can reach), and a mantissa (which sets the precision — how many significant digits). The 16-bit formats spend their 15 non-sign bits differently.

Step 1 — line up the three formats. Count the bits:

\underbrace{\text{fp32}}_{1+8+23} \qquad \underbrace{\text{fp16}}_{1+5+10} \qquad \underbrace{\text{bf16}}_{1+8+7}.

fp16 keeps 10 mantissa bits but only 5 exponent bits — good precision, tiny range (down to about 6\times 10^{-5}). bf16 ("brain float") keeps fp32's full 8 exponent bits — fp32's huge range — but only 7 mantissa bits, so it is coarser.

Step 2 — the underflow problem. Late in training, gradients are small. A gradient like g \approx 10^{-7} sits below fp16's smallest representable normal value, so it rounds to exactly zero:

g \approx 10^{-7} \;\xrightarrow{\ \text{fp16}\ }\; 0.

A gradient that underflows to zero contributes no update — the parameter silently stops learning. bf16, with fp32's range, represents 10^{-7} fine; only fp16 has this disease.

Step 3 — fix one: an fp32 master copy. Many tiny updates -\eta_t\, \hat g are each far smaller than the weight itself; added in 16-bit they round away to nothing. So keep the authoritative weights in fp32 — the master copy — apply every update there at full precision, and cast a 16-bit copy only for the fast forward/backward compute:

\theta^{\text{fp32}}_t = \theta^{\text{fp32}}_{t-1} - \eta_t\, \hat g, \qquad \theta^{16} = \text{cast}_{16}(\theta^{\text{fp32}}_t).

Step 4 — fix two: loss scaling (fp16 only). To rescue the underflowing gradients, multiply the loss by a large constant S (say 2^{14}) before the backward pass. By linearity of the gradient, every gradient is scaled by the same S, lifting it back into fp16's representable range:

\nabla(S\cdot L) = S\,\nabla L, \qquad S \cdot 10^{-7} = 2^{14}\cdot 10^{-7} \approx 1.6\times 10^{-3} \ \checkmark

Step 5 — unscale before the update. The scaled gradient is S times too big, so divide it out before stepping the master weights — restoring the true gradient exactly:

\hat g = \frac{1}{S}\,\big(S\,\nabla L\big) = \nabla L.

bf16 already reaches 10^{-7} without help, so it needs no loss scaling — which is why bf16 is today's default: the same memory and speed as fp16, the same fp32-style range, and one fewer knob to tune.

Train with 16-bit compute and an fp32 safety net:

A floating-point value is (-1)^{\text{sign}} \times 1.\text{mantissa} \times 2^{\text{exponent}}. The exponent field is a range dial and the mantissa field is a resolution dial, and the three formats just split their bits between the two:

\begin{aligned} \text{fp32:}&\quad 1\ \text{sign} \;+\; 8\ \text{exp} \;+\; 23\ \text{mantissa} = 32 \\ \text{fp16:}&\quad 1\ \text{sign} \;+\; 5\ \text{exp} \;+\; 10\ \text{mantissa} = 16 \\ \text{bf16:}&\quad 1\ \text{sign} \;+\; 8\ \text{exp} \;+\; 7\ \text{mantissa} = 16 \end{aligned}

bf16 is literally fp32 with the bottom 16 mantissa bits chopped off: same exponent field, so casting fp32 ↔ bf16 never overflows. That is why deep-learning hardware made bf16 the native training format — you keep the range you reasoned about in fp32 and simply tolerate coarser rounding, which stochastic gradient noise washes out anyway.

Range, on a log scale

Each bar spans the representable range of a format, drawn against the power of ten \log_{10}|x|. fp32 and bf16 reach the same tiny magnitudes (same exponent bits); fp16's bar stops far short on the small-number end. The dashed line marks a typical small gradient \approx 10^{-7} — it lands inside fp32 and bf16 but past the end of fp16, where it underflows to zero. Step through to reveal each format and the underflow cliff.

The default, and what is left to fit

bf16 mixed precision halves the memory of weights, gradients, and activations and roughly doubles speed — the first thing every large run turns on. But halving is not enough when the batch you want still will not fit. The next pages buy more room a different way: accumulating gradients over micro-batches, and recomputing activations instead of storing them.