张量程序抽象
导航
张量程序抽象#
安装包#
提供以下命令来安装 mlc 课程的打包版本。
!python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels
Looking in links: https://mlc.ai/wheels
Requirement already satisfied: mlc-ai-nightly in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (0.9.dev1664+g1f3985de0)
Requirement already satisfied: decorator in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.1.1)
Requirement already satisfied: numpy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.23.1)
Requirement already satisfied: attrs in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (21.4.0)
Requirement already satisfied: scipy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.8.1)
Requirement already satisfied: cloudpickle in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (2.1.0)
Requirement already satisfied: synr==0.6.0 in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (0.6.0)
Requirement already satisfied: psutil in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.9.1)
Requirement already satisfied: tornado in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (6.1)
构建张量程序#
从构造执行两个向量的加法的张量程序开始。
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np
from IPython.display import Code
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(A: T.Buffer[128, "float32"],
B: T.Buffer[128, "float32"],
C: T.Buffer[128, "float32"]):
# 函数的额外注解
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i in range(128):
with T.block("C"):
# 在 spatial 域中声明数据并行迭代器
vi = T.axis.spatial(128, i)
C[vi] = A[vi] + B[vi]
TVMScript 是用 Python ast 表达张量程序的方式。注意,这段代码实际上并不对应于 python 程序,而是可以在 MLC 进程中使用的张量程序。该语言被设计成与 python 语法保持一致的附加结构,以方便分析和变换。
type(MyModule)
tvm.ir.module.IRModule
MyModule
是 IRModule 数据结构的实例(用来保存张量函数的集合)。
可以使用 script
函数获得基于 IRModule 表示的字符串。这个函数对于在变换的每个步骤中检查模块非常有用。
Code(MyModule.script(), language="python")
@tvm.script.ir_module
class Module:
@tir.prim_func
def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with tir.block("root")
for i in tir.serial(128):
with tir.block("C"):
vi = tir.axis.spatial(128, i)
tir.reads(A[vi], B[vi])
tir.writes(C[vi])
C[vi] = A[vi] + B[vi]
构建并运行#
在任何时间点,都可以通过调用 build
函数将 IRModule 变换为可运行函数。
rt_mod = tvm.build(MyModule, target="llvm") # CPU 后端的模块
print(type(rt_mod))
<class 'tvm.driver.build_module.OperatorModule'>
构建后,mod 包含了一组可运行的函数。可以通过函数名检索每个函数。
func = rt_mod["main"]
func
<tvm.runtime.packed_func.PackedFunc at 0x7fdd0e398980>
a = tvm.nd.array(np.arange(128, dtype="float32"))
b = tvm.nd.array(np.ones(128, dtype="float32"))
c = tvm.nd.empty((128,), dtype="float32")
要调用该函数,可以在 tvm 运行时中创建三个 ndarray,然后调用生成的函数。
func(a, b, c)
print(a)
print(b)
print(c)
[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13.
14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27.
28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41.
42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55.
56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69.
70. 71. 72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83.
84. 85. 86. 87. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97.
98. 99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111.
112. 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125.
126. 127.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1.]
[ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14.
15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28.
29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42.
43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56.
57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70.
71. 72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84.
85. 86. 87. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 98.
99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111. 112.
113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125. 126.
127. 128.]
变换张量程序#
现在开始变换张量程序。张量程序可以使用一种称为调度(schedule)的辅助数据结构进行变换。
sch = tvm.tir.Schedule(MyModule)
print(type(sch))
<class 'tvm.tir.schedule.schedule.Schedule'>
先试着划分循环:
# 按名称获取块
block_c = sch.get_block("C")
# 获取块周围的循环
(i,) = sch.get_loops(block_c)
# Tile 循环嵌套
i_0, i_1, i_2 = sch.split(i, factors=[None, 4, 4])
Code(sch.mod.script(), language="Python")
@tvm.script.ir_module
class Module:
@tir.prim_func
def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with tir.block("root")
for i_0, i_1, i_2 in tir.grid(8, 4, 4):
with tir.block("C"):
vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)
tir.reads(A[vi], B[vi])
tir.writes(C[vi])
C[vi] = A[vi] + B[vi]
也可以重新排序循环。把循环 i_2 移到 i_1 的外面。
sch.reorder(i_0, i_2, i_1)
print(sch.mod.script())
@tvm.script.ir_module
class Module:
@tir.prim_func
def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with tir.block("root")
for i_0, i_2, i_1 in tir.grid(8, 4, 4):
with tir.block("C"):
vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)
tir.reads(A[vi], B[vi])
tir.writes(C[vi])
C[vi] = A[vi] + B[vi]
最后,可以向程序生成器添加要对最内部循环进行向量化的提示。
Code(sch.mod.script(), language="python")
@tvm.script.ir_module
class Module:
@tir.prim_func
def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with tir.block("root")
for i_0, i_2, i_1 in tir.grid(8, 4, 4):
with tir.block("C"):
vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)
tir.reads(A[vi], B[vi])
tir.writes(C[vi])
C[vi] = A[vi] + B[vi]
sch.parallel(i_0)
Code(sch.mod.script(), language="python")
@tvm.script.ir_module
class Module:
@tir.prim_func
def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with tir.block("root")
for i_0 in tir.parallel(8):
for i_2, i_1 in tir.grid(4, 4):
with tir.block("C"):
vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)
tir.reads(A[vi], B[vi])
tir.writes(C[vi])
C[vi] = A[vi] + B[vi]
可以构建并运行变换后的程序
transformed_mod = tvm.build(sch.mod, target="llvm") # The module for CPU backends.
transformed_mod["main"](a, b, c)
使用张量表达式构建张量程序#
在前面的例子中,直接使用 TVMScript 来构造张量程序。在实践中,从现有定义实用地构造这些函数通常是有帮助的。张量表达式是一个 API,它可以帮助构建一些类似表达式的数组计算。
# namespace for tensor expression utility
from tvm import te
# declare the computation using the expression API
A = te.placeholder((128, ), name="A")
B = te.placeholder((128, ), name="B")
C = te.compute((128,), lambda i: A[i] + B[i], name="C")
# create a function with the specified list of arguments.
func = te.create_prim_func([A, B, C])
# mark that the function name is main
func = func.with_attr("global_symbol", "main")
ir_mod_from_te = IRModule({"main": func})
Code(ir_mod_from_te.script(), language="python")
@tvm.script.ir_module
class Module:
@tir.prim_func
def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with tir.block("root")
for i0 in tir.serial(128):
with tir.block("C"):
i = tir.axis.spatial(128, i0)
tir.reads(A[i], B[i])
tir.writes(C[i])
C[i] = A[i] + B[i]
变换矩阵乘法程序#
在上面的例子中,展示了如何变换向量加法。现在试着把它应用到稍微复杂一点的程序中(矩阵乘法)。首先尝试使用张量表达式 API 构建初始代码。
from tvm import te
M = 1024
K = 1024
N = 1024
# The default tensor type in tvm
dtype = "float32"
target = "llvm"
dev = tvm.device(target, 0)
# Algorithm
k = te.reduce_axis((0, K), "k")
A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name="C")
# Default schedule
func = te.create_prim_func([A, B, C])
func = func.with_attr("global_symbol", "main")
ir_module = IRModule({"main": func})
print(Code(ir_module.script(), language="python"))
func = tvm.build(ir_module, target="llvm") # The module for CPU backends.
a = tvm.nd.array(np.random.rand(M, K).astype(dtype), dev)
b = tvm.nd.array(np.random.rand(K, N).astype(dtype), dev)
c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
evaluator = func.time_evaluator(func.entry_name, dev, number=1)
print("Baseline: %f" % evaluator(a, b, c).mean)
@tvm.script.ir_module
class Module:
@tir.prim_func
def main(A: tir.Buffer[(1024, 1024), "float32"], B: tir.Buffer[(1024, 1024), "float32"], C: tir.Buffer[(1024, 1024), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2 in tir.grid(1024, 1024, 1024):
with tir.block("C"):
m, n, k = tir.axis.remap("SSR", [i0, i1, i2])
tir.reads(A[m, k], B[k, n])
tir.writes(C[m, n])
with tir.init():
C[m, n] = tir.float32(0)
C[m, n] = C[m, n] + A[m, k] * B[k, n]
Baseline: 2.238068
可以变换循环访问模式,使其对缓存更友好。让我们使用下面的调度。
sch = tvm.tir.Schedule(ir_module)
print(type(sch))
block_c = sch.get_block("C")
# Get loops surronding the block
(y, x, k) = sch.get_loops(block_c)
block_size = 32
yo, yi = sch.split(y, [None, block_size])
xo, xi = sch.split(x, [None, block_size])
sch.reorder(yo, xo, k, yi, xi)
Code(sch.mod.script(), language="python")
func = tvm.build(sch.mod, target="llvm") # The module for CPU backends.
c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)
evaluator = func.time_evaluator(func.entry_name, dev, number=1)
print("after transformation: %f" % evaluator(a, b, c).mean)
<class 'tvm.tir.schedule.schedule.Schedule'>
after transformation: 0.239306
试着改变 bn 的值,看看你能得到什么性能。在实践中,将利用自动系统来搜索一组可能的变换,以找到最优的变换。