blog

Matrix Gradients of Scalar Functions

Understanding the building blocks of reverse-mode automatic differentiation.
44 min read /

These notes were written for the implementation of reverse-mode automatic differentiation in Birch, specifically the NumBirch library that provides its numerics on both CPU and GPU. The motivation is to:

  1. consolidate some essential gradients in one place, especially those commonly required for probabilistic modeling (e.g. Cholesky factorizations, logarithms of determinants), and

  2. present those gradients in the form required for reverse-mode automatic differentiation.

While some gradients are simply stated and cited, others are derived from first principles using element-wise calculations as an aid to understanding.

  1. Background
  2. Results
    1. Transpose
    2. Multiplication
    3. Inverse
    4. Inverse of symmetric positive definite matrix
    5. Solve
    6. Solve with symmetric positive definite matrix
    7. Trace
    8. Logarithm of the determinant
    9. Logarithm of the determinant of a symmetric positive definite matrix
    10. Cholesky factorization
  3. References

Background

In the simplest setting, we are interested in a function ff that accepts a matrix argument X\mathbf{X} and returns a scalar result. A typical setting is model training, where ff is the objective function (e.g. log-likelihood or mean squared error), X\mathbf{X} are some parameters (e.g. weights and biases of a neural network), and we wish to compute f/X\partial f/\partial\mathbf{X}, being the gradient of ff with respect to X\mathbf{X}. Having the gradient, we then might apply a gradient-based update to the parameters, as when optimizing with stochastic gradient descent or sampling with Langevin Monte Carlo.

We will focus in particular on reverse-mode automatic differentation. For illustration, assume that ff can be decomposed into a single chain of simpler functions f1,,fnf_{1},\ldots,f_{n}:

f(X)=(fnf1)(X).f(\mathbf{X})=(f_{n}\circ\cdots\circ f_{1})(\mathbf{X}).

We first perform a forward pass to compute the intermediate results, first computing Y1=f1(X)\mathbf{Y}_{1}=f_{1}(\mathbf{X}) then proceeding recursively:

Yi=fi(Yi1)=(fif1)(X) for i=2,,n1y=fn(Yn1)=(fnf1)(X)=f(X).\begin{aligned} \mathbf{Y}_{i} & =f_{i}(\mathbf{Y}_{i-1})=(f_{i}\circ\cdots\circ f_{1})(\mathbf{X})\text{ for }i=2,\ldots,n-1\\ y & =f_{n}(\mathbf{Y}_{n-1})=(f_{n}\circ\cdot\circ f_{1})(\mathbf{X})=f(\mathbf{X}).\end{aligned}

We then perform a backward pass to compute the gradient by applying the chain rule, first computing f/Yn1\partial f/\partial\mathbf{Y}_{n-1} then proceeding recursively:

fYi=fYi+1Yi+1Yi for i=n2,,1fX=fY1Y1X.\begin{aligned} \frac{\partial f}{\partial\mathbf{Y}_{i}} & =\frac{\partial f}{\partial\mathbf{Y}_{i+1}}\cdot\frac{\partial\mathbf{Y}_{i+1}}{\partial\mathbf{Y}_{i}}\text{ for }i=n-2,\ldots,1\\ \frac{\partial f}{\partial\,\mathbf{X}} & =\frac{\partial f}{\partial\mathbf{Y}_{1}}\cdot\frac{\partial\mathbf{Y}_{1}}{\partial\mathbf{X}}.\end{aligned}

The intermediate results Yi\mathbf{Y}_{i} computed during the forward pass often appear in the computations required during the backward pass. They can be memoized for reuse rather than computed again.

A particular convenience of working in reverse mode for scalar functions is that each intermediate gradient has the same size (rows and columns) as the corresponding intermediate result, i.e. f/Yi\partial f/\partial\mathbf{Y}_{i} has the same size as Yi\mathbf{Y}_{i}, and f/X\partial f/\partial\mathbf{X} the same size as X\mathbf{X}.

The above is a simplification in that it assumes ff can be decomposed into a single chain. Generally, for the presence of intermediate functions with multiple arguments, it will decompose into a tree or graph (often called a compute graph in machine learning or expression template elsewhere). The principle is nonetheless the same: we apply the chain rule along branches of the graph.

