Multi-Head Attention: The Code
Previously Defined
n_head = 4,head_dim = 4- Each head attends independently over a 4-dimensional slice
- Outputs concatenated back to 16 dimensions
Layer-prefixed keys
Before we look at the attention code, one thing will look different: the parameter names. In 3.4, the state dict had flat keys like 'attn_wq'. Now they’re prefixed with a layer index:
This is because the attention+MLP block now lives inside a layer loop (for li in range(n_layer)), and each layer has its own set of weight matrices. We’ll cover the layer loop fully in 4.4 — for now, just read layer{li}. as “this layer’s.”
Setup: residual, norm, Q/K/V
The start of the attention block is familiar from 3.4:
x_residual = x
x = rmsnorm(x)
q = linear(x, state_dict[f'layer{li}.attn_wq'])
k = linear(x, state_dict[f'layer{li}.attn_wk'])
v = linear(x, state_dict[f'layer{li}.attn_wv'])
keys[li].append(k)
values[li].append(v)
Save the residual, normalize, project Q/K/V, and cache K and V. The only changes are the layer-prefixed keys and keys[li] instead of keys (each layer has its own cache — more on that in 4.5).
The head loop
This is the new part. Instead of computing attention over all 16 dimensions at once, we loop over four heads. Step through to see how each head slices its portion and x_attn builds up:
x_attn = []
for h in range(n_head):
hs = h * head_dim
q_h = q[hs:hs+head_dim]
k_h = [ki[hs:hs+head_dim] for ki in keys[li]]
v_h = [vi[hs:hs+head_dim] for vi in values[li]]
attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]
attn_weights = softmax(attn_logits)
head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]
x_attn.extend(head_out)
Each head slices its portion of Q, K, and V:
| hs = h * head_dim | → | Start index for this head’s slice (0, 4, 8, 12) |
| q_h = q[hs:hs+head_dim] | → | This head’s 4-dim query (sliced from full 16-dim) |
| k_h = [ki[hs:hs+head_dim] for ki in keys[li]] | → | This head’s 4-dim keys from all cached positions |
| v_h = [vi[hs:hs+head_dim] for vi in values[li]] | → | This head’s 4-dim values from all cached positions |
Q, K, V are still computed as full 16-dimensional projections — the same weight matrices as 3.4. The splitting into heads happens by slicing: head h takes dimensions h*4 through h*4 + 3. There are no separate per-head weight matrices.
Then each head runs the same attention computation from 3.5 and 3.6, just on 4 dimensions instead of 16. Here’s what happens inside one head:
| attn_logits = [... / head_dim**0.5 ...] | → | Attention scores (scaled by √4, not √16) |
| attn_weights = softmax(attn_logits) | → | Normalize to probabilities |
| head_out = [...] | → | Weighted sum of values (4-dim output) |
| x_attn.extend(head_out) | → | Append this head’s output to the accumulator |
After all four heads, x_attn is a 16-dimensional vector (4 heads × 4 dims).
Output projection and residual
The last two lines are the same as 3.6, apart from the layer-prefixed key:
x = linear(x_attn, state_dict[f'layer{li}.attn_wo'])
x = [a + b for a, b in zip(x, x_residual)]
The output projection remixes the concatenated head outputs, and the residual connection adds the result back to the input.
What changed from Step 3
The single attention computation over all 16 dimensions becomes a loop of four computations over 4 dimensions each. The computation inside each head is identical to Step 3’s, just narrower (head_dim instead of n_embd).