Exemplo n.º 1
0
def test_scalar_add():
    a = tir.Var("a", "float32")
    b = tir.Var("b", "float32")
    c = a + b
    c = tir.ret(c)
    c = tir.Evaluate(c)
    func = tir.PrimFunc([a, b], c)
    func = build_tir_func(func)
    out = func(1.0, 2.0)
    assert out == 3.0
Exemplo n.º 2
0
def test_control_flow_jump():
    ib = tvm.tir.ir_builder.create()
    a = tir.Var("a", "float32")
    b = tir.Var("b", "float32")
    with ib.if_scope(True):
        ib.emit(tir.Evaluate(tir.ret(a)))
    ib.emit(tir.Evaluate(tir.ret(b)))
    stmt = ib.get()
    func = tir.PrimFunc([a, b], stmt)
    func = build_tir_func(func)
    out = func(1.0, 2.0)
    assert out == 1.0
Exemplo n.º 3
0
 def push(self, name=None, var=None):
     if name is None:
         name = f'_i{len(self.var_stack)}'
     if var is None:
         var = tir.Var(name=name, dtype='int32')
     self.var_stack.append((var, name))
     self.vars[name] = var
     return var
Exemplo n.º 4
0
def test_specialize_matmul():
    a, _, _, n = matmul.params
    # fully specialized
    func = matmul.specialize({a: tir.decl_buffer((128, 128))})
    tvm.ir.assert_structural_equal(func, matmul_128)
    # partially specialized
    func = matmul.specialize({n: 128})
    tvm.ir.assert_structural_equal(func, matmul_m_128)
    # symbolic specialized
    func = matmul.specialize({n: tir.Var("x", "int32") * 8})
    tvm.ir.assert_structural_equal(func, matmul_m_8x)
Exemplo n.º 5
0
def test_scalar_add():
    # All these types should be interchangeable with each other
    # E.g. float16 + float32 upconverts the float16 --> float32
    # Meanwhile if an int or float or together the int will be
    # cast to the float type.
    lhs_types = ["float32", "float16", "int32", "int64"]
    rhs_types = ["float32", "float16"]
    for lhs_type, rhs_type in itertools.product(lhs_types, rhs_types):
        # Input vars should be float32, we will cast to test for upcasting between them
        lhs_input = tir.Var("lhs", "float32")
        rhs_input = tir.Var("rhs", "float32")
        lhs = tir.Cast(lhs_type, lhs_input)
        rhs = tir.Cast(rhs_type, rhs_input)
        output = lhs + rhs
        output = tir.ret(output)
        output = tir.Evaluate(output)
        func = tir.PrimFunc([lhs_input, rhs_input], output)
        func = build_tir_func(func)
        out = func(1.0, 2.0)
        assert out == 3.0
Exemplo n.º 6
0
def test_convert_ssa():
    zero = tir.const(0)
    nop = tir.Evaluate(zero)
    v = tir.Var("i1", "int32")
    for_stmt = tir.For(v, zero, zero, tir.ForKind.SERIAL, nop)
    load = tir.Evaluate(tir.Load("int32", v, zero))
    seq = tir.SeqStmt([for_stmt, for_stmt, load])
    func = tir.PrimFunc([], seq)
    mod = tvm.IRModule({"main": func})
    mod = tir.transform.InjectVirtualThread()(
        mod
    )  # Use pass InjectVirtualThread to invoke ConvertSSA
def test_convert_ssa():
    dtype = "int32"
    zero = tir.const(0)
    nop = tir.Evaluate(zero)
    var_type = ir.PointerType(ir.PrimType(dtype))
    v = tir.Var("i1", var_type)
    buf = tir.decl_buffer([16], dtype=dtype, data=v)
    let = tir.LetStmt(v, v, nop)
    load = tir.Evaluate(tir.BufferLoad(buf, [zero]))
    seq = tir.SeqStmt([let, let, load])
    func = tir.PrimFunc([], seq)
    mod = tvm.IRModule({"main": func})
    mod = tir.transform.InjectVirtualThread()(
        mod)  # Use pass InjectVirtualThread to invoke ConvertSSA
Exemplo n.º 8
0
def test_tir_op_call_likely():
    x = tir.Var("x", dtype="int32")
    expr = tir.likely(cond=x)
    assert expr.op.name == "tir.likely"
Exemplo n.º 9
0
def test_tir_op_call_assume():
    x = tir.Var("x", dtype="int32")
    expr = tir.assume(cond=x)
    assert expr.op.name == "tir.assume"
Exemplo n.º 10
0
def test_tir_op_tvm_struct_set():
    x = tir.Var("x", dtype="handle")
    expr = tir.tvm_struct_set(x, 1, 2, 3)
    assert expr.op.name == "tir.tvm_struct_set"
Exemplo n.º 11
0
def test_tir_op_tvm_tuple():
    x = tir.Var("x", dtype="float32")
    y = tir.Var("y", dtype="float32")
    z = tir.Var("z", dtype="float32")
    expr = tir.tvm_tuple(x, y, z, 1, 2, 3)
    assert expr.op.name == "tir.tvm_tuple"
Exemplo n.º 12
0
def test_exception():
    with pytest.raises(tvm.TVMError):
        x = tir.Var(name=1, dtype="int")
Exemplo n.º 13
0
def assignment_helper(store_dtype, value_dtype):
    store = tir.Var("store", dtype=store_dtype)
    value = tir.Var("value", dtype=value_dtype)
    tir.Let(store, value, body=store)