Open In Colab

Ep7: GPU and Hardware Acceleration, Part 1#

Install packages#

For this course, we will use some ongoing development in tvm, which is an open-source machine learning compilation framework. We provide the following command to install a packaged version for mlc course. This particular notebook depends on a CUDA11 environment.

!python3 -m  pip install mlc-ai-nightly-cu110 -f https://mlc.ai/wheels
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
Collecting mlc-ai-nightly-cu110
  Downloading https://github.com/mlc-ai/utils/releases/download/v0.9.dev0/mlc_ai_nightly_cu110-0.9.dev1956%2Bge3f218d71-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (237.1 MB)
     |████████████████████████████████| 237.1 MB 11 kB/s 
?25hRequirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly-cu110) (1.7.3)
Requirement already satisfied: Pygments in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly-cu110) (2.6.1)
Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly-cu110) (22.1.0)
Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly-cu110) (4.4.2)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly-cu110) (1.3.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly-cu110) (1.21.6)
Requirement already satisfied: tornado in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly-cu110) (5.1.1)
Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly-cu110) (5.4.8)
Collecting synr==0.6.0
  Downloading synr-0.6.0-py3-none-any.whl (18 kB)
Installing collected packages: synr, mlc-ai-nightly-cu110
Successfully installed mlc-ai-nightly-cu110-0.9.dev1956+ge3f218d71 synr-0.6.0
!nvidia-smi
Mon Aug 29 20:06:46 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:03:00.0  On |                  N/A |
| 16%   38C    P8    14W / 250W |      3MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  On   | 00000000:04:00.0 Off |                  N/A |
| 16%   32C    P8    20W / 250W |      3MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

Prelude#

In the past chapter, we discussed MLC flows in CPU environments. This chapter will discuss how to bring some of the optimizations onto GPU. We are going to use CUDA terminology. However, the same set of concepts applies to other kinds of GPUs as well.

Preparations#

To begin with, let us import the necessary dependencies.

import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T, relax as R
from tvm import relax
import numpy as np

# This is needed for deferring annotation parsing in TVMScript
from __future__ import annotations 

GPU Architecture#

Let us begin by reviewing what a GPU architecture looks like. A typical GPU contains a collection of stream multi-processors, and each multi-processor has many cores. A GPU device is massively parallel and allows us to execute many tasks concurrently.

image.png

To program a GPU, we need to create a set of thread blocks, with each thread mapping to the cores and the thread block map to the stream multiprocessors.

image.png

Let us start GPU programming using a vector add example. The following TensorIR program takes two vectors, A and B, performs element-wise add, and stores the result in C.

@tvm.script.ir_module
class MyModuleVecAdd:
    @T.prim_func
    def main(A: T.Buffer[(1024,), "float32"], 
             B: T.Buffer[(1024,), "float32"], 
             C: T.Buffer[(1024,), "float32"]) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i in T.grid(1024):
            with T.block("C"):
                vi = T.axis.remap("S", [i])
                C[vi] = A[vi] + B[vi]

We first split loop i into two loops.

sch = tvm.tir.Schedule(MyModuleVecAdd)
block_C = sch.get_block("C")
i, = sch.get_loops(block=block_C)
i0, i1 = sch.split(i, [None, 128])
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"], C: T.Buffer[1024, "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        for i_0, i_1 in T.grid(8, 128):
            with T.block("C"):
                vi = T.axis.spatial(1024, i_0 * 128 + i_1)
                T.reads(A[vi], B[vi])
                T.writes(C[vi])
                C[vi] = A[vi] + B[vi]
    

GPU Thread Blocks#

Then we bind the iterators to the GPU thread blocks. Each thread is parameterized by two indices – threadIdx.x and blockIdx.x. In practice, we can have multiple dimensional thread indices, but we keep them simple as one dimension.

image.png

sch.bind(i0, "blockIdx.x")
sch.bind(i1, "threadIdx.x")
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"], C: T.Buffer[1024, "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        for i_0 in T.thread_binding(8, thread="blockIdx.x"):
            for i_1 in T.thread_binding(128, thread="threadIdx.x"):
                with T.block("C"):
                    vi = T.axis.spatial(1024, i_0 * 128 + i_1)
                    T.reads(A[vi], B[vi])
                    T.writes(C[vi])
                    C[vi] = A[vi] + B[vi]
    

Build and Run the TensorIR Function on GPU#

We can build and test out the resulting function on the GPU.

