def test_match_vars(): mod = tvm.IRModule() p = relay.prelude.Prelude(mod) x = relay.Var('x') y = relay.Var('y') z = relay.Var('z') match1 = relay.Match(p.nil(), [ relay.Clause(relay.PatternConstructor(p.nil), z), relay.Clause( relay.PatternConstructor( p.cons, [relay.PatternVar(x), relay.PatternVar(y)]), p.cons(x, y)) ]) match2 = relay.Match(p.nil(), [ relay.Clause( relay.PatternConstructor( p.cons, [relay.PatternWildcard(), relay.PatternVar(x)]), y), relay.Clause(relay.PatternWildcard(), z) ]) assert_vars_match(bound_vars(match1), [x, y]) assert_vars_match(free_vars(match1), [z]) assert_vars_match(all_vars(match1), [z, x, y]) assert_vars_match(bound_vars(match2), [x]) assert_vars_match(free_vars(match2), [y, z]) assert_vars_match(all_vars(match2), [x, y, z])
def test_bound_vars(): x = relay.Var("x") y = relay.Var("y") z = relay.Var("z") a = relay.Var("a") f1 = relay.Function([x, y, z], relay.Let(a, x, relay.Tuple([]))) assert_vars_match(bound_vars(f1), [x, y, z, a]) tup = relay.Tuple([x, y, z, a]) assert len(bound_vars(tup)) == 0 f2 = relay.Function([x, y], relay.Tuple([x, y, z, a])) assert_vars_match(bound_vars(f2), [x, y])
def test_match_vars(): mod = tvm.IRModule() p = relay.prelude.Prelude(mod) rlist, cons, nil = p.mod.get_type("List") x = relay.Var("x") y = relay.Var("y") z = relay.Var("z") match1 = relay.Match( nil(), [ relay.Clause(relay.PatternConstructor(nil), z), relay.Clause( relay.PatternConstructor( cons, [relay.PatternVar(x), relay.PatternVar(y)]), cons(x, y), ), ], ) match2 = relay.Match( nil(), [ relay.Clause( relay.PatternConstructor( cons, [relay.PatternWildcard(), relay.PatternVar(x)]), y), relay.Clause(relay.PatternWildcard(), z), ], ) assert_vars_match(bound_vars(match1), [x, y]) assert_vars_match(free_vars(match1), [z]) assert_vars_match(all_vars(match1), [z, x, y]) assert_vars_match(bound_vars(match2), [x]) assert_vars_match(free_vars(match2), [y, z]) assert_vars_match(all_vars(match2), [x, y, z])