# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""Tensor and Operation class for computation declaration."""# pylint: disable=invalid-nameimportwarningsimportnumpyas_npfromtvm.runtimeimportndarrayas_ndfromtvmimporttefromtvm.tirimportexpras_exprfromtvm.teimporttensoras_tensorfloat32="float32"itype="int32"
[文档]classCSRNDArray(object):"""Sparse tensor object in CSR format."""
[文档]def__init__(self,arg1,device=None,shape=None):"""Construct a sparse matrix in CSR format. Parameters ---------- arg1 : numpy.ndarray or a tuple with (data, indices, indptr) The corresponding a dense numpy array, or a tuple for constructing a sparse matrix directly. device: Device The corresponding device. shape : tuple of int The shape of the array """ifisinstance(arg1,tuple):assertlen(arg1)==3self.data,self.indices,self.indptr=arg1self.shape=shapeelifisinstance(arg1,_np.ndarray):source_array=arg1ridx,cidx=_np.nonzero(source_array)data=source_array[ridx,cidx]self.data=_nd.array(data,device)indices=_np.nonzero(source_array)[1].astype(itype)self.indices=_nd.array(indices,device)indptr=[0]+_np.apply_along_axis(_np.count_nonzero,axis=1,arr=source_array).tolist()indptr=_np.cumsum(_np.array(indptr,itype)).astype(itype)self.indptr=_nd.array(indptr,device)self.shape=source_array.shapeelse:raiseRuntimeError("Construct CSRNDArray with either a tuple (data, indices, indptr) ""or a numpy.array, can't handle type %s."%(type(arg1),))self.stype="csr"self.dtype=self.data.dtypeassertself.shapeisnotNoneassertisinstance(self.data,_nd.NDArray)assertisinstance(self.indices,_nd.NDArray)assertstr(self.indices.dtype)=="int32"orstr(self.indices.dtype)=="int64",str(self.indices.dtype)assertisinstance(self.indptr,_nd.NDArray)assertstr(self.indptr.dtype)=="int32"orstr(self.indptr.dtype)=="int64",str(self.indptr.dtype)
[文档]defasnumpy(self):"""Construct a full matrix and convert it to numpy array. This API will be deprecated in TVM v0.8 release. Please use `numpy` instead."""warnings.warn("CSRNDArray.asnumpy() will be deprecated in TVM v0.8 release. ""Please use CSRNDArray.numpy() instead.",DeprecationWarning,)returnself.numpy()
[文档]defnumpy(self):"""Construct a full matrix and convert it to numpy array."""full=_np.zeros(self.shape,self.dtype)ridx=_np.diff(self.indptr.numpy())ridx=_np.hstack((_np.ones((v,),itype)*ifori,vinenumerate(ridx)))full[ridx,self.indices.numpy().astype(itype)]=self.data.numpy()returnfull
[文档]defarray(source_array,device=None,shape=None,stype="csr"):"""Construct a sparse NDArray from numpy.ndarray"""ret=Noneifstype=="csr":ret=CSRNDArray(source_array,shape=shape,device=device)else:raiseNotImplementedError("stype=%s is not supported yet."%(stype,))returnret
[文档]classSparsePlaceholderOp(object):"""Placeholder class for sparse tensor representations."""
[文档]def__init__(self,shape,nonzeros,dtype,name):# pylint: disable=unused-argument"""Contructing a bare bone structure for a sparse matrix Parameters ---------- shape: Tuple of Expr The shape of the tensor nonzeros: int The number of non-zero values dtype: str, optional The data type of the tensor name: str, optional The name hint of the tensor """self.shape=shapeself.dtype=dtypeself.name=nameself.stype="unknown"
[文档]classCSRPlaceholderOp(SparsePlaceholderOp):"""Placeholder class for CSR based sparse tensor representation."""
[文档]def__init__(self,shape,nonzeros,dtype,name):"""Contructing a bare bone structure for a csr_matrix Parameters ---------- shape: Tuple of Expr The shape of the tensor nonzeros: int The number of non-zero values dtype: str, optional The data type of the tensor name: str, optional The name hint of the tensor """SparsePlaceholderOp.__init__(self,shape,nonzeros,dtype,name)self.stype="csr"self.data=te.placeholder((nonzeros,),dtype=dtype,name=self.name+"_data")self.indices=te.placeholder((nonzeros,),dtype=itype,name=self.name+"_indices")self.indptr=te.placeholder((self.shape[0]+1,),dtype=itype,name=self.name+"_indptr")assertisinstance(self.data,_tensor.Tensor)assertisinstance(self.indices,_tensor.Tensor)assertisinstance(self.indptr,_tensor.Tensor)
[文档]defplaceholder(shape,nonzeros=None,dtype=None,name="placeholder",stype=None):"""Construct an empty sparse tensor object. Parameters ---------- shape: Tuple of Expr The shape of the tensor nonzeros: int The number of non-zero values dtype: str, optional The data type of the tensor name: str, optional The name hint of the tensor stype: str, optional The name storage type of the sparse tensor (e.g. csr, coo, ell) Returns ------- tensor: SparsePlaceholderOp The created sparse tensor placeholder """shape=(shape,)ifisinstance(shape,_expr.PrimExpr)elseshapenonzeros=0ifnonzerosisNoneelsenonzerosdtype=float32ifdtypeisNoneelsedtypestype="csr"ifstypeisNoneelsestyperet=Noneifstype=="csr":ret=CSRPlaceholderOp(shape=shape,nonzeros=nonzeros,dtype=dtype,name=name)else:raiseNotImplementedError("stype=%s is not supported yet."%(stype,))returnret