def innermost(): body = self.parse(cur, cur_p) if 'threadIdx' in self.mark_stack[-1]: tensors_from_host = [] for i in self.shared_tensors: if self.has_side_effect and 'read' in i.access_types \ or not self.has_side_effect and 'write' not in i.access_types: tensors_from_host.append(i) tensors_to_host = [] for i in self.shared_tensors: if 'write' in i.access_types: tensors_to_host.append(i) stmts = [] for i in tensors_from_host: stmts.append( i.build_copy_from_host(self.cuda_iter_var_table, self.iter_var_table)) if tensors_from_host: stmts.append(tir.Evaluate(tir_cuda_shared_sync())) stmts.append(body) if tensors_to_host: stmts.append(tir.Evaluate(tir_cuda_shared_sync())) for i in tensors_to_host: stmts.append( i.build_copy_to_host(self.cuda_iter_var_table, self.iter_var_table)) if len(stmts) >= 2: body = tir.SeqStmt(stmts) return body
def test_control_flow_jump(): ib = tvm.tir.ir_builder.create() a = tir.Var("a", "float32") b = tir.Var("b", "float32") with ib.if_scope(True): ib.emit(tir.Evaluate(tir.ret(a))) ib.emit(tir.Evaluate(tir.ret(b))) stmt = ib.get() func = tir.PrimFunc([a, b], stmt) func = build_tir_func(func) out = func(1.0, 2.0) assert out == 1.0
def test_convert_ssa(): zero = tir.const(0) nop = tir.Evaluate(zero) v = tir.Var("i1", "int32") for_stmt = tir.For(v, zero, zero, tir.ForKind.SERIAL, nop) load = tir.Evaluate(tir.Load("int32", v, zero)) seq = tir.SeqStmt([for_stmt, for_stmt, load]) func = tir.PrimFunc([], seq) mod = tvm.IRModule({"main": func}) mod = tir.transform.InjectVirtualThread()( mod ) # Use pass InjectVirtualThread to invoke ConvertSSA
def test_convert_ssa(): dtype = "int32" zero = tir.const(0) nop = tir.Evaluate(zero) var_type = ir.PointerType(ir.PrimType(dtype)) v = tir.Var("i1", var_type) buf = tir.decl_buffer([16], dtype=dtype, data=v) let = tir.LetStmt(v, v, nop) load = tir.Evaluate(tir.BufferLoad(buf, [zero])) seq = tir.SeqStmt([let, let, load]) func = tir.PrimFunc([], seq) mod = tvm.IRModule({"main": func}) mod = tir.transform.InjectVirtualThread()( mod) # Use pass InjectVirtualThread to invoke ConvertSSA
def test_ret_const(): a = tir.const(0) b = tir.ret(a) b = tir.Evaluate(b) func = tir.PrimFunc([], b) func = build_tir_func(func) out = func() assert out == 0
def test_scalar_add(): a = tir.Var("a", "float32") b = tir.Var("b", "float32") c = a + b c = tir.ret(c) c = tir.Evaluate(c) func = tir.PrimFunc([a, b], c) func = build_tir_func(func) out = func(1.0, 2.0) assert out == 3.0
def test_replace_block_in_opaque_block(): s = replace_ir_builder_with_opaque() root_hash = s.mod["main"].__hash__() for_loop = s.mod["main"].body.block.body.body.block.body[1].then_case.block.body sref = s.get_sref(for_loop) new_for_loop = tir.For( loop_var=for_loop.loop_var, min_val=0, extent=128, kind=tir.ForKind.SERIAL, body=tir.Evaluate(0), thread_binding=None, annotations=None, ) s.replace(sref, new_for_loop) assert root_hash == s.mod["main"].__hash__() tvm.ir.assert_structural_equal(sref.stmt, new_for_loop)
def test_scalar_add(): # All these types should be interchangeable with each other # E.g. float16 + float32 upconverts the float16 --> float32 # Meanwhile if an int or float or together the int will be # cast to the float type. lhs_types = ["float32", "float16", "int32", "int64"] rhs_types = ["float32", "float16"] for lhs_type, rhs_type in itertools.product(lhs_types, rhs_types): # Input vars should be float32, we will cast to test for upcasting between them lhs_input = tir.Var("lhs", "float32") rhs_input = tir.Var("rhs", "float32") lhs = tir.Cast(lhs_type, lhs_input) rhs = tir.Cast(rhs_type, rhs_input) output = lhs + rhs output = tir.ret(output) output = tir.Evaluate(output) func = tir.PrimFunc([lhs_input, rhs_input], output) func = build_tir_func(func) out = func(1.0, 2.0) assert out == 3.0