{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Scan and Recurrent Kernel\n**Author**: [Tianqi Chen](https://tqchen.github.io)\n\nThis is an introduction material on how to do recurrent computing in TVM.\nRecurrent computing is a typical pattern in neural networks.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from __future__ import absolute_import, print_function\n\n\nimport tvm\nimport tvm.testing\nfrom tvm import te\nimport numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TVM supports a scan operator to describe symbolic loop.\nThe following scan op computes cumsum over columns of X.\n\nThe scan is carried over the highest dimension of the tensor.\n:code:`s_state` is a placeholder that describes the transition state of the scan.\n:code:`s_init` describes how we can initialize the first k timesteps.\nHere since s_init's first dimension is 1, it describes how we initialize\nThe state at first timestep.\n\n:code:`s_update` describes how to update the value at timestep t. The update\nvalue can refer back to the values of previous timestep via state placeholder.\nNote that while it is invalid to refer to :code:`s_state` at current or later timestep.\n\nThe scan takes in state placeholder, initial value and update description.\nIt is also recommended(although not necessary) to list the inputs to the scan cell.\nThe result of the scan is a tensor, giving the result of :code:`s_state` after the\nupdate over the time domain.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "m = te.var(\"m\")\nn = te.var(\"n\")\nX = te.placeholder((m, n), name=\"X\")\ns_state = te.placeholder((m, n))\ns_init = te.compute((1, n), lambda _, i: X[0, i])\ns_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i])\ns_scan = tvm.te.scan(s_init, s_update, s_state, inputs=[X])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Schedule the Scan Cell\nWe can schedule the body of the scan by scheduling the update and\ninit part separately. Note that it is invalid to schedule the\nfirst iteration dimension of the update part.\nTo split on the time iteration, user can schedule on scan_op.scan_axis instead.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "s = te.create_schedule(s_scan.op)\nnum_thread = 256\nblock_x = te.thread_axis(\"blockIdx.x\")\nthread_x = te.thread_axis(\"threadIdx.x\")\nxo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)\ns[s_init].bind(xo, block_x)\ns[s_init].bind(xi, thread_x)\nxo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)\ns[s_update].bind(xo, block_x)\ns[s_update].bind(xi, thread_x)\nprint(tvm.lower(s, [X, s_scan], simple_mode=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build and Verify\nWe can build the scan kernel like other TVM kernels, here we use\nnumpy to verify the correctness of the result.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fscan = tvm.build(s, [X, s_scan], \"cuda\", name=\"myscan\")\ndev = tvm.cuda(0)\nn = 1024\nm = 10\na_np = np.random.uniform(size=(m, n)).astype(s_scan.dtype)\na = tvm.nd.array(a_np, dev)\nb = tvm.nd.array(np.zeros((m, n), dtype=s_scan.dtype), dev)\nfscan(a, b)\ntvm.testing.assert_allclose(b.numpy(), np.cumsum(a_np, axis=0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multi-Stage Scan Cell\nIn the above example we described the scan cell using one Tensor\ncomputation stage in s_update. It is possible to use multiple\nTensor stages in the scan cell.\n\nThe following lines demonstrate a scan with two stage operations\nin the scan cell.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "m = te.var(\"m\")\nn = te.var(\"n\")\nX = te.placeholder((m, n), name=\"X\")\ns_state = te.placeholder((m, n))\ns_init = te.compute((1, n), lambda _, i: X[0, i])\ns_update_s1 = te.compute((m, n), lambda t, i: s_state[t - 1, i] * 2, name=\"s1\")\ns_update_s2 = te.compute((m, n), lambda t, i: s_update_s1[t, i] + X[t, i], name=\"s2\")\ns_scan = tvm.te.scan(s_init, s_update_s2, s_state, inputs=[X])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These intermediate tensors can also be scheduled normally.\nTo ensure correctness, TVM creates a group constraint to forbid\nthe body of scan to be compute_at locations outside the scan loop.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "s = te.create_schedule(s_scan.op)\nxo, xi = s[s_update_s2].split(s_update_s2.op.axis[1], factor=32)\ns[s_update_s1].compute_at(s[s_update_s2], xo)\nprint(tvm.lower(s, [X, s_scan], simple_mode=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multiple States\nFor complicated applications like RNN, we might need more than one\nrecurrent state. Scan support multiple recurrent states.\nThe following example demonstrates how we can build recurrence with two states.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "m = te.var(\"m\")\nn = te.var(\"n\")\nl = te.var(\"l\")\nX = te.placeholder((m, n), name=\"X\")\ns_state1 = te.placeholder((m, n))\ns_state2 = te.placeholder((m, l))\ns_init1 = te.compute((1, n), lambda _, i: X[0, i])\ns_init2 = te.compute((1, l), lambda _, i: 0.0)\ns_update1 = te.compute((m, n), lambda t, i: s_state1[t - 1, i] + X[t, i])\ns_update2 = te.compute((m, l), lambda t, i: s_state2[t - 1, i] + s_state1[t - 1, 0])\ns_scan1, s_scan2 = tvm.te.scan(\n [s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2], inputs=[X]\n)\ns = te.create_schedule(s_scan1.op)\nprint(tvm.lower(s, [X, s_scan1, s_scan2], simple_mode=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\nThis tutorial provides a walk through of scan primitive.\n\n- Describe scan with init and update.\n- Schedule the scan cells as normal schedule.\n- For complicated workload, use multiple states and steps in scan cell.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 0 }