Simple line drawing of a GPU
courses

Foundations of GPU Computing: Practical Exercises #1

In this practical, we work with C code that trains a deep neural network to predict trips in a bike sharing network. Both the C code and data set are provided for you. We review the code to understand the implementation, investigate performance by tweaking basic configuration options, and explore some essentials of CUDA programming such as memory management and kernel launches.

This practical is part of the course Foundations of GPU Computing.

Not familiar with C?

You will not be required to write any C code from scratch, only read and make guided modifications. Tips are provided along the way.

  1. Prerequisites
  2. Introduction (10 minutes)
  3. Orientation (10 minutes)
  4. Building and running (10 minutes)
  5. Simple tests (10 minutes)
    1. Kernels
    2. Streams
    3. Floating point precision
    4. Cleaning up
  6. Exploration (20 minutes)
    1. Building
    2. Floating point literals
    3. Memory allocation
    4. Kernels
    5. Model forward and backward passes
  7. Summary
  8. Possible extensions

Prerequisites

Make sure that your system satisfies the prerequisites given on the main course page. We will use Visual Studio Code, but not Nsight Systems, in this practical.

Start Visual Studio code and, if you plan to run code on a remote machine, connect to it.

Introduction (10 minutes)

The example code implements a simple feed-forward neural network to predict hourly trips of the Capital bike sharing scheme in Washington D.C. It implements both the forward and backward passes (there is no automatic differentiation!), an Adam optimizer, and the training and testing loop. The data (provided) gives the number of trips for each hour of 2019, aggregated over the entire city. Features include weather and sunlight.

The model can be described mathematically as follows. For each hour we predict the number of trips yy from features x\mathbf{x} using a neural network with LL number of layers. The forward pass is given by:

z1=W1x+b1zl=Wlzl1++bl for l=2,,L,\begin{aligned} \mathbf{z}_{1} & =\mathbf{W}_{1}\mathbf{x}+\mathbf{b}_{1}\\ \mathbf{z}_{l} & =\mathbf{W}_{l}\mathbf{z}_{l-1}^{+}+\mathbf{b}_{l}\text{ for }l=2,\ldots,L,\\ \end{aligned}

where superscript ++ denotes rectification and the final layer zL\mathbf{z}_L (which will actually be a scalar, so we also write zLz_L) forms the prediction (we reserve yy for the actual observation). The parameters are the weight matrices Wl\mathbf{W}_l and bias vectors bl\mathbf{b}_l for l=1,Ll=1\ldots,L.

We train the model by minimizing the mean of the squared error (zLy)2(z_L - y)^2 over all data points. The backward pass computes the gradient of this loss with respect to the parameters. Recall that, for a scalar function with vector argument, f(x)Rf(\mathbf{x}) \rightarrow \mathbb{R}, given y=Ax\mathbf{y}=\mathbf{A}\mathbf{x} and f/y\partial f/\partial\mathbf{y}, we have:

fx=AfyfA=fyx.\begin{aligned} \frac{\partial f}{\partial\mathbf{x}} & =\mathbf{A}^{\top}\frac{\partial f}{\partial\mathbf{y}}\\ \frac{\partial f}{\partial\mathbf{A}} & =\frac{\partial f}{\partial\mathbf{y}}\mathbf{x}^{\top}. \end{aligned}

The loss function is such a scalar function. We can compute all necessary gradients by working backward, recursively applying the chain rule (the same order as for reverse-mode automatic differentiation):

