def parse_for(self, node, parent): with self._for_loop_vars(node) as (iter_var, c_var, extent_var, lower, upper, step, for_type): extent = tir.FloorDiv(tir.Sub(upper, lower), step) return tir.LetStmt( extent_var, extent, tir.For( iter_var, tir.IntImm('int32', 0), extent_var, for_type, tir.LetStmt(c_var, tir.Add(tir.Mul(iter_var, step), lower), self.parse(node.body(), node))))
def _build_copy_schedule(self, cuda_var_table: CUDAIterVarTable, iter_var_table: IterVarTable, stmt: Statement): num_threads = cuda_var_table.axis_extents['threadIdx'] idx = cuda_var_table.axis_idx['threadIdx'] total = tir_imm(reduce(tir.Mul, self.usage_extents(True))) with iter_var_table.var() as iter_var, iter_var_table.var() as extent_var: body = tir.For( iter_var, tir_imm(0), extent_var, tir.For.Serial, 0, stmt.to_tvm(None, iter_var * num_threads + idx) ) body = tir.LetStmt(extent_var, (total - 1 - idx) // num_threads + 1, body) return body
def test_convert_ssa(): zero = tir.const(0) nop = tir.Evaluate(zero) v = tir.Var("i1", "int32") for_stmt = tir.For(v, zero, zero, tir.ForKind.SERIAL, nop) load = tir.Evaluate(tir.Load("int32", v, zero)) seq = tir.SeqStmt([for_stmt, for_stmt, load]) func = tir.PrimFunc([], seq) mod = tvm.IRModule({"main": func}) mod = tir.transform.InjectVirtualThread()( mod ) # Use pass InjectVirtualThread to invoke ConvertSSA
def test_convert_ssa(): dtype = "int32" zero = tir.const(0) nop = tir.Evaluate(zero) var_type = ir.PointerType(ir.PrimType(dtype)) v = tir.Var("i1", var_type) buf = tir.decl_buffer([16], dtype=dtype, data=v) for_stmt = tir.For(v, zero, zero, tir.ForKind.SERIAL, nop) load = tir.Evaluate(tir.BufferLoad(buf, [zero])) seq = tir.SeqStmt([for_stmt, for_stmt, load]) func = tir.PrimFunc([], seq) mod = tvm.IRModule({"main": func}) mod = tir.transform.InjectVirtualThread()( mod) # Use pass InjectVirtualThread to invoke ConvertSSA
def test_replace_block_in_opaque_block(): s = replace_ir_builder_with_opaque() root_hash = s.mod["main"].__hash__() for_loop = s.mod["main"].body.block.body.body.block.body[1].then_case.block.body sref = s.get_sref(for_loop) new_for_loop = tir.For( loop_var=for_loop.loop_var, min_val=0, extent=128, kind=tir.ForKind.SERIAL, body=tir.Evaluate(0), thread_binding=None, annotations=None, ) s.replace(sref, new_for_loop) assert root_hash == s.mod["main"].__hash__() tvm.ir.assert_structural_equal(sref.stmt, new_for_loop)