We limit ourselves to reverse mode here, but it is not the only approach. Forward mode automatic differentation works in the opposite order, using a single forward pass to recursively compute Yi\mathbf{Y}_{i} and Yi/X\partial\mathbf{Y}_{i}/\partial\mathbf{X} simultaneously—but for scalar functions this foregoes the sizing convenience of reverse mode. Mixed mode denotes any other order, which can be designed to exploit specific structure in the compute graph.

Results

To derive a gradient of ff with respect to X\mathbf{X}, we assume that the upstream gradient f/g(X)\partial f/\partial g(\mathbf{X}) is known for some gg that denotes a matrix operation of interest, e.g. transpose g(X)=Xg(\mathbf{X})=\mathbf{X}^{\top} or inverse g(X)=X1g(\mathbf{X})=\mathbf{X}^{-1}. We then apply the chain rule element-wise to compute the partial derivative of ff with respect to the (i,j)(i,j)th element of X\mathbf{X}, denoted (X)ij(\mathbf{X})_{ij}:

f(X)ij=klfg(X)klg(X)kl(X)ij.\frac{\partial f}{\partial(\mathbf{X})_{ij}}=\sum_{kl}\frac{\partial f}{\partial g(\mathbf{X})_{kl}}\cdot\frac{\partial g(\mathbf{X})_{kl}}{\partial(\mathbf{X})_{ij}}.

We then simplify the element-wise expression into a matrix expression to allow the partial derivatives with respect to all elements to be computed simultaneously—also more efficiently using high-performance linear algebra routines—and so obtain the gradient.

Transpose

The transpose is straightforward: the gradient of ff with respect to X\mathbf{X} is the transpose of the gradient of ff with respect to X\mathbf{X}^{\top}. But to demonstrate use of the above formulation, let ff be a scalar function and f/X\partial f/\partial\mathbf{X}^{\top} be given; we then have, element-wise:

f(X)ij=klf(X)kl(X)kl(X)ij=klf(X)klδkjδli=f(X)ji,\begin{aligned} \frac{\partial f}{\partial(\mathbf{X})_{ij}} & =\sum_{kl}\frac{\partial f}{\partial\left(\mathbf{X}^{\top}\right){}_{kl}}\cdot\frac{\partial\left(\mathbf{X}^{\top}\right)_{kl}}{\partial(\mathbf{X})_{ij}}\\ & =\sum_{kl}\frac{\partial f}{\partial\left(\mathbf{X}^{\top}\right){}_{kl}}\cdot\delta_{kj}\delta_{li}\\ & =\frac{\partial f}{\partial\left(\mathbf{X}^{\top}\right){}_{ji}},\end{aligned}

where δij\delta_{ij} denotes the Kronecker delta (i.e. 1 when i=ji=j and 0 otherwise). We recognize this as a simple transpose operation on all elements simultaneously:

fX=(fX).\begin{aligned} \frac{\partial f}{\partial\mathbf{X}} & =\left(\frac{\partial f}{\partial\mathbf{X}^{\top}}\right)^{\top}.\end{aligned}

Multiplication

Let ff be a scalar function and f/XY\partial f/\partial\mathbf{XY} be given. Element-wise, we have:

f(X)ij=klf(XY)kl(XY)kl(X)ij=klf(XY)klδki(Y)jl=lf(XY)il(Y)jl.\begin{aligned} \frac{\partial f}{\partial(\mathbf{X})_{ij}} & =\sum_{kl}\frac{\partial f}{\partial(\mathbf{XY})_{kl}}\cdot\frac{\partial(\mathbf{XY})_{kl}}{\partial(\mathbf{X})_{ij}}\\ & =\sum_{kl}\frac{\partial f}{\partial(\mathbf{XY})_{kl}}\delta_{ki}(\mathbf{Y})_{jl}\\ & =\sum_{l}\frac{\partial f}{\partial(\mathbf{XY})_{il}}(\mathbf{Y})_{jl}.\end{aligned}

We observe from this that f/(X)ij\partial f/\partial(\mathbf{X})_{ij} is the dot product between the iith row of f/XY\partial f/\partial\mathbf{X}\mathbf{Y} and jjth row of Y\mathbf{Y}, and consolidate in matrix form as:

fX=fXYY.\frac{\partial f}{\partial\mathbf{X}}=\frac{\partial f}{\partial\mathbf{XY}}\mathbf{Y}^{\top}.

Next we wish to compute f/Y\partial f/\partial\mathbf{Y}. We can derive using the same element-wise approach, or simply apply the transpose property above to obtain:

