The
training
recipe stored every weight, gradient, and activation in 32-bit floating point
(fp32). But a number kept in 16 bits takes half the memory and moves
through the hardware roughly twice as fast. Mixed-precision training does the
heavy compute in 16-bit while keeping a few critical pieces in fp32 — halving memory and
nearly doubling throughput for almost free.
"Almost" is the whole story. Sixteen bits buys speed at the cost of either range or
precision, and getting it wrong silently destroys your gradients. The fixes are two
small, exact tricks: a master copy of the weights in fp32, and (for one of
the two 16-bit formats) loss scaling.
Deriving the fixes
A floating-point number splits its bits into a sign, an
exponent (which sets the range — how large or small a value it can
reach), and a mantissa (which sets the precision — how many
significant digits). The 16-bit formats spend their 15 non-sign bits differently.
Step 1 — line up the three formats. Count the bits:
\underbrace{\text{fp32}}_{1+8+23} \qquad \underbrace{\text{fp16}}_{1+5+10} \qquad \underbrace{\text{bf16}}_{1+8+7}.
fp16 keeps 10 mantissa bits but only
5 exponent bits — good precision, tiny range
(down to about 6\times 10^{-5}). bf16 ("brain float") keeps
fp32's full 8 exponent bits — fp32's huge range — but only
7 mantissa bits, so it is coarser.
Step 2 — the underflow problem. Late in training, gradients are small. A
gradient like g \approx 10^{-7} sits below fp16's smallest
representable normal value, so it rounds to exactly zero:
g \approx 10^{-7} \;\xrightarrow{\ \text{fp16}\ }\; 0.
A gradient that underflows to zero contributes no update — the parameter silently
stops learning. bf16, with fp32's range, represents 10^{-7} fine;
only fp16 has this disease.
Step 3 — fix one: an fp32 master copy. Many tiny updates
-\eta_t\, \hat g are each far smaller than the weight itself; added
in 16-bit they round away to nothing. So keep the authoritative weights in fp32 — the
master copy — apply every update there at full precision, and cast a 16-bit
copy only for the fast forward/backward compute:
\theta^{\text{fp32}}_t = \theta^{\text{fp32}}_{t-1} - \eta_t\, \hat g, \qquad \theta^{16} = \text{cast}_{16}(\theta^{\text{fp32}}_t).
Step 4 — fix two: loss scaling (fp16 only). To rescue the underflowing
gradients, multiply the loss by a large constant S (say
2^{14}) before the backward pass. By linearity of the
gradient, every gradient is scaled by the same S, lifting it back
into fp16's representable range:
\nabla(S\cdot L) = S\,\nabla L, \qquad S \cdot 10^{-7} = 2^{14}\cdot 10^{-7} \approx 1.6\times 10^{-3} \ \checkmark
Step 5 — unscale before the update. The scaled gradient is
S times too big, so divide it out before stepping the master
weights — restoring the true gradient exactly:
\hat g = \frac{1}{S}\,\big(S\,\nabla L\big) = \nabla L.
bf16 already reaches 10^{-7} without help, so it needs
no loss scaling — which is why bf16 is today's default: the same
memory and speed as fp16, the same fp32-style range, and one fewer knob to tune.
Train with 16-bit compute and an fp32 safety net:
-
16-bit compute. Do the forward and backward passes in fp16 or bf16 —
half the memory, roughly double the throughput of fp32.
-
fp32 master weights. Keep the authoritative weights in fp32 and apply
every optimizer update there, so tiny updates are not lost to rounding.
-
fp16 loss scaling. For fp16, multiply the loss by
S before backward (so small gradients survive) and divide by
S before the update — exact by linearity.
-
bf16 ⇒ no scaling. bf16 has fp32's exponent range, so gradients do not
underflow and loss scaling is unnecessary — the modern default.
A floating-point value is (-1)^{\text{sign}} \times 1.\text{mantissa} \times 2^{\text{exponent}}.
The exponent field is a range dial and the mantissa field is a resolution dial, and the
three formats just split their bits between the two:
\begin{aligned} \text{fp32:}&\quad 1\ \text{sign} \;+\; 8\ \text{exp} \;+\; 23\ \text{mantissa} = 32 \\ \text{fp16:}&\quad 1\ \text{sign} \;+\; 5\ \text{exp} \;+\; 10\ \text{mantissa} = 16 \\ \text{bf16:}&\quad 1\ \text{sign} \;+\; 8\ \text{exp} \;+\; 7\ \text{mantissa} = 16 \end{aligned}
bf16 is literally fp32 with the bottom 16 mantissa bits chopped off: same exponent field,
so casting fp32 ↔ bf16 never overflows. That is why deep-learning hardware made bf16 the
native training format — you keep the range you reasoned about in fp32 and simply tolerate
coarser rounding, which stochastic gradient noise washes out anyway.