{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "Mpn1ti5Urdsv" }, "source": [ "# Ep6: Integration with Machine Learning Frameworks" ] }, { "cell_type": "markdown", "metadata": { "id": "qXysoqn-vZuF" }, "source": [ "## Install packages \n", "\n", "For the purpose of this course, we will use some ongoing development in tvm, which is an open-source machine learning compilation framework. We provide the following command to install a packaged version for mlc course." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Xe3vClsD9jlq", "outputId": "26a1bfbf-9182-4cfe-d507-042d11b9225a" }, "outputs": [], "source": [ "# !python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels" ] }, { "cell_type": "markdown", "metadata": { "id": "i-14C4skxIrJ" }, "source": [ "## Prelude\n", "\n", "In the past chapters, we have learned about abstractions for machine learning compilation and transformations among tensor functions.\n", "\n", "This chapter will discuss how to bring machine learning models from the existing ML framework into an MLC flow." ] }, { "cell_type": "markdown", "metadata": { "id": "BBIuE2jc1DaU" }, "source": [ "## Preparations\n", "\n", "To begin with, we will import necessary dependencies.\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "BVp0fHyRkYj6" }, "outputs": [], "source": [ "import tvm\n", "from tvm.ir.module import IRModule\n", "from tvm.script import tir as T, relax as R\n", "from tvm import relax\n", "import numpy as np\n", "\n", "# This is needed for deferring annotation parsing in TVMScript\n", "from __future__ import annotations " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "6saIbYSCrZF7" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch import fx\n", "from torch.nn import functional as F" ] }, { "cell_type": "markdown", "metadata": { "id": "8yH4IMSMvF9o" }, "source": [ "## Build an IRModule Through a Builder\n", "\n", "In the past chapters, we have been building IRModule by directly writing TVMScript. As the model gets larger, we need a programmatical way to build up an IRModule. In this section, let us review some of the tools to support that process.\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "uahNKehVr2gg" }, "source": [ "### Tensor Expression for TensorIR Creation\n", "\n", "First, we review the tensor expression domain-specific language to build TensorIR functions.\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "rxjwYscu4ukP" }, "outputs": [], "source": [ "from tvm import te" ] }, { "cell_type": "markdown", "metadata": { "id": "yk63AG3t6Kdu" }, "source": [ "We begin by creating a placeholder object, which represents an input to a TensorIR function." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "s_IUriaf590M" }, "outputs": [], "source": [ "A = te.placeholder((128, 128), name=\"A\", dtype=\"float32\")\n", "B = te.placeholder((128, 128), name=\"B\", dtype=\"float32\")" ] }, { "cell_type": "markdown", "metadata": { "id": "bOfyFOkS6YVO" }, "source": [ "Each input and intermediate result here are represented as a `te.Tensor` object." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "caakIb3B59_Q", "outputId": "222c77b2-1d65-43ea-d079-4e67743c7d50" }, "outputs": [ { "data": { "text/plain": [ "tvm.te.tensor.Tensor" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(A)" ] }, { "cell_type": "markdown", "metadata": { "id": "yN6-WogI6l8k" }, "source": [ "Each `te.Tensor` has a shape field and dtype field that tracks the shape\n", "and data type of the computation." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_5SiENj96VMa", "outputId": "9f77fccf-4467-4d82-cc67-e3c8afd07d35" }, "outputs": [ { "data": { "text/plain": [ "[128, 128]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "A.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "GWg6lEgR5P7i" }, "source": [ "We can describe computations through a sequence of tensor expression computation, Here `te.compute` takes the signature `te.compute(output_shape, fcompute)`. And the fcompute function describes how we want to compute the value of each element `[i, j]` for a given index.\n", "\n", "The `te_matmul` function takes in an object with type `te.Tensor`, and returns the matrix multiplication result. Note how we build up computations depending on A and B's input shape. The `te_matmul` works for A and B with different input shapes.\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "zlvL1Zfkt9A4" }, "outputs": [], "source": [ "def te_matmul(A: te.Tensor, B: te.Tensor) -> te.Tensor:\n", " assert A.shape[1] == B.shape[0]\n", " n = A.shape[0]\n", " m = B.shape[1]\n", " k = te.reduce_axis((0, A.shape[1]), name=\"k\")\n", " return te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name=\"matmul\")" ] }, { "cell_type": "markdown", "metadata": { "id": "dkqrFnHA7er4" }, "source": [ "We can create the result of matmul calling `te_matmul` with A and B." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "ewHuGnPv7EyV" }, "outputs": [], "source": [ "C = te_matmul(A, B)" ] }, { "cell_type": "markdown", "metadata": { "id": "UJKUFKwq9D5d" }, "source": [ "To create a TensorIR function, we can call `te.create_prim_func` and pass in the input and output values." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ibJ1f-TI8k2F", "outputId": "e7e005fc-d6bf-49e7-8fba-45bbf6b5a719" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;30;03m# from tvm.script import tir as T\u001b[39;00m\n", "\u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", "\u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfunc\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], matmul: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1, i2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i, j, k \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1, i2])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(A[i, k], B[k, j])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(matmul[i, j])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " matmul[i, j] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " matmul[i, j] \u001b[38;5;129;01m=\u001b[39;00m matmul[i, j] \u001b[38;5;129;01m+\u001b[39;00m A[i, k] \u001b[38;5;129;01m*\u001b[39;00m B[k, j]\n", "\n" ] } ], "source": [ "te.create_prim_func([A, B, C]).show()" ] }, { "cell_type": "markdown", "metadata": { "id": "TeGKASD09NIJ" }, "source": [ "We can create a tensor expression for relu computation in a similar fashion. Here we write it in a way so that `te_relu` function can work for `te.Tensor` with any dimension and shape." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "GSeib3RT778t" }, "outputs": [], "source": [ "def te_relu(A: te.Tensor) -> te.Tensor:\n", " return te.compute(A.shape, lambda *i: te.max(A(*i), 0), name=\"relu\")" ] }, { "cell_type": "markdown", "metadata": { "id": "CwsiEWlU91UY" }, "source": [ "Let us try out `te_relu` on two different input shapes and dimensions. First `X1` with shape `(10,)`." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PYRH0bCV8k9B", "outputId": "f4a27fd7-173f-438f-eb49-e83b53360ba6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;30;03m# from tvm.script import tir as T\u001b[39;00m\n", "\u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", "\u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfunc\u001b[39m(X1: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m10\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], relu: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m10\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m10\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrelu\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i0_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m10\u001b[39m, i0)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X1[i0_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(relu[i0_1])\n", " relu[i0_1] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmax(X1[i0_1], T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m))\n", "\n" ] } ], "source": [ "X1 = te.placeholder((10,), name=\"X1\", dtype=\"float32\")\n", "Y1 = te_relu(X1)\n", "te.create_prim_func([X1, Y1]).show()" ] }, { "cell_type": "markdown", "metadata": { "id": "DDPuxjGy-AiV" }, "source": [ "Then `X2` with shape `(10, 20)`." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CXN_tu5k9o4u", "outputId": "16099283-d781-4635-d173-ba915cd60868" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;30;03m# from tvm.script import tir as T\u001b[39;00m\n", "\u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", "\u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfunc\u001b[39m(X1: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m10\u001b[39m, \u001b[38;5;28m20\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], relu: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m10\u001b[39m, \u001b[38;5;28m20\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m10\u001b[39m, \u001b[38;5;28m20\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrelu\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i0_1, i1_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X1[i0_1, i1_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(relu[i0_1, i1_1])\n", " relu[i0_1, i1_1] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmax(X1[i0_1, i1_1], T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m))\n", "\n" ] } ], "source": [ "X2 = te.placeholder((10, 20), name=\"X1\", dtype=\"float32\")\n", "Y2 = te_relu(X2)\n", "te.create_prim_func([X2, Y2]).show()" ] }, { "cell_type": "markdown", "metadata": { "id": "1slxb7C6-FT3" }, "source": [ "One final thing that `te` API allows us to do is to compose operations and create \"fused\" operators. For example, we can take the result of matmul and apply relu again. " ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "5h7i20Qd9pEA" }, "outputs": [], "source": [ "C = te_matmul(A, B)\n", "D = te_relu(C)" ] }, { "cell_type": "markdown", "metadata": { "id": "Q2vxDWDq-keD" }, "source": [ "We can create a TensorIR function by only passing the input and output values of interest, skipping intermediate values. This will cause the result of matmul being allocated as a temp space in the TensorIR function." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kYqc5Ft--jHR", "outputId": "c4b17b9f-d426-4372-f976-fe8e475b676d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;30;03m# from tvm.script import tir as T\u001b[39;00m\n", "\u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", "\u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfunc\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], relu: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " matmul \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1, i2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i, j, k \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1, i2])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(A[i, k], B[k, j])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(matmul[i, j])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " matmul[i, j] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " matmul[i, j] \u001b[38;5;129;01m=\u001b[39;00m matmul[i, j] \u001b[38;5;129;01m+\u001b[39;00m A[i, k] \u001b[38;5;129;01m*\u001b[39;00m B[k, j]\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrelu\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i0_1, i1_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(matmul[i0_1, i1_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(relu[i0_1, i1_1])\n", " relu[i0_1, i1_1] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmax(matmul[i0_1, i1_1], T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m))\n", "\n" ] } ], "source": [ "te.create_prim_func([A, B, D]).show()" ] }, { "cell_type": "markdown", "metadata": { "id": "3DVxa0rO_z26" }, "source": [ "We can also pass the intermediate result C into the argument list. In this case, the TensorIR function expects us to also pass in the buffer of C from the caller side. Normally we recommend only passing in the input/output so we can have more advanced fusion inside." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rwzFD4bt-wm3", "outputId": "359f4aa0-5f08-4ee6-94ac-33b5e4e17898" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;30;03m# from tvm.script import tir as T\u001b[39;00m\n", "\u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", "\u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfunc\u001b[39m(A: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], B: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], matmul: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], relu: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1, i2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i, j, k \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1, i2])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(A[i, k], B[k, j])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(matmul[i, j])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " matmul[i, j] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " matmul[i, j] \u001b[38;5;129;01m=\u001b[39;00m matmul[i, j] \u001b[38;5;129;01m+\u001b[39;00m A[i, k] \u001b[38;5;129;01m*\u001b[39;00m B[k, j]\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrelu\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i0_1, i1_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(matmul[i0_1, i1_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(relu[i0_1, i1_1])\n", " relu[i0_1, i1_1] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmax(matmul[i0_1, i1_1], T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m))\n", "\n" ] } ], "source": [ "te.create_prim_func([A, B, C, D]).show()" ] }, { "cell_type": "markdown", "metadata": { "id": "oOTcJQDB7EeJ" }, "source": [ "### Use BlockBuilder to Create an IRModule" ] }, { "cell_type": "markdown", "metadata": { "id": "Vm71Tp87Qz8M" }, "source": [ "So far, we have created a single TensorIR function. In order to build end-to-end model execution, we also need to be able to connect multiple TensorIR functions through a computational graph.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "SnaE0Szgu5cq" }, "source": [ "Let us first create a block builder, which helps us incrementally build a `relax.Function`." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "0xwqttTtAESi" }, "outputs": [], "source": [ "A = relax.Var(\"A\", (128, 128), relax.DynTensorType(2, \"float32\"))\n", "B = relax.Var(\"B\", (128, 128), relax.DynTensorType(2, \"float32\"))" ] }, { "cell_type": "markdown", "metadata": { "id": "IlXZjFsT02-Y" }, "source": [ "We construct the relax function by creating a block builder and then a sequence of primitive tensor operations." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Hz4UojV8vT8Y", "outputId": "e7cb31e3-d7f3-4cba-9042-f1ea5a140d36" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mte_matmul\u001b[39m(rxplaceholder: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], rxplaceholder_1: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], matmul: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mte_matmul\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1, i2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i, j, k \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1, i2])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(rxplaceholder[i, k], rxplaceholder_1[k, j])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(matmul[i, j])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " matmul[i, j] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " matmul[i, j] \u001b[38;5;129;01m=\u001b[39;00m matmul[i, j] \u001b[38;5;129;01m+\u001b[39;00m rxplaceholder[i, k] \u001b[38;5;129;01m*\u001b[39;00m rxplaceholder_1[k, j]\n", " \n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mte_relu\u001b[39m(rxplaceholder: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], relu: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mte_relu\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrelu\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i0_1, i1_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(rxplaceholder[i0_1, i1_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(relu[i0_1, i1_1])\n", " relu[i0_1, i1_1] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmax(rxplaceholder[i0_1, i1_1], T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m))\n", " \n", " \u001b[38;5;129m@R\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mfunction\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(A: Tensor((\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m), B: Tensor((\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m Tensor(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, ndim \u001b[38;5;129;01m=\u001b[39;00m \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;30;03m# block 0\u001b[39;00m\n", " \u001b[38;5;28;01mwith\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mdataflow():\n", " lv \u001b[38;5;129;01m=\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mcall_tir(te_matmul, (A, B), (\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " lv1 \u001b[38;5;129;01m=\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mcall_tir(te_relu, (lv,), (\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " gv: Tensor((\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m=\u001b[39;00m lv1\n", " R\u001b[38;5;129;01m.\u001b[39;00moutput(gv)\n", " \u001b[38;5;28;01mreturn\u001b[39;00m gv\n", " \n", "\n" ] } ], "source": [ "bb = relax.BlockBuilder()\n", "\n", "with bb.function(\"main\"):\n", " with bb.dataflow():\n", " C = bb.emit_te(te_matmul, A, B)\n", " D = bb.emit_te(te_relu, C)\n", " R = bb.emit_output(D)\n", " bb.emit_func_output(R, params=[A, B])\n", "\n", "MyModule = bb.get()\n", "MyModule.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "2yBnEX5C2lG5" }, "source": [ "### Deep Dive into Block Builder APIs\n", "\n", "Now let us do a deep dive into each block builder API. It is helpful to put the block builder code and the resulting module side by side." ] }, { "cell_type": "markdown", "metadata": { "id": "nzJgNT4kyh3i" }, "source": [ "![Screen Shot 2022-07-29 at 11.16.24 PM.png]()" ] }, { "cell_type": "markdown", "metadata": { "id": "erv3pnzy25-7" }, "source": [ "The block builder comes with scopes that correspond to the scopes in the relax function. For example, `bb.dataflow()` creates a dataflow\n", "block where all the block builder calls inside the scope belonging to the dataflow scope. \n", "\n", "```python\n", "with bb.function(\"main\"):\n", " with bb.dataflow():\n", " # every emit call generates a variable inside a dataflow block.\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "LA-0uerSAnUC" }, "source": [ "Each intermediate result is a `relax.Var` corresponding to a variable that stores the result of the computation. `DataflowVar` indicates that the var is an intermediate step inside a dataflow block (computational graph)." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "H8l9fbJQAlY_", "outputId": "2eca5a8b-7290-447c-854d-99f99d8d5ff6" }, "outputs": [ { "data": { "text/plain": [ "tvm.relax.expr.DataflowVar" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(C)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ufvwcuD5ytYk", "outputId": "fd34141a-0d32-4a80-8e1f-77e890a55c1b" }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "isinstance(C, relax.Var)" ] }, { "cell_type": "markdown", "metadata": { "id": "f61-biEx82nr" }, "source": [ "Each line in the relax function is generated by an `emit_te` call. For example, \n", "\n", "```python\n", "lv = R.call_tir(te_matmul, (A, B), (128, 128), dtype=\"float32\")\n", "```\n", "is generated by \n", "```python\n", "C = bb.emit_te(te_matmul, A, B).\n", "```\n", "Under the hood, the bb.emit_te does the following things:\n", "- Create an input `te.placeholder` for A and B\n", "- Run them through `te_matmul` function.\n", "- Call into `te.create_prim_func` to create a TensorIR function.\n", "- Generate a call into the function via `call_tir`.\n", "\n", "We can find that the result is a computational graph with two intermediate values, with one node corresponding to the te_matmul operation and another one corresponding to `te_relu`." ] }, { "cell_type": "markdown", "metadata": { "id": "hWd1GbGHBZlt" }, "source": [ "We can create output variable of each dataflow block through `bb.emit_output`.\n", "\n", "```python\n", "with bb.dataflow():\n", " ...\n", " R = bb.emit_output(D)\n", "```\n", "The above code marks that D is a variable that can be referred to outside of the dataflow block.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "FrqQCHu3BrGC" }, "source": [ "Finally, the function output is marked by `bb.emit_func_output`. We can only call `emit_func_output` once in each function scope.\n", "\n", "Notably, we can specify the list of parameters of the function in the output emission stage. Doing so helps us in cases where we collect the list of parameters on the fly.\n", "\n", "```python\n", "with bb.function(\"main\"):\n", " ...\n", " # specify parameters in the end\n", " bb.emit_func_output(R, params=[A, B])\n", "```\n", "Alternatively, we can specify the list of parameters at the beginning of the function scope.\n", "```python\n", "# specify parameters in the beginning.\n", "with bb.function(\"main\", params=[A, B]):\n", " ...\n", " bb.emit_func_output(R)\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "id": "a-c8N-7sEa1M" }, "source": [ "## Import Model From PyTorch" ] }, { "cell_type": "markdown", "metadata": { "id": "PUcCRU2IQPm-" }, "source": [ "Now that we have learned the tools to construct an IRModule programmatically. Let us use them to bring a model from PyTorch into the IRModule format.\n", "\n", "Most machine learning framework comes with computational graph abstractions, where each node corresponds to an operation, and the edges correspond to the dependency among them. We will take a PyTorch model, obtain a computational graph in PyTorch's native format, and translate that into IRModule.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "HJFs5zDkGAOU" }, "source": [ "Let us begin by defining a model in PyTorch. To keep the example consistent, we will use matmul relu example." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "F6gzwCFfF6V9" }, "outputs": [], "source": [ "class MyModel(nn.Module):\n", " def __init__(self):\n", " super(MyModel, self).__init__()\n", " self.weight = nn.Parameter(torch.randn(128, 128))\n", "\n", " def forward(self, x):\n", " x = torch.matmul(x, self.weight)\n", " x = torch.relu(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": { "id": "PoX9dyqVG0ZB" }, "source": [ "### Create TorchFX GraphModule\n", "\n", "We use TorchFX to trace a graph from the PyTorch module." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4vGt8XleGH5q", "outputId": "1ffd6f09-38c4-4902-b4bf-76ccf99d6ad4" }, "outputs": [ { "data": { "text/plain": [ "torch.fx.graph_module.GraphModule.__new__..GraphModuleImpl" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = MyModel()\n", "fx_module = fx.symbolic_trace(model)\n", "type(fx_module)" ] }, { "cell_type": "markdown", "metadata": { "id": "eYJUSv2qHm5x" }, "source": [ "The `fx_module` contains a simple computation graph view that can be printed as tabular data. Our goal is to translate this graph into an IRModule." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "oxvS3jjmGztF", "outputId": "970a7536-c672-48fa-9b8e-fadeffb8edf1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "opcode name target args kwargs\n", "------------- ------ --------------------------------------------------------- ----------- --------\n", "placeholder x x () {}\n", "get_attr weight weight () {}\n", "call_function matmul (x, weight) {}\n", "call_function relu (matmul,) {}\n", "output output output (relu,) {}\n" ] } ], "source": [ "fx_module.graph.print_tabular()" ] }, { "cell_type": "markdown", "metadata": { "id": "RjVbP-mmIdUs" }, "source": [ "### Create Map Function\n", "\n", "Let us define the overall high-level translation logic. The main flow is as follows:\n", "- Create a `node_map` that maps `fx.Node` to the corresponding `relax.Var` that represents the translated node in IRModule.\n", "- Iterate over the nodes in the fx graph in topological order.\n", "- Compute the mapped output of the node given the mapped inputs." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "U6t1bQ7kOmLs" }, "outputs": [], "source": [ "def map_param(param: nn.Parameter):\n", " ndim = len(param.data.shape)\n", " return relax.const(\n", " param.data.cpu().numpy(), relax.DynTensorType(ndim, \"float32\")\n", " )\n", "\n", "def fetch_attr(fx_mod, target: str):\n", " \"\"\"Helper function to fetch an attr\"\"\"\n", " target_atoms = target.split('.')\n", " attr_itr = fx_mod\n", " for i, atom in enumerate(target_atoms):\n", " if not hasattr(attr_itr, atom):\n", " raise RuntimeError(f\"Node referenced nonexistant target {'.'.join(target_atoms[:i])}\")\n", " attr_itr = getattr(attr_itr, atom)\n", " return attr_itr\n", "\n", "def from_fx(fx_mod, input_shapes, call_function_map, call_module_map):\n", " input_index = 0\n", " node_map = {}\n", " named_modules = dict(fx_mod.named_modules())\n", "\n", " bb = relax.BlockBuilder()\n", "\n", " fn_inputs = []\n", " fn_output = None\n", " with bb.function(\"main\"):\n", " with bb.dataflow():\n", " for node in fx_mod.graph.nodes:\n", " if node.op == \"placeholder\":\n", " # create input placeholder\n", " shape = input_shapes[input_index]\n", " input_index += 1 \n", " input_var = relax.Var(\n", " node.target, shape, relax.DynTensorType(len(shape), \"float32\")\n", " )\n", " fn_inputs.append(input_var)\n", " node_map[node] = input_var\n", " elif node.op == \"get_attr\":\n", " node_map[node] = map_param(fetch_attr(fx_mod, node.target))\n", " elif node.op == \"call_function\":\n", " node_map[node] = call_function_map[node.target](bb, node_map, node)\n", " elif node.op == \"call_module\":\n", " named_module = named_modules[node.target]\n", " node_map[node] = call_module_map[type(named_module)](bb, node_map, node, named_module)\n", " elif node.op == \"output\":\n", " output = node_map[node.args[0]]\n", " assert fn_output is None\n", " fn_output = bb.emit_output(output)\n", " # output and finalize the function\n", " bb.emit_func_output(output, fn_inputs)\n", " return bb.get()" ] }, { "cell_type": "markdown", "metadata": { "id": "HBlaVkysfrCu" }, "source": [ "We did not define the function map in the `from_fx` function. We will supply the translation rule of each torch function via a map. Specifically, the following code block shows how we can do that through the `emit_te` API." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CKqkz9jrfrbh", "outputId": "cbd008cf-6862-48fe-db45-59931c0fd68c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mte_matmul\u001b[39m(rxplaceholder: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], rxplaceholder_1: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], matmul: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mte_matmul\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1, i2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m128\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i, j, k \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1, i2])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(rxplaceholder[i, k], rxplaceholder_1[k, j])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(matmul[i, j])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " matmul[i, j] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " matmul[i, j] \u001b[38;5;129;01m=\u001b[39;00m matmul[i, j] \u001b[38;5;129;01m+\u001b[39;00m rxplaceholder[i, k] \u001b[38;5;129;01m*\u001b[39;00m rxplaceholder_1[k, j]\n", " \n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mte_relu\u001b[39m(rxplaceholder: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], relu: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mte_relu\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrelu\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i0_1, i1_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(rxplaceholder[i0_1, i1_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(relu[i0_1, i1_1])\n", " relu[i0_1, i1_1] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmax(rxplaceholder[i0_1, i1_1], T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m))\n", " \n", " \u001b[38;5;129m@R\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mfunction\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(x: Tensor((\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m Tensor(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, ndim \u001b[38;5;129;01m=\u001b[39;00m \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;30;03m# block 0\u001b[39;00m\n", " \u001b[38;5;28;01mwith\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mdataflow():\n", " lv \u001b[38;5;129;01m=\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mcall_tir(te_matmul, (x, meta[relay\u001b[38;5;129;01m.\u001b[39;00mConstant][\u001b[38;5;28m0\u001b[39m]), (\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " lv1 \u001b[38;5;129;01m=\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mcall_tir(te_relu, (lv,), (\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " gv: Tensor((\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m=\u001b[39;00m lv1\n", " R\u001b[38;5;129;01m.\u001b[39;00moutput(gv)\n", " \u001b[38;5;28;01mreturn\u001b[39;00m lv1\n", " \n", "\n" ] } ], "source": [ "def map_matmul(bb, node_map, node: fx.Node):\n", " A = node_map[node.args[0]]\n", " B = node_map[node.args[1]]\n", " return bb.emit_te(te_matmul, A, B)\n", "\n", "def map_relu(bb, node_map, node: fx.Node):\n", " A = node_map[node.args[0]]\n", " return bb.emit_te(te_relu, A)\n", "\n", "MyModule = from_fx(\n", " fx_module, \n", " input_shapes = [(1, 128)], \n", " call_function_map = {\n", " torch.matmul: map_matmul,\n", " torch.relu: map_relu, \n", " },\n", " call_module_map={},\n", ")\n", "\n", "MyModule.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "6OBLhELUgpBT" }, "source": [ "## Coming back to FashionMNIST Example" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "o1JM175Ri_hq" }, "outputs": [], "source": [ "import torch\n", "import torchvision\n", "\n", "test_data = torchvision.datasets.FashionMNIST(\n", " root=\"data\",\n", " train=False,\n", " download=True,\n", " transform=torchvision.transforms.ToTensor()\n", ")\n", "test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)\n", "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n", " 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']\n", "\n", "img, label = next(iter(test_loader))\n", "img = img.reshape(1, 28, 28).numpy()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 287 }, "id": "VOoQL7r7i_qf", "outputId": "3685c1c1-2121-47e1-c297-40624727763f" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS4AAAD8CAYAAADJwUnTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdu0lEQVR4nO3dfZQc5XXn8e+dntGMNHpBQkgokjAsCNuyHQMWwgRnDYYEwdmAvfYB5BMHr8HynljrkNi7Zr17bA45ew7JBrzOBpMMRgGy2JgY2ygbBcKyJITsgpEwBkmsQBbCktC70At6mZfuu390C/e81K2e6Z6uLvH7nNNH3X2rqp+pad2peurW85i7IyKSJ21ZN0BEZLSUuEQkd5S4RCR3lLhEJHeUuEQkd5S4RCR3lLhEZNyY2Qoz22VmaxPiZmZ/amYbzexFMzuvlu0qcYnIeLoXWBLErwAWVB7LgLtq2agSl4iMG3d/CtgXLHI1cL+XPQOcZGZz0rbb3qgG1mKCdXoX3c38SJF3lGMcps97rZ5tXH5Jt+/dV6xp2TUv9q4DjlW91ePuPaP4uLnAlqrXWyvvbY9WqitxmdkS4FtAAfiOu98WLd9FNxfYpfV8pIgEnvUn6t7Gnn1Fnn1sXk3Ldsz5+TF3X1T3h47SmBOXmRWAO4HfoJwlnzOzle6+vlGNE5EsOEUvNevDtgHzq17Pq7wXqqePazGw0d03uXsf8CDl81URyTEHSnhNjwZYCfxO5erih4ED7h6eJkJ9p4ojnZteMHQhM1tG+WoBXUyq4+NEpFlKNOaIy8y+B1wMzDSzrcA3gA4Ad/9zYBVwJbAROAL8m1q2O+6d85WOuh6AqTZDY+iItDjH6W/QqaK7L02JO/DF0W63nsQ1pnNTEWltDhQbcxo4burp43oOWGBmZ5jZBOA6yuerIpJzTezjGpMxH3G5+4CZLQceo1wOscLd1zWsZdIQB//uzDB+yamvhvHeUvwV+eHP4js0zv7c6jAurceBYouPjFxXH5e7r6LcuSYiJ5CmFUOMUVMr50Wk9Tne8n1cSlwiMog79Ld23lLiEpGhjCJ13e447pS4RGQQB0o64hKRvNERl4jkSrkAVYnrHc86O8O49/aG8cIpp4TxS57clBi7+6XJ4bprrkirQY4vjJ//9Gth/IFtaxJj/2ruh1I+O0VbIY6XahtTSgZzoN9be4xRJS4RGcQxii0+OLISl4gMU3KdKopIjqiPS0RyyCiqj0tE8qQ8AqoSl4jkiLvR5ylXbDOmxNUIKZfl08od0rz30b1h/NvPXJIYO/vG8R1W5sBH4rbdvOb8xNgr30kZEiel7W0Tu8J46ciR5GCLD9uStZL6uEQkT8qd8zpVFJFcUee8iOSMOudFJJeKKkAVkTxxjH5v7dTQ2q0TkaZT57yI5I5jOlWUdHs/f2EY/62T/iKMr71x7HOypA6509cXbyClHmrth5LbNvEHx+JtpygdPlzX+pJMnfMikivuqBxCRPKl3DmvW35EJGfUOS8iueKYBhIUkfzREZeI5Ep5XkUlLhHJFc1k/c5Q5zRYXZ/cGca/9NK1YXwOL4/5s+sdK6webW1xDdjRqxeH8YmP/CTe/qRJibFwrK53uPL0ZCfwVUUz2wwcAorAgLsvakSjRCQ77tbyp4qNaN0l7n6OkpbIiaPobTU9amFmS8xsg5ltNLObR4ifZmZPmtlPzexFM7sybZutnVZFpOnK43FZTY80ZlYA7gSuABYCS81s4ZDF/jPwkLufC1wHfDttu/UmLgf+3szWmNmykRYws2VmttrMVveTXX+KiNTKGnnEtRjY6O6b3L0PeBC4esgyDkytPJ8GvJG20Xo75z/i7tvMbBbwuJn9P3d/alCL3HuAHoCpNkMzFIi0uHI5RM1XFWeaWfWsJj2V//PHzQW2VL3eClwwZBu3UD4A+ndAN3BZ2ofWlbjcfVvl311m9iPK2fWpeC0RaWWjvFdxTwP6t5cC97r77WZ2IfBXZvZ+d08cWmTMp4pm1m1mU44/B34TWDvW7YlI6yjRVtOjBtuA+VWv51Xeq3YD8BCAu/9foAuYGW20niOu2cCPzOz4dr7r7o/Wsb3WZsGhc51z9F0972dh/P57Lx/ztq1jQhj3gf6UDcRfzrYJHWG8dCwYc+vZaeG6OxfH+/X0R8KwjFF5WJuGFaA+BywwszMoJ6zrgE8PWeYXwKXAvWb2XsqJa3e00TEnLnffBHxwrOuLSOtq1E3W7j5gZsuBx4ACsMLd15nZrcBqd18JfBm428x+n3IX22fd46MBVc6LyCDl0SEaVynl7quAVUPe+3rV8/XARaPZphKXiAxSvuWntUs8lbhEZIjWv+VHiUtEhqmlKj5LSlwiMkiDryqOCyWuWkVlAV7fsDbv7twexn/l6bfq2n5dUofsicshItNfibe99bfq26/h0DVReQvUXeKSdzpVFJFc0ZjzIpI7DgzoiEtE8kaniiKSL65TRRHJmeMDCbYyJS4RGUZHXCKSK6McSDATSlzNkFIzVLDE8dLK8YPB0DCUp1hKlDwWWyWeXb1SoTdu24xTDo7fh4/3z53jOjHHGCipc15EckZ9XCKSL65TRRHJGfVxiUguKXGJSK44RlGd8yKSN+qcF5FccXXON1FbPIGlFWqe4DJh+8m/SO+Nx41qmzQpjB8rpUwh1j72w3abkLLtUp31RPXu18D/+MC9Yfwmfm3M27bOzjDu/QNj3jaABd8XAB+ob/vjzZW4RCRfdJO1iOSQjrhEJFfcoVhS4hKRnNFVRRHJFUeniiKSO+qcF5EcauFRd4ATKXGlzP/nqfMDjp/S4cNh/JOT43GnVhwM5gcEolGtwrkFGyDtZ4t07Y7HGdvQP2vM207jvb3jtm1IHwYt1AJjebX6qWJqZaOZrTCzXWa2tuq9GWb2uJm9Wvl3+vg2U0SapXxVsa2mR1Zq+eR7gSVD3rsZeMLdFwBPVF6LyAnCvbZHVlITl7s/Bewb8vbVwH2V5/cBH29ss0QkS+5W0yMrY+3jmu3u2yvPdwCzkxY0s2XAMoAu4nv2RCR7TrZJqRZ1n6S6u1Mu/UiK97j7Indf1EF8Y6uItAav8ZGVsSaunWY2B6Dy767GNUlEMuXgJavpUQszW2JmG8xso5mN2B9uZteY2XozW2dm303b5lgT10rg+srz64FHxrgdEWlBjerjMrMCcCdwBbAQWGpmC4csswD4j8BF7v4+4Ka07ab2cZnZ94CLgZlmthX4BnAb8JCZ3QC8DlyT+hNkbN/nLgzjE95KOfCNLqGk1N1MfSWu09pV/OcwvuPyuWH81L9NrlHrPSux+7Es5btX7IzH2/KU4bg69ybXS73y6Ynhuj87cloYt3PfF8YPnTU5MeaFlLku++LvQ//EeP1pDzwTxkMtUP3ZwCYsBja6+yYAM3uQ8sW99VXLfB64093fLH+2p57BpSYud1+aELo0bV0RyZ9R3qs408xWV73ucfeeqtdzgS1Vr7cCFwzZxtkAZvbPQAG4xd0fjT70xKmcF5HGcKD2xLXH3RfV+YntwALKZ3bzgKfM7APuvj9phdaeykNEMtHAAtRtwPyq1/Mq71XbCqx09353fw14hXIiS6TEJSJD1HZFscaris8BC8zsDDObAFxH+eJetR9TPtrCzGZSPnXcFG1UiUtEhmtQIZe7DwDLgceAl4GH3H2dmd1qZldVFnsM2Gtm64EngX/v7nuj7aqPS0QG88aODuHuq4BVQ977etVzB/6g8qhJUxNX/6xudnw6eUqpto8NvSVysENvJV8+f8/cHeG6//20PwvjKw+cG8bbLPnPy8GBrnDdND1vfiiMz7xmSxif8jtHE2PTeC1c92B/3Pbo5waY1N4Xxo8MJE+P9oenPheu+/0d54fxD9yzPoz3lpK/3pML8bA2xZSTkU9MWxPG/+yL8UX359+YlxgrFOIxcSY/NDUxVny0jjKMatlXZIR0xCUiI2jtexWVuERkuHoGQmwCJS4RGWx0dVyZUOISkWFa4K6jkBKXiAynxCUiuaNTRRHJm5QqmMw1NXGVukscviB5uqyPzt4arn9gRnLNUVq90e9viEfeOdKbXG8E8MHZQ2+v+qWTOpLrqGpx2oQ9YfxXp8fDv2w4lDx0zeH++Oc60t8RxtMKEad0xvVQBUu+PPVG/0nhuu+f+kYYn9Ye7/d1b81JjKV9XzYeOiWMbzlyWRjffSx5SB2A7q7k+rfzZsV1e49fvjAxNvB/GpBx3KDGQQKzoiMuERlOR1wikjtKXCKSO0pcIpIrKkAVkTzSVUURyR8lLhHJGx1xVenY38asR5JrsSZ+pT9c/yfbk6erOvfUuAbs0jmvhPGdvcljHAF0tyfXK3VY8vRgAJMK8ZhVpZSxn943KbmGDGBu55uJsU1H43qkenW2DYTxaNyrd6XUr/UH42kBtAU1YgBnTkrefjRWF8CMzsNh/D3dO8P4a+0zw/j/+nlyLdYp814O153/g+Q54fYmfxVGR31cIpIrNQ7LnCUlLhEZTolLRPIm5Sw8c0pcIjKcjrhEJE/MdVVRRPJIVxVFJHd0xPVLNuBM3J1cq7V4ys/D9VcdfF9irDAnZf6/triWKs2hYP7BtLGd0mqGZnV0hvGTCsljmAE8ve+sxFgp5S/nhEJcg9aW8g0upUxjNavzUGJsX0c8ZtWBYjwO2SSPf6c7+5Jr8/pLybVQAHO6Dobxvf3dYfyMiXGN2ux5yQVXZ3dtD9ddvS25xsz6G9Or3uqninHlI2BmK8xsl5mtrXrvFjPbZmYvVB5Xjm8zRaRpvHxVsZZHVlITF3AvsGSE97/p7udUHqtGiItIXnmNj4ykJi53fwrY14S2iEiryHviCiw3sxcrp5LTkxYys2VmttrMVvf3x/d/iUhrOF4SkfbIylgT113AmcA5wHbg9qQF3b3H3Re5+6KOjrhDU0SkFmNKXO6+092L7l4C7gYWN7ZZIpKpE/FU0cyq5336BLA2aVkRyZkcXFVMreMys+8BFwMzzWwr8A3gYjM7h3LO3Qx8oZYPazvSS+fzGxPj2/sTu8rK9ifPEbh2T/IcegCXLVgfxl89MiuMTywk158VU2qljhbjuQ1fPZo8LyLAr0+JxxJbOvsnibGNvfG2p6XUiE0pHAvjx0rxvIxnTtiVGOuyePy1Q6Xk2jmAnx2cH8YXdCd/dlTjBem1d9Pb4/2WWt/W/VZibN3ReeG6rE/+P0RvPM9lzVq8jis1cbn70hHevmcc2iIiLcBo/QJU3fIjIsO1eOKqpxxCRE5ENZZC1HpUZmZLzGyDmW00s5uD5T5pZm5mi9K2qcQlIsOVanykMLMCcCdwBbAQWGpmwwbcN7MpwO8Bz9bSPCUuERmmgUdci4GN7r7J3fuAB4GrR1juD4E/AuKrQRVKXCIyXO11XDOP3xlTeSwbsqW5wJaq11sr773NzM4D5rv739bavOZ2zrcXsBnJJQ9fPfkfw9VfuSD50v7e3rgqf/fAlDA+t2t/GD84EEyr1hYPDbPuQFyqsWlnPJXVs38Tn/LvvDx5eJeF74qHSNlzJN5vlvJndaAYDw9ztC+5XKJ/XVySMNCd8if9lPjS/z0fezwxtrE/nlZtSsrvtMvicodHD78rjHfMSN7+V09+NVz3yrOuTYzZxrj0piajKy7d4+6pfVJJzKwNuAP47GjW01VFERmmgeUQ24Dqgrt5lfeOmwK8H/gHK/8xOBVYaWZXufvqpI0qcYnIcI1LXM8BC8zsDMoJ6zrg029/jPsB4O1TDjP7B+ArUdIC9XGJyAgadcuPuw8Ay4HHgJeBh9x9nZndamZXjbV9OuISkcEafAN1ZaDRVUPe+3rCshfXsk0lLhEZxCqPVqbEJSLDtfgtP0pcIjKMbrKu4r19DGzanBhf0xtPN3XaxOSh7w+nDB2z5diMMD69Ix6m5FMnJ1/k2FGM65Fun/N8GL/ornhUoMl//UwYP+mR5FosnxJPATZjIHmaLABLWd/3x9MR2PRpibHS1JSpz16IhyLa+/kLw/jGX0+u1br2+RvDdT3lP+6R3XH9278+P7woxuT25Bq039324XDd4roNiTF3DWsjIu9Enu0ggbVQ4hKR4XTEJSJ5oz4uEckfJS4RyRsdcYlIvjg1DRKYJSUuERlEk2WM0jU//lIYP39x8jRdu4/G9UZb71gQxrsfjkeM/SdLrq1pnzc3MQbwp1MmhfG+D8f3uu/7XFyvdHRW8g0aa7/07XDd8RbV5n3l3/5uuO7rd14Qxk9eE3/2f7jkusTYvC3JtVAAbZPjOq3iweTpxQBe/8e4bnDepP2JsUef+WC47oLaRjeujxKXiOSNpVXgZkyJS0QGa/DoEONBiUtEhlEfl4jkjm75EZH80RGXiOTKKGapzooSl4gMl/fEZWbzgfuB2ZR/nB53/5aZzQC+D5wObAaucfd4cKcUk1+P65m6L0yuCdqb8ifiyKyUbYdRKLw3qAPbdyBct7g+uf4M4OFHnw7jjx4+O17/vbMSY1fe/bFw3eLeeDytVCnzCxamJY9VNmF/PGbVpr98IYx/cO7SMD6w4vXEWP9lHwrXnfBmPKFy2tDGA6U9YbwQlKZP3hzPVTne8lCAWsssPwPAl919IfBh4ItmthC4GXjC3RcAT1Rei8gJwEpe0yMrqYnL3be7+/OV54coTzE0F7gauK+y2H3Ax8epjSLSTD6KR0ZG1cdlZqcD5wLPArPd/fj87json0qKyAnghCmHMLPJwMPATe5+0Kr6NtzdzUY+KzazZcAygC7ie/ZEpEWcAH1cmFkH5aT1gLv/sPL2TjObU4nPAXaNtK6797j7Indf1EFnI9osIuPMvLZHVlITl5UPre4BXnb3O6pCK4HrK8+vBx5pfPNEpOmc8jRHtTwyUsup4kXAZ4CXzOyFyntfA24DHjKzG4DXgWvqbcycpw+G8f5rk/PsxPb+cN1iZ8oF7JTL+v0zkk9zbXp8Cmw7dobxa9ddH8Z374unPzuTnybGSqfPCddtD8oVAHxifJTsKfvNoz+N++Mykt/efHEYP7QlbvupQaxvWvzVL/TG0915e/xzt7cVU+LJnUjTNyRPq9Ysue/jcvenSS5bubSxzRGRrOWhjkuV8yIyWMangbVQ4hKRYXTEJSL5o8QlInmjIy4RyRcHiq2duZS4RGQYHXGNgq9eG8Z/cej0xNip3XEN2EDa3UZ1XEVJq+lJGwKlsxDX/BT7arrBYUT907rCuBfibXtHPMSKp/1wQZ1X2k/VkVILVTi5N/7o9uSvd1qdkg2kLNAWt36gFO+37ceSa9Amrx/xJpRfbjuMNkgDryqa2RLgW0AB+I673zYk/gfAjZR/tN3A59w9eUwiarzlR0TeWRp1y4+ZFYA7gSuAhcDSyrBY1X4KLHL3XwV+APxx2naVuERksMYOa7MY2Ojum9y9D3iQ8pBYv/w49yfd/Ujl5TPAvLSNttSpoohkzwCrvXN+pplVD2Xb4+49Va/nAluqXm8FoinKbwD+Lu1DlbhEZJhRzGS9x90XNeQzzX4bWAR8NG1ZJS4RGayxo5tuA+ZXvZ5XeW8QM7sM+E/AR909vuqC+rhEZJgah7Sp7ajsOWCBmZ1hZhOA6ygPifU2MzsX+AvgKnePL6lW6IhLRIZpVB2Xuw+Y2XLgMcrlECvcfZ2Z3QqsdveVwH8FJgN/XRlZ+RfuflW03Vwlri3rk0dYevevxWNepdYbpYg6K8e7Vq+tvY7BkVJ+7tQ6rbaUGrWUmV6KnckH9WmH+73F+OtZ7Bu/abzS+ngKB+KzmQmFuNpq7e7kcdJO2bQhXLcpGljH5e6rgFVD3vt61fPLRrvNXCUuEWkCH9VVxUwocYnIcK2dt5S4RGS4UZRDZEKJS0SGU+ISkVxxIO+TZYjIO4vhOlUUkRwqtfYhV64S18znk2uKChfFfyFShnZKnVfRism/SKuzSMxSqv1sHO9vSBtLLFXK97tUx/b7Usa0oi3eb55SYxYptcc7vX3nvnj9lO/E/tdPSoydEq7ZBDpVFJE80qmiiOSPEpeI5IsmhBWRvNEsPyKSR+rjEpH8UeISkVxxoI5SkmZITVxmNh+4H5hN+UfqcfdvmdktwOcpz4MG8LXKuDvRxrCOCYlh7+8LV5/+3ecSY2/ccHK4blpJUNpfGA/qetLmJkwrw9p/ZGK8QB2jurWlzA/YdjRllr6U8bjSvuAdh8be9iMDyd8VAD+a8kstJRfvFY7F+6V/WkcYbz8Qz+NZ8u4w3v2L8RtLrH4nRuf8APBld3/ezKYAa8zs8Ursm+7+J+PXPBHJRN4Tl7tvB7ZXnh8ys5cpTzkkIiciB4I7RVrBqG4mMbPTgXOBZytvLTezF81shZlNT1hnmZmtNrPV/X6svtaKSBM4eKm2R0ZqTlxmNhl4GLjJ3Q8CdwFnAudQPiK7faT13L3H3Re5+6IO66q/xSIy/ho3y8+4qOmqopl1UE5aD7j7DwHcfWdV/G7gf45LC0WkuXJwVTH1iMvK8wXdA7zs7ndUvV89TckngLWNb56IZOIEOOK6CPgM8JKZvVB572vAUjM7h3J+3gx8IXVL7qklD+HqA8mX7g/0pZyGpqTowtlnhvHDJ3cmxoqdccnA1KlTw/hJk46G8a4J/WHcOpPb1js9vqxfmJjyFajzyxlNbzZp9qxw3fOmbwnjO+ZMCePtc38lMXZoalyO0DaQUh7z7jPC+OSOuO2z1qRO1pwsGoKpUbnkBLiq+DQjz84X12yJSD65QzFtALtsqXJeRIbL+xGXiLwDKXGJSL54y19VVOISkcEcPMPi0loocYnIcC1+y48Sl4gM5q7pyYYJa1DGfl49aVkcP6P/9TDuU+NhSLp2JtdalSbENUF+enI9EcD+H88I44W+lOnL3pM8/Eva9GDFKSnTsqV8f9NG3LGgr2TgzDmJMYC/ufesMD7zpbgWamBu8n6Z8trhcN1oOjqAwr54WJu9100L4+2vrwnjoWZ0nKtzXkTyxnXEJSL5cmIMJCgi7yQ5uMlaiUtEBnHAW/yWn1ENJCgi7wDe2IEEzWyJmW0ws41mdvMI8U4z+34l/mxlwNKQEpeIDOMlr+mRxswKwJ3AFcBCyqPKLByy2A3Am+5+FvBN4I/StqvEJSLDNe6IazGw0d03uXsf8CBw9ZBlrgbuqzz/AXBpZRzAROZNvHpgZruB6oKqmcCepjVgdFq1ba3aLlDbxqqRbXuXu59SzwbM7FHKbapFF1A9mUSPu/dUbetTwBJ3v7Hy+jPABe6+vGqZtZVltlZe/7yyTOI+aWrn/NAdamar3X1RM9tQq1ZtW6u2C9S2sWq1trn7kqzbkEaniiIynrYB86tez6u8N+IyZtYOTAP2RhtV4hKR8fQcsMDMzjCzCcB1wMohy6wErq88/xTwvz2lDyvrOq6e9EUy06pta9V2gdo2Vq3ctrq4+4CZLQceAwrACndfZ2a3AqvdfSXlyXj+ysw2AvsoJ7dQUzvnRUQaQaeKIpI7SlwikjuZJK60WwCyZGabzewlM3vBzFZn3JYVZrarUudy/L0ZZva4mb1a+Xd6C7XtFjPbVtl3L5jZlRm1bb6ZPWlm681snZn9XuX9TPdd0K6W2G950vQ+rsotAK8AvwFspXzVYam7r29qQxKY2WZgUVT81sS2/EvgLeB+d39/5b0/Bva5+22VpD/d3b/aIm27BXjL3f+k2e0Z0rY5wBx3f97MpgBrgI8DnyXDfRe06xpaYL/lSRZHXLXcAiCAuz9F+SpLterbI+6j/MVvuoS2tQR33+7uz1eeHwJeBuaS8b4L2iWjlEXimgtUz0++ldb65Tnw92a2xsxSBoTOxGx33155vgOYnWVjRrDczF6snEpmchpbrTLSwLnAs7TQvhvSLmix/dbq1Dk/3Efc/TzKd7N/sXJK1JIqRXqtVM9yF3AmcA6wHbg9y8aY2WTgYeAmdx80SHyW+26EdrXUfsuDLBJXLbcAZMbdt1X+3QX8iPKpbSvZWekrOd5nsivj9rzN3Xe6e9HLk/LdTYb7zsw6KCeHB9z9h5W3M993I7WrlfZbXmSRuGq5BSATZtZd6TTFzLqB3wTWxms1XfXtEdcDj2TYlkGOJ4WKT5DRvqsMiXIP8LK731EVynTfJbWrVfZbnmRSOV+53Pvf+OUtAP+l6Y0YgZn9C8pHWVC+Heq7WbbNzL4HXEx5iJGdwDeAHwMPAadRHiLoGndveid5Qtsupny648Bm4AtVfUrNbNtHgH8CXgKODxr1Ncr9SZntu6BdS2mB/ZYnuuVHRHJHnfMikjtKXCKSO0pcIpI7SlwikjtKXCKSO0pcIpI7Slwikjv/H/gT2ZAo1j/NAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Class: Bag\n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.figure()\n", "plt.imshow(img[0])\n", "plt.colorbar()\n", "plt.grid(False)\n", "plt.show()\n", "\n", "print(\"Class:\", class_names[label[0]])" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_-xEQ84Fi_yR", "outputId": "3ecdeebe-f28a-406e-921d-83c3a7072559" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "File ‘fasionmnist_mlp_params.pkl’ already there; not retrieving.\n", "\n" ] } ], "source": [ "# Hide outputs\n", "!wget -nc https://github.com/mlc-ai/web-data/raw/main/models/fasionmnist_mlp_params.pkl" ] }, { "cell_type": "markdown", "metadata": { "id": "mBwreM4qjn7x" }, "source": [ "![image.png]()" ] }, { "cell_type": "markdown", "metadata": { "id": "Vqh3yklvjquJ" }, "source": [ "The above is our model of interest, we can build the PyTorch model as follows" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "OqbgbCtEgfoV" }, "outputs": [], "source": [ "class MLP(nn.Module):\n", " def __init__(self):\n", " super(MLP, self).__init__()\n", " self.linear0 = nn.Linear(784, 128, bias=True)\n", " self.relu = nn.ReLU()\n", " self.linear1 = nn.Linear(128, 10, bias=True)\n", "\n", " def forward(self, x):\n", " x = self.linear0(x)\n", " x = self.relu(x)\n", " x = self.linear1(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "usBESp0ojUdc" }, "outputs": [], "source": [ "import pickle as pkl\n", "mlp_model = MLP()\n", "\n", "mlp_params = pkl.load(open(\"fasionmnist_mlp_params.pkl\", \"rb\"))\n", "mlp_model.linear0.weight.data = torch.from_numpy(mlp_params[\"w0\"])\n", "mlp_model.linear0.bias.data = torch.from_numpy(mlp_params[\"b0\"])\n", "mlp_model.linear1.weight.data = torch.from_numpy(mlp_params[\"w1\"])\n", "mlp_model.linear1.bias.data = torch.from_numpy(mlp_params[\"b1\"])" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MDKt0OAWk3Jr", "outputId": "8686b68c-bc0e-4eb9-8916-7d96b86c3ed4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Torch Prediction: Bag\n" ] } ], "source": [ "torch_res = mlp_model(torch.from_numpy(img.reshape(1, 784)))\n", "\n", "pred_kind = np.argmax(torch_res.detach().numpy(), axis=1)\n", "print(\"Torch Prediction:\", class_names[pred_kind[0]])" ] }, { "cell_type": "markdown", "metadata": { "id": "6BEuhKdzqlUW" }, "source": [ "Let us try to translate from fx by defining mapping functions for the corresponding `nn.Module`. Here we are reusing pre-defined TE libraries from TVM `topi` instead of defining our own tensor expression.\n", "\n", "- `topi.nn.dense(x, w)` performs transposed matrix multiplication `x @ w.T`\n", "- `topi.add` performs broadcast add.\n" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0YmBlh3ClbNX", "outputId": "a498fa9b-6ea7-4c2f-8a91-146a30cf5940" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdense1\u001b[39m(rxplaceholder: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], rxplaceholder_1: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m10\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], T_matmul_NT: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m10\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdense1\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlayout_free_buffers\u001b[39m\u001b[38;5;124m\"\u001b[39m: [\u001b[38;5;28m1\u001b[39m]})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1, i2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m10\u001b[39m, \u001b[38;5;28m128\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mT_matmul_NT\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i, j, k \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1, i2])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(rxplaceholder[i, k], rxplaceholder_1[j, k])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(T_matmul_NT[i, j])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " T_matmul_NT[i, j] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " T_matmul_NT[i, j] \u001b[38;5;129;01m=\u001b[39;00m T_matmul_NT[i, j] \u001b[38;5;129;01m+\u001b[39;00m rxplaceholder[i, k] \u001b[38;5;129;01m*\u001b[39;00m rxplaceholder_1[j, k]\n", " \n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mte_relu\u001b[39m(rxplaceholder: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], relu: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mte_relu\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrelu\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i0_1, i1_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(rxplaceholder[i0_1, i1_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(relu[i0_1, i1_1])\n", " relu[i0_1, i1_1] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmax(rxplaceholder[i0_1, i1_1], T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m))\n", " \n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21madd\u001b[39m(rxplaceholder: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], rxplaceholder_1: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m128\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], T_add: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124madd\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mT_add\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " ax0, ax1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(T_add[ax0, ax1])\n", " T_add[ax0, ax1] \u001b[38;5;129;01m=\u001b[39;00m rxplaceholder[ax0, ax1] \u001b[38;5;129;01m+\u001b[39;00m rxplaceholder_1[ax1]\n", " \n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21madd1\u001b[39m(rxplaceholder: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m10\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], rxplaceholder_1: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[\u001b[38;5;28m10\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], T_add: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m10\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124madd1\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m10\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mT_add\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " ax0, ax1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(T_add[ax0, ax1])\n", " T_add[ax0, ax1] \u001b[38;5;129;01m=\u001b[39;00m rxplaceholder[ax0, ax1] \u001b[38;5;129;01m+\u001b[39;00m rxplaceholder_1[ax1]\n", " \n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdense\u001b[39m(rxplaceholder: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m784\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], rxplaceholder_1: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m784\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m], T_matmul_NT: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdense\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlayout_free_buffers\u001b[39m\u001b[38;5;124m\"\u001b[39m: [\u001b[38;5;28m1\u001b[39m]})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i0, i1, i2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m, \u001b[38;5;28m784\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mT_matmul_NT\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " i, j, k \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i0, i1, i2])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(rxplaceholder[i, k], rxplaceholder_1[j, k])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(T_matmul_NT[i, j])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " T_matmul_NT[i, j] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " T_matmul_NT[i, j] \u001b[38;5;129;01m=\u001b[39;00m T_matmul_NT[i, j] \u001b[38;5;129;01m+\u001b[39;00m rxplaceholder[i, k] \u001b[38;5;129;01m*\u001b[39;00m rxplaceholder_1[j, k]\n", " \n", " \u001b[38;5;129m@R\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mfunction\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(x: Tensor((\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m784\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m Tensor(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, ndim \u001b[38;5;129;01m=\u001b[39;00m \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;30;03m# block 0\u001b[39;00m\n", " \u001b[38;5;28;01mwith\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mdataflow():\n", " lv \u001b[38;5;129;01m=\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mcall_tir(dense, (x, meta[relay\u001b[38;5;129;01m.\u001b[39;00mConstant][\u001b[38;5;28m0\u001b[39m]), (\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " lv1 \u001b[38;5;129;01m=\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mcall_tir(add, (lv, meta[relay\u001b[38;5;129;01m.\u001b[39;00mConstant][\u001b[38;5;28m1\u001b[39m]), (\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " lv2 \u001b[38;5;129;01m=\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mcall_tir(te_relu, (lv1,), (\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " lv3 \u001b[38;5;129;01m=\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mcall_tir(dense1, (lv2, meta[relay\u001b[38;5;129;01m.\u001b[39;00mConstant][\u001b[38;5;28m2\u001b[39m]), (\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m10\u001b[39m), dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " lv4 \u001b[38;5;129;01m=\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mcall_tir(add1, (lv3, meta[relay\u001b[38;5;129;01m.\u001b[39;00mConstant][\u001b[38;5;28m3\u001b[39m]), (\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m10\u001b[39m), dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " gv: Tensor((\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m10\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m=\u001b[39;00m lv4\n", " R\u001b[38;5;129;01m.\u001b[39;00moutput(gv)\n", " \u001b[38;5;28;01mreturn\u001b[39;00m lv4\n", " \n", "\n" ] } ], "source": [ "from tvm import topi\n", "\n", "def map_nn_linear(bb, node_map, node, nn_mod):\n", " x = node_map[node.args[0]]\n", " w = map_param(nn_mod.weight)\n", " if nn_mod.bias is not None:\n", " b = map_param(nn_mod.bias)\n", " y = bb.emit_te(topi.nn.dense, x, w)\n", " return bb.emit_te(topi.add, y, b)\n", "\n", "def map_nn_relu(bb, node_map, node, nn_mod):\n", " return map_relu(bb, node_map, node)\n", "\n", "\n", "MLPModule = from_fx(\n", " fx.symbolic_trace(mlp_model), \n", " input_shapes = [(1, 784)], \n", " call_function_map={\n", " },\n", " call_module_map={\n", " torch.nn.Linear: map_nn_linear,\n", " torch.nn.ReLU: map_nn_relu,\n", " },\n", ")\n", "\n", "MLPModule.show()" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "WAg9RZV2lbbK", "outputId": "476cd64c-6dca-4e3f-e9da-94cf81eb02be" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLPModule Prediction: Bag\n" ] } ], "source": [ "ex = relax.vm.build(MLPModule, target=\"llvm\")\n", "vm = relax.VirtualMachine(ex, tvm.cpu())\n", "data_nd = tvm.nd.array(img.reshape(1, 784))\n", "\n", "nd_res = vm[\"main\"](data_nd)\n", "\n", "pred_kind = np.argmax(nd_res.numpy(), axis=1)\n", "print(\"MLPModule Prediction:\", class_names[pred_kind[0]])" ] }, { "cell_type": "markdown", "metadata": { "id": "3BSOnEbUgu60" }, "source": [ "## Remark: Translating into High-level Operators\n", "\n", "In most machine learning frameworks, it is sometimes helpful to first translate into high-level built-in primitive operators. The following code block gives an example to do that.\n" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "YOo0H-5mhOGl", "outputId": "e3bca65f-6504-41c2-b41a-4120bf4e5ab9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@R\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mfunction\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(x: Tensor((\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m784\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m Tensor(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, ndim \u001b[38;5;129;01m=\u001b[39;00m \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;30;03m# block 0\u001b[39;00m\n", " \u001b[38;5;28;01mwith\u001b[39;00m R\u001b[38;5;129;01m.\u001b[39;00mdataflow():\n", " lv: Tensor((\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m=\u001b[39;00m relax\u001b[38;5;129;01m.\u001b[39;00mnn\u001b[38;5;129;01m.\u001b[39;00mdense(x, meta[relay\u001b[38;5;129;01m.\u001b[39;00mConstant][\u001b[38;5;28m0\u001b[39m])\n", " lv1: Tensor((\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m=\u001b[39;00m relax\u001b[38;5;129;01m.\u001b[39;00madd(lv, meta[relay\u001b[38;5;129;01m.\u001b[39;00mConstant][\u001b[38;5;28m1\u001b[39m])\n", " lv2: Tensor((\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m128\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m=\u001b[39;00m relax\u001b[38;5;129;01m.\u001b[39;00mnn\u001b[38;5;129;01m.\u001b[39;00mrelu(lv1)\n", " lv3: Tensor((\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m10\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m=\u001b[39;00m relax\u001b[38;5;129;01m.\u001b[39;00mnn\u001b[38;5;129;01m.\u001b[39;00mdense(lv2, meta[relay\u001b[38;5;129;01m.\u001b[39;00mConstant][\u001b[38;5;28m2\u001b[39m])\n", " lv4: Tensor((\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m10\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m=\u001b[39;00m relax\u001b[38;5;129;01m.\u001b[39;00madd(lv3, meta[relay\u001b[38;5;129;01m.\u001b[39;00mConstant][\u001b[38;5;28m3\u001b[39m])\n", " gv: Tensor((\u001b[38;5;28m1\u001b[39m, \u001b[38;5;28m10\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m=\u001b[39;00m lv4\n", " R\u001b[38;5;129;01m.\u001b[39;00moutput(gv)\n", " \u001b[38;5;28;01mreturn\u001b[39;00m lv4\n", " \n", "\n" ] } ], "source": [ "def map_nn_relu_op(bb, node_map, node, nn_mod):\n", " A = node_map[node.args[0]]\n", " return bb.emit(relax.op.relu(A))\n", "\n", "def map_nn_linear_op(bb, node_map, node, nn_mod):\n", " x = node_map[node.args[0]]\n", " w = map_param(nn_mod.weight)\n", " if nn_mod.bias is not None:\n", " b = map_param(nn_mod.bias)\n", " y = bb.emit(relax.op.dense(x, w))\n", " return bb.emit(relax.op.add(y, b))\n", "\n", "MLPModuleHighLevel = from_fx(\n", " fx.symbolic_trace(mlp_model), \n", " input_shapes = [(1, 784)], \n", " call_function_map={\n", " },\n", " call_module_map={\n", " torch.nn.Linear: map_nn_linear_op,\n", " torch.nn.ReLU: map_nn_relu_op,\n", " },\n", ")\n", "\n", "MLPModuleHighLevel.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "eQvtGR0rqci1" }, "source": [ "After we get the model into IRModule with those built-in operator calls.\n", "These built-in operators are **higher-level abstraction** than the TensorIR functions. There can be different opportunities to further translate these primitive operators into either library or TensorIR functions.\n", "\n", "In most cases, it can be helpful to translate into high-level builtins when they are available. However, there are many cases where we cannot find the corresponding high-level built-in or when we want to specify the TensorIR function directly. In those cases, we can customize the translation logic or transformation to generate `call_tir` or call into the library functions. Usually, we can get the best result by combining the high-level op, TensorIR, and library abstractions. We will discuss the tradeoffs in the follow-up lectures." ] }, { "cell_type": "markdown", "metadata": { "id": "bMHXI_ccFS-z" }, "source": [ "## Discussions\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "VkB5oHttOY0U" }, "source": [ "In this chapter, we focus on the **develop** part of the MLC flow. We studied different ways to get models from machine learning frameworks onto the IRModule. We also briefly touched upon the high-level primitive operators.\n", "\n", "Once we get the model into the IRModule, we can introduce more kinds of transformations on primitive functions and computational graph functions. A good MLC process composes these transformations together to form an end deployment form.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "y2KrBILMsNGf" }, "source": [ "![image.png]()" ] }, { "cell_type": "markdown", "metadata": { "id": "pZGWq5EjJ0BA" }, "source": [ "## Summary\n", "\n", "- Tensor expression API allows us to create a primitive TensorIR function.\n", "- BlockBuilder API creates IRModule through `emit_te` and other functions.\n", "- Integrate with existing machine learning frameworks by transforming models into an IRModule." ] } ], "metadata": { "colab": { "collapsed_sections": [], "include_colab_link": true, "name": "6 Integration with Machine Learning Frameworks.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3.10.4 ('mlc': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "vscode": { "interpreter": { "hash": "d8a760899c905ec5a15e0d212432af25d7f0b614c7ae634224dffa77837bb03c" } } }, "nbformat": 4, "nbformat_minor": 0 }