Causal Masking

A language model has one job: given the tokens so far, predict the next one. That job comes with a non-negotiable rule of fair play — when the model scores its guess for position i, it must not have peeked at positions i+1, i+2, \dots. Those are the future; using them would be reading the answer off the back of the exam. Yet plain self-attention lets every position attend to every other — including the ones to its right. Causal masking (also called autoregressive masking) is the one small surgical change that forbids the look-ahead, and it does so without slowing anything down.

Blocking the future, line by line

Recall the scaled dot-product attention for a sequence of length n: from queries Q and keys K we form a score matrix and softmax each row into attention weights. We intervene on one line — the scores, just before the softmax.

Step 1 — start from the raw scores. Entry S_{ij} is how much query position i wants to attend to key position j:

S_{ij} = \frac{q_i \cdot k_j}{\sqrt{d_k}}.

Position i reading position j with j > i means a present token consulting a future token — exactly what we must stop.

Step 2 — build the causal mask. Define an n \times n matrix M that is -\infty strictly above the diagonal (the future) and 0 on and below it (the present and past):

M_{ij} = \begin{cases} 0 & j \le i \quad(\text{present or past}) \\ -\infty & j > i \quad(\text{future}). \end{cases}

Step 3 — add the mask to the scores. Adding 0 leaves the allowed entries untouched; adding -\infty drives every future entry to -\infty:

(S + M)_{ij} = \begin{cases} S_{ij} & j \le i \\ -\infty & j > i. \end{cases}

Step 4 — softmax kills the future exactly. Apply the softmax along each row. Because e^{-\infty} = 0, every masked entry contributes nothing to the normaliser and receives weight exactly zero — not small, not approximately, but 0:

A_{ij} = \operatorname{softmax}_j (S + M)_{ij} = \frac{e^{\,S_{ij}}\,\mathbf{1}[\,j \le i\,]}{\sum_{k \le i} e^{\,S_{ik}}}, \qquad A_{ij} = 0 \ \text{ for } j > i.

Step 5 — read off the shape. Every weight above the diagonal is zero, so the attention matrix A is lower-triangular. Each row i is a clean probability distribution over only positions 1, \dots, i:

\sum_{j=1}^{i} A_{ij} = 1, \qquad A_{ij} = 0 \text{ for all } j > i.

Position i attends to itself and everything before it, and to nothing after. The future is sealed off.

Why this is the whole point of parallel training

Here is the payoff that made the transformer a training rocket. A naïve autoregressive model would have to generate token by token even during training, one slow step at a time. But with the mask in place, position i's prediction already depends only on 1, \dots, i — so we can feed the model the entire target sequence at once and compute all n next-token predictions in a single forward pass. The mask guarantees prediction i never cheats by seeing token i+1, even though that token is sitting right there in the batch. Left-to-right structure preserved; full GPU parallelism unlocked.

Adding the mask M (with M_{ij} = -\infty for j > i and 0 otherwise) to the attention scores before the softmax has three consequences:

The mask is a design knob, and which way you set it decides what kind of model you are building.

No mask (bidirectional). Every position attends to every other, future included. The model sees the whole context at once, which is ideal for understanding a fixed input — filling in a blanked-out word needs the words on both sides. This is the encoder reading of a sentence. What it cannot do is generate: if position i may look at i+1, "predict the next token" is meaningless because the answer is already visible.

Causal mask (left-to-right). Each position sees only itself and the past, so "predict the next token" is a genuine, leak-free task at every position simultaneously. This is the decoder reading, and the foundation of every generative language model. You trade away two-sided context for the ability to write the future one token at a time — exactly the trade a generator wants.

The mask, drawn

Each cell is one attention weight A_{ij}: row i is the query, column j the key. The active (coloured) cells are the lower triangle j \le i — present and past. The greyed-out upper triangle j > i is the future, masked to -\infty and so to weight exactly 0. Step through the queries with the slider: row i can reach exactly the first i columns — count them and you get i.