def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ binds, _ = build_module.get_binds(args, binds) sch = sch.normalize() # Phase 0 bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds, True) stmt = ir_pass.StorageFlatten(stmt, binds, 64) stmt = ir_pass.CanonicalSimplify(stmt) assert simple_mode return stmt
def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ binds, _ = build_module.get_binds(args, binds) sch = sch.normalize() # Phase 0 bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds, True) func = schedule.SchedulePostProcToPrimFunc(args, stmt, None) mod = tvm.IRModule.from_expr(func._move()) mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) assert simple_mode return mod["main"].body
def form_body(sch): """According to the given schedule, form the raw body Parameters ---------- sch : tvm.te.schedule.Schedule The given scheduler to form the raw body Returns ------- The body formed according to the given schedule """ # normalize schedule first sch = sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) stmt = ir_pass.InjectPrefetch(stmt) return stmt
def form_irmodule(sch, args, name, binds): """According to the given schedule, form a function. Parameters ---------- sch : tvm.te.schedule.Schedule The given scheduler to form the raw body args : list of Buffer or Tensor or Var The argument lists to the function. name : str The name of result function. binds : dict of :any:`Tensor` to :any:`Buffer`, optional The binds information Returns ------- The body formed according to the given schedule """ # normalize schedule first pass_ctx = PassContext.current() sch = sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) compact = schedule.VerifyCompactBuffer(stmt) binds, arg_list = get_binds(args, compact, binds) stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) func = func.with_attr("global_symbol", name) if pass_ctx.config.get("tir.noalias", True): func = func.with_attr("tir.noalias", True) return tvm.IRModule({name: func})