Foundations of GPU Computing: Closing Lecture

Lawrence Murray

Outline

  1. Tie up some loose threads from the practicals.
  2. A taste of more advanced kernel programming.

This lecture is part of the course Foundations of GPU Computing.

Error checking

In the practicals, we’ve been very lazy about error checking.

Instead, the usual idiom is:

cudaError_t err = cudaMalloc(...);  // or kernel call, kernel<<<...>>>()
if (err != cudaSuccess) {
  printf("error: %s\n", cudaGetErrorString(err));
}

See Error Handling in the CUDA Runtime API documentation.

Streaming memory allocation

In the practicals, we allocated all memory upfront. What about situations where we do not know memory requirements upfront?

  • CUDA has cudaMallocAsync and cudaFreeAsync for streaming allocation of device memory. It achieves this with memory pools.

Some alternatives:

  • Rapids, also with memory pools.
  • Birch, via its numerical library NumBirch, by integrating with jemalloc for pools, arenas and thread caching. It works with managed memory.

Multiple streams

In the practicals, we used one stream to enqueue kernels for the GPU. We can in fact use many.

  • Multiple streams can execute concurrently, allowing concurrent kernel execution.
  • When we say that a GPU can do task parallelism, not just data parallelism, this is mostly what we mean.
  • Where individual kernels are not large enough to occupy the GPU, launching several concurrently may be.
  • Multiple streams can also be used to run on multiple GPUs.

Events

In the practicals, we used cudaDeviceSynchronize() to wait for the GPU to complete work enqueued in a stream. Events facilitate more fine-grained synchronization.

  • We create and destroy events with cudaEventCreate() and cudaEventDestroy().
  • We record an event in a stream with cudaEventRecord().
  • We can have the CPU wait for the event with cudaEventSynchronize().
  • We can have another stream wait for the event with cudaStreamWaitEvent().

Advanced kernel programming

In the practicals, we only looked at kernels that perform simple transformations, where an operation is applied element-wise to a vector or matrix, and threads do not need to interact.

  • We may be interested in kernels that involve interacting threads, such as reductions, scans, sorts and other such primitive operations. The shuffle that we used in the practicals is one such kernel, although we avoided implementing it for GPU.
  • We can write more complex transformations that involve arbitrary C++ code, including conditionals (if, else), loops (for, while) and function calls.

When might we need to write our own custom kernel?

Example: sum of two discrete random variables

This is needed for automatic marginalization in Birch.

The problem:

  1. given a random variable xxx on the integers {0,…,m−1}\{0,\ldots,m-1\}{0,…,m−1} with a vector of probabilities p\mathbf{p}p with pi=P(x=i)p_i = P(x = i)pi​=P(x=i), and
  2. another random variable yyy on the integers {0,…,n−1}\{0,\ldots,n-1\}{0,…,n−1} with a vector of probabilities q\mathbf{q}q with qi=P(y=i)q_i = P(y = i)qi​=P(y=i),
  3. consider the sum x+yx + yx+y on the integers {0,m+n−2}\{0, m + n - 2\}{0,m+n−2} and compute its vector of probabilities r\mathbf{r}r with ri=P(x+y=i)r_i = P(x + y = i)ri​=P(x+y=i).

The basic computation:

ri=∑j=max⁡(0,i−n+1)min⁡(m−1,i)pjqi−j.r_i = \sum_{j = \max(0,i-n+1)}^{\min (m-1,i)} p_j q_{i-j}.ri​=j=max(0,i−n+1)∑min(m−1,i)​pj​qi−j​.

Example: sum of two discrete random variables

We can write this in matrix form as:

r=(p00………⋮p00……pm−1⋮p00…0pm−1⋮⋱0⋮0pm−1⋮p0⋮⋮0⋱⋮⋮⋮⋮0pm−1)q,\mathbf{r} = \left(\begin{array}{cccccc} p_0 & 0 & \ldots & \ldots & \ldots \\ \vdots & p_0 & 0 & \ldots & \ldots \\ p_{m-1} & \vdots & p_0 & 0 & \ldots \\ 0 & p_{m-1} & \vdots & \ddots & 0 \\ \vdots & 0 & p_{m-1} & \vdots & p_0 \\ \vdots & \vdots & 0 & \ddots & \vdots \\ \vdots & \vdots & \vdots & 0 & p_{m-1} \\ \end{array}\right) \mathbf{q},r=​p0​⋮pm−1​0⋮⋮⋮​0p0​⋮pm−1​0⋮⋮​…0p0​⋮pm−1​0⋮​……0⋱⋮⋱0​………0p0​⋮pm−1​​​q,

