def test_match_vars(): mod = tvm.IRModule() p = relay.prelude.Prelude(mod) x = relay.Var('x') y = relay.Var('y') z = relay.Var('z') match1 = relay.Match(p.nil(), [ relay.Clause(relay.PatternConstructor(p.nil), z), relay.Clause( relay.PatternConstructor( p.cons, [relay.PatternVar(x), relay.PatternVar(y)]), p.cons(x, y)) ]) match2 = relay.Match(p.nil(), [ relay.Clause( relay.PatternConstructor( p.cons, [relay.PatternWildcard(), relay.PatternVar(x)]), y), relay.Clause(relay.PatternWildcard(), z) ]) assert_vars_match(bound_vars(match1), [x, y]) assert_vars_match(free_vars(match1), [z]) assert_vars_match(all_vars(match1), [z, x, y]) assert_vars_match(bound_vars(match2), [x]) assert_vars_match(free_vars(match2), [y, z]) assert_vars_match(all_vars(match2), [x, y, z])
def test_all_vars(): x = relay.Var("x") y = relay.Var("y") z = relay.Var("z") f1 = relay.Function([x, y], z) assert_vars_match(all_vars(f1), [x, y, z]) f2 = relay.Function([x], relay.Let(y, relay.Tuple([]), z)) assert_vars_match(all_vars(f2), [x, y, z]) f3 = relay.Function([x], relay.Tuple([y, z])) assert_vars_match(all_vars(f3), [x, y, z]) tup = relay.Tuple([x, y, z]) assert_vars_match(all_vars(tup), [x, y, z])
def test_match_vars(): mod = tvm.IRModule() p = relay.prelude.Prelude(mod) rlist, cons, nil = p.mod.get_type("List") x = relay.Var("x") y = relay.Var("y") z = relay.Var("z") match1 = relay.Match( nil(), [ relay.Clause(relay.PatternConstructor(nil), z), relay.Clause( relay.PatternConstructor( cons, [relay.PatternVar(x), relay.PatternVar(y)]), cons(x, y), ), ], ) match2 = relay.Match( nil(), [ relay.Clause( relay.PatternConstructor( cons, [relay.PatternWildcard(), relay.PatternVar(x)]), y), relay.Clause(relay.PatternWildcard(), z), ], ) assert_vars_match(bound_vars(match1), [x, y]) assert_vars_match(free_vars(match1), [z]) assert_vars_match(all_vars(match1), [z, x, y]) assert_vars_match(bound_vars(match2), [x]) assert_vars_match(free_vars(match2), [y, z]) assert_vars_match(all_vars(match2), [x, y, z])