{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Auto-scheduling a Convolution Layer for GPU\n**Author**: [Lianmin Zheng](https://github.com/merrymercy), [Chengfan Jia](https://github.com/jcf94/)\n\nThis is a tutorial on how to use the auto-scheduler for GPUs.\n\nDifferent from the template-based `autotvm ` which relies on\nmanual templates to define the search space, the auto-scheduler does not require any templates.\nUsers only need to write the computation declaration without any schedule commands or templates.\nThe auto-scheduler can automatically generate a large search space and\nfind a good schedule in the space.\n\nWe use a convolution layer as an example in this tutorial.\n\nNote that this tutorial will not run on Windows or recent versions of macOS. To\nget it to run, you will need to wrap the body of this tutorial in a :code:`if\n__name__ == \"__main__\":` block.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import os\n\nimport numpy as np\nimport tvm\nfrom tvm import te, auto_scheduler, topi\nfrom tvm.topi.testing import conv2d_nchw_python" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define the computation\nTo begin with, let us define the computation of a convolution layer.\nThe function should return the list of input/output tensors.\nFrom these tensors, the auto-scheduler can get the whole computational graph.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@auto_scheduler.register_workload\ndef conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):\n data = te.placeholder((N, CI, H, W), name=\"data\")\n kernel = te.placeholder((CO, CI, KH, KW), name=\"kernel\")\n bias = te.placeholder((1, CO, 1, 1), name=\"bias\")\n conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype=\"float32\")\n out = topi.nn.relu(conv + bias)\n return [data, kernel, bias, out]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create the search task\nWe then create a search task for the last convolution layer in the resnet.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "target = tvm.target.Target(\"cuda\")\n\n# Use the last layer in ResNet-50\nN, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)\ntask = auto_scheduler.SearchTask(\n func=conv2d_layer, args=(N, H, W, CO, CI, KH, KW, strides, padding), target=target\n)\n\n# Inspect the computational graph\nprint(\"Computational DAG:\")\nprint(task.compute_dag)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we set parameters for the auto-scheduler. These parameters\nmainly specify how we do the measurement during the search.\n\n* :code:`measure_ctx` launches a different process for measurement to\n provide isolation. It can protect the main process from GPU crashes\n during measurement and avoid other runtime conflicts.\n* :code:`min_repeat_ms` defines the minimum duration of one \"repeat\" in every measurement.\n This can warmup the GPU, which is necessary to get accurate measurement results.\n Typically, we recommend a value >= 300 ms.\n* :code:`num_measure_trials` is the number of measurement trials we can use during the search.\n We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a\n good value for the search to converge. You can do more trials according to your time budget.\n* In addition, we use :code:`RecordToFile` to dump measurement records into a file `conv2d.json`.\n The measurement records can be used to query the history best, resume the search,\n and do more analyses later.\n* see :any:`auto_scheduler.TuningOptions`,\n :any:`auto_scheduler.LocalRPCMeasureContext` for more parameters.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "log_file = \"conv2d.json\"\nmeasure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)\ntune_option = auto_scheduler.TuningOptions(\n num_measure_trials=10, # change this to 1000 to achieve the best performance\n runner=measure_ctx.runner,\n measure_callbacks=[auto_scheduler.RecordToFile(log_file)],\n verbose=2,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run the search\nNow we get all inputs ready. Pretty simple, isn't it?\nWe can kick off the search and let the auto-scheduler do its magic.\nAfter some measurement trials, we can load the best schedule from the log\nfile and apply it.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Run auto-tuning (search)\ntask.tune(tune_option)\n# Apply the best schedule\nsch, args = task.apply_best(log_file)\n\n# Kill the measurement process\ndel measure_ctx" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can lower the schedule to see the IR after auto-scheduling.\nThe auto-scheduler correctly performs optimizations including multi-level tiling,\ncooperative fetching, unrolling and operator fusion.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(\"Lowered TIR:\")\nprint(tvm.lower(sch, args, simple_mode=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Check correctness and evaluate performance\nWe build the binary and check its correctness and performance.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "func = tvm.build(sch, args, target)\n\n# Check correctness\ndata_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)\nweight_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)\nbias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32)\nconv_np = conv2d_nchw_python(data_np, weight_np, strides, padding)\nout_np = np.maximum(conv_np + bias_np, 0.0)\n\ndev = tvm.cuda()\ndata_tvm = tvm.nd.array(data_np, device=dev)\nweight_tvm = tvm.nd.array(weight_np, device=dev)\nbias_tvm = tvm.nd.array(bias_np, device=dev)\nout_tvm = tvm.nd.empty(out_np.shape, device=dev)\nfunc(data_tvm, weight_tvm, bias_tvm, out_tvm)\n\n# Check results\nnp.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)\n\n# Evaluate execution time\nevaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)\nprint(\n \"Execution time of this operator: %.3f ms\"\n % (np.median(evaluator(data_tvm, weight_tvm, bias_tvm, out_tvm).results) * 1000)\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using the record file\nDuring the search, all measurement records are dumped into the record\nfile \"conv2d.json\". The measurement records can be used to re-apply search results,\nresume the search, and perform other analyses.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is an example where we load the best schedule from a file,\nprint the equivalent python schedule API and CUDA source code.\nThey can be used for debugging and learning the behavior of the auto-scheduler.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(\"Equivalent python schedule:\")\nprint(task.print_best(log_file, print_mode=\"schedule\"))\n\nprint(\"CUDA source code:\")\nprint(task.print_best(log_file, print_mode=\"cuda\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A more complicated example is to resume the search.\nIn this case, we need to create the search policy and cost model by ourselves\nand resume the status of search policy and cost model with the log file.\nIn the example below we resume the status and do more 5 trials.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def resume_search(task, log_file):\n print(\"Resume search:\")\n cost_model = auto_scheduler.XGBModel()\n cost_model.update_from_file(log_file)\n search_policy = auto_scheduler.SketchPolicy(\n task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]\n )\n measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)\n tune_option = auto_scheduler.TuningOptions(\n num_measure_trials=5,\n runner=measure_ctx.runner,\n measure_callbacks=[auto_scheduler.RecordToFile(log_file)],\n )\n task.tune(tune_option, search_policy=search_policy)\n\n # Kill the measurement process\n del measure_ctx\n\n\nresume_search(task, log_file)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 0 }