fY=XfXY.\begin{aligned} \frac{\partial f}{\partial\mathbf{Y}} & =\mathbf{X}^{\top}\frac{\partial f}{\partial\mathbf{XY}}.\end{aligned}

Inverse

Let ff be a scalar function and f/X1\partial f/\partial\mathbf{X}^{-1} be given. We have (Petersen & Pedersen, 2016):

fX=XfX1X,\frac{\partial f}{\partial\mathbf{X}}=-\mathbf{X}^{-\top}\frac{\partial f}{\partial\mathbf{X}^{-1}}\mathbf{X}^{-\top},

using the shorthand notation X=(X1)=(X)1\mathbf{X}^{-\top}=\left(\mathbf{X}^{-1}\right)^{\top}=\left(\mathbf{X}^{\top}\right)^{-1}. We can derive this element-wise but it requires a non-obvious property given in (Petersen & Pedersen, 2016):=

(X1)kl(X)ij=(X1)ki(X1)jl.\frac{\partial(\mathbf{X}^{-1})_{kl}}{\partial(\mathbf{X})_{ij}}=(\mathbf{X}^{-1})_{ki}(\mathbf{X}^{-1})_{jl}.

Using this, the derivation proceeds:

f(X)ij=klf(X1)kl(X1)kl(X)ij=klf(X1)kl(X1)ki(X1)jl=k(X1)kilf(X1)kl(X1)jl,\begin{aligned} \frac{\partial f}{\partial(\mathbf{X})_{ij}} & =\sum_{kl}\frac{\partial f}{\partial(\mathbf{X}^{-1})_{kl}}\cdot\frac{\partial(\mathbf{X}^{-1})_{kl}}{\partial(\mathbf{X})_{ij}}\\ & =-\sum_{kl}\frac{\partial f}{\partial(\mathbf{X}^{-1})_{kl}}(\mathbf{X}^{-1})_{ki}(\mathbf{X}^{-1})_{jl}\\ & =-\sum_{k}(\mathbf{X}^{-1})_{ki}\sum_{l}\frac{\partial f}{\partial(\mathbf{X}^{-1})_{kl}}(\mathbf{X}^{-1})_{jl},\end{aligned}

from which we recognize the matrix form above.

Inverse of symmetric positive definite matrix

Let ff be a scalar function and f/S1\partial f/\partial\mathbf{S}^{-1} be given, where S\mathbf{S} is a symmetric positive definite matrix with Cholesky decomposition S=LL\mathbf{S}=\mathbf{L}\mathbf{L}^{\top}, and so inverse S1=LL1\mathbf{S}^{-1}=\mathbf{L}^{-\top}\mathbf{L}^{-1}, where L\mathbf{L} is a lower-triangular matrix. We use the transpose, multiplication and inverse properties to obtain:

fL=LfL1L=L((fS1L)+L1fS1)L=L(L1fST+L1fS1)L=LL1(fS+fS1)L.\begin{aligned} \frac{\partial f}{\partial\mathbf{L}} & =-\mathbf{\mathbf{L}}^{-\top}\frac{\partial f}{\partial\mathbf{L}^{-1}}\mathbf{\mathbf{L}}^{-\top}\\ & =-\mathbf{\mathbf{L}}^{-\top}\left(\left(\frac{\partial f}{\partial\mathbf{S}^{-1}}\mathbf{L}^{-\top}\right)^{\top}+\mathbf{L}^{-1}\frac{\partial f}{\partial\mathbf{S}^{-1}}\right)\mathbf{\mathbf{L}}^{-\top}\\ & =-\mathbf{\mathbf{L}}^{-\top}\left(\mathbf{L}^{-1}\frac{\partial f}{\partial\mathbf{S}^{-T}}+\mathbf{L}^{-1}\frac{\partial f}{\partial\mathbf{S}^{-1}}\right)\mathbf{\mathbf{L}}^{-\top}\\ & =-\mathbf{\mathbf{L}}^{-\top}\mathbf{L}^{-1}\left(\frac{\partial f}{\partial\mathbf{S}^{-\top}}+\frac{\partial f}{\partial\mathbf{S}^{-1}}\right)\mathbf{\mathbf{L}}^{-\top}.\end{aligned}

Solve

When computing, matrix inversions X1\mathbf{X}^{-1} are replaced with the solutions of equations Y=XZ\mathbf{Y}=\mathbf{X}\mathbf{Z} wherever possible, for numerical stability. We can think of this as Z=X1Y\mathbf{Z}=\mathbf{X}^{-1}\mathbf{Y}.

