def visit_Assign(self, node): rhs = self.visit(node.value) if isinstance(rhs, Operation): rmap = {} _internal_assert(len(node.targets) == rhs.num_outputs, \ "Unable to detuple the outs to targets") for i in range(rhs.num_outputs): _internal_assert(isinstance(node.targets[i], ast.Name), "You should bind a pure name to the tensors") self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i)) rmap[rhs.outputs[i].op] = rhs.output(i) return util.replace_io(rhs.body, rmap) _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") lhs = node.targets[0] if isinstance(rhs, _expr.PrimExpr): rhs = _ir_pass.Simplify(rhs) if isinstance(lhs, ast.Name): #TODO: support defined intermediate buffer later lhs_ = lhs lhs = lhs.id if lhs in self.symbols.keys(): ty, _ = self.symbols[lhs] _internal_assert(ty != Symbol.LoopVar, \ "Loop variable cannot be overwritten!") decl, _, rw = self.usage[lhs] if decl == lhs_: _internal_assert(lhs not in self.symbols.keys(), "This value should not be defined before this point!") if isinstance(rhs, tuple): shape, dtype, scope = rhs ph = tvm.te.placeholder(shape, dtype=dtype, name=lhs) self.add_symbol(lhs, getattr(Symbol, scope.title() + "Buffer"), ph) if scope == 'output': self.outputs.append(lhs) return util.make_nop() if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw: self.add_symbol(lhs, Symbol.ConstVar, rhs) else: _internal_assert(self.device == 0, "Single variable not supported in devices' side!\n" + \ "If you are using GPU, please allocate a 'local' spad " + \ "outside the bind body") ph = tvm.te.placeholder((1, ), dtype=rhs.dtype, name=lhs) self.add_symbol(lhs, Symbol.BufferVar, ph) lhs = self.visit(lhs_) if lhs is not None: buf, args = lhs return tvm.tir.Provide(buf.op, 0, rhs, args) return util.make_nop() lhs, args = self.visit(lhs) _internal_assert(isinstance(lhs, Tensor), \ "An array access's LHS is expected to be a expr.Call!") res = tvm.tir.Provide(lhs.op, lhs.value_index, rhs, args) return res
def get_const_int(exp): """Verifies expr is integer and get the constant value. Parameters ---------- exp : tvm.Expr or int The input expression. Returns ------- out_value : int The output. """ if isinstance(exp, int): return exp if not isinstance(exp, (expr.IntImm, )): exp = ir_pass.Simplify(exp) if not isinstance(exp, (expr.IntImm, )): raise ValueError("Expect value to be constant int") return exp.value
def get_const_tuple(in_tuple): """Verifies input tuple is IntImm or Var, returns tuple of int or Var. Parameters ---------- in_tuple : tuple of Expr The input. Returns ------- out_tuple : tuple of int The output. """ ret = [] for elem in in_tuple: if isinstance(elem, expr.Var): ret.append(elem) elif not isinstance(elem, (expr.IntImm, int)): elem = ir_pass.Simplify(elem) if not isinstance(elem, (expr.IntImm)): ret.append(elem) else: ret.append(get_const_int(elem)) return tuple(ret)
def lower(sch, args, name="default_function", binds=None, simple_mode=False): """Lowering step before build into target. Parameters ---------- sch : tvm.te.schedule.Schedule The schedule to be built args : list of Buffer or Tensor or Var The argument lists to the function. name : str, optional The name of result function. binds : dict of :any:`Tensor` to :any:`Buffer`, optional Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. simple_mode : bool, optional Whether only output simple and compact statement, this will skip LoopPartition, api wrapper generation and Unrolling. Returns ------- m : IRModule or Stmt The result IRModule, if simple_mode=False Then the Stmt before make api is returned. """ cfg = BuildConfig.current() add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] if cfg.dump_pass_ir: add_lower_pass = BuildConfig._dump_ir.decorate_custompass( add_lower_pass) lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] # Phase 0 if isinstance(sch, schedule.Schedule): stmt = form_body(sch) for f in lower_phase0: stmt = f(stmt) compact = ir_pass.VerifyCompactBuffer(stmt) binds, arg_list = get_binds(args, compact, binds) # Phase 1 stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) stmt = ir_pass.NarrowDataType(stmt, 32) stmt = ir_pass.CanonicalSimplify(stmt) for f in lower_phase1: stmt = f(stmt) # Phase 2 if not simple_mode: stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) if cfg.disable_vectorize: stmt = ir_pass.SkipVectorize(stmt) else: stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) stmt = ir_pass.StorageRewrite(stmt) stmt = ir_pass.UnrollLoop(stmt, cfg.auto_unroll_max_step, cfg.auto_unroll_max_depth, cfg.auto_unroll_max_extent, cfg.unroll_explicit) for f in lower_phase2: stmt = f(stmt) # Phase 3 stmt = ir_pass.Simplify(stmt) stmt = ir_pass.RemoveNoOp(stmt) if not cfg.disable_select_rewriting: stmt = ir_pass.RewriteUnsafeSelect(stmt) for f in lower_phase3: stmt = f(stmt) # Instrument BoundCheckers if cfg.instrument_bound_checkers: stmt = ir_pass.InstrumentBoundCheckers(stmt) if simple_mode: return stmt f = tvm.tir.PrimFunc(arg_list, stmt).with_attr("global_symbol", tvm.runtime.String(name)) if cfg.restricted_func: f = f.with_attr("tir.noalias", True) mod = tvm.IRModule({name: f}) return tvm.tir.transform.MakePackedAPI()(mod)