Artistic image depicting the thread tiling strategy.
blog

Matrix Multiplication On GPU: Part 2, Tiling

Breaking down large matrix multiplications into tiles
8 min read /

As a follow up to the previous post, this post details the tiling scheme used by the matrix-matrix multiplication kernel. None of this is particularly elegant, it is mainly an exercise in sizing to fit hardware resources, then trial and error at the finer scales to strike a Goldilocks mix of floating point, integer and memory load/store instructions to get the performance just right.

  1. Background
  2. Thread block, global memory
  3. Thread block, shared memory
  4. Warp, shared memory
  5. Thread, registers
  6. Finishing up
  7. Summary

Background

The purpose of a tiling scheme is twofold:

  1. To parallelize by breaking down the original problem into smaller sub-problems that can be solved concurrently (divide and conquer). By way of available parallelism, a GPU executes many threads simultaneously—thousands or tens of thousands—that are organized into teams of 32 called warps, and further into teams of (for this kernel) 8 warps (256 threads) called thread blocks. The breakdown into smaller tiles reflects this thread hierarchy.

  2. To make best use of the memory hierarchy of the GPU by fitting those smaller sub-problems into faster memory types. The test device is an Nvidia GeForce RTX 4080 Laptop GPU, with the following memory types, from slowest & largest to fastest & smallest:

    Type Size Speed (latency)
    Global memory 12 GB Hundreds of cycles
    L2 cache 48 MB Tens of cycles
    Shared memory and/or L1 cache 128 KB per thread block A few cycles
    Registers 256 32-bit per thread One cycle

As an example to visualize, consider a matrix multiplication CABC \leftarrow AB where CC is a 1024x1024 matrix, AA a 1024x512 matrix, and BB a 512x1024 matrix, i.e. the dimensions of the problem are M=1024M = 1024, N=1024N = 1024, and K=512K = 512. Figure 1 depicts the first two levels of tiling breakdown.

Breakdown of the matrix multiplication into smaller sub-problems via tiling.

Figure 1: Breakdown of the matrix multiplication problem into smaller sub-problems. ■Grey represents the original matrices. ■Blue represents the first breakdown into tiles for thread blocks, with output size 256x128. ■Red represents the second breakdown into tiles for warps, with output size 64x64. Each of these is nested in the previous. The arrows and fine gridlines represent a further breakdown of the input tiles into 256x8 slices of AA and 8x128 slices of BB, for which pairwise products are summed to compute the output tile.

The following sections provide more details on each level of the scheme.

Thread block, global memory

At the first level, the output matrix is divided into output tiles of size 256x128. Each is assigned to a thread block to compute. To do so, it must compute the product of 256 rows of matrix AA by 128 columns of matrix BB. This is shown in ■blue in Figure 1.

To improve L2 cache utilization, thread blocks are assigned to output tiles according to a two-dimensional Hilbert curve.

Assigning output tiles to thread blocks with the Hilbert curve

The test device has 58 streaming multiprocessors. For this kernel, each of those can execute one thread block at a time. For simplicity, we can round that to 64 thread blocks at a time, and think about how to assign an output tile to each.

We could go column by column, so a 64x1 grid of 256x128 output tiles. Assuming a large enough problem, each of the 64 thread blocks would read in a different 256 rows of AA, but all would read the same 128 columns of BB. That would be reading in 64×256×K+1×128×K=16512K64 \times 256 \times K + 1 \times 128 \times K = 16512K elements to compute the 64 output tiles.

We could instead use a square, so an 8x8 grid of 256x128 output tiles. Between them, the thread blocks would then read in eight 256-row slices of AA and eight 128-column slices of BB, for 8×256×K+8×128×K=3072K8 \times 256 \times K + 8 \times 128 \times K = 3072K elements. That is, this configuration reads much fewer elements to compute the same amount of CC, i.e. 64 output tiles. In fact, a square configuration is optimal in this regard.