Let ff be a scalar function and f/X1Y\partial f/\partial\mathbf{X}^{-1}\mathbf{Y} be given. We use the multiplication and inverse properties above to obtain:

fY=XfX1YfX=XfX1X=XfX1YYX=fY(X1Y)=fYZ.\begin{aligned} \frac{\partial f}{\partial\mathbf{Y}} & =\mathbf{X}^{-\top}\frac{\partial f}{\partial\mathbf{X}^{-1}\mathbf{Y}}\\ \frac{\partial f}{\partial\mathbf{X}} & =-\mathbf{X}^{-\top}\frac{\partial f}{\partial\mathbf{X}^{-1}}\mathbf{X}^{-\top}\\ & =-\mathbf{X}^{-\top}\frac{\partial f}{\partial\mathbf{X}^{-1}\mathbf{Y}}\mathbf{Y}^{\top}\mathbf{X}^{-\top}\\ & =-\frac{\partial f}{\partial\mathbf{Y}}(\mathbf{X}^{-1}\mathbf{Y})^{\top}\\ & =-\frac{\partial f}{\partial\mathbf{Y}}\mathbf{Z}^{\top}.\end{aligned}

In the second last line we substitute in f/Y\partial f/\partial\mathbf{Y} from the first line, and in the last line Z=X1Y\mathbf{Z}=\mathbf{X}^{-1}\mathbf{Y} from the forward pass, to avoid repeating expensive computations.

Solve with symmetric positive definite matrix

Let ff be a scalar function and f/S1Y\partial f/\partial\mathbf{S}^{-1}\mathbf{Y} be given, where S\mathbf{S} is a symmetric positive definite matrix with Cholesky decomposition S=LL\mathbf{S}=\mathbf{L}\mathbf{L}^{\top}, and L\mathbf{L} a lower-triangular matrix. Using the solve and multiplication properties above we obtain:

fY=LL1fS1YfL=tril(fSL+(LfS))=tril((fS+fS)L)=tril((fYZ+ZfY)L).\begin{aligned} \frac{\partial f}{\partial\mathbf{Y}} & =\mathbf{L}^{-\top}\mathbf{L}^{-1}\frac{\partial f}{\partial\mathbf{S}^{-1}\mathbf{Y}}\\ \frac{\partial f}{\partial\mathbf{L}} & =\mathrm{tril}\left(\frac{\partial f}{\partial\mathbf{S}}\mathbf{L}+\left(\mathbf{L}^{\top}\frac{\partial f}{\partial\mathbf{S}}\right)^{\top}\right)\\ & =\mathrm{tril}\left(\left(\frac{\partial f}{\partial\mathbf{S}}+\frac{\partial f}{\partial\mathbf{S}^{\top}}\right)\mathbf{L}\right)\\ & =\mathrm{tril}\left(-\left(\frac{\partial f}{\partial\mathbf{Y}}\mathbf{Z}^{\top}+\mathbf{Z}\frac{\partial f}{\partial\mathbf{Y}^{\top}}\right)\mathbf{L}\right).\end{aligned}

where tril\mathrm{tril} is a function that returns a matrix with its lower triangle set to that of its argument, and its strictly upper triangle set to zero. This form gives the gradient in the space of lower-triangular matrices, rather than the space of arbitrary dense matrices, to ensure that a gradient update maintains L\mathbf{L} as lower triangular.

Trace

Let ff be a scalar function and f/tr(X)\partial f/\partial\mathrm{tr}(\mathbf{X}) be given. We have:

f(X)ij=ftr(X)tr(X)(X)ij=ftr(X).\frac{\partial f}{\partial(\mathbf{X})}_{ij}=\frac{\partial f}{\partial\mathrm{tr}(\mathbf{X})}\cdot\frac{\partial\mathrm{tr}(\mathbf{X})}{\partial(\mathbf{X})_{ij}}=\frac{\partial f}{\partial\mathrm{tr}(\mathbf{X})}.

Or in matrix form:

fX=tr(X)I.\frac{\partial f}{\partial\mathbf{X}}=\mathrm{tr}(\mathbf{X})\mathbf{I}.

Logarithm of the determinant

Let ff be a scalar function and f/det(X)\partial f/\partial\det(\mathbf{X}) be given. We have (Petersen & Pedersen, 2016):

