blog

# Matrix Gradients of Scalar Functions

#### Understanding the building blocks of reverse-mode automatic differentiation.

Lawrence Murray on 7 November 2022

These notes were written for the implementation of reverse-mode automatic differentiation in Birch, specifically the NumBirch library that provides its numerics on both CPU and GPU. The motivation is to:

1. consolidate some essential gradients in one place, especially those commonly required for probabilistic modeling (e.g. Cholesky factorizations, logarithms of determinants), and

2. present those gradients in the form required for reverse-mode automatic differentiation.

While some gradients are simply stated and cited, others are derived from first principles using element-wise calculations as an aid to understanding.

## Background

In the simplest setting, we are interested in a function $f$ that accepts a matrix argument $\mathbf{X}$ and returns a scalar result. A typical setting is model training, where $f$ is the objective function (e.g. log-likelihood or mean squared error), $\mathbf{X}$ are some parameters (e.g. weights and biases of a neural network), and we wish to compute $\partial f/\partial\mathbf{X}$, being the gradient of $f$ with respect to $\mathbf{X}$. Having the gradient, we then might apply a gradient-based update to the parameters, as when optimizing with stochastic gradient descent or sampling with Langevin Monte Carlo.

We will focus in particular on reverse-mode automatic differentation. For illustration, assume that $f$ can be decomposed into a single chain of simpler functions $f_{1},\ldots,f_{n}$: $f(\mathbf{X})=(f_{n}\circ\cdots\circ f_{1})(\mathbf{X}).$ We first perform a forward pass to compute the intermediate results, first computing $\mathbf{Y}_{1}=f_{1}(\mathbf{X})$ then proceeding recursively: \begin{aligned} \mathbf{Y}_{i} & =f_{i}(\mathbf{Y}_{i-1})=(f_{i}\circ\cdots\circ f_{1})(\mathbf{X})\text{ for }i=2,\ldots,n-1\\ y & =f_{n}(\mathbf{Y}_{n-1})=(f_{n}\circ\cdot\circ f_{1})(\mathbf{X})=f(\mathbf{X}).\end{aligned} We then perform a backward pass to compute the gradient by applying the chain rule, first computing $\partial f/\partial\mathbf{Y}_{n-1}$ then proceeding recursively: \begin{aligned} \frac{\partial f}{\partial\mathbf{Y}_{i}} & =\frac{\partial f}{\partial\mathbf{Y}_{i+1}}\cdot\frac{\partial\mathbf{Y}_{i+1}}{\partial\mathbf{Y}_{i}}\text{ for }i=n-2,\ldots,1\\ \frac{\partial f}{\partial\,\mathbf{X}} & =\frac{\partial f}{\partial\mathbf{Y}_{1}}\cdot\frac{\partial\mathbf{Y}_{1}}{\partial\mathbf{X}}.\end{aligned} The intermediate results $\mathbf{Y}_{i}$ computed during the forward pass often appear in the computations required during the backward pass. They can be memoized for reuse rather than computed again.

A particular convenience of working in reverse mode for scalar functions is that each intermediate gradient has the same size (rows and columns) as the corresponding intermediate result, i.e. $\partial f/\partial\mathbf{Y}_{i}$ has the same size as $\mathbf{Y}_{i}$, and $\partial f/\partial\mathbf{X}$ the same size as $\mathbf{X}$.

The above is a simplification in that it assumes $f$ can be decomposed into a single chain. Generally, for the presence of intermediate functions with multiple arguments, it will decompose into a tree or graph (often called a compute graph in machine learning or expression template elsewhere). The principle is nonetheless the same: we apply the chain rule along branches of the graph.

We limit ourselves to reverse mode here, but it is not the only approach. Forward mode automatic differentation works in the opposite order, using a single forward pass to recursively compute $\mathbf{Y}_{i}$ and $\partial\mathbf{Y}_{i}/\partial\mathbf{X}$ simultaneously—but for scalar functions this foregoes the sizing convenience of reverse mode. Mixed mode denotes any other order, which can be designed to exploit specific structure in the compute graph.

## Results

