{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Compile Tensorflow Models\nThis article is an introductory tutorial to deploy tensorflow models with TVM.\n\nFor us to begin with, tensorflow python module is required to be installed.\n\nPlease refer to https://www.tensorflow.org/install\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# tvm, relay\nimport tvm\nfrom tvm import te\nfrom tvm import relay\n\n# os and numpy\nimport numpy as np\nimport os.path\n\n# Tensorflow imports\nimport tensorflow as tf\n\n\n# Ask tensorflow to limit its GPU memory to what's actually needed\n# instead of gobbling everything that's available.\n# https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth\n# This way this tutorial is a little more friendly to sphinx-gallery.\ngpus = tf.config.list_physical_devices(\"GPU\")\nif gpus:\n try:\n for gpu in gpus:\n tf.config.experimental.set_memory_growth(gpu, True)\n print(\"tensorflow will use experimental.set_memory_growth(True)\")\n except RuntimeError as e:\n print(\"experimental.set_memory_growth option is not available: {}\".format(e))\n\n\ntry:\n tf_compat_v1 = tf.compat.v1\nexcept ImportError:\n tf_compat_v1 = tf\n\n# Tensorflow utility functions\nimport tvm.relay.testing.tf as tf_testing\n\n# Base location for model related files.\nrepo_base = \"https://github.com/dmlc/web-data/raw/main/tensorflow/models/InceptionV1/\"\n\n# Test image\nimg_name = \"elephant-299.jpg\"\nimage_url = os.path.join(repo_base, img_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tutorials\nPlease refer docs/frontend/tensorflow.md for more details for various models\nfrom tensorflow.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model_name = \"classify_image_graph_def-with_shapes.pb\"\nmodel_url = os.path.join(repo_base, model_name)\n\n# Image label map\nmap_proto = \"imagenet_2012_challenge_label_map_proto.pbtxt\"\nmap_proto_url = os.path.join(repo_base, map_proto)\n\n# Human readable text for labels\nlabel_map = \"imagenet_synset_to_human_label_map.txt\"\nlabel_map_url = os.path.join(repo_base, label_map)\n\n# Target settings\n# Use these commented settings to build for cuda.\n# target = tvm.target.Target(\"cuda\", host=\"llvm\")\n# layout = \"NCHW\"\n# dev = tvm.cuda(0)\ntarget = tvm.target.Target(\"llvm\", host=\"llvm\")\nlayout = None\ndev = tvm.cpu(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download required files\nDownload files listed above.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from tvm.contrib.download import download_testdata\n\nimg_path = download_testdata(image_url, img_name, module=\"data\")\nmodel_path = download_testdata(model_url, model_name, module=[\"tf\", \"InceptionV1\"])\nmap_proto_path = download_testdata(map_proto_url, map_proto, module=\"data\")\nlabel_path = download_testdata(label_map_url, label_map, module=\"data\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import model\nCreates tensorflow graph definition from protobuf file.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with tf_compat_v1.gfile.GFile(model_path, \"rb\") as f:\n graph_def = tf_compat_v1.GraphDef()\n graph_def.ParseFromString(f.read())\n graph = tf.import_graph_def(graph_def, name=\"\")\n # Call the utility to import the graph definition into default graph.\n graph_def = tf_testing.ProcessGraphDefParam(graph_def)\n # Add shapes to the graph.\n with tf_compat_v1.Session() as sess:\n graph_def = tf_testing.AddShapesToGraphDef(sess, \"softmax\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decode image\n
tensorflow frontend import doesn't support preprocessing ops like JpegDecode.\n JpegDecode is bypassed (just return source node).\n Hence we supply decoded frame to TVM instead.