fX=fdet(X)det(X)X=fdet(X)det(X)X.\frac{\partial f}{\partial\mathbf{X}}=\frac{\partial f}{\partial\det(\mathbf{X})}\cdot\frac{\partial\det(\mathbf{X})}{\partial\mathbf{X}}=\frac{\partial f}{\partial\det(\mathbf{X})}\det(\mathbf{X})\mathbf{X}^{-\top}.

By extension, and assuming that the determinant is positive so that we can take the logarithm, let f/log(det(X))\partial f/\partial\log\left(\det(\mathbf{X})\right) be given:

fX=flog(det(X))log(det(X))X=flog(det(X))X.\frac{\partial f}{\partial\mathbf{X}}=\frac{\partial f}{\partial\log\left(\det(\mathbf{X})\right)}\cdot\frac{\partial\log\left(\det(\mathbf{X})\right)}{\partial\mathbf{X}}=\frac{\partial f}{\partial\log\left(\det(\mathbf{X})\right)}\mathbf{X}^{-\top}.

These seem non-trivial to derive element-wise.

Logarithm of the determinant of a symmetric positive definite matrix

Let S\mathbf{S} be a symmetric positive definite matrix with Cholesky decomposition S=LL\mathbf{S}=\mathbf{L}\mathbf{L}^{\top}, with L\mathbf{L} a lower-triangular matrix consisting of positive elements along its main diagonal. The determinant of L\mathbf{L} (or indeed any triangular matrix) is the product of the elements along its main diagonal, and so the logarithm of the determinant the sum of the logarithms of the elements along its main diagonal.

Let ff be a scalar function and f/log(det(L))\partial f/\partial\log\left(\det(\mathbf{L})\right) be given. We then have:

f(L)ij=flog(det(L))log(det(L))(L)ij=flog(det(L))δij(L)ij,\begin{aligned} \frac{\partial f}{\partial(\mathbf{L})_{ij}} & =\frac{\partial f}{\partial\log\left(\det(\mathbf{L})\right)}\cdot\frac{\partial\log\left(\det(\mathbf{L})\right)}{\partial(\mathbf{L})_{ij}}\\ & =\frac{\partial f}{\partial\log\left(\det(\mathbf{L})\right)}\cdot\frac{\delta_{ij}}{(\mathbf{L})_{ij}},\end{aligned}

which can be written as:

fL=flog(det(L))diag(L)1,\frac{\partial f}{\partial\mathbf{L}}=\frac{\partial f}{\partial\log\left(\det(\mathbf{L})\right)}\cdot\mathrm{diag}(\mathbf{L})^{-1},

where diag\mathrm{diag} is a function that returns a diagonal matrix with its main diagonal set to that of the argument, and all off-diagonal elements set to zero.

As det(S)=det(L)2\det(\mathbf{S})=\det(\mathbf{L})^{2}, we have log(det(S))=2log(det(L))\log\left(\det(\mathbf{S})\right)=2\log\left(\det(\mathbf{L})\right) and so:

fL=2flog(det(S))diag(L)1.\frac{\partial f}{\partial\mathbf{L}}=2\frac{\partial f}{\partial\log\left(\det(\mathbf{S})\right)}\cdot\mathrm{diag}(\mathbf{L})^{-1}.

Cholesky factorization

Let ff be a scalar function and f/L\partial f/\partial\mathbf{L} be given, where S=LLT\mathbf{S}=\mathbf{L}\mathbf{L}^{T}. Define (Murray, 2016):

P=Φ(LfL),\begin{aligned} \mathbf{P} & =\Phi\left(\mathbf{L}^{\top}\frac{\partial f}{\partial\mathbf{L}}\right),\end{aligned}

where Φ\Phi is a function that returns a lower-triangular matrix with the strictly lower triangle set to that of its argument, the main diagonal set to half that of its argument, and the upper triangle set to zero. We then have (Murray, 2016):

fS=Φ(L(P+P)L1).\begin{aligned} \frac{\partial f}{\partial\mathbf{S}} & =\Phi\left(\mathbf{L}^{-\top}(\mathbf{P}+\mathbf{P}^{\top})\mathbf{L}^{-1}\right).\end{aligned}

This form gives the gradient in the space of symmetric matrices, rather than the space of arbitrary dense matrices. A gradient update on S\mathbf{S} will only affect the lower triangle. This is suitable when using linear algebra routines for symmetric matrices, which typically reference only one triangle (in this case, the lower) and assume that the other matches by symmetry. If a dense matrix must be formed, however, the strictly lower triangle should be transposed and copied into the strictly upper.

References