In the practicals, we’ve been very lazy about error checking.
Instead, the usual idiom is:
cudaError_t err = cudaMalloc(...); // or kernel call, kernel<<<...>>>()
if (err != cudaSuccess) {
printf("error: %s\n", cudaGetErrorString(err));
}
See Error Handling in the CUDA Runtime API documentation.
In the practicals, we allocated all memory upfront. What about situations where we do not know memory requirements upfront?
cudaMallocAsync
and cudaFreeAsync
for streaming allocation of device memory. It achieves this with memory pools.Some alternatives:
In the practicals, we used one stream to enqueue kernels for the GPU. We can in fact use many.
In the practicals, we used
cudaDeviceSynchronize()
to wait for the GPU to complete work enqueued in a stream. Events facilitate more fine-grained synchronization.
cudaEventCreate()
and cudaEventDestroy()
.cudaEventRecord()
.cudaEventSynchronize()
.cudaStreamWaitEvent()
.In the practicals, we only looked at kernels that perform simple transformations, where an operation is applied element-wise to a vector or matrix, and threads do not need to interact.
if
, else
), loops (for
, while
) and function calls.When might we need to write our own custom kernel?
This is needed for automatic marginalization in Birch.
The problem:
The basic computation:
ri=j=max(0,i−n+1)∑min(m−1,i)pjqi−j.We can write this in matrix form as:
r=p0⋮pm−10⋮⋮⋮0p0⋮pm−10⋮⋮…0p0⋮pm−10⋮……0⋱⋮⋱0………0p0⋮pm−1q,where there are n number of columns.
Or illustrate it for m=n=4:
The numbers are not representing the values of the elements here, but rather the associated x and y. We can see that the matrix-vector product accumulates the probabilities of pairs of x and y that sum to the same outcome.
This is the product of a banded matrix and a vector.
gbmv
for the multiplication of banded matrices and vectors.A
) and lead (e.g. ldA
).For this banded matrix we could use a lead of -1
.
Unfortunately, as with most (all?) BLAS implementations, the lead must be positive and at least as large as the number of rows in the matrix or an error is given. So we’ll need a custom kernel!
The length of r is m+n−1. We can assign one thread to compute each element of r.
__global__ void kernel_enumerate(const int m, const int n, const float* p,
const int incp, const float* q, const int incq, float* r,
const int incr) {
/* element of r for which thread is responsible */
int i = threadIdx.y + blockIdx.y*blockDim.y;
/* sum across elements in each thread */
float result = 0.0f;
for (j = 0; j < n; ++j) {
if (0 <= i - j && i - j < m) {
result += p[(i - j)*incp]*q[j*incq];
}
}
/* write element */
r[i] = result;
}
Next issue: Each thread has a different amount of work to do, leading to warp divergence.
The fact that warps can diverge makes programming kernels much easier, but reducing warp divergence will improve performance.
Consider the following transformation of the problem:
__global__ void kernel_enumerate(const int m, const int n, const float* p,
const int incp, const float* q, const int incq, float* r,
const int incr) {
/* first element of r for which thread is responsible */
int i = threadIdx.y + blockIdx.y*blockDim.y;
/* sum across elements in each thread */
float result1 = 0.0f, result2 = 0.0f;
for (j = 0; j < n; ++j) {
if (0 <= i - j) {
result1 += p[(i - j)*incp]*q[j*incq];
} else {
result2 += p[(m + i - j)*incp]*q[j*incq];
}
}
/* write elements */
r[i] = result1;
if (i < m - 1) {
r[i + m] = result2;
}
}
Next issue: Each thread reads the whole vector q from device memory.
We can use shared memory to read r from device memory only once per thread block, i.e. threads in the same block can share a single read.
cudaDeviceSetCacheConfig()
.__global__ void kernel_enumerate(const int m, const int n, const float* p,
const int incp, const float* q, const int incq, float* r,
const int incr) {
__shared__ float* q_shared;
int i = threadIdx.y + blockIdx.y*blockDim.y;
float result1 = 0.0f, result2 = 0.0f;
for (int base = 0; base < n; base += blockDim.y) {
/* read the next batch into shared memory */
int j = threadIdx.y;
q_shared[j] = base + j < n ? q[(base + j)*incq] : 0.0f;
__syncthreads();
/* sum across elements in each thread */
for (j = 0; j < blockDim.y; ++j) {
if (0 <= i - j) {
result1 += p[(base + i - j)*incp]*q_shared[j];
} else {
result2 += p[(base + m + i - j)*incp]*q_shared[j];
}
}
}
/* set the final result */
r[i] = result1;
if (i < m - 1) {
r[i + m] = result2;
}
}
When calling the kernel, a third argument is given to specify the amount of shared memory required: kernel_enumerate<<<grid,block,shared>>>(...)
.
Next issue: The kernel offers m-way parallelism, which may not be enough to occupy the GPU.
The current kernel uses a thread per two elements of the output, giving m-way parallelism. If we use a warp per two elements of the output, we can increase this to 32m-way parallelism.
Each thread of the warp computes a partial sum, then they interact to compute the total sum. We have a few options for the second step:
__syncthreads()
.__global__ void kernel_enumerate(const int m, const int n,
const float* p, const int incp, const float* q, const int incq, float* r,
const int incr) {
assert(blockDim.x == warpSize && gridDim.x == 1);
__shared__ float* q_shared;
int i = blockIdx.x*blockDim.x;
float result1 = 0.0f, result2 = 0.0f;
for (int base = 0; base < n; base += warpSize*blockDim.y) {
/* read the next batch into shared memory */
int j = threadIdx.y*warpSize + threadIdx.x;
q_shared[j] = base + j < n ? q[(base + j)*incq] : 0.0f;
__syncthreads();
/* sum across elements in each thread */
for (j = threadIdx.x; j < warpSize*blockDim.y; j += warpSize) {
/* read the next batch into shared memory */
if (0 <= i - j) {
result1 += p[(base + i - j)*incp]*q_shared[j];
} else {
result2 += p[(base + m + i - j)*incp]*q_shared[j];
}
}
}
/* sum across threads of warp, using butterfly sum */
for (int k = 16; k >= 1; k /= 2) {
result1 += __shfl_xor_sync(0xffffffff, result1, k, warpSize);
result2 += __shfl_xor_sync(0xffffffff, result2, k, warpSize);
}
/* set the final result, only first thread in each warp */
if (threadIdx.x == 0) {
r[i] = result1;
if (i < m - 1) {
r[i + m] = result2;
}
}
}
CUDA_LAUNCH_BLOCKING=1
to enable synchronization.cudaDeviceSynchronize()
after any kernel launch and print out values for debugging.cudaMallocManaged()
rather than device memory with cudaMalloc()
. It avoids the nuisance of copying values back to host memory first.log-sum-exp
or log-softmax
operations instead of log(sum(exp(x)))
and log(exp(x)/sum(exp(x)))
.log1p
and expm1
instead of log(1.0 + x)
and exp(x) - 1.0
.