blog

Gradients of Softmax and Logsumexp

Essential functions for categorical distributions and attention mechanisms in machine learning
4 min read /

Two mathematical functions that commonly arise in machine learning models are softmax and logsumexp. They occur when dealing with categorical and multinomial probability distributions, as well as in attention mechanisms used in neural network architectures (such as the transformer). The two functions are closely related in that they both involve sums of exponentials, with similar numerical considerations. In fact they are so closely related that deriving the gradient of softmax involves (or at least can involve) logsumexp, while the gradient of logsumexp is softmax. This note serves as a quick explanation of the functions as well as a derivation of their gradients for the purposes of implementation.

  1. Logsumexp and Softmax
  2. Gradient of Logsumexp
  3. Gradient of Softmax

Logsumexp and Softmax

The two functions of interest are logsumexp and softmax, defined as:

logsumexp(x):=log(iexp(xi))softmax(x):=exp(x)iexp(xi).\begin{aligned} \mathrm{logsumexp}(\mathbf{x}) & :=\log\left(\sum_{i}\exp(x_{i})\right)\\ \mathrm{softmax}(\mathbf{x}) & :=\frac{\exp(\mathbf{x})}{\sum_{i}\exp(x_{i})}. \end{aligned}

The first resolves to a scalar, whereas the second resolves to a vector of the same length as x\mathbf{x} (by exp(x)\exp(\mathbf{x}) we mean the exponential function applied to each element of x\mathbf{x}). The latter can be written in terms of the former as:

softmax(x)=exp(x)iexp(xi).=exp(log(exp(x)iexp(xi)))=exp(xlog(iexp(xi)))=exp(xlogsumexp(x)).\begin{aligned} \mathrm{softmax}(\mathbf{x}) & =\frac{\exp(\mathbf{x})}{\sum_{i}\exp(x_{i})}.\\ & =\exp\left(\log\left(\frac{\exp(\mathbf{x})}{\sum_{i}\exp(x_{i})}\right)\right)\\ & =\exp\left(\mathbf{x}-\log\left(\sum_{i}\exp(x_{i})\right)\right)\\ & =\exp\left(\mathbf{x}-\mathrm{logsumexp}\left(\mathbf{x}\right)\right). \end{aligned}

The two functions arise in several contexts, including categorical and multinomial distributions (where x\mathbf{x} is a vector of logarithms of unnormalized probabilities) and attention mechanisms in neural network architectures (where x\mathbf{x} is a vector of unnormalized attention weights).

When computing the softmax and logsumexp functions, implementations are usually adapted from the above definitions, rather than taking them exactly as written, for reasons of numerical stability. The exponential function is very sensitive to the scale of x\mathbf{x}, and in floating point arithmetic has a propensity to underflow to zero for large negative arguments, and overflow to infinity for large positive arguments. The effect is much worse in single precision than double precision, and much worse again in half precision—critical as these lower precisions are commonly used in machine learning to reduce compute time and memory use, especially on GPUs. To mitigate the issue, observe that softmax(xc)=softmax(x)\mathrm{softmax}(\mathbf{x} - c)=\mathrm{softmax}(\mathbf{x}) for any scalar cc (allowing the addition to broadcast the scalar). This can be used to translate x\mathbf{x} into a better range for numerical stability of the exponential function. A common choice is to set cc to the maximum element of x\mathbf{x}. In practice the functions may then be computed as:

logsumexp(x)=log(iexp(xic))+csoftmax(x):=exp(xc)iexp(xic).\begin{aligned} \mathrm{logsumexp}(\mathbf{x}) & =\log\left(\sum_{i}\exp(x_{i}-c)\right)+c\\ \mathrm{softmax}(\mathbf{x}) & :=\frac{\exp(\mathbf{x}-c)}{\sum_{i}\exp(x_{i}-c)}. \end{aligned}

