def test_match_vars(): mod = relay.Module() p = relay.prelude.Prelude(mod) x = relay.Var('x') y = relay.Var('y') z = relay.Var('z') match1 = relay.Match(p.nil(), [ relay.Clause(relay.PatternConstructor(p.nil), z), relay.Clause( relay.PatternConstructor( p.cons, [relay.PatternVar(x), relay.PatternVar(y)]), p.cons(x, y)) ]) match2 = relay.Match(p.nil(), [ relay.Clause( relay.PatternConstructor( p.cons, [relay.PatternWildcard(), relay.PatternVar(x)]), y), relay.Clause(relay.PatternWildcard(), z) ]) assert_vars_match(bound_vars(match1), [x, y]) assert_vars_match(free_vars(match1), [z]) assert_vars_match(all_vars(match1), [z, x, y]) assert_vars_match(bound_vars(match2), [x]) assert_vars_match(free_vars(match2), [y, z]) assert_vars_match(all_vars(match2), [x, y, z])
def test_match_vars(): mod = relay.Module() p = relay.prelude.Prelude(mod) x = relay.Var('x') y = relay.Var('y') z = relay.Var('z') match1 = relay.Match(p.nil(), [ relay.Clause(relay.PatternConstructor(p.nil), z), relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternVar(x), relay.PatternVar(y)]), p.cons(x, y)) ]) match2 = relay.Match(p.nil(), [ relay.Clause(relay.PatternConstructor(p.cons, [ relay.PatternWildcard(), relay.PatternVar(x) ]), y), relay.Clause(relay.PatternWildcard(), z) ]) assert_vars_match(bound_vars(match1), [x, y]) assert_vars_match(free_vars(match1), [z]) assert_vars_match(all_vars(match1), [z, x, y]) assert_vars_match(bound_vars(match2), [x]) assert_vars_match(free_vars(match2), [y, z]) assert_vars_match(all_vars(match2), [x, y, z])
def test_tuple(): t = relay.Var('t') fv = free_vars(relay.Tuple([t, t])) assert len(fv) == 1 assert fv[0] == t fv = free_vars(relay.TupleGetItem(t, 123)) assert len(fv) == 1 assert fv[0] == t
def test_tuple(): t = relay.Var('t') fv = free_vars(relay.Tuple([t, t])) assert len(fv) == 1 assert fv[0] == t fv = free_vars(relay.TupleGetItem(t, 123)) assert len(fv) == 1 assert fv[0] == t
def test_tuple_get_item(): tt = relay.TupleType([e.float32, e.float32]) t = relay.Var('t', tt) a = relay.Var('a') g = relay.TupleGetItem(t, 0) dced = transform.OptimizeOnExpr(g, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0) dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
def test_free_vars(): x = relay.Var("x") fvx = free_vars(x) assert len(fvx) == 1 assert fvx[0] == x v = relay.Constant(tvm.nd.array(10)) ty = relay.TensorType([], "int32") let = relay.Let(x, v, x, ty) fvx = free_vars(let) assert len(free_vars(let)) == 0 f = relay.Function([relay.Param(x, ty)], ty, x) assert len(free_vars(f)) == 0
def test_free_vars(): ty = relay.TensorType([], "int32") x = relay.Var("x", ty) fvx = free_vars(x) assert len(fvx) == 1 assert fvx[0] == x v = relay.Constant(tvm.nd.array(10)) let = relay.Let(x, v, x) fvx = free_vars(let) assert len(free_vars(let)) == 0 f = relay.Function([x], x, ty) assert len(free_vars(f)) == 0
def test_free_type_vars(): tp = relay.TypeParam("") ty = relay.TupleType([tp, relay.TensorType([], "int32")]) x = relay.Var("x", ty) y = relay.Var("y") let = relay.Let(x, y, x) fvl = free_vars(let) assert len(fvl) == 1 assert fvl[0] == y ftvl = free_type_vars(let) assert len(ftvl) == 1 assert ftvl[0] == tp
def test_free_type_vars(): tp = relay.TypeVar("") ty = relay.TupleType([tp, relay.TensorType([], "int32")]) x = relay.Var("x", ty) y = relay.Var("y") let = relay.Let(x, y, x) fvl = free_vars(let) assert len(fvl) == 1 assert fvl[0] == y ftvl = free_type_vars(let) assert len(ftvl) == 1 assert ftvl[0] == tp
def test_chain_unused_let(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e)) orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
def test_inline(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
def test_let(): orig = relay.Let(e.x, e.y, e.z) orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z))