def innermost(): body = self.parse(cur, cur_p) if 'threadIdx' in self.mark_stack[-1]: tensors_from_host = [] for i in self.shared_tensors: if self.has_side_effect and 'read' in i.access_types \ or not self.has_side_effect and 'write' not in i.access_types: tensors_from_host.append(i) tensors_to_host = [] for i in self.shared_tensors: if 'write' in i.access_types: tensors_to_host.append(i) stmts = [] for i in tensors_from_host: stmts.append( i.build_copy_from_host(self.cuda_iter_var_table, self.iter_var_table)) if tensors_from_host: stmts.append(tir.Evaluate(tir_cuda_shared_sync())) stmts.append(body) if tensors_to_host: stmts.append(tir.Evaluate(tir_cuda_shared_sync())) for i in tensors_to_host: stmts.append( i.build_copy_to_host(self.cuda_iter_var_table, self.iter_var_table)) if len(stmts) >= 2: body = tir.SeqStmt(stmts) return body
def parse_block(self, node, parent): children = node.children() body = [] for i in range(children.size()): child = self.parse(children.at(i), parent) if child is not None: body.append(child) return tir.SeqStmt(body)
def to_stmt_tvm(self): assert self.record record = self.record[-1] if len(record) <= 1: res = record[0] else: res = tir.SeqStmt(record) return res
def test_convert_ssa(): zero = tir.const(0) nop = tir.Evaluate(zero) v = tir.Var("i1", "int32") for_stmt = tir.For(v, zero, zero, tir.ForKind.SERIAL, nop) load = tir.Evaluate(tir.Load("int32", v, zero)) seq = tir.SeqStmt([for_stmt, for_stmt, load]) func = tir.PrimFunc([], seq) mod = tvm.IRModule({"main": func}) mod = tir.transform.InjectVirtualThread()( mod ) # Use pass InjectVirtualThread to invoke ConvertSSA
def test_convert_ssa(): dtype = "int32" zero = tir.const(0) nop = tir.Evaluate(zero) var_type = ir.PointerType(ir.PrimType(dtype)) v = tir.Var("i1", var_type) buf = tir.decl_buffer([16], dtype=dtype, data=v) let = tir.LetStmt(v, v, nop) load = tir.Evaluate(tir.BufferLoad(buf, [zero])) seq = tir.SeqStmt([let, let, load]) func = tir.PrimFunc([], seq) mod = tvm.IRModule({"main": func}) mod = tir.transform.InjectVirtualThread()( mod) # Use pass InjectVirtualThread to invoke ConvertSSA