def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") f1 = relay.Var("f1") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If( equal(n, relay.const(0)), data, relay.Call(f1, [subtract(n, relay.const(1)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) orig = relay.Let(f, value, relay.Call( f, [relay.const(2), relay.const(10000.0)])) dced = run_opt_pass(orig, transform.DeadCodeElimination()) orig = run_opt_pass(orig, transform.InferType()) assert graph_equal(dced, orig) dced = run_opt_pass(relay.Let(f, value, e.three), transform.DeadCodeElimination()) assert alpha_equal(dced, e.three)
def test_tuple_get_item(): tt = relay.TupleType([e.float32, e.float32]) t = relay.Var('t', tt) a = relay.Var('a') g = relay.TupleGetItem(t, 0) dced = run_opt_pass(g, transform.DeadCodeElimination()) assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0) dced = run_opt_pass(orig, transform.DeadCodeElimination()) assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
def test_tuple_get_item(): tt = relay.TupleType([e.float32, e.float32]) t = relay.Var('t', tt) a = relay.Var('a') g = relay.TupleGetItem(t, 0) dced = transform.OptimizeOnExpr(g, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0) dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
def test_refs(): """Don't elide expressions with reference create/read/write side effects""" before_program = """ #[version = "0.0.5"] def @f(%r) -> int { let %v = ref_read(%r); let %u = ref_write(%r, %v + 1); %v } def @main() -> int { let %r = ref(0); let %y = @f(%r); let %z = @f(%r); %z } """ after_program = before_program optimize_and_check( before_program, after_program, [ transform.InferType(), transform.DeadCodeElimination(inline_once=True) ], )
def test_dead_recursion(): before_program = """ #[version = "0.0.5"] def @main() { let %f = fn (%n: int, %data: int) -> int { if (%n == 0) { %data } else { %f(%n - 1, log(%data)) } }; () } """ after_program = """ #[version = "0.0.5"] def @main() { () } """ optimize_and_check( before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()])
def test_impure_op(): """Don't elide calls to side-effecting operators.""" before_program = tvm.parser.parse( """ #[version = "0.0.5"] def @main() { let %size: int64 = cast(1024, dtype="int64"); let %alignment: int64 = cast(64, dtype="int64"); let %x = memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][0]); 0 } """, "from_string", core, metatable, ) after_program = tvm.parser.parse( """ #[version = "0.0.5"] def @main() { let %x = memory.alloc_storage(cast(1024, dtype="int64"), cast(64, dtype="int64"), virtual_device=meta[VirtualDevice][0]); 0 } """, "from_string", core, metatable, ) optimize_and_check(before_program, after_program, transform.DeadCodeElimination(inline_once=True))
def test_checkpoint_alpha_equal(): xs = [ relay.var("x{}".format(i), relay.TensorType((1, ), "float32")) for i in range(4) ] f = relay.Function( xs, relay.annotation.checkpoint( relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3]))), ) df = transform.gradient(run_infer_type(f)) # run PE and DCE with tvm.transform.PassContext(opt_level=3): passes = [ transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True) ] mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = tvm.parser.parse_expr(""" #[version = "0.0.5"] fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> (Tensor[(1), float32], (Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32])) { %0 = add(%x, %y); %1 = add(%z, %w); let %x1: Tensor[(1), float32] = multiply(%0, %1); let %x2: Tensor[(1), float32] = ones_like(%x1); let %x3: Tensor[(1), float32] = add(%x, %y); let %x4: Tensor[(1), float32] = add(%z, %w); %2 = zeros_like(%x3); %3 = multiply(%x2, %x4); %4 = collapse_sum_like(%3, %x3); let %x5: Tensor[(1), float32] = add(%2, %4); %5 = zeros_like(%x4); %6 = multiply(%x2, %x3); %7 = collapse_sum_like(%6, %x4); let %x6: Tensor[(1), float32] = add(%5, %7); %8 = zeros_like(%x); %9 = collapse_sum_like(%x5, %x); %10 = add(%8, %9); %11 = zeros_like(%y); %12 = collapse_sum_like(%x5, %y); %13 = add(%11, %12); %14 = zeros_like(%z); %15 = collapse_sum_like(%x6, %z); %16 = add(%14, %15); %17 = zeros_like(%w); %18 = collapse_sum_like(%x6, %w); %19 = add(%17, %18); %20 = (%10, %13, %16, %19); (%x1, %20) } """) tvm.ir.assert_structural_equal(df, df_parsed)
def test_inline_into_function(): """Don't inline across function boundaries.""" before_program = """ #[version = "0.0.5"] def @main() { let %x = 1 + 1; let %f = fn (%y: int) -> int { let %z = %y + %y; %x + %z }; (%f(2), %f(3)) } """ after_program = """ #[version = "0.0.5"] def @main() { let %x = 1 + 1; let %f = fn (%y: int) -> int { %x + (%y + %y) }; (%f(2), %f(3)) } """ optimize_and_check(before_program, after_program, transform.DeadCodeElimination(inline_once=True))
def test_before_partial_eval(): """Test transformation before PartialEval""" mod = tvm.IRModule() shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) x = relay.var("x", t) y = relay.var("y", t) func = relay.Function([x, y], x * y) func = run_infer_type(func) back_func = transform.gradient(func) back_func = run_infer_type(back_func) mod["main"] = back_func seq = tvm.transform.Sequential( [transform.LazyGradientInit(), transform.PartialEvaluate(), transform.DeadCodeElimination()] ) mod = seq(mod) back_func = mod["main"] assert mod["main"].checked_type == relay.FuncType( [t, t], relay.TupleType([t, relay.TupleType([t, t])]) ) ex = create_executor(mod=mod) x = rand(dtype, *shape) y = rand(dtype, *shape)
def destroy_ref(x): x = run_infer_type(x) x = to_cps(x) x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)])) assert Feature.fRefCreate not in detect_feature(x)
def test_checkpoint_alpha_equal_tuple(): xs = [ relay.var("x{}".format(i), relay.TensorType((1, ), "float32")) for i in range(4) ] f = relay.Function( xs, relay.annotation.checkpoint( relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])])), ) df = transform.gradient(run_infer_type(f)) # run PE and DCE with tvm.transform.PassContext(opt_level=3): # See comment in test_checkpoint_alpha_equal above. # TODO(mbs): Revisit once DCE supports dead reference writes. passes = [ transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True, ignore_impurity=True), ] mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = tvm.parser.parse_expr(""" #[version = "0.0.5"] fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> ((Tensor[(1), float32], Tensor[(1), float32]), (Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32])) { let %x1: Tensor[(1), float32] = add(%x, %y) /* ty=Tensor[(1), float32] */; let %x2: Tensor[(1), float32] = add(%z, %w) /* ty=Tensor[(1), float32] */; let %x3: Tensor[(1), float32] = zeros_like(%x2) /* ty=Tensor[(1), float32] */; let %x4: Tensor[(1), float32] = ones_like(%x1) /* ty=Tensor[(1), float32] */; %0 = (%x1, %x2); %1 = zeros_like(%x) /* ty=Tensor[(1), float32] */; %2 = collapse_sum_like(%x4, %x) /* ty=Tensor[(1), float32] */; %3 = add(%1, %2) /* ty=Tensor[(1), float32] */; %4 = zeros_like(%y) /* ty=Tensor[(1), float32] */; %5 = collapse_sum_like(%x4, %y) /* ty=Tensor[(1), float32] */; %6 = add(%4, %5) /* ty=Tensor[(1), float32] */; %7 = zeros_like(%z) /* ty=Tensor[(1), float32] */; %8 = collapse_sum_like(%x3, %z) /* ty=Tensor[(1), float32] */; %9 = add(%7, %8) /* ty=Tensor[(1), float32] */; %10 = zeros_like(%w) /* ty=Tensor[(1), float32] */; %11 = collapse_sum_like(%x3, %w) /* ty=Tensor[(1), float32] */; %12 = add(%10, %11) /* ty=Tensor[(1), float32] */; %13 = (%3, %6, %9, %12); (%0, %13) } """) tvm.ir.assert_structural_equal(df, df_parsed)
def dcpe(expr, mod=None, grad=False): passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] if grad: expr = gradient(run_infer_type(expr)) if mod: assert isinstance(expr, Function) mod["main"] = expr seq = transform.Sequential(passes) mod = seq(mod) return mod["main"] return run_opt_pass(expr, passes)
def test_checkpoint_alpha_equal_tuple(): xs = [ relay.var("x{}".format(i), relay.TensorType((1, ), "float32")) for i in range(4) ] f = relay.Function( xs, relay.annotation.checkpoint( relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])]))) df = transform.gradient(run_infer_type(f)) # run PE and DCE with transform.PassContext(opt_level=3): passes = [ transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True) ] mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = relay.parser.fromtext(""" v0.0.4 fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> ((Tensor[(1), float32], Tensor[(1), float32]), (Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32], Tensor[(1), float32])) { let %x1: Tensor[(1), float32] = add(%x, %y) /* ty=Tensor[(1), float32] */; let %x2: Tensor[(1), float32] = add(%z, %w) /* ty=Tensor[(1), float32] */; let %x3: Tensor[(1), float32] = zeros_like(%x2) /* ty=Tensor[(1), float32] */; let %x4: Tensor[(1), float32] = ones_like(%x1) /* ty=Tensor[(1), float32] */; %0 = (%x1, %x2); %1 = zeros_like(%x) /* ty=Tensor[(1), float32] */; %2 = collapse_sum_like(%x4, %x) /* ty=Tensor[(1), float32] */; %3 = add(%1, %2) /* ty=Tensor[(1), float32] */; %4 = zeros_like(%y) /* ty=Tensor[(1), float32] */; %5 = collapse_sum_like(%x4, %y) /* ty=Tensor[(1), float32] */; %6 = add(%4, %5) /* ty=Tensor[(1), float32] */; %7 = zeros_like(%z) /* ty=Tensor[(1), float32] */; %8 = collapse_sum_like(%x3, %z) /* ty=Tensor[(1), float32] */; %9 = add(%7, %8) /* ty=Tensor[(1), float32] */; %10 = zeros_like(%w) /* ty=Tensor[(1), float32] */; %11 = collapse_sum_like(%x3, %w) /* ty=Tensor[(1), float32] */; %12 = add(%10, %11) /* ty=Tensor[(1), float32] */; %13 = (%3, %6, %9, %12); (%0, %13) } """) relay.analysis.assert_alpha_equal(df, df_parsed)
def dcpe(expr, mod=None, grad=False): passes = [ transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True) ] if grad: expr = gradient(expr) if mod: assert isinstance(expr, Function) mod[mod.entry_func] = expr seq = transform.Sequential(passes) mod = seq(mod) return mod[mod.entry_func] return transform.OptimizeOnExpr(expr, passes)
def test_dead_let(): before_program = """ #[version = "0.0.5"] def @main(%z: int) { let %x = 1; %z } """ after_program = """ #[version = "0.0.5"] def @main(%z: int) { %z } """ optimize_and_check(before_program, after_program, transform.DeadCodeElimination())
def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)])) dced = run_opt_pass(orig, transform.DeadCodeElimination()) orig = run_opt_pass(orig, transform.InferType()) tvm.ir.assert_structural_equal(dced, orig)
def test_tuple_get_item(): before_program = """ #[version = "0.0.5"] def @main() { let %a = 100; (1, 2, 3, 4).0 } """ after_program = """ #[version = "0.0.5"] def @main() { (1, 2, 3, 4).0 } """ optimize_and_check(before_program, after_program, transform.DeadCodeElimination())
def test_nested_let(): before_program = """ #[version = "0.0.5"] def @main(%d: int, %b: int) { let %a = %b; let %c = %d; %c } """ after_program = """ #[version = "0.0.5"] def @main(%d: int, %b: int) { let %c = %d; %c } """ optimize_and_check(before_program, after_program, transform.DeadCodeElimination())
def destroy_ref(x): x = run_infer_type(x) x = to_cps(x) x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) # TODO(mbs): Revisit once DCE can eliminate dead writes. x = run_opt_pass( x, tvm.transform.Sequential( [ transform.PartialEvaluate(), transform.InferType(), transform.DeadCodeElimination(inline_once=True, ignore_impurity=True), ] ), ) assert Feature.fRefCreate not in detect_feature(x)
def test_add_with_let(): before_program = """ #[version = "0.0.5"] def @main() { (let %a = 1; 3) + 2 } """ after_program = """ #[version = "0.0.5"] def @main() { 3 + 2 } """ optimize_and_check( before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()] )
def test_after_partial_eval(): """Test transformation following reverse mode ad and PartialEval""" mod = tvm.IRModule() shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) x = relay.var("x", t) y = relay.var("y", t) func = relay.Function([x, y], (x * y) * relay.const(np.ones(shape, dtype))) func = run_infer_type(func) back_func = transform.gradient(func) back_func = run_infer_type(back_func) mod["main"] = back_func back_func = mod["main"] seq = tvm.transform.Sequential( [ transform.PartialEvaluate(), transform.InferType(), transform.LazyGradientInit(), transform.InferType(), transform.DeadCodeElimination(), transform.InferType(), ] ) mod = seq(mod) assert mod["main"].checked_type == relay.FuncType( [t, t], relay.TupleType([t, relay.TupleType([t, t])]) ) x = rand(dtype, *shape) y = rand(dtype, *shape) (forward), (grad_x, grad_y,) = create_executor(mod=mod).evaluate( back_func )(x, y) assert_allclose(forward.numpy(), x.numpy() * y.numpy()) assert_allclose(grad_x.numpy(), y.numpy()) assert_allclose(grad_y.numpy(), x.numpy())
def test_impure_func(): """Don't elide calls to side-effecting functions.""" before_program = tvm.parser.parse( """ #[version = "0.0.5"] def @f() -> int { let %size: int64 = cast(1024, dtype="int64"); let %alignment: int64 = cast(64, dtype="int64"); let %x = memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][0]); 0 } def @main() -> int { let %y = @f(); 0 } """, "from_string", core, metatable, ) after_program = tvm.parser.parse( """ #[version = "0.0.5"] def @f() -> int { let %x = memory.alloc_storage(cast(1024, dtype="int64"), cast(64, dtype="int64"), se_scope=meta[SEScope][0]); 0 } def @main() -> int { let %y = @f(); 0 } """, "from_string", core, metatable, ) optimize_and_check(before_program, after_program, transform.DeadCodeElimination(inline_once=True))
def test_used_let(): orig = relay.Let(e.c, e.one, e.c + e.c) orig = run_opt_pass(orig, transform.DeadCodeElimination()) expected = relay.Let(e.c, e.one, e.c + e.c) assert tvm.ir.structural_equal(Function([], orig), Function([], expected))
def test_let(): orig = relay.Let(e.x, e.y, e.z) orig = run_opt_pass(orig, transform.DeadCodeElimination()) assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
def test_complexity(): g = inception_v3.get_net(1, 1000, (3, 299, 299), "float32") run_opt_pass(g, transform.DeadCodeElimination())
def test_op_let(): dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two), transform.DeadCodeElimination()) assert tvm.ir.structural_equal(dced, add(e.three, e.two))
def test_recursion_dead(): x = relay.Let(e.a, e.one, e.three) dced_f = lambda f: x dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination()) assert tvm.ir.structural_equal(dced, e.three)
def test_inline(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) orig = run_opt_pass(orig, transform.DeadCodeElimination(True)) tvm.ir.assert_structural_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
def test_chain_unused_let(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e)) orig = run_opt_pass(orig, transform.DeadCodeElimination()) assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
def test_complexity(): mod = transform.InferType()(tvm.IRModule.from_expr( inception_v3.get_net(1, 1000, (3, 299, 299), "float32"))) optimize_and_check(mod, mod, transform.DeadCodeElimination(inline_once=True))