Distributed Training: Putting It Together

We now have four tools for training a model too big or too slow for one GPU: data parallelism (split the batch), tensor parallelism (split a layer), pipeline parallelism (split the depth), and ZeRO/FSDP (shard the state), with FlashAttention making each attention layer cheap. To train a frontier model on thousands of GPUs, you don't pick one — you compose them.

The three parallelism axes are orthogonal: each splits a different dimension of the problem, so they multiply together. Stack all three and you get 3-D parallelism, the standard recipe behind every large model.

Composing the axes, line by line

Picture a fleet of GPUs and assign each a coordinate (d, t, p) — its data-parallel rank, tensor-parallel rank, and pipeline-parallel rank.

Step 1 — tensor parallelism within a node. Tensor parallelism all-reduces every layer, on every forward and backward pass, so it demands the fastest interconnect you have. Put a tensor-parallel group of size t inside a single server, where the GPUs share an ultra-fast link (e.g. NVLink):

t \text{ GPUs} \;\xrightarrow{\text{fast intra-node link}}\; \text{one layer split } t \text{ ways.}

Step 2 — pipeline parallelism across nodes. Pipeline parallelism talks rarely (only activations at the stage cuts), so it tolerates the slower network between servers. Chain p nodes into a pipeline, each holding a block of layers:

p \text{ stages} \;\xrightarrow{\text{slower inter-node link}}\; \text{the depth split } p \text{ ways.}

Step 3 — data parallelism on top. A single (t \times p) group now holds one complete model replica. Replicate that whole group d times and feed each replica a different slice of the batch — adding throughput without touching the model split:

d \text{ replicas} \;\times\; (t \times p \text{ model-parallel group}).

Step 4 — shard and accelerate within. Layer ZeRO/FSDP across the data-parallel replicas so no GPU stores a redundant full copy of the optimizer state, and run FlashAttention inside every attention layer so the n^2 cost never bites. These don't add a fourth axis; they make the existing three cheaper in memory and time.

Step 5 — count the GPUs. Because the axes are independent, the total is simply their product:

N_{\text{GPU}} = d \times t \times p.

With, say, t = 8 (one node), p = 12 stages, and d = 64 replicas, that is 8 \times 12 \times 64 = 6144 GPUs training one model — the rule-of-thumb arrangement: tensor-parallel within a node, pipeline across nodes, data-parallel on top.

A frontier model is trained by composing three orthogonal axes:

Notice that every design choice above is really about communication. The placement rule — tensor-parallel within a node, pipeline across nodes — exists because the bandwidth inside a server (NVLink, hundreds of GB/s) is an order of magnitude higher than the bandwidth between servers (InfiniBand/Ethernet). Match the chattiest parallelism to the fastest link.

At thousands of GPUs the arithmetic is rarely the limit; moving gradients, activations, and shards between devices is. This is why frontier clusters are built around their network as much as their accelerators, why ZeRO and FlashAttention — both fundamentally about moving less data — matter so much, and why a careless parallelism layout can leave most of an expensive GPU fleet stalled waiting on the wire. Scaling laws assume you can keep the GPUs fed; the interconnect is what decides whether you can.

Map the fleet

Each cell below is one GPU, labelled by its (t, p) coordinate within a single data-parallel replica: columns are tensor-parallel ranks (one node, fast link), rows are pipeline stages (across nodes). Slide the data-parallel count d to see how many such grids exist in the whole fleet — the total GPU count d \times t \times p updates live.