To derive a gradient of $f$ with respect to $\mathbf{X}$, we assume that the upstream gradient $\partial f/\partial g(\mathbf{X})$ is known for some $g$ that denotes a matrix operation of interest, e.g. transpose $g(\mathbf{X})=\mathbf{X}^{\top}$ or inverse $g(\mathbf{X})=\mathbf{X}^{-1}$. We then apply the chain rule element-wise to compute the partial derivative of $f$ with respect to the $(i,j)$th element of $\mathbf{X}$, denoted $(\mathbf{X})_{ij}$: $\frac{\partial f}{\partial(\mathbf{X})_{ij}}=\sum_{kl}\frac{\partial f}{\partial g(\mathbf{X})_{kl}}\cdot\frac{\partial g(\mathbf{X})_{kl}}{\partial(\mathbf{X})_{ij}}.$ We then simplify the element-wise expression into a matrix expression to allow the partial derivatives with respect to all elements to be computed simultaneously—also more efficiently using high-performance linear algebra routines—and so obtain the gradient.

### Transpose

The transpose is straightforward: the gradient of $f$ with respect to $\mathbf{X}$ is the transpose of the gradient of $f$ with respect to $\mathbf{X}^{\top}$. But to demonstrate use of the above formulation, let $f$ be a scalar function and $\partial f/\partial\mathbf{X}^{\top}$ be given; we then have, element-wise: \begin{aligned} \frac{\partial f}{\partial(\mathbf{X})_{ij}} & =\sum_{kl}\frac{\partial f}{\partial\left(\mathbf{X}^{\top}\right){}_{kl}}\cdot\frac{\partial\left(\mathbf{X}^{\top}\right)_{kl}}{\partial(\mathbf{X})_{ij}}\\ & =\sum_{kl}\frac{\partial f}{\partial\left(\mathbf{X}^{\top}\right){}_{kl}}\cdot\delta_{kj}\delta_{li}\\ & =\frac{\partial f}{\partial\left(\mathbf{X}^{\top}\right){}_{ji}},\end{aligned} where $\delta_{ij}$ denotes the Kronecker delta (i.e. 1 when $i=j$ and 0 otherwise). We recognize this as a simple transpose operation on all elements simultaneously: \begin{aligned} \frac{\partial f}{\partial\mathbf{X}} & =\left(\frac{\partial f}{\partial\mathbf{X}^{\top}}\right)^{\top}.\end{aligned}

### Multiplication

Let $f$ be a scalar function and $\partial f/\partial\mathbf{XY}$ be given. Element-wise, we have:

\begin{aligned} \frac{\partial f}{\partial(\mathbf{X})_{ij}} & =\sum_{kl}\frac{\partial f}{\partial(\mathbf{XY})_{kl}}\cdot\frac{\partial(\mathbf{XY})_{kl}}{\partial(\mathbf{X})_{ij}}\\ & =\sum_{kl}\frac{\partial f}{\partial(\mathbf{XY})_{kl}}\delta_{ki}(\mathbf{Y})_{jl}\\ & =\sum_{l}\frac{\partial f}{\partial(\mathbf{XY})_{il}}(\mathbf{Y})_{jl}.\end{aligned} We observe from this that $\partial f/\partial(\mathbf{X})_{ij}$ is the dot product between the $i$th row of $\partial f/\partial\mathbf{X}\mathbf{Y}$ and $j$th row of $\mathbf{Y}$, and consolidate in matrix form as: $\frac{\partial f}{\partial\mathbf{X}}=\frac{\partial f}{\partial\mathbf{XY}}\mathbf{Y}^{\top}.$ Next we wish to compute $\partial f/\partial\mathbf{Y}$. We can derive using the same element-wise approach, or simply apply the transpose property above to obtain: \begin{aligned} \frac{\partial f}{\partial\mathbf{Y}} & =\mathbf{X}^{\top}\frac{\partial f}{\partial\mathbf{XY}}.\end{aligned}

### Inverse

