def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ binds, _ = build_module.get_binds(args, binds) sch = sch.normalize() # Phase 0 bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds, True) stmt = ir_pass.StorageFlatten(stmt, binds, 64) stmt = ir_pass.CanonicalSimplify(stmt) assert simple_mode return stmt
def lower(sch, args, name="default_function", binds=None, simple_mode=False): """Lowering step before build into target. Parameters ---------- sch : tvm.te.schedule.Schedule The schedule to be built args : list of Buffer or Tensor or Var The argument lists to the function. name : str, optional The name of result function. binds : dict of :any:`Tensor` to :any:`Buffer`, optional Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. simple_mode : bool, optional Whether only output simple and compact statement, this will skip LoopPartition, api wrapper generation and Unrolling. Returns ------- m : IRModule or Stmt The result IRModule, if simple_mode=False Then the Stmt before make api is returned. """ cfg = BuildConfig.current() add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] if cfg.dump_pass_ir: add_lower_pass = BuildConfig._dump_ir.decorate_custompass( add_lower_pass) lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] # Phase 0 if isinstance(sch, schedule.Schedule): stmt = form_body(sch) for f in lower_phase0: stmt = f(stmt) compact = ir_pass.VerifyCompactBuffer(stmt) binds, arg_list = get_binds(args, compact, binds) # Phase 1 stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) stmt = ir_pass.NarrowDataType(stmt, 32) stmt = ir_pass.CanonicalSimplify(stmt) for f in lower_phase1: stmt = f(stmt) # Phase 2 if not simple_mode: stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) if cfg.disable_vectorize: stmt = ir_pass.SkipVectorize(stmt) else: stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) stmt = ir_pass.StorageRewrite(stmt) stmt = ir_pass.UnrollLoop(stmt, cfg.auto_unroll_max_step, cfg.auto_unroll_max_depth, cfg.auto_unroll_max_extent, cfg.unroll_explicit) for f in lower_phase2: stmt = f(stmt) # Phase 3 stmt = ir_pass.Simplify(stmt) stmt = ir_pass.RemoveNoOp(stmt) if not cfg.disable_select_rewriting: stmt = ir_pass.RewriteUnsafeSelect(stmt) for f in lower_phase3: stmt = f(stmt) # Instrument BoundCheckers if cfg.instrument_bound_checkers: stmt = ir_pass.InstrumentBoundCheckers(stmt) if simple_mode: return stmt f = tvm.tir.PrimFunc(arg_list, stmt).with_attr("global_symbol", tvm.runtime.String(name)) if cfg.restricted_func: f = f.with_attr("tir.noalias", True) mod = tvm.IRModule({name: f}) return tvm.tir.transform.MakePackedAPI()(mod)