def test_length(): a = relay.TypeVar("a") assert mod[length].checked_type == relay.FuncType( [l(a)], relay.scalar_type('int32'), [a]) res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil()))))) assert get_scalar(res) == 3
def test_double(): assert mod[double].checked_type == relay.FuncType([nat()], nat()) res = intrp.evaluate(double(s(z()))) assert count(res) == 2
def test_add(): assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) res = intrp.evaluate(add(s(z()), s(z()))) assert count(res) == 2
def test_func_type_alpha_equal(): t1 = relay.TensorType((1, 2), "float32") t2 = relay.TensorType((1, 2, 3), "float32") tp1 = relay.TypeVar("v1", relay.TypeKind.Type) tp2 = relay.TypeVar("v2", relay.TypeKind.Type) tp3 = relay.TypeVar("v3", relay.TypeKind.ShapeVar) tp4 = relay.TypeVar("v3", relay.TypeKind.ShapeVar) broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity") tr1 = relay.TypeRelation(broadcast, tvm.runtime.convert([tp1, tp3]), 1, None) tr2 = relay.TypeRelation(broadcast, tvm.runtime.convert([tp2, tp4]), 1, None) tr3 = relay.TypeRelation(identity, tvm.runtime.convert([tp1, tp3]), 1, None) ft = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1, tvm.runtime.convert([tp1, tp3]), tvm.runtime.convert([tr1])) translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1, tvm.runtime.convert([tp2, tp4]), tvm.runtime.convert([tr2])) assert ft == translate_vars different_args = relay.FuncType(tvm.runtime.convert([t1]), tp1, tvm.runtime.convert([tp1, tp3]), tvm.runtime.convert([tr1])) assert ft != different_args different_order = relay.FuncType(tvm.runtime.convert([t2, t1]), tp1, tvm.runtime.convert([tp1, tp3]), tvm.runtime.convert([tr1])) assert ft != different_order no_rel = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1, tvm.runtime.convert([tp1, tp3]), tvm.runtime.convert([])) assert ft != no_rel more_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp2, tvm.runtime.convert([tp1, tp2, tp3]), tvm.runtime.convert([tr1])) assert ft != more_vars all_the_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1, tvm.runtime.convert([tp1, tp2, tp3, tp4]), tvm.runtime.convert([tr1, tr2])) assert ft != all_the_vars different_rel = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1, tvm.runtime.convert([tp1, tp3]), tvm.runtime.convert([tr3])) assert ft != different_rel more_rels = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1, tvm.runtime.convert([tp1, tp3]), tvm.runtime.convert([tr1, tr3])) assert ft != more_rels
def test_fused_reshape(): mod = tvm.ir.IRModule() @T.prim_func def mul_primfunc(a: T.handle, b: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) D = T.match_buffer(d, [128, 128]) for i, j, k in T.grid(128, 128, 128): with T.block("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) D[vi, vj] = A[vi, vk] * B[vj, vk] @T.prim_func def fused_reshape_primfunc(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [128, 128]) D = T.match_buffer(d, [128, 128]) for i, j in T.grid(128, 128): D[i, j] = A[i, j] metatable = {"VirtualDevice": [CPU]} mul_ty = relay.FuncType( [ relay.TensorType((128, 128), "float32"), relay.TensorType((128, 128), "float32"), relay.TensorType((128, 128), "float32"), ], relay.TensorType((128, 128), "float32"), ) mul_gv = relay.GlobalVar("multiply", type_annot=mul_ty) mod[mul_gv] = mul_primfunc reshape_ty = relay.FuncType( [ relay.TensorType((128, 128), "float32"), ], relay.TensorType((128, 128), "float32"), ) reshape_gv = relay.GlobalVar("fused_reshape", type_annot=reshape_ty) mod[reshape_gv] = fused_reshape_primfunc mod = tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], %z {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], virtual_device=meta[VirtualDevice][0]) { %0 = call_lowered(@multiply, (%x, %y, %z)); let %x_12: Tensor[(128, 128), float32] = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %1 = call_lowered(@fused_reshape, (%x_12,) ); let %x_14: Tensor[(128, 128), float32] = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True); %x_14 } """, "from_string", mod, metatable, ) # Expected main: ##[version = "0.0.5"] # def @main(%x /* ty=Tensor[(128, 128), float32] */) -> Tensor[(128, 128), float32] { # %0 = (%x, %y, %z); # %1 = call_lowered(@multiply, %0); # let %x_12: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True); # let %x_14: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True); # %x_14 # } mod = RemoveStandaloneReshapes()(mod) reshapes_present = any( ["reshape" in gv.name_hint for gv in mod.get_global_vars()]) assert reshapes_present, "Reshape should have been removed." return
def test_sum(): assert mod[sum].checked_type == relay.FuncType( [l(relay.scalar_type('int32'))], relay.scalar_type('int32')) res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil())))) assert get_scalar(res) == 3
def test_single_op(): "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }" x = relay.var("x", shape=[]) func = relay.Function([x], op.log(x)) ttype = relay.TensorType([], dtype="float32") assert_has_type(func, relay.FuncType([ttype], ttype))
def test_length(): a = relay.TypeVar("a") assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a]) res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil()))))) assert count(res) == 3
def test_sum(): assert mod[sum].checked_type == relay.FuncType([l(nat())], nat()) res = intrp.evaluate(sum(cons(build_nat(1), cons(build_nat(2), nil())))) assert count(res) == 3
def test_sum(): assert prelude.mod[sum].checked_type == relay.FuncType( [rlist(relay.scalar_type("int32"))], relay.scalar_type("int32")) res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil())))) assert get_scalar(res) == 3
def test_func_with_invalid_arg_types(): tp1 = relay.TypeParam('tp1', relay.Kind.Shape) tp2 = relay.TypeParam('tp2', relay.Kind.Type) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
def test_double(): t = relay.TypeVar("t") x = relay.var("x", t) f = relay.var("f", relay.FuncType([t], t)) double = run_infer_type(relay.Function([f, x], f(f(x)), t, [t])) double_cps = run_infer_type(to_cps(double))
def test_single_op(): "Program: fn (x : float32) { let t1 = f(x); t1 }" x = relay.var('x', shape=[]) func = relay.Function([x], op.log(x)) ttype = relay.TensorType([], dtype='float32') assert_has_type(func, relay.FuncType([ttype], ttype))
def test_length(): a = relay.TypeVar("a") assert prelude.mod[length].checked_type == relay.FuncType( [rlist(a)], relay.scalar_type("int32"), [a]) res = eval(length(cons(z(), cons(z(), cons(z(), nil()))))) assert get_scalar(res) == 3
def test_add(): assert prelude.mod[add].checked_type == relay.FuncType( [nat(), nat()], nat()) res = eval(add(s(z()), s(z()))) assert count(res) == 2
def test_double(): assert prelude.mod[double].checked_type == relay.FuncType([nat()], nat()) res = eval(double(s(z()))) assert count(res) == 2