def test_adt(): mod = relay.Module() p = Prelude(mod) x = relay.Var("x") s_case = relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), x) default_case = relay.Clause(relay.PatternVar(x), x) m0 = relay.Match(p.z(), [default_case]) m1 = relay.Match(p.z(), [s_case, default_case]) assert well_formed(m0) assert not well_formed(m1)
def test_nested_matches(): a = relay.TypeVar("a") # TODO(@jroesch): inference should be able to handle this one x = relay.Var("x", type_annotation=rlist(rlist(a))) y = relay.Var("y") w = relay.Var("w") h = relay.Var("h") t = relay.Var("t") flatten = relay.GlobalVar("flatten") # flatten could be written using a fold, but this way has nested matches inner_match = relay.Match( y, [ relay.Clause(relay.PatternConstructor(nil), flatten(w)), relay.Clause( relay.PatternConstructor(cons, [relay.PatternVar(h), relay.PatternVar(t)]), cons(h, flatten(cons(t, w))), ), ], ) prelude.mod[flatten] = relay.Function( [x], relay.Match( x, [ relay.Clause(relay.PatternConstructor(nil), nil()), relay.Clause( relay.PatternConstructor(cons, [relay.PatternVar(y), relay.PatternVar(w)]), inner_match, ), ], ), rlist(a), [a], ) first_list = cons( make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), ) second_list = cons( make_nat_expr(prelude, 4), cons(make_nat_expr(prelude, 5), cons(make_nat_expr(prelude, 6), nil())), ) final_list = cons(first_list, cons(second_list, nil())) res = intrp.evaluate(flatten(final_list)) flat = to_list(res) assert len(flat) == 6 for i in range(6): assert count(flat[i]) == i + 1
def test_tuple_match(): a = relay.Var("a") b = relay.Var("b") clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) a = relay.Var("a") b = relay.Var("b") clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) assert analysis.alpha_equal(x, y) assert analysis.structural_hash(x) == analysis.structural_hash(y)
def test_adt(): mod = tvm.IRModule() p = Prelude(mod) _, none, some = p.mod.get_type("Option") x = relay.Var("x") some_case = relay.Clause( relay.PatternConstructor(some, [relay.PatternVar(x)]), x) default_case = relay.Clause(relay.PatternVar(x), x) m0 = relay.Match(none(), [default_case]) m1 = relay.Match(none(), [some_case, default_case]) assert well_formed(m0) assert not well_formed(m1)
def test_tuple_match(): a = relay.Var("a") b = relay.Var("b") clause = relay.Clause( relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) a = relay.Var("a") b = relay.Var("b") clause = relay.Clause( relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) assert consistent_equal(x, y)
def test_unfoldl(): a = relay.TypeVar("a") b = relay.TypeVar("b") expected_type = relay.FuncType( [relay.FuncType([a], optional(relay.TupleType([a, b]))), a], l(b), [a, b] ) x = relay.Var("x", nat()) n = relay.Var("n", nat()) count_down = relay.Function( [x], relay.Match( x, [ relay.Clause( relay.PatternConstructor(s, [relay.PatternVar(n)]), some(relay.Tuple([n, x])) ), relay.Clause(relay.PatternConstructor(z, []), none()), ], ), ) res = intrp.evaluate(unfoldl(count_down, make_nat_expr(3))) unfolded = to_list(res) assert len(unfolded) == 3 assert count(unfolded[0]) == 1 assert count(unfolded[1]) == 2 assert count(unfolded[2]) == 3
def test_optional_matching(): x = relay.Var("x") y = relay.Var("y") v = relay.Var("v") condense = relay.Function( [x, y], relay.Match( x, [ relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(v)]), cons(v, y)), relay.Clause(relay.PatternConstructor(none), y), ], ), ) res = intrp.evaluate( foldr( condense, nil(), cons(some(make_nat_expr(3)), cons(none(), cons(some(make_nat_expr(1)), nil()))), ) ) reduced = to_list(res) assert len(reduced) == 2 assert count(reduced[0]) == 3 assert count(reduced[1]) == 1
def test_match_order(): mod = tvm.IRModule() box, box_ctor = init_box_adt(mod) v = relay.Var("v") w = relay.Var("w") # wildcard pattern goes first match = relay.Let( v, box_ctor(box_ctor(relay.const(2))), relay.Match( v, [ relay.Clause(relay.PatternWildcard(), relay.const(1)), relay.Clause( relay.PatternConstructor(box_ctor, [ relay.PatternConstructor(box_ctor, [relay.PatternVar(w)]) ]), w, ), ], ), ) match_val = run_as_python(match, mod) assert_tensor_value(match_val, 1)
def test_nested_matches(): a = relay.TypeVar("a") x = relay.Var("x") y = relay.Var("y") w = relay.Var("w") h = relay.Var("h") t = relay.Var("t") flatten = relay.GlobalVar("flatten") # flatten could be written using a fold, but this way has nested matches inner_match = relay.Match( y, [ relay.Clause(relay.PatternConstructor(nil), flatten(w)), relay.Clause( relay.PatternConstructor(cons, [relay.PatternVar(h), relay.PatternVar(t)]), cons(h, flatten(cons(t, w))), ), ], ) mod[flatten] = relay.Function( [x], relay.Match( x, [ relay.Clause(relay.PatternConstructor(nil), nil()), relay.Clause( relay.PatternConstructor(cons, [relay.PatternVar(y), relay.PatternVar(w)]), inner_match, ), ], ), l(a), [a], ) first_list = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))) second_list = cons(make_nat_expr(4), cons(make_nat_expr(5), cons(make_nat_expr(6), nil()))) final_list = cons(first_list, cons(second_list, nil())) res = intrp.evaluate(flatten(final_list)) flat = to_list(res) assert len(flat) == 6 for i in range(6): assert count(flat[i]) == i + 1
def test_match(): # pair each match keyword with whether it specifies a complete match or not match_keywords = [("match", True), ("match?", False)] for (match_keyword, is_complete) in match_keywords: mod = tvm.IRModule() list_var = relay.GlobalTypeVar("List") typ_var = relay.TypeVar("A") cons_constructor = relay.Constructor( "Cons", [typ_var, list_var(typ_var)], list_var) nil_constructor = relay.Constructor("Nil", [], list_var) list_def = relay.TypeData(list_var, [typ_var], [cons_constructor, nil_constructor]) mod[list_var] = list_def length_var = relay.GlobalVar("length") typ_var = relay.TypeVar("A") input_type = list_var(typ_var) input_var = relay.Var("xs", input_type) rest_var = relay.Var("rest") cons_case = relay.Let( relay.var("", type_annotation=None), UNIT, relay.add(relay.const(1), relay.Call(length_var, [rest_var])), ) body = relay.Match( input_var, [ relay.Clause( relay.PatternConstructor( cons_constructor, [relay.PatternWildcard(), relay.PatternVar(rest_var)]), cons_case, ), relay.Clause(relay.PatternConstructor(nil_constructor, []), relay.const(0)), ], complete=is_complete, ) length_func = relay.Function([input_var], body, int32, [typ_var]) mod[length_var] = length_func assert_parse_module_as( """ %s def @length[A](%%xs: List[A]) -> int32 { %s (%%xs) { Cons(_, %%rest : List[A]) => { (); 1 + @length(%%rest) }, Nil => 0, } } """ % (LIST_DEFN, match_keyword), mod, )
def test_wildcard_match_order(): x = relay.Var('x', l(nat())) y = relay.Var('y') a = relay.Var('a') return_zero = relay.Function( [x], relay.Match(x, [ relay.Clause(relay.PatternWildcard(), z()), relay.Clause( relay.PatternConstructor( cons, [relay.PatternVar(y), relay.PatternVar(a)]), y), relay.Clause(relay.PatternConstructor(nil), s(z())) ]), nat()) res = intrp.evaluate(return_zero(cons(s(z()), nil()))) # wildcard pattern is evaluated first assert count(res) == 0
def test_trivial_matches(): # a match clause with a wildcard will match anything v = relay.Var("v") match = relay.Match(v, [relay.Clause(relay.PatternWildcard(), v)]) assert len(unmatched_cases(match)) == 0 # same with a pattern var w = relay.Var("w") match = relay.Match(v, [relay.Clause(relay.PatternVar(w), w)]) assert len(unmatched_cases(match)) == 0
def test_global_recursion(): mod = tvm.IRModule() p = Prelude(mod) rlist, cons, nil = p.mod.get_type("List") copy = relay.GlobalVar("copy") # same as above: it copies the given list a = relay.TypeVar("a") v = relay.Var("v", rlist(a)) h = relay.Var("h") t = relay.Var("t") copy_def = relay.Function( [v], relay.Match( v, [ relay.Clause( relay.PatternConstructor( cons, [relay.PatternVar(h), relay.PatternVar(t)]), cons(h, copy(t)), ), relay.Clause(relay.PatternConstructor(nil, []), nil()), ], ), rlist(a), [a], ) mod[copy] = copy_def call1 = copy_def(cons(relay.const(1), cons(relay.const(2), nil()))) val1 = run_as_python(call1, mod) assert_constructor_value(val1, cons, 2) assert_tensor_value(val1.fields[0], 1) assert_constructor_value(val1.fields[1], cons, 2) assert_tensor_value(val1.fields[1].fields[0], 2) assert_constructor_value(val1.fields[1].fields[1], nil, 0) call2 = copy_def(cons(relay.Tuple([]), nil())) val2 = run_as_python(call2, mod) assert_constructor_value(val2, cons, 2) assert_adt_len(val2.fields[0], 0) assert_constructor_value(val2.fields[1], nil, 0)
def test_nested_pattern_match(): x = relay.Var('x', l(nat())) h1 = relay.Var('h1') h2 = relay.Var('h2') t = relay.Var('t') match = relay.Match(x, [ relay.Clause( relay.PatternConstructor(cons, [ relay.PatternVar(h1), relay.PatternConstructor( cons, [relay.PatternVar(h2), relay.PatternVar(t)]) ]), h2), relay.Clause(relay.PatternWildcard(), z()) ]) get_second = relay.Function([x], match) res = intrp.evaluate(get_second(cons(s(z()), cons(s(s(z())), nil())))) assert count(res) == 2
def test_match_var(): mod = relay.Module() box, box_ctor = init_box_adt(mod) v = relay.Var('v') w = relay.Var('w') match = relay.Let(v, box_ctor(relay.const(1)), relay.Match(v, [relay.Clause(relay.PatternVar(w), w)])) match_val = run_as_python(match, mod) assert_constructor_value(match_val, box_ctor, 1) assert_tensor_value(match_val.fields[0], 1)
def test_local_recursion(): mod = tvm.IRModule() p = Prelude(mod) _, cons, nil = p.mod.get_type("List") v = relay.Var("v") h = relay.Var("h") t = relay.Var("t") f = relay.Var("f") # just returns the same list let = relay.Let( f, relay.Function( [v], relay.Match( v, [ relay.Clause( relay.PatternConstructor( cons, [relay.PatternVar(h), relay.PatternVar(t)]), cons(h, f(t)), ), relay.Clause(relay.PatternConstructor(nil, []), nil()), ], ), ), f( cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))), ) val = run_as_python(let, mod) assert_constructor_value(val, cons, 2) assert_tensor_value(val.fields[0], 1) assert_constructor_value(val.fields[1], cons, 2) assert_tensor_value(val.fields[1].fields[0], 2) assert_constructor_value(val.fields[1].fields[1], cons, 2) assert_tensor_value(val.fields[1].fields[1].fields[0], 3) assert_constructor_value(val.fields[1].fields[1].fields[1], nil, 0)
def test_nested_matches(): a = relay.TypeVar('a') x = relay.Var('x') y = relay.Var('y') w = relay.Var('w') h = relay.Var('h') t = relay.Var('t') flatten = relay.GlobalVar('flatten') # flatten could be written using a fold, but this way has nested matches inner_match = relay.Match(y, [ relay.Clause(relay.PatternConstructor(nil), flatten(w)), relay.Clause( relay.PatternConstructor( cons, [relay.PatternVar(h), relay.PatternVar(t)]), cons(h, flatten(cons(t, w)))) ]) mod[flatten] = relay.Function( [x], relay.Match(x, [ relay.Clause(relay.PatternConstructor(nil), nil()), relay.Clause( relay.PatternConstructor( cons, [relay.PatternVar(y), relay.PatternVar(w)]), inner_match) ]), l(a), [a]) first_list = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) second_list = cons(build_nat(4), cons(build_nat(5), cons(build_nat(6), nil()))) final_list = cons(first_list, cons(second_list, nil())) res = intrp.evaluate(flatten(final_list)) flat = to_list(res) assert len(flat) == 6 for i in range(6): assert count(flat[i]) == i + 1
def test_match_vars(): mod = tvm.IRModule() p = relay.prelude.Prelude(mod) rlist, cons, nil = p.mod.get_type("List") x = relay.Var("x") y = relay.Var("y") z = relay.Var("z") match1 = relay.Match( nil(), [ relay.Clause(relay.PatternConstructor(nil), z), relay.Clause( relay.PatternConstructor( cons, [relay.PatternVar(x), relay.PatternVar(y)]), cons(x, y), ), ], ) match2 = relay.Match( nil(), [ relay.Clause( relay.PatternConstructor( cons, [relay.PatternWildcard(), relay.PatternVar(x)]), y), relay.Clause(relay.PatternWildcard(), z), ], ) assert_vars_match(bound_vars(match1), [x, y]) assert_vars_match(free_vars(match1), [z]) assert_vars_match(all_vars(match1), [z, x, y]) assert_vars_match(bound_vars(match2), [x]) assert_vars_match(free_vars(match2), [y, z]) assert_vars_match(all_vars(match2), [x, y, z])
def test_wildcard_match_order(): x = relay.Var("x", rlist(nat())) y = relay.Var("y") a = relay.Var("a") return_zero = relay.Function( [x], relay.Match( x, [ relay.Clause(relay.PatternWildcard(), z()), relay.Clause( relay.PatternConstructor( cons, [relay.PatternVar(y), relay.PatternVar(a)]), y), relay.Clause(relay.PatternConstructor(nil), s(z())), ], ), nat(), ) res = eval(return_zero(cons(s(z()), nil()))) # wildcard pattern is evaluated first assert count(res) == 0
def test_match_full_var(): x = relay.Var("x") v = relay.Var("v") id_func = relay.Function([x], relay.Match(x, [relay.Clause(relay.PatternVar(v), v)])) res1 = intrp.evaluate(id_func(nil())) res2 = intrp.evaluate(id_func(cons(z(), cons(z(), nil())))) empty = to_list(res1) assert len(empty) == 0 zeroes = to_list(res2) assert len(zeroes) == 2 assert count(zeroes[0]) == 0 assert count(zeroes[1]) == 0
def test_nested_match_pattern(): mod = relay.Module() box, box_ctor = init_box_adt(mod) v = relay.Var('v') w = relay.Var('w') match = relay.Let( v, box_ctor(box_ctor(relay.const(2))), relay.Match(v, [ relay.Clause( relay.PatternConstructor(box_ctor, [ relay.PatternConstructor(box_ctor, [relay.PatternVar(w)]) ]), w) ])) match_val = run_as_python(match, mod) assert_tensor_value(match_val, 2)
def test_match_pattern(): mod = tvm.IRModule() box, box_ctor = init_box_adt(mod) v = relay.Var("v") w = relay.Var("w") match = relay.Let( v, box_ctor(relay.const(1)), relay.Match(v, [ relay.Clause( relay.PatternConstructor(box_ctor, [relay.PatternVar(w)]), w) ]), ) match_val = run_as_python(match, mod) assert_tensor_value(match_val, 1)
def test_adt_match(): mod = relay.Module() box, constructor = initialize_box_adt(mod) v = relay.Var('v', relay.TensorType((), 'float32')) match = relay.Match(constructor(relay.const(0, 'float32')), [relay.Clause( relay.PatternConstructor(constructor, [relay.PatternVar(v)]), relay.Tuple([])), # redundant but shouldn't matter to typechecking relay.Clause(relay.PatternWildcard(), relay.Tuple([]))]) mt = relay.ir_pass.infer_type(match, mod) assert mt.checked_type == relay.TupleType([])
def test_adt_match_type_annotations(): mod = relay.Module() box, constructor = initialize_box_adt(mod) # the only type annotation is inside the match pattern var # but that should be enough info tt = relay.TensorType((2, 2), 'float32') x = relay.Var('x') mv = relay.Var('mv', tt) match = relay.Match(constructor(x), [ relay.Clause( relay.PatternConstructor(constructor, [relay.PatternVar(mv)]), relay.Tuple([])) ]) func = relay.Function([x], match) ft = relay.ir_pass.infer_type(func, mod) assert ft.checked_type == relay.FuncType([tt], relay.TupleType([]))
def test_optional_matching(): x = relay.Var('x') y = relay.Var('y') v = relay.Var('v') condense = relay.Function( [x, y], relay.Match(x, [ relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(v)]), cons(v, y)), relay.Clause(relay.PatternConstructor(none), y) ])) res = intrp.evaluate(foldr(condense, nil(), cons( some(build_nat(3)), cons(none(), cons(some(build_nat(1)), nil()))))) reduced = to_list(res) assert len(reduced) == 2 assert count(reduced[0]) == 3 assert count(reduced[1]) == 1
def test_adt_match(): mod = tvm.IRModule() box, constructor = initialize_box_adt(mod) v = relay.Var("v", relay.TensorType((), "float32")) match = relay.Match( constructor(relay.const(0, "float32")), [ relay.Clause( relay.PatternConstructor(constructor, [relay.PatternVar(v)]), relay.Tuple([])), # redundant but shouldn't matter to typechecking relay.Clause(relay.PatternWildcard(), relay.Tuple([])), ], ) func = relay.Function([], match) mod["main"] = func mod = infer_mod(mod) actual = mod["main"].checked_type.ret_type assert actual == relay.TupleType([])
def test_adt_match_type_annotations(): mod = tvm.IRModule() box, constructor = initialize_box_adt(mod) # the only type annotation is inside the match pattern var # but that should be enough info tt = relay.TensorType((2, 2), "float32") x = relay.Var("x") mv = relay.Var("mv", tt) match = relay.Match( constructor(x), [ relay.Clause( relay.PatternConstructor(constructor, [relay.PatternVar(mv)]), relay.Tuple([])) ], ) mod["main"] = relay.Function([x], match) mod = infer_mod(mod) ft = mod["main"].checked_type assert ft == relay.FuncType([tt], relay.TupleType([]))
def test_match_alpha_equal(): mod = relay.Module() p = relay.prelude.Prelude(mod) x = relay.Var('x') y = relay.Var('y') nil_case = relay.Clause(relay.PatternConstructor(p.nil), p.nil()) cons_case = relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternVar(x), relay.PatternVar(y)]), p.cons(x, y)) z = relay.Var('z') a = relay.Var('a') equivalent_cons = relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternVar(z), relay.PatternVar(a)]), p.cons(z, a)) data = p.cons(relay.const(1), p.cons(relay.const(2), p.nil())) match = relay.Match(data, [nil_case, cons_case]) equivalent = relay.Match(data, [nil_case, equivalent_cons]) empty = relay.Match(data, []) no_cons = relay.Match(data, [nil_case]) no_nil = relay.Match(data, [cons_case]) different_data = relay.Match(p.nil(), [nil_case, cons_case]) different_order = relay.Match(data, [cons_case, nil_case]) different_nil = relay.Match(data, [ relay.Clause(relay.PatternConstructor(p.nil), p.cons(p.nil(), p.nil())), cons_case ]) different_cons = relay.Match(data, [ nil_case, relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]), p.nil()) ]) another_case = relay.Match(data, [ nil_case, cons_case, relay.Clause(relay.PatternWildcard(), p.nil()) ]) wrong_constructors = relay.Match(data, [ relay.Clause(relay.PatternConstructor(p.none), p.nil()), relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]), p.cons(x, p.nil())) ]) assert alpha_equal(match, match) assert alpha_equal(match, equivalent) assert not alpha_equal(match, no_cons) assert not alpha_equal(match, no_nil) assert not alpha_equal(match, empty) assert not alpha_equal(match, different_data) assert not alpha_equal(match, different_order) assert not alpha_equal(match, different_nil) assert not alpha_equal(match, different_cons) assert not alpha_equal(match, another_case) assert not alpha_equal(match, wrong_constructors)
def test_tuple_match(): a = relay.Var("a") b = relay.Var("b") clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) assert len(unmatched_cases(x)) == 0
def test_expr_node_incompatible_sequal(): v1 = relay.Var("v1") v2 = relay.PatternVar(relay.Var("v2")) assert not consistent_equal(v1, v2)