{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# TVM 中的调度原语\n", "\n", "**原作者**: [Ziheng Jiang](https://github.com/ZihengJiang)\n", "\n", "TVM 用于高效构建 kernel 的领域特定语言。\n", "\n", "在本教程中,将您展示如何通过 TVM 提供的各种原语调度计算。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import tvm\n", "from tvm import te\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通常有几种方法可以计算相同的结果,但是,不同的方法会导致不同的局部性(locality)和性能。因此 TVM 要求用户提供如何执行名为 **Schedule** (调度)的计算。\n", "\n", "**Schedule** 是一组变换程序中计算循环的计算变换。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# 声明一些变量以备以后使用\n", "n = te.var(\"n\")\n", "m = te.var(\"m\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "调度可以从 ops 列表中创建,默认情况下,调度以 row-major 顺序的串行方式计算张量。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# 声明矩阵元素级的乘法\n", "A = te.placeholder((m, n), name=\"A\")\n", "B = te.placeholder((m, n), name=\"B\")\n", "C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name=\"C\")\n", "\n", "s = te.create_schedule([C.op])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`lower` 将计算从定义转换为实际的可调用函数。使用 `simple_mode=True` 参数,它将返回可读的 C like 语句,在这里使用它来打印调度结果。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mhandle, B: T\u001b[38;5;129;01m.\u001b[39;00mhandle, C: T\u001b[38;5;129;01m.\u001b[39;00mhandle) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " m \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " n \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_3 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_4 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_5 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(A, [stride \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " B_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(B, [stride_1 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " C_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(C, [stride_2 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(A_1, [m, n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mA_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride, stride_3], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(B_1, [m, n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mB_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_1, stride_4], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(C_1, [m, n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mC_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_2, stride_5], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i, j \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(m, n):\n", " C_1[i \u001b[38;5;129;01m*\u001b[39;00m stride_2 \u001b[38;5;129;01m+\u001b[39;00m j \u001b[38;5;129;01m*\u001b[39;00m stride_5] \u001b[38;5;129;01m=\u001b[39;00m A_1[i \u001b[38;5;129;01m*\u001b[39;00m stride \u001b[38;5;129;01m+\u001b[39;00m j \u001b[38;5;129;01m*\u001b[39;00m stride_3] \u001b[38;5;129;01m*\u001b[39;00m B_1[i \u001b[38;5;129;01m*\u001b[39;00m stride_1 \u001b[38;5;129;01m+\u001b[39;00m j \u001b[38;5;129;01m*\u001b[39;00m stride_4]\n", " \n", "\n" ] } ], "source": [ "tvm.lower(s, [A, B, C], simple_mode=True).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "每个调度由多个阶段(Stage)组成,每个阶段表示一个运算的调度。\n", "\n", "下面提供各种方法来调度每个阶段。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## split\n", "\n", "`split` 可以通过 `factor` 将指定的轴分裂(split)为两个轴。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mhandle, B: T\u001b[38;5;129;01m.\u001b[39;00mhandle) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " m \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(A, [stride \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " B_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(B, [stride_1 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(A_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mA_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(B_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mB_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_1], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_outer, i_inner \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m32\u001b[39m, \u001b[38;5;28m32\u001b[39m):\n", " cse_var_1: T\u001b[38;5;129;01m.\u001b[39;00mint32 \u001b[38;5;129;01m=\u001b[39;00m i_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_inner\n", " B_1[cse_var_1 \u001b[38;5;129;01m*\u001b[39;00m stride_1] \u001b[38;5;129;01m=\u001b[39;00m A_1[cse_var_1 \u001b[38;5;129;01m*\u001b[39;00m stride] \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m2\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_outer, i_inner \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid((m \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m31\u001b[39m) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m32\u001b[39m, \u001b[38;5;28m32\u001b[39m):\n", " \u001b[38;5;28;01mif\u001b[39;00m m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_inner \u001b[38;5;129;01m<\u001b[39;00m m:\n", " B_1[(m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_inner) \u001b[38;5;129;01m*\u001b[39;00m stride_1] \u001b[38;5;129;01m=\u001b[39;00m A_1[(m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_inner) \u001b[38;5;129;01m*\u001b[39;00m stride] \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m2\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "m = te.var(\"m\")\n", "A = te.placeholder((m,), name=\"A\")\n", "B = te.compute((m,), lambda i: A[i] * 2, name=\"B\")\n", "\n", "s = te.create_schedule(B.op)\n", "xo, xi = s[B].split(B.op.axis[0], factor=32)\n", "tvm.lower(s, [A, B]).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "你也可以通过 `nparts` 分裂轴,它与 `factor` 分割轴相对。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mhandle, B: T\u001b[38;5;129;01m.\u001b[39;00mhandle) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " m \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(A, [stride \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " B_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(B, [stride_1 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(A_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mA_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(B_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mB_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_1], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_outer, i_inner \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m32\u001b[39m, (m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m31\u001b[39m) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m32\u001b[39m):\n", " \u001b[38;5;28;01mif\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mlikely(i_inner \u001b[38;5;129;01m+\u001b[39;00m i_outer \u001b[38;5;129;01m*\u001b[39;00m ((m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m31\u001b[39m) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m32\u001b[39m) \u001b[38;5;129;01m<\u001b[39;00m m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbool\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " B_1[(i_inner \u001b[38;5;129;01m+\u001b[39;00m i_outer \u001b[38;5;129;01m*\u001b[39;00m ((m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m31\u001b[39m) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m32\u001b[39m)) \u001b[38;5;129;01m*\u001b[39;00m stride_1] \u001b[38;5;129;01m=\u001b[39;00m A_1[(i_inner \u001b[38;5;129;01m+\u001b[39;00m i_outer \u001b[38;5;129;01m*\u001b[39;00m ((m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m31\u001b[39m) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m32\u001b[39m)) \u001b[38;5;129;01m*\u001b[39;00m stride]\n", " \n", "\n" ] } ], "source": [ "m = te.var(\"m\")\n", "A = te.placeholder((m,), name=\"A\")\n", "B = te.compute((m,), lambda i: A[i], name=\"B\")\n", "\n", "s = te.create_schedule(B.op)\n", "bx, tx = s[B].split(B.op.axis[0], nparts=32)\n", "tvm.lower(s, [A, B], simple_mode=True).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## tile\n", "\n", "`tile` 帮助你在两个轴上逐块(tile by tile)执行计算。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mhandle, B: T\u001b[38;5;129;01m.\u001b[39;00mhandle) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " m \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " n \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_3 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(A, [stride \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " B_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(B, [stride_1 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(A_1, [m, n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mA_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride, stride_2], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(B_1, [m, n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mB_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_1, stride_3], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_outer, j_outer, i_inner \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid((m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m9\u001b[39m) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m10\u001b[39m, (n \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m4\u001b[39m) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m5\u001b[39m, \u001b[38;5;28m10\u001b[39m):\n", " \u001b[38;5;28;01mif\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mlikely(i_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m10\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_inner \u001b[38;5;129;01m<\u001b[39;00m m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbool\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m j_inner \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m5\u001b[39m):\n", " \u001b[38;5;28;01mif\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mlikely(j_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m5\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_inner \u001b[38;5;129;01m<\u001b[39;00m n, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbool\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " cse_var_2: T\u001b[38;5;129;01m.\u001b[39;00mint32 \u001b[38;5;129;01m=\u001b[39;00m j_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m5\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_inner\n", " cse_var_1: T\u001b[38;5;129;01m.\u001b[39;00mint32 \u001b[38;5;129;01m=\u001b[39;00m i_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m10\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_inner\n", " B_1[cse_var_1 \u001b[38;5;129;01m*\u001b[39;00m stride_1 \u001b[38;5;129;01m+\u001b[39;00m cse_var_2 \u001b[38;5;129;01m*\u001b[39;00m stride_3] \u001b[38;5;129;01m=\u001b[39;00m A_1[cse_var_1 \u001b[38;5;129;01m*\u001b[39;00m stride \u001b[38;5;129;01m+\u001b[39;00m cse_var_2 \u001b[38;5;129;01m*\u001b[39;00m stride_2]\n", " \n", "\n" ] } ], "source": [ "A = te.placeholder((m, n), name=\"A\")\n", "B = te.compute((m, n), lambda i, j: A[i, j], name=\"B\")\n", "\n", "s = te.create_schedule(B.op)\n", "xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)\n", "tvm.lower(s, [A, B], simple_mode=True).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## fuse\n", "\n", "`fuse` 可以融合一个计算的两个连续轴。\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mhandle, B: T\u001b[38;5;129;01m.\u001b[39;00mhandle) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " m \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " n \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_3 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(A, [stride \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " B_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(B, [stride_1 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(A_1, [m, n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mA_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride, stride_2], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(B_1, [m, n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mB_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_1, stride_3], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_outer, j_outer, i_inner_j_inner_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid((m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m9\u001b[39m) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m10\u001b[39m, (n \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m4\u001b[39m) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m5\u001b[39m, \u001b[38;5;28m50\u001b[39m):\n", " \u001b[38;5;28;01mif\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mlikely(i_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m10\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_inner_j_inner_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m5\u001b[39m \u001b[38;5;129;01m<\u001b[39;00m m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbool\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mif\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mlikely(j_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m5\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_inner_j_inner_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m5\u001b[39m \u001b[38;5;129;01m<\u001b[39;00m n, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbool\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " cse_var_2: T\u001b[38;5;129;01m.\u001b[39;00mint32 \u001b[38;5;129;01m=\u001b[39;00m j_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m5\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_inner_j_inner_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m5\u001b[39m\n", " cse_var_1: T\u001b[38;5;129;01m.\u001b[39;00mint32 \u001b[38;5;129;01m=\u001b[39;00m i_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m10\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_inner_j_inner_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m5\u001b[39m\n", " B_1[cse_var_1 \u001b[38;5;129;01m*\u001b[39;00m stride_1 \u001b[38;5;129;01m+\u001b[39;00m cse_var_2 \u001b[38;5;129;01m*\u001b[39;00m stride_3] \u001b[38;5;129;01m=\u001b[39;00m A_1[cse_var_1 \u001b[38;5;129;01m*\u001b[39;00m stride \u001b[38;5;129;01m+\u001b[39;00m cse_var_2 \u001b[38;5;129;01m*\u001b[39;00m stride_2]\n", " \n", "\n" ] } ], "source": [ "A = te.placeholder((m, n), name=\"A\")\n", "B = te.compute((m, n), lambda i, j: A[i, j], name=\"B\")\n", "\n", "s = te.create_schedule(B.op)\n", "# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)\n", "xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)\n", "# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)\n", "fused = s[B].fuse(xi, yi)\n", "tvm.lower(s, [A, B], simple_mode=True).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## reorder\n", ":code:`reorder` can reorder the axes in the specified order.\n", "\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mhandle, B: T\u001b[38;5;129;01m.\u001b[39;00mhandle) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " m \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " n \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_3 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(A, [stride \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " B_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(B, [stride_1 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(A_1, [m, n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mA_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride, stride_2], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(B_1, [m, n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mB_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_1, stride_3], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_inner, j_outer, i_outer \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m10\u001b[39m, (n \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m4\u001b[39m) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m5\u001b[39m, (m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m9\u001b[39m) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m10\u001b[39m):\n", " \u001b[38;5;28;01mif\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mlikely(i_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m10\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_inner \u001b[38;5;129;01m<\u001b[39;00m m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbool\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m j_inner \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m5\u001b[39m):\n", " \u001b[38;5;28;01mif\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mlikely(j_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m5\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_inner \u001b[38;5;129;01m<\u001b[39;00m n, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbool\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " cse_var_2: T\u001b[38;5;129;01m.\u001b[39;00mint32 \u001b[38;5;129;01m=\u001b[39;00m j_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m5\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_inner\n", " cse_var_1: T\u001b[38;5;129;01m.\u001b[39;00mint32 \u001b[38;5;129;01m=\u001b[39;00m i_outer \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m10\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_inner\n", " B_1[cse_var_1 \u001b[38;5;129;01m*\u001b[39;00m stride_1 \u001b[38;5;129;01m+\u001b[39;00m cse_var_2 \u001b[38;5;129;01m*\u001b[39;00m stride_3] \u001b[38;5;129;01m=\u001b[39;00m A_1[cse_var_1 \u001b[38;5;129;01m*\u001b[39;00m stride \u001b[38;5;129;01m+\u001b[39;00m cse_var_2 \u001b[38;5;129;01m*\u001b[39;00m stride_2]\n", " \n", "\n" ] } ], "source": [ "A = te.placeholder((m, n), name=\"A\")\n", "B = te.compute((m, n), lambda i, j: A[i, j], name=\"B\")\n", "\n", "s = te.create_schedule(B.op)\n", "# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)\n", "xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)\n", "# then reorder the axes: (i.inner, j.outer, i.outer, j.inner)\n", "s[B].reorder(xi, yo, xo, yi)\n", "tvm.lower(s, [A, B], simple_mode=True).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## bind\n", ":code:`bind` can bind a specified axis with a thread axis, often used\n", "in gpu programming.\n", "\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mhandle, B: T\u001b[38;5;129;01m.\u001b[39;00mhandle) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# var definition\u001b[39;00m\n", " threadIdx_x \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00menv_thread(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " blockIdx_x \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00menv_thread(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " n \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(A, [stride \u001b[38;5;129;01m*\u001b[39;00m n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " B_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(B, [stride_1 \u001b[38;5;129;01m*\u001b[39;00m n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(A_1, [n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mA_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(B_1, [n], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mB_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_1], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mlaunch_thread(blockIdx_x, (n \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m63\u001b[39m) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mlaunch_thread(threadIdx_x, \u001b[38;5;28m64\u001b[39m)\n", " \u001b[38;5;28;01mif\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mlikely(blockIdx_x \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m threadIdx_x \u001b[38;5;129;01m<\u001b[39;00m n, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbool\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " B_1[(blockIdx_x \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m threadIdx_x) \u001b[38;5;129;01m*\u001b[39;00m stride_1] \u001b[38;5;129;01m=\u001b[39;00m A_1[(blockIdx_x \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m threadIdx_x) \u001b[38;5;129;01m*\u001b[39;00m stride] \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m2\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "A = te.placeholder((n,), name=\"A\")\n", "B = te.compute(A.shape, lambda i: A[i] * 2, name=\"B\")\n", "\n", "s = te.create_schedule(B.op)\n", "bx, tx = s[B].split(B.op.axis[0], factor=64)\n", "s[B].bind(bx, te.thread_axis(\"blockIdx.x\"))\n", "s[B].bind(tx, te.thread_axis(\"threadIdx.x\"))\n", "tvm.lower(s, [A, B], simple_mode=True).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## compute_at\n", "For a schedule that consists of multiple operators, TVM will compute\n", "tensors at the root separately by default.\n", "\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mhandle, B: T\u001b[38;5;129;01m.\u001b[39;00mhandle, C: T\u001b[38;5;129;01m.\u001b[39;00mhandle) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " m \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(A, [stride \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " B_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(B, [stride_1 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " C_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(C, [stride_2 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(A_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mA_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(B_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mB_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_1], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(C_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mC_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_2], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(m):\n", " B_1[i \u001b[38;5;129;01m*\u001b[39;00m stride_1] \u001b[38;5;129;01m=\u001b[39;00m A_1[i \u001b[38;5;129;01m*\u001b[39;00m stride] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(m):\n", " C_1[i \u001b[38;5;129;01m*\u001b[39;00m stride_2] \u001b[38;5;129;01m=\u001b[39;00m B_1[i \u001b[38;5;129;01m*\u001b[39;00m stride_1] \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m2\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "A = te.placeholder((m,), name=\"A\")\n", "B = te.compute((m,), lambda i: A[i] + 1, name=\"B\")\n", "C = te.compute((m,), lambda i: B[i] * 2, name=\"C\")\n", "\n", "s = te.create_schedule(C.op)\n", "tvm.lower(s, [A, B, C], simple_mode=True).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ":code:`compute_at` can move computation of `B` into the first axis\n", "of computation of `C`.\n", "\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mhandle, B: T\u001b[38;5;129;01m.\u001b[39;00mhandle, C: T\u001b[38;5;129;01m.\u001b[39;00mhandle) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " m \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(A, [stride \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " B_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(B, [stride_1 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " C_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(C, [stride_2 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(A_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mA_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(B_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mB_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_1], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(C_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mC_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_2], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(m):\n", " B_1[i \u001b[38;5;129;01m*\u001b[39;00m stride_1] \u001b[38;5;129;01m=\u001b[39;00m A_1[i \u001b[38;5;129;01m*\u001b[39;00m stride] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n", " C_1[i \u001b[38;5;129;01m*\u001b[39;00m stride_2] \u001b[38;5;129;01m=\u001b[39;00m B_1[i \u001b[38;5;129;01m*\u001b[39;00m stride_1] \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m2\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "A = te.placeholder((m,), name=\"A\")\n", "B = te.compute((m,), lambda i: A[i] + 1, name=\"B\")\n", "C = te.compute((m,), lambda i: B[i] * 2, name=\"C\")\n", "\n", "s = te.create_schedule(C.op)\n", "s[B].compute_at(s[C], C.op.axis[0])\n", "tvm.lower(s, [A, B, C], simple_mode=True).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## compute_inline\n", ":code:`compute_inline` can mark one stage as inline, then the body of\n", "computation will be expanded and inserted at the address where the\n", "tensor is required.\n", "\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mhandle, B: T\u001b[38;5;129;01m.\u001b[39;00mhandle, C: T\u001b[38;5;129;01m.\u001b[39;00mhandle) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " m \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(A, [stride \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " B_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(B, [stride_1 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " C_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(C, [stride_2 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(A_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mA_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(B_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mB_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_1], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(C_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mC_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_2], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(m):\n", " C_1[i \u001b[38;5;129;01m*\u001b[39;00m stride_2] \u001b[38;5;129;01m=\u001b[39;00m (A_1[i \u001b[38;5;129;01m*\u001b[39;00m stride] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m2\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "A = te.placeholder((m,), name=\"A\")\n", "B = te.compute((m,), lambda i: A[i] + 1, name=\"B\")\n", "C = te.compute((m,), lambda i: B[i] * 2, name=\"C\")\n", "\n", "s = te.create_schedule(C.op)\n", "s[B].compute_inline()\n", "tvm.lower(s, [A, B, C], simple_mode=True).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## compute_root\n", ":code:`compute_root` can move computation of one stage to the root.\n", "\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mhandle, B: T\u001b[38;5;129;01m.\u001b[39;00mhandle, C: T\u001b[38;5;129;01m.\u001b[39;00mhandle) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_legacy_te_schedule\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " m \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " stride_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(A, [stride \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " B_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(B, [stride_1 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " C_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(C, [stride_2 \u001b[38;5;129;01m*\u001b[39;00m m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(A_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mA_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(B_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mB_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_1], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mpreflattened_buffer(C_1, [m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;129;01m=\u001b[39;00mC_1\u001b[38;5;129;01m.\u001b[39;00mdata, strides\u001b[38;5;129;01m=\u001b[39;00m[stride_2], type\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(m):\n", " B_1[i \u001b[38;5;129;01m*\u001b[39;00m stride_1] \u001b[38;5;129;01m=\u001b[39;00m A_1[i \u001b[38;5;129;01m*\u001b[39;00m stride] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m1\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(m):\n", " C_1[i \u001b[38;5;129;01m*\u001b[39;00m stride_2] \u001b[38;5;129;01m=\u001b[39;00m B_1[i \u001b[38;5;129;01m*\u001b[39;00m stride_1] \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m2\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "A = te.placeholder((m,), name=\"A\")\n", "B = te.compute((m,), lambda i: A[i] + 1, name=\"B\")\n", "C = te.compute((m,), lambda i: B[i] * 2, name=\"C\")\n", "\n", "s = te.create_schedule(C.op)\n", "s[B].compute_at(s[C], C.op.axis[0])\n", "s[B].compute_root()\n", "tvm.lower(s, [A, B, C], simple_mode=True).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "This tutorial provides an introduction to schedule primitives in\n", "tvm, which permits users schedule the computation easily and\n", "flexibly.\n", "\n", "In order to get a good performance kernel implementation, the\n", "general workflow often is:\n", "\n", "- Describe your computation via series of operations.\n", "- Try to schedule the computation with primitives.\n", "- Compile and run to see the performance difference.\n", "- Adjust your schedule according the running result.\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('py38': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "vscode": { "interpreter": { "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7" } } }, "nbformat": 4, "nbformat_minor": 0 }