Splitting a matmul, line by line
A linear layer is just Y = X W, with input
X \in \mathbb{R}^{n \times d} and weights
W \in \mathbb{R}^{d \times h}. Suppose W is
too big for one GPU. We have two natural ways to cut it.
Step 1 — column-partition the weights. Split W down
the middle into two column blocks, one per GPU:
W = \big[\, W_1 \;\big|\; W_2 \,\big], \qquad W_1, W_2 \in \mathbb{R}^{d \times h/2}.
GPU 1 holds W_1; GPU 2 holds W_2. Neither
stores the whole matrix, so each needs only half the memory.
Step 2 — each GPU computes its own half of the output. Matrix multiplication is
block-structured: multiplying by a block of columns produces exactly the matching block of output
columns. Give both GPUs the same input X and they compute, fully in
parallel,
Y = X W = X\big[\, W_1 \mid W_2 \,\big] = \big[\, X W_1 \;\big|\; X W_2 \,\big].
GPU 1 produces X W_1 (the left half of Y),
GPU 2 produces X W_2 (the right half). One big matmul has become two
independent half-width matmuls, running at the same time on two devices.
Step 3 — concatenate to recover the full output. Lay the two halves side by
side and you have the exact same Y a single GPU would have produced —
no approximation, just a redistribution of the work:
Y = \operatorname{concat}\big(X W_1,\; X W_2\big) \in \mathbb{R}^{n \times h}.
Step 4 — feed a row-partitioned matmul to turn concat into all-reduce. In a
transformer the column-split layer is followed by a second linear map
Z = Y B. Partition that weight matrix by rows to
match the split halves of Y:
B = \begin{bmatrix} B_1 \\ B_2 \end{bmatrix}, \qquad Z = Y B = \big[\, X W_1 \mid X W_2 \,\big]\begin{bmatrix} B_1 \\ B_2 \end{bmatrix} = (X W_1) B_1 + (X W_2) B_2.
Now each GPU already owns one term of that sum: GPU 1 computes
(X W_1) B_1, GPU 2 computes (X W_2) B_2.
We never need to concatenate the intermediate Y at all.
Step 5 — combine the partial sums with one all-reduce. The full output is the
sum of the two partial products, each held on a different device, so a single
all-reduce — sum-across-GPUs — produces the final answer on every GPU:
Z = \underbrace{(X W_1) B_1}_{\text{GPU 1}} + \underbrace{(X W_2) B_2}_{\text{GPU 2}} \;\xrightarrow{\;\text{all-reduce}\;}\; Z \text{ on every GPU.}
One enormous matmul has become several smaller ones spread across devices, paid for by exactly
one communication step. That is the whole bargain of tensor parallelism: more arithmetic in
parallel, in exchange for a little cross-GPU traffic.
A single linear layer can be evaluated across t GPUs:
-
Split the weight matrix. Partition
W \in \mathbb{R}^{d \times h} across GPUs — by columns,
W = [W_1 \mid \dots \mid W_t], so each GPU stores only
h/t columns.
-
Compute partial products in parallel. With shared input
X, GPU i computes
X W_i — its slice of the output — independently of the others.
-
Combine with all-reduce (or concat). A following row-partitioned matmul
turns the output into a sum of per-GPU partial products,
Z = \sum_i (X W_i) B_i, recovered exactly on every device by a
single all-reduce. The result is identical to the single-GPU computation.
Multi-head
attention is tailor-made for tensor parallelism, because the heads are already
independent. With H heads and t GPUs,
simply give each GPU a disjoint block of H/t heads:
\text{GPU } i \;\text{holds heads}\; \{(i-1)\tfrac{H}{t} + 1,\, \dots,\, i\tfrac{H}{t}\}.
Each GPU computes the
scaled
dot-product attention for its own heads — its own slices of the query, key, and
value projections — entirely on its own. The per-head outputs are then concatenated and fed
through the output projection, which (being row-partitioned over the head dimension) closes
with the same single all-reduce. Because each head's Q,K,V
projections are just column-partitioned weight matrices, this is precisely the
column-then-row pattern above, applied to the attention block.
That all-reduce runs once per layer, on the forward pass and again on the backward pass
— many times per step, on large activation tensors. So tensor parallelism is bandwidth-hungry:
it only pays off when the GPUs are joined by a very fast interconnect (e.g. NVLink within a
single server). Spread a tensor-parallel group across a slow network and the communication
swamps the saved computation. This is why, in practice, tensor parallelism is kept
within a node and the coarser
pipeline
parallelism is used to span nodes.