コード例 #1
0
ファイル: test_adt.py プロジェクト: xiaoliyang1/tvm
def test_length():
    a = relay.TypeVar("a")
    assert mod[length].checked_type == relay.FuncType(
        [l(a)], relay.scalar_type('int32'), [a])
    res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil())))))
    assert get_scalar(res) == 3
コード例 #2
0
ファイル: test_adt.py プロジェクト: xiaoliyang1/tvm
def test_double():
    assert mod[double].checked_type == relay.FuncType([nat()], nat())
    res = intrp.evaluate(double(s(z())))
    assert count(res) == 2
コード例 #3
0
ファイル: test_adt.py プロジェクト: xiaoliyang1/tvm
def test_add():
    assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
    res = intrp.evaluate(add(s(z()), s(z())))
    assert count(res) == 2
コード例 #4
0
def test_func_type_alpha_equal():
    t1 = relay.TensorType((1, 2), "float32")
    t2 = relay.TensorType((1, 2, 3), "float32")

    tp1 = relay.TypeVar("v1", relay.TypeKind.Type)
    tp2 = relay.TypeVar("v2", relay.TypeKind.Type)
    tp3 = relay.TypeVar("v3", relay.TypeKind.ShapeVar)
    tp4 = relay.TypeVar("v3", relay.TypeKind.ShapeVar)

    broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
    identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity")

    tr1 = relay.TypeRelation(broadcast, tvm.runtime.convert([tp1, tp3]), 1,
                             None)
    tr2 = relay.TypeRelation(broadcast, tvm.runtime.convert([tp2, tp4]), 1,
                             None)
    tr3 = relay.TypeRelation(identity, tvm.runtime.convert([tp1, tp3]), 1,
                             None)

    ft = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
                        tvm.runtime.convert([tp1, tp3]),
                        tvm.runtime.convert([tr1]))
    translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
                                    tvm.runtime.convert([tp2, tp4]),
                                    tvm.runtime.convert([tr2]))
    assert ft == translate_vars

    different_args = relay.FuncType(tvm.runtime.convert([t1]), tp1,
                                    tvm.runtime.convert([tp1, tp3]),
                                    tvm.runtime.convert([tr1]))
    assert ft != different_args

    different_order = relay.FuncType(tvm.runtime.convert([t2, t1]), tp1,
                                     tvm.runtime.convert([tp1, tp3]),
                                     tvm.runtime.convert([tr1]))
    assert ft != different_order

    no_rel = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
                            tvm.runtime.convert([tp1, tp3]),
                            tvm.runtime.convert([]))
    assert ft != no_rel

    more_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp2,
                               tvm.runtime.convert([tp1, tp2, tp3]),
                               tvm.runtime.convert([tr1]))
    assert ft != more_vars

    all_the_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
                                  tvm.runtime.convert([tp1, tp2, tp3, tp4]),
                                  tvm.runtime.convert([tr1, tr2]))
    assert ft != all_the_vars

    different_rel = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
                                   tvm.runtime.convert([tp1, tp3]),
                                   tvm.runtime.convert([tr3]))
    assert ft != different_rel

    more_rels = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
                               tvm.runtime.convert([tp1, tp3]),
                               tvm.runtime.convert([tr1, tr3]))
    assert ft != more_rels
