MLC 作业 1: 端到端模型执行#

第一部分: 模型准备#

本作业的目标是让你对机器学习编译过程中的端到端模型的执行和变换更加熟悉。从简单的图像分类模型开始。

首先使用如下的命令来安装必要的库。

!python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels
!python3 -m pip install torch torchvision torchaudio torchsummary --extra-index-url https://download.pytorch.org/whl/cpu
Looking in links: https://mlc.ai/wheels
Requirement already satisfied: mlc-ai-nightly in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (0.9.dev1664+g1f3985de0)
Requirement already satisfied: psutil in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.9.1)
Requirement already satisfied: scipy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.8.1)
Requirement already satisfied: synr==0.6.0 in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (0.6.0)
Requirement already satisfied: tornado in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (6.1)
Requirement already satisfied: decorator in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.1.1)
Requirement already satisfied: attrs in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (21.4.0)
Requirement already satisfied: numpy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.23.1)
Requirement already satisfied: cloudpickle in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (2.1.0)
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu
Collecting torch
  Downloading https://download.pytorch.org/whl/cpu/torch-1.12.0%2Bcpu-cp310-cp310-linux_x86_64.whl (189.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 189.0/189.0 MB 4.3 MB/s eta 0:00:0000:0100:01
?25hCollecting torchvision
  Downloading https://download.pytorch.org/whl/cpu/torchvision-0.13.0%2Bcpu-cp310-cp310-linux_x86_64.whl (13.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 8.8 MB/s eta 0:00:0000:0100:01m
?25hCollecting torchaudio
  Downloading https://download.pytorch.org/whl/cpu/torchaudio-0.12.0%2Bcpu-cp310-cp310-linux_x86_64.whl (3.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.5/3.5 MB 15.3 MB/s eta 0:00:0000:0100:01
?25hCollecting torchsummary
  Using cached torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Requirement already satisfied: typing-extensions in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from torch) (4.3.0)
Requirement already satisfied: numpy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from torchvision) (1.23.1)
Requirement already satisfied: requests in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from torchvision) (2.28.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from torchvision) (9.2.0)
Requirement already satisfied: idna<4,>=2.5 in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from requests->torchvision) (3.3)
Requirement already satisfied: certifi>=2017.4.17 in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from requests->torchvision) (2022.6.15)
Requirement already satisfied: charset-normalizer<3,>=2 in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from requests->torchvision) (2.1.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from requests->torchvision) (1.26.10)
Installing collected packages: torchsummary, torch, torchvision, torchaudio
Successfully installed torch-1.12.0+cpu torchaudio-0.12.0+cpu torchsummary-1.5.1 torchvision-0.13.0+cpu
import numpy as np
import pickle as pkl
import torch
import torch.nn.functional as F
import torchvision
import tvm
import tvm.testing

from matplotlib import pyplot as plt
from torch import nn
from torchvision import transforms
from tvm import topi, relax, te
from tvm.script import tir as T

以下是用 PyTorch 定义的模型。该模型接受一批图像为输入,然后对它们依次作用卷积层,激活层,池化层和全连接层,得到分类结果。

batch_size = 4
input_shape = (batch_size, 1, 28, 28)  # NCHW 布局


def pytorch_model(weight_map):
    model = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), bias=True),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=(2, 2)),
        nn.Flatten(),
        nn.Linear(in_features=5408, out_features=100, bias=True),
        nn.ReLU(),
        nn.Linear(in_features=100, out_features=10, bias=True),
        nn.Softmax(dim=1)
    ).cpu()
    name_map = {
        "0.weight": "conv2d_weight",
        "0.bias": "conv2d_bias",
        "4.weight": "linear0_weight",
        "4.bias": "linear0_bias",
        "6.weight": "linear1_weight",
        "6.bias": "linear1_bias",
    }
    for name, param in model.named_parameters():
        param.data = torch.from_numpy(weight_map[name_map[name]]).cpu()
    return model

