Ejemplo n.º 1
0
def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos):
    """helper to assing poolinfos to allocate nodes in a tir.PrimFunc"""

    def set_poolinfos(stmt):
        if isinstance(stmt, tvm.tir.Allocate):
            return tvm.tir.Allocate(
                buffer_var=stmt.buffer_var,
                dtype=stmt.dtype,
                extents=stmt.extents,
                condition=stmt.condition,
                body=stmt.body,
                annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos},
            )

    return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos))
Ejemplo n.º 2
0
def replace_io(body, rmap):
    """Replacing tensors usage according to the dict given"""
    # pylint: disable=import-outside-toplevel
    from tvm.tir import stmt_functor

    def replace(op):
        if isinstance(op, _stmt.ProducerStore) and op.producer.op in rmap.keys():
            buf = rmap[op.producer.op]
            return _stmt.ProducerStore(buf, op.value, op.indices)
        if isinstance(op, _expr.ProducerLoad) and  op.producer.op in rmap.keys():
            buf = rmap[op.producer.op]
            return _expr.ProducerLoad(buf, op.indices)
        return None

    return stmt_functor.ir_transform(body, None, replace, ['tir.ProducerStore', 'tir.ProducerLoad'])