rt_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(size=(1024,)).astype("float32")
B_np = np.random.uniform(size=(1024,)).astype("float32")
A_nd = tvm.nd.array(A_np, tvm.cuda(0))
B_nd = tvm.nd.array(B_np, tvm.cuda(0))
C_nd = tvm.nd.array(np.zeros((1024,), dtype="float32"), tvm.cuda(0))

rt_mod["main"](A_nd, B_nd, C_nd)
print(A_nd)
print(B_nd)
print(C_nd)
[20:07:11] /workspace/tvm/src/target/target_kind.cc:163: Warning: Unable to detect CUDA version, default to "-arch=sm_20" instead
---------------------------------------------------------------------------
TVMError                                  Traceback (most recent call last)
/media/pc/data/4tb/lxw/books/mlcbook/doc/study/7_GPU_and_Specialized_Hardware.ipynb Cell 22 in <cell line: 1>()
----> <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/7_GPU_and_Specialized_Hardware.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> rt_mod = tvm.build(sch.mod, target="cuda")
      <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/7_GPU_and_Specialized_Hardware.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a> A_np = np.random.uniform(size=(1024,)).astype("float32")
      <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/7_GPU_and_Specialized_Hardware.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a> B_np = np.random.uniform(size=(1024,)).astype("float32")

File /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages/tvm/driver/build_module.py:282, in build(inputs, args, target, target_host, runtime, name, binds)
    278     target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
    280 annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)
--> 282 rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
    284 annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)
    286 if not isinstance(target_host, Target):