For the purposes of gradients, however, we can go back to the original definitions. To simplify notation in what follows, let s=softmax(x)\mathbf{s}=\mathrm{softmax}(\mathbf{x}) and l=logsumexp(x)l=\mathrm{logsumexp}(\mathbf{x}).

Gradient of Logsumexp

Element-wise, using the usual rule for derivatives of the logarithmic function, we have:

lxi=xilog(jexp(xj))=exp(xi)jexp(xj)=si.\begin{aligned} \frac{\partial l}{\partial x_{i}} & =\frac{\partial}{\partial x_{i}}\log\left(\sum_{j}\exp(x_{j})\right)\\ & =\frac{\exp(x_{i})}{\sum_{j}\exp(x_{j})}\\ & =s_{i}. \end{aligned}

Written in vector form, this is:

lx=s.\frac{\partial l}{\partial\mathbf{x}}=\mathbf{s}.

This is, the gradient of logsumexp(x)\mathrm{logsumexp}(\mathbf{x}) is just softmax(x)\mathrm{softmax}(\mathbf{x}).

Gradient of Softmax

Element-wise, using the property introduced above that softmax(x)=exp(xlogsumexp(x))\mathrm{softmax}(\mathbf{x})=\exp\left(\mathbf{x}-\mathrm{logsumexp}\left(\mathbf{x}\right)\right) (and so s=exp(xl)\mathbf{s}=\exp(\mathbf{x}-l), and so si=exp(xil)s_{i}=\exp(x_{i}-l)), and the previously-derived l/xi=si\partial l/\partial x_{i}=s_{i}, we have:

sixj=xjexp(xil)={(1sj)exp(xil)i=jsjexp(xil)ij={(1sj)sii=jsjsiij.\begin{aligned} \frac{\partial s_{i}}{\partial x_{j}} & =\frac{\partial}{\partial x_{j}}\exp\left(x_{i}-l\right)\\ & =\begin{cases} \left(1-s_{j}\right)\exp\left(x_{i}-l\right) & i=j\\ -s_{j}\exp\left(x_{i}-l\right) & i\neq j \end{cases}\\ & =\begin{cases} (1-s_{j})s_{i} & i=j\\ -s_{j}s_{i} & i\neq j. \end{cases} \end{aligned}

We can write this as a matrix of all-pairs partial derivatives (the Jacobian) as:

J:=Iss.\mathbf{J}:=\mathbf{I}-\mathbf{s}\mathbf{s}^{\top}.

According to the application, we may never need to compute this matrix explicitly. A typical use case is computing gradients of a scalar function with respect to some parameters of that function. In machine learning, during the training of a model, that scalar function is a loss function, and reverse-mode automatic differentiation is used to backpropagate losses onto parameters for the next gradient update. Let ff denote such a scalar function and let f/s\partial f/\partial\mathbf{s} be given (i.e. working backward from a loss given by ff, reverse mode automatic differentiation has arrived at the output of the softmax, for which the gradient is f/s\partial f/\partial\mathbf{s}); we can write:

fxi=jfsjsjxi,\frac{\partial f}{\partial x_{i}}=\sum_{j}\frac{\partial f}{\partial s_{j}}\cdot\frac{\partial s_{j}}{\partial x_{i}},

which in matrix form is:

fx=Jfs=(Iss)fs=fss(sfs).\begin{aligned} \frac{\partial f}{\partial\mathbf{x}} & =\mathbf{J}\frac{\partial f}{\partial\mathbf{s}}\\ & =(\mathbf{I}-\mathbf{s}\mathbf{s}^{\top})\frac{\partial f}{\partial\mathbf{s}}\\ & =\frac{\partial f}{\partial\mathbf{s}}-\mathbf{s}\left(\mathbf{s}^{\top}\frac{\partial f}{\partial\mathbf{s}}\right). \end{aligned}

The last line is a more efficient way of computing the gradient than the first line, in both time and memory, as it avoids computing the whole Jacobian matrix.