コード例 #1
0
ファイル: isl_to_tir.py プロジェクト: dawnpower/pyvlova
 def innermost():
     body = self.parse(cur, cur_p)
     if 'threadIdx' in self.mark_stack[-1]:
         tensors_from_host = []
         for i in self.shared_tensors:
             if self.has_side_effect and 'read' in i.access_types \
                     or not self.has_side_effect and 'write' not in i.access_types:
                 tensors_from_host.append(i)
         tensors_to_host = []
         for i in self.shared_tensors:
             if 'write' in i.access_types:
                 tensors_to_host.append(i)
         stmts = []
         for i in tensors_from_host:
             stmts.append(
                 i.build_copy_from_host(self.cuda_iter_var_table,
                                        self.iter_var_table))
         if tensors_from_host:
             stmts.append(tir.Evaluate(tir_cuda_shared_sync()))
         stmts.append(body)
         if tensors_to_host:
             stmts.append(tir.Evaluate(tir_cuda_shared_sync()))
         for i in tensors_to_host:
             stmts.append(
                 i.build_copy_to_host(self.cuda_iter_var_table,
                                      self.iter_var_table))
         if len(stmts) >= 2:
             body = tir.SeqStmt(stmts)
     return body
コード例 #2
0
ファイル: isl_to_tir.py プロジェクト: dawnpower/pyvlova
 def parse_block(self, node, parent):
     children = node.children()
     body = []
     for i in range(children.size()):
         child = self.parse(children.at(i), parent)
         if child is not None:
             body.append(child)
     return tir.SeqStmt(body)
コード例 #3
0
 def to_stmt_tvm(self):
     assert self.record
     record = self.record[-1]
     if len(record) <= 1:
         res = record[0]
     else:
         res = tir.SeqStmt(record)
     return res
コード例 #4
0
def test_convert_ssa():
    zero = tir.const(0)
    nop = tir.Evaluate(zero)
    v = tir.Var("i1", "int32")
    for_stmt = tir.For(v, zero, zero, tir.ForKind.SERIAL, nop)
    load = tir.Evaluate(tir.Load("int32", v, zero))
    seq = tir.SeqStmt([for_stmt, for_stmt, load])
    func = tir.PrimFunc([], seq)
    mod = tvm.IRModule({"main": func})
    mod = tir.transform.InjectVirtualThread()(
        mod
    )  # Use pass InjectVirtualThread to invoke ConvertSSA
コード例 #5
0
def test_convert_ssa():
    dtype = "int32"
    zero = tir.const(0)
    nop = tir.Evaluate(zero)
    var_type = ir.PointerType(ir.PrimType(dtype))
    v = tir.Var("i1", var_type)
    buf = tir.decl_buffer([16], dtype=dtype, data=v)
    let = tir.LetStmt(v, v, nop)
    load = tir.Evaluate(tir.BufferLoad(buf, [zero]))
    seq = tir.SeqStmt([let, let, load])
    func = tir.PrimFunc([], seq)
    mod = tvm.IRModule({"main": func})
    mod = tir.transform.InjectVirtualThread()(
        mod)  # Use pass InjectVirtualThread to invoke ConvertSSA