File /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages/tvm/_ffi/_ctypes/packed_func.py:237, in PackedFuncBase.__call__(self, *args)
    225 ret_tcode = ctypes.c_int()
    226 if (
    227     _LIB.TVMFuncCall(
    228         self.handle,
   (...)
    235     != 0
    236 ):
--> 237     raise get_last_ffi_error()
    238 _ = temp_args
    239 _ = args

TVMError: Traceback (most recent call last):
  4: TVMFuncCall
  3: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS0_6ModuleERKNS0_3MapINS_6TargetENS_8IRModuleEvvEES7_EE17AssignTypedLambdaINS_L10__mk_TVM17MUlSB_S7_E_EEEvT_SsEUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SH_SL_
  2: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  1: tvm::codegen::Build(tvm::IRModule, tvm::Target)
  0: _ZN3tvm7runtime6deta
  File "/workspace/tvm/src/target/codegen.cc", line 58
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (bf != nullptr) is false: target.build.cuda is not enabled
np.testing.assert_allclose(C_nd.numpy(), A_np + B_np)

Window Sum Example#

Now, let us move forward to another example – window sum. This program can be viewed as a basic version of “convolution” with a predefined weight [1,1,1]. We are taking sliding over the input and add three neighboring values together.

image.png

@tvm.script.ir_module
class MyModuleWindowSum:
    @T.prim_func
    def main(A: T.Buffer[(1027,), "float32"], 
             B: T.Buffer[(1024,), "float32"]) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i in T.grid(1024):
            with T.block("C"):
                vi = T.axis.remap("S", [i])
                B[vi] = A[vi] + A[vi + 1] + A[vi + 2]

First, we can bind the loop to GPU threads.

sch = tvm.tir.Schedule(MyModuleWindowSum)
nthread = 128
block_C = sch.get_block("C")
i,  = sch.get_loops(block=block_C)
i0, i1 = sch.split(i, [None, nthread])
sch.bind(i0, "blockIdx.x")
sch.bind(i1, "threadIdx.x")
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[1027, "float32"], B: T.Buffer[1024, "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        for i_0 in T.thread_binding(8, thread="blockIdx.x"):
            for i_1 in T.thread_binding(128, thread="threadIdx.x"):
                with T.block("C"):
                    vi = T.axis.spatial(1024, i_0 * 128 + i_1)
                    T.reads(A[vi : vi + 3])
                    T.writes(B[vi])
                    B[vi] = A[vi] + A[vi + 1] + A[vi + 2]
    

image.png

Importantly, in this case, there are reuse opportunities. Remember that each GPU thread block contains shared memory that all threads can access within the block. We use cache_read to add an intermediate stage that caches segments (in green below) onto the shared memory. After the caching is finished, the threads can then read from the shared memory.

image.png

A_shared = sch.cache_read(block_C, read_buffer_index=0, storage_scope="shared")
sch.compute_at(A_shared, i1)
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[1027, "float32"], B: T.Buffer[1024, "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        A_shared = T.alloc_buffer([1027], dtype="float32", scope="shared")
        for i_0 in T.thread_binding(8, thread="blockIdx.x"):
            for i_1 in T.thread_binding(128, thread="threadIdx.x"):
                for ax0 in T.serial(130):
                    with T.block("A_shared"):
                        v0 = T.axis.spatial(1027, i_0 * 128 + ax0)
                        T.reads(A[v0])
                        T.writes(A_shared[v0])
                        A_shared[v0] = A[v0]
                with T.block("C"):
                    vi = T.axis.spatial(1024, i_0 * 128 + i_1)
                    T.reads(A_shared[vi : vi + 3])
                    T.writes(B[vi])
                    B[vi] = A_shared[vi] + A_shared[vi + 1] + A_shared[vi + 2]
    

Because the memory is shared across threads, we need to resplit the loop and bind the inner iterator of the fetching process onto the thread indices. This technique is called cooperative fetching, where multiple threads work together to bring the data onto the shared memory. The following reading process can be different.

ax = sch.get_loops(A_shared)[-1]
ax0, ax1 = sch.split(ax, [None, nthread])
sch.bind(ax1, "threadIdx.x")
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[1027, "float32"], B: T.Buffer[1024, "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        A_shared = T.alloc_buffer([1027], dtype="float32", scope="shared")
        for i_0 in T.thread_binding(8, thread="blockIdx.x"):
            for i_1 in T.thread_binding(128, thread="threadIdx.x"):
                for ax0_0 in T.serial(2):
                    for ax0_1 in T.thread_binding(128, thread="threadIdx.x"):
                        with T.block("A_shared"):
                            v0 = T.axis.spatial(1027, i_0 * 128 + (ax0_0 * 128 + ax0_1))
                            T.where(ax0_0 * 128 + ax0_1 < 130)
                            T.reads(A[v0])
                            T.writes(A_shared[v0])
                            A_shared[v0] = A[v0]
                with T.block("C"):
                    vi = T.axis.spatial(1024, i_0 * 128 + i_1)
                    T.reads(A_shared[vi : vi + 3])
                    T.writes(B[vi])
                    B[vi] = A_shared[vi] + A_shared[vi + 1] + A_shared[vi + 2]
    

We can inspect the corresponding low-level code (in CUDA). The generated code contains two parts:

  • A host part that calls into the GPU driver

  • A cuda kernel that runs the corresponding computation.

We can print out the cuda kernel using the following code. We still need both the host and kernel code to run the program, so it is only a quick way to inspect what the final code generation result.

Notably, the build process automatically compacts the shared memory stage to use a minimum region used within the thread block.

rt_mod = tvm.build(sch.mod, target="cuda")
print(rt_mod.imported_modules[0].get_source())
#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = long long;
  using uint64_t = unsigned long long;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t long long
  #define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(128) main_kernel0(float* __restrict__ A, float* __restrict__ B) {
  __shared__ float A_shared[130];
  for (int ax0_0 = 0; ax0_0 < 2; ++ax0_0) {
    if (((ax0_0 * 64) + (((int)threadIdx.x) >> 1)) < 65) {
      A_shared[((ax0_0 * 128) + ((int)threadIdx.x))] = A[(((((int)blockIdx.x) * 128) + (ax0_0 * 128)) + ((int)threadIdx.x))];
    }
  }
  __syncthreads();
  B[((((int)blockIdx.x) * 128) + ((int)threadIdx.x))] = ((A_shared[((int)threadIdx.x)] + A_shared[(((int)threadIdx.x) + 1)]) + A_shared[(((int)threadIdx.x) + 2)]);
}

Build Code for Other GPU Platforms#

A MLC process usually support targeting multiple kinds of hardware platforms, we can generate opencl code(which is another kind of GPU programming model) by changing the target parameter.

rt_mod = tvm.build(sch.mod, target="metal")
print(rt_mod.imported_modules[0].get_source())
// Function: main_kernel0
#include <metal_stdlib>
using namespace metal;

union __TVMArgUnion {
 int v_int[2];
};

kernel void main_kernel0(  device float* A [[ buffer(0) ]],
  device float* B [[ buffer(1) ]],
  uint blockIdx [[threadgroup_position_in_grid]],
  uint threadIdx [[thread_position_in_threadgroup]]
) {
  threadgroup float A_shared[130];
  for (int ax0_0 = 0; ax0_0 < 2; ++ax0_0) {
    if (((ax0_0 * 64) + (((int)threadIdx) >> 1)) < 65) {
      A_shared[((ax0_0 * 128) + ((int)threadIdx))] = A[(((((int)blockIdx) * 128) + (ax0_0 * 128)) + ((int)threadIdx))];
    }
  }
  threadgroup_barrier(mem_flags::mem_threadgroup);
  B[((((int)blockIdx) * 128) + ((int)threadIdx))] = ((A_shared[((int)threadIdx)] + A_shared[(((int)threadIdx) + 1)]) + A_shared[(((int)threadIdx) + 2)]);
}

Matrix Multiplication#

Let us now get to something slightly more complicated and try out optimizing matrix multiplication on GPU. We will go over two common techniques for GPU performance optimization.

@tvm.script.ir_module
class MyModuleMatmul:
    @T.prim_func
    def main(A: T.Buffer[(1024, 1024), "float32"], 
             B: T.Buffer[(1024, 1024), "float32"], 
             C: T.Buffer[(1024, 1024), "float32"]) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i, j, k in T.grid(1024, 1024, 1024):
            with T.block("C"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = 0.0
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

Local Blocking#

image.png

To increase overall memory reuse. We can tile the loops. In particular, we introduce local tiles such that we only need to load stripe of data from A and B once, then use them to perform a V * V matrix multiplication result.

This local tiling helps to reduce the memory pressure, as each element in the stripe is reused V times.

def blocking(sch, 
             tile_local_y, 
             tile_local_x, 
             tile_block_y, 
             tile_block_x,
             tile_k):
    block_C = sch.get_block("C")
    C_local = sch.cache_write(block_C, 0, "local")

    i, j, k = sch.get_loops(block=block_C)

    i0, i1, i2 = sch.split(loop=i, factors=[None, tile_block_y, tile_local_y])
    j0, j1, j2 = sch.split(loop=j, factors=[None, tile_block_x, tile_local_x])
    k0, k1 = sch.split(loop=k, factors=[None, tile_k])
    sch.unroll(k1)
    sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2)
    sch.reverse_compute_at(C_local, j1)

    sch.bind(i0, "blockIdx.y")
    sch.bind(j0, "blockIdx.x")

    sch.bind(i1, "threadIdx.y")
    sch.bind(j1, "threadIdx.x")
    sch.decompose_reduction(block_C, k0)

    return sch

sch = tvm.tir.Schedule(MyModuleMatmul)
sch = blocking(sch, 8, 8, 8, 8, 4)
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        C_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local")
        for i_0 in T.thread_binding(16, thread="blockIdx.y"):
            for j_0 in T.thread_binding(16, thread="blockIdx.x"):
                for i_1 in T.thread_binding(8, thread="threadIdx.y"):
                    for j_1 in T.thread_binding(8, thread="threadIdx.x"):
                        for i_2_init, j_2_init in T.grid(8, 8):
                            with T.block("C_init"):
                                vi = T.axis.spatial(1024, i_0 * 64 + i_1 * 8 + i_2_init)
                                vj = T.axis.spatial(1024, j_0 * 64 + j_1 * 8 + j_2_init)
                                T.reads()
                                T.writes(C_local[vi, vj])
                                C_local[vi, vj] = T.float32(0)
                        for k_0 in T.serial(256):
                            for k_1 in T.unroll(4):
                                for i_2, j_2 in T.grid(8, 8):
                                    with T.block("C_update"):
                                        vi = T.axis.spatial(1024, i_0 * 64 + i_1 * 8 + i_2)
                                        vj = T.axis.spatial(1024, j_0 * 64 + j_1 * 8 + j_2)
                                        vk = T.axis.reduce(1024, k_0 * 4 + k_1)
                                        T.reads(C_local[vi, vj], A[vi, vk], B[vk, vj])
                                        T.writes(C_local[vi, vj])
                                        C_local[vi, vj] = C_local[vi, vj] + A[vi, vk] * B[vk, vj]
                        for ax0, ax1 in T.grid(8, 8):
                            with T.block("C_local"):
                                v0 = T.axis.spatial(1024, i_0 * 64 + i_1 * 8 + ax0)
                                v1 = T.axis.spatial(1024, j_0 * 64 + j_1 * 8 + ax1)
                                T.reads(C_local[v0, v1])
                                T.writes(C[v0, v1])
                                C[v0, v1] = C_local[v0, v1]
    
rt_mod = tvm.build(sch.mod, target="cuda")
dev = tvm.cuda(0)
A_np = np.random.uniform(size=(1024, 1024)).astype("float32")
B_np = np.random.uniform(size=(1024, 1024)).astype("float32")
A_nd = tvm.nd.array(A_np, dev)
B_nd = tvm.nd.array(B_np, dev)
C_nd = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev)

num_flop = 2 * 1024 * 1024 * 1024
evaluator = rt_mod.time_evaluator("main", dev, number=10)


print("GEMM-Blocking: %f GFLOPS" % (num_flop / evaluator(A_nd, B_nd, C_nd).mean / 1e9))
GEMM-Blocking: 2378.997450 GFLOPS

Shared Memory Blocking#

image.png

Our first attempt did not consider the neighboring threads which sit in the same GPU thread block, and we can load the data they commonly need into a piece of shared memory.

The following transformation does that.

def cache_read_and_coop_fetch(sch, block, nthread, read_idx, read_loc):
    read_cache = sch.cache_read(block=block, read_buffer_index=read_idx, storage_scope="shared")
    sch.compute_at(block=read_cache, loop=read_loc)
    # vectorized cooperative fetch
    inner0, inner1 = sch.get_loops(block=read_cache)[-2:]
    inner = sch.fuse(inner0, inner1)
    _, tx, vec = sch.split(loop=inner, factors=[None, nthread, 4])
    sch.vectorize(vec)
    sch.bind(tx, "threadIdx.x")


def blocking_with_shared(
    sch, 
    tile_local_y, 
    tile_local_x, 
    tile_block_y, 
    tile_block_x,
    tile_k):
    block_C = sch.get_block("C")
    C_local = sch.cache_write(block_C, 0, "local")

    i, j, k = sch.get_loops(block=block_C)

    i0, i1, i2 = sch.split(loop=i, factors=[None, tile_block_y, tile_local_y])
    j0, j1, j2 = sch.split(loop=j, factors=[None, tile_block_x, tile_local_x])
    k0, k1 = sch.split(loop=k, factors=[None, tile_k])

    sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2)
    sch.reverse_compute_at(C_local, j1)

    sch.bind(i0, "blockIdx.y")
    sch.bind(j0, "blockIdx.x")

    tx = sch.fuse(i1, j1)
    sch.bind(tx, "threadIdx.x")
    nthread = tile_block_y * tile_block_x
    cache_read_and_coop_fetch(sch, block_C, nthread, 0, k0)
    cache_read_and_coop_fetch(sch, block_C, nthread, 1, k0)    
    sch.decompose_reduction(block_C, k0)

    return sch

sch = tvm.tir.Schedule(MyModuleMatmul)
sch = blocking_with_shared(sch, 8, 8, 8, 8, 8)
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        C_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local")
        A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared")
        B_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared")
        for i_0 in T.thread_binding(16, thread="blockIdx.y"):
            for j_0 in T.thread_binding(16, thread="blockIdx.x"):
                for i_1_j_1_fused in T.thread_binding(64, thread="threadIdx.x"):
                    for i_2_init, j_2_init in T.grid(8, 8):
                        with T.block("C_init"):
                            vi = T.axis.spatial(1024, i_0 * 64 + i_1_j_1_fused // 8 * 8 + i_2_init)
                            vj = T.axis.spatial(1024, j_0 * 64 + i_1_j_1_fused % 8 * 8 + j_2_init)
                            T.reads()
                            T.writes(C_local[vi, vj])
                            C_local[vi, vj] = T.float32(0)
                    for k_0 in T.serial(128):
                        for ax0_ax1_fused_0 in T.serial(2):
                            for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                                for ax0_ax1_fused_2 in T.vectorized(4):
                                    with T.block("A_shared"):
                                        v0 = T.axis.spatial(1024, i_0 * 64 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 8)
                                        v1 = T.axis.spatial(1024, k_0 * 8 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 8)
                                        T.reads(A[v0, v1])
                                        T.writes(A_shared[v0, v1])
                                        A_shared[v0, v1] = A[v0, v1]
                        for ax0_ax1_fused_0 in T.serial(2):
                            for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                                for ax0_ax1_fused_2 in T.vectorized(4):
                                    with T.block("B_shared"):
                                        v0 = T.axis.spatial(1024, k_0 * 8 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 64)
                                        v1 = T.axis.spatial(1024, j_0 * 64 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 64)
                                        T.reads(B[v0, v1])
                                        T.writes(B_shared[v0, v1])
                                        B_shared[v0, v1] = B[v0, v1]
                        for k_1, i_2, j_2 in T.grid(8, 8, 8):
                            with T.block("C_update"):
                                vi = T.axis.spatial(1024, i_0 * 64 + i_1_j_1_fused // 8 * 8 + i_2)
                                vj = T.axis.spatial(1024, j_0 * 64 + i_1_j_1_fused % 8 * 8 + j_2)
                                vk = T.axis.reduce(1024, k_0 * 8 + k_1)
                                T.reads(C_local[vi, vj], A_shared[vi, vk], B_shared[vk, vj])
                                T.writes(C_local[vi, vj])
                                C_local[vi, vj] = C_local[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj]
                    for ax0, ax1 in T.grid(8, 8):
                        with T.block("C_local"):
                            v0 = T.axis.spatial(1024, i_0 * 64 + i_1_j_1_fused // 8 * 8 + ax0)
                            v1 = T.axis.spatial(1024, j_0 * 64 + i_1_j_1_fused % 8 * 8 + ax1)
                            T.reads(C_local[v0, v1])
                            T.writes(C[v0, v1])
                            C[v0, v1] = C_local[v0, v1]
    
rt_mod = tvm.build(sch.mod, target="cuda")
dev = tvm.cuda(0)
evaluator = rt_mod.time_evaluator("main", dev, number=10)

print("GEMM-Blocking: %f GFLOPS" % (num_flop / evaluator(A_nd, B_nd, C_nd).mean / 1e9))
GEMM-Blocking: 5672.038389 GFLOPS

Leveraging Automatic Program Optimization#

So far, we have been manually writing transformations to optimize the TensorIR program on GPU. We can leverage the automatic program optimization framework to tune the same program. The following code does that, we only set a small number here, and it can take a few min to finish.

from tvm import meta_schedule as ms

sch_tuned = ms.tune_tir(
    mod=MyModuleMatmul,
    target="nvidia/tesla-p100",
    config=ms.TuneConfig(
      max_trials_global=64,
      num_trials_per_iter=64,
    ),
    work_dir="./tune_tmp",
    task_name="main"
)
2022-08-06 13:02:40.921 INFO Logging directory: ./tune_tmp/logs
2022-08-06 13:02:40.928 INFO Logging directory: ./tune_tmp/logs
2022-08-06 13:02:40.930 INFO Working directory: ./tune_tmp
2022-08-06 13:02:40.933 INFO Creating JSONDatabase. Workload at: ./tune_tmp/database_workload.json. Tuning records at: ./tune_tmp/database_tuning_record.json
2022-08-06 13:02:40.935 INFO LocalBuilder: max_workers = 1
2022-08-06 13:02:41.536 INFO LocalRunner: max_workers = 1
2022-08-06 13:02:42.113 INFO Initializing Task #0: "main"
2022-08-06 13:02:42.130 INFO 
 ID | Name |       FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated 
---------------------------------------------------------------------------------------------------------------
  0 | main | 2147483648 |      1 |            N/A |          N/A |                   N/A |      0 |            
---------------------------------------------------------------------------------------------------------------
Total trials: 0
Total latency (us): 0

2022-08-06 13:02:42.133 INFO Scheduler picks Task #0: "main"
2022-08-06 13:03:59.506 INFO Sending 64 sample(s) to builder
sch_tuned.mod.show()
rt_mod = tvm.build(sch_tuned.mod, target="nvidia/tesla-p100")
dev = tvm.cuda(0)
evaluator = rt_mod.time_evaluator("main", dev, number=10)

print("MetaSchedule: %f GFLOPS" % (num_flop / evaluator(A_nd, B_nd, C_nd).mean / 1e9))

Summary#

This chapter studies another axis of MLC – how we can transform our program for hardware acceleration. The MLC process helps us to bridge the input models toward different GPU programming models and environments. We will visit more hardware specialization topics in the incoming chapter as well.

  • A typical GPU contains two-level hierarchy. Each thread is indexed by(in cuda terminology) threadIdx.x and blockIdx.x(there can be multiple dimension indices as well, but they can be fused to one.

  • Shared memory helps cache data commonly used across the threads within the same block.

  • Encourage memory reuse during GPU optimization.