编写定制 Pass#

原作者: Jian Weng

TVM 是抽象出机器学习加速器异质性(heterogenity)的框架。有时用户可能希望定制一些分析和 IR 变换,使 TVM 适应他们自己的专用硬件。本教程帮助用户在 TVM 中编写定制的pass。

前提条件#

在阅读本教程开始之前,假设读者已经很好地了解了以下主题:

  • 在 TVM 中编写算法并对其进行调度。否则,请参阅示例教程,如 How to optimize GEMM on CPU

  • HalideIR 的基本结构。否则,请参见 HalideIR/src/ir/IR.h 来了解 IR 节点定义了哪些属性。

  • 访问者设计模式(Visitor design pattern)。否则,请查看 Python AST 模块,查看 AST visitor 是如何实现的。

  • 如何将 Schedule 降格(lower)为 IRModule class 或 LLVM module。否则,请查看 python/tvm/build_module.py 以获得一些基础知识。

import tvm
from tvm import te

编写非常简单的向量加法,并使用默认的调度来构建它。然后,使用定制的 lower pass 直接操作 IR,而不是使用调度原语(primitives.)。

n = tvm.tir.const(128, "int32")
a = te.placeholder((n,), name="a")
b = te.placeholder((n,), name="b")
c = te.compute((n,), lambda i: a[i] + b[i], name="c")

sch = te.create_schedule(c.op)
ir = tvm.lower(sch, [a, b, c])
ir.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(a: T.Buffer[128, "float32"], b: T.Buffer[128, "float32"], c: T.Buffer[128, "float32"]) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        T.preflattened_buffer(a, [128], dtype="float32", data=a.data)
        T.preflattened_buffer(b, [128], dtype="float32", data=b.data)
        T.preflattened_buffer(c, [128], dtype="float32", data=c.data)
        # body
        for i in T.serial(128):
            c[i] = a[i] + b[i]
    

编写 Pass#

本质上,“IR 变换 pass” 是将语句映射到新语句的函数。因此,下面定义了一个向量化函数,并逐步实现它。

TVM 已经为用户提供了两个类来分析和变换 IR。

IR Visitor#

可以使用 tvm.tir.stmt_functor.post_order_visit(stmt, func) 从 Halide IR 收集信息。func 是回调函数。该函数将在退出当前 IR 节点之前调用,即后序访问(post-order visit)。然后利用副作用来存储 IR 访问的结果,因为 func 的返回值会被忽略。

备注

你必须使用一些数组来存储 IR 访问的结果。甚至值也是 single 变量。这主要是由于 Python-C 运行时中的约束。每次递归都会刷新变量值,但保留数组值。

loops = []

def find_width8(op):
    """找出所有范围能被 8 除的 'tir.For' 节点。"""
    if isinstance(op, tvm.tir.For):
        if isinstance(op.extent, tvm.tir.IntImm):
            if op.extent.value % 8 == 0:
                loops.append(op)

IR 变换#

变换(transformation)接口与 visitor 接口略有不同。在 visitor 中只有 post-order 回调,但是 transformation visitor 同时支持 pre-order 和 post-order 回调。如果您想保留原始 IR 节点,只需返回 None。如果您想将当前节点更改为某个节点,请使用 TVM IR maker 接口来构建它并返回此值。

备注

如果 pre-order 函数被调用并返回非 None 的值,则 post-order 函数将被跳过。

def vectorize8(op):
    """Split can vectorize the loops found in `find_width8`."""
    if op in loops:
        extent = op.extent.value
        name = op.loop_var.name
        lo, li = te.var(name + ".outer"), te.var(name + ".inner")
        body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li})
        body = tvm.tir.For(li, 0, 8, tvm.tir.ForKind.VECTORIZED, body)
        body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.ForKind.SERIAL, body)
        return body
    return None


@tvm.tir.transform.prim_func_pass(opt_level=0)
def vectorize(f, mod, ctx):
    global loops
    tvm.tir.stmt_functor.post_order_visit(f.body, find_width8)
    if not loops:
        return f
    # 最后一个 list 参数表示要转换的节点类型。
    # 因此,在这种情况下,只有 `For` 节点会调用 `vectorize8`
    return f.with_body(tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ["tir.For"]))

Glue 到 lower pass#

到目前为止,已经完成了这个 IR 变换过程。接下来需要做的是将这个 pass 粘合到 TVM 的 lower pass 上。

在本例中,通过向 tir.add_lower_pass 提供元组参数列表,将上面编写的 pass 注入到 TVM 标准 lower pass 中。“Tuple” 表明 lower 的不同阶段。在 TVM 中,有四个 lower 阶段,每个阶段(phase)完成后将调用用户自定义的阶段。

备注

以下是每个阶段所做的基本变换:

  • 阶段 0:生成 raw IR 和循环级别(loop levels)。

  • 阶段 1:对 array storage 进行扁平化(flatten)。

  • 阶段 2:变换循环(transforms loops):如 unroll、vectorization 和 thread-binding。

  • 阶段 3:做一些清理工作。

因此,将这个变换过程放置在阶段 1 之后是一个很好的地方。

with tvm.transform.PassContext(config={"tir.add_lower_pass":
                                       [(1, vectorize)]}):
    tvm.lower(sch, [a, b, c]).show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(a: T.Buffer[128, "float32"], b: T.Buffer[128, "float32"], c: T.Buffer[128, "float32"]) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        T.preflattened_buffer(a, [128], dtype="float32", data=a.data)
        T.preflattened_buffer(b, [128], dtype="float32", data=b.data)
        T.preflattened_buffer(c, [128], dtype="float32", data=c.data)
        # body
        for i_outer in T.serial(16):
            cse_var_1: T.int32 = i_outer * 8
            c[cse_var_1:cse_var_1 + 8] = a[cse_var_1:cse_var_1 + 8] + b[cse_var_1:cse_var_1 + 8]
    

快速视图#

本教程提供了编写自定义 IR 变换 pass 的快速视图:

  • 使用 tvm.tir.stmt_functor.post_order_visit 收集每个 IR 节点的信息。

  • 使用 tvm.tir.stmt_functor.ir_transform 变换 IR 节点。

  • 包装上面的两个,写出 IR-transformation 函数。

  • 使用 tvm.transform.PassContext 将该函数放入 TVM lowering pass