Exemplo n.º 1
0
 def innermost():
     body = self.parse(cur, cur_p)
     if 'threadIdx' in self.mark_stack[-1]:
         tensors_from_host = []
         for i in self.shared_tensors:
             if self.has_side_effect and 'read' in i.access_types \
                     or not self.has_side_effect and 'write' not in i.access_types:
                 tensors_from_host.append(i)
         tensors_to_host = []
         for i in self.shared_tensors:
             if 'write' in i.access_types:
                 tensors_to_host.append(i)
         stmts = []
         for i in tensors_from_host:
             stmts.append(
                 i.build_copy_from_host(self.cuda_iter_var_table,
                                        self.iter_var_table))
         if tensors_from_host:
             stmts.append(tir.Evaluate(tir_cuda_shared_sync()))
         stmts.append(body)
         if tensors_to_host:
             stmts.append(tir.Evaluate(tir_cuda_shared_sync()))
         for i in tensors_to_host:
             stmts.append(
                 i.build_copy_to_host(self.cuda_iter_var_table,
                                      self.iter_var_table))
         if len(stmts) >= 2:
             body = tir.SeqStmt(stmts)
     return body
Exemplo n.º 2
0
def test_control_flow_jump():
    ib = tvm.tir.ir_builder.create()
    a = tir.Var("a", "float32")
    b = tir.Var("b", "float32")
    with ib.if_scope(True):
        ib.emit(tir.Evaluate(tir.ret(a)))
    ib.emit(tir.Evaluate(tir.ret(b)))
    stmt = ib.get()
    func = tir.PrimFunc([a, b], stmt)
    func = build_tir_func(func)
    out = func(1.0, 2.0)
    assert out == 1.0
Exemplo n.º 3
0
def test_convert_ssa():
    zero = tir.const(0)
    nop = tir.Evaluate(zero)
    v = tir.Var("i1", "int32")
    for_stmt = tir.For(v, zero, zero, tir.ForKind.SERIAL, nop)
    load = tir.Evaluate(tir.Load("int32", v, zero))
    seq = tir.SeqStmt([for_stmt, for_stmt, load])
    func = tir.PrimFunc([], seq)
    mod = tvm.IRModule({"main": func})
    mod = tir.transform.InjectVirtualThread()(
        mod
    )  # Use pass InjectVirtualThread to invoke ConvertSSA
def test_convert_ssa():
    dtype = "int32"
    zero = tir.const(0)
    nop = tir.Evaluate(zero)
    var_type = ir.PointerType(ir.PrimType(dtype))
    v = tir.Var("i1", var_type)
    buf = tir.decl_buffer([16], dtype=dtype, data=v)
    let = tir.LetStmt(v, v, nop)
    load = tir.Evaluate(tir.BufferLoad(buf, [zero]))
    seq = tir.SeqStmt([let, let, load])
    func = tir.PrimFunc([], seq)
    mod = tvm.IRModule({"main": func})
    mod = tir.transform.InjectVirtualThread()(
        mod)  # Use pass InjectVirtualThread to invoke ConvertSSA
Exemplo n.º 5
0
def test_ret_const():
    a = tir.const(0)
    b = tir.ret(a)
    b = tir.Evaluate(b)
    func = tir.PrimFunc([], b)
    func = build_tir_func(func)
    out = func()
    assert out == 0
Exemplo n.º 6
0
def test_scalar_add():
    a = tir.Var("a", "float32")
    b = tir.Var("b", "float32")
    c = a + b
    c = tir.ret(c)
    c = tir.Evaluate(c)
    func = tir.PrimFunc([a, b], c)
    func = build_tir_func(func)
    out = func(1.0, 2.0)
    assert out == 3.0
Exemplo n.º 7
0
def test_replace_block_in_opaque_block():
    s = replace_ir_builder_with_opaque()
    root_hash = s.mod["main"].__hash__()
    for_loop = s.mod["main"].body.block.body.body.block.body[1].then_case.block.body
    sref = s.get_sref(for_loop)
    new_for_loop = tir.For(
        loop_var=for_loop.loop_var,
        min_val=0,
        extent=128,
        kind=tir.ForKind.SERIAL,
        body=tir.Evaluate(0),
        thread_binding=None,
        annotations=None,
    )
    s.replace(sref, new_for_loop)
    assert root_hash == s.mod["main"].__hash__()
    tvm.ir.assert_structural_equal(sref.stmt, new_for_loop)
Exemplo n.º 8
0
def test_scalar_add():
    # All these types should be interchangeable with each other
    # E.g. float16 + float32 upconverts the float16 --> float32
    # Meanwhile if an int or float or together the int will be
    # cast to the float type.
    lhs_types = ["float32", "float16", "int32", "int64"]
    rhs_types = ["float32", "float16"]
    for lhs_type, rhs_type in itertools.product(lhs_types, rhs_types):
        # Input vars should be float32, we will cast to test for upcasting between them
        lhs_input = tir.Var("lhs", "float32")
        rhs_input = tir.Var("rhs", "float32")
        lhs = tir.Cast(lhs_type, lhs_input)
        rhs = tir.Cast(rhs_type, rhs_input)
        output = lhs + rhs
        output = tir.ret(output)
        output = tir.Evaluate(output)
        func = tir.PrimFunc([lhs_input, rhs_input], output)
        func = build_tir_func(func)
        out = func(1.0, 2.0)
        assert out == 3.0