Simple line drawing of a GPU
blog

Fast Enumeration of Sums of Discrete Random Variables

A CUDA kernel.
5 min read /

I’ve been working on an upgrade to the delayed sampling implementation in Birch to remove reference cycles and reduce garbage collector overhead. One culprit for these is the handling of enumeration of sums of discrete random variables. I’ve taken the opportunity to also revisit performance, which we explore here.

The basic question:

  1. given a random variable xx on the integers {0,,m1}\{0,\ldots,m-1\} with a vector of probabilities p\mathbf{p} with pi=P(x=i)p_i = P(x = i), and
  2. another random variable yy on the integers {0,,n1}\{0,\ldots,n-1\} with a vector of probabilities q\mathbf{q} with qi=P(y=i)q_i = P(y = i),
  3. consider the sum x+yx + y on the integers {0,m+n2}\{0, m + n - 2\} and compute its vector of probabilities r\mathbf{r} with ri=P(x+y=i)r_i = P(x + y = i).

The basic computation is the convolution:

ri=j=max(0,in+1)min(m1,i)pjqij,r_i = \sum_{j = \max(0,i-n+1)}^{\min (m-1,i)} p_j q_{i-j},

but the aim is to develop a kernel to compute this efficiently on GPU.

Consider the outer product R=pq\mathbf{R} = \mathbf{p}\mathbf{q}^\top. To obtain rir_i we must sum over the iith skew diagonal of R\mathbf{R}. To see this, assume xx and yy are on {0,1,2,3}\{0, 1, 2, 3\}, and consider a diagram of R\mathbf{R} (skewed a little): Diagram

On the left are the possible outcomes for xx and yy. Each cell corresponds to a pair of such outcomes, marked with the sum x+yx + y in the diagram, but having the probability of that pair of outcomes, Rxy=pxqyR_{xy} = p_x q_y, in actual fact. Pairs that yield the same sum collect along the same skew diagonal—horizontal in this orientation. Summing along those skew diagonals give the probabilities rx+yr_{x+y} on the right.

To perform those sums on GPU, we ideally make use of fast warp shuffle functions. This requires threads to be arranged by skew diagonal, not by rows and columns as is more typical. Assume a warp size of four threads (for the sake of demonstration only—it’s actually 32 threads). With the 4×44 \times 4 matrix here we use 16 threads arranged into four such warps. In hexadecimal notation (i.e. AA through FF denote the decimal numbers 10 through 15), the four warps contain the threads (0,1,2,3)(0,1,2,3), (4,5,6,7)(4,5,6,7), (8,9,A,B)(8,9,A,B) and (C,D,E,F)(C,D,E,F). They are assigned to elements of the matrix as follows: Diagram With this arrangement, each warp spans two skew diagonals, with the exception of the last, which has only the main skew diagonal. For example, the first warp has thread 0 on the first skew diagonal, with the single pair where x+y=0x + y = 0, and threads 1, 2, 3 on the fifth skew diagonal, having the three pairs where x+y=4x + y = 4. This ordering has a regular pattern; thread kk maps to element (x,y)(x,y) with:

x=(k/4(kmod4))mod4y=kmod4.\begin{aligned} x &= (\lfloor k/4 \rfloor - (k \bmod 4)) \bmod 4 \\ y &= k \bmod 4. \end{aligned}

Alternatively, with two-dimensional thread indices, thread (i,j)(i,j), with i=k/4i = \lfloor k/4 \rfloor and j=kmod4j = k \bmod 4, maps to element (x,y)(x,y), with:

x=(ji)mod4y=j.\begin{aligned} x &= (j - i) \bmod 4 \\ y &= j. \end{aligned}

We also designate a single representative thread for each skew diagonal for operations that are performed only once per unique value of x+yx + y. These are the threads for which xmod4=0x \bmod 4 = 0 or xmod4=3x \bmod 4 = 3, excluding thread FF.

The above all holds for a standard warp size of 32. Just replace 4 with 32 and 3 with 31.

For m=n=32m = n = 32, the computation proceeds as follows:

  1. Initialize r\mathbf{r} to all zero.
  2. Launch a kernel with a two-dimensional execution configuration, of size 32 by 32.
  3. Each thread (i,j)(i,j) computes Rij=piqjR_{ij} = p_i q_j and stores the result in shared memory.
  4. All threads synchronize.
  5. Each thread (i,j)(i,j) adds RxyR_{xy} to rx+yr_{x+y}.

The matrix R\mathbf{R} is temporarily constructed in fast shared memory; no device memory is required. We can readily extend the kernel to cases where mm and nn are not 32. Larger matrices can be decomposed into tiles of size 32 by 32. Matrices where the number of rows or columns is not a multiple of 32 can be padded with zeros.

The last step can make use of warp shuffle functions, indeed this is the intention of the design. Specifically, we can use __shfl_xor_sync as in the butterfly sum example of the CUDA Programming Guide. While there is a __reduce_add_sync warp reduce function that would seem useful instead, it does not support floating point types.

A sketch of the code (untested) is as follows:

/**
 * Compute the probability distribution for a sum of two discrete random
 * variables.
 *
 * @param m Number of values for the first random variable.
 * @param n Number of values for the second random variable.
 * @param p Array of length @p m of probabilities for each value of the first
 * random variable.
 * @param incp Stride between elements of @p p.
 * @param q Array of length @p n of probabilities for each value of the
 * second random variable.
 * @param incq Stride between elements of @p q.
 * @param[out] r Array of length `m + n - 1` of probabilities for each sum.
 * @param incr Stride between elements of @p r.
 */
__global__ void kernel_add_enumerate(const int m, const int n,
    const float* p, const int incp, const float* q, const int incq, float* r,
    const int incr) {
  assert(blockDim.x == warpSize);
  assert(blockDim.y == warpSize);

  /* populate outer product */
  int x = threadIdx.x + blockIdx.x*warpSize;
  int y = threadIdx.y + blockIdx.y*warpSize;

  __shared__ R[warpSize*warpSize];
  R[threadIdx.x + threadIdx.y*warpSize] = (x < m && y < n) ? p[x]*q[y] : 0.0f;
  __syncthreads();
  
  /* local sum */
  x = (threadIdx.x - threadIdx.y) % warpSize;
  y = threadIdx.y;

  unsigned mask = 0xffffffff;
  if (x + y < warpSize) {
    mask <<= warpSize - x - y - 1;
  } else {
    mask >>= x + y + 1 - warpSize;
  }

  float value = R[x + y*warpSize];  
  for (int i = 16; i >= 1; i /= 2) {
    value += __shfl_xor_sync(mask, value, i, warpSize);
  }

  /* one thread only, for each row, updates the global sum */
  float mod = x % warpSize;
  if ((mod == 0 || mod == warpSize - 1) && r != warpSize*warpSize - 1) {    
    x += blockIdx.x*warpSize;
    y += blockIdx.y*warpSize;
    atomicAdd(r + (x + y)*incr, value);
  }
}

It can be launched with, for example:

void add_enumerate(const int m, const int n, const float* p,
    const int incp, const float* q, const int incq, float* r,
    const int incr) {
  cudaMemset(r, 0, (m + n - 1)*sizeof(float));
  dim3 block(warpSize, warpSize);
  dim3 grid((m + warpSize - 1)/warpSize, (n + warpSize - 1)/warpSize);
  kernel_add_enumerate<<<grid,block>>>(m, n, p, incp, q, incq, r, incr);
}

Further work may include: