def parse_for(self, node, parent): with self._for_loop_vars(node) as (iter_var, c_var, extent_var, lower, upper, step, for_type): extent = tir.FloorDiv(tir.Sub(upper, lower), step) return tir.LetStmt( extent_var, extent, tir.For( iter_var, tir.IntImm('int32', 0), extent_var, for_type, tir.LetStmt(c_var, tir.Add(tir.Mul(iter_var, step), lower), self.parse(node.body(), node))))
def _build_copy_schedule(self, cuda_var_table: CUDAIterVarTable, iter_var_table: IterVarTable, stmt: Statement): num_threads = cuda_var_table.axis_extents['threadIdx'] idx = cuda_var_table.axis_idx['threadIdx'] total = tir_imm(reduce(tir.Mul, self.usage_extents(True))) with iter_var_table.var() as iter_var, iter_var_table.var() as extent_var: body = tir.For( iter_var, tir_imm(0), extent_var, tir.For.Serial, 0, stmt.to_tvm(None, iter_var * num_threads + idx) ) body = tir.LetStmt(extent_var, (total - 1 - idx) // num_threads + 1, body) return body
def expand_axis_var(num, axis_var): if num >= len(bounds): if 'threadIdx' in self.mark_stack[-1]: for tensor_usage in self.shared_tensors: tensor_usage.gen_offset_tvm_repr(self.expr_parser) return _under_shared(0) else: return innermost() c_var_name, lower, _, step = bounds[num] with self.iter_var_table.var(c_var_name) as c_var: val = axis_var // anchors[num] % extents[num] return tir.LetStmt(c_var, val * step + lower, body=expand_axis_var(num + 1, axis_var))
def test_convert_ssa(): dtype = "int32" zero = tir.const(0) nop = tir.Evaluate(zero) var_type = ir.PointerType(ir.PrimType(dtype)) v = tir.Var("i1", var_type) buf = tir.decl_buffer([16], dtype=dtype, data=v) let = tir.LetStmt(v, v, nop) load = tir.Evaluate(tir.BufferLoad(buf, [zero])) seq = tir.SeqStmt([let, let, load]) func = tir.PrimFunc([], seq) mod = tvm.IRModule({"main": func}) mod = tir.transform.InjectVirtualThread()( mod) # Use pass InjectVirtualThread to invoke ConvertSSA