{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "qXysoqn-vZuF" }, "source": [ "## Install packages \n", "\n", "For the purpose of this course, we will use some ongoing development in tvm, which is an open-source machine learning compilation framework. We provide the following command to install a packaged version for mlc course." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Xe3vClsD9jlq", "outputId": "9321483a-59c5-40a8-d6fc-d00a44d97053" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in links: https://mlc.ai/wheels\n", "Requirement already satisfied: mlc-ai-nightly in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (0.9.dev1956+ge3f218d71)\n", "Requirement already satisfied: tornado in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (6.2)\n", "Requirement already satisfied: Pygments in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (2.12.0)\n", "Requirement already satisfied: synr==0.6.0 in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (0.6.0)\n", "Requirement already satisfied: decorator in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.1.1)\n", "Requirement already satisfied: numpy in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.22.3)\n", "Requirement already satisfied: attrs in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (22.1.0)\n", "Requirement already satisfied: cloudpickle in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (2.1.0)\n", "Requirement already satisfied: psutil in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.9.1)\n", "Requirement already satisfied: scipy in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.9.0)\n" ] } ], "source": [ "!python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels" ] }, { "cell_type": "markdown", "metadata": { "id": "i-14C4skxIrJ" }, "source": [ "## Prelude\n", "\n", "In the past chapters, we learned about how to build primitive tensor functions and connect them to form end-to-end model executions. There are three primary types of abstractions we have used so far.\n", "\n", "- A computational graph view that drives the high-level executions.\n", "- Abstraction for primitive tensor functions.\n", "- Library function calls via environment function registration.\n", "\n", "All of these elements are encapsulated in an IRModule. Most of the MLC processes can be viewed as transformations among tensor functions.\n", "\n", "There are many different ways to transform the same program. This chapter will discuss ways to automate some of the processes. \n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "BBIuE2jc1DaU" }, "source": [ "## Preparations\n", "\n", "To begin with, we will import necessary dependencies and create helper functions.\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "BVp0fHyRkYj6" }, "outputs": [], "source": [ "import tvm\n", "from tvm.ir.module import IRModule\n", "from tvm.script import tir as T, relax as R\n", "import numpy as np\n", "from tvm import relax\n", "# This is needed for deferring annotation parsing in TVMScript\n", "from __future__ import annotations " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "LlyLijMxEcs3" }, "outputs": [], "source": [ "import IPython\n", "\n", "def code2html(code):\n", " \"\"\"Helper function to use pygments to turn the code string into highlighted html.\"\"\"\n", " import pygments\n", " from pygments.lexers import Python3Lexer\n", " from pygments.formatters import HtmlFormatter\n", " formatter = HtmlFormatter()\n", " html = pygments.highlight(code, Python3Lexer(), formatter)\n", " return \"%s\\n\" % (formatter.get_style_defs(\".highlight\"), html)" ] }, { "cell_type": "markdown", "metadata": { "id": "8yH4IMSMvF9o" }, "source": [ "## Recap: Transform a Primitive Tensor Function.\n", "\n", "Let us begin by reviewing what we did in our previous chapters -- transforming a single primitive tensor function." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "zlvL1Zfkt9A4" }, "outputs": [], "source": [ "@tvm.script.ir_module\n", "class MyModule:\n", " @T.prim_func\n", " def main(\n", " A: T.Buffer[(128, 128), \"float32\"],\n", " B: T.Buffer[(128, 128), \"float32\"],\n", " C: T.Buffer[(128, 128), \"float32\"],\n", " ):\n", " T.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " for i, j, k in T.grid(128, 128, 128):\n", " with T.block(\"C\"):\n", " vi, vj, vk = T.axis.remap(\"SSR\", [i, j, k])\n", " with T.init():\n", " C[vi, vj] = 0.0\n", " C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]" ] }, { "cell_type": "markdown", "metadata": { "id": "6gtnHe0KcMGO" }, "source": [ "First, let us define a set of inputs and outputs for evaluation." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "qVfi2Rb6Ch6t" }, "outputs": [], "source": [ "dtype = \"float32\"\n", "a_np = np.random.rand(128, 128).astype(dtype)\n", "b_np = np.random.rand(128, 128).astype(dtype)\n", "c_mm = a_np @ b_np" ] }, { "cell_type": "markdown", "metadata": { "id": "LxUn0qMZcQYp" }, "source": [ "We can build and run `MyModule` as follows.\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rexBtll-CFtA", "outputId": "46264544-5b84-45f3-d6c3-62758a43a960" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Time cost of MyModule: 2.261 ms\n" ] } ], "source": [ "a_nd = tvm.nd.array(a_np)\n", "b_nd = tvm.nd.array(b_np)\n", "c_nd = tvm.nd.empty((128, 128), dtype=\"float32\")\n", "\n", "lib = tvm.build(MyModule, target=\"llvm\")\n", "f_timer_before = lib.time_evaluator(\"main\", tvm.cpu())\n", "print(\"Time cost of MyModule: %.3f ms\" % (f_timer_before(a_nd, b_nd, c_nd).mean * 1000))" ] }, { "cell_type": "markdown", "metadata": { "id": "d55E-IE8cXC8" }, "source": [ "Next, we transform `MyModule` a bit by reorganizing the loop access pattern." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "ANEDn2_BCF3F" }, "outputs": [], "source": [ "def schedule_mm(sch: tvm.tir.Schedule, jfactor=4):\n", " block_C = sch.get_block(\"C\", \"main\")\n", " i, j, k = sch.get_loops(block=block_C)\n", " j_0, j_1 = sch.split(loop=j, factors=[None, jfactor])\n", " sch.reorder(i, j_0, k, j_1)\n", " sch.decompose_reduction(block_C, k)\n", " return sch" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 451 }, "id": "lxXbTgv0CGAH", "outputId": "01b3c601-bced-45c4-c4b3-567b4d18411a" }, "outputs": [ { "data": { "text/html": [ "
@tvm.script.ir_module\n",
"class Module:\n",
" @T.prim_func\n",
" def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:\n",
" # function attr dict\n",
" T.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
" # body\n",
" # with T.block("root")\n",
" for i, j_0 in T.grid(128, 32):\n",
" for j_1_init in T.serial(4):\n",
" with T.block("C_init"):\n",
" vi = T.axis.spatial(128, i)\n",
" vj = T.axis.spatial(128, j_0 * 4 + j_1_init)\n",
" T.reads()\n",
" T.writes(C[vi, vj])\n",
" C[vi, vj] = T.float32(0)\n",
" for k, j_1 in T.grid(128, 4):\n",
" with T.block("C_update"):\n",
" vi = T.axis.spatial(128, i)\n",
" vj = T.axis.spatial(128, j_0 * 4 + j_1)\n",
" vk = T.axis.reduce(128, k)\n",
" T.reads(C[vi, vj], A[vi, vk], B[vk, vj])\n",
" T.writes(C[vi, vj])\n",
" C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]\n",
" \n",
"
@tvm.script.ir_module\n",
"class Module:\n",
" @T.prim_func\n",
" def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:\n",
" # function attr dict\n",
" T.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
" # body\n",
" # with T.block("root")\n",
" for i, j_0 in T.grid(128, 32):\n",
" for j_1_init in T.serial(4):\n",
" with T.block("C_init"):\n",
" vi = T.axis.spatial(128, i)\n",
" vj = T.axis.spatial(128, j_0 * 4 + j_1_init)\n",
" T.reads()\n",
" T.writes(C[vi, vj])\n",
" C[vi, vj] = T.float32(0)\n",
" for k, j_1 in T.grid(128, 4):\n",
" with T.block("C_update"):\n",
" vi = T.axis.spatial(128, i)\n",
" vj = T.axis.spatial(128, j_0 * 4 + j_1)\n",
" vk = T.axis.reduce(128, k)\n",
" T.reads(C[vi, vj], A[vi, vk], B[vk, vj])\n",
" T.writes(C[vi, vj])\n",
" C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]\n",
" \n",
"
@tvm.script.ir_module\n",
"class Module:\n",
" @T.prim_func\n",
" def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:\n",
" # function attr dict\n",
" T.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
" # body\n",
" # with T.block("root")\n",
" for i, j, k in T.grid(128, 128, 128):\n",
" with T.block("C"):\n",
" vi, vj, vk = T.axis.remap("SSR", [i, j, k])\n",
" T.reads(A[vi, vk], B[vk, vj])\n",
" T.writes(C[vi, vj])\n",
" with T.init():\n",
" C[vi, vj] = T.float32(0)\n",
" C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]\n",
" \n",
"
@tvm.script.ir_module\n",
"class Module:\n",
" @T.prim_func\n",
" def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:\n",
" # function attr dict\n",
" T.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
" # body\n",
" # with T.block("root")\n",
" for i, j_0, k, j_1 in T.grid(128, 8, 128, 16):\n",
" with T.block("C"):\n",
" vi = T.axis.spatial(128, i)\n",
" vj = T.axis.spatial(128, j_0 * 16 + j_1)\n",
" vk = T.axis.reduce(128, k)\n",
" T.reads(A[vi, vk], B[vk, vj])\n",
" T.writes(C[vi, vj])\n",
" with T.init():\n",
" C[vi, vj] = T.float32(0)\n",
" C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]\n",
" \n",
"
@tvm.script.ir_module\n",
"class Module:\n",
" @T.prim_func\n",
" def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:\n",
" # function attr dict\n",
" T.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
" # body\n",
" # with T.block("root")\n",
" for i, j_0 in T.grid(128, 8):\n",
" for j_1_init in T.serial(16):\n",
" with T.block("C_init"):\n",
" vi = T.axis.spatial(128, i)\n",
" vj = T.axis.spatial(128, j_0 * 16 + j_1_init)\n",
" T.reads()\n",
" T.writes(C[vi, vj])\n",
" C[vi, vj] = T.float32(0)\n",
" for k, j_1 in T.grid(128, 16):\n",
" with T.block("C_update"):\n",
" vi = T.axis.spatial(128, i)\n",
" vj = T.axis.spatial(128, j_0 * 16 + j_1)\n",
" vk = T.axis.reduce(128, k)\n",
" T.reads(C[vi, vj], A[vi, vk], B[vk, vj])\n",
" T.writes(C[vi, vj])\n",
" C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]\n",
" \n",
"