Data
parallelism is the simplest way to use many GPUs: give each a full copy of the model
and a different slice of the batch. Simple — but wasteful. Every GPU stores the same
parameters, the same gradients, and the same optimizer states. With
N GPUs you are paying for the model N times
over, even though it is one model.
ZeRO (Zero Redundancy Optimizer) — and its PyTorch-native form
FSDP (Fully Sharded Data Parallel) — removes that redundancy. Instead of
replicating the model state, it shards it across the data-parallel GPUs and gathers each
piece only when it is needed.
Where the memory goes, line by line
Let the model have \Psi parameters. Train it with mixed precision and
AdamW,
and account for the bytes per parameter.
Step 1 — count the replicated state. Plain data parallelism keeps, on
every GPU, three things: the parameters, their gradients, and the optimizer's bookkeeping.
For AdamW the optimizer carries the first moment m, the second moment
v, and an fp32 master copy of the weights:
\underbrace{2\Psi}_{\text{fp16 params}} + \underbrace{2\Psi}_{\text{fp16 grads}} + \underbrace{4\Psi + 4\Psi + 4\Psi}_{\text{Adam } m,\,v,\,\text{fp32 master}} = 16\Psi \text{ bytes.}
That is roughly 16 bytes per parameter, and the optimizer states alone
(12\Psi) dwarf the parameters themselves. With
N GPUs, this entire 16\Psi is copied
N times — pure redundancy.
Step 2 — Stage 1: shard the optimizer states. The optimizer states are needed
only during the update, not during the forward/backward pass, so they are the easiest to split.
Give each of the N GPUs only 1/N of
m, v, and the master weights:
\text{per GPU} = 2\Psi + 2\Psi + \frac{12\Psi}{N}.
Step 3 — Stage 2: also shard the gradients. Each GPU only needs the gradient for
the parameters whose optimizer state it owns, so shard the gradients the same way:
\text{per GPU} = 2\Psi + \frac{2\Psi}{N} + \frac{12\Psi}{N}.
Step 4 — Stage 3 (= FSDP): also shard the parameters. Finally split the
parameters themselves. Now nothing is fully replicated — every part of the model state is
divided N ways:
\text{per GPU} = \frac{2\Psi}{N} + \frac{2\Psi}{N} + \frac{12\Psi}{N} = \frac{16\Psi}{N}.
Step 5 — gather on demand. But the forward pass does need a layer's full
parameters. So just before computing a layer, the GPUs do an all-gather to
momentarily reconstruct that layer's weights; right after, each GPU throws away the shards it does
not own. The full weights exist only for the instant they are used, one layer at a time — so peak
memory tracks the largest single layer, not the whole model:
\text{Stage 3 memory per GPU} \;\approx\; \frac{16\Psi}{N} \;\xrightarrow[N \text{ large}]{}\; \text{model fits, however big.}
The bargain is the mirror image of pipeline parallelism's: ZeRO pays a little extra communication
(the on-demand gathers) to cut per-GPU memory by up to N\times, all
while keeping data parallelism's dead-simple programming model.
For a model of \Psi parameters trained data-parallel on
N GPUs:
-
Replication wastes memory. Plain data parallelism stores all
16\Psi bytes (params + grads + optimizer states) on
every GPU — the same model copied N times.
-
Shard in stages. Stage 1 shards the optimizer states; Stage 2 also the
gradients; Stage 3 (= FSDP) also the parameters — until per-GPU memory is
16\Psi / N.
-
Gather on demand. A layer's full parameters (and gradients) are
all-gathered just before they are used and dropped right after, so a model far too big for one
GPU still trains, at a cost of extra communication and a roughly
N\times memory saving.
Why 16\Psi? Mixed-precision AdamW training keeps, per parameter:
- 2 bytes — the fp16 (half-precision) parameter used in the forward pass.
- 2 bytes — the fp16 gradient.
- 4 bytes — Adam's first moment m (fp32).
- 4 bytes — Adam's second moment v (fp32).
- 4 bytes — the fp32 "master" copy of the weights, kept for numerical stability.
Total: 2 + 2 + 4 + 4 + 4 = 16 bytes. The striking part is that
12 of those 16 — three quarters — are
optimizer overhead, not the model. That is exactly why ZeRO shards the optimizer states first
(Stage 1): it is the biggest, cheapest win, removing 12\Psi(1 - 1/N)
bytes per GPU before touching anything on the hot forward/backward path.