Let $f$ be a scalar function and $\partial f/\partial\mathbf{X}^{-1}$ be given. We have (Petersen & Pedersen, 2016): $\frac{\partial f}{\partial\mathbf{X}}=-\mathbf{X}^{-\top}\frac{\partial f}{\partial\mathbf{X}^{-1}}\mathbf{X}^{-\top},$ using the shorthand notation $\mathbf{X}^{-\top}=\left(\mathbf{X}^{-1}\right)^{\top}=\left(\mathbf{X}^{\top}\right)^{-1}$. We can derive this element-wise but it requires a non-obvious property given in (Petersen & Pedersen, 2016): $\frac{\partial(\mathbf{X}^{-1})_{kl}}{\partial(\mathbf{X})_{ij}}=(\mathbf{X}^{-1})_{ki}(\mathbf{X}^{-1})_{jl}.$ Using this, the derivation proceeds: \begin{aligned} \frac{\partial f}{\partial(\mathbf{X})_{ij}} & =\sum_{kl}\frac{\partial f}{\partial(\mathbf{X}^{-1})_{kl}}\cdot\frac{\partial(\mathbf{X}^{-1})_{kl}}{\partial(\mathbf{X})_{ij}}\\ & =-\sum_{kl}\frac{\partial f}{\partial(\mathbf{X}^{-1})_{kl}}(\mathbf{X}^{-1})_{ki}(\mathbf{X}^{-1})_{jl}\\ & =-\sum_{k}(\mathbf{X}^{-1})_{ki}\sum_{l}\frac{\partial f}{\partial(\mathbf{X}^{-1})_{kl}}(\mathbf{X}^{-1})_{jl},\end{aligned} from which we recognize the matrix form above.

### Inverse of symmetric positive definite matrix

Let $f$ be a scalar function and $\partial f/\partial\mathbf{S}^{-1}$ be given, where $\mathbf{S}$ is a symmetric positive definite matrix with Cholesky decomposition $\mathbf{S}=\mathbf{L}\mathbf{L}^{\top}$, and so inverse $\mathbf{S}^{-1}=\mathbf{L}^{-\top}\mathbf{L}^{-1}$, where $\mathbf{L}$ is a lower-triangular matrix. We use the transpose, multiplication and inverse properties to obtain: \begin{aligned} \frac{\partial f}{\partial\mathbf{L}} & =-\mathbf{\mathbf{L}}^{-\top}\frac{\partial f}{\partial\mathbf{L}^{-1}}\mathbf{\mathbf{L}}^{-\top}\\ & =-\mathbf{\mathbf{L}}^{-\top}\left(\left(\frac{\partial f}{\partial\mathbf{S}^{-1}}\mathbf{L}^{-\top}\right)^{\top}+\mathbf{L}^{-1}\frac{\partial f}{\partial\mathbf{S}^{-1}}\right)\mathbf{\mathbf{L}}^{-\top}\\ & =-\mathbf{\mathbf{L}}^{-\top}\left(\mathbf{L}^{-1}\frac{\partial f}{\partial\mathbf{S}^{-T}}+\mathbf{L}^{-1}\frac{\partial f}{\partial\mathbf{S}^{-1}}\right)\mathbf{\mathbf{L}}^{-\top}\\ & =-\mathbf{\mathbf{L}}^{-\top}\mathbf{L}^{-1}\left(\frac{\partial f}{\partial\mathbf{S}^{-\top}}+\frac{\partial f}{\partial\mathbf{S}^{-1}}\right)\mathbf{\mathbf{L}}^{-\top}.\end{aligned}

### Solve

When computing, matrix inversions $\mathbf{X}^{-1}$ are replaced with the solutions of equations $\mathbf{Y}=\mathbf{X}\mathbf{Z}$ wherever possible, for numerical stability. We can think of this as $\mathbf{Z}=\mathbf{X}^{-1}\mathbf{Y}$.

