Everything You Need to Know About Backpropagation
Introduction
No algorithm is more fundamental to deep learning than backpropagation. While the concept itself is beautifully simple, peeking under the hood of modern frameworks like PyTorch or TensorFlow reveals a labyrinth of device dispatchers, type handlers, and optimization layers that obscure the elegant mathematics beneath. Until JAX came along, which exposes a very elegant API for exploring and defining backprop operations very easily.
You might wonder: if modern frameworks handle differentiation automatically, why bother understanding backpropagation at all? The answer becomes clear the moment you need to push beyond standard operations. When optimizing critical bottlenecks with custom kernels in CUDA, Pallas (JAX), or Triton, you’re suddenly responsible for defining how gradients flow through your code. Without understanding backpropagation’s core mechanics (how to compute and propagate gradients correctly), you’ll struggle to write kernels that integrate seamlessly into your training loop. Even a foundational grasp of VJPs and the chain rule transforms your understanding of how deep learning models are trained.
The Core Theory of Backpropagation
At its heart, a neural network is a series of nested mathematical functions. We start with an input, pass it through the first function (or “layer”) to get an intermediate result, pass that result through the next layer, and so on, until we get a final output. To “train” the network, we need to systematically adjust the parameters (the weights and biases) of each function to minimize the final error. Backpropagation is the clever and efficient algorithm that tells us exactly how to do this. It works by calculating the gradient, or the rate of change of the final error with respect to each parameter in the network, allowing an optimization algorithm like gradient descent to update the parameters in the right direction.
To make this concrete, let’s trace the flow of data through a simple two-layer neural network, which will be our working example for the rest of this post.
\[\begin{align*} z_1 &= f_1(x) = W_1 x \hspace{2cm} &\text{// } x: (d_1, 1),\; W_1: (d_2, d_1),\; z_1: (d_2, 1) \\ a &= \tanh(z_1) &\text{// } a: (d_2, 1) \\z_2 &= f_2(a) = W_2 a &\text{// } W_2: (d_3, d_2),\; z_2: (d_3, 1) \\ y &= \text{Softmax}(z_2) &\text{// } y: (d_3, 1) \\ \mathcal{L} &= \text{Loss}(y) &\text{// } \mathcal{L}: \text{scalar} \end{align*}\]The Chain Rule with Vector-Jacobian Products (VJP)
This is the single most essential mathematical concept for understanding backpropagation:
Definition 1 (Vector-Jacobian Product) >
\[\nabla_{x} \mathcal{L} = J_f(x)^T \nabla_{y} \mathcal{L}\]
Let $y = f(x)$, $x \in \mathbb{R}^n, \quad y \in \mathbb{R}^m, \quad \mathcal{L} = \mathcal{L}(y)$ is a scalar loss function; thenwhere $J_f(x) = \frac{\partial f(x)}{\partial x} \in \mathbb{R}^{m \times n}$.
The key insight of backpropagation is understanding what the vector $\nabla_{y} \mathcal{L}$ represents. It is the “upstream gradient”, the gradient of the final scalar loss with respect to the output $y$ of the function $f$. The VJP $J_f(x)^T \nabla_{y} \mathcal{L}$ then gives us the “local gradient”, the gradient of the final scalar loss with respect to the function’s input $x$. Backpropagation is, in essence, a chain of these VJP calculations, passing the upstream gradient from the last layer all the way back to the first.
Now, we can get the equation for computing the gradient of the loss with the weights of our furthest layer as such:
\[\begin{align*}\nabla_{z_2} \mathcal{L} &= J_{y}^T \nabla_{y} \mathcal{L} \\\nabla_{a} \mathcal{L} &= J_{z_2}^T \nabla_{z_2} \mathcal{L} = J_{z_2}^T J_{y}^T \nabla_{y} \mathcal{L} \\\nabla_{z_1} \mathcal{L} &= J_{a}^T \nabla_{a} \mathcal{L}= J_{a}^T J_{z_2}^T J_{y}^T \nabla_{y} \mathcal{L} \\ \nabla_{W_1} \mathcal{L} &= J_{z_1}^T \nabla_{z_1} \mathcal{L} = J_{z_1}^T J_{a}^T J_{z_2}^T J_{y}^T \nabla_{y} \mathcal{L} \end{align*}\]It’s very important to take a moment and realize there are two ways to compute the gradient, one could start computing from left to right in a fashion known as “Forward mode automatic differentiation” or start computing from right to left in a fashion known as “Backward mode automatic differentiation”. The latter is the one used in training neural networks and commonly referred to as “backpropagation” or “backprop” for short.
\[\underbrace{\nabla_{W_1} \mathcal{L}}_{(d_2 \times d_1,\, d_2)} = \underbrace{J_{z_1}^T}_{(d_2 \times d_1,\, d_2)} \ \underbrace{J_{a}^T}_{(d_2,\, d_2)} \ \underbrace{J_{z_2}^T}_{(d_2,\, d_3)} \ \underbrace{J_{y}^T}_{(d_3,\, d_3)} \ \underbrace{\nabla_{y} \mathcal{L}}_{(d_3,\, 1)}\]Forward vs. Reverse Mode: Why Backprop Wins
To avoid any hand-waving, I won’t move on until I show you fully why backprop is a vastly more efficient method in training deep learning models compared to forward mode autodiff. (If you read this and aren’t yet convinced, leave a comment and I will try to provide another example).
Even though in reality we need to compute $\nabla_{W_1} \mathcal{L}$, I will dissect the computation of $\nabla_{z_1} \mathcal{L}$ mainly because the leftmost jacobian turns out to be a simple outer product (we will show later that we never materialize any of the jacobians shown here).
\[\underbrace{\nabla_{z_1} \mathcal{L}}_{(d_2, 1)} = \underbrace{J_{a}^T}_{(d_2,\, d_2)} \ \underbrace{J_{z_2}^T}_{(d_2,\, d_3)} \ \underbrace{J_{y}^T}_{(d_3,\, d_3)} \ \underbrace{\nabla_{y} \mathcal{L}}_{(d_3,\, 1)}\]In forward mode autodiff, we start from the left and compute all matrix multiplications as such:
\[\nabla_{z_1} \mathcal{L} = ((J_{a}^T J_{z_2}^T )J_{y}^T) \nabla_{y} \mathcal{L}\]which yields the number of multiplications as follows (remember multiplying $(m,n) \times (n,p)$ matrices requires $mnp$ multiplications)
Total No. of Multiplications for Forward Mode: $d_2^2 d_3 + d_2 d_3^2 + d_2 d_3$
In backward mode autodiff, we start from the right and compute all matrix multiplications as such:
\[\nabla_{z_1} \mathcal{L} = J_{a}^T (J_{z_2}^T (J_{y}^T \nabla_{y} \mathcal{L}))\]which yields the number of multiplications as follows
Total No. of Multiplications for Reverse Mode: $d_3^2 + d_2 d_3 + d_2^2$
The reason backpropagation is universally used for training neural networks becomes clear when we compare the total costs, especially with dimensions typical for these models. In most neural networks, the hidden layers are much wider than the output layer ($d_2 \gg d_3$).
Let’s consider a simple classification problem with a 500-neuron hidden layer ($d_2 = 500$) and 10 output classes ($d_3 = 10$).
-
Reverse Mode Cost = $10^2 + (500 \times 10) + 500^2 = 100 + 5,000 + 250,000 = \boldsymbol{255,100}$ operations.
-
Forward Mode Cost = $(500^2 \times 10) + (500 \times 10^2) + (500 \times 10) = 2,500,000 + 50,000 + 5,000 = \boldsymbol{2,555,000}$ operations.
In this realistic scenario, forward mode is 10 times more computationally expensive than reverse mode.
This dramatic difference stems from the fact that forward mode requires performing computationally heavy matrix-matrix multiplications, creating large intermediate matrices. Reverse mode (backpropagation) cleverly avoids this by only ever performing matrix-vector multiplications, which is significantly cheaper.
The VJP Trick: Never Materialize Jacobians
While we’ve established that backprop is the efficient approach for computing gradients, our algorithm can be made even faster and more memory-efficient by avoiding an expensive computational step: materializing Jacobian matrices.
We will show a very powerful trick, we never need to materialize any jacobian to pass on the gradients to the inputs.
Deriving the Gradient for a Linear Layer
Let’s demonstrate this principle using the most fundamental operation in neural networks: matrix multiplication. Consider a linear transformation defined as:
\[\begin{align*} y &= Wx \\ \nabla_{W} \mathcal{L} &= J_y^T \nabla_{y} \mathcal{L} \quad \text{From Definition 1 above} \end{align*}\]Let’s expand the above equation fully to observe that the jacobian is sparse (mostly zeros) and simplify the operation further.
Let’s define the dimensions: $x \in \mathbb{R}^{d_1 \times 1}$, $W \in \mathbb{R}^{d_2 \times d_1}$, and $y \in \mathbb{R}^{d_2 \times 1}$. The components of the output are $y_i = \sum_{j=1}^{d_1} W_{ij} x_j$.
The Jacobian of $y$ with respect to the flattened weight matrix is a $(d_2, d_2 d_1)$ matrix. The derivative $\frac{\partial y_k}{\partial W_{ij}}$ is non-zero only when $k=i$, resulting in a very sparse Jacobian:
\[J_y = \begin{pmatrix} \underbrace{x_1 \ \dots \ x_{d_1}}_{W_{1j}} & \underbrace{0 \ \dots \ 0}_{W_{2j}} & \cdots & \underbrace{0 \ \dots \ 0}_{W_{d_2,j}} \\ 0 \ \dots \ 0 & x_1 \ \dots \ x_{d_1} & \cdots & 0 \ \dots \ 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 \ \dots \ 0 & 0 \ \dots \ 0 & \cdots & x_1 \ \dots \ x_{d_1} \end{pmatrix}\]The transposed Jacobian $J_y^T$ is a $(d_2 d_1, d_2)$ matrix. We multiply this by the gradient vector $\nabla_y \mathcal{L} \in \mathbb{R}^{d_2 \times 1}$:
\[\nabla_W \mathcal{L} = J_y^T \nabla_y \mathcal{L} = \begin{pmatrix} x_1 & 0 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ x_{d_1} & 0 & \cdots & 0 \\[0.4em] \hline 0 & x_1 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & x_{d_1} & \cdots & 0 \\[0.4em] \hline \vdots & \vdots & \ddots & \vdots \\[0.4em] \hline 0 & 0 & \cdots & x_1 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & x_{d_1} \end{pmatrix} \begin{pmatrix} \frac{\partial \mathcal{L}}{\partial y_1} \\ \frac{\partial \mathcal{L}}{\partial y_2} \\ \vdots \\ \frac{\partial \mathcal{L}}{\partial y_{d_2}} \end{pmatrix} = \begin{pmatrix} x_1 \frac{\partial \mathcal{L}}{\partial y_1} \\ \vdots \\ x_{d_1} \frac{\partial \mathcal{L}}{\partial y_1} \\[0.4em] \hline x_1 \frac{\partial \mathcal{L}}{\partial y_2} \\ \vdots \\ x_{d_1} \frac{\partial \mathcal{L}}{\partial y_2} \\[0.4em] \hline \vdots \\[0.4em] \hline x_1 \frac{\partial \mathcal{L}}{\partial y_{d_2}} \\ \vdots \\ x_{d_1} \frac{\partial \mathcal{L}}{\partial y_{d_2}} \end{pmatrix}\]Reshaping this flattened $(d_2 d_1, 1)$ vector back to the $(d_2, d_1)$ dimensions of $W$, we get:
\[\nabla_W \mathcal{L} = \begin{pmatrix} \frac{\partial \mathcal{L}}{\partial y_1}x_1 & \frac{\partial \mathcal{L}}{\partial y_1}x_2 & \dots & \frac{\partial \mathcal{L}}{\partial y_1}x_{d_1} \\ \frac{\partial \mathcal{L}}{\partial y_2}x_1 & \frac{\partial \mathcal{L}}{\partial y_2}x_2 & \dots & \frac{\partial \mathcal{L}}{\partial y_2}x_{d_1} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial \mathcal{L}}{\partial y_{d_2}}x_1 & \frac{\partial \mathcal{L}}{\partial y_{d_2}}x_2 & \dots & \frac{\partial \mathcal{L}}{\partial y_{d_2}}x_{d_1} \end{pmatrix}\]This matrix is exactly the outer product of the column vector $\nabla_y \mathcal{L}$ and the row vector $x^T$:
\[\nabla_y \mathcal{L} \cdot x^T = \begin{pmatrix} \frac{\partial \mathcal{L}}{\partial y_1} \\ \frac{\partial \mathcal{L}}{\partial y_2} \\ \vdots \\ \frac{\partial \mathcal{L}}{\partial y_{d_2}} \end{pmatrix} \begin{pmatrix} x_1 & x_2 & \dots & x_{d_1} \end{pmatrix} = \begin{pmatrix} \frac{\partial \mathcal{L}}{\partial y_1}x_1 & \frac{\partial \mathcal{L}}{\partial y_1}x_2 & \dots & \frac{\partial \mathcal{L}}{\partial y_1}x_{d_1} \\ \frac{\partial \mathcal{L}}{\partial y_2}x_1 & \frac{\partial \mathcal{L}}{\partial y_2}x_2 & \dots & \frac{\partial \mathcal{L}}{\partial y_2}x_{d_1} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial \mathcal{L}}{\partial y_{d_2}}x_1 & \frac{\partial \mathcal{L}}{\partial y_{d_2}}x_2 & \dots & \frac{\partial \mathcal{L}}{\partial y_{d_2}}x_{d_1} \end{pmatrix}\]Thus, the Vector-jacobian product (VJP) simplifies to the efficient outer product
\[\nabla_W \mathcal{L} = \boxed{ \nabla_y \mathcal{L} \cdot x^T }\]I want you to take a moment to see why this is so efficient. $x$ is the input to our local operation which we already have and $\nabla_{y} \mathcal{L}$ is the gradient of the loss with respect to the output of our local operation and it’s being passed down through the computational graph. We successfully computed the final gradient without ever needing to materialize the full jacobian. This idea can be extended to any operation as long as you define how to pass the gradients from the outputs to the inputs which is exactly what we are going to show in code next using JAX.
Implementing Custom VJPs in JAX
JAX already has a built-in autodiff library that can compute the gradients of any function using jax.numpy primitives, but we are going to override that using a mechanism in JAX that allows you to write custom gradients for any function you write and then we are going to construct all operations needed to train the simple two layer classifier we have been working with all this time to train on MNIST dataset.
Let’s start by defining the simplest operation of matrix multiplication, we will call it Linear
import jax
import jax.numpy as jnp
@jax.custom_vjp
def Linear(W, x):
return W @ x
def Linear_fwd(W, x):
return Linear(W, x), (W, x)
def Linear_bwd(res, g):
W, x = res
grad_W = jnp.outer(g, x)
grad_x = W.T @ g
return (grad_W, grad_x)
Linear.defvjp(Linear_fwd, Linear_bwd)
The g parameter in Linear_bwd represents the upstream gradient, that is, the gradient of the final loss with respect to the output of the Linear function.
There are two things we need to understand here: first is that we need to define two functions and not just one for the backward pass bwd, the reason is we need to save some variables that we need during the backward pass, they are saved in a variable called residuals (res), second is that we need to compute the gradients for both $W$ and $x$ because even though we will only update $W$, we need the gradient to keep flowing through $x$ to the rest of the computational graph.
Now let’s define the rest of the operations we need.
To further solidify the VJP (Vector-Jacobian Product) concept, let’s consider an element-wise activation function like $a = \tanh(z)$. The derivative of $\tanh(z)$ is $1 - \tanh^2(z)$, and since the operation is applied element-wise, the Jacobian $J_a$ is a diagonal matrix with each diagonal entry being the derivative of the corresponding component:
\[J_a = \begin{pmatrix} 1 - a_1^2 & 0 & \cdots & 0 \\ 0 & 1 - a_2^2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & 1 - a_{d_2}^2 \end{pmatrix}\]Here, $a_i = \tanh(z_i)$, so each entry is $1 - \tanh^2(z_i) = 1 - a_i^2$.
Since the Jacobian is diagonal and the operation is element-wise, the VJP (Vector-Jacobian Product) simplifies to a Hadamard (element-wise) product:
\[\nabla_z \mathcal{L} = J_a^T \nabla_a \mathcal{L} = J_a \nabla_a \mathcal{L} = (1 - a^2) \odot \nabla_a \mathcal{L}\]# 2. Tanh Activation
@jax.custom_vjp
def Tanh(z):
return jnp.tanh(z)
def Tanh_fwd(z):
a = Tanh(z)
return a, (a,)
def Tanh_bwd(res, g):
a, = res
grad_z = g * (1 - a**2)
return (grad_z,) # Gradient for the single input z
Tanh.defvjp(Tanh_fwd, Tanh_bwd)
Deriving the Softmax Cross-Entropy Gradient
The tanh is quite straightforward and you could easily verify it by hand. One important thing to realize is there is no restriction on the complexity of the operation as long as we know how to flow the gradients from the outputs to the inputs. So we have two options now either first define a softmax operation and then define another operation for the cross-entropy loss or combine them in a single operation. There are multiple reasons why you would want to combine operations into a single operation (e.g., numerical stability) but we will do because the gradients will be very simple. Let’s derive the gradients for the combined Softmax + Cross entropy operation and then implement softmax_cross_entropy function.
Note: here $y$ and $z$ refer to the logits and true labels respectively, not to be confused with the $y$ and $z$ used in the two-layer network example above.
First, let’s define the core functions. We have a set of logits (the raw output of the last linear layer), $z = [z_1, z_2, …, z_C]$, where $C$ is the number of classes. The true label is a one-hot encoded vector $y = [y_1, y_2, …, y_C]$, where only one $y_k=1$ and all others are 0.
Softmax Function: Converts logits into probabilities. The probability for the $j$-th class, $p_j$, is:
\[p_j = \text{softmax}(z)_j = \frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}}\]Cross-Entropy Loss: Calculates the loss based on the predicted probabilities $p$ and the true labels $y$.
\[\mathcal{L} = - \sum_{k=1}^{C} y_k \log(p_k)\]The loss $\mathcal{L}$ is not a direct function of $z_i$. Instead, $\mathcal{L}$ depends on all the probabilities $p_1, p_2, …, p_C$, and each of these probabilities in turn depends on the logit $z_i$. Therefore, we must use the chain rule and sum over all paths through which $z_i$ affects $\mathcal{L}$:
\[\frac{\partial \mathcal{L} }{\partial z_i} = \sum_{j=1}^{C} \frac{\partial \mathcal{L} }{\partial p_j} \frac{\partial p_j}{\partial z_i}\]We will solve this by calculating the two partial derivatives separately.
Derivative of Loss with respect to Probabilities
This part is straightforward. We differentiate the loss function $\mathcal{L} = - \sum_{k=1}^{C} y_k \log(p_k)$ with respect to a single probability $p_j$. Only the term where $k=j$ is non-zero.
\[\frac{\partial \mathcal{L} }{\partial p_j} = \frac{\partial}{\partial p_j} \left( - y_j \log(p_j) \right) = - \frac{y_j}{p_j}\]Derivative of Softmax with respect to Logits
This part is more complex because the output of softmax at index $j$ depends on all the input logits. We have two cases for $\frac{\partial p_j}{\partial z_i}$.
Case A: When $i = j$ (e.g., differentiating $p_i$ with respect to $z_i$) We use the quotient rule on $p_i = \frac{e^{z_i}}{\sum_{k} e^{z_k}}$:
\[\begin{align*} \frac{\partial p_i}{\partial z_i} &= \frac{(e^{z_i})' (\sum_{k} e^{z_k}) - (e^{z_i}) (\sum_{k} e^{z_k})'}{(\sum_{k} e^{z_k})^2}\\ &= \frac{e^{z_i} (\sum_{k} e^{z_k}) - e^{z_i} (e^{z_i})}{(\sum_{k} e^{z_k})^2}\\ &= \frac{e^{z_i}}{\sum_{k} e^{z_k}} - \left(\frac{e^{z_i}}{\sum_{k} e^{z_k}}\right)^2 = p_i - p_i^2 = p_i(1 - p_i) \end{align*}\]Case B: When $i \neq j$ (e.g., differentiating $p_j$ with respect to $z_i$) Again, using the quotient rule on $p_j = \frac{e^{z_j}}{\sum_{k} e^{z_k}}$. This time, the numerator $e^{z_j}$ is a constant with respect to $z_i$.
\[\begin{align*} \frac{\partial p_j}{\partial z_i} &= \frac{(e^{z_j})' (\sum_{k} e^{z_k}) - (e^{z_j}) (\sum_{k} e^{z_k})'}{(\sum_{k} e^{z_k})^2}\\ &= \frac{0 \cdot (\sum_{k} e^{z_k}) - e^{z_j} (e^{z_i})}{(\sum_{k} e^{z_k})^2}\\ &= - \left(\frac{e^{z_j}}{\sum_{k} e^{z_k}}\right) \left(\frac{e^{z_i}}{\sum_{k} e^{z_k}}\right) = - p_j p_i \end{align*}\]Now we substitute these results back into our chain rule sum. We split the sum into two parts: the term where $j=i$ and all the terms where $j \neq i$.
\[\frac{\partial \mathcal{L} }{\partial z_i} = \underbrace{\frac{\partial \mathcal{L} }{\partial p_i} \frac{\partial p_i}{\partial z_i}}_{\text{Term for } j=i} + \underbrace{\sum_{j \neq i} \frac{\partial \mathcal{L} }{\partial p_j} \frac{\partial p_j}{\partial z_i}}_{\text{Terms for } j \neq i}\]Substitute our findings from the previous steps and simplify
\[\begin{align*} \frac{\partial \mathcal{L} }{\partial z_i} &= \left(-\frac{y_i}{p_i}\right) \cdot \left(p_i(1-p_i)\right) + \sum_{j \neq i} \left(-\frac{y_j}{p_j}\right) \cdot \left(-p_j p_i\right)\\ &= -y_i(1-p_i) + \sum_{j \neq i} y_j p_i \\ &= -y_i + y_i p_i + p_i \sum_{j \neq i} y_j \\ &= -y_i + p_i \left( y_i + \sum_{j \neq i} y_j \right) \\ &= -y_i + p_i(1) \\ &= \boxed{p_i - y_i} \end{align*}\]This remarkable result shows that the gradient of the combined softmax and cross-entropy loss with respect to a single logit $z_i$ is simply the difference between the predicted probability for that class and the true label for that class.
When we generalize this from a single logit to the entire vector of logits $z$, the gradient vector is:
\[\nabla_z \mathcal{L} = p - y\]This is precisely what we implemented in the following softmax_cross_entropy_bwd function: probs - labels_res.
@jax.custom_vjp
def softmax_cross_entropy(logits, label):
# This is the forward definition of the function
max_logit = jnp.max(logits)
exps = jnp.exp(logits - max_logit)
probs = exps / jnp.sum(exps)
return -jnp.sum(label * jnp.log(probs + 1e-7))
def softmax_cross_entropy_fwd(logits, label):
# Forward pass for VJP: compute output and save residuals
loss = softmax_cross_entropy(logits, label)
# Re-compute probabilities here to save them as residuals for the backward pass.
max_logit = jnp.max(logits)
exps = jnp.exp(logits - max_logit)
probs = exps / jnp.sum(exps)
return loss, (probs, label)
def softmax_cross_entropy_bwd(res, g):
# Backward pass for VJP
probs, label_res = res
# Gradient with respect to logits
grad_logits = g * (probs - label_res)
# The gradient for 'label' is not needed, so we return None.
# The returned tuple must match the number of inputs to the function.
return (grad_logits, None)
softmax_cross_entropy.defvjp(softmax_cross_entropy_fwd, softmax_cross_entropy_bwd)
Putting It All Together: A Full Example
Now, I will combine all the pieces together to define the two-layer neural network and train it on the MNIST dataset using JAX with our custom backpropagation implementations in a single script of 150 lines.
Training a Neural Network on MNIST
import jax
import jax.numpy as jnp
from jax import grad, jit, value_and_grad, vmap
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
# 1. Linear Layer
@jax.custom_vjp
def Linear(W, x):
return W @ x
def Linear_fwd(W, x):
return Linear(W, x), (W, x)
def Linear_bwd(res, g):
W, x = res
grad_W = jnp.outer(g, x)
grad_x = W.T @ g
return (grad_W, grad_x)
Linear.defvjp(Linear_fwd, Linear_bwd)
@jax.custom_vjp
def Tanh(z):
return jnp.tanh(z)
def Tanh_fwd(z):
a = Tanh(z)
return a, (a,)
def Tanh_bwd(res, g):
a, = res
grad_z = g * (1 - a**2)
return (grad_z,)
Tanh.defvjp(Tanh_fwd, Tanh_bwd)
@jax.custom_vjp
def softmax_cross_entropy(logits, label):
max_logit = jnp.max(logits)
exps = jnp.exp(logits - max_logit)
probs = exps / jnp.sum(exps)
return -jnp.sum(label * jnp.log(probs + 1e-7))
def softmax_cross_entropy_fwd(logits, label):
loss = softmax_cross_entropy(logits, label)
# Re-compute probabilities here to save them as residuals for the backward pass.
max_logit = jnp.max(logits)
exps = jnp.exp(logits - max_logit)
probs = exps / jnp.sum(exps)
return loss, (probs, label)
def softmax_cross_entropy_bwd(res, g):
probs, label_res = res
grad_logits = g * (probs - label_res)
# The gradient for 'label' is not needed, so we return None.
# The returned tuple must match the number of inputs to the function.
return (grad_logits, None)
softmax_cross_entropy.defvjp(softmax_cross_entropy_fwd, softmax_cross_entropy_bwd)
def init_params(key, sizes):
keys = jax.random.split(key, len(sizes) - 1)
return [jax.random.normal(k, (out_size, in_size)) * jnp.sqrt(2. / in_size)
for k, in_size, out_size in zip(keys, sizes[:-1], sizes[1:])]
def forward_pass(params, image):
W1, W2 = params
z1 = Linear(W1, image)
a1 = Tanh(z1)
logits = Linear(W2, a1)
return logits
# Vectorize the forward pass to handle batches efficiently
batch_forward = vmap(forward_pass, in_axes=(None, 0))
def batch_loss_fn(params, images, labels):
logits = batch_forward(params, images)
# Vmap the single-example loss function over the batch
losses = vmap(softmax_cross_entropy)(logits, labels)
return jnp.mean(losses)
def accuracy(params, images, labels):
logits = batch_forward(params, images)
predicted_class = jnp.argmax(logits, axis=1)
true_class = jnp.argmax(labels, axis=1)
return jnp.mean(predicted_class == true_class)
@jit
def update_step(params, images, labels, learning_rate):
loss, grads = value_and_grad(batch_loss_fn)(params, images, labels)
updated_params = jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g, params, grads
)
return updated_params, loss
# --- PyTorch Data Loading and Training Loop ---
def load_mnist_data_pytorch(batch_size=128):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10000, shuffle=False)
return train_loader, test_loader
if __name__ == '__main__':
train_loader, test_loader = load_mnist_data_pytorch()
# Network architecture: 784 -> 512 -> 10
layer_sizes = [784, 512, 10]
key = jax.random.PRNGKey(0)
params = init_params(key, layer_sizes)
epochs = 10
learning_rate = 0.01
print("Starting training...")
for epoch in range(epochs):
epoch_loss = 0.0
num_batches = 0
for images, labels in train_loader:
images_np = images.numpy()
labels_np = labels.numpy()
images_np = images_np.reshape(images_np.shape[0], -1)
labels_one_hot = jax.nn.one_hot(labels_np, 10)
params, loss_val = update_step(params, images_np, labels_one_hot, learning_rate)
epoch_loss += loss_val
num_batches += 1
avg_loss = epoch_loss / num_batches
test_images_tensor, test_labels_tensor = next(iter(test_loader))
test_images_np = test_images_tensor.numpy().reshape(-1, 784)
test_labels_one_hot = jax.nn.one_hot(test_labels_tensor.numpy(), 10)
test_acc = accuracy(params, test_images_np, test_labels_one_hot)
print(f"Epoch {epoch + 1}/{epochs} | Loss: {avg_loss:.4f} | Test Accuracy: {test_acc:.4f}")
Results and Verification
Let’s see if our definitions worked as expected by running the training script:
Starting training...
Epoch 1/10 | Loss: 0.5898 | Test Accuracy: 0.9009
Epoch 2/10 | Loss: 0.3269 | Test Accuracy: 0.9187
Epoch 3/10 | Loss: 0.2807 | Test Accuracy: 0.9262
Epoch 4/10 | Loss: 0.2536 | Test Accuracy: 0.9323
Epoch 5/10 | Loss: 0.2341 | Test Accuracy: 0.9363
Epoch 6/10 | Loss: 0.2186 | Test Accuracy: 0.9387
Epoch 7/10 | Loss: 0.2056 | Test Accuracy: 0.9428
Epoch 8/10 | Loss: 0.1943 | Test Accuracy: 0.9448
Epoch 9/10 | Loss: 0.1844 | Test Accuracy: 0.9466
Epoch 10/10 | Loss: 0.1755 | Test Accuracy: 0.9487
The model achieves over 94% accuracy on the MNIST test set after just 10 epochs of training, demonstrating that our custom backpropagation implementations are functioning correctly!
Let’s visualize some predictions from the trained model to see how well it performs.
Enjoy Reading This Article?
Here are some more articles you might like to read next: