def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" def set_poolinfos(stmt): if isinstance(stmt, tvm.tir.Allocate): return tvm.tir.Allocate( buffer_var=stmt.buffer_var, dtype=stmt.dtype, extents=stmt.extents, condition=stmt.condition, body=stmt.body, annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos}, ) return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos))
def replace_io(body, rmap): """Replacing tensors usage according to the dict given""" # pylint: disable=import-outside-toplevel from tvm.tir import stmt_functor def replace(op): if isinstance(op, _stmt.ProducerStore) and op.producer.op in rmap.keys(): buf = rmap[op.producer.op] return _stmt.ProducerStore(buf, op.value, op.indices) if isinstance(op, _expr.ProducerLoad) and op.producer.op in rmap.keys(): buf = rmap[op.producer.op] return _expr.ProducerLoad(buf, op.indices) return None return stmt_functor.ir_transform(body, None, replace, ['tir.ProducerStore', 'tir.ProducerLoad'])