예제 #1
0
def ana_lower(sch, args, binds=None, simple_mode=True):
    """Do lower while keeping all axes in IR
    i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
    """
    binds, _ = build_module.get_binds(args, binds)
    sch = sch.normalize()
    # Phase 0
    bounds = schedule.InferBound(sch)
    stmt = schedule.ScheduleOps(sch, bounds, True)
    stmt = ir_pass.StorageFlatten(stmt, binds, 64)
    stmt = ir_pass.CanonicalSimplify(stmt)
    assert simple_mode
    return stmt
예제 #2
0
def _lower_schedule(sch, args):
    sch = sch.normalize()
    bounds = tvm.te.schedule.InferBound(sch)
    stmt = tvm.te.schedule.ScheduleOps(sch, bounds)

    compact = tvm.te.schedule.VerifyCompactBuffer(stmt)
    binds, arg_list = get_binds(args, compact, None)
    func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

    func = func.with_attr("global_symbol", "main")
    func = func.with_attr("tir.noalias", True)
    mod = tvm.IRModule({"main": func})
    return mod
예제 #3
0
def ana_lower(sch, args, binds=None, simple_mode=True):
    """Do lower while keeping all axes in IR
    i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
    """
    binds, _ = build_module.get_binds(args, binds)
    sch = sch.normalize()
    # Phase 0
    bounds = schedule.InferBound(sch)
    stmt = schedule.ScheduleOps(sch, bounds, True)
    func = schedule.SchedulePostProcToPrimFunc(args, stmt, None)
    mod = tvm.IRModule.from_expr(func._move())
    mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
    mod = tvm.tir.transform.Simplify()(mod._move())
    assert simple_mode
    return mod["main"].body
예제 #4
0
def lower_tvm_stmt(stmt, args, binds=None, name='main'):
    name = slugify(name)
    compact = schedule.VerifyCompactBuffer(stmt)
    binds, arg_list = build_module.get_binds(args,
                                             compact=compact,
                                             binds=binds)

    func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

    func = func.with_attr('global_symbol', name)

    pass_ctx = tvm.ir.transform.PassContext.current()
    if pass_ctx.config.get('tir.noalias', True):
        func = func.with_attr('tir.noalias', True)

    module = tvm.IRModule({name: func})
    module = build_module.lower(module, args, name=name)

    return module
예제 #5
0
def lower_ethosu(sch, args, const_dict, name="main"):
    """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target.

    The resulting TIR module will contain a single function
    that comprises of a sequence of tir.extern_calls to NPU
    operations.

    Parameters
    ----------
    sch : tvm.te.Schedule
        The schedule to be lowered.
    args : Union[list of tvm.te.Tensor, TEGraph]
        The input/output tensors.
    const_dict : dict of int to numpy.ndarray
        The constant dictionary.
    name : str, optional
        The name of the lowered primitive function.

    Returns
    -------
    mod : tvm.IRModule
        The lowered TIR module.
    const_dict : dict of int to numpy.ndarray
        The modified constant dictionary.

    """
    if not isinstance(args, list):
        args = list(args.inputs) + list(args.outputs)
    # config setup
    curr_pass_ctx = tvm.ir.transform.PassContext.current()
    curr_cfg = dict()
    for key, value in curr_pass_ctx.config.items():
        curr_cfg[key] = value
    tir_compiler_cfg = {
        "tir.LoopPartition": {
            "partition_const_loop": True,
            "no_unroll_loop_with_extent_one": True,
        },
        "tir.UnrollLoop": {
            "auto_max_depth": -1
        },
    }
    # Merge two configs
    curr_cfg = {**curr_cfg, **tir_compiler_cfg}

    sch = sch.normalize()
    bounds = tvm.te.schedule.InferBound(sch)
    stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True)

    compact = tvm.te.schedule.VerifyCompactBuffer(stmt)
    binds, arg_list = get_binds(args, compact, None)
    func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

    func = func.with_attr("global_symbol", name)
    func = func.with_attr("tir.noalias", True)
    mod = tvm.IRModule({name: func})
    with tvm.transform.PassContext(config=curr_cfg):
        mod = tvm.tir.transform.Simplify()(mod)
        mod = tvm.tir.transform.StorageFlatten(64)(mod)
        mod = tvm.tir.transform.UnrollLoop()(mod)
        mod = tvm.tir.transform.LoopPartition()(mod)
        mod = RemoveZeroStores()(mod)
        mod = tvm.tir.transform.Simplify()(mod)
        mod = tvm.tir.transform.RemoveNoOp()(mod)
        mod = ReplaceOperators()(mod)
        mod = tvm.tir.transform.RemoveNoOp()(mod)
        mod, const_dict = EncodeConstants(const_dict)(mod)
        mod = tvm.tir.transform.StorageRewrite()(mod)
        mod = tvm.tir.transform.RemoveNoOp()(mod)
    return mod, const_dict