Understanding the building blocks of reverse-mode automatic differentiation.
8
min read /
These notes were written for the implementation of reverse-mode automatic differentiation in Birch, specifically the NumBirch library that provides its numerics on both CPU and GPU. The motivation is to:
consolidate some essential gradients in one place, especially those commonly required for probabilistic modeling (e.g. Cholesky factorizations, logarithms of determinants), and
present those gradients in the form required for reverse-mode automatic differentiation.
While some gradients are simply stated and cited, others are derived from first principles using element-wise calculations as an aid to understanding.
In the simplest setting, we are interested in a function f that accepts a matrix argument X and returns a scalar result. A typical setting is model training, where f is the objective function (e.g. log-likelihood or mean squared error), X are some parameters (e.g. weights and biases of a neural network), and we wish to compute ∂f/∂X, being the gradient of f with respect to X. Having the gradient, we then might apply a gradient-based update to the parameters, as when optimizing with stochastic gradient descent or sampling with Langevin Monte Carlo.
We will focus in particular on reverse-mode automatic differentation. For illustration, assume that f can be decomposed into a single chain of simpler functions f1,…,fn:
f(X)=(fn∘⋯∘f1)(X).
We first perform a forward pass to compute the intermediate results, first computing Y1=f1(X) then proceeding recursively:
Yiy=fi(Yi−1)=(fi∘⋯∘f1)(X) for i=2,…,n−1=fn(Yn−1)=(fn∘⋅∘f1)(X)=f(X).
We then perform a backward pass to compute the gradient by applying the chain rule, first computing ∂f/∂Yn−1 then proceeding recursively:
∂Yi∂f∂X∂f=∂Yi+1∂f⋅∂Yi∂Yi+1 for i=n−2,…,1=∂Y1∂f⋅∂X∂Y1.
The intermediate results Yi computed during the forward pass often appear in the computations required during the backward pass. They can be memoized for reuse rather than computed again.
A particular convenience of working in reverse mode for scalar functions is that each intermediate gradient has the same size (rows and columns) as the corresponding intermediate result, i.e. ∂f/∂Yi has the same size as Yi, and ∂f/∂X the same size as X.
The above is a simplification in that it assumes f can be decomposed into a single chain. Generally, for the presence of intermediate functions with multiple arguments, it will decompose into a tree or graph (often called a compute graph in machine learning or expression template elsewhere). The principle is nonetheless the same: we apply the chain rule along branches of the graph.
We limit ourselves to reverse mode here, but it is not the only approach. Forward mode automatic differentation works in the opposite order, using a single forward pass to recursively compute Yi and ∂Yi/∂X simultaneously—but for scalar functions this foregoes the sizing convenience of reverse mode. Mixed mode denotes any other order, which can be designed to exploit specific structure in the compute graph.
Results
To derive a gradient of f with respect to X, we assume that the upstream gradient ∂f/∂g(X) is known for some g that denotes a matrix operation of interest, e.g. transpose g(X)=X⊤ or inverse g(X)=X−1. We then apply the chain rule element-wise to compute the partial derivative of f with respect to the (i,j)th element of X, denoted (X)ij:
∂(X)ij∂f=kl∑∂g(X)kl∂f⋅∂(X)ij∂g(X)kl.
We then simplify the element-wise expression into a matrix expression to allow the partial derivatives with respect to all elements to be computed simultaneously—also more efficiently using high-performance linear algebra routines—and so obtain the gradient.
Transpose
The transpose is straightforward: the gradient of f with respect to X is the transpose of the gradient of f with respect to X⊤. But to demonstrate use of the above formulation, let f be a scalar function and ∂f/∂X⊤ be given; we then have, element-wise:
where δij denotes the Kronecker delta (i.e. 1 when i=j and 0 otherwise). We recognize this as a simple transpose operation on all elements simultaneously:
∂X∂f=(∂X⊤∂f)⊤.
Multiplication
Let f be a scalar function and ∂f/∂XY be given. Element-wise, we have:
We observe from this that ∂f/∂(X)ij is the dot product between the ith row of ∂f/∂XY and jth row of Y, and consolidate in matrix form as:
∂X∂f=∂XY∂fY⊤.
Next we wish to compute ∂f/∂Y. We can derive using the same element-wise approach, or simply apply the transpose property above to obtain:
∂Y∂f=X⊤∂XY∂f.
Inverse
Let f be a scalar function and ∂f/∂X−1 be given. We have (Petersen & Pedersen, 2016):
∂X∂f=−X−⊤∂X−1∂fX−⊤,
using the shorthand notation X−⊤=(X−1)⊤=(X⊤)−1. We can derive this element-wise but it requires a non-obvious property given in (Petersen & Pedersen, 2016):=
Let f be a scalar function and ∂f/∂S−1 be given, where S is a symmetric positive definite matrix with Cholesky decomposition S=LL⊤, and so inverse S−1=L−⊤L−1, where L is a lower-triangular matrix. We use the transpose, multiplication and inverse properties to obtain:
When computing, matrix inversions X−1 are replaced with the solutions of equations Y=XZ wherever possible, for numerical stability. We can think of this as Z=X−1Y.
Let f be a scalar function and ∂f/∂X−1Y be given. We use the multiplication and inverse properties above to obtain:
In the second last line we substitute in ∂f/∂Y from the first line, and in the last line Z=X−1Y from the forward pass, to avoid repeating expensive computations.
Solve with symmetric positive definite matrix
Let f be a scalar function and ∂f/∂S−1Y be given, where S is a symmetric positive definite matrix with Cholesky decomposition S=LL⊤, and L a lower-triangular matrix. Using the solve and multiplication properties above we obtain:
where tril is a function that returns a matrix with its lower triangle set to that of its argument, and its strictly upper triangle set to zero. This form gives the gradient in the space of lower-triangular matrices, rather than the space of arbitrary dense matrices, to ensure that a gradient update maintains L as lower triangular.
Trace
Let f be a scalar function and ∂f/∂tr(X) be given. We have:
∂(X)∂fij=∂tr(X)∂f⋅∂(X)ij∂tr(X)=∂tr(X)∂f.
Or in matrix form:
∂X∂f=tr(X)I.
Logarithm of the determinant
Let f be a scalar function and ∂f/∂det(X) be given. We have (Petersen & Pedersen, 2016):
∂X∂f=∂det(X)∂f⋅∂X∂det(X)=∂det(X)∂fdet(X)X−⊤.
By extension, and assuming that the determinant is positive so that we can take the logarithm, let ∂f/∂log(det(X)) be given:
Logarithm of the determinant of a symmetric positive definite matrix
Let S be a symmetric positive definite matrix with Cholesky decomposition S=LL⊤, with L a lower-triangular matrix consisting of positive elements along its main diagonal. The determinant of L (or indeed any triangular matrix) is the product of the elements along its main diagonal, and so the logarithm of the determinant the sum of the logarithms of the elements along its main diagonal.
Let f be a scalar function and ∂f/∂log(det(L)) be given. We then have:
where diag is a function that returns a diagonal matrix with its main diagonal set to that of the argument, and all off-diagonal elements set to zero.
As det(S)=det(L)2, we have log(det(S))=2log(det(L)) and so:
∂L∂f=2∂log(det(S))∂f⋅diag(L)−1.
Cholesky factorization
Let f be a scalar function and ∂f/∂L be given, where S=LLT. Define (Murray, 2016):
P=Φ(L⊤∂L∂f),
where Φ is a function that returns a lower-triangular matrix with the strictly lower triangle set to that of its argument, the main diagonal set to half that of its argument, and the upper triangle set to zero. We then have (Murray, 2016):
∂S∂f=Φ(L−⊤(P+P⊤)L−1).
This form gives the gradient in the space of symmetric matrices, rather than the space of arbitrary dense matrices. A gradient update on S will only affect the lower triangle. This is suitable when using linear algebra routines for symmetric matrices, which typically reference only one triangle (in this case, the lower) and assume that the other matches by symmetry. If a dense matrix must be formed, however, the strictly lower triangle should be transposed and copied into the strictly upper.