Install packages
导航
Install packages#
For the purpose of this course, we will use some ongoing development in tvm, which is an open-source machine learning compilation framework. We provide the following command to install a packaged version for mlc course.
!python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels
Looking in links: https://mlc.ai/wheels
Requirement already satisfied: mlc-ai-nightly in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (0.9.dev1956+ge3f218d71)
Requirement already satisfied: tornado in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (6.2)
Requirement already satisfied: Pygments in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (2.12.0)
Requirement already satisfied: synr==0.6.0 in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (0.6.0)
Requirement already satisfied: decorator in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.1.1)
Requirement already satisfied: numpy in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.22.3)
Requirement already satisfied: attrs in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (22.1.0)
Requirement already satisfied: cloudpickle in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (2.1.0)
Requirement already satisfied: psutil in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.9.1)
Requirement already satisfied: scipy in /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.9.0)
Prelude#
In the past chapters, we learned about how to build primitive tensor functions and connect them to form end-to-end model executions. There are three primary types of abstractions we have used so far.
A computational graph view that drives the high-level executions.
Abstraction for primitive tensor functions.
Library function calls via environment function registration.
All of these elements are encapsulated in an IRModule. Most of the MLC processes can be viewed as transformations among tensor functions.
There are many different ways to transform the same program. This chapter will discuss ways to automate some of the processes.
Preparations#
To begin with, we will import necessary dependencies and create helper functions.
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T, relax as R
import numpy as np
from tvm import relax
# This is needed for deferring annotation parsing in TVMScript
from __future__ import annotations
import IPython
def code2html(code):
"""Helper function to use pygments to turn the code string into highlighted html."""
import pygments
from pygments.lexers import Python3Lexer
from pygments.formatters import HtmlFormatter
formatter = HtmlFormatter()
html = pygments.highlight(code, Python3Lexer(), formatter)
return "<style>%s</style>%s\n" % (formatter.get_style_defs(".highlight"), html)
Recap: Transform a Primitive Tensor Function.#
Let us begin by reviewing what we did in our previous chapters – transforming a single primitive tensor function.
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"],
C: T.Buffer[(128, 128), "float32"],
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i, j, k in T.grid(128, 128, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
First, let us define a set of inputs and outputs for evaluation.
dtype = "float32"
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)
c_mm = a_np @ b_np
We can build and run MyModule
as follows.
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")
lib = tvm.build(MyModule, target="llvm")
f_timer_before = lib.time_evaluator("main", tvm.cpu())
print("Time cost of MyModule: %.3f ms" % (f_timer_before(a_nd, b_nd, c_nd).mean * 1000))
Time cost of MyModule: 2.261 ms
Next, we transform MyModule
a bit by reorganizing the loop access pattern.
def schedule_mm(sch: tvm.tir.Schedule, jfactor=4):
block_C = sch.get_block("C", "main")
i, j, k = sch.get_loops(block=block_C)
j_0, j_1 = sch.split(loop=j, factors=[None, jfactor])
sch.reorder(i, j_0, k, j_1)
sch.decompose_reduction(block_C, k)
return sch
sch = tvm.tir.Schedule(MyModule)
sch = schedule_mm(sch)
IPython.display.HTML(code2html(sch.mod.script()))
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i, j_0 in T.grid(128, 32):
for j_1_init in T.serial(4):
with T.block("C_init"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1_init)
T.reads()
T.writes(C[vi, vj])
C[vi, vj] = T.float32(0)
for k, j_1 in T.grid(128, 4):
with T.block("C_update"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1)
vk = T.axis.reduce(128, k)
T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
Then we can build and run the re-organized program.
lib = tvm.build(sch.mod, target="llvm")
f_timer_after = lib.time_evaluator("main", tvm.cpu())
print("Time cost of MyModule=>schedule_mm: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd).mean * 1000))
Time cost of MyModule=>schedule_mm: 1.553 ms
Transformation Trace#
Besides sch.mod
field, another thing tir.Schedule
offers is a trace field that can be used to show the steps involved to get to the transformed module. We can print it out using the following code.
print(sch.trace)
b0 = sch.get_block(name="C", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
l4, l5 = sch.split(loop=l2, factors=[None, 4], preserve_unit_iters=True)
sch.reorder(l1, l4, l3, l5)
b6 = sch.decompose_reduction(block=b0, loop=l3)
def schedule_mm(sch: tvm.tir.Schedule, jfactor=4):
block_C = sch.get_block("C", "main")
i, j, k = sch.get_loops(block=block_C)
j_0, j_1 = sch.split(loop=j, factors=[None, jfactor])
sch.reorder(i, j_0, k, j_1)
sch.decompose_reduction(block_C, k)
return sch
The above trace aligns with the transformations we specified in schedule_mm
. One thing to note is that the trace (plus the original program) gives us a way to completely re-derive the final output program. Let us keep that in mind; we will use trace throughout this chapter as another way to inspect the transformations.
Stochastic Schedule Transformation#
Up until now, we have specified every detail about what transformations we want to make on the original TensorIR program. Many of those choices are based on our understanding of the underlying environment, such as cache and hardware unit.
However, in practice, we may not be able to decide every detail accurately. Instead of doing so, we would like to specify what are possible ways to transform the program, while leaving out some details.
One natural way to achieve the goal is to add some stochastic (randomness) elements to our transformations. The following code does that.
def stochastic_schedule_mm(sch: tvm.tir.Schedule):
block_C = sch.get_block("C", "main")
i, j, k = sch.get_loops(block=block_C)
j_factors = sch.sample_perfect_tile(loop=j, n=2)
j_0, j_1 = sch.split(loop=j, factors=j_factors)
sch.reorder(i, j_0, k, j_1)
sch.decompose_reduction(block_C, k)
return sch
Let us compare stochastic_schedule_mm
and schedule_mm
side by side. We can find that the only difference is how to specify j_factors
. In the case of schedule_mm
, j_factors
is passed in as a parameter specified by us. In the case of stochastic_schedule_mm
, it comes from sch.sample_perfect_tile
.
As the name suggests, sch.sample_perfect_tile
tries to draw random numbers to fill in j_factors
. It samples factors such that they perfectly split the loop. For example, when the original loop size is 128
, possible ways to split the loop include: [8, 16]
, [32, 4]
, [2, 64]
(note 8 * 16 = 32 * 4 = 2 * 64 = 128
).
Let us first try to see what is the effect of stochastic_schedule_mm
by running the following code-block. Try to run the following code block multiple times and observe the outcome difference. You might find that the loop bound of j_1
changes each time we run the code-block.
sch = tvm.tir.Schedule(MyModule)
sch = stochastic_schedule_mm(sch)
IPython.display.HTML(code2html(sch.mod.script()))
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i, j_0 in T.grid(128, 32):
for j_1_init in T.serial(4):
with T.block("C_init"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1_init)
T.reads()
T.writes(C[vi, vj])
C[vi, vj] = T.float32(0)
for k, j_1 in T.grid(128, 4):
with T.block("C_update"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1)
vk = T.axis.reduce(128, k)
T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
What is happening here is that each time we run stochastic_schedule_mm
it draws a different j_factors
randomly. We can print out the trace of the latest one to see the decisions we made in sampling.
print(sch.trace)
b0 = sch.get_block(name="C", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[32, 4])
l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
sch.reorder(l1, l6, l3, l7)
b8 = sch.decompose_reduction(block=b0, loop=l3)
When we look at the trace, pay close attention to the decision=[...]
part of sample_perfect_tile
. They correspond to the value that the sampling_perfect_tile
picked in our last call to stochastic_schedule_mm
.
As an alternative way to look at different samples of stochastic_schedule_mm
, we can run the following block multiple times and look at the trace.
sch = tvm.tir.Schedule(MyModule)
sch = stochastic_schedule_mm(sch)
print(sch.trace)
b0 = sch.get_block(name="C", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[32, 4])
l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
sch.reorder(l1, l6, l3, l7)
b8 = sch.decompose_reduction(block=b0, loop=l3)
Deep Dive into Stochastic Transformation#
Now let us take a deeper dive into what happened in stochastic schedule transformations. We can find that it is a simple generalization of the original deterministic transformations, with two additional elements:
Random variables that come from
sample_perfect_tile
and other sampling operations that we did not cover in the example.Schedule operations that take action depending on the random variables.
Let us try to run the stochastic transformation step by step.
sch = tvm.tir.Schedule(MyModule)
block_C = sch.get_block("C", "main")
i, j, k = sch.get_loops(block=block_C)
j_factors = sch.sample_perfect_tile(loop=j, n=2)
type(j_factors[0])
tvm.tir.expr.Var
Elements in the j_factors
are not real integer numbers. Instead, they are symbolic variables that refer to a random variable being sampled. We can pass these variables to the transformation API to specify choices such as factor values.
print(sch.trace)
b0 = sch.get_block(name="C", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16])
The schedule trace keeps track of the choices of these symbolic variables in the decisions
field. So follow-up steps will be able to look up these choices to decide how to split the loop.
IPython.display.HTML(code2html(sch.mod.script()))
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i, j, k in T.grid(128, 128, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
If we look at the code at the current time point, we can find that the module remains the same since we only sampled the random variables but have not yet made any transformation actions based on them.
Let us now take some of the actions:
j_0, j_1 = sch.split(loop=j, factors=j_factors)
sch.reorder(i, j_0, k, j_1)
These actions are recorded in the following trace.
print(sch.trace)
b0 = sch.get_block(name="C", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16])
l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
sch.reorder(l1, l6, l3, l7)
If we retake a look at the code, the transformed module now corresponds to the updated versions after the actions are taken.
IPython.display.HTML(code2html(sch.mod.script()))
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i, j_0, k, j_1 in T.grid(128, 8, 128, 16):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 16 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
We can do some further transformations to get to the final state
sch.reorder(i, j_0, k, j_1)
sch.decompose_reduction(block_C, k)
tir.BlockRV(0x560d1be0a370)
IPython.display.HTML(code2html(sch.mod.script()))
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i, j_0 in T.grid(128, 8):
for j_1_init in T.serial(16):
with T.block("C_init"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 16 + j_1_init)
T.reads()
T.writes(C[vi, vj])
C[vi, vj] = T.float32(0)
for k, j_1 in T.grid(128, 16):
with T.block("C_update"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 16 + j_1)
vk = T.axis.reduce(128, k)
T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
Search Over Stochastic Transformations#
One thing that you might realize is that stochastic_schedule_mm
create a search space of possible programs depending on the specific decisions made at each sampling step.
Coming back to our initial intuition, we want to be able to specify a set of possible programs instead of one program. stochastic_schedule_mm
did exactly that. Of course, one natural question to ask next is what is the best choice.
We will need a search algorithm to do that. To show what can be done here, let us first try the most straightforward search algorithm – random search, in the following code block. It tries to run stochastic_schedule_mm
repetitively, gets a transformed module, runs benchmark, then book keep the best one in history.
def random_search(mod: tvm.IRModule, num_trials=5):
best_result = None
best_sch = None
for i in range(num_trials):
sch = stochastic_schedule_mm(tvm.tir.Schedule(mod))
lib = tvm.build(sch.mod, target="llvm")
f_timer_after = lib.time_evaluator("main", tvm.cpu())
result = f_timer_after(a_nd, b_nd, c_nd).mean
print("=====Attempt %d, time-cost: %.3f ms====" % (i, result * 1000))
print(sch.trace)
# book keep the best result so far
if best_result is None or result < best_result:
best_result = result
best_sch = sch
return best_sch
sch = random_search(MyModule)
=====Attempt 0, time-cost: 1.175 ms====
b0 = sch.get_block(name="C", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16])
l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
sch.reorder(l1, l6, l3, l7)
b8 = sch.decompose_reduction(block=b0, loop=l3)
=====Attempt 1, time-cost: 1.556 ms====
b0 = sch.get_block(name="C", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[32, 4])
l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
sch.reorder(l1, l6, l3, l7)
b8 = sch.decompose_reduction(block=b0, loop=l3)
=====Attempt 2, time-cost: 1.565 ms====
b0 = sch.get_block(name="C", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[32, 4])
l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
sch.reorder(l1, l6, l3, l7)
b8 = sch.decompose_reduction(block=b0, loop=l3)
=====Attempt 3, time-cost: 1.211 ms====
b0 = sch.get_block(name="C", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[16, 8])
l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
sch.reorder(l1, l6, l3, l7)
b8 = sch.decompose_reduction(block=b0, loop=l3)
=====Attempt 4, time-cost: 1.202 ms====
b0 = sch.get_block(name="C", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16])
l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
sch.reorder(l1, l6, l3, l7)
b8 = sch.decompose_reduction(block=b0, loop=l3)
If we run the code, we can find that it goes over a few choices and then returns the best run throughout five trials.
print(sch.trace)
b0 = sch.get_block(name="C", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16])
l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
sch.reorder(l1, l6, l3, l7)
b8 = sch.decompose_reduction(block=b0, loop=l3)
In practice, we use smarter algorithms. We also need to provide additional utilities, such as benchmarking on remote devices, if we are interested in optimization for other devices. TVM’s meta schedule API provides these additional capabilities.
meta_schedule
is the namespace that comes to support search over a space of possible transformations. There are many additional things that meta-schedule do behind the scene:
Parallel benchmarking across many processes.
Use cost models to avoid benchmarking each time.
Evolutionary search on the traces instead of randomly sampling at each time.
Despite these magics, the key idea remains the same: use stochastic transformation to specify a search space of good programs, tune_tir
API helps to search and find an optimized solution within the search space.
from tvm import meta_schedule as ms
sch_tuned = ms.tune_tir(
mod=MyModule,
target="llvm --num-cores=1",
config=ms.TuneConfig(
max_trials_global=64,
num_trials_per_iter=64,
),
space=ms.space_generator.ScheduleFn(stochastic_schedule_mm),
work_dir="./tune_tmp",
task_name="main"
)
2022-08-29 19:50:40.936 INFO Logging directory: ./tune_tmp/logs
2022-08-29 19:50:40.939 INFO Logging directory: ./tune_tmp/logs
2022-08-29 19:50:40.940 INFO Working directory: ./tune_tmp
2022-08-29 19:50:40.941 INFO Creating JSONDatabase. Workload at: ./tune_tmp/database_workload.json. Tuning records at: ./tune_tmp/database_tuning_record.json
2022-08-29 19:50:40.944 INFO LocalBuilder: max_workers = 24
2022-08-29 19:50:41.790 INFO LocalRunner: max_workers = 1
2022-08-29 19:50:42.693 INFO Initializing Task #0: "main"
2022-08-29 19:50:42.759 INFO
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated
------------------------------------------------------------------------------------------------------------
0 | main | 4194304 | 1 | N/A | N/A | N/A | 0 |
------------------------------------------------------------------------------------------------------------
Total trials: 0
Total latency (us): 0
2022-08-29 19:50:42.760 INFO Scheduler picks Task #0: "main"
2022-08-29 19:50:45.139 INFO Sending 5 sample(s) to builder
2022-08-29 19:50:46.263 INFO Sending 5 sample(s) to runner
---------------------------------------------------------------------------
TVMError Traceback (most recent call last)
/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb Cell 56 in <cell line: 3>()
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb#Y111sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> from tvm import meta_schedule as ms
----> <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb#Y111sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a> sch_tuned = ms.tune_tir(
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb#Y111sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a> mod=MyModule,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb#Y111sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a> target="llvm --num-cores=1",
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb#Y111sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a> config=ms.TuneConfig(
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb#Y111sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a> max_trials_global=64,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb#Y111sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a> num_trials_per_iter=64,
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb#Y111sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a> ),
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb#Y111sdnNjb2RlLXJlbW90ZQ%3D%3D?line=9'>10</a> space=ms.space_generator.ScheduleFn(stochastic_schedule_mm),
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb#Y111sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10'>11</a> work_dir="./tune_tmp",
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb#Y111sdnNjb2RlLXJlbW90ZQ%3D%3D?line=11'>12</a> task_name="main"
<a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/books/mlcbook/doc/study/5_Automatic_Program_Optimization.ipynb#Y111sdnNjb2RlLXJlbW90ZQ%3D%3D?line=12'>13</a> )
File /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages/tvm/meta_schedule/tune.py:414, in tune_tir(mod, target, config, work_dir, builder, runner, database, cost_model, measure_callbacks, space, sch_rules, postprocs, mutator_probs, task_name, num_threads)
412 target = default_config.target(target)
413 # pylint: enable=protected-access
--> 414 database = tune_extracted_tasks(
415 extracted_tasks=[
416 ExtractedTask(
417 task_name=task_name,
418 mod=mod,
419 dispatched=[mod],
420 target=target,
421 weight=1,
422 ),
423 ],
424 config=config,
425 work_dir=work_dir,
426 builder=builder,
427 runner=runner,
428 database=database,
429 cost_model=cost_model,
430 measure_callbacks=measure_callbacks,
431 space=space,
432 sch_rules=sch_rules,
433 postprocs=postprocs,
434 mutator_probs=mutator_probs,
435 num_threads=num_threads,
436 )
437 with Profiler.timeit("ApplyHistoryBest"):
438 bests: List[TuningRecord] = database.get_top_k(database.commit_workload(mod), top_k=1)
File /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages/tvm/meta_schedule/tune.py:350, in tune_extracted_tasks(extracted_tasks, config, work_dir, builder, runner, database, cost_model, measure_callbacks, space, sch_rules, postprocs, mutator_probs, num_threads)
340 task_scheduler = config.create_task_scheduler(
341 tasks=tune_contexts,
342 task_weights=[float(t.weight) for t in extracted_tasks],
(...)
347 measure_callbacks=measure_callbacks,
348 )
349 if config.max_trials_global > 0:
--> 350 task_scheduler.tune()
351 cost_model.save(osp.join(work_dir, "cost_model.xgb"))
352 return database
File /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages/tvm/meta_schedule/task_scheduler/task_scheduler.py:72, in TaskScheduler.tune(self)
70 def tune(self) -> None:
71 """Auto-tuning."""
---> 72 _ffi_api.TaskSchedulerTune(self)
File /media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages/tvm/_ffi/_ctypes/packed_func.py:237, in PackedFuncBase.__call__(self, *args)
225 ret_tcode = ctypes.c_int()
226 if (
227 _LIB.TVMFuncCall(
228 self.handle,
(...)
235 != 0
236 ):
--> 237 raise get_last_ffi_error()
238 _ = temp_args
239 _ = args
TVMError: Traceback (most recent call last):
8: TVMFuncCall
7: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::string)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
6: tvm::meta_schedule::TaskSchedulerNode::Tune()
5: tvm::meta_schedule::GradientBasedNode::NextTaskId()
4: tvm::meta_schedule::GradientBasedNode::JoinRunningTask(int)
3: tvm::meta_schedule::TaskSchedulerNode::JoinRunningTask(int)
2: tvm::meta_schedule::UpdateCostModelNode::Apply(tvm::meta_schedule::TaskScheduler const&, int, tvm::runtime::Array<tvm::meta_schedule::MeasureCandidate, void> const&, tvm::runtime::Array<tvm::meta_schedule::BuilderResult, void> const&, tvm::runtime::Array<tvm::meta_schedule::RunnerResult, void> const&)
1: tvm::meta_schedule::PyCostModelNode::Update(tvm::meta_schedule::TuneContext const&, tvm::runtime::Array<tvm::meta_schedule::MeasureCandidate, void> const&, tvm::runtime::Array<tvm::meta_schedule::RunnerResult, void> const&)
0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) [clone .cold]
File "/media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
rv = local_pyfunc(*pyargs)
File "/media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages/tvm/meta_schedule/utils.py", line 77, in method
return getattr(inst, name)(*args, **kwargs)
File "/media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages/tvm/meta_schedule/cost_model/xgb_model.py", line 513, in update
self._train(
File "/media/pc/data/4tb/lxw/libs/anaconda3/envs/mlc/lib/python3.10/site-packages/tvm/meta_schedule/cost_model/xgb_model.py", line 563, in _train
import xgboost as xgb # type: ignore # pylint: disable=import-outside-toplevel
ModuleNotFoundError: No module named 'xgboost'
tune_tir
functions return an optimized schedule found during the tuning process.
print(sch_tuned.trace)
IPython.display.HTML(code2html(sch_tuned.mod.script()))
lib = tvm.build(sch_tuned.mod, target="llvm")
f_timer_after = lib.time_evaluator("main", tvm.cpu())
print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd).mean * 1000))
Leverage Default AutoScheduling#
In the last section, we showed how to tune a workload with stochastic transformations that we crafted. Metaschedule comes with its own built-in set of generic stochastic transformations that works for a broad set of TensorIR computations. This approach is also called auto-scheduling, as the search space is generated by the system. We can run it by removing the line space=ms.space_generator.ScheduleFn(stochastic_schedule_mm)
.
Under the hood, the meta-scheduler analyzes each block’s data access and loop patterns and proposes stochastic transformations to the program. We won’t go into these generic transformations in this chapter but want to note that they are also just stochastic transformations coupled with an analysis of the code. We can use the same mechanisms learned in the last section to enhance auto-scheduling. We will touch base on this topic in future chapters.
sch_tuned = ms.tune_tir(
mod=MyModule,
target="llvm --num-cores=1",
config=ms.TuneConfig(
max_trials_global=64,
num_trials_per_iter=64,
),
work_dir="./tune_tmp",
task_name="main",
)
lib = tvm.build(sch_tuned.mod, target="llvm")
f_timer_after = lib.time_evaluator("main", tvm.cpu())
print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd).mean * 1000))
The result gets much faster than our original code. We can take a glimpse at the trace and the final code. For the purpose of this chapter, you do not need to understand all the transformations. At a high level, the trace involves:
More levels of loop tiling transformations.
Vectorization of intermediate computations.
Parallelization and unrolling of loops.
sch_tuned.trace
IPython.display.HTML(code2html(sch_tuned.mod.script()))
Section Checkpoint#
Let us have a checkpoint about what we have learned so far.
Stochastic schedule allow us to express “what are the possible transformations”.
Metaschedule’s
tune_tir
API helps to find a good solution within the space.Metaschedule comes with a default built-in set of stochastic transformations that covers a broad range of search space.
Putting Things Back to End to End Model Execution#
Up until now, we have learned to automate program optimization of a single tensor primitive function. How can we put it back and improve our end-to-end model execution?
From the MLC perspective, the automated search is a modular step, and we just need to replace the original primitive function implementation with the new one provided by the tuned result.
We will reuse the two-layer MLP example from the last chapter.
import torchvision
import torch
test_data = torchvision.datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=torchvision.transforms.ToTensor()
)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
img, label = next(iter(test_loader))
img = img.reshape(1, 28, 28).numpy()
import matplotlib.pyplot as plt
plt.figure()
plt.imshow(img[0])
plt.colorbar()
plt.grid(False)
plt.show()
print("Class:", class_names[label[0]])
We also download pre-packed model parameters that we will use in our examples.
!wget https://github.com/mlc-ai/web-data/raw/main/models/fasionmnist_mlp_params.pkl
As a reminder, the above figure shows the model of interest.
import pickle as pkl
mlp_params = pkl.load(open("fasionmnist_mlp_params.pkl", "rb"))
data_nd = tvm.nd.array(img.reshape(1, 784))
nd_params = {k: tvm.nd.array(v) for k, v in mlp_params.items()}
Let us use a mixture module where most of the components call into environment function and also come with one TensorIR function linear0
.
@tvm.script.ir_module
class MyModuleMixture:
@T.prim_func
def linear0(X: T.Buffer[(1, 784), "float32"],
W: T.Buffer[(128, 784), "float32"],
B: T.Buffer[(128,), "float32"],
Z: T.Buffer[(1, 128), "float32"]):
T.func_attr({"global_symbol": "linear0", "tir.noalias": True})
Y = T.alloc_buffer((1, 128), "float32")
for i, j, k in T.grid(1, 128, 784):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk]
for i, j in T.grid(1, 128):
with T.block("Z"):
vi, vj = T.axis.remap("SS", [i, j])
Z[vi, vj] = Y[vi, vj] + B[vj]
@R.function
def main(x: Tensor((1, 784), "float32"),
w0: Tensor((128, 784), "float32"),
b0: Tensor((128,), "float32"),
w1: Tensor((10, 128), "float32"),
b1: Tensor((10,), "float32")):
with R.dataflow():
lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32")
lv1 = R.call_tir("env.relu", (lv0,), (1, 128), dtype="float32")
out = R.call_tir("env.linear", (lv1, w1, b1), (1, 10), dtype="float32")
R.output(out)
return out
@tvm.register_func("env.linear", override=True)
def torch_linear(x: tvm.nd.NDArray,
w: tvm.nd.NDArray,
b: tvm.nd.NDArray,
out: tvm.nd.NDArray):
x_torch = torch.from_dlpack(x)
w_torch = torch.from_dlpack(w)
b_torch = torch.from_dlpack(b)
out_torch = torch.from_dlpack(out)
torch.mm(x_torch, w_torch.T, out=out_torch)
torch.add(out_torch, b_torch, out=out_torch)
@tvm.register_func("env.relu", override=True)
def lnumpy_relu(x: tvm.nd.NDArray,
out: tvm.nd.NDArray):
x_torch = torch.from_dlpack(x)
out_torch = torch.from_dlpack(out)
torch.maximum(x_torch, torch.Tensor([0.0]), out=out_torch)
We can bind the parameters and see if it gives the correct prediction.
MyModuleWithParams = relax.transform.BindParams("main", nd_params)(MyModuleMixture)
ex = relax.vm.build(MyModuleWithParams, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())
nd_res = vm["main"](data_nd)
pred_kind = np.argmax(nd_res.numpy(), axis=1)
print("MyModuleWithParams Prediction:", class_names[pred_kind[0]])
The following code evaluates the run time cost of the module before the transformation. Note that because this is a small model, the number can fluctuate a bit between runs, so we just need to read the overall magnitude.
ftimer = vm.module.time_evaluator("main", tvm.cpu(), number=100)
print("MyModuleWithParams time-cost: %g ms" % (ftimer(data_nd).mean * 1000))
We are now ready to tune the linear0
. Our overall process is summarized in the following diagram.
Currently, tune API only takes an IRModule with one main
function, so we first get the linear0
out into another module’s main function and pass it to tune
mod_linear = tvm.IRModule.from_expr(MyModuleMixture["linear0"].with_attr("global_symbol", "main"))
IPython.display.HTML(code2html(mod_linear.script()))
sch_tuned_linear = ms.tune_tir(
mod=mod_linear,
target="llvm --num-cores=1",
config=ms.TuneConfig(
max_trials_global=64,
num_trials_per_iter=64,
),
work_dir="./tune_tmp",
task_name="main",
)
Now we need to replace the original linear0
with the new function after tuning. We can do that by first getting a global_var
, a pointer
reference to the functions inside the IRModule, then calling update_func
to replace the function with the new one.
MyModuleWithParams2 = relax.transform.BindParams("main", nd_params)(MyModuleMixture)
new_func = sch_tuned_linear.mod["main"].with_attr("global_symbol", "linear0")
gv = MyModuleWithParams2.get_global_var("linear0")
MyModuleWithParams2.update_func(gv, new_func)
IPython.display.HTML(code2html(MyModuleWithParams2.script()))
We can find that the linear0
has been replaced in the above code.
ex = relax.vm.build(MyModuleWithParams2, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())
nd_res = vm["main"](data_nd)
pred_kind = np.argmax(nd_res.numpy(), axis=1)
print("MyModuleWithParams2 Prediction:", class_names[pred_kind[0]])
Running the code again, we can find that we get an observable amount of time reduction, mainly thanks to the new linear0
function.
ftimer = vm.module.time_evaluator("main", tvm.cpu(), number=50)
print("MyModuleWithParams2 time-cost: %g ms" % (ftimer(data_nd).mean * 1000))
Discussions#
We might notice that our previous two chapters focused on abstraction while this chapter starts to focus on transformation. Stochastic transformations specify what can be possibly optimized without nailing down all the choices. The meta-schedule API helps us to search over the space of possible transformations and pick the best one.
Importantly, putting the search result back into the end-to-end flow is just a matter of replacing the implementation of the original function with a new one that is informed by the tuning process.
So we again are following the generic MLC process in the figure below. In future lectures, we will introduce more kinds of transformations on primitive functions and computational graph functions. A good MLC process composes these transformations together to form an end deployment form.
Summary#
Stochastic transformations help us to specify a search space of possible programs.
MetaSchedule searches over the search space and finds an optimized one.
We can use another transformation to replace the primitive tensor function with optimized ones and an updated end-to-end execution flow.