在 Fashion MNIST 数据集上的预训练权重图。

# Hide outputs
!wget -nc https://github.com/mlc-ai/web-data/raw/main/models/fasionmnist_mlp_assignment_params.pkl
File ‘fasionmnist_mlp_assignment_params.pkl’ already there; not retrieving.

可以看到它的准确率约为84%。

# 从文件中加载权重图。
# 权值图对试验数据的预测精度约为 83.3%。
weight_map = pkl.load(open("fasionmnist_mlp_assignment_params.pkl", "rb"))
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']


def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        print_img = True
        for data, label in test_loader:
            data, label = data.cpu(), label.cpu()
            output = model(data)
            # sum up batch loss
            test_loss += F.nll_loss(output, label, reduction="sum").item()
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            if print_img:
                imshow(data[0])
                print("predict: {}, label: {}".format(class_names[pred[0][0]], class_names[label[0]]))
                print_img = False
            correct += pred.eq(label.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print("\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


test_data = torchvision.datasets.FashionMNIST(
    "./data",
    download=True,
    train=False,
    transform=transforms.Compose([transforms.ToTensor()])
)
test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=batch_size, shuffle=False)
test(pytorch_model(weight_map), test_loader)
../_images/4363cf464ac87db604eb77d05e1d3e371f9618f203ed6a32572e9cd8e41ca558.png
predict: Ankle boot, label: Ankle boot

Test set: Average loss: -0.8369, Accuracy: 8388/10000 (84%)

第二部分: 从PyTorch迁移模型#

为了展示机器学习编译对端到端模型的抽象,需要将模型从 PyTorch 迁移并转换为 TVMScript 实现。然后,手工迁移很难。正如你在 TensorIR 练习中所体验的那样,为模型中的每一层写一个元张量函数需要大量的人力来完成。另外,手工写这些函数是容易犯错的。你可以想象,当你写了几百行,但其中有零星几个 bug,那么找到 bug 的过程将会是痛苦的。

幸运的是,在 TVM 中有简单的多的方法能够迁移模型。TVM 提供了 relax.BlockBuilder 类,它能够从空白的 IRModule 开始一步步的构建端到端模型。(回忆在第四节课中介绍的 Relax 的 Dataflow Block,这里的 “block” 就是代表了 Relax 函数中的 Dataflow Block)

具体而言,在 BlockBuilder 中有 emit_te 的 API,它可以将张量表达式(第三节课中介绍过)的算子描述转变成对应 TensorIR 函数的 call_tir 运算(call_tir 在第四节课中介绍过)。与手工写 TensorIR 函数相比,写张量表达式描述可以用几行代码来完成,这减少了需要的工作量和犯错的概率。

emit_te 的函数签名是 emit_te(func, *input),其中 func 是返回张量表达式的函数,而 *inputfunc 的输入。

从一个例子开始详细介绍。在下方的代码块中,relu 是返回 ReLU 算子的张量表达式描述的函数。为了构建执行单个 ReLU 算子的 Relax 函数,在 emit_te_example 中首先定义了 BlockBuilder 实例 bb。也定义了 2 维 128x128 大小的张量变量 x,它将作为 ReLU 运算的输入张量(同时也是 Relax 函数的输入)。

在这之后,用 with bb.function(name, [*input]) API构建以 x 为输入的 Relax 函数 main。然后构建 dataflow block。在这个 dataflow block 里,首先用 emit_te 生成调用 ReLU 算子的 call_tir。这里 emit_te 在 IRModule 中生成了名为 relu 的 TensorIR 函数,然后在 dataflow block 中生成 call_tir(relu, (x,), (128, 128), dtype="float32") 运算。call_tir 之后是函数返回。

在这一构造之后,BlockBuilder 实例 bb 包含构建完的 IRModule,它可以通过 bb.get() 得到。

def relu(A):
    B = te.compute(shape=(128, 128), fcompute=lambda i, j: te.max(A[i, j], 0), name="B")
    return B


def emit_te_example():
    bb = relax.BlockBuilder()
    x = relax.Var("x", (128, 128), relax.DynTensorType(2, "float32"))
    with bb.function("main", [x]):
        with bb.dataflow():
            lv0 = bb.emit_te(relu, x)
            gv = bb.emit_output(lv0)
        bb.emit_func_output(gv)
    return bb.get()

函数 emit_te_example 返回构造得到的 IRModule。为了看的更清楚,可以输出这一 IRModule。

import IPython

mod = emit_te_example()
IPython.display.Code(mod.script(), language="python")

正如你看到的,通过 BlockBuilder 生成的 IRModule 确实包含了 ReLU 的 TensorIR 实现和含有调用 ReLU 实现的 call_tir 的 Relax 函数。

现在轮到你来用 BlockBuilder 和 emit_te 来创建和之前定义的 PyTorch 模型等价的 IRModule。你可以自己为所有的算子写张量表达式描述。或者,TVM 提供了 TOPI(TVM Operator Inventory)库,它为不同的算子提供了张量表达式描述。如果你愿意阅读 文档 来弄懂它的用法,这也是被鼓励的。我们提供了测试函数来检查你的 IRModule 的正确性。

注意到每个 Conv2d 层和 linear 层都包含了偏置加法,这应该在你构建的 IRModule 中被体现。

def create_model_via_emit_te():
    bb = relax.BlockBuilder()
    x = relax.Var("x", input_shape, relax.DynTensorType(batch_size, "float32"))

    conv2d_weight = relax.const(weight_map["conv2d_weight"], "float32")
    conv2d_bias = relax.const(weight_map["conv2d_bias"].reshape(1, 32, 1, 1), "float32")
    linear0_weight = relax.const(weight_map["linear0_weight"], "float32")
    linear0_bias = relax.const(weight_map["linear0_bias"].reshape(1, 100), "float32")
    linear1_weight = relax.const(weight_map["linear1_weight"], "float32")
    linear1_bias = relax.const(weight_map["linear1_bias"].reshape(1, 10), "float32")

    with bb.function("main", [x]):
        with bb.dataflow():
           # TODO
           ...
        bb.emit_func_output(gv)

    return bb.get()


def build_mod(mod):
    exec = relax.vm.build(mod, "llvm")
    dev = tvm.cpu()
    vm = relax.VirtualMachine(exec, dev)
    return vm


def check_equivalence(mod, torch_model, test_loader):
    torch_model.eval()
    with torch.no_grad():
        rt_mod = build_mod(mod)
        for data, label in test_loader:
            data, label = data.cpu(), label.cpu()
            output_from_pytorch = torch_model(data).numpy()
            output_from_relax = rt_mod["main"](tvm.nd.array(data, tvm.cpu())).numpy()
            tvm.testing.assert_allclose(output_from_pytorch, output_from_relax, rtol=1e-4)


test_data = torchvision.datasets.FashionMNIST(
    "./data",
    download=True,
    train=False,
    transform=transforms.Compose([transforms.ToTensor()])
)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

mod = create_model_via_emit_te()
torch_model = pytorch_model()

check_equivalence(mod, torch_model, test_loader)
IPython.display.Code(mod.script(), language="python")

第三部分: 使用库#

正如我们在第四节课中谈到的,我们可以将torch函数整合进IRModule。步骤包括注册一个外部运行时函数,和在 IRModule 中用 call_tir 调用。

这里是用 torch matmul 和 torch add 拉力实现 linear 层的例子。你也可以在第四节课的笔记中找到这个例子。

@tvm.register_func("env.linear", override=True)
def torch_linear(x: tvm.nd.NDArray,
                 w: tvm.nd.NDArray,
                 b: tvm.nd.NDArray,
                 out: tvm.nd.NDArray):
    x_torch = torch.from_dlpack(x)
    w_torch = torch.from_dlpack(w)
    b_torch = torch.from_dlpack(b)
    out_torch = torch.from_dlpack(out)
    torch.mm(x_torch, w_torch.T, out=out_torch)
    torch.add(out_torch, b_torch, out=out_torch)


@tvm.script.ir_module
class MyModuleWithExternCall:
    @R.function
    def main(x: Tensor((1, 784), "float32"),
             w0: Tensor((128, 784), "float32"),
             b0: Tensor((128,), "float32")):
        # block 0
        with R.dataflow():
            lv0 = R.call_tir("env.linear", (x, w0, b0), (1, 128), dtype="float32")
            ...
        return ...

请为你在第二部分中创建的 IRModule 中的卷积层注册外部函数。你需要使用 NumPy 或者 PyTorch 作为你的函数实现。

你可能需要使用 BlockBuilder.emit 在正在构建的 Relax 函数的结尾直接添加 call_tir 运算。

def create_model_with_torch_func():
    bb = relax.BlockBuilder()

    x = relax.Var("x", input_shape, relax.DynTensorType(4, "float32"))

    conv2d_weight = relax.const(weight_map["conv2d_weight"], "float32")
    conv2d_bias = relax.const(weight_map["conv2d_bias"].reshape(1, 32, 1, 1), "float32")
    linear0_weight = relax.const(weight_map["linear0_weight"], "float32")
    linear0_bias = relax.const(weight_map["linear0_bias"].reshape(1, 100), "float32")
    linear1_weight = relax.const(weight_map["linear1_weight"], "float32")
    linear1_bias = relax.const(weight_map["linear1_bias"].reshape(1, 10), "float32")

    with bb.function("main", [x]):
        with bb.dataflow():
            # TODO:
            ...
        bb.emit_func_output(gv)

    return bb.get()


test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
mod = create_model_with_torch_func()
check_equivalence(mod, torch_model, test_loader)

第四部分: 端到端模型中的程序变换#

在 TensorIR 练习中,学会了如何变换单个 TensorIR 函数。在端到端模型中变换是类似的。

和批量矩阵乘法相比,让我们关注更加有挑战性的算子:conv2d(二维卷积)。

首先,介绍一些新的原语:

  • compute_inline:它将 block 内联到另一个 block 中,以减少内存使用大小和内存访问次数

  • fuse:和 split 相对。融合多个轴。这里 fuseparallel / vectorize / unroll 一起使用,以增加并行度。

@T.prim_func
def before_inline(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.alloc_buffer((128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0


sch = tvm.tir.Schedule(before_inline)
sch.compute_inline(sch.get_block("B"))
IPython.display.Code(sch.mod["main"].script(), language="python")
@T.prim_func
def before_fuse(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0


sch = tvm.tir.Schedule(before_fuse)
i, j = sch.get_loops(sch.get_block("B"))
sch.fuse(i, j)
IPython.display.Code(sch.mod["main"].script(), language="python")

现在我们首先为第二部分中得到的IRModule创建一个schedule,然后对其中的conv2d TensorIR函数变换。和TensorIR练习类似,我们提供了一个目标函数。但请注意,目标函数不是标准答案,原因如下:

  • 它可能不能在所有硬件中都取得最佳性能

  • 原始的conv2d TensorIR实现可能不同,这决定与你在第二部分中使用的张量表达式描述:

    • 如果你将conv2d的计算和偏置加法的计算放在了一个张量表达式中,那么在变换完成的TensorIR函数的末尾应该有一个计算偏置加法的block

    • 如果你将上述两个计算分开在不同的张量表达式,或者你使用了TOPI提供的conv2d,那么变换完成的TensorIR函数末尾不应该有计算偏置加法的block。下面给出的目标函数是用TOPI conv2d获得的TensorIR函数做变换后生成的。

@T.prim_func
def target_func(rxplaceholder: T.Buffer[(4, 1, 28, 28), "float32"], rxplaceholder_1: T.Buffer[(32, 1, 3, 3), "float32"], conv2d_nchw: T.Buffer[(4, 32, 26, 26), "float32"]) -> None:
    T.func_attr({"global_symbol": "conv2d", "tir.noalias": True})
    # body
    # with T.block("root")
    for i0_0_i1_0_i2_0_i3_0_fused in T.parallel(2704):
        for i0_1_i1_1_fused_init in T.unroll(8):
            for i2_1_i3_1_fused_init in T.vectorized(4):
                with T.block("conv2d_nchw_init"):
                    nn = T.axis.spatial(
                        4, i0_0_i1_0_i2_0_i3_0_fused // 1352 * 2 + i0_1_i1_1_fused_init // 4)
                    ff = T.axis.spatial(
                        32, i0_0_i1_0_i2_0_i3_0_fused % 1352 // 169 * 4 + i0_1_i1_1_fused_init % 4)
                    yy = T.axis.spatial(
                        26, i0_0_i1_0_i2_0_i3_0_fused % 169 // 13 * 2 + i2_1_i3_1_fused_init // 2)
                    xx = T.axis.spatial(
                        26, i0_0_i1_0_i2_0_i3_0_fused % 13 * 2 + i2_1_i3_1_fused_init % 2)
                    T.reads()
                    T.writes(conv2d_nchw[nn, ff, yy, xx])
                    conv2d_nchw[nn, ff, yy, xx] = T.float32(0)
        for i4, i5, i6 in T.grid(1, 3, 3):
            for i0_1_i1_1_fused in T.unroll(8):
                for i2_1_i3_1_fused in T.vectorized(4):
                    with T.block("conv2d_nchw_update"):
                        nn = T.axis.spatial(
                            4, i0_0_i1_0_i2_0_i3_0_fused // 1352 * 2 + i0_1_i1_1_fused // 4)
                        ff = T.axis.spatial(
                            32, i0_0_i1_0_i2_0_i3_0_fused % 1352 // 169 * 4 + i0_1_i1_1_fused % 4)
                        yy = T.axis.spatial(
                            26, i0_0_i1_0_i2_0_i3_0_fused % 169 // 13 * 2 + i2_1_i3_1_fused // 2)
                        xx = T.axis.spatial(
                            26, i0_0_i1_0_i2_0_i3_0_fused % 13 * 2 + i2_1_i3_1_fused % 2)
                        rc, ry, rx = T.axis.remap("RRR", [i4, i5, i6])
                        T.reads(conv2d_nchw[nn, ff, yy, xx], rxplaceholder[nn,
                                rc, yy + ry, xx + rx], rxplaceholder_1[ff, rc, ry, rx])
                        T.writes(conv2d_nchw[nn, ff, yy, xx])
                        conv2d_nchw[nn, ff, yy, xx] = conv2d_nchw[nn, ff, yy, xx] + \
                            rxplaceholder[nn, rc, yy + ry, xx +
                                          rx] * rxplaceholder_1[ff, rc, ry, rx]

和TensorIR练习中不同的是, 这里schedule是为一个IRModule创建的,而不是TensorIR函数. 因此,当使用sch.get_block时,需要提供TensorIR函数名字,如下方所示。

mod = create_model_via_emit_te()
sch = tvm.tir.Schedule(mod)

# Step 1. Get blocks
# block = sch.get_block(name="your_block_name", func_name="your_function_name")

# Step 2. Inline the padding block (if exists)

# Step 3. Get loops

# Step 4. Organize the loops

# Step 5. decompose reduction

# Step 6. fuse + vectorize / fuse + parallel / fuse + unroll

IPython.display.Code(sch.mod.script(), language="python")

同样,我们可以测试变换后IRModule的正确性。

test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
check_equivalence(sch.mod, torch_model, test_loader)