{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 在 Relay 中使用管道执行器\n", "\n", "**原作者**: [Hua Jiang](https://github.com/huajsj)\n", "\n", "这是关于如何在 Relay 中使用“管道执行器”(Pipeline Executor)的简短教程。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/media/pc/data/4tb/lxw/home/lxw/tvm/xinetzone/src\n" ] } ], "source": [ "import env # 加载 TVM\n", "import tvm\n", "from tvm import te\n", "import numpy as np\n", "from tvm.contrib import graph_executor as runtime\n", "from tvm.relay.op.contrib.cutlass import partition_for_cutlass\n", "from tvm import relay\n", "from tvm.relay import testing\n", "import tvm.testing\n", "from tvm.contrib.cutlass import (\n", " has_cutlass,\n", " num_cutlass_partitions,\n", " finalize_modules,\n", " finalize_modules_vm,\n", ")\n", "\n", "img_size = 8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 创建简单网络,可以是预训练的模型\n", "\n", "创建非常简单的网络进行演示。它由卷积、batch normalization、dense 和 ReLU 激活组成。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def get_network():\n", " out_channels = 16\n", " batch_size = 1\n", " data = relay.var(\"data\", relay.TensorType((batch_size, 3, img_size, img_size), \"float16\"))\n", " dense_weight = relay.var(\n", " \"dweight\", relay.TensorType((batch_size, 16 * img_size * img_size), \"float16\")\n", " )\n", " weight = relay.var(\"weight\")\n", " second_weight = relay.var(\"second_weight\")\n", " bn_gamma = relay.var(\"bn_gamma\")\n", " bn_beta = relay.var(\"bn_beta\")\n", " bn_mmean = relay.var(\"bn_mean\")\n", " bn_mvar = relay.var(\"bn_var\")\n", " simple_net = relay.nn.conv2d(\n", " data=data, weight=weight, kernel_size=(3, 3), channels=out_channels, padding=(1, 1)\n", " )\n", " simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0]\n", " simple_net = relay.nn.relu(simple_net)\n", " simple_net = relay.nn.batch_flatten(simple_net)\n", " simple_net = relay.nn.dense(simple_net, dense_weight)\n", " simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net)\n", " data_shape = (batch_size, 3, img_size, img_size)\n", " net, params = testing.create_workload(simple_net)\n", " return net, params, data_shape\n", "\n", "\n", "net, params, data_shape = get_network()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 将网络分成两个子图\n", "\n", "单元测试中的 'graph_split' 函数只是一个例子。用户可以创建自定义的逻辑来分割计算图。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import os\n", "from env import TVM_ROOT\n", "os.sys.path.append(f\"{TVM_ROOT}/tests/python/relay\")\n", "\n", "from test_pipeline_executor import graph_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "将网络分成两个子图。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [], "source": [ "split_config = [{\"op_name\": \"nn.relu\", \"op_index\": 0}]\n", "subgraphs = graph_split(net[\"main\"], split_config, params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "生成的子图应该如下所示。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "子图 0:\n", " fn (%data: Tensor[(1, 3, 8, 8), float16] /* ty=Tensor[(1, 3, 8, 8), float16] */) {\n", " %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float16] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 8, 8), float16] */;\n", " %1 = nn.batch_norm(%0, meta[relay.Constant][1] /* ty=Tensor[(16), float16] */, meta[relay.Constant][2] /* ty=Tensor[(16), float16] */, meta[relay.Constant][3] /* ty=Tensor[(16), float16] */, meta[relay.Constant][4] /* ty=Tensor[(16), float16] */) /* ty=(Tensor[(1, 16, 8, 8), float16], Tensor[(16), float16], Tensor[(16), float16]) */;\n", " %2 = %1.0 /* ty=Tensor[(1, 16, 8, 8), float16] */;\n", " nn.relu(%2) /* ty=Tensor[(1, 16, 8, 8), float16] */\n", "}\n", "\n", "子图 1:\n", " fn (%data_n_0: Tensor[(1, 16, 8, 8), float16] /* ty=Tensor[(1, 16, 8, 8), float16] */) {\n", " %0 = nn.batch_flatten(%data_n_0) /* ty=Tensor[(1, 1024), float16] */;\n", " nn.dense(%0, meta[relay.Constant][0] /* ty=Tensor[(1, 1024), float16] */, units=None) /* ty=Tensor[(1, 1), float16] */\n", "}\n", "\n" ] } ], "source": [ "for k, m in enumerate(subgraphs):\n", " print(f\"子图 {k}:\\n\", m[\"main\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 使用 cutlass target 构建子图" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [], "source": [ "cutlass = tvm.target.Target(\n", " {\n", " \"kind\": \"cutlass\",\n", " \"sm\": int(tvm.target.Target(\"cuda\").arch.split(\"_\")[1]),\n", " \"use_3xtf32\": True,\n", " \"split_k_slices\": [1],\n", " \"profile_all_alignments\": False,\n", " \"find_first_valid\": True,\n", " \"use_multiprocessing\": True,\n", " \"use_fast_math\": False,\n", " \"tmp_dir\": \"./tmp\",\n", " },\n", " host=tvm.target.Target(\"llvm\"),\n", ")\n", "\n", "\n", "def cutlass_build(mod, target, params=None, target_host=None, mod_name=\"default\"):\n", " target = [target, cutlass]\n", " lib = relay.build_module.build(\n", " mod, target=target, params=params, target_host=target_host, mod_name=mod_name\n", " )\n", " return lib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 使用管道执行器在管道中运行两个子图\n", "\n", "在 cmake 中将 `USE_PIPELINE_EXECUTOR` 设置为 `ON`,并将 `USE_CUTLASS` 设置为 `ON`。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from tvm.contrib import graph_executor, pipeline_executor, pipeline_executor_build" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create subgraph pipeline configuration.\n", "Associate a subgraph module with a target.\n", "Use CUTLASS BYOC to build the second subgraph module.\n", "\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [], "source": [ "mod0, mod1 = subgraphs[0], subgraphs[1]\n", "# Use cutlass as the codegen.\n", "mod1 = partition_for_cutlass(mod1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get the pipeline executor configuration object.\n", "\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pipe_config = pipeline_executor_build.PipelineConfig()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set the compile target of the subgraph module.\n", "\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pipe_config[mod0].target = \"llvm\"\n", "pipe_config[mod0].dev = tvm.cpu(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set the compile target of the second subgraph module as cuda.\n", "\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pipe_config[mod1].target = \"cuda\"\n", "pipe_config[mod1].dev = tvm.device(\"cuda\", 0)\n", "pipe_config[mod1].build_func = cutlass_build\n", "pipe_config[mod1].export_cc = \"nvcc\"\n", "# Create the pipeline by connecting the subgraph modules.\n", "# The global input will be forwarded to the input interface of the first module named mod0\n", "pipe_config[\"input\"][\"data\"].connect(pipe_config[mod0][\"input\"][\"data\"])\n", "# The first output of mod0 will be forwarded to the input interface of mod1\n", "pipe_config[mod0][\"output\"][0].connect(pipe_config[mod1][\"input\"][\"data_n_0\"])\n", "# The first output of mod1 will be the first global output.\n", "pipe_config[mod1][\"output\"][0].connect(pipe_config[\"output\"][0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The pipeline configuration as below.\n", "\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'\\nprint(pipe_config)\\n Inputs\\n |data: mod0:data\\n\\n output\\n |output(0) : mod1.output(0)\\n\\n connections\\n |mod0.output(0)-> mod1.data_n_0\\n'" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"\n", "print(pipe_config)\n", " Inputs\n", " |data: mod0:data\n", "\n", " output\n", " |output(0) : mod1.output(0)\n", "\n", " connections\n", " |mod0.output(0)-> mod1.data_n_0\n", "\"\"\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build the pipeline executor.\n", "\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/media/pc/data/4tb/lxw/home/lxw/tvm/python/tvm/driver/build_module.py:266: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.\n", " warnings.warn(\n", "One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n" ] } ], "source": [ "with tvm.transform.PassContext(opt_level=3):\n", " pipeline_mod_factory = pipeline_executor_build.build(pipe_config)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Export the parameter configuration to a file.\n", "\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [], "source": [ "directory_path = tvm.contrib.utils.tempdir().temp_dir\n", "os.makedirs(directory_path, exist_ok=True)\n", "config_file_name = pipeline_mod_factory.export_library(directory_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Use the load function to create and initialize PipelineModule.\n", "\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pipeline_module = pipeline_executor.PipelineModule.load_library(config_file_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run the pipeline executor.\n", "Allocate input data.\n", "\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [], "source": [ "data = np.random.uniform(-1, 1, size=data_shape).astype(\"float16\")\n", "pipeline_module.set_input(\"data\", tvm.nd.array(data))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Run the two subgraph in the pipeline mode to get the output asynchronously\n", "or synchronously. In the following example, it is synchronous.\n", "\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pipeline_module.run()\n", "outputs = pipeline_module.get_output()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Use graph_executor for verification.\n", "Run these two subgraphs in sequence with graph_executor to get the output.\n", "\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/media/pc/data/4tb/lxw/home/lxw/tvm/python/tvm/driver/build_module.py:266: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.\n", " warnings.warn(\n" ] } ], "source": [ "target = \"llvm\"\n", "dev0 = tvm.device(target, 0)\n", "lib0 = relay.build_module.build(mod0, target, params=params)\n", "module0 = runtime.GraphModule(lib0[\"default\"](dev0))\n", "cuda = tvm.target.Target(\"cuda\", host=tvm.target.Target(\"llvm\"))\n", "lib1 = relay.build_module.build(mod1, [cuda, cutlass], params=params)\n", "lib1 = finalize_modules(lib1, \"compile.so\", \"./tmp\")\n", "\n", "dev1 = tvm.device(\"cuda\", 0)\n", "\n", "module1 = runtime.GraphModule(lib1[\"default\"](dev1))\n", "\n", "module0.set_input(\"data\", data)\n", "module0.run()\n", "out_shape = (1, 16, img_size, img_size)\n", "out = module0.get_output(0, tvm.nd.empty(out_shape, \"float16\"))\n", "module1.set_input(\"data_n_0\", out)\n", "module1.run()\n", "out_shape = (1, 1)\n", "out = module1.get_output(0, tvm.nd.empty(out_shape, \"float16\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Verify the result.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "tvm.testing.assert_allclose(outputs[0].numpy(), out.numpy())" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.2 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 0 }