def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ binds, _ = build_module.get_binds(args, binds) sch = sch.normalize() # Phase 0 bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds, True) stmt = ir_pass.StorageFlatten(stmt, binds, 64) stmt = ir_pass.CanonicalSimplify(stmt) assert simple_mode return stmt
def _lower_schedule(sch, args): sch = sch.normalize() bounds = tvm.te.schedule.InferBound(sch) stmt = tvm.te.schedule.ScheduleOps(sch, bounds) compact = tvm.te.schedule.VerifyCompactBuffer(stmt) binds, arg_list = get_binds(args, compact, None) func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) func = func.with_attr("global_symbol", "main") func = func.with_attr("tir.noalias", True) mod = tvm.IRModule({"main": func}) return mod
def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ binds, _ = build_module.get_binds(args, binds) sch = sch.normalize() # Phase 0 bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds, True) func = schedule.SchedulePostProcToPrimFunc(args, stmt, None) mod = tvm.IRModule.from_expr(func._move()) mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) assert simple_mode return mod["main"].body
def lower_tvm_stmt(stmt, args, binds=None, name='main'): name = slugify(name) compact = schedule.VerifyCompactBuffer(stmt) binds, arg_list = build_module.get_binds(args, compact=compact, binds=binds) func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) func = func.with_attr('global_symbol', name) pass_ctx = tvm.ir.transform.PassContext.current() if pass_ctx.config.get('tir.noalias', True): func = func.with_attr('tir.noalias', True) module = tvm.IRModule({name: func}) module = build_module.lower(module, args, name=name) return module
def lower_ethosu(sch, args, const_dict, name="main"): """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target. The resulting TIR module will contain a single function that comprises of a sequence of tir.extern_calls to NPU operations. Parameters ---------- sch : tvm.te.Schedule The schedule to be lowered. args : Union[list of tvm.te.Tensor, TEGraph] The input/output tensors. const_dict : dict of int to numpy.ndarray The constant dictionary. name : str, optional The name of the lowered primitive function. Returns ------- mod : tvm.IRModule The lowered TIR module. const_dict : dict of int to numpy.ndarray The modified constant dictionary. """ if not isinstance(args, list): args = list(args.inputs) + list(args.outputs) # config setup curr_pass_ctx = tvm.ir.transform.PassContext.current() curr_cfg = dict() for key, value in curr_pass_ctx.config.items(): curr_cfg[key] = value tir_compiler_cfg = { "tir.LoopPartition": { "partition_const_loop": True, "no_unroll_loop_with_extent_one": True, }, "tir.UnrollLoop": { "auto_max_depth": -1 }, } # Merge two configs curr_cfg = {**curr_cfg, **tir_compiler_cfg} sch = sch.normalize() bounds = tvm.te.schedule.InferBound(sch) stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True) compact = tvm.te.schedule.VerifyCompactBuffer(stmt) binds, arg_list = get_binds(args, compact, None) func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) func = func.with_attr("global_symbol", name) func = func.with_attr("tir.noalias", True) mod = tvm.IRModule({name: func}) with tvm.transform.PassContext(config=curr_cfg): mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.UnrollLoop()(mod) mod = tvm.tir.transform.LoopPartition()(mod) mod = RemoveZeroStores()(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod = ReplaceOperators()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod, const_dict = EncodeConstants(const_dict)(mod) mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) return mod, const_dict