def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If( equal(n, relay.const(0)), data, relay.Call(f, [subtract(n, relay.const(1.0)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) orig = relay.Let( f, funcbody, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
def test_tuple_get_item(): t = relay.Var('t') g = relay.TupleGetItem(t, 0) assert alpha_equal(dead_code_elimination(g), g) assert alpha_equal( dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g)
def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") n = relay.Var("n") np = relay.Param(n, e.int32) data = relay.Var("data") datap = relay.Param(data, e.float32) funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data))) value = relay.Function([np, datap], e.float32, funcbody, []) orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)), e.float32) assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal( dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three)
def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If(equal(n, relay.const(0)), data, relay.Call(f, [subtract(n, relay.const(1.0)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three)
def test_tuple_get_item(): t = relay.Var('t') g = relay.TupleGetItem(t, 0) assert alpha_equal(dead_code_elimination(g), g) assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g)
def test_op_let(): assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two))
def dcpe(expr): return dead_code_elimination(partial_evaluate(expr))
def test_let(): orig = relay.Let(e.x, e.y, e.z) assert alpha_equal(dead_code_elimination(orig), e.z)
def test_used_let(): orig = relay.Let(e.c, e.one, e.c + e.c) assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.one, e.c + e.c))
def test_if(): orig = relay.If(convert(True), e.a, e.b) assert alpha_equal(dead_code_elimination(orig), e.a)
def test_used_let(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt) assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt))
def test_inline(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) assert alpha_equal(dead_code_elimination(orig), e.d)
def test_chain_unused_let(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e)) assert alpha_equal(dead_code_elimination(orig), e.e)
def test_if(): cond = relay.const(True) orig = relay.If(cond, e.a, e.b) y = dead_code_elimination(orig) assert alpha_equal(y, e.a)
def dcpe(expr, mod=None): return dead_code_elimination(partial_evaluate(expr, mod=mod), inline_once=True)