That motivates an 8x8 square grid, but not necessarily a Hilbert curve. The Hilbert curve has the advantage that we do not need to determine the optimal grid size for any particular device; the test device has 58 streaming multiprocessors, but other devices have more or fewer than this. The recursive formulation of the Hilbert curve gives a good configuration (c.f. cache oblivious algorithms), although not necessarily an optimal one. It also has the advantage that the next grid of 8x8 thread blocks is never far from the current, in fact it will have either the rows or the columns in common, and it is possible that some of the data from those rows and columns still resides in L2 cache so as to advantage that next grid, at least for smaller problem sizes.

Thread block, shared memory

The 256 rows of matrix AA are further broken into tiles of size 256x8, and the 128 columns of BB into tiles of size 8x128. This is still shown in ■blue in Figure 1, with the finest gridlines. These are copied into shared memory using a four stage pipeline: at any time the product of a pair of such tiles is being computed while the next three pairs are being copied asynchronously from global to shared memory. This allows overlap of copy and compute.

Warp, shared memory

Once in shared memory, the multiplication of the pairs of 256x8 tiles of AA by 8x128 tiles of BB are shared by the 8 warps of each thread block. Each warp computes the product of a 64x8 tile against an 8x64 tile. This is shown in ■red in Figure 1.

Synchronization

When using shared memory, threads in a block must synchronize to ensure that writes made into shared memory by those threads become visible to the other threads.

The thread block consists of 8 warps of 32 threads each. Similarly, the 256x128 output tile is broken into 8 subtiles of 64x64 elements each. Each warp then requires one of four 64-row slices of AA, and one of two 64-column slices of BB, to compute its 64x64 output tile of CC.

The copy of these tiles from global to shared memory can be arranged such that two warps copy each 64-row slice of AA and four warps copy each 64-column slice of BB. Each warp must then synchronize with the one other warp contributing to the copy of the 64-row slice that it requires, and the three other warps contributing to the copy of the 64-column slice that it requires (assuming that it participated in those copies too). That is, it takes a quarter block (2 warp) and half block (4 warp) synchronization, but not a full block (8 warp) synchronization to ensure that all the writes are visible where they need to be. This can improve performance significantly.

Thread, registers

Zooming into the warp-level computation, we have the thread-level computation shown in Figure 2. Recall that there are 32 threads in each warp. Each thread loads two 4x1 tiles of AA and one 1x16 tile of BB from shared memory to registers, computes the 8x16 product of those, and accumulates it back into registers. It does this eight times, being the length of the inner dimension (KK) of the tiles at this level. The rows handled by the thread are not contiguous, as this avoids shared memory bank conflicts.

Breakdown of the warp-level computation into the thread-level computation

Figure 2: Breakdown of the warp-level computation into the thread-level computation. As before, the thread block level is shown in ■blue and the warp level in ■red, but we are now zoomed in enough to see the thread level in ■darker red.

Finishing up

Once complete, each thread has accumulated an 8x16 tile in registers, representing its share of the output matrix CC. This is written to global memory to complete the computation.

Two other tricks:

  1. Where possible, 128-bit reads (i.e. four consecutive elements in single-precision floating point) are used when copying from global memory to shared memory, from shared memory to registers, and from registers to global memory. This reduces the number of instructions that must be executed to copy the data, but requires that the memory used to store the matrices is 128-bit aligned.

  2. When copying tiles of BB from global to shared memory, we take the opportunity to transpose the tile, as this improves performance later when copying from shared memory to registers. The transpose imposes 32-bit copies from global to shared memory rather than the preferred 128-bit copies (see above), but is still a net win.

In CUDA C++, 128-bit reads are enabled by using the float4 type. In PTX, the cp.async instruction takes, as one of its arguments, the number of bytes to copy.

Summary

This post presented the tiling scheme for the CUDA kernel for matrix-matrix multiplication introduced in the previous post. The breakdown is aimed at creating a hierarchy of parallelization that matches the hierarchy of the GPU hardware with respect to processors and memory types.