Computation Graphs

Multivariate calculus from first principles

One data structure organizes everything from the last two lessons: the computation graph. Every arithmetic operation in a model (add, multiply, matmul, activation) becomes a node in a directed graph. This graph is how PyTorch, JAX, and TensorFlow compute gradients automatically.

Training runs the graph in two sweeps. The forward pass flows left to right, computing and caching each node's value. The backward pass flows right to left, using the chain rule to push the gradient from the loss back to every input, one node at a time.

The idea that makes it scale: each node only needs to know its own local derivative. To send the gradient backward through a node, multiply the incoming gradient (from above) by the node's local Jacobian (how its output depends on its inputs). No node ever needs the global picture; local rules chained together produce the exact total gradient.

Where this lives in MLA computation graph is autograd. When you write a model in PyTorch, each operation silently records a node; calling loss.backward() walks the graph in reverse, multiplying local Jacobians via the chain rule, and deposits ∂loss/∂w on every parameter. You never write a derivative by hand, and that one convenience, derivatives computed exactly and for free, is much of why modern deep learning is…
▶ Computation Graphs
← Chain Rule: Matrix FormCritical Points in Rⁿ →