# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.# pylint: disable=invalid-name, unused-import"""FFI registry to register function and objects."""importsysimportctypesfrom.baseimport_LIB,check_call,py_str,c_str,string_types,_FFI_MODE,_RUNTIME_ONLYtry:# pylint: disable=wrong-import-position,unused-importif_FFI_MODE=="ctypes":raiseImportError()from._cy3.coreimport_register_object,_get_object_type_indexfrom._cy3.coreimport_reg_extensionfrom._cy3.coreimportconvert_to_tvm_func,_get_global_func,PackedFuncBaseexcept(RuntimeError,ImportError)aserror:# pylint: disable=wrong-import-position,unused-importif_FFI_MODE=="cython":raiseerrorfrom._ctypes.objectimport_register_object,_get_object_type_indexfrom._ctypes.ndarrayimport_reg_extensionfrom._ctypes.packed_funcimportconvert_to_tvm_func,_get_global_func,PackedFuncBase
[文档]defregister_object(type_key=None):"""register object type. Parameters ---------- type_key : str or cls The type key of the node Examples -------- The following code registers MyObject using type key "test.MyObject" .. code-block:: python @tvm.register_object("test.MyObject") class MyObject(Object): pass """object_name=type_keyifisinstance(type_key,str)elsetype_key.__name__defregister(cls):"""internal register function"""ifhasattr(cls,"_type_index"):tindex=cls._type_indexelse:tidx=ctypes.c_uint()ifnot_RUNTIME_ONLY:check_call(_LIB.TVMObjectTypeKey2Index(c_str(object_name),ctypes.byref(tidx)))else:# directly skip unknown objects during runtime.ret=_LIB.TVMObjectTypeKey2Index(c_str(object_name),ctypes.byref(tidx))ifret!=0:returnclstindex=tidx.value_register_object(tindex,cls)returnclsifisinstance(type_key,str):returnregisterreturnregister(type_key)
[文档]defget_object_type_index(cls):""" Get type index of object type Parameters ---------- cls : type The object type to get type index for. Returns ------- type_index : Optional[int] The type index, or None if type not found in the registry. """return_get_object_type_index(cls)
[文档]defregister_extension(cls,fcreate=None):"""Register a extension class to TVM. After the class is registered, the class will be able to directly pass as Function argument generated by TVM. Parameters ---------- cls : class The class object to be registered as extension. fcreate : function, optional The creation function to create a class object given handle value. Note ---- The registered class is requires one property: _tvm_handle. If the registered class is a subclass of NDArray, it is required to have a class attribute _array_type_code. Otherwise, it is required to have a class attribute _tvm_tcode. - ```_tvm_handle``` returns integer represents the address of the handle. - ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type code of the class. Returns ------- cls : class The class being registered. Example ------- The following code registers user defined class MyTensor to be DLTensor compatible. .. code-block:: python @tvm.register_extension class MyTensor(object): _tvm_tcode = tvm.ArgTypeCode.ARRAY_HANDLE def __init__(self): self.handle = _LIB.NewDLTensor() @property def _tvm_handle(self): return self.handle.value """asserthasattr(cls,"_tvm_tcode")iffcreate:raiseValueError("Extension with fcreate is no longer supported")_reg_extension(cls,fcreate)returncls
[文档]defregister_func(func_name,f=None,override=False):"""Register global function Parameters ---------- func_name : str or function The function name f : function, optional The function to be registered. override: boolean optional Whether override existing entry. Returns ------- fregister : function Register function if f is not specified. Examples -------- The following code registers my_packed_func as global function. Note that we simply get it back from global function table to invoke it from python side. However, we can also invoke the same function from C++ backend, or in the compiled TVM code. .. code-block:: python targs = (10, 10.0, "hello") @tvm.register_func def my_packed_func(*args): assert(tuple(args) == targs) return 10 # Get it out from global function table f = tvm.get_global_func("my_packed_func") assert isinstance(f, tvm.PackedFunc) y = f(*targs) assert y == 10 """ifcallable(func_name):f=func_namefunc_name=f.__name__ifnotisinstance(func_name,str):raiseValueError("expect string function name")ioverride=ctypes.c_int(override)defregister(myf):"""internal register function"""ifnotisinstance(myf,PackedFuncBase):myf=convert_to_tvm_func(myf)check_call(_LIB.TVMFuncRegisterGlobal(c_str(func_name),myf.handle,ioverride))returnmyfiff:returnregister(f)returnregister
[文档]defget_global_func(name,allow_missing=False):"""Get a global function by name Parameters ---------- name : str The name of the global function allow_missing : bool Whether allow missing function or raise an error. Returns ------- func : PackedFunc The function to be returned, None if function is missing. """return_get_global_func(name,allow_missing)
[文档]deflist_global_func_names():"""Get list of global functions registered. Returns ------- names : list List of global functions names. """plist=ctypes.POINTER(ctypes.c_char_p)()size=ctypes.c_uint()check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size),ctypes.byref(plist)))fnames=[]foriinrange(size.value):fnames.append(py_str(plist[i]))returnfnames
[文档]defextract_ext_funcs(finit):""" Extract the extension PackedFuncs from a C module. Parameters ---------- finit : ctypes function a ctypes that takes signature of TVMExtensionDeclarer Returns ------- fdict : dict of str to Function The extracted functions """fdict={}def_list(name,func):fdict[name]=funcmyf=convert_to_tvm_func(_list)ret=finit(myf.handle)_=myfifret!=0:raiseRuntimeError("cannot initialize with %s"%finit)returnfdict
[文档]defremove_global_func(name):"""Remove a global function by name Parameters ---------- name : str The name of the global function """check_call(_LIB.TVMFuncRemoveGlobal(c_str(name)))
[文档]def_init_api(namespace,target_module_name=None):"""Initialize api for a given module name namespace : str The namespace of the source registry target_module_name : str The target module name if different from namespace """target_module_name=target_module_nameiftarget_module_nameelsenamespaceifnamespace.startswith("tvm."):_init_api_prefix(target_module_name,namespace[4:])else:_init_api_prefix(target_module_name,namespace)