where there are nnn number of columns.

Example: sum of two discrete random variables

Or illustrate it for m=n=4m = n = 4m=n=4:

Diagram

The numbers are not representing the values of the elements here, but rather the associated xxx and yyy. We can see that the matrix-vector product accumulates the probabilities of pairs of xxx and yyy that sum to the same outcome.

This is the product of a banded matrix and a vector.

Example: sum of two discrete random variables

  • cuBLAS provides the function family gbmv for the multiplication of banded matrices and vectors.
  • Recall from the practicals that we specify a matrix with a base pointer (e.g. A) and lead (e.g. ldA).
  • For this banded matrix we could use a lead of -1. Diagram

  • Off-band elements are ignored, so it does not matter if they are invalid.

Unfortunately, as with most (all?) BLAS implementations, the lead must be positive and at least as large as the number of rows in the matrix or an error is given. So we’ll need a custom kernel!

Version 0

The length of r\mathbf{r}r is m+n−1m + n - 1m+n−1. We can assign one thread to compute each element of r\mathbf{r}r.

__global__ void kernel_enumerate(const int m, const int n, const float* p,
    const int incp, const float* q, const int incq, float* r,
    const int incr) {
  /* element of r for which thread is responsible */
  int i = threadIdx.y + blockIdx.y*blockDim.y;
  
  /* sum across elements in each thread */
  float result = 0.0f;
  for (j = 0; j < n; ++j) {
    if (0 <= i - j && i - j < m) {
      result += p[(i - j)*incp]*q[j*incq];
    }
  }
  
  /* write element */
  r[i] = result;
}

Next issue: Each thread has a different amount of work to do, leading to warp divergence.

Warp divergence

  • The threads in a warp can execute only one common instruction at a time.
  • If threads take different execution paths, those execution paths must be serialized, e.g.
    • Threads split between true and false branches of a conditional: first threads on the true branch execute then threads on the false branch execute.
    • Threads have different trip counts for a loop: each thread stops after its trip count is reached while others may continue.
  • This is called warp divergence.

The fact that warps can diverge makes programming kernels much easier, but reducing warp divergence will improve performance.

Version 1: Reducing warp divergence

Consider the following transformation of the problem: Diagram

  • We now use mmm threads, where each computes two elements of rrr (besides thread m−1m - 1m−1, which only computes one).
  • Each thread now has an equal amount of work to do.

Version 1: Reducing warp divergence

__global__ void kernel_enumerate(const int m, const int n, const float* p,
    const int incp, const float* q, const int incq, float* r,
    const int incr) {
  /* first element of r for which thread is responsible */
  int i = threadIdx.y + blockIdx.y*blockDim.y;
  
  /* sum across elements in each thread */
  float result1 = 0.0f, result2 = 0.0f;
  for (j = 0; j < n; ++j) {
    if (0 <= i - j) {
      result1 += p[(i - j)*incp]*q[j*incq];
    } else {
      result2 += p[(m + i - j)*incp]*q[j*incq];
    }
  }
  
  /* write elements */
  r[i] = result1;
  if (i < m - 1) {
    r[i + m] = result2;
  }
}

Next issue: Each thread reads the whole vector q\mathbf{q}q from device memory.

Shared memory

We can use shared memory to read r\mathbf{r}r from device memory only once per thread block, i.e. threads in the same block can share a single read.

  • Within a thread block, shared memory can be used for communication between threads.
  • Shared memory is fast. It takes a chunk of L1 cache, the amount of which can be configured with cudaDeviceSetCacheConfig().

Version 2: Using shared memory

