tvm.relay.backend.te_compiler 源代码

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=len-as-condition,no-else-return,invalid-name
"""TE compiler engine (replacing legacy compile_engine)."""
from __future__ import absolute_import

import logging

import numpy as np
import tvm
from tvm import autotvm, te
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.meta_schedule import is_meta_schedule_enabled
from tvm.runtime import Object
from tvm.support import libinfo
from tvm.target import Target

from .. import function as _function
from .. import ty as _ty
from ..backend.utils import mangle_module_name
from . import _backend

logger = logging.getLogger("te_compiler")
autotvm_logger = logging.getLogger("autotvm")

_first_warning = True


[文档]@tvm._ffi.register_object("relay.LoweredOutput") class LoweredOutput(Object): """Lowered output""" def __init__(self, outputs, implement): self.__init_handle_by_constructor__(_backend._make_LoweredOutput, outputs, implement)
[文档]@tvm._ffi.register_object("relay.CCacheKey") class CCacheKey(Object): """Key in the TE Compiler. Parameters ---------- source_func : tvm.relay.Function The source function. target : tvm.Target The target we want to run the function on. """ def __init__(self, source_func, target): self.__init_handle_by_constructor__(_backend._make_CCacheKey, source_func, target)
[文档]@tvm._ffi.register_object("relay.CCacheValue") class CCacheValue(Object): """Value in the TE Compiler, including usage statistics."""
def _get_cache_key(source_func, target): if isinstance(source_func, _function.Function): if isinstance(target, str): target = Target(target) if not target: raise ValueError("Need target when source_func is a Function") return CCacheKey(source_func, target) if not isinstance(source_func, CCacheKey): raise TypeError("Expect source_func to be CCacheKey") return source_func
[文档]def get_valid_implementations(op, attrs, inputs, out_type, target): """Get all valid implementations from the op strategy. Note that this function doesn't support op with symbolic input shapes. Parameters ---------- op : tvm.ir.Op Relay operator. attrs : object The op attribute. inputs : List[tvm.te.Tensor] Input tensors to the op. out_type : relay.Type The output type. target : tvm.target.Target The target to compile the op. Returns ------- ret : List[relay.op.OpImplementation] The list of all valid op implementations. """ fstrategy = op.get_attr("FTVMStrategy") assert fstrategy is not None, ( "%s doesn't have an FTVMStrategy registered. You can register " "one in python with `tvm.relay.op.register_strategy`." % op.name ) with target: strategy = fstrategy(attrs, inputs, out_type, target) analyzer = tvm.arith.Analyzer() ret = [] for spec in strategy.specializations: if spec.condition: # check if all the clauses in the specialized condition are true flag = True for clause in spec.condition.clauses: clause = analyzer.canonical_simplify(clause) if isinstance(clause, tvm.tir.IntImm) and clause.value: continue flag = False break if flag: for impl in spec.implementations: ret.append(impl) else: for impl in spec.implementations: ret.append(impl) return ret
[文档]def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True): """Select the best implementation from the op strategy. If use_autotvm is True, it'll first try to find the best implementation based on AutoTVM profile results. If no AutoTVM profile result is found, it'll choose the implementation with highest plevel. If use_autotvm is False, it'll directly choose the implementation with highest plevel. Note that this function doesn't support op with symbolic input shapes. Parameters ---------- op : tvm.ir.Op Relay operator. attrs : object The op attribute. inputs : List[tvm.te.Tensor] Input tensors to the op. out_type : relay.Type The output type. target : tvm.target.Target The target to compile the op. use_autotvm : bool Whether query AutoTVM to pick the best. Returns ------- ret : tuple(relay.op.OpImplementation, List[tvm.te.Tensor]) The best op implementation and the corresponding output tensors. """ all_impls = get_valid_implementations(op, attrs, inputs, out_type, target) if len(all_impls) == 0: raise RuntimeError(f"No valid {op} implementations for {target}") best_plevel_impl = max(all_impls, key=lambda x: x.plevel) # Disable autotvm if auto_scheduler is enabled. # (i.e., always return the implementation with the highest priority for auto-scheduler). if is_auto_scheduler_enabled() or is_meta_schedule_enabled(): use_autotvm = False # If not use autotvm, always return the implementation with the highest priority if not use_autotvm: logger.info( "Using %s for %s based on highest priority (%d)", best_plevel_impl.name, op.name, best_plevel_impl.plevel, ) outs = best_plevel_impl.compute(attrs, inputs, out_type) return best_plevel_impl, outs # Otherwise, try autotvm templates outputs = {} workloads = {} best_autotvm_impl = None best_cfg = None dispatch_ctx = autotvm.task.DispatchContext.current old_silent = autotvm.GLOBAL_SCOPE.silent autotvm.GLOBAL_SCOPE.silent = True for impl in all_impls: outs = impl.compute(attrs, inputs, out_type) outputs[impl] = outs workload = autotvm.task.get_workload(outs) workloads[impl] = workload if workload is None: # Not an AutoTVM tunable implementation continue cfg = dispatch_ctx.query(target, workload) if cfg.is_fallback: # Skip fallback config continue logger.info("Implementation %s for %s has cost %.2e", impl.name, op.name, cfg.cost) if best_cfg is None or best_cfg.cost > cfg.cost: best_autotvm_impl = impl best_cfg = cfg autotvm.GLOBAL_SCOPE.silent = old_silent if best_autotvm_impl: # The best autotvm implementation definitely doesn't use fallback config logger.info( "Using %s for %s based on lowest cost (%.2e)", best_autotvm_impl.name, op.name, best_cfg.cost, ) return best_autotvm_impl, outputs[best_autotvm_impl] # Use the implementation with highest plevel if workloads[best_plevel_impl] is not None: msg = ( "Cannot find tuning records for:\n target=%s\n key=%s\n" "TVM will apply a default schedule which may negatively impact performance." % (target, workloads[best_plevel_impl]) ) if ( not autotvm.env.GLOBAL_SCOPE.silent and msg not in autotvm.task.DispatchContext.warning_messages ): autotvm.task.DispatchContext.warning_messages.add(msg) global _first_warning if _first_warning: _first_warning = False info_msg = ( "One or more operators have not been tuned. Please tune your model " "for better performance. Use DEBUG logging level to see more details." ) autotvm_logger.warning(info_msg) autotvm_logger.debug(msg) logger.info( "Using %s for %s based on highest priority (%s)", best_plevel_impl.name, op.name, best_plevel_impl.plevel, ) return best_plevel_impl, outputs[best_plevel_impl]
[文档]def get_shape(shape): """Convert the shape to correct dtype and vars.""" ret = [] for dim in shape: if isinstance(dim, tvm.tir.IntImm): if libinfo()["INDEX_DEFAULT_I64"] == "ON": ret.append(dim) else: val = int(dim) assert val <= np.iinfo(np.int32).max ret.append(tvm.tir.IntImm("int32", val)) elif isinstance(dim, tvm.tir.Any): ret.append(te.size_var("any_dim", "int32")) else: ret.append(dim) return ret
@tvm._ffi.register_func("relay.backend.lower_call") def lower_call(call, inputs, target, otype=None): """Lower the call expression to op implementation and tensor outputs.""" assert isinstance(call.op, tvm.ir.Op) op = call.op if otype is not None: ret_type = otype else: # Prepare the call_node->checked_type(). For the call node inputs, we ensure that # the shape is Int32. Following code ensures the same for the output as well. # TODO(@icemelon9): Support recursive tuple ret_type = call.checked_type if isinstance(ret_type, _ty.TensorType): ret_type = _ty.TensorType(get_shape(ret_type.shape), ret_type.dtype) elif isinstance(ret_type, _ty.TupleType): new_fields = [] for field in ret_type.fields: if isinstance(field, _ty.TensorType): new_fields.append(_ty.TensorType(get_shape(field.shape), field.dtype)) else: new_fields.append(field) ret_type = _ty.TupleType(new_fields) is_dyn = _ty.is_dynamic(call.checked_type) for arg in call.args: is_dyn = is_dyn or _ty.is_dynamic(arg.checked_type) # check if in the AutoTVM tracing mode, and disable if op is not in wanted list env = autotvm.task.TaskExtractEnv.current reenable_tracing = False if env is not None and env.tracing: if env.wanted_relay_ops is not None and op not in env.wanted_relay_ops: env.tracing = False reenable_tracing = True if not is_dyn: best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target) else: # TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes. best_impl, outputs = select_implementation( op, call.attrs, inputs, ret_type, target, use_autotvm=False ) # re-enable AutoTVM tracing if reenable_tracing: env.tracing = True return LoweredOutput(outputs, best_impl)
[文档]@tvm._ffi.register_object("relay.TECompiler") class TECompiler(Object): """TECompiler to get lowered code.""" def __init__(self): raise RuntimeError("Cannot construct a TECompiler")
[文档] def lower(self, source_func, target=None, mod_name="default"): """Lower a source_func to a CachedFunc. Parameters ---------- source_func : Union[tvm.relay.Function, CCacheKey] The source relay function. target : tvm.Target The target platform. Returns ------- cached_func: CachedFunc The result of lowering. """ # pylint: disable=broad-except, import-outside-toplevel try: mod_name = mangle_module_name(mod_name) key = _get_cache_key(source_func, target) return _backend._TECompilerLower(self, key, mod_name) except Exception: import traceback msg = traceback.format_exc() msg += "Error during compile func\n" msg += "--------------------------\n" msg += source_func.astext(show_meta_data=False) msg += "--------------------------\n" raise RuntimeError(msg)
[文档] def jit(self, source_func, target=None): """JIT a source_func to a tvm.runtime.PackedFunc. Parameters ---------- source_func : Union[tvm.relay.Function, CCacheKey] The source relay function. target : tvm.Target The target platform. Returns ------- jited_func: tvm.runtime.PackedFunc The result of jited function. """ key = _get_cache_key(source_func, target) return _backend._TECompilerJIT(self, key)
[文档] def clear(self): """clear the existing cached functions""" _backend._TECompilerClear(self)
[文档] def items(self): """List items in the cache. Returns ------- item_list : List[Tuple[CCacheKey, CCacheValue]] The list of items. """ res = _backend._TECompilerListItems(self) assert len(res) % 2 == 0 return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)]
[文档]def get(): """Get the global TE Compiler. Returns ------- engine : tvm.relay.backend.TECompiler The TE Compiler. """ return _backend._TECompilerGlobal()