def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ binds, _ = build_module.get_binds(args, binds) sch = sch.normalize() # Phase 0 bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds, True) func = schedule.SchedulePostProcToPrimFunc(args, stmt, None) mod = tvm.IRModule.from_expr(func._move()) mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) assert simple_mode return mod["main"].body
def lower_tvm_stmt(stmt, args, binds=None, name='main'): name = slugify(name) compact = schedule.VerifyCompactBuffer(stmt) binds, arg_list = build_module.get_binds(args, compact=compact, binds=binds) func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) func = func.with_attr('global_symbol', name) pass_ctx = tvm.ir.transform.PassContext.current() if pass_ctx.config.get('tir.noalias', True): func = func.with_attr('tir.noalias', True) module = tvm.IRModule({name: func}) module = build_module.lower(module, args, name=name) return module
def form_irmodule(sch, args, name, binds): """According to the given schedule, form a function. Parameters ---------- sch : tvm.te.schedule.Schedule The given scheduler to form the raw body args : list of Buffer or Tensor or Var The argument lists to the function. name : str The name of result function. binds : dict of :any:`Tensor` to :any:`Buffer`, optional The binds information Returns ------- The body formed according to the given schedule """ # normalize schedule first pass_ctx = PassContext.current() sch = sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) compact = schedule.VerifyCompactBuffer(stmt) binds, arg_list = get_binds(args, compact, binds) stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) func = func.with_attr("global_symbol", name) if pass_ctx.config.get("tir.noalias", True): func = func.with_attr("tir.noalias", True) return tvm.IRModule({name: func})