MicroGPT Visualized

Building a GPT from scratch — an interactive visual guide

← 3.6 Attention Output and Residuals 3.8 Training and Results →
Step 3: Attention › 3.7

The KV Cache

Previously Defined

  • keys, values — lists that grow with each token (from 3.5)
  • Attention computes Q·K → weights → weighted sum of V
  • Residual connections around attention and MLP blocks (from 3.6)

You may have noticed that keys and values are passed into the gpt() function and grow with each call:

    keys.append(k)
    values.append(v)

This is the KV cache — a running accumulation of every token’s key and value vectors. When the model processes position 5, the cache already contains keys and values from positions 0–4. The query at position 5 can attend to all of them without recomputing anything.

Try it

Click a token to see how the cache grows. The newest entry is highlighted — everything above it was already cached.

During training

The training loop initializes empty lists and feeds tokens one at a time:

    keys, values = [], []
    losses = []
    for pos_id in range(n):
        token_id, target_id = tokens[pos_id], tokens[pos_id + 1]
        logits = gpt(token_id, pos_id, keys, values)
        probs = softmax(logits)
        loss_t = -probs[target_id].log()
        losses.append(loss_t)

Each call to gpt() appends to the cache. By the time we process position 3, the cache contains keys and values from positions 0, 1, 2, and 3.

An important detail: the cached keys and values are live Value objects in the computation graph. Gradients flow through them during loss.backward(). This is unusual — in production systems, the KV cache during inference holds detached tensors. But since microgpt processes one token at a time even during training, the cache is part of the graph.

During inference

Same pattern — the cache builds up as we generate:

    keys, values = [], []
    token_id = BOS
    sample = []
    for pos_id in range(block_size):
        logits = gpt(token_id, pos_id, keys, values)
        probs = softmax([l / temperature for l in logits])
        token_id = random.choices(range(vocab_size), weights=[p.data for p in probs])[0]
        if token_id == BOS:
            break
        sample.append(uchars[token_id])

Each generated token’s key and value get cached, so the next token can attend to the full history.

Why it matters

Without the cache, generating position 5 would require recomputing keys and values for positions 0–4 as well. Position 6 would redo 0–5. Each new token would mean re-running the model on the entire sequence so far — O(n²) total work for n tokens.

With the cache, each position computes just its own key and value, then reads the rest from the cache. The total work becomes O(n) — one forward pass per token, regardless of how many came before.

← 3.6 Attention Output and Residuals 3.8 Training and Results →