{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 编写定制 Pass\n", "\n", "\n", "**原作者**: [Jian Weng](https://were.github.io)\n", "\n", "TVM 是抽象出机器学习加速器异质性(heterogenity)的框架。有时用户可能希望定制一些分析和 IR 变换,使 TVM 适应他们自己的专用硬件。本教程帮助用户在 TVM 中编写定制的pass。\n", "\n", "## 前提条件\n", "\n", "在阅读本教程开始之前,假设读者已经很好地了解了以下主题:\n", "\n", "- 在 TVM 中编写算法并对其进行调度。否则,请参阅示例教程,如 {ref}`opt-gemm`。\n", "- HalideIR 的基本结构。否则,请参见 ``HalideIR/src/ir/IR.h`` 来了解 IR 节点定义了哪些属性。\n", "- 访问者设计模式(Visitor design pattern)。否则,请查看 [Python AST 模块](https://docs.python.org/3/library/ast.html),查看 AST visitor 是如何实现的。\n", "- 如何将 Schedule 降格(lower)为 IRModule class 或 LLVM module。否则,请查看 ``python/tvm/build_module.py`` 以获得一些基础知识。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import tvm\n", "from tvm import te" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "编写非常简单的向量加法,并使用默认的调度来构建它。然后,使用定制的 lower pass 直接操作 IR,而不是使用调度原语(primitives.)。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(a: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m128\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], b: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m128\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], c: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m128\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(a, [\u001b[38;5;28m128\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00ma\u001b[38;5;129;01m.\u001b[39;00mdata)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(b, [\u001b[38;5;28m128\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mb\u001b[38;5;129;01m.\u001b[39;00mdata)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(c, [\u001b[38;5;28m128\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mc\u001b[38;5;129;01m.\u001b[39;00mdata)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m128\u001b[39m):\n", " c[i] \u001b[38;5;129;01m=\u001b[39;00m a[i] \u001b[38;5;129;01m+\u001b[39;00m b[i]\n", " \n", "\n" ] } ], "source": [ "n = tvm.tir.const(128, \"int32\")\n", "a = te.placeholder((n,), name=\"a\")\n", "b = te.placeholder((n,), name=\"b\")\n", "c = te.compute((n,), lambda i: a[i] + b[i], name=\"c\")\n", "\n", "sch = te.create_schedule(c.op)\n", "ir = tvm.lower(sch, [a, b, c])\n", "ir.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 编写 Pass\n", "\n", "本质上,“IR 变换 pass” 是将语句映射到新语句的函数。因此,下面定义了一个向量化函数,并逐步实现它。\n", "\n", "TVM 已经为用户提供了两个类来分析和变换 IR。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### IR Visitor\n", "\n", "可以使用 ``tvm.tir.stmt_functor.post_order_visit(stmt, func)`` 从 Halide IR 收集信息。``func`` 是回调函数。该函数将在退出当前 IR 节点之前调用,即后序访问(post-order visit)。然后利用副作用来存储 IR 访问的结果,因为 ``func`` 的返回值会被忽略。\n", "\n", "```{note}\n", ":class: alert alert-info\n", "\n", "你必须使用一些数组来存储 IR 访问的结果。甚至值也是 single 变量。这主要是由于 Python-C 运行时中的约束。每次递归都会刷新变量值,但保留数组值。\n", "```" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [], "source": [ "loops = []\n", "\n", "def find_width8(op):\n", " \"\"\"找出所有范围能被 8 除的 'tir.For' 节点。\"\"\"\n", " if isinstance(op, tvm.tir.For):\n", " if isinstance(op.extent, tvm.tir.IntImm):\n", " if op.extent.value % 8 == 0:\n", " loops.append(op)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### IR 变换\n", "\n", "变换(transformation)接口与 visitor 接口略有不同。在 visitor 中只有 post-order 回调,但是 transformation visitor 同时支持 pre-order 和 post-order 回调。如果您想保留原始 IR 节点,只需返回 None。如果您想将当前节点更改为某个节点,请使用 TVM IR maker 接口来构建它并返回此值。\n", "\n", "```{note}\n", ":class: alert alert-info\n", "\n", "如果 pre-order 函数被调用并返回非 None 的值,则 post-order 函数将被跳过。\n", "```" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def vectorize8(op):\n", " \"\"\"Split can vectorize the loops found in `find_width8`.\"\"\"\n", " if op in loops:\n", " extent = op.extent.value\n", " name = op.loop_var.name\n", " lo, li = te.var(name + \".outer\"), te.var(name + \".inner\")\n", " body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li})\n", " body = tvm.tir.For(li, 0, 8, tvm.tir.ForKind.VECTORIZED, body)\n", " body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.ForKind.SERIAL, body)\n", " return body\n", " return None\n", "\n", "\n", "@tvm.tir.transform.prim_func_pass(opt_level=0)\n", "def vectorize(f, mod, ctx):\n", " global loops\n", " tvm.tir.stmt_functor.post_order_visit(f.body, find_width8)\n", " if not loops:\n", " return f\n", " # 最后一个 list 参数表示要转换的节点类型。\n", " # 因此,在这种情况下,只有 `For` 节点会调用 `vectorize8`\n", " return f.with_body(tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, [\"tir.For\"]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Glue 到 lower pass\n", "\n", "到目前为止,已经完成了这个 IR 变换过程。接下来需要做的是将这个 pass 粘合到 TVM 的 lower pass 上。\n", "\n", "在本例中,通过向 ``tir.add_lower_pass`` 提供元组参数列表,将上面编写的 pass 注入到 TVM 标准 lower pass 中。\"Tuple\" 表明 lower 的不同阶段。在 TVM 中,有四个 lower 阶段,每个阶段(phase)完成后将调用用户自定义的阶段。\n", "\n", "```{note}\n", ":class: alert alert-info\n", "以下是每个阶段所做的基本变换:\n", "\n", "- 阶段 0:生成 raw IR 和循环级别(loop levels)。\n", "- 阶段 1:对 array storage 进行扁平化(flatten)。\n", "- 阶段 2:变换循环(transforms loops):如 unroll、vectorization 和 thread-binding。\n", "- 阶段 3:做一些清理工作。\n", "```\n", "\n", "因此,将这个变换过程放置在阶段 1 之后是一个很好的地方。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(a: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m128\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], b: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m128\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], c: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m128\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(a, [\u001b[38;5;28m128\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00ma\u001b[38;5;129;01m.\u001b[39;00mdata)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(b, [\u001b[38;5;28m128\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mb\u001b[38;5;129;01m.\u001b[39;00mdata)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(c, [\u001b[38;5;28m128\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mc\u001b[38;5;129;01m.\u001b[39;00mdata)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_outer \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " cse_var_1: T\u001b[38;5;129;01m.\u001b[39;00mint32 \u001b[38;5;129;01m=\u001b[39;00m i_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m\n", " c[cse_var_1:cse_var_1 \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m8\u001b[39m] \u001b[38;5;129;01m=\u001b[39;00m a[cse_var_1:cse_var_1 \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m8\u001b[39m] \u001b[38;5;129;01m+\u001b[39;00m b[cse_var_1:cse_var_1 \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m8\u001b[39m]\n", " \n", "\n" ] } ], "source": [ "with tvm.transform.PassContext(config={\"tir.add_lower_pass\":\n", " [(1, vectorize)]}):\n", " tvm.lower(sch, [a, b, c]).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 快速视图\n", "\n", "本教程提供了编写自定义 IR 变换 pass 的快速视图:\n", "\n", "- 使用 ``tvm.tir.stmt_functor.post_order_visit`` 收集每个 IR 节点的信息。\n", "- 使用 ``tvm.tir.stmt_functor.ir_transform`` 变换 IR 节点。\n", "- 包装上面的两个,写出 IR-transformation 函数。\n", "- 使用 ``tvm.transform.PassContext`` 将该函数放入 TVM lowering pass" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('py38': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "vscode": { "interpreter": { "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7" } } }, "nbformat": 4, "nbformat_minor": 0 }