コード例 #5
0
def test_fused_reshape():
    mod = tvm.ir.IRModule()

    @T.prim_func
    def mul_primfunc(a: T.handle, b: T.handle, d: T.handle) -> None:
        A = T.match_buffer(a, [128, 128])
        B = T.match_buffer(b, [128, 128])
        D = T.match_buffer(d, [128, 128])

        for i, j, k in T.grid(128, 128, 128):
            with T.block("update"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                D[vi, vj] = A[vi, vk] * B[vj, vk]

    @T.prim_func
    def fused_reshape_primfunc(a: T.handle, d: T.handle) -> None:
        A = T.match_buffer(a, [128, 128])
        D = T.match_buffer(d, [128, 128])

        for i, j in T.grid(128, 128):
            D[i, j] = A[i, j]

    metatable = {"VirtualDevice": [CPU]}
    mul_ty = relay.FuncType(
        [
            relay.TensorType((128, 128), "float32"),
            relay.TensorType((128, 128), "float32"),
            relay.TensorType((128, 128), "float32"),
        ],
        relay.TensorType((128, 128), "float32"),
    )

    mul_gv = relay.GlobalVar("multiply", type_annot=mul_ty)
    mod[mul_gv] = mul_primfunc
    reshape_ty = relay.FuncType(
        [
            relay.TensorType((128, 128), "float32"),
        ],
        relay.TensorType((128, 128), "float32"),
    )

    reshape_gv = relay.GlobalVar("fused_reshape", type_annot=reshape_ty)
    mod[reshape_gv] = fused_reshape_primfunc
    mod = tvm.parser.parse(
        """
        #[version = "0.0.5"]
        def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32],
                  %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32],
                  %z {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32],
                  virtual_device=meta[VirtualDevice][0]) {
          %0 = call_lowered(@multiply, (%x, %y, %z));
          let %x_12: Tensor[(128, 128), float32] = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
          %1 = call_lowered(@fused_reshape, (%x_12,) );
          let %x_14: Tensor[(128, 128), float32] = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True);
          %x_14
        }
        """,
        "from_string",
        mod,
        metatable,
    )

    # Expected main:
    ##[version = "0.0.5"]
    # def @main(%x /* ty=Tensor[(128, 128), float32] */) -> Tensor[(128, 128), float32] {
    #  %0 = (%x, %y, %z);
    #  %1 = call_lowered(@multiply, %0);
    #  let %x_12: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True);
    #  let %x_14: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True);
    #  %x_14
    # }

    mod = RemoveStandaloneReshapes()(mod)
    reshapes_present = any(
        ["reshape" in gv.name_hint for gv in mod.get_global_vars()])
    assert reshapes_present, "Reshape should have been removed."
    return
コード例 #6
0
def test_sum():
    assert mod[sum].checked_type == relay.FuncType(
        [l(relay.scalar_type('int32'))], relay.scalar_type('int32'))
    res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2),
                                                       nil()))))
    assert get_scalar(res) == 3
コード例 #7
0
def test_single_op():
    "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }"
    x = relay.var("x", shape=[])
    func = relay.Function([x], op.log(x))
    ttype = relay.TensorType([], dtype="float32")
    assert_has_type(func, relay.FuncType([ttype], ttype))
コード例 #8
0
ファイル: test_adt.py プロジェクト: yTakatsukasa/tvm
def test_length():
    a = relay.TypeVar("a")
    assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a])
    res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil())))))
    assert count(res) == 3
コード例 #9
0
ファイル: test_adt.py プロジェクト: yTakatsukasa/tvm
def test_sum():
    assert mod[sum].checked_type == relay.FuncType([l(nat())], nat())
    res = intrp.evaluate(sum(cons(build_nat(1), cons(build_nat(2), nil()))))
    assert count(res) == 3
コード例 #10
0
def test_sum():
    assert prelude.mod[sum].checked_type == relay.FuncType(
        [rlist(relay.scalar_type("int32"))], relay.scalar_type("int32"))
    res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2),
                                                       nil()))))
    assert get_scalar(res) == 3
コード例 #11
0
def test_func_with_invalid_arg_types():
    tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
    tp2 = relay.TypeParam('tp2', relay.Kind.Type)
    tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]),
                        tvm.convert([]))
コード例 #12
0
def test_double():
    t = relay.TypeVar("t")
    x = relay.var("x", t)
    f = relay.var("f", relay.FuncType([t], t))
    double = run_infer_type(relay.Function([f, x], f(f(x)), t, [t]))
    double_cps = run_infer_type(to_cps(double))
コード例 #13
0
def test_single_op():
    "Program: fn (x : float32) { let t1 = f(x); t1 }"
    x = relay.var('x', shape=[])
    func = relay.Function([x], op.log(x))
    ttype = relay.TensorType([], dtype='float32')
    assert_has_type(func, relay.FuncType([ttype], ttype))
コード例 #14
0
def test_length():
    a = relay.TypeVar("a")
    assert prelude.mod[length].checked_type == relay.FuncType(
        [rlist(a)], relay.scalar_type("int32"), [a])
    res = eval(length(cons(z(), cons(z(), cons(z(), nil())))))
    assert get_scalar(res) == 3
コード例 #15
0
def test_add():
    assert prelude.mod[add].checked_type == relay.FuncType(
        [nat(), nat()], nat())
    res = eval(add(s(z()), s(z())))
    assert count(res) == 2
コード例 #16
0
def test_double():
    assert prelude.mod[double].checked_type == relay.FuncType([nat()], nat())
    res = eval(double(s(z())))
    assert count(res) == 2