def test_global_var(): name_hint = 'g' gv = relay.GlobalVar(name_hint) gv.name_hint == name_hint # assert lv.span == None todo(@jroesch): what do we do about spans str(gv)
def test_function_pass(): shape = (10, ) dtype = 'float32' tp = relay.TensorType(shape, dtype) x = relay.var("x", tp) v_log = relay.GlobalVar("myLog") log = relay.Function([x], relay.log(x)) mod = relay.Module({v_log: log}) pass_name = "function_pass_test" opt_level = 1 opt_tester = OptTester(mod) pass_ctx = None @_transform.function_pass(opt_level=opt_level, name=pass_name) def transform(expr, mod, ctx): return opt_tester.transform(expr, ctx) def get_ref_log(): ref_log = relay.Function([x], relay.log(relay.add(x, x))) return ref_log def test_pass_registration(): function_pass = transform assert isinstance(function_pass, _transform.FunctionPass) pass_info = function_pass.info assert pass_info.name == pass_name assert pass_info.opt_level == opt_level def test_pass_registration_no_decorator(): def direct_transform(expr, ctx): return opt_tester.transform(expr, ctx) mod_pass = _transform.function_pass(direct_transform, opt_level=0) assert isinstance(mod_pass, _transform.FunctionPass) pass_info = mod_pass.info assert pass_info.name == "direct_transform" assert pass_info.opt_level == 0 def test_pass_run(): function_pass = transform assert pass_name in function_pass.astext() updated_mod = function_pass(mod) assert isinstance(updated_mod, relay.Module) # Check the log function in the updated module. new_v_log = updated_mod.get_global_var(v_log.name_hint) new_log = updated_mod[new_v_log] check_func(new_log, get_ref_log()) # Check the log function in the python transformed function. ret = opt_tester.transform(log, pass_ctx) check_func(new_log, ret) # Execute the add function. x_nd = get_rand(shape, dtype) ref_res = np.log(x_nd.asnumpy() * 2) for target, ctx in ctx_list(): exe1 = relay.create_executor("graph", ctx=ctx, target=target) exe2 = relay.create_executor("debug", ctx=ctx, target=target) res1 = exe1.evaluate(new_log)(x_nd) tvm.testing.assert_allclose(res1.asnumpy(), ref_res, rtol=1e-5) res2 = exe2.evaluate(new_log)(x_nd) tvm.testing.assert_allclose(res2.asnumpy(), ref_res, rtol=1e-5) test_pass_registration() test_pass_registration_no_decorator() test_pass_run()
def test_sequential_pass(): shape = (10, ) dtype = 'float32' tp = relay.TensorType(shape, dtype) x = relay.var("x", tp) y = relay.var("y", tp) v_sub = relay.GlobalVar("mySub") sub = relay.Function([x, y], relay.subtract(x, y)) z = relay.var("z", tp) v_log = relay.GlobalVar("myLog") log = relay.Function([z], relay.log(z)) mod = relay.Module({v_sub: sub, v_log: log}) def get_ref_log(): ref_log = relay.Function([x], relay.log(relay.add(x, x))) return ref_log def get_ref_sub(): ref_sub = relay.Function([x, y], relay.subtract(relay.add(x, x), relay.add(y, y))) return ref_sub def get_ref_abs(): shape = (5, 10) tp = relay.TensorType(shape, "float32") a = relay.var("a", tp) ref_abs = relay.Function([a], relay.abs(relay.add(a, a))) return ref_abs # Register a module pass. opt_tester = OptTester(mod) pass_ctx = None @_transform.module_pass(opt_level=1) def mod_transform(expr, ctx): return opt_tester.transform(expr, ctx) module_pass = mod_transform # Register a function pass. @_transform.function_pass(opt_level=1) def func_transform(expr, mod, ctx): return opt_tester.transform(expr, ctx) function_pass = func_transform def test_pass_registration(): passes = [module_pass, function_pass] opt_level = 2 pass_name = "sequential" sequential = _transform.Sequential(passes=passes, opt_level=opt_level) pass_info = sequential.info assert pass_info.name == pass_name assert pass_info.opt_level == opt_level def test_no_pass(): passes = [] sequential = _transform.Sequential(opt_level=1, passes=passes) ret_mod = sequential(mod) mod_func = ret_mod[v_sub] check_func(sub, mod_func) def test_only_module_pass(): passes = [module_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) with relay.build_config(required_pass=["mod_transform"]): ret_mod = sequential(mod) # Check the subtract function. sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, sub) # Check the abs function is added. abs_var, abs_func = get_var_func() abs_var, new_abs = extract_var_func(ret_mod, abs_var.name_hint) check_func(new_abs, abs_func) def test_only_function_pass(): # Check the subtract function. passes = [function_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) with relay.build_config(required_pass=["func_transform"]): ret_mod = sequential(mod) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, get_ref_sub()) # Check the log function. log_var, new_log = extract_var_func(ret_mod, v_log.name_hint) check_func(new_log, get_ref_log()) def test_multiple_passes(): # Reset the current module since mod has been polluted by the previous # function pass. mod = relay.Module({v_sub: sub, v_log: log}) passes = [module_pass, function_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) required = ["mod_transform", "func_transform"] with relay.build_config(required_pass=required): ret_mod = sequential(mod) # Check the abs function is added. abs_var, abs_func = get_var_func() abs_var, new_abs = extract_var_func(ret_mod, abs_var.name_hint) check_func(new_abs, get_ref_abs()) # Check the subtract function is modified correctly. _, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, get_ref_sub()) # Check the log function is modified correctly. _, new_log = extract_var_func(ret_mod, v_log.name_hint) check_func(new_log, get_ref_log()) # Execute the updated subtract function. x_nd = get_rand(shape, dtype) y_nd = get_rand(shape, dtype) ref_res = np.subtract(x_nd.asnumpy() * 2, y_nd.asnumpy() * 2) for target, ctx in ctx_list(): exe1 = relay.create_executor("graph", ctx=ctx, target=target) exe2 = relay.create_executor("debug", ctx=ctx, target=target) res1 = exe1.evaluate(new_sub)(x_nd, y_nd) tvm.testing.assert_allclose(res1.asnumpy(), ref_res, rtol=1e-5) res2 = exe2.evaluate(new_sub)(x_nd, y_nd) tvm.testing.assert_allclose(res2.asnumpy(), ref_res, rtol=1e-5) # Execute the updated abs function. x_nd = get_rand((5, 10), dtype) ref_res = np.abs(x_nd.asnumpy() * 2) for target, ctx in ctx_list(): exe1 = relay.create_executor("graph", ctx=ctx, target=target) exe2 = relay.create_executor("debug", ctx=ctx, target=target) res1 = exe1.evaluate(new_abs)(x_nd) tvm.testing.assert_allclose(res1.asnumpy(), ref_res, rtol=1e-5) res2 = exe2.evaluate(new_abs)(x_nd) tvm.testing.assert_allclose(res2.asnumpy(), ref_res, rtol=1e-5) test_pass_registration() test_no_pass() test_only_module_pass() test_only_function_pass() test_multiple_passes()
def test_global_var(): name_hint = 'g' gv = relay.GlobalVar(name_hint) gv.name_hint == name_hint show(gv)
def test_module_pass(): shape = (5, 10) dtype = 'float32' tp = relay.TensorType(shape, dtype) x = relay.var("x", tp) y = relay.var("y", tp) v_add = relay.GlobalVar("myAdd") func = relay.Function([x, y], x + y) mod = relay.Module({v_add: func}) pass_name = "module_pass_test" opt_level = 0 opt_tester = OptTester(mod) pass_ctx = None @_transform.module_pass(opt_level=opt_level, name=pass_name) def transform(expr, ctx): return opt_tester.transform(expr, ctx) def test_pass_registration(): mod_pass = transform assert isinstance(mod_pass, _transform.ModulePass) pass_info = mod_pass.info assert pass_info.name == pass_name assert pass_info.opt_level == opt_level def test_pass_registration_no_decorator(): def direct_transform(expr, ctx): return opt_tester.transform(expr, ctx) mod_pass = _transform.module_pass(direct_transform, opt_level=3) assert isinstance(mod_pass, _transform.ModulePass) pass_info = mod_pass.info assert pass_info.name == "direct_transform" assert pass_info.opt_level == 3 def test_pass_run(): module_pass = transform assert pass_name in module_pass.astext() updated_mod = module_pass(mod) assert isinstance(updated_mod, relay.Module) # Check the abs function in the updated module. v_abs, myabs = get_var_func() new_v_add = updated_mod.get_global_var(v_abs.name_hint) new_abs = updated_mod[new_v_add] check_func(new_abs, myabs) # Check the add function in the updated module. v_abs, myabs = get_var_func() new_v_add = updated_mod.get_global_var(v_add.name_hint) new_add = updated_mod[new_v_add] check_func(new_add, func) # Check the add function in the python transformed module. ret = opt_tester.transform(mod, pass_ctx) transformed_v_add = ret.get_global_var(v_add.name_hint) transformed_add = mod[transformed_v_add] check_func(new_add, transformed_add) # Execute the add function. x_nd = get_rand(shape, dtype) y_nd = get_rand(shape, dtype) ref_res = x_nd.asnumpy() + y_nd.asnumpy() for target, ctx in ctx_list(): exe1 = relay.create_executor("graph", ctx=ctx, target=target) exe2 = relay.create_executor("debug", ctx=ctx, target=target) res1 = exe1.evaluate(new_add)(x_nd, y_nd) tvm.testing.assert_allclose(res1.asnumpy(), ref_res, rtol=1e-5) res2 = exe2.evaluate(new_add)(x_nd, y_nd) tvm.testing.assert_allclose(res2.asnumpy(), ref_res, rtol=1e-5) test_pass_registration() test_pass_registration_no_decorator test_pass_run()
def conv2d_bias_relu(): ishape = (1, 32, 14, 14) w1shape = (32, 32, 3, 3) bshape = (32, 1, 1) # Composite function in_1 = relay.var("in_1", shape=ishape, dtype=dtype) in_2 = relay.var("in_2", shape=w1shape, dtype=dtype) in_3 = relay.var("in_3", shape=bshape, dtype=dtype) conv2d = relay.nn.conv2d(in_1, in_2, kernel_size=(3, 3), padding=(1, 1)) add = relay.add(conv2d, in_3) relu = relay.nn.relu(add) func = relay.Function([in_1, in_2, in_3], relu) func = func.with_attr('Composite', 'dnnl.conv2d_bias_relu') func = func.with_attr('PartitionedFromPattern', 'nn.conv2d_add_nn.relu_') # Partition function arg_1 = relay.var("arg_1", shape=ishape, dtype=dtype) arg_2 = relay.var("arg_2", shape=w1shape, dtype=dtype) arg_3 = relay.var("arg_3", shape=bshape, dtype=dtype) call = relay.Call(func, [arg_1, arg_2, arg_3]) p_func = relay.Function([arg_1, arg_2, arg_3], call) p_func = set_func_attr(p_func, "dnnl", "dnnl_0") glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = p_func # Main function data = relay.var("data", shape=ishape, dtype=dtype) weight = relay.var("weight", shape=w1shape, dtype=dtype) bias = relay.var('bias', shape=bshape, dtype=dtype) main_func = relay.Function([data, weight, bias], glb_var(data, weight, bias)) mod["main"] = main_func # Reference module data = relay.var("data", shape=ishape, dtype=dtype) weight = relay.var("weight", shape=w1shape, dtype=dtype) bias = relay.var('bias', shape=bshape, dtype=dtype) conv2d = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1)) add = relay.add(conv2d, bias) relu = relay.nn.relu(add) main_func = relay.Function([data, weight, bias], relu) ref_mod = tvm.IRModule() ref_mod["main"] = main_func i_data = np.random.uniform(0, 1, ishape).astype(dtype) w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) b_data = np.random.uniform(0, 1, bshape).astype(dtype) return mod, ref_mod, { 'data': i_data, 'weight': w1_data, 'bias': b_data }, (1, 32, 14, 14)
def test_list_constructor(): test_consz = relay.GlobalVar("test_consz") func = relay.Function([], cons(z(), nil())) mod[test_consz] = func assert mod[test_consz].body.checked_type == l(nat())
def test_global(): v = relay.GlobalVar('f') check_visit(v)
def test_list_constructor(): test_consz = relay.GlobalVar("test_consz") func = relay.Function([], cons(z(), nil())) prelude.mod[test_consz] = func ck_mod = relay.transform.InferType()(prelude.mod) assert ck_mod[test_consz].body.checked_type == rlist(nat())
def test_match(): # pair each match keyword with whether it specifies a complete match or not match_keywords = [("match", True), ("match?", False)] for (match_keyword, is_complete) in match_keywords: mod = tvm.IRModule() list_var = relay.GlobalTypeVar("List") typ_var = relay.TypeVar("A") cons_constructor = relay.Constructor( "Cons", [typ_var, list_var(typ_var)], list_var) nil_constructor = relay.Constructor("Nil", [], list_var) list_def = relay.TypeData( list_var, [typ_var], [cons_constructor, nil_constructor]) mod[list_var] = list_def length_var = relay.GlobalVar("length") typ_var = relay.TypeVar("A") input_type = list_var(typ_var) input_var = relay.Var("xs", input_type) rest_var = relay.Var("rest") cons_case = relay.Let( relay.var("", type_annotation=None), UNIT, relay.add(relay.const(1), relay.Call(length_var, [rest_var]))) body = relay.Match(input_var, [relay.Clause( relay.PatternConstructor( cons_constructor, [relay.PatternWildcard(), relay.PatternVar(rest_var)]), cons_case), relay.Clause( relay.PatternConstructor(nil_constructor, []), relay.const(0))], complete=is_complete ) length_func = relay.Function( [input_var], body, int32, [typ_var] ) mod[length_var] = length_func assert_parse_module_as( """ %s def @length[A](%%xs: List[A]) -> int32 { %s (%%xs) { Cons(_, %%rest : List[A]) => { (); 1 + @length(%%rest) }, Nil => 0, } } """ % (LIST_DEFN, match_keyword), mod )
def test_fused_reshape(): mod = tvm.ir.IRModule() @T.prim_func def mul_primfunc(a: T.handle, b: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) D = T.match_buffer(d, [128, 128]) for i, j, k in T.grid(128, 128, 128): with T.block("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) D[vi, vj] = A[vi, vk] * B[vj, vk] @T.prim_func def fused_reshape_primfunc(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [128, 128]) D = T.match_buffer(d, [128, 128]) for i, j in T.grid(128, 128): D[i, j] = A[i, j] metatable = {"VirtualDevice": [CPU]} mul_ty = relay.FuncType( [ relay.TensorType((128, 128), "float32"), relay.TensorType((128, 128), "float32"), relay.TensorType((128, 128), "float32"), ], relay.TensorType((128, 128), "float32"), ) mul_gv = relay.GlobalVar("multiply", type_annot=mul_ty) mod[mul_gv] = mul_primfunc reshape_ty = relay.FuncType( [ relay.TensorType((128, 128), "float32"), ], relay.TensorType((128, 128), "float32"), ) reshape_gv = relay.GlobalVar("fused_reshape", type_annot=reshape_ty) mod[reshape_gv] = fused_reshape_primfunc mod = tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], %z {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], virtual_device=meta[VirtualDevice][0]) { %0 = call_lowered(@multiply, (%x, %y, %z)); let %x_12: Tensor[(128, 128), float32] = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %1 = call_lowered(@fused_reshape, (%x_12,) ); let %x_14: Tensor[(128, 128), float32] = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True); %x_14 } """, "from_string", mod, metatable, ) # Expected main: ##[version = "0.0.5"] # def @main(%x /* ty=Tensor[(128, 128), float32] */) -> Tensor[(128, 128), float32] { # %0 = (%x, %y, %z); # %1 = call_lowered(@multiply, %0); # let %x_12: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True); # let %x_14: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True); # %x_14 # } mod = RemoveStandaloneReshapes()(mod) reshapes_present = any( ["reshape" in gv.name_hint for gv in mod.get_global_vars()]) assert reshapes_present, "Reshape should have been removed." return
from tvm.contrib.relay_viz.terminal import ( TermGraph, TermPlotter, TermVizParser, ) ###################################################################### # Define a Relay IR Module with multiple GlobalVar # ------------------------------------------------ # Let's build an example Relay IR Module containing multiple ``GlobalVar``. # We define an ``add`` function and call it in the main function. data = relay.var("data") bias = relay.var("bias") add_op = relay.add(data, bias) add_func = relay.Function([data, bias], add_op) add_gvar = relay.GlobalVar("AddFunc") input0 = relay.var("input0") input1 = relay.var("input1") input2 = relay.var("input2") add_01 = relay.Call(add_gvar, [input0, input1]) add_012 = relay.Call(add_gvar, [input2, add_01]) main_func = relay.Function([input0, input1, input2], add_012) main_gvar = relay.GlobalVar("main") mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func}) ###################################################################### # Render the graph with Relay Visualizer on the terminal # ------------------------------------------------------ # The terminal can show a Relay IR module in text similar to clang AST-dump.