{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# How to optimize convolution on GPU\n**Author**: [Haichen Shen](https://homes.cs.washington.edu/~haichen/)\n\nIn this tutorial, we will demonstrate how to write a high performance\nconvolution implementation in TVM. We use square size input tensors and filters\nas an example, and assume the input to convolution has a large batch. In this\nexample, we use a different layout to store the data in order to achieve better\ndata locality. The buffer layout is HWCN, which stands for height, width,\nchannel, batch.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparation and Algorithm\n\nWe use the fixed size for input tensors with 256 channels and 14 x 14\ndimensions. The batch size is 256. Convolution filters contain 512 filters\nof size 3 x 3. We use stride size 1 and padding size 1 for the\nconvolution. The following code defines the convolution algorithm in TVM.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\nimport tvm\nfrom tvm import te\n\n# The sizes of inputs and filters\nbatch = 256\nin_channel = 256\nout_channel = 512\nin_size = 14\nkernel = 3\npad = 1\nstride = 1\n\n# Algorithm\nA = te.placeholder((in_size, in_size, in_channel, batch), name=\"A\")\nW = te.placeholder((kernel, kernel, in_channel, out_channel), name=\"W\")\nout_size = (in_size - kernel + 2 * pad) // stride + 1\n# Pad input\nApad = te.compute(\n (in_size + 2 * pad, in_size + 2 * pad, in_channel, batch),\n lambda yy, xx, cc, nn: tvm.tir.if_then_else(\n tvm.tir.all(yy >= pad, yy - pad < in_size, xx >= pad, xx - pad < in_size),\n A[yy - pad, xx - pad, cc, nn],\n tvm.tir.const(0.0, \"float32\"),\n ),\n name=\"Apad\",\n)\n# Create reduction variables\nrc = te.reduce_axis((0, in_channel), name=\"rc\")\nry = te.reduce_axis((0, kernel), name=\"ry\")\nrx = te.reduce_axis((0, kernel), name=\"rx\")\n# Compute the convolution\nB = te.compute(\n (out_size, out_size, out_channel, batch),\n lambda yy, xx, ff, nn: te.sum(\n Apad[yy * stride + ry, xx * stride + rx, rc, nn] * W[ry, rx, rc, ff], axis=[ry, rx, rc]\n ),\n name=\"B\",\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Memory Hierarchy\n\nWe first specify the memory hierarchy for buffers. The figure below shows the\nGPU memory hierarchy. One important difference from CPU memory hierarchy is\nthat GPU provides a cache buffer called shared memory, which is managed by\nprogrammers. Thus how to maximize the data reuse in the shared memory is\ncritical to achieve high performance in GPU kernels.\n\n\n\nIn this example, we load both Apad and W into buffer AA and WW, which are\nstored in the shared memory. These buffers will be later shared by all\nthreads within the same thread block to compute the convolution. Each thread\nthen loads its own part from shared buffer into their local registers, AL and\nWL. BL is a local cache of output B, which is also stored in the thread local\nregisters.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Designate the memory hierarchy\ns = te.create_schedule(B.op)\ns[Apad].compute_inline() # compute Apad inline\nAA = s.cache_read(Apad, \"shared\", [B])\nWW = s.cache_read(W, \"shared\", [B])\nAL = s.cache_read(AA, \"local\", [B])\nWL = s.cache_read(WW, \"local\", [B])\nBL = s.cache_write(B, \"local\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Blocking\n\nThe following code splits the workload into thread blocks and individual\nthreads. We follow the blocking scheme in the matrix multiply. As shown in the\nfigure below, given a pixel coordinate (y, x), a thread block is responsible\nfor computing a region of block_factor x block_factor (64 x 64) for output\nchannels and batch. Due to the limit of shared memory space, we only load step\nx block_factor (8 x 64) data from Apad and B each time to buffers in the\nshared memory.\n\n\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# tile consts\ntile = 8\nnum_thread = 8\nblock_factor = tile * num_thread\nstep = 8\nvthread = 2\n\n# Get the GPU thread indices\nblock_x = te.thread_axis(\"blockIdx.x\")\nblock_y = te.thread_axis(\"blockIdx.y\")\nblock_z = te.thread_axis(\"blockIdx.z\")\nthread_x = te.thread_axis((0, num_thread), \"threadIdx.x\")\nthread_y = te.thread_axis((0, num_thread), \"threadIdx.y\")\nthread_xz = te.thread_axis((0, vthread), \"vthread\", name=\"vx\")\nthread_yz = te.thread_axis((0, vthread), \"vthread\", name=\"vy\")\n\n# Split the workloads\nhi, wi, fi, ni = s[B].op.axis\nbz = s[B].fuse(hi, wi)\nby, fi = s[B].split(fi, factor=block_factor)\nbx, ni = s[B].split(ni, factor=block_factor)\n\n# Bind the iteration variables to GPU thread indices\ns[B].bind(bz, block_z)\ns[B].bind(by, block_y)\ns[B].bind(bx, block_x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Virtual Thread Split\n\nWe further split the workload from a thread block to individual threads. To\navoid *memory bank conflict*, we use virtual thread to split the area into 4\nparts, and then tile into 8x8 grids. Therefore, shown in the figure below,\neach thread computes 4 strided grids, where size of each grid is 4 x 4.\n\n\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "tyz, fi = s[B].split(fi, nparts=vthread) # virtual thread split\ntxz, ni = s[B].split(ni, nparts=vthread) # virtual thread split\nty, fi = s[B].split(fi, nparts=num_thread)\ntx, ni = s[B].split(ni, nparts=num_thread)\ns[B].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)\n\ns[B].bind(tyz, thread_yz)\ns[B].bind(txz, thread_xz)\ns[B].bind(ty, thread_y)\ns[B].bind(tx, thread_x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cooperative Fetching\n\nAs mentioned before, each time step we need to transfer step x block_factor\ndata from GPU global memory to shared memory. In order to reduce the memory\ntransfer per thread, the following code lets threads in the same thread block\ncoopertively fetch dependent data from global memory.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Schedule BL local write\ns[BL].compute_at(s[B], tx)\nyi, xi, fi, ni = s[BL].op.axis\nry, rx, rc = s[BL].op.reduce_axis\nrco, rci = s[BL].split(rc, factor=step)\ns[BL].reorder(rco, ry, rx, rci, fi, ni)\n\n# Attach computation to iteration variables\ns[AA].compute_at(s[BL], rx)\ns[WW].compute_at(s[BL], rx)\ns[AL].compute_at(s[BL], rci)\ns[WL].compute_at(s[BL], rci)\n\n# Schedule for A's shared memory load\nyi, xi, ci, ni = s[AA].op.axis\nty, ci = s[AA].split(ci, nparts=num_thread)\ntx, ni = s[AA].split(ni, nparts=num_thread)\n_, ni = s[AA].split(ni, factor=4)\ns[AA].reorder(ty, tx, yi, xi, ci, ni)\ns[AA].bind(ty, thread_y)\ns[AA].bind(tx, thread_x)\ns[AA].vectorize(ni) # vectorize memory load\n\n# Schedule for W's shared memory load\nyi, xi, ci, fi = s[WW].op.axis\nty, ci = s[WW].split(ci, nparts=num_thread)\ntx, fi = s[WW].split(fi, nparts=num_thread)\n_, fi = s[WW].split(fi, factor=4)\ns[WW].reorder(ty, tx, yi, xi, ci, fi)\ns[WW].bind(ty, thread_y)\ns[WW].bind(tx, thread_x)\ns[WW].vectorize(fi) # vectorize memory load" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate CUDA Kernel\n\nFinally we use TVM to generate and compile the CUDA kernel, and evaluate the\nlatency of convolution.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "func = tvm.build(s, [A, W, B], \"cuda\")\ndev = tvm.cuda(0)\na_np = np.random.uniform(size=(in_size, in_size, in_channel, batch)).astype(A.dtype)\nw_np = np.random.uniform(size=(kernel, kernel, in_channel, out_channel)).astype(W.dtype)\na = tvm.nd.array(a_np, dev)\nw = tvm.nd.array(w_np, dev)\nb = tvm.nd.array(np.zeros((out_size, out_size, out_channel, batch), dtype=B.dtype), dev)\nfunc(a, w, b)\nevaluator = func.time_evaluator(func.entry_name, dev, number=1)\nprint(\"Convolution: %f ms\" % (evaluator(a, w, b).mean * 1e3))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 0 }