ZeRO and FSDP

Data parallelism is the simplest way to use many GPUs: give each a full copy of the model and a different slice of the batch. Simple — but wasteful. Every GPU stores the same parameters, the same gradients, and the same optimizer states. With N GPUs you are paying for the model N times over, even though it is one model.

ZeRO (Zero Redundancy Optimizer) — and its PyTorch-native form FSDP (Fully Sharded Data Parallel) — removes that redundancy. Instead of replicating the model state, it shards it across the data-parallel GPUs and gathers each piece only when it is needed.

Where the memory goes, line by line

Let the model have \Psi parameters. Train it with mixed precision and AdamW, and account for the bytes per parameter.

Step 1 — count the replicated state. Plain data parallelism keeps, on every GPU, three things: the parameters, their gradients, and the optimizer's bookkeeping. For AdamW the optimizer carries the first moment m, the second moment v, and an fp32 master copy of the weights:

\underbrace{2\Psi}_{\text{fp16 params}} + \underbrace{2\Psi}_{\text{fp16 grads}} + \underbrace{4\Psi + 4\Psi + 4\Psi}_{\text{Adam } m,\,v,\,\text{fp32 master}} = 16\Psi \text{ bytes.}

That is roughly 16 bytes per parameter, and the optimizer states alone (12\Psi) dwarf the parameters themselves. With N GPUs, this entire 16\Psi is copied N times — pure redundancy.

Step 2 — Stage 1: shard the optimizer states. The optimizer states are needed only during the update, not during the forward/backward pass, so they are the easiest to split. Give each of the N GPUs only 1/N of m, v, and the master weights:

\text{per GPU} = 2\Psi + 2\Psi + \frac{12\Psi}{N}.

Step 3 — Stage 2: also shard the gradients. Each GPU only needs the gradient for the parameters whose optimizer state it owns, so shard the gradients the same way:

\text{per GPU} = 2\Psi + \frac{2\Psi}{N} + \frac{12\Psi}{N}.

Step 4 — Stage 3 (= FSDP): also shard the parameters. Finally split the parameters themselves. Now nothing is fully replicated — every part of the model state is divided N ways:

\text{per GPU} = \frac{2\Psi}{N} + \frac{2\Psi}{N} + \frac{12\Psi}{N} = \frac{16\Psi}{N}.

Step 5 — gather on demand. But the forward pass does need a layer's full parameters. So just before computing a layer, the GPUs do an all-gather to momentarily reconstruct that layer's weights; right after, each GPU throws away the shards it does not own. The full weights exist only for the instant they are used, one layer at a time — so peak memory tracks the largest single layer, not the whole model:

\text{Stage 3 memory per GPU} \;\approx\; \frac{16\Psi}{N} \;\xrightarrow[N \text{ large}]{}\; \text{model fits, however big.}

The bargain is the mirror image of pipeline parallelism's: ZeRO pays a little extra communication (the on-demand gathers) to cut per-GPU memory by up to N\times, all while keeping data parallelism's dead-simple programming model.

For a model of \Psi parameters trained data-parallel on N GPUs:

Why 16\Psi? Mixed-precision AdamW training keeps, per parameter:

Total: 2 + 2 + 4 + 4 + 4 = 16 bytes. The striking part is that 12 of those 16 — three quarters — are optimizer overhead, not the model. That is exactly why ZeRO shards the optimizer states first (Stage 1): it is the biggest, cheapest win, removing 12\Psi(1 - 1/N) bytes per GPU before touching anything on the hot forward/backward path.

Watch per-GPU memory shrink, stage by stage

Three stacked bands show one GPU's memory in bytes-per-parameter: parameters (2), gradients (2), and optimizer states (12) — 16 total. Move the stage selector from 0 (plain data parallelism, nothing sharded) to 3 (FSDP, everything sharded across N=8 GPUs) and watch each band collapse as it gets sharded.