def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If( equal(n, relay.const(0)), data, relay.Call(f, [subtract(n, relay.const(1.0)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) orig = relay.Let( f, funcbody, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
def test_call_alpha_equal(): v1 = relay.Var("v1") v2 = relay.Var("v2") # attrs are compared only by pointer equality attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3, 4)) attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3, 4)) tt1 = relay.TensorType((1, 2, 3), "float32") tt2 = relay.TensorType((), "int8") basic_args = [relay.const(1), relay.const(2), v2, relay.Tuple([])] # manually writing out args to ensure that args does not rely on # pointer equality call = relay.Call( v1, [relay.const(1), relay.const(2), v2, relay.Tuple([])], attr1, [tt1]) same = relay.Call(v1, basic_args, attr1, [tt1]) assert alpha_equal(call, same) different_fn = relay.Call(v2, basic_args, attr1, [tt1]) assert not alpha_equal(call, different_fn) fewer_args = relay.Call( v1, [relay.const(1), relay.const(2), v2], attr1, [tt1]) assert not alpha_equal(call, fewer_args) reordered_args = relay.Call( v1, [relay.const(2), relay.const(1), relay.Tuple([]), v2], attr1, [tt1]) assert not alpha_equal(call, reordered_args) different_args = relay.Call( v1, [relay.const(1), relay.const(2), relay.const(3)], attr1, [tt1]) assert not alpha_equal(call, different_args) more_args = relay.Call(v1, [ relay.const(1), relay.const(2), v2, relay.Tuple([]), relay.const(3), relay.const(4) ], attr1, [tt1]) assert not alpha_equal(call, more_args) different_attrs = relay.Call(v1, basic_args, attr2, [tt1]) assert not alpha_equal(call, different_attrs) no_type_args = relay.Call(v1, basic_args, attr1) assert not alpha_equal(call, no_type_args) more_type_args = relay.Call(v1, basic_args, attr1, [tt1, tt2]) assert not alpha_equal(call, more_type_args) different_type_arg = relay.Call(v1, basic_args, attr1, [tt2]) assert not alpha_equal(call, different_type_arg)
def test_tuple_get_item(): t = relay.Var('t') g = relay.TupleGetItem(t, 0) assert alpha_equal(dead_code_elimination(g), g) assert alpha_equal( dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g)
def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") n = relay.Var("n") np = relay.Param(n, e.int32) data = relay.Var("data") datap = relay.Param(data, e.float32) funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data))) value = relay.Function([np, datap], e.float32, funcbody, []) orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)), e.float32) assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal( dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three)
def test_global_var_alpha_equal(): v1 = relay.GlobalVar("v1") v2 = relay.GlobalVar("v2") # only pointer equality suffices (smoke test) assert alpha_equal(v1, v1) assert not alpha_equal(v1, v2)
def test_function_type(): assert alpha_equal( relay.fromtext(""" let %_: fn () -> int32 = fn () -> int32 { 0 }; () """), relay.Let(relay.Var("_", relay.FuncType([], int32, [], [])), relay.Function([], relay.const(0), int32, []), UNIT)) assert alpha_equal( relay.fromtext(""" let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () """), relay.Let( relay.Var("_", relay.FuncType([int32], int32, [], [])), relay.Function([relay.Var("x", int32)], relay.const(0), int32, []), UNIT)) assert alpha_equal( relay.fromtext(""" let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () """), relay.Let( relay.Var("_", relay.FuncType([int32, int32], int32, [], [])), relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []), UNIT))
def test_tensor_type(): assert alpha_equal( relay.fromtext("let %_ : Tensor[(), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((), "float32")), UNIT, UNIT ) ) assert alpha_equal( relay.fromtext("let %_ : Tensor[(1,), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((1,), "float32")), UNIT, UNIT ) ) assert alpha_equal( relay.fromtext("let %_ : Tensor[(1, 1), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((1, 1), "float32")), UNIT, UNIT ) )
def test_op_alpha_equal(): # only checks names op1 = relay.op.get("add") op2 = relay.op.get("add") assert alpha_equal(op1, op2) op3 = relay.op.get("take") assert not alpha_equal(op1, op3)
def test_tuple(): assert alpha_equal(relay.fromtext("()"), relay.Tuple([])) assert alpha_equal(relay.fromtext("(0,)"), relay.Tuple([relay.const(0)])) assert alpha_equal(relay.fromtext("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)])) assert alpha_equal(relay.fromtext("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))
def test_tuple_get_item(): tt = relay.TupleType([e.float32, e.float32]) t = relay.Var('t', tt) a = relay.Var('a') g = relay.TupleGetItem(t, 0) dced = transform.OptimizeOnExpr(g, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0) dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
def test_tuple_alpha_equal(): v1 = relay.Var("v1") v2 = relay.Var("v2") # unit value is a valid tuple assert alpha_equal(relay.Tuple([]), relay.Tuple([])) tup = relay.Tuple( [v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) same = relay.Tuple( [v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) assert alpha_equal(tup, same) # use the eq_map let_tup = relay.Let(v1, tup, v1) let_mapped = relay.Let( v2, relay.Tuple([ v2, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]) ]), v2) assert alpha_equal(let_tup, let_mapped) more_fields = relay.Tuple([ v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2 ]) assert not alpha_equal(tup, more_fields) fewer_fields = relay.Tuple([v1, relay.const(2), relay.const(3)]) assert not alpha_equal(tup, fewer_fields) different_end = relay.Tuple( [v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(5)])]) assert not alpha_equal(tup, different_end) different_start = relay.Tuple( [v2, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) assert not alpha_equal(tup, different_start) longer_at_end = relay.Tuple([ v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4), relay.const(5)]) ]) assert not alpha_equal(tup, longer_at_end)
def test_comments(): assert alpha_equal( relay.fromtext(""" // This is a line comment! () """), UNIT) assert alpha_equal( relay.fromtext(""" /* This is a block comment! This is still a block comment! */ () """), UNIT)
def test_param_alpha_equal(): # only checks equality of the types v1 = relay.Var("v1") v2 = relay.Var("v2") p1 = relay.Param(v1, relay.TensorType((1, 2, 3), "float32")) p2 = relay.Param(v2, relay.TensorType((1, 2, 3), "float32")) assert alpha_equal(p1, p2) p3 = relay.Param(v1, relay.TensorType((4, 5, 6), "int8")) assert not alpha_equal(p1, p3) p4 = relay.Param(v1, relay.TupleType([relay.TensorType((1, 2, 3), "float32")])) assert not alpha_equal(p1, p4)
def test_var_alpha_equal(): v1 = relay.Var("v1") v2 = relay.Var("v2") # normally only pointer equality assert alpha_equal(v1, v1) assert not alpha_equal(v1, v2) # let node allows for setting the eq_map l1 = relay.Let(v1, convert(1), v1) l2 = relay.Let(v2, convert(1), v2) l3 = relay.Let(v1, convert(1), v2) assert alpha_equal(l1, l2) assert not alpha_equal(l1, l3)
def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") f1 = relay.Var("f1") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If(equal(n, relay.const(0)), data, relay.Call(f1, [subtract(n, relay.const(1)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) orig = relay.Let(f, value, relay.Call(f, [relay.const(2), relay.const(10000.0)])) dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) orig = transform.OptimizeOnExpr(orig, transform.InferType()) assert graph_equal(dced, orig) dced = transform.OptimizeOnExpr(relay.Let(f, value, e.three), transform.DeadCodeElimination()) assert alpha_equal(dced, e.three)
def test_if_alpha_equal(): v1 = relay.Var("v1") v2 = relay.Var("v2") if_sample = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)])) same = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)])) assert alpha_equal(if_sample, same) different_cond = relay.If(v2, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)])) assert not alpha_equal(if_sample, different_cond) different_true = relay.If(v1, relay.const(2), relay.Tuple([relay.const(2), relay.const(3)])) assert not alpha_equal(if_sample, different_true) different_false = relay.If(v1, relay.const(1), relay.Tuple([])) assert not alpha_equal(if_sample, different_false)
def test_callback(): def before(): x = relay.var("x", shape=(1, 16)) y1 = relay.nn.relu(x) y2 = relay.nn.relu(x) y1 = relay.add(y1, relay.const(1.0, "float32")) y2 = relay.add(y2, relay.const(1.0, "float32")) y = relay.add(y1, y2) f = relay.Function([x], y) return f def expected(): x = relay.var("x", shape=(1, 16)) y = relay.nn.relu(x) y1 = relay.add(y, relay.const(1.0, "float32")) y2 = relay.add(y, relay.const(1.0, "float32")) y = relay.add(y1, y2) f = relay.Function([x], y) return f def fskip(expr): if isinstance(expr, relay.expr.Call) and expr.op.name == 'add': return True return False z = before() z = ir_pass.eliminate_common_subexpr(z, fskip) assert ir_pass.alpha_equal(z, expected())
def test_double(): mod = Module() p = Prelude(mod) add_nat_definitions(p) orig = p.double(make_nat_expr(p, 3)) res = dcpe(orig, mod=mod) assert alpha_equal(res, make_nat_expr(p, 6))
def test_func(): # 0 args assert alpha_equal(relay.fromtext("fn () { 0 }"), relay.Function([], relay.const(0), None, [])) # 1 arg assert alpha_equal(relay.fromtext("fn (%x) { %x }"), relay.Function([X], X, None, [])) # 2 args assert alpha_equal(relay.fromtext("fn (%x, %y) { %x + %y }"), relay.Function([X, Y], relay.add(X, Y), None, [])) # annotations assert alpha_equal(relay.fromtext("fn (%x: int32) -> int32 { %x }"), relay.Function([X_ANNO], X_ANNO, int32, []))
def alpha_equal(x, y): """ Wrapper around alpha equality which ensures that the hash function respects equality. """ return ir_pass.alpha_equal( x, y) and ir_pass.structural_hash(x) == ir_pass.structural_hash(y)
def test_seq(): assert alpha_equal( relay.fromtext("(); ()"), relay.Let( _, UNIT, UNIT) ) assert alpha_equal( relay.fromtext("let %_ = { 1 }; ()"), relay.Let( X, relay.const(1), UNIT ) )
def test_tuple(): t = TypeVar("t") x = Var("x", t) body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) f = Function([x], body, None, [t]) expected = relay.Function([x], x, None, [t]) expected = transform.OptimizeOnExpr(expected, transform.InferType()) assert alpha_equal(dcpe(f), expected)
def test_let_alpha_equal(): tt1 = relay.TensorType((), "float32") tt2 = relay.TensorType((), "int8") v1 = relay.Var("v1") v1_wtype = relay.Var("v1", tt1) v2 = relay.Var("v2") v3 = relay.Var("v3") let = relay.Let(v1, convert(2), v1) mapped = relay.Let(v2, convert(2), v2) assert alpha_equal(let, mapped) mismatched_var = relay.Let(v2, convert(2), v3) assert not alpha_equal(let, mismatched_var) different_value = relay.Let(v2, convert(3), v2) assert not alpha_equal(let, different_value) different_body = relay.Let(v2, convert(3), convert(12)) assert not alpha_equal(let, different_body) # specified types must match let_with_type = relay.Let(v1_wtype, convert(2), v1_wtype) same_type = relay.Let(v1_wtype, convert(2), v1_wtype) assert alpha_equal(let_with_type, same_type) assert not alpha_equal(let, let_with_type) v2 = relay.Var("v1", tt2) different_type = relay.Let(v2, convert(2), v2) assert not alpha_equal(let_with_type, different_type)
def test_var_alpha_equal(): v1 = relay.Var("v1") v2 = relay.Var("v2") # normally only pointer equality assert alpha_equal(v1, v1) assert not alpha_equal(v1, v2) # let node allows for setting the eq_map l1 = relay.Let(v1, relay.const(1), v1) l2 = relay.Let(v2, relay.const(1), v2) l3 = relay.Let(v1, relay.const(1), v2) assert alpha_equal(l1, l2) assert not alpha_equal(l1, l3) # type annotations tt1 = relay.TensorType([], "int32") tt2 = relay.TensorType([], "int32") tt3 = relay.TensorType([], "int64") v3 = relay.Var("v3", tt1) v4 = relay.Var("v4", tt2) v5 = relay.Var("v5", tt3) l4 = relay.Let(v3, relay.const(1), v3) l5 = relay.Let(v4, relay.const(1), v4) l6 = relay.Let(v5, relay.const(1), v5) # same annotations assert alpha_equal(l4, l5) # different annotations assert not alpha_equal(l4, l6) # one null annotation assert not alpha_equal(l1, l4)
def test_empty_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d) g = dcpe(gradient(f)) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) assert alpha_equal(g, expected)
def test_ifelse(): assert alpha_equal( relay.fromtext(""" if (True) { 0 } else { 1 } """), relay.If(relay.const(True), relay.const(0), relay.const(1)))
def test_graph_equal(): x = relay.var("x") y0 = relay.add(x, x) z0 = relay.add(y0, y0) y1 = relay.add(x, x) z1 = relay.add(y1, y1) z3 = relay.add(relay.add(x, x), relay.add(x, x)) assert alpha_equal(z0, z1) # z3's dataflow format is different from z0 # z0 is computed from a common y0 node # Relay view them as different programs # Check the difference in the text format. assert not alpha_equal(z0, z3)
def test_let(): assert alpha_equal( relay.fromtext("let %x = 1; ()"), relay.Let( X, relay.const(1), UNIT ) )
def test_incomplete_type(): assert alpha_equal( relay.fromtext("let %_ : _ = (); ()"), relay.Let( _, UNIT, UNIT ) )
def test_ref(): d = relay.Var("d") r = relay.Var("r") x = relay.Var("x") body = relay.RefRead(r) body = relay.Let(x, relay.RefWrite(r, relay.RefRead(r) * relay.RefRead(r)), body) body = relay.Let(r, relay.RefCreate(d), body) square = relay.Function([d], body) assert alpha_equal(dcpe(square), relay.Function([d], d * d))
def test_ad(): # TODO(MK): fix me shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) d = relay.Var("d", t) f = relay.Function([d], d * d) g = dcpe(gradient(f)) m = d * d o = relay.op.ones_like(m) grad = relay.op.zeros_like(d) + relay.op.collapse_sum_like(o * d, d) + relay.op.collapse_sum_like(o * d, d) expected = relay.Function([d], relay.Tuple([m, relay.Tuple([grad])])) assert alpha_equal(g, expected)
def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If(equal(n, relay.const(0)), data, relay.Call(f, [subtract(n, relay.const(1.0)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three)
def test_head_cons(): mod = relay.Module() p = Prelude(mod) def hd_impl(): a = relay.TypeVar("a") x = relay.Var("x", p.l(a)) y = relay.Var("y") z = relay.Var("z") cons_case = relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternVar(y), relay.PatternVar(z)]), y) return relay.Function([x], relay.Match(x, [cons_case]), a, [a]) t = relay.TypeVar("t") x = relay.Var("x", t) hd = relay.Var("hd") body = relay.Let(hd, hd_impl(), hd(p.cons(x, p.nil()))) f = relay.Function([x], body, None, [t]) f = infer_type(f, mod=mod) res = dcpe(f) assert alpha_equal(res, relay.Function([x], x, t, [t]))
def test_simple(): def before(): x = relay.var("x", shape=(1, 16)) y1 = relay.nn.relu(x) y2 = relay.nn.relu(x) y1 = relay.add(y1, relay.const(1.0, "float32")) y2 = relay.add(y2, relay.const(1.0, "float32")) y = relay.add(y1, y2) f = relay.Function([x], y) return f def expected(): x = relay.var("x", shape=(1, 16)) y = relay.nn.relu(x) y = relay.add(y, relay.const(1.0, "float32")) y = relay.add(y, y) f = relay.Function([x], y) return f z = before() z = ir_pass.eliminate_common_subexpr(z) assert ir_pass.alpha_equal(z, expected())
def test_chain_unused_let(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e)) assert alpha_equal(dead_code_elimination(orig), e.e)
def test_used_let(): orig = relay.Let(e.c, e.one, e.c + e.c) assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.one, e.c + e.c))
def test_tuple(): t = relay.TypeVar("t") x = relay.Var("x", t) body = relay.TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) f = relay.Function([x], body, None, [t]) assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
def alpha_equal(x, y): """ Wrapper around alpha equality which ensures that the hash function respects equality. """ return ir_pass.alpha_equal(x, y) and ir_pass.structural_hash(x) == ir_pass.structural_hash(y)
def test_let(): orig = relay.Let(e.x, e.y, e.z) assert alpha_equal(dead_code_elimination(orig), e.z)
def test_const_inline(): # TODO(MK): fix me d = relay.Var("d") double = relay.Function([d], d + d) orig = double(relay.const(4.0)) assert alpha_equal(dcpe(double(relay.const(4.0))), relay.const(8.0))
def test_tuple_get_item(): t = relay.Var('t') g = relay.TupleGetItem(t, 0) assert alpha_equal(dead_code_elimination(g), g) assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g)
def test_inline(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) assert alpha_equal(dead_code_elimination(orig), e.d)
def parses_as(code, expr): # type: (str, relay.Expr) -> bool return alpha_equal(relay.fromtext(SEMVER + "\n" + code), expr)
def test_op_assoc(): assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1 < 1 == 1"), relay.fromtext(SEMVER+"(((1 * 1) + 1) < 1) == 1")) assert alpha_equal(relay.fromtext(SEMVER+"1 == 1 < 1 + 1 * 1"), relay.fromtext(SEMVER+"1 == (1 < (1 + (1 * 1)))"))
def test_parens(): assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"(1 * 1) + 1")) assert not alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"1 * (1 + 1)"))
def test_op_let(): assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two))