fzL=2(zLy)fzl1+=Wlfzl for l=2,,Lfzl1=fzl1+H(zl1+) for l=2,,LfWl=fzlzl1+ for l=2,,LfW1=fz1xfbl=fzl for l=1,,L.\begin{aligned} \frac{\partial f}{\partial z_L} & = 2(z_L - y) \\ \frac{\partial f}{\partial\mathbf{z}_{l-1}^{+}} & =\mathbf{W}_{l}^{\top}\frac{\partial f}{\partial\mathbf{z}_{l}}\text{ for }l=2,\ldots,L\\ \frac{\partial f}{\partial\mathbf{z}_{l-1}} & =\frac{\partial f}{\partial\mathbf{z}_{l-1}^{+}}\odot H(\mathbf{z}_{l-1}^{+})\text{ for }l=2,\ldots,L\\ \frac{\partial f}{\partial\mathbf{W}_{l}} & =\frac{\partial f}{\partial\mathbf{z}_{l}}\mathbf{z}_{l-1}^{+\top}\text{ for }l=2,\ldots,L\\ \frac{\partial f}{\partial\mathbf{W}_{1}} & =\frac{\partial f}{\partial\mathbf{z}_{1}}\mathbf{x}^{\top}\\ \frac{\partial f}{\partial\mathbf{b}_{l}} & =\frac{\partial f}{\partial\mathbf{z}_{l}}\text{ for }l=1,\ldots,L. \end{aligned}

where HH denotes the Heaviside step function:

ddxx+=H(x):={1,x>00,x0.\begin{aligned} \frac{\mathrm{d}}{\mathrm{d}x}x^{+} & = H(x):=\begin{cases} 1, & x>0\\ 0, & x\leq0 \end{cases}. \end{aligned}

Notice that we need to preserve the values of Wl\mathbf{W}_l, bl\mathbf{b}_l and zl+\mathbf{z}_l^+ from the forward pass in order to perform these calculations. We can discard the zl\mathbf{z}_l, however, which simplifies the implementation somewhat. In practice we will use mini-batches of data by stacking multiple data points as columns of a matrix of X\mathbf{X} and elements of a vector y\mathbf{y}. The computations are analogous.

For more detail on reverse-mode automatic differentiation, which we have just unrolled above, see Matrix Gradients of Scalar Functions.

Orientation (10 minutes)

We will start by obtaining the source code from GitHub.

From within Visual Studio Code, connect to your GPU instance, open a new terminal, and enter:

git clone https://github.com/lawmurray/gpu-course.git
cd gpu-course

On the left Explorer pane in Visual Studio Code, open the gpu-course directory to access the files.

The source code is split into the following files:

Not familiar with C?

C code is separated into header files and source files. The header files declare functions and the source files define them. Alternatively, we can think of the headers as providing the interface and the source files the implementation. Within any header or source file we can #include any number of header files to access additional functionality, analogous to importing packages in other languages.

By convention we use the following file extensions:

  • *.h for header files,
  • *.c for source files,
  • *.cu for source files that contain CUDA language extensions.

The build is controlled by the Makefile, which provides the commands to execute when calling make on the terminal. These commands use the CUDA compiler nvcc to compile function.cu, and your system compiler (probably gcc) for the remaining *.c files.

Not familiar with C?

Header files do not get compiled, only source files.

The data is provided in bikeshare.csv. The last column is the standardized label, the other columns are all features.

Building and running (10 minutes)

From within the gpu-course directory, build the program with:

make

then run it with:

./main

If either of these commands fails you may be missing some dependencies. If the code is running successfully, you will see one line per epoch, reporting test loss and elapsed time.

Throughout the following, whenever you are asked to build and run the code again, just rerun those two commands: make to build and ./main to run.

Simple tests (10 minutes)

Let’s verify some of the assertions in the opening lecture. The file src/config.h defines macros (#define NAME value) with various configuraiton options that we can modify.

Open src/config.h for editing in Visual Studio Code.

Kernels

For BLOCK_SIZE, try a few different power-of-two values, e.g. 32, 64, 128, …, 1024. After each change, build and run the program. How does performance vary? Choose a value that seems to optimize performance and set it to this. The existing choice, or something larger, is likely optimal.

When the block size is too small, we do not use the full capacity of the GPU, and so execution is slower. Setting the block size above 1024 will not work, as 1024 is the maximum size supported (this may vary by hardware and CUDA version).

The example code does not do much error checking and may appear to run with BLOCK_SIZE greater than 1024. We will touch on more robust practices in the closing lecture.

Streams

Run the code with:

CUDA_LAUNCH_BLOCKING=1 ./main

This command sets the environment variable CUDA_LAUNCH_BLOCKING to a value of 1. CUDA recognizes this and makes all kernel launches blocking, which is to say synchronous rather than asynchronous with respect to the CPU. Each time the CPU enqueues a kernel for the GPU, it will wait for the GPU to complete the work before proceeding.

The execution time should increase as a result of running this way. This highlights the importance of streaming computation when working with GPUs.

Floating point precision

Set ENABLE_SINGLE to 0 to use double precision instead of single precision. Build and run again to see how the execution time increases. How much slower is double precision? Once finished, restore ENABLE_SINGLE to 1.

Depending on your hardware, it could be anywhere from 2 to 64 times slower in double precision than single precision.

If you see a degradation in line with the ratio of single to double precision performance of your hardware, as measured in peak FLOPS, this might be interpreted as a good sign: at least the program performance is compute bound by instruction throughput rather memory bound by memory access latency.

Cleaning up

To finish up this section, make sure that the only change remaining in src/config.h is your choice of BLOCK_SIZE. Build one last time to make sure it is up to date.

Using git diff will identify any unintentional changes. If you simply want to reset back to the original sources, use git checkout . from within the gpu-course directory.

Exploration (20 minutes)

Take some time to browse through the code to better understand how it works.

The remaining material highlights some of the important points that you may discover during your exploration.

Building

The Makefile is worth a brief look. The CUDA compiler is nvcc. It works mostly as a wrapper around the system compiler (probably gcc), but behind-the-scenes calls the CUDA toolchain to compile CUDA-specific code.

Floating point literals

Throughout the code, notice that floating point literals are written with the suffix f, e.g. 0.0f and 1.0f instead of 0.0 and 1.0. Literals in C are typed: for floating point literals the f suffix denotes type float, while the absence of a suffix denotes type double.

In C, operands to arithmetic operations undergo numerical promotion, e.g.

In the final case, the operand of float type is promoted to double type before the addition.

If we are sloppy with literals, we can inadvertently perform computations that we intend to be in single precision in double precision instead. On GPUs in particular this has significant performance implications.

Memory allocation

Not familiar with C?

C is not a managed language, so it is the programmer’s responsibility to allocate and free memory as required. In a managed language this is handled by the runtime when, for example, objects and arrays are created and eventually garbage collected.

Note the initialization and finalization idiom: the source files src/data.c, src/model.c and src/optimizer.c define the structs data_t, model_t and optimizer_t, which are initialized by the functions data_init(), model_init() and optimizer_init(), and finalized by the functions data_term(), model_term() and optimizer_term(). The former set of functions allocate memory for all variables in the structs, while the latter set free it.

Not familiar with C?

A struct declares a compound type consisting of several variables. These are like classes in other languages, without support for member functions.

To allocate virtual memory accessible only by the CPU, we use the standard C library functions malloc() and free(). To allocate virtual memory that can be accessed by both the CPU and GPU, we use the CUDA Runtime API functions cudaMallocManaged() and cudaFree(). We have a number of different memory allocation alternatives available:

Allocate Free Description Accessible by
malloc() free() Virtual host memory Host only
cudaMalloc() cudaFree() Device memory Device only
cudaMallocAsync() cudaFreeAsync() Device memory Device only, asynchronous
cudaMallocHost() cudaFreeHost() Pinned host memory Host and device
cudaMallocManaged() cudaFree() Managed memory Host and device on demand

In the exercises above, ENABLE_PINNED had the effect of converting all cudaMallocManaged() calls to cudaMallocHost(), and all cudaFree() calls to cudaFreeHost().

Kernels

The source file src/function.cu provides kernel functions and, for each, an associated wrapper function. For example rectification is implemented in the kernel kernel_rectify(), which we call via the associated wrapper function rectify(). This pattern makes calling code that runs on GPU more convenient, as the wrapper function can establish the execution configuration rather than the caller.

The code follows BLAS/LAPACK conventions on the representation of vectors and matrices. A vector is specified by providing:

Not familiar with C?

Arrays a zero-based, i.e. for array x the first element is x[0] and the second x[1], etc.

For example, the elements of a vector of length 4 with an increment of 3 appear in memory as: Diagram

If we wish to apply a transformation to all elements of a vector, we might write a loop:

for (int i = 0; i < nelements; ++i) {
  x[incx*i] = f(x[incx*i]);
}

We could instead parallelize it by writing a kernel:

__global__ kernel_f(int nelements, float* x, int incx) {
  int i = blockIdx.x*blockDim.x + threadIdx.x;
  if (i < nelements) {
    x[incx*i] = f(x[incx*i]);
  }
}

The __global__ keyword is a CUDA language extension for specifying a kernel. Once launched, the kernel is executed for each thread in a grid of blocks of threads (recall: the execution configuration). To determine which particular thread an execution pertains to, we use the built-in variables:

Notice the guard if (i < nelements) { ... }. The execution configuration may not evenly divide through the number of elements, so we should ensure—with a conditional like this—that the particular thread actually has any work to do.

We can launch the kernel on the GPU with:

dim3 block(256); // ❶
dim3 grid((nelements + block.x - 1)/block.x); // ❷
kernel_f<<<grid,block>>>(x, incx); // ❸

Recall from the opening lecture that we launch a kernel using an execution configuration that specifies the size of a grid of blocks of threads. Line ❶ configures the block size to 256 threads (setting the block variable). Line ❷ configures the grid size to a sufficient number of blocks for one thread per element (setting the grid variable). Line ❸ launches the kernel, passing grid and block in triple angle brackets after the name of the kernel—this is a CUDA language extension for specifying execution configurations.

Not familiar with C?

The code (nelements + block.x - 1)/block.x above is a convenient way to write the integer division “nelements divided by block.x, rounding up”.

A matrix is stored column-major, and is specified by providing:

For example, the elements of a 3 by 4 matrix with a lead of 4 and column-major layout appear in memory as: Diagram

If we wish to apply a transformation to all elements of a matrix, we might write a loop:

for (int j = 0; j < nrows; ++j) {
  for (int i = 0; i < ncols; ++i) {
    A[i + ldA*j] = ...;
  }
}

We could instead parallelize it by writing a kernel:

__global__ transform_f(int nrows, int ncols, float* A, int ldA) {
  int i = blockIdx.x*blockDim.x + threadIdx.x;
  int j = blockIdx.y*blockDim.y + threadIdx.y;
  if (i < nrows && j < ncols) {
    A[i + ldA*j] = f(A[i + ldA*j]);
  }
}

and launch it with:

dim3 block(256, 1);
dim3 grid((nrows + block.x - 1)/block.x, ncols);
transform_f<<<grid,block>>>(nrows, ncols, A, ldA);

Note the similar conventions to the vector case, but we are now working with a two-dimensional execution configuration, and so use the additional y members of blockIdx, threadIdx and blockDim. A further z member is available to support three-dimensional execution configurations.

Model forward and backward passes

With respect to the model, the most important functions are model_forward() and model_backward() in src/model.c, which implement the forward and backward passes, respectively. The forward pass takes a minibatch of features and outputs a prediction. The backward pass takes the prediction and computes the gradient of the loss with respect to the parameters. The implementation aligns with the mathematical description above.

Summary

In this practical we have learned how to build and run a CUDA program. We have done some simple experiments on floating point precision and execution configurations to better understand performance tradeoffs. Finally, we have had a deep dive into the code to see how memory for the GPU is allocated and freed, and how kernels are specified and launched using specific CUDA language extensions.

Possible extensions

Alternative activation function
The model is implemented using rectification. We could replace this with an alternative activation function, such as a sigmoid, by implementing functions analogous to rectify() and rectify_grad().