Backward: Walking the Graph
Previously Defined
Valuewraps numbers and records operations- Each operation stores its local gradient
Now the payoff. The backward() method starts at the loss and walks the entire computation graph in reverse, applying the chain rule at every node.
def backward(self):
topo = []
visited = set()
def build_topo(v):
if v not in visited:
visited.add(v)
for child in v._children:
build_topo(child)
topo.append(v)
build_topo(self)
self.grad = 1
for v in reversed(topo):
for child, local_grad in zip(v._children, v._local_grads):
child.grad += local_grad * v.grad
Two phases:
Phase 1: Topological Sort
build_topo(self)
Starting from the loss, recursively visit all children first, then append the node. This guarantees that when we process a node, all nodes that depend on it have already been processed. It’s the same “start at the end, work backwards” idea as Step 1 — but now it’s automatic.
Phase 2: Propagate Gradients
self.grad = 1
for v in reversed(topo):
for child, local_grad in zip(v._children, v._local_grads):
child.grad += local_grad * v.grad
Starting with the loss (gradient = 1, because ∂loss/∂loss = 1), we walk in reverse topological order. At each node, we multiply the node’s gradient by each local gradient and add it to the child’s gradient.
That one line — child.grad += local_grad * v.grad — is the entire chain rule. It’s doing exactly what we did by hand in Step 1, but for any computation graph, not just our specific MLP.
Why += ?
A value might be used in multiple places (e.g., the same embedding row used for multiple tokens). The += accumulates gradients from all paths — this is the multivariate chain rule.