{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "(tir_blitz)=\n",
        "# TensorIR 的突击课程\n",
        "\n",
        "**作者**: [Siyuan Feng](https://github.com/Hzfengsy)\n",
        "\n",
        "TensorIR 是用于深度学习程序的特定域语言,有两个广泛的目的:\n",
        "\n",
        "- 在各种硬件后端进行程序变换和优化的实现。\n",
        "- 用于自动张量化程序优化的抽象。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tvm\n",
        "from tvm.script.parser import ir_module\n",
        "from tvm.ir.module import IRModule\n",
        "from tvm.script import tir as T\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## IRModule\n",
        "\n",
        "IRModule 是 TVM 的中心数据结构,它包含深度学习程序。它是 IR 变换和模型构建的基本关注对象。\n",
        "\n",
        "```{image} https://daobook.github.io/tvm-web-data/images/design/tvm_life_of_irmodule.png\n",
        ":width: 85%\n",
        "```\n",
        "\n",
        "这是 IRModule 的生命周期(life cycle),它可以从 TVMScript 创建。TensorIR 调度原语(primitive)和传递(pass)是变换 IRModule 的两种主要方式。另外,对 IRModule 进行一系列的变换也是可以接受的。请注意,可以在 **任何** 阶段向 TVMScript 打印 IRModule。在所有变换和优化完成后,可以将 IRModule 构建为可运行的模块,以部署在目标设备上。\n",
        "\n",
        "基于 TensorIR 和 IRModule 的设计,能够创建新的编程方式:\n",
        "\n",
        "1. 用 TVMScript 写基于 Python-AST 语法的程序。\n",
        "2. 用 python api 变换和优化程序。\n",
        "3. 通过命令式的变换 API,交互式地检查和尝试性能。\n",
        "\n",
        "## 创建 IRModule\n",
        "\n",
        "IRModule 可以通过编写 TVMScript 来创建,TVMScript 是 TVM IR 的可圆润化(round-trippable)的语法。\n",
        "\n",
        "与通过 [张量表达式](tutorial-tensor-expr-get-started) 创建计算表达式不同,TensorIR 允许用户通过 TVMScript(嵌入式 python AST 的语言)来编程。这种新方法使得编写复杂的程序并进一步调度和优化它成为可能。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<class 'tvm.ir.module.IRModule'>\n",
            "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n",
            "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n",
            "    \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n",
            "    \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
            "        \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n",
            "        T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n",
            "        \u001b[38;5;30;03m# body\u001b[39;00m\n",
            "        \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n",
            "        \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m8\u001b[39m):\n",
            "            \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
            "                vi \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m8\u001b[39m, i)\n",
            "                T\u001b[38;5;129;01m.\u001b[39;00mreads(A[vi])\n",
            "                T\u001b[38;5;129;01m.\u001b[39;00mwrites(B[vi])\n",
            "                B[vi] \u001b[38;5;129;01m=\u001b[39;00m A[vi] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n",
            "    \n",
            "\n"
          ]
        }
      ],
      "source": [
        "@ir_module\n",
        "class MyModule:\n",
        "    @T.prim_func\n",
        "    def main(A: T.Buffer[8, \"float32\"], B: T.Buffer[8, \"float32\"]):\n",
        "        T.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n",
        "        for i in range(8):\n",
        "            # block 是计算的抽象。\n",
        "            with T.block(\"B\"):\n",
        "                # 定义 spatial block 迭代器,并将其绑定到值 i。\n",
        "                vi = T.axis.spatial(8, i)\n",
        "                B[vi] = A[vi] + 1.0\n",
        "\n",
        "\n",
        "ir_module = MyModule\n",
        "print(type(ir_module))\n",
        "ir_module.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "此外,还可以使用张量表达式 DSL 来编写简单的算子,并将其转换为 IRModule。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n",
            "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n",
            "    \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n",
            "    \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
            "        \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n",
            "        T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n",
            "        \u001b[38;5;30;03m# body\u001b[39;00m\n",
            "        \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n",
            "        \u001b[38;5;28;01mfor\u001b[39;00m i0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m8\u001b[39m):\n",
            "            \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
            "                i0_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m8\u001b[39m, i0)\n",
            "                T\u001b[38;5;129;01m.\u001b[39;00mreads(A[i0_1])\n",
            "                T\u001b[38;5;129;01m.\u001b[39;00mwrites(B[i0_1])\n",
            "                B[i0_1] \u001b[38;5;129;01m=\u001b[39;00m A[i0_1] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n",
            "    \n",
            "\n"
          ]
        }
      ],
      "source": [
        "from tvm import te\n",
        "\n",
        "A = te.placeholder((8,), dtype=\"float32\", name=\"A\")\n",
        "B = te.compute((8,), lambda *i: A(*i) + 1.0, name=\"B\")\n",
        "func = te.create_prim_func([A, B])\n",
        "ir_module_from_te = IRModule({\"main\": func})\n",
        "ir_module_from_te.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 构建和运行 IRModule\n",
        "\n",
        "我们可以将 IRModule 构建为具有特定目标后端的可运行模块。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<class 'tvm.driver.build_module.OperatorModule'>\n"
          ]
        }
      ],
      "source": [
        "mod = tvm.build(ir_module, target=\"llvm\")  # 用于 CPU 后端的模块\n",
        "print(type(mod))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "准备好输入 array 和输出 array,然后运行该模块。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[0. 1. 2. 3. 4. 5. 6. 7.]\n",
            "[1. 2. 3. 4. 5. 6. 7. 8.]\n"
          ]
        }
      ],
      "source": [
        "a = tvm.nd.array(np.arange(8).astype(\"float32\"))\n",
        "b = tvm.nd.array(np.zeros((8,)).astype(\"float32\"))\n",
        "mod(a, b)\n",
        "print(a)\n",
        "print(b)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 转换 IRModule\n",
        "\n",
        "IRModule 是程序优化的中心数据结构,它可以通过 `Schedule` 进行转换。调度包含多个原语方法,以交互式地转换程序。每个原语都以某些方式改造程序,以带来额外的性能优化。\n",
        "\n",
        "<img src=\"https://daobook.github.io/tvm-web-data/images/design/tvm_tensor_ir_opt_flow.png\" align=\"center\" width=\"100%\">\n",
        "\n",
        "上面的图片是优化张量程序的典型工作流程。首先,需要在由 TVMScript 或 Tensor Expression 创建的初始 IRModule 上创建调度。然后,一连串的调度原语将有助于提高性能。最后,我们可以将其降低并构建为可运行的模块。\n",
        "\n",
        "这里只演示了非常简单的变换。首先,在输入的 `ir_module` 上创建调度。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<class 'tvm.tir.schedule.schedule.Schedule'>\n"
          ]
        }
      ],
      "source": [
        "sch = tvm.tir.Schedule(ir_module)\n",
        "print(type(sch))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "将该循环分为 3 个循环,并打印结果。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n",
            "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n",
            "    \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n",
            "    \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
            "        \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n",
            "        T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n",
            "        \u001b[38;5;30;03m# body\u001b[39;00m\n",
            "        \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n",
            "        \u001b[38;5;28;01mfor\u001b[39;00m i_0, i_1, i_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n",
            "            \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
            "                vi \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m8\u001b[39m, i_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_2)\n",
            "                T\u001b[38;5;129;01m.\u001b[39;00mreads(A[vi])\n",
            "                T\u001b[38;5;129;01m.\u001b[39;00mwrites(B[vi])\n",
            "                B[vi] \u001b[38;5;129;01m=\u001b[39;00m A[vi] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n",
            "    \n",
            "\n"
          ]
        }
      ],
      "source": [
        "# 按 name 获取 block\n",
        "block_b = sch.get_block(\"B\")\n",
        "# 获取 block 周围的 loops\n",
        "(i,) = sch.get_loops(block_b)\n",
        "# 平铺(tile)循环嵌套。\n",
        "i_0, i_1, i_2 = sch.split(i, factors=[None, 2, 2])\n",
        "sch.mod.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "也可以重新调度循环的顺序。现在将循环 `i_2` 移到 `i_1` 的外面。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n",
            "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n",
            "    \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n",
            "    \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
            "        \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n",
            "        T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n",
            "        \u001b[38;5;30;03m# body\u001b[39;00m\n",
            "        \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n",
            "        \u001b[38;5;28;01mfor\u001b[39;00m i_0, i_2, i_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n",
            "            \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
            "                vi \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m8\u001b[39m, i_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_2)\n",
            "                T\u001b[38;5;129;01m.\u001b[39;00mreads(A[vi])\n",
            "                T\u001b[38;5;129;01m.\u001b[39;00mwrites(B[vi])\n",
            "                B[vi] \u001b[38;5;129;01m=\u001b[39;00m A[vi] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n",
            "    \n",
            "\n"
          ]
        }
      ],
      "source": [
        "sch.reorder(i_0, i_2, i_1)\n",
        "sch.mod.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 转化为 GPU 程序\n",
        "\n",
        "如果想在 GPU 上部署模型,线程绑定是必要的。幸运的是,也可以使用原语并做增量变换。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n",
            "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n",
            "    \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n",
            "    \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
            "        \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n",
            "        T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n",
            "        \u001b[38;5;30;03m# body\u001b[39;00m\n",
            "        \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n",
            "        \u001b[38;5;28;01mfor\u001b[39;00m i_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m2\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
            "            \u001b[38;5;28;01mfor\u001b[39;00m i_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m2\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
            "                \u001b[38;5;28;01mfor\u001b[39;00m i_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n",
            "                    \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
            "                        vi \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m8\u001b[39m, i_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_2)\n",
            "                        T\u001b[38;5;129;01m.\u001b[39;00mreads(A[vi])\n",
            "                        T\u001b[38;5;129;01m.\u001b[39;00mwrites(B[vi])\n",
            "                        B[vi] \u001b[38;5;129;01m=\u001b[39;00m A[vi] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n",
            "    \n",
            "\n"
          ]
        }
      ],
      "source": [
        "sch.bind(i_0, \"blockIdx.x\")\n",
        "sch.bind(i_2, \"threadIdx.x\")\n",
        "sch.mod.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "绑定线程后,现在用 `cuda` 后端构建 IRModule。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[0. 1. 2. 3. 4. 5. 6. 7.]\n",
            "[1. 2. 3. 4. 5. 6. 7. 8.]\n"
          ]
        }
      ],
      "source": [
        "ctx = tvm.cuda(0)\n",
        "cuda_mod = tvm.build(sch.mod, target=\"cuda\")\n",
        "cuda_a = tvm.nd.array(np.arange(8).astype(\"float32\"), ctx)\n",
        "cuda_b = tvm.nd.array(np.zeros((8,)).astype(\"float32\"), ctx)\n",
        "cuda_mod(cuda_a, cuda_b)\n",
        "print(cuda_a)\n",
        "print(cuda_b)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3.8.13 ('py38': conda)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    },
    "vscode": {
      "interpreter": {
        "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}