Let $f$ be a scalar function and $\partial f/\partial\mathbf{X}^{-1}\mathbf{Y}$ be given. We use the multiplication and inverse properties above to obtain: \begin{aligned} \frac{\partial f}{\partial\mathbf{Y}} & =\mathbf{X}^{-\top}\frac{\partial f}{\partial\mathbf{X}^{-1}\mathbf{Y}}\\ \frac{\partial f}{\partial\mathbf{X}} & =-\mathbf{X}^{-\top}\frac{\partial f}{\partial\mathbf{X}^{-1}}\mathbf{X}^{-\top}\\ & =-\mathbf{X}^{-\top}\frac{\partial f}{\partial\mathbf{X}^{-1}\mathbf{Y}}\mathbf{Y}^{\top}\mathbf{X}^{-\top}\\ & =-\frac{\partial f}{\partial\mathbf{Y}}(\mathbf{X}^{-1}\mathbf{Y})^{\top}\\ & =-\frac{\partial f}{\partial\mathbf{Y}}\mathbf{Z}^{\top}.\end{aligned}

In the second last line we substitute in $\partial f/\partial\mathbf{Y}$ from the first line, and in the last line $\mathbf{Z}=\mathbf{X}^{-1}\mathbf{Y}$ from the forward pass, to avoid repeating expensive computations.

### Solve with symmetric positive definite matrix

Let $f$ be a scalar function and $\partial f/\partial\mathbf{S}^{-1}\mathbf{Y}$ be given, where $\mathbf{S}$ is a symmetric positive definite matrix with Cholesky decomposition $\mathbf{S}=\mathbf{L}\mathbf{L}^{\top}$, and $\mathbf{L}$ a lower-triangular matrix. Using the solve and multiplication properties above we obtain: \begin{aligned} \frac{\partial f}{\partial\mathbf{Y}} & =\mathbf{L}^{-\top}\mathbf{L}^{-1}\frac{\partial f}{\partial\mathbf{S}^{-1}\mathbf{Y}}\\ \frac{\partial f}{\partial\mathbf{L}} & =\mathrm{tril}\left(\frac{\partial f}{\partial\mathbf{S}}\mathbf{L}+\left(\mathbf{L}^{\top}\frac{\partial f}{\partial\mathbf{S}}\right)^{\top}\right)\\ & =\mathrm{tril}\left(\left(\frac{\partial f}{\partial\mathbf{S}}+\frac{\partial f}{\partial\mathbf{S}^{\top}}\right)\mathbf{L}\right)\\ & =\mathrm{tril}\left(-\left(\frac{\partial f}{\partial\mathbf{Y}}\mathbf{Z}^{\top}+\mathbf{Z}\frac{\partial f}{\partial\mathbf{Y}^{\top}}\right)\mathbf{L}\right).\end{aligned} where $\mathrm{tril}$ is a function that returns a matrix with its lower triangle set to that of its argument, and its strictly upper triangle set to zero. This form gives the gradient in the space of lower-triangular matrices, rather than the space of arbitrary dense matrices, to ensure that a gradient update maintains $\mathbf{L}$ as lower triangular.

### Trace

Let $f$ be a scalar function and $\partial f/\partial\mathrm{tr}(\mathbf{X})$ be given. We have: $\frac{\partial f}{\partial(\mathbf{X})}_{ij}=\frac{\partial f}{\partial\mathrm{tr}(\mathbf{X})}\cdot\frac{\partial\mathrm{tr}(\mathbf{X})}{\partial(\mathbf{X})_{ij}}=\frac{\partial f}{\partial\mathrm{tr}(\mathbf{X})}.$ Or in matrix form: $\frac{\partial f}{\partial\mathbf{X}}=\mathrm{tr}(\mathbf{X})\mathbf{I}.$

### Logarithm of the determinant

Let $f$ be a scalar function and $\partial f/\partial\det(\mathbf{X})$ be given. We have (Petersen & Pedersen, 2016): $\frac{\partial f}{\partial\mathbf{X}}=\frac{\partial f}{\partial\det(\mathbf{X})}\cdot\frac{\partial\det(\mathbf{X})}{\partial\mathbf{X}}=\frac{\partial f}{\partial\det(\mathbf{X})}\det(\mathbf{X})\mathbf{X}^{-\top}.$ By extension, and assuming that the determinant is positive so that we can take the logarithm, let $\partial f/\partial\log\left(\det(\mathbf{X})\right)$ be given: $\frac{\partial f}{\partial\mathbf{X}}=\frac{\partial f}{\partial\log\left(\det(\mathbf{X})\right)}\cdot\frac{\partial\log\left(\det(\mathbf{X})\right)}{\partial\mathbf{X}}=\frac{\partial f}{\partial\log\left(\det(\mathbf{X})\right)}\mathbf{X}^{-\top}.$ These seem non-trivial to derive element-wise.

