{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Intrinsics and Math Functions\n**Author**: [Tianqi Chen](https://tqchen.github.io)\n\nWhile TVM supports basic arithmetic operations. In many cases\nusually we will need more complicated builtin functions.\nFor example :code:`exp` to take the exponential of the function.\n\nThese functions are target system dependent and may have different\nnames of different target platforms. In this tutorial, we will learn\nhow we can invoke these target specific functions, and how we can unify\nthe interface via TVM's intrinsic API.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from __future__ import absolute_import, print_function\n\n\nimport numpy as np\n\nimport tvm\nfrom tvm import te\nfrom tvm.ir import register_op_attr, register_intrin_lowering" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Direct Declare Extern Math Call\nThe most straight-forward way to call target specific function is via\nextern function call construct in tvm.\nIn the following example, we use :any:`tvm.tir.call_pure_extern` to call\n:code:`__expf` function, which is only available under CUDA.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = te.var(\"n\")\nA = te.placeholder((n,), name=\"A\")\nB = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern(\"float32\", \"__expf\", A[i]), name=\"B\")\ns = te.create_schedule(B.op)\nnum_thread = 64\nbx, tx = s[B].split(B.op.axis[0], factor=num_thread)\ns[B].bind(bx, te.thread_axis(\"blockIdx.x\"))\ns[B].bind(tx, te.thread_axis(\"threadIdx.x\"))\nf = tvm.build(s, [A, B], \"cuda\", name=\"myexp\")\nprint(f.imported_modules[0].get_source())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Unified Intrinsic Call\nThe above code verifies that direct external call can be used to\ncall into device specific functions.\nHowever, the above way only works for CUDA target with float type.\nIdeally, we want to write same code for any device and any data type.\n\nTVM intrinsic provides the user a mechanism to achieve this, and this\nis the recommended way to solve the problem.\nThe following code use te.exp instead, which create an intrinsic call\n:py::func:`tvm.te.exp` to do the exponential.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = te.var(\"n\")\nA = te.placeholder((n,), name=\"A\")\nB = te.compute(A.shape, lambda i: te.exp(A[i]), name=\"B\")\ns = te.create_schedule(B.op)\nnum_thread = 64\nbx, tx = s[B].split(B.op.axis[0], factor=num_thread)\ns[B].bind(bx, te.thread_axis(\"blockIdx.x\"))\ns[B].bind(tx, te.thread_axis(\"threadIdx.x\"))\nfcuda = tvm.build(s, [A, B], \"cuda\", name=\"myexp\")\nprint(fcuda.imported_modules[0].get_source())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can find that the code works for both CUDA and opencl.\nThe same te.exp can also be used for float64 data types.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fopencl = tvm.build(s, [A, B], \"opencl\", name=\"myexp\")\nprint(fopencl.imported_modules[0].get_source())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Intrinsic Lowering Rule\nWhen :py:func:`tvm.te.exp` is called, TVM creates an intrinsic Call Expr.\nTVM uses transformation rules to transform the intrinsic\ncall to device specific extern calls.\n\nTVM also allows user to customize the rules during runtime.\nThe following example customizes CUDA lowering rule for :code:`exp`.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def my_cuda_math_rule(op):\n \"\"\"Customized CUDA intrinsic lowering rule\"\"\"\n assert isinstance(op, tvm.tir.Call)\n name = op.op.name\n assert name.startswith(\"tir.\")\n dispatch_name = name[4:]\n if op.dtype == \"float32\":\n # call float function\n return tvm.tir.call_pure_extern(\"float32\", \"%sf\" % dispatch_name, op.args[0])\n elif op.dtype == \"float64\":\n # call double function\n return tvm.tir.call_pure_extern(\"float32\", dispatch_name, op.args[0])\n else:\n # cannot do translation, return self.\n return op\n\n\nregister_intrin_lowering(\"tir.exp\", target=\"cuda\", f=my_cuda_math_rule, level=99)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Register the rule to TVM with override option to override existing rule.\nNotice the difference between the printed code from previous one:\nour new rule uses math function :code:`expf` instead of\nfast math version :code:`__expf`.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fcuda = tvm.build(s, [A, B], \"cuda\", name=\"myexp\")\nprint(fcuda.imported_modules[0].get_source())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Add Your Own Intrinsic\nIf there is an intrinsic that is not provided by TVM.\nUser can easily add new intrinsic by using the intrinsic rule system.\nThe following example add an intrinsic :code:`mylog` to the system.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def mylog(x):\n \"\"\"customized log intrinsic function\"\"\"\n return tvm.tir.call_intrin(x.dtype, \"tir.mylog\", x)\n\n\ndef my_cuda_mylog_rule(op):\n \"\"\"CUDA lowering rule for log\"\"\"\n if op.dtype == \"float32\":\n return tvm.tir.call_pure_extern(\"float32\", \"logf\", op.args[0])\n elif op.dtype == \"float64\":\n return tvm.tir.call_pure_extern(\"float64\", \"log\", op.args[0])\n else:\n return op\n\n\n# new op registration is triggered by registering an attribute of the op\nregister_op_attr(\"tir.mylog\", \"TCallEffectKind\", tvm.tir.CallEffectKind.Pure)\nregister_intrin_lowering(\"tir.mylog\", target=\"cuda\", f=my_cuda_mylog_rule, level=99)\n\nn = te.var(\"n\")\nA = te.placeholder((n,), name=\"A\")\nB = te.compute(A.shape, lambda i: mylog(A[i]), name=\"B\")\ns = te.create_schedule(B.op)\nnum_thread = 64\nbx, tx = s[B].split(B.op.axis[0], factor=num_thread)\ns[B].bind(bx, te.thread_axis(\"blockIdx.x\"))\ns[B].bind(tx, te.thread_axis(\"threadIdx.x\"))\nfcuda = tvm.build(s, [A, B], \"cuda\", name=\"mylog\")\nprint(fcuda.imported_modules[0].get_source())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n- TVM can call extern target dependent math function.\n- Use intrinsic to defined a unified interface for the functions.\n- For more intrinsics available in tvm, take a look at :any:`tvm.tir`\n- You can customize the intrinsic behavior by defining your own rules.\n\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 0 }