{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(tir_blitz)=\n", "# TensorIR 的突击课程\n", "\n", "**作者**: [Siyuan Feng](https://github.com/Hzfengsy)\n", "\n", "TensorIR 是用于深度学习程序的特定域语言,有两个广泛的目的:\n", "\n", "- 在各种硬件后端进行程序变换和优化的实现。\n", "- 用于自动张量化程序优化的抽象。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import tvm\n", "from tvm.script.parser import ir_module\n", "from tvm.ir.module import IRModule\n", "from tvm.script import tir as T\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## IRModule\n", "\n", "IRModule 是 TVM 的中心数据结构,它包含深度学习程序。它是 IR 变换和模型构建的基本关注对象。\n", "\n", "```{image} https://daobook.github.io/tvm-web-data/images/design/tvm_life_of_irmodule.png\n", ":width: 85%\n", "```\n", "\n", "这是 IRModule 的生命周期(life cycle),它可以从 TVMScript 创建。TensorIR 调度原语(primitive)和传递(pass)是变换 IRModule 的两种主要方式。另外,对 IRModule 进行一系列的变换也是可以接受的。请注意,可以在 **任何** 阶段向 TVMScript 打印 IRModule。在所有变换和优化完成后,可以将 IRModule 构建为可运行的模块,以部署在目标设备上。\n", "\n", "基于 TensorIR 和 IRModule 的设计,能够创建新的编程方式:\n", "\n", "1. 用 TVMScript 写基于 Python-AST 语法的程序。\n", "2. 用 python api 变换和优化程序。\n", "3. 通过命令式的变换 API,交互式地检查和尝试性能。\n", "\n", "## 创建 IRModule\n", "\n", "IRModule 可以通过编写 TVMScript 来创建,TVMScript 是 TVM IR 的可圆润化(round-trippable)的语法。\n", "\n", "与通过 [张量表达式](tutorial-tensor-expr-get-started) 创建计算表达式不同,TensorIR 允许用户通过 TVMScript(嵌入式 python AST 的语言)来编程。这种新方法使得编写复杂的程序并进一步调度和优化它成为可能。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m8\u001b[39m, i)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(A[vi])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(B[vi])\n", " B[vi] \u001b[38;5;129;01m=\u001b[39;00m A[vi] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "@ir_module\n", "class MyModule:\n", " @T.prim_func\n", " def main(A: T.Buffer[8, \"float32\"], B: T.Buffer[8, \"float32\"]):\n", " T.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " for i in range(8):\n", " # block 是计算的抽象。\n", " with T.block(\"B\"):\n", " # 定义 spatial block 迭代器,并将其绑定到值 i。\n", " vi = T.axis.spatial(8, i)\n", " B[vi] = A[vi] + 1.0\n", "\n", "\n", "ir_module = MyModule\n", "print(type(ir_module))\n", "ir_module.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "此外,还可以使用张量表达式 DSL 来编写简单的算子,并将其转换为 IRModule。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i0_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m8\u001b[39m, i0)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(A[i0_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(B[i0_1])\n", " B[i0_1] \u001b[38;5;129;01m=\u001b[39;00m A[i0_1] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "from tvm import te\n", "\n", "A = te.placeholder((8,), dtype=\"float32\", name=\"A\")\n", "B = te.compute((8,), lambda *i: A(*i) + 1.0, name=\"B\")\n", "func = te.create_prim_func([A, B])\n", "ir_module_from_te = IRModule({\"main\": func})\n", "ir_module_from_te.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 构建和运行 IRModule\n", "\n", "我们可以将 IRModule 构建为具有特定目标后端的可运行模块。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "mod = tvm.build(ir_module, target=\"llvm\") # 用于 CPU 后端的模块\n", "print(type(mod))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "准备好输入 array 和输出 array,然后运行该模块。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0. 1. 2. 3. 4. 5. 6. 7.]\n", "[1. 2. 3. 4. 5. 6. 7. 8.]\n" ] } ], "source": [ "a = tvm.nd.array(np.arange(8).astype(\"float32\"))\n", "b = tvm.nd.array(np.zeros((8,)).astype(\"float32\"))\n", "mod(a, b)\n", "print(a)\n", "print(b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 转换 IRModule\n", "\n", "IRModule 是程序优化的中心数据结构,它可以通过 `Schedule` 进行转换。调度包含多个原语方法,以交互式地转换程序。每个原语都以某些方式改造程序,以带来额外的性能优化。\n", "\n", "\n", "\n", "上面的图片是优化张量程序的典型工作流程。首先,需要在由 TVMScript 或 Tensor Expression 创建的初始 IRModule 上创建调度。然后,一连串的调度原语将有助于提高性能。最后,我们可以将其降低并构建为可运行的模块。\n", "\n", "这里只演示了非常简单的变换。首先,在输入的 `ir_module` 上创建调度。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "sch = tvm.tir.Schedule(ir_module)\n", "print(type(sch))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "将该循环分为 3 个循环,并打印结果。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0, i_1, i_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m8\u001b[39m, i_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(A[vi])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(B[vi])\n", " B[vi] \u001b[38;5;129;01m=\u001b[39;00m A[vi] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "# 按 name 获取 block\n", "block_b = sch.get_block(\"B\")\n", "# 获取 block 周围的 loops\n", "(i,) = sch.get_loops(block_b)\n", "# 平铺(tile)循环嵌套。\n", "i_0, i_1, i_2 = sch.split(i, factors=[None, 2, 2])\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "也可以重新调度循环的顺序。现在将循环 `i_2` 移到 `i_1` 的外面。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0, i_2, i_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m8\u001b[39m, i_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(A[vi])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(B[vi])\n", " B[vi] \u001b[38;5;129;01m=\u001b[39;00m A[vi] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "sch.reorder(i_0, i_2, i_1)\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 转化为 GPU 程序\n", "\n", "如果想在 GPU 上部署模型,线程绑定是必要的。幸运的是,也可以使用原语并做增量变换。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m8\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m2\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m2\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m8\u001b[39m, i_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(A[vi])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(B[vi])\n", " B[vi] \u001b[38;5;129;01m=\u001b[39;00m A[vi] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "sch.bind(i_0, \"blockIdx.x\")\n", "sch.bind(i_2, \"threadIdx.x\")\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "绑定线程后,现在用 `cuda` 后端构建 IRModule。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0. 1. 2. 3. 4. 5. 6. 7.]\n", "[1. 2. 3. 4. 5. 6. 7. 8.]\n" ] } ], "source": [ "ctx = tvm.cuda(0)\n", "cuda_mod = tvm.build(sch.mod, target=\"cuda\")\n", "cuda_a = tvm.nd.array(np.arange(8).astype(\"float32\"), ctx)\n", "cuda_b = tvm.nd.array(np.zeros((8,)).astype(\"float32\"), ctx)\n", "cuda_mod(cuda_a, cuda_b)\n", "print(cuda_a)\n", "print(cuda_b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('py38': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "vscode": { "interpreter": { "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7" } } }, "nbformat": 4, "nbformat_minor": 0 }