### Logarithm of the determinant of a symmetric positive definite matrix

Let $\mathbf{S}$ be a symmetric positive definite matrix with Cholesky decomposition $\mathbf{S}=\mathbf{L}\mathbf{L}^{\top}$, with $\mathbf{L}$ a lower-triangular matrix consisting of positive elements along its main diagonal. The determinant of $\mathbf{L}$ (or indeed any triangular matrix) is the product of the elements along its main diagonal, and so the logarithm of the determinant the sum of the logarithms of the elements along its main diagonal.

Let $f$ be a scalar function and $\partial f/\partial\log\left(\det(\mathbf{L})\right)$ be given. We then have: \begin{aligned} \frac{\partial f}{\partial(\mathbf{L})_{ij}} & =\frac{\partial f}{\partial\log\left(\det(\mathbf{L})\right)}\cdot\frac{\partial\log\left(\det(\mathbf{L})\right)}{\partial(\mathbf{L})_{ij}}\\ & =\frac{\partial f}{\partial\log\left(\det(\mathbf{L})\right)}\cdot\frac{\delta_{ij}}{(\mathbf{L})_{ij}},\end{aligned} which can be written as: $\frac{\partial f}{\partial\mathbf{L}}=\frac{\partial f}{\partial\log\left(\det(\mathbf{L})\right)}\cdot\mathrm{diag}(\mathbf{L})^{-1},$ where $\mathrm{diag}$ is a function that returns a diagonal matrix with its main diagonal set to that of the argument, and all off-diagonal elements set to zero.

As $\det(\mathbf{S})=\det(\mathbf{L})^{2}$, we have $\log\left(\det(\mathbf{S})\right)=2\log\left(\det(\mathbf{L})\right)$ and so: $\frac{\partial f}{\partial\mathbf{L}}=2\frac{\partial f}{\partial\log\left(\det(\mathbf{S})\right)}\cdot\mathrm{diag}(\mathbf{L})^{-1}.$

### Cholesky factorization

Let $f$ be a scalar function and $\partial f/\partial\mathbf{L}$ be given, where $\mathbf{S}=\mathbf{L}\mathbf{L}^{T}$. Define (Murray, 2016): \begin{aligned} \mathbf{P} & =\Phi\left(\mathbf{L}^{\top}\frac{\partial f}{\partial\mathbf{L}}\right),\end{aligned} where $\Phi$ is a function that returns a lower-triangular matrix with the strictly lower triangle set to that of its argument, the main diagonal set to half that of its argument, and the upper triangle set to zero. We then have (Murray, 2016): \begin{aligned} \frac{\partial f}{\partial\mathbf{S}} & =\Phi\left(\mathbf{L}^{-\top}(\mathbf{P}+\mathbf{P}^{\top})\mathbf{L}^{-1}\right).\end{aligned}

This form gives the gradient in the space of symmetric matrices, rather than the space of arbitrary dense matrices. A gradient update on $\mathbf{S}$ will only affect the lower triangle. This is suitable when using linear algebra routines for symmetric matrices, which typically reference only one triangle (in this case, the lower) and assume that the other matches by symmetry. If a dense matrix must be formed, however, the strictly lower triangle should be transposed and copied into the strictly upper.

## References

blog Latest
##### Sums of Discrete Random Variables as Banded Matrix Products

Zero-stride catch and a custom CUDA kernel.

Lawrence Murray

16 Mar 23

blog Related
##### GPU Programming in the Cloud

A how-to and round-up of cloud service providers.

Lawrence Murray

22 Nov 22

blog Next
##### Open Source Alternatives for Two Factor Authentication (2FA) Across Multiple Devices

Gnome Authenticator for Desktop, Aegis Authenticator for Android, import and export between.

Lawrence Murray

10 Nov 22

software Previous
##### Jekyll Responsive Magick

A Jekyll plugin for responsive images using ImageMagick. Works with Jekyll 4.