__global__ void kernel_enumerate(const int m, const int n, const float* p,
    const int incp, const float* q, const int incq, float* r,
    const int incr) {
  __shared__ float* q_shared;
  int i = threadIdx.y + blockIdx.y*blockDim.y;
  float result1 = 0.0f, result2 = 0.0f;
  
  for (int base = 0; base < n; base += blockDim.y) {
    /* read the next batch into shared memory */
    int j = threadIdx.y;
    q_shared[j] = base + j < n ? q[(base + j)*incq] : 0.0f;
    __syncthreads();

    /* sum across elements in each thread  */
    for (j = 0; j < blockDim.y; ++j) {
      if (0 <= i - j) {
        result1 += p[(base + i - j)*incp]*q_shared[j];
      } else {
        result2 += p[(base + m + i - j)*incp]*q_shared[j];
      }
    }
  }

  /* set the final result */
  r[i] = result1;
  if (i < m - 1) {
    r[i + m] = result2;
  }
}

When calling the kernel, a third argument is given to specify the amount of shared memory required: kernel_enumerate<<<grid,block,shared>>>(...).

Next issue: The kernel offers mmm-way parallelism, which may not be enough to occupy the GPU.

Increasing parallelism

The current kernel uses a thread per two elements of the output, giving mmm-way parallelism. If we use a warp per two elements of the output, we can increase this to 32m32m32m-way parallelism.

Each thread of the warp computes a partial sum, then they interact to compute the total sum. We have a few options for the second step:

  • Synchronization functions, the most basic of which is __syncthreads().
  • Fence functions
  • Atomic functions
  • Warp shuffle functions
  • Warp reduce functions

Version 3: Increasing parallelism

__global__ void kernel_enumerate(const int m, const int n,
    const float* p, const int incp, const float* q, const int incq, float* r,
    const int incr) {
  assert(blockDim.x == warpSize && gridDim.x == 1);

  __shared__ float* q_shared;
  int i = blockIdx.x*blockDim.x;
  float result1 = 0.0f, result2 = 0.0f;
  
  for (int base = 0; base < n; base += warpSize*blockDim.y) {
    /* read the next batch into shared memory */
    int j = threadIdx.y*warpSize + threadIdx.x;
    q_shared[j] = base + j < n ? q[(base + j)*incq] : 0.0f;
    __syncthreads();

    /* sum across elements in each thread  */
    for (j = threadIdx.x; j < warpSize*blockDim.y; j += warpSize) {
      /* read the next batch into shared memory */
      if (0 <= i - j) {
        result1 += p[(base + i - j)*incp]*q_shared[j];
      } else {
        result2 += p[(base + m + i - j)*incp]*q_shared[j];
      }
    }
  }

  /* sum across threads of warp, using butterfly sum */
  for (int k = 16; k >= 1; k /= 2) {
    result1 += __shfl_xor_sync(0xffffffff, result1, k, warpSize);
    result2 += __shfl_xor_sync(0xffffffff, result2, k, warpSize);
  }

  /* set the final result, only first thread in each warp */
  if (threadIdx.x == 0) {
    r[i] = result1;
    if (i < m - 1) {
      r[i + m] = result2;
    }
  }
}

Summary

Version 0
We started with a basic implementation of a kernel to enumerate a sum of two discrete random variables.
Version 1
We improved performance by reducing warp divergence.
Version 2
We improved performance by using shared memory to avoid repeated reads from device memory.
Version 3
We increased parallelism by having whole warps, rather than single threads, share the computation of each element of the output.

General tips

  • If results are different with and without synchronization, there is certainly a problem! You can debug such issues by setting CUDA_LAUNCH_BLOCKING=1 to enable synchronization.
  • For more focused interventions, insert cudaDeviceSynchronize() after any kernel launch and print out values for debugging.
  • Printing those values is much easier if you stick to using managed memory with cudaMallocManaged() rather than device memory with cudaMalloc(). It avoids the nuisance of copying values back to host memory first.
  • Using single precision during development can be a good way to detect and debug numerical issues, even if intending to use double precision for actual runs.
  • Given lack of extended precision, and especially if tempted by single precision, be careful with numerics:
    • Look for opportunities to use log-sum-exp or log-softmax operations instead of log(sum(exp(x))) and log(exp(x)/sum(exp(x))).
    • Look for opportunities to use log1p and expm1 instead of log(1.0 + x) and exp(x) - 1.0.
    • Avoid large sums.

Further reading

  • The CUDA Programming Guide is worth browsing.
  • The Nvidia Technical Blog has some good tutorial style articles on more specific topics.
  • My work at indii.org and birch.sh.