Training
So far
docs— shuffled namesuchars— [a…z]BOS— token 26state_dict— 27×27 count tablebigram(token)— prob of next token
Training this model is beautifully simple. Each step takes one name, tokenizes it, wraps it in BOS tokens, and does two things: measure how well the model currently predicts this name (the forward pass), then update the counts.
num_steps = 1000
for step in range(num_steps):
doc = docs[step % len(docs)]
tokens = [BOS] + [uchars.index(ch) for ch in doc] + [BOS]
n = len(tokens) - 1
The % (modulo) operator wraps around: when step reaches the end of the list, it starts over from the beginning. With 1,000 training steps and 32,033 names, we’ll only see the first 1,000 — each exactly once. If we had fewer names than steps (say 100 names and 1,000 steps), we’d cycle through the full list 10 times. Each name gets tokenized by looking up each character’s index in uchars, then wrapped with BOS on both sides.
So if doc is “ava”, then tokens becomes:
And n is 4 — the number of consecutive pairs we’ll examine (the arrows above).
The Forward Pass
Before we update the counts, we quiz the model: how well can you predict this name right now, before seeing it? This is the forward pass — it doesn’t change the model, it just measures how surprised the model is.
losses = []
for pos_id in range(n):
token_id, target_id = tokens[pos_id], tokens[pos_id + 1]
probs = bigram(token_id)
loss_t = -math.log(probs[target_id])
losses.append(loss_t)
loss = (1 / n) * sum(losses)
For each consecutive pair, we ask the model: “given this token, what probability do you assign to the actual next token?” The less confident the model is, the higher the loss. (More on what loss means in the next substep.)
Walking through “ava”:
| pos 0 | → | Given BOS26, how likely is a0? Record the loss. |
| pos 1 | → | Given a0, how likely is v21? Record the loss. |
| pos 2 | → | Given v21, how likely is a0? Record the loss. |
| pos 3 | → | Given a0, how likely is BOS26? Record the loss. |
The average of these four losses is the model’s loss for this name.
The Update
Then a separate loop incorporates what the model just saw:
for pos_id in range(n):
token_id, target_id = tokens[pos_id], tokens[pos_id + 1]
state_dict[token_id][target_id] += 1
For “ava”, this increments four cells in the count table:
| state_dict[26][0] += 1 | → | BOS26 was followed by a0 |
| state_dict[0][21] += 1 | → | a0 was followed by v21 |
| state_dict[21][0] += 1 | → | v21 was followed by a0 |
| state_dict[0][26] += 1 | → | a0 was followed by BOS26 |
That’s the entire learning algorithm. No gradients, no backpropagation, no optimizer. Just increment the count.
Try it
Watch the count table fill up as names are processed. Step through one at a time, or hit Play:
This works because the bigram model is simple enough to have a closed-form solution: the optimal parameters are just the normalized counts. Gradient descent — which we’ll introduce in Step 1 — is what you need when the model is too complex for exact solutions.