Ep9: Computational Graph Optimization
导航
Ep9: Computational Graph Optimization#
Install packages#
For this course, we will use some ongoing development in tvm, which is an open-source machine learning compilation framework. We provide the following command to install a packaged version for mlc course.
!python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
Requirement already satisfied: mlc-ai-nightly in /usr/local/lib/python3.7/dist-packages (0.9.dev2226+gf68b7661e)
Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly) (22.1.0)
Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly) (5.4.8)
Requirement already satisfied: synr==0.6.0 in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly) (0.6.0)
Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly) (4.4.2)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly) (1.7.3)
Requirement already satisfied: tornado in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly) (5.1.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly) (1.21.6)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from mlc-ai-nightly) (1.5.0)
Prelude#
Most of the MLC process can be viewed as transformation among tensor functions. In the past chapters, we studied how to transform each primitive tensor functions individually. In this chapter, let us talk about high-level transformations among computational graphs.
Preparations#
To begin with, let us import the necessary dependencies.
# This is needed for deferring annotation parsing in TVMScript
from __future__ import annotations
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T, relax as R
from tvm import relax, topi
import numpy as np
Pattern Match and Rewriting#
To begin with, let us start with the following example.
@tvm.script.ir_module
class MyModule:
@R.function
def main(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")):
with relax.dataflow():
lv0 = relax.multiply(x, y)
gv0 = relax.add(lv0, y)
relax.output(gv0)
return gv0
MyModule
contains a relax function with two high-level operators, relax.multiply and relax.add. Our goal is to find these two operators and replace it
with a call into relax.ewise_fma
operator.
Before we dive into how to do that exactly, let us first examine the data structure that makes up the MyModule. Each IRModule contains a collection of functions, and the function body is composed of a set of data structures called abstract syntax trees (AST).
relax_func = MyModule["main"]
Each function is represented by a relax.Function node
type(relax_func)
tvm.relax.expr.Function
The function contains a list of parameters
relax_func.params
[relax.expr.Var(0xa8f2d80), relax.expr.Var(0xa750480)]
The function contains a body fields that represents its return value and set of binding blocks in the function.
func_body = relax_func.body
type(func_body)
tvm.relax.expr.SeqExpr
The function body SeqExpr contains a sequence of (binding) blocks
func_body.blocks
[relax.expr.DataflowBlock(0x3444520)]
dataflow_block = func_body.blocks[0]
In our particular case, we have a single data flow block that contains two bindings. Each binding corresponds to one of the following two lines
lv0 = relax.multiply(x, y)
gv0 = relax.add(lv0, y)
dataflow_block.bindings
binding = dataflow_block.bindings[0]
Each binding have a var field that corresponds to the left hand side of the binding (lv0, gv0)
binding.var
relax.expr.DataflowVar(0x6c36300)
And its value field corresponds to the right-hand side of the binding. Each value field corresponds to a relax.Call
node representing a call into a primitive function
binding.value
CallNode(Op(relax.multiply), [relax.expr.Var(0x3463080), relax.expr.Var(0x3463300)], (nullptr), [])
The above figure summarizes the data structure involved in this particular function.
One approach to rewrite the program would be to traverse MyModule’s AST recursively and generate a transformed AST. We can certainly do that using the python API available. However, we can use extra tooling support to simplify the process. The following code block follows a design pattern called visitor pattern that allows us to visit each AST node and rewrite them to transformed versions.
@relax.expr_functor.mutator
class EwiseFMARewriter(relax.PyExprMutator):
def visit_call_(self, call):
call = self.visit_expr_post_order(call)
add_op = tvm.ir.Op.get("relax.add")
multiply_op = tvm.ir.Op.get("relax.multiply")
ewise_fma_op = tvm.ir.Op.get("relax.ewise_fma")
if call.op != add_op:
return call
value = self.lookup_binding(call.args[0])
if not isinstance(value, relax.Call) or value.op != multiply_op:
return call
fma_call = relax.Call(
ewise_fma_op, [value.args[0], value.args[1], call.args[1]], None, None
)
return fma_call
updated_fn = EwiseFMARewriter().visit_expr(MyModule["main"])
updated_fn.show()
@R.function
def main(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv0: Tensor((3, 4), "float32") = relax.multiply(x, y)
gv0: Tensor((3, 4), "float32") = relax.ewise_fma(x, y, y)
R.output(gv0)
return gv0
We can go ahead and run the code. Note that the result rewrites gv0 to the fused operator but leaves lv0 in the code. We can use remove_all_unused
to further simplify the code block.
relax.analysis.remove_all_unused(updated_fn).show()
@R.function
def main(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
gv0: Tensor((3, 4), "float32") = relax.ewise_fma(x, y, y)
R.output(gv0)
return gv0
Fuse Linear and ReLU#
Now we have get a basic taste of graph rewriting. Let us try it on an end to end model.
!wget -nc https://github.com/mlc-ai/web-data/raw/main/models/fasionmnist_mlp_params.pkl
File ‘fasionmnist_mlp_params.pkl’ already there; not retrieving.
import pickle as pkl
mlp_params = pkl.load(open("fasionmnist_mlp_params.pkl", "rb"))
The following code reconstructs the FashionMNIST MLP model we used in our past chapters. To simplify our explaination, we directly construct the model using high-level operators such as relax.op.add
and relax.op.dense
.
def create_model():
bb = relax.BlockBuilder()
x = relax.Var("x", (1, 784), relax.DynTensorType(2, "float32"))
w0 = relax.const(mlp_params["w0"], "float32")
b0 = relax.const(mlp_params["b0"], "float32")
w1 = relax.const(mlp_params["w1"], "float32")
b1 = relax.const(mlp_params["b1"], "float32")
with bb.function("main", [x]):
with bb.dataflow():
lv0 = bb.emit(relax.op.dense(x, w0))
lv1 = bb.emit(relax.op.add(lv0, b0))
lv2 = bb.emit(relax.op.relu(lv1))
lv3 = bb.emit(relax.op.dense(lv2, w1))
lv4 = bb.emit(relax.op.add(lv3, b1))
gv = bb.emit_output(lv4)
bb.emit_func_output(gv)
return bb.get()
MLPModel = create_model()
MLPModel.show()
@tvm.script.ir_module
class Module:
@R.function
def main(x: Tensor((1, 784), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv: Tensor((1, 128), "float32") = relax.nn.dense(x, meta[relay.Constant][0])
lv1: Tensor((1, 128), "float32") = relax.add(lv, meta[relay.Constant][1])
lv2: Tensor((1, 128), "float32") = relax.nn.relu(lv1)
lv3: Tensor((1, 10), "float32") = relax.nn.dense(lv2, meta[relay.Constant][2])
lv4: Tensor((1, 10), "float32") = relax.add(lv3, meta[relay.Constant][3])
gv: Tensor((1, 10), "float32") = lv4
R.output(gv)
return gv
We aim to “fuse” the dense and add operations into a single group. The following code achieves that through the following steps:
Identify
dense
andadd
patterns.Generate another fused sub-function that calls into the dense and add operators.
Replace
dense
andadd
with the fused sub-functions.
@relax.expr_functor.mutator
class DenseAddFusor(relax.PyExprMutator):
def __init__(self, mod: IRModule) -> None:
super().__init__()
self.mod_ = mod
# cache pre-defined ops
self.add_op = tvm.ir.Op.get("relax.add")
self.dense_op = tvm.ir.Op.get("relax.nn.dense")
self.counter = 0
def transform(self) -> IRModule:
for global_var, func in self.mod_.functions.items():
if not isinstance(func, relax.Function):
continue
# avoid already fused primitive functions
if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
continue
updated_func = self.visit_expr(func)
updated_func = relax.analysis.remove_all_unused(updated_func)
self.builder_.update_func(global_var, updated_func)
return self.builder_.get()
def visit_call_(self, call):
call = self.visit_expr_post_order(call)
def match_call(node, op):
if not isinstance(node, relax.Call):
return False
return node.op == op
# pattern match dense => add
if not match_call(call, self.add_op):
return call
value = self.lookup_binding(call.args[0])
if value is None:
return call
if not match_call(value, self.dense_op):
return call
x = value.args[0]
w = value.args[1]
b = call.args[1]
# construct a new fused primitive function
param_x = relax.Var("x", x.shape_, x._checked_type_)
param_w = relax.Var("w", w.shape_, w._checked_type_)
param_b = relax.Var("b", b.shape_, b._checked_type_)
bb = relax.BlockBuilder()
fn_name = "fused_dense_add%d" % (self.counter)
self.counter += 1
with bb.function(fn_name, [param_x, param_w, param_b]):
with bb.dataflow():
lv0 = bb.emit(relax.op.nn.dense(param_x, param_w))
gv = bb.emit_output(relax.op.add(lv0, param_b))
bb.emit_func_output(gv)
# Add Primitive attribute to the fused funtions
fused_fn = bb.get()[fn_name].with_attr("Primitive", 1)
global_var = self.builder_.add_func(fused_fn, fn_name)
# construct call into the fused function
return relax.Call(global_var, [x, w, b], None, None)
@tvm.ir.transform.module_pass(opt_level=2, name="DeseAddFuse")
class FuseDenseAddPass:
"""The wrapper for the LowerTensorIR pass."""
def transform_module(self, mod, ctx):
return DenseAddFusor(mod).transform()
MLPFused = FuseDenseAddPass()(MLPModel)
MLPFused.show()
@tvm.script.ir_module
class Module:
@R.function
def fused_dense_add0(x: Tensor((1, 784), "float32"), w: Tensor((128, 784), "float32"), b: Tensor((128,), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv: Tensor((1, 128), "float32") = relax.nn.dense(x, w)
gv: Tensor((1, 128), "float32") = relax.add(lv, b)
R.output(gv)
return gv
@R.function
def fused_dense_add1(x1: Tensor((1, 128), "float32"), w1: Tensor((10, 128), "float32"), b1: Tensor((10,), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv1: Tensor((1, 10), "float32") = relax.nn.dense(x1, w1)
gv1: Tensor((1, 10), "float32") = relax.add(lv1, b1)
R.output(gv1)
return gv1
@R.function
def main(x2: Tensor((1, 784), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv11: Tensor((1, 128), "float32") = fused_dense_add0(x2, meta[relay.Constant][0], meta[relay.Constant][1])
lv2: Tensor((1, 128), "float32") = relax.nn.relu(lv11)
lv4: Tensor((1, 10), "float32") = fused_dense_add1(lv2, meta[relay.Constant][2], meta[relay.Constant][3])
gv2: Tensor((1, 10), "float32") = lv4
R.output(gv2)
return gv2
Why Creating a Sub-function#
In the above example, we created two sub-functions with the prefix fuse_dense_add
. These sub-function bodies contain information about the operations performed by the fused operator. An alternative to this rewriting is simply creating a separate primitive operation for the fused operator (like ewise_fma
). However, as we are looking into fusing more operators, there can be an exponential amount of possible combinations. A sub-function that groups the fused operation together provides the same amount of information for follow-up code lowering without introducing a dedicated high-level operator for each fusion pattern.
Map to TensorIR Calls#
The fused IRModule only contains calls into high-level operations. To further low-level optimization and code generation, we need to translate those high-level primitive operators into corresponding TensorIR functions (or environment library functions).
The following code remaps high-level operations to the corresponding TensorIR functions. Here we leverage the internal block builder in each Mutator and return the transformed value using call_te
.
@relax.expr_functor.mutator
class LowerToTensorIR(relax.PyExprMutator):
def __init__(self, mod: IRModule, op_map) -> None:
super().__init__()
self.mod_ = mod
self.op_map = {
tvm.ir.Op.get(k): v for k, v in op_map.items()
}
def visit_call_(self, call):
call = self.visit_expr_post_order(call)
if call.op in self.op_map:
return self.op_map[call.op](self.builder_, call)
return call
def transform(self) -> IRModule:
for global_var, func in self.mod_.functions.items():
if not isinstance(func, relax.Function):
continue
updated_func = self.visit_expr(func)
self.builder_.update_func(global_var, updated_func)
return self.builder_.get()
def map_dense(bb, call):
x, w = call.args
return bb.call_te(topi.nn.dense, x, w)
def map_add(bb, call):
a, b = call.args
return bb.call_te(topi.add, a, b)
def map_relu(bb, call):
return bb.call_te(topi.nn.relu, call.args[0])
op_map = {
"relax.nn.dense": map_dense,
"relax.add": map_add,
"relax.nn.relu": map_relu
}
@tvm.ir.transform.module_pass(opt_level=0, name="LowerToTensorIR")
class LowerToTensorIRPass:
"""The wrapper for the LowerTensorIR pass."""
def transform_module(self, mod, ctx):
return LowerToTensorIR(mod, op_map).transform()
MLPModelTIR = LowerToTensorIRPass()(MLPFused)
MLPModelTIR.show()
@tvm.script.ir_module
class Module:
@R.function
def main(x: Tensor((1, 784), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv1: Tensor((1, 128), "float32") = fused_dense_add0(x, meta[relay.Constant][0], meta[relay.Constant][1])
lv2 = R.call_tir(relu, (lv1,), (1, 128), dtype="float32")
lv4: Tensor((1, 10), "float32") = fused_dense_add1(lv2, meta[relay.Constant][2], meta[relay.Constant][3])
gv: Tensor((1, 10), "float32") = lv4
R.output(gv)
return gv
@T.prim_func
def dense(rxplaceholder: T.Buffer[(1, 784), "float32"], rxplaceholder_1: T.Buffer[(T.int64(128), T.int64(784)), "float32"], T_matmul_NT: T.Buffer[(1, T.int64(128)), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "dense", "tir.noalias": True, "layout_free_buffers": [1]})
# body
# with T.block("root")
for i0, i1, i2 in T.grid(1, T.int64(128), 784):
with T.block("T_matmul_NT"):
i = T.axis.spatial(1, i0)
j = T.axis.spatial(T.int64(128), i1)
k = T.axis.reduce(784, i2)
T.reads(rxplaceholder[i, k], rxplaceholder_1[j, k])
T.writes(T_matmul_NT[i, j])
with T.init():
T_matmul_NT[i, j] = T.float32(0)
T_matmul_NT[i, j] = T_matmul_NT[i, j] + rxplaceholder[i, k] * rxplaceholder_1[j, k]
@T.prim_func
def add1(rxplaceholder: T.Buffer[(1, T.int64(10)), "float32"], rxplaceholder_1: T.Buffer[T.int64(10), "float32"], T_add: T.Buffer[(1, T.int64(10)), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "add1", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1 in T.grid(1, T.int64(10)):
with T.block("T_add"):
ax0 = T.axis.spatial(1, i0)
ax1 = T.axis.spatial(T.int64(10), i1)
T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1])
T.writes(T_add[ax0, ax1])
T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[ax1]
@T.prim_func
def add(rxplaceholder: T.Buffer[(1, T.int64(128)), "float32"], rxplaceholder_1: T.Buffer[T.int64(128), "float32"], T_add: T.Buffer[(1, T.int64(128)), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "add", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1 in T.grid(1, T.int64(128)):
with T.block("T_add"):
ax0 = T.axis.spatial(1, i0)
ax1 = T.axis.spatial(T.int64(128), i1)
T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1])
T.writes(T_add[ax0, ax1])
T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[ax1]
@R.function
def fused_dense_add1(x1: Tensor((1, 128), "float32"), w: Tensor((10, 128), "float32"), b: Tensor((10,), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv = R.call_tir(dense1, (x1, w), (1, 10), dtype="float32")
gv1 = R.call_tir(add1, (lv, b), (1, 10), dtype="float32")
R.output(gv1)
return gv1
@T.prim_func
def dense1(rxplaceholder: T.Buffer[(1, T.int64(128)), "float32"], rxplaceholder_1: T.Buffer[(T.int64(10), T.int64(128)), "float32"], T_matmul_NT: T.Buffer[(1, T.int64(10)), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "dense1", "tir.noalias": True, "layout_free_buffers": [1]})
# body
# with T.block("root")
for i0, i1, i2 in T.grid(1, T.int64(10), T.int64(128)):
with T.block("T_matmul_NT"):
i = T.axis.spatial(1, i0)
j = T.axis.spatial(T.int64(10), i1)
k = T.axis.reduce(T.int64(128), i2)
T.reads(rxplaceholder[i, k], rxplaceholder_1[j, k])
T.writes(T_matmul_NT[i, j])
with T.init():
T_matmul_NT[i, j] = T.float32(0)
T_matmul_NT[i, j] = T_matmul_NT[i, j] + rxplaceholder[i, k] * rxplaceholder_1[j, k]
@T.prim_func
def relu(rxplaceholder: T.Buffer[(1, T.int64(128)), "float32"], compute: T.Buffer[(1, T.int64(128)), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "relu", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1 in T.grid(1, T.int64(128)):
with T.block("compute"):
i0_1 = T.axis.spatial(1, i0)
i1_1 = T.axis.spatial(T.int64(128), i1)
T.reads(rxplaceholder[i0_1, i1_1])
T.writes(compute[i0_1, i1_1])
compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float32(0))
@R.function
def fused_dense_add0(x2: Tensor((1, 784), "float32"), w1: Tensor((128, 784), "float32"), b1: Tensor((128,), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv3 = R.call_tir(dense, (x2, w1), (1, 128), dtype="float32")
gv2 = R.call_tir(add, (lv3, b1), (1, 128), dtype="float32")
R.output(gv2)
return gv2
Note that in the above code. fused_dense_add0
and fused_dense_add1
still are high-level relax functions that calls into the corresponding TensorIR dense and add functions. We can turn them into a single TensorIR function, which then can be used for follow-up optimization and code generation phases.
MLPModelFinal = relax.transform.FuseTIR()(MLPModelTIR)
MLPModelFinal.show()
@tvm.script.ir_module
class Module:
@T.prim_func
def fused_dense_add0(x: T.Buffer[(1, 784), "float32"], w: T.Buffer[(T.int64(128), T.int64(784)), "float32"], b: T.Buffer[T.int64(128), "float32"], T_add: T.Buffer[(1, T.int64(128)), "float32"]) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "fused_dense_add0"})
# body
# with T.block("root")
T_matmul_NT = T.alloc_buffer([1, T.int64(128)], dtype="float32")
for i0, i1, i2 in T.grid(1, T.int64(128), 784):
with T.block("T_matmul_NT"):
i = T.axis.spatial(1, i0)
j = T.axis.spatial(T.int64(128), i1)
k = T.axis.reduce(784, i2)
T.reads(x[i, k], w[j, k])
T.writes(T_matmul_NT[i, j])
with T.init():
T_matmul_NT[i, j] = T.float32(0)
T_matmul_NT[i, j] = T_matmul_NT[i, j] + x[i, k] * w[j, k]
for i0, i1 in T.grid(1, T.int64(128)):
with T.block("T_add"):
ax0 = T.axis.spatial(1, i0)
ax1 = T.axis.spatial(T.int64(128), i1)
T.reads(T_matmul_NT[ax0, ax1], b[ax1])
T.writes(T_add[ax0, ax1])
T_add[ax0, ax1] = T_matmul_NT[ax0, ax1] + b[ax1]
@T.prim_func
def relu(rxplaceholder: T.Buffer[(1, T.int64(128)), "float32"], compute: T.Buffer[(1, T.int64(128)), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "relu", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1 in T.grid(1, T.int64(128)):
with T.block("compute"):
i0_1 = T.axis.spatial(1, i0)
i1_1 = T.axis.spatial(T.int64(128), i1)
T.reads(rxplaceholder[i0_1, i1_1])
T.writes(compute[i0_1, i1_1])
compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float32(0))
@T.prim_func
def fused_dense_add1(x: T.Buffer[(1, T.int64(128)), "float32"], w: T.Buffer[(T.int64(10), T.int64(128)), "float32"], b: T.Buffer[T.int64(10), "float32"], T_add: T.Buffer[(1, T.int64(10)), "float32"]) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "fused_dense_add1"})
# body
# with T.block("root")
T_matmul_NT = T.alloc_buffer([1, T.int64(10)], dtype="float32")
for i0, i1, i2 in T.grid(1, T.int64(10), T.int64(128)):
with T.block("T_matmul_NT"):
i = T.axis.spatial(1, i0)
j = T.axis.spatial(T.int64(10), i1)
k = T.axis.reduce(T.int64(128), i2)
T.reads(x[i, k], w[j, k])
T.writes(T_matmul_NT[i, j])
with T.init():
T_matmul_NT[i, j] = T.float32(0)
T_matmul_NT[i, j] = T_matmul_NT[i, j] + x[i, k] * w[j, k]
for i0, i1 in T.grid(1, T.int64(10)):
with T.block("T_add"):
ax0 = T.axis.spatial(1, i0)
ax1 = T.axis.spatial(T.int64(10), i1)
T.reads(T_matmul_NT[ax0, ax1], b[ax1])
T.writes(T_add[ax0, ax1])
T_add[ax0, ax1] = T_matmul_NT[ax0, ax1] + b[ax1]
@R.function
def main(x: Tensor((1, 784), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv1 = R.call_tir(fused_dense_add0, (x, meta[relay.Constant][0], meta[relay.Constant][1]), (1, 128), dtype="float32")
lv2 = R.call_tir(relu, (lv1,), (1, 128), dtype="float32")
lv4 = R.call_tir(fused_dense_add1, (lv2, meta[relay.Constant][2], meta[relay.Constant][3]), (1, 10), dtype="float32")
gv: Tensor((1, 10), "float32") = lv4
R.output(gv)
return gv
Build and Run#
We can go ahead and build the final module and try it out on an example picture.
import torch
import torchvision
test_data = torchvision.datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=torchvision.transforms.ToTensor()
)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
img, label = next(iter(test_loader))
img = img.reshape(1, 28, 28).numpy()
import matplotlib.pyplot as plt
plt.figure()
plt.imshow(img[0])
plt.colorbar()
plt.grid(False)
plt.show()
print("Class:", class_names[label[0]])
Class: Sneaker
ex = relax.vm.build(MLPModelFinal, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())
data_nd = tvm.nd.array(img.reshape(1, 784))
nd_res = vm["main"](data_nd)
pred_kind = np.argmax(nd_res.numpy(), axis=1)
print("MLPModule Prediction:", class_names[pred_kind[0]])
MLPModule Prediction: Sneaker
Discussions#
This section comes back to our common theme of transformation among computational graphs. Despite being minimum, this sequence of transformations covers two important optimizations we commonly do in MLC process – fusion and loop level code lowering.
Real-world MLC process can contain more powerful and robust transformations. For example, our fusion pass can create duplicated dense computations in which a dense operator is referenced in two follow-ups add operations. A robust fusion pass will detect that and choose to skip such cases. Additionally, we do not want to have to write down rules for each combination. Instead, TVM’s internal fusor will analyze the TensorIR function loop patterns and use them in fusion decisions.
Notably, each of these transformations is composable with each other. For example, we can choose to use our version of customized fusor to support additional new fusion patterns that we want to explore and then feed into an existing fusor to handle the rest of the steps.
Summary#
We can optimize tensor programs by rewriting computational graph data structures.
Visitor pattern to rewrite call nodes.
We can perform computational graph transformations, such as fusion and loop-level program lowering.