The MLP Model
So far
state_dict— three weight matriceslinear(x, w)— matrix-vector multiplysoftmax(logits)— logits → probabilities
Now we wire everything together. The MLP (Multi-Layer Perceptron) takes a token ID and produces 27 logits — one raw score for each possible next token:
def mlp(token_id):
x = state_dict['wte'][token_id]
x = linear(x, state_dict['mlp_fc1'])
x = [max(0, xi) for xi in x] # relu
logits = linear(x, state_dict['mlp_fc2'])
return logits
Let’s trace through what happens when we pass in a0:
| wte[0] | → | Look up row 0 in the embedding table. Result: a vector of 16 numbers. |
| linear(x, mlp_fc1) | → | Multiply by the 64×16 hidden layer. Result: 64 numbers. |
| max(0, xi) | → | ReLU: set all negative values to zero. Still 64 numbers, but some are now 0. |
| linear(x, mlp_fc2) | → | Multiply by the 27×64 output layer. Result: 27 logits. |
Try it
Pick a token and watch the values flow through each layer:
ReLU (Rectified Linear Unit) is the simplest activation function: keep positive values, zero out negatives. Without it, stacking two linear layers would be mathematically equivalent to a single linear layer. ReLU is what makes the network capable of learning nonlinear patterns.
Comparing to Step 0
Both models have the same interface: give them a token, get back scores for the next token.
Step 0: bigram() | Step 1: mlp() | |
|---|---|---|
| Input | token ID | token ID |
| Output | 27 probabilities | 27 logits (need softmax) |
| Internals | Count table lookup + normalize | Embedding → linear → ReLU → linear |
| Parameters | 729 counts | 2,480 weights |
| Learning | Counting | Gradient descent |
The MLP is a “differentiable version” of the count table — it can represent the same patterns, but it can also learn subtler ones because it has more capacity and processes information through multiple layers.