def replace(op): if isinstance(op, _stmt.ProducerStore) and op.producer.op in rmap.keys(): buf = rmap[op.producer.op] return _stmt.ProducerStore(buf, op.value, op.indices) if isinstance(op, _expr.ProducerLoad) and op.producer.op in rmap.keys(): buf = rmap[op.producer.op] return _expr.ProducerLoad(buf, op.indices) return None
def replace(op): if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): buf = rmap[op.func] return _stmt.Provide(buf.op, op.value_index, op.value, op.args) if isinstance(op, _expr.ProducerLoad) and op.producer.op in rmap.keys(): buf = rmap[op.producer.op] return _expr.ProducerLoad(buf, op.indices) return None
def __call__(self, *indices): ndim = self.ndim if len(indices) != ndim: raise ValueError("Need to provide %d index in tensor slice" % ndim) indices = convert_to_object(indices) args = [] for x in indices: if isinstance(x, _expr.PrimExpr): args.append(x) elif isinstance(x, _expr.IterVar): args.append(x.var) else: raise ValueError("The indices must be expression") return _expr.ProducerLoad(self, args)
def tensor_no_check_call(self, *indices): """An indexing function without any check. This is the same as `tvm.te.Tensor::__call__` except that the safety check is removed. """ indices = convert_to_object(indices) args = [] for x in indices: if isinstance(x, _expr.PrimExpr): args.append(x) elif isinstance(x, _expr.IterVar): args.append(x.var) else: raise ValueError("The indices must be expression") return _expr.ProducerLoad(self, args)