Gradients of Softmax and Logsumexp
Two mathematical functions that commonly arise in machine learning models are softmax and logsumexp. They occur when dealing with categorical and multinomial probability distributions, as well as in attention mechanisms used in neural network architectures (such as the transformer). The two functions are closely related in that they both involve sums of exponentials, with similar numerical considerations. In fact they are so closely related that deriving the gradient of softmax involves (or at least can involve) logsumexp, while the gradient of logsumexp is softmax. This note serves as a quick explanation of the functions as well as a derivation of their gradients for the purposes of implementation.
Logsumexp and Softmax
The two functions of interest are logsumexp and softmax, defined as:
The first resolves to a scalar, whereas the second resolves to a vector of the same length as (by we mean the exponential function applied to each element of ). The latter can be written in terms of the former as:
The two functions arise in several contexts, including categorical and multinomial distributions (where is a vector of logarithms of unnormalized probabilities) and attention mechanisms in neural network architectures (where is a vector of unnormalized attention weights).
When computing the softmax and logsumexp functions, implementations are usually adapted from the above definitions, rather than taking them exactly as written, for reasons of numerical stability. The exponential function is very sensitive to the scale of , and in floating point arithmetic has a propensity to underflow to zero for large negative arguments, and overflow to infinity for large positive arguments. The effect is much worse in single precision than double precision, and much worse again in half precision—critical as these lower precisions are commonly used in machine learning to reduce compute time and memory use, especially on GPUs. To mitigate the issue, observe that for any scalar (allowing the addition to broadcast the scalar). This can be used to translate into a better range for numerical stability of the exponential function. A common choice is to set to the maximum element of . In practice the functions may then be computed as:
For the purposes of gradients, however, we can go back to the original definitions. To simplify notation in what follows, let and .
Gradient of Logsumexp
Element-wise, using the usual rule for derivatives of the logarithmic function, we have:
Written in vector form, this is:
This is, the gradient of is just .
Gradient of Softmax
Element-wise, using the property introduced above that (and so , and so ), and the previously-derived , we have:
We can write this as a matrix of all-pairs partial derivatives (the Jacobian) as:
According to the application, we may never need to compute this matrix explicitly. A typical use case is computing gradients of a scalar function with respect to some parameters of that function. In machine learning, during the training of a model, that scalar function is a loss function, and reverse-mode automatic differentiation is used to backpropagate losses onto parameters for the next gradient update. Let denote such a scalar function and let be given (i.e. working backward from a loss given by , reverse mode automatic differentiation has arrived at the output of the softmax, for which the gradient is ); we can write:
which in matrix form is:
The last line is a more efficient way of computing the gradient than the first line, in both time and memory, as it avoids computing the whole Jacobian matrix.