def test_pass_registration(): passes = [module_pass, function_pass] opt_level = 2 pass_name = "sequential_pass" sequential_pass = ir_pass.sequential_pass(passes=passes, opt_level=opt_level) assert isinstance(sequential_pass, ir_pass.SequentialPass) pass_info = sequential_pass.info assert pass_info.name == pass_name assert pass_info.opt_level == opt_level
def test_only_function_pass(): # Check the subtract function. passes = [function_pass] sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) ret_mod = sequential_pass(mod) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, get_ref_sub()) # Check the log function. log_var, new_log = extract_var_func(ret_mod, v_log.name_hint) check_func(new_log, get_ref_log())
def test_only_module_pass(): passes = [module_pass] sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) ret_mod = sequential_pass(mod) # Check the subtract function. sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, sub) # Check the abs function is added. abs_var, abs_func = get_var_func() abs_var, new_abs = extract_var_func(ret_mod, abs_var.name_hint) check_func(new_abs, abs_func)
def test_multiple_passes(): # Reset the current module since mod has been polluted by the previous # function pass. mod = relay.Module({v_sub: sub, v_log: log}) passes = [module_pass, function_pass] sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) ret_mod = sequential_pass(mod) # Check the abs function is added. abs_var, abs_func = get_var_func() abs_var, new_abs = extract_var_func(ret_mod, abs_var.name_hint) check_func(new_abs, get_ref_abs()) # Check the subtract function is modified correctly. _, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, get_ref_sub()) # Check the log function is modified correctly. _, new_log = extract_var_func(ret_mod, v_log.name_hint) check_func(new_log, get_ref_log()) # Execute the updated subtract function. x_nd = get_rand(shape, dtype) y_nd = get_rand(shape, dtype) ref_res = np.subtract(x_nd.asnumpy() * 2, y_nd.asnumpy() * 2) for target, ctx in ctx_list(): exe1 = relay.create_executor("graph", ctx=ctx, target=target) exe2 = relay.create_executor("debug", ctx=ctx, target=target) res1 = exe1.evaluate(new_sub)(x_nd, y_nd) tvm.testing.assert_allclose(res1.asnumpy(), ref_res, rtol=1e-5) res2 = exe2.evaluate(new_sub)(x_nd, y_nd) tvm.testing.assert_allclose(res2.asnumpy(), ref_res, rtol=1e-5) # Execute the updated abs function. x_nd = get_rand((5, 10), dtype) ref_res = np.abs(x_nd.asnumpy() * 2) for target, ctx in ctx_list(): exe1 = relay.create_executor("graph", ctx=ctx, target=target) exe2 = relay.create_executor("debug", ctx=ctx, target=target) res1 = exe1.evaluate(new_abs)(x_nd) tvm.testing.assert_allclose(res1.asnumpy(), ref_res, rtol=1e-5) res2 = exe2.evaluate(new_abs)(x_nd) tvm.testing.assert_allclose(res2.asnumpy(), ref_res, rtol=1e-5)
def test_no_pass(): passes = [] sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) ret_mod = sequential_pass(mod) mod_func = ret_mod[v_sub] check_func(sub, mod_func)