Gradient Checkpointing

The backward pass has a hidden appetite. To compute a layer's gradient, automatic differentiation needs that layer's activations — the values it produced on the forward pass. The naive recipe simply stores every layer's activations so they are waiting when backprop reaches them. For a network of depth n, that is memory growing linearly, O(n) — and for a deep transformer on a long sequence, it is the single biggest consumer of GPU memory.

Gradient checkpointing (also called activation checkpointing) makes a different bargain: store only a few activations — the checkpoints — and recompute the rest on the fly during the backward pass. It spends extra compute (one extra forward pass) to slash activation memory from O(n) toward O(\sqrt{n}).

Deriving the √n trade-off

Take a network of n layers. We choose m evenly spaced checkpoints and recompute everything between them when needed. Count the memory and the extra compute as functions of m.

Step 1 — the naive cost. Store all n activations: memory \propto n. That is the baseline we want to beat.

M_{\text{naive}} = O(n).

Step 2 — keep only m checkpoints. Save activations at m layers and discard the rest. The stored checkpoints cost O(m) memory:

M_{\text{checkpoints}} = O(m).

Step 3 — recompute a segment at a time. The m checkpoints cut the network into m segments of about n/m layers each. To get the gradients inside one segment, replay the forward pass from its checkpoint — holding only that segment's activations in memory at once, an extra O(n/m):

M_{\text{total}}(m) = O\!\left(m + \frac{n}{m}\right).

Step 4 — minimize over m. Balance the two terms. Their sum is smallest when they are equal, m = n/m, i.e. m = \sqrt{n} (calculus agrees: set \frac{d}{dm}\!\left(m + \tfrac{n}{m}\right) = 1 - \tfrac{n}{m^2} = 0). At that choice each term is \sqrt{n}:

M_{\text{total}}(\sqrt{n}) = O\!\left(\sqrt{n} + \frac{n}{\sqrt{n}}\right) = O(\sqrt{n}).

Step 5 — the compute it costs. Backprop already does one forward and one backward. Recomputing each segment once during the backward pass adds, in total, exactly one more forward pass over the network — a fixed \approx 33\% overhead, independent of n:

\text{compute: } \ \underbrace{1\ \text{fwd}}_{\text{normal}} + \underbrace{1\ \text{fwd}}_{\text{recompute}} + \underbrace{1\ \text{bwd}}_{\text{normal}}.

So m = \sqrt{n} checkpoints turn O(n) activation memory into O(\sqrt{n}) for the price of one extra forward pass — the classic compute-for-memory trade.

For a depth-n network:

Transformer activation memory scales with depth × sequence length × width, and the attention activations grow with the square of the sequence length. Double the context window and the activations balloon; they, not the weights, are what pin a long-context model to a too-small batch (or off the GPU entirely). This is the memory wall.

Checkpointing is the standard tool for climbing it: by recomputing instead of storing, it lets you train deeper networks and longer sequences on the same hardware, or fit a larger micro-batch (which then pairs with gradient accumulation for a still-bigger effective batch). The extra forward pass is usually a price well worth paying when the alternative is not training the model at all.

Find the valley

The curve is total activation memory M(m) = m + n/m against the number of checkpoints m. Too few checkpoints (left) and you recompute giant segments; too many (right) and the stored checkpoints themselves dominate. The sweet spot is the bottom of the valley at m = \sqrt{n}, where memory is 2\sqrt{n}. Drag network depth n and watch the valley deepen and its floor track \sqrt{n}.

Memory bought, one device at a time

Mixed precision, accumulation, and checkpointing are all single-device tricks: each squeezes a bigger model or batch onto the GPU you already have. When even that is not enough — or you simply want to go faster — the next move is to spread the work across many GPUs, which is data parallelism.