def visit_For(self, node): iter_var, low, ext, for_type = self.visit(node.iter) _internal_assert(isinstance(node.target, ast.Name), \ "The loop iterator should be a variable!") _name = node.target.id if isinstance(for_type, tuple): low = _ir_pass.CanonicalSimplify(low) ext = _ir_pass.CanonicalSimplify(ext) _internal_assert(isinstance(low, _expr.ConstExpr) and isinstance(ext, _expr.ConstExpr), \ "Const range should start from a const " + \ "and iterate const times") low, ext = low.value, ext.value if ext > 114514: logging.log(logging.CRITICAL, \ '[Warning] Are you sure to unroll a large loop in Python?') bodies = [] for i in range(low, low + ext): self.add_symbol(_name, Symbol.ConstLoopVar, i) body = visit_list_to_block(self.visit, node.body) body = self.wrap_up_realize(node, body) bodies.append(body) self.symbols.pop(_name) return concat_list_to_block(bodies) if iter_var is None: _internal_assert(for_type is not None, "The loop iterating function parse error!") offset = iter_var = tvm.te.var(_name) if not tvm.tir.analysis.expr_deep_equal( low, tvm.runtime.const(0, 'int32')): offset = iter_var + low self.add_symbol(_name, Symbol.LoopVar, offset) _body = visit_list_to_block(self.visit, node.body) else: _internal_assert(for_type is None, "The loop bind function parse error!") self.add_symbol(_name, Symbol.ThreadBind, iter_var) self.device += 1 _body = visit_list_to_block(self.visit, node.body) self.device -= 1 _body = self.wrap_up_realize(node, _body) if for_type is None: res = _body else: _internal_assert(not isinstance(for_type, tuple), \ "Micro expansion should be handled before!") res = tvm.tir.For(iter_var, tvm.runtime.const(0, 'int32'), ext, for_type, 0, _body) self.symbols.pop(_name) return res
def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ binds, _ = build_module.get_binds(args, binds) sch = sch.normalize() # Phase 0 bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds, True) stmt = ir_pass.StorageFlatten(stmt, binds, 64) stmt = ir_pass.CanonicalSimplify(stmt) assert simple_mode return stmt
def visit_If(self, node): cond = _ir_pass.CanonicalSimplify(self.visit(node.test)) # Return no IfThenElse if proven if isinstance(cond, _expr.IntImm): if cond.value: return visit_list_to_block(self.visit, node.body) if node.orelse: return visit_list_to_block(self.visit, node.orelse) return util.make_nop() if_body = visit_list_to_block(self.visit, node.body) if node.orelse: else_body = visit_list_to_block(self.visit, node.orelse) else: else_body = None return tvm.tir.IfThenElse(cond, if_body, else_body)
def lower(sch, args, name="default_function", binds=None, simple_mode=False): """Lowering step before build into target. Parameters ---------- sch : tvm.te.schedule.Schedule The schedule to be built args : list of Buffer or Tensor or Var The argument lists to the function. name : str, optional The name of result function. binds : dict of :any:`Tensor` to :any:`Buffer`, optional Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. simple_mode : bool, optional Whether only output simple and compact statement, this will skip LoopPartition, api wrapper generation and Unrolling. Returns ------- m : IRModule or Stmt The result IRModule, if simple_mode=False Then the Stmt before make api is returned. """ cfg = BuildConfig.current() add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] if cfg.dump_pass_ir: add_lower_pass = BuildConfig._dump_ir.decorate_custompass( add_lower_pass) lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] # Phase 0 if isinstance(sch, schedule.Schedule): stmt = form_body(sch) for f in lower_phase0: stmt = f(stmt) compact = ir_pass.VerifyCompactBuffer(stmt) binds, arg_list = get_binds(args, compact, binds) # Phase 1 stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) stmt = ir_pass.NarrowDataType(stmt, 32) stmt = ir_pass.CanonicalSimplify(stmt) for f in lower_phase1: stmt = f(stmt) # Phase 2 if not simple_mode: stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) if cfg.disable_vectorize: stmt = ir_pass.SkipVectorize(stmt) else: stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) stmt = ir_pass.StorageRewrite(stmt) stmt = ir_pass.UnrollLoop(stmt, cfg.auto_unroll_max_step, cfg.auto_unroll_max_depth, cfg.auto_unroll_max_extent, cfg.unroll_explicit) for f in lower_phase2: stmt = f(stmt) # Phase 3 stmt = ir_pass.Simplify(stmt) stmt = ir_pass.RemoveNoOp(stmt) if not cfg.disable_select_rewriting: stmt = ir_pass.RewriteUnsafeSelect(stmt) for f in lower_phase3: stmt = f(stmt) # Instrument BoundCheckers if cfg.instrument_bound_checkers: stmt = ir_pass.InstrumentBoundCheckers(stmt) if simple_mode: return stmt f = tvm.tir.PrimFunc(arg_list, stmt).with_attr("global_symbol", tvm.runtime.String(name)) if cfg.restricted_func: f = f.with_attr("tir.noalias", True) mod = tvm.IRModule({name: f}) return tvm.tir.transform.MakePackedAPI()(mod)