Example #1
0
def test_match_alpha_equal():
    mod = relay.Module()
    p = relay.prelude.Prelude(mod)

    x = relay.Var('x')
    y = relay.Var('y')
    nil_case = relay.Clause(relay.PatternConstructor(p.nil), p.nil())
    cons_case = relay.Clause(
        relay.PatternConstructor(
            p.cons,
            [relay.PatternVar(x), relay.PatternVar(y)]), p.cons(x, y))

    z = relay.Var('z')
    a = relay.Var('a')
    equivalent_cons = relay.Clause(
        relay.PatternConstructor(
            p.cons,
            [relay.PatternVar(z), relay.PatternVar(a)]), p.cons(z, a))

    data = p.cons(relay.const(1), p.cons(relay.const(2), p.nil()))

    match = relay.Match(data, [nil_case, cons_case])
    equivalent = relay.Match(data, [nil_case, equivalent_cons])
    empty = relay.Match(data, [])
    no_cons = relay.Match(data, [nil_case])
    no_nil = relay.Match(data, [cons_case])
    different_data = relay.Match(p.nil(), [nil_case, cons_case])
    different_order = relay.Match(data, [cons_case, nil_case])
    different_nil = relay.Match(data, [
        relay.Clause(relay.PatternConstructor(p.nil), p.cons(
            p.nil(), p.nil())), cons_case
    ])
    different_cons = relay.Match(data, [
        nil_case,
        relay.Clause(
            relay.PatternConstructor(
                p.cons, [relay.PatternWildcard(),
                         relay.PatternWildcard()]), p.nil())
    ])
    another_case = relay.Match(
        data,
        [nil_case, cons_case,
         relay.Clause(relay.PatternWildcard(), p.nil())])
    wrong_constructors = relay.Match(data, [
        relay.Clause(relay.PatternConstructor(p.none), p.nil()),
        relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]),
                     p.cons(x, p.nil()))
    ])

    assert alpha_equal(match, match)
    assert alpha_equal(match, equivalent)
    assert not alpha_equal(match, no_cons)
    assert not alpha_equal(match, no_nil)
    assert not alpha_equal(match, empty)
    assert not alpha_equal(match, different_data)
    assert not alpha_equal(match, different_order)
    assert not alpha_equal(match, different_nil)
    assert not alpha_equal(match, different_cons)
    assert not alpha_equal(match, another_case)
    assert not alpha_equal(match, wrong_constructors)
def test_too_specific_match():
    mod = tvm.IRModule()
    p = Prelude(mod)

    v = relay.Var('v')
    match = relay.Match(v, [
        relay.Clause(
            relay.PatternConstructor(p.cons, [
                relay.PatternWildcard(),
                relay.PatternConstructor(
                    p.cons, [relay.PatternWildcard(),
                             relay.PatternWildcard()])
            ]), v)
    ])

    unmatched = unmatched_cases(match, mod)

    # will not match nil or a list of length 1
    nil_found = False
    single_length_found = False
    assert len(unmatched) == 2
    for case in unmatched:
        assert isinstance(case, relay.PatternConstructor)
        if case.constructor == p.nil:
            nil_found = True
        if case.constructor == p.cons:
            assert isinstance(case.patterns[1], relay.PatternConstructor)
            assert case.patterns[1].constructor == p.nil
            single_length_found = True
Example #3
0
def test_filter():
    a = relay.TypeVar("a")
    expected_type = relay.FuncType(
        [relay.FuncType([a], relay.scalar_type("bool")),
         l(a)], l(a), [a])
    assert mod[filter].checked_type == expected_type

    x = relay.Var("x", nat())
    greater_than_one = relay.Function(
        [x],
        relay.Match(x, [
            relay.Clause(
                relay.PatternConstructor(
                    s,
                    [relay.PatternConstructor(s, [relay.PatternWildcard()])]),
                relay.const(True)),
            relay.Clause(relay.PatternWildcard(), relay.const(False))
        ]))
    res = intrp.evaluate(
        filter(
            greater_than_one,
            cons(
                make_nat_expr(1),
                cons(
                    make_nat_expr(1),
                    cons(
                        make_nat_expr(3),
                        cons(
                            make_nat_expr(1),
                            cons(make_nat_expr(5),
                                 cons(make_nat_expr(1), nil()))))))))
    filtered = to_list(res)
    assert len(filtered) == 2
    assert count(filtered[0]) == 3
    assert count(filtered[1]) == 5
def test_match_vars():
    mod = tvm.IRModule()
    p = relay.prelude.Prelude(mod)

    x = relay.Var('x')
    y = relay.Var('y')
    z = relay.Var('z')

    match1 = relay.Match(p.nil(), [
        relay.Clause(relay.PatternConstructor(p.nil), z),
        relay.Clause(
            relay.PatternConstructor(
                p.cons,
                [relay.PatternVar(x), relay.PatternVar(y)]), p.cons(x, y))
    ])

    match2 = relay.Match(p.nil(), [
        relay.Clause(
            relay.PatternConstructor(
                p.cons, [relay.PatternWildcard(),
                         relay.PatternVar(x)]), y),
        relay.Clause(relay.PatternWildcard(), z)
    ])

    assert_vars_match(bound_vars(match1), [x, y])
    assert_vars_match(free_vars(match1), [z])
    assert_vars_match(all_vars(match1), [z, x, y])

    assert_vars_match(bound_vars(match2), [x])
    assert_vars_match(free_vars(match2), [y, z])
    assert_vars_match(all_vars(match2), [x, y, z])
def test_single_constructor_adt():
    mod = tvm.IRModule()
    box = relay.GlobalTypeVar("box")
    a = relay.TypeVar("a")
    box_ctor = relay.Constructor("box", [a], box)
    box_data = relay.TypeData(box, [a], [box_ctor])
    mod[box] = box_data

    v = relay.Var("v")
    match = relay.Match(
        v, [relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), v)]
    )

    # with one constructor, having one pattern constructor case is exhaustive
    assert len(unmatched_cases(match, mod)) == 0

    # this will be so if we nest the constructors too
    nested_pattern = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(
                    box_ctor,
                    [
                        relay.PatternConstructor(
                            box_ctor,
                            [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()])],
                        )
                    ],
                ),
                v,
            )
        ],
    )
    assert len(unmatched_cases(nested_pattern, mod)) == 0
def test_match_effect_exactly_once():
    mod = tvm.IRModule()
    p = Prelude(mod)
    _, cons, nil = p.mod.get_type("List")

    # the list should be of length 1!
    # Unless we mistakenly execute the data clause more than once
    r = relay.Var("r")
    data = seq(relay.RefWrite(r, cons(relay.Tuple([]), relay.RefRead(r))),
               relay.RefRead(r))
    match = relay.Let(
        r,
        relay.RefCreate(nil()),
        relay.Match(
            data,
            [
                relay.Clause(relay.PatternConstructor(nil, []),
                             relay.const(0)),
                relay.Clause(
                    relay.PatternConstructor(cons, [
                        relay.PatternWildcard(),
                        relay.PatternConstructor(nil, [])
                    ]),
                    relay.const(1),
                ),
                relay.Clause(relay.PatternWildcard(), relay.const(2)),
            ],
        ),
    )

    match_val = run_as_python(match, mod)
    assert_tensor_value(match_val, 1)
Example #7
0
def test_missing_in_the_middle():
    mod = tvm.IRModule()
    p = Prelude(mod)
    _, cons, nil = mod.get_type("List")

    v = relay.Var("v")
    match = relay.Match(
        v,
        [
            # list of length exactly 1
            relay.Clause(
                relay.PatternConstructor(cons, [
                    relay.PatternWildcard(),
                    relay.PatternConstructor(nil, [])
                ]),
                v,
            ),
            # empty list
            relay.Clause(relay.PatternConstructor(nil, []), v),
            # list of length 3 or more
            relay.Clause(
                relay.PatternConstructor(
                    cons,
                    [
                        relay.PatternWildcard(),
                        relay.PatternConstructor(
                            cons,
                            [
                                relay.PatternWildcard(),
                                relay.PatternConstructor(
                                    cons, [
                                        relay.PatternWildcard(),
                                        relay.PatternWildcard()
                                    ]),
                            ],
                        ),
                    ],
                ),
                v,
            ),
        ],
    )

    # fails to match a list of length exactly two
    unmatched = unmatched_cases(match, mod)
    assert len(unmatched) == 1
    assert isinstance(unmatched[0], relay.PatternConstructor)
    assert unmatched[0].constructor == cons
    assert isinstance(unmatched[0].patterns[1], relay.PatternConstructor)
    assert unmatched[0].patterns[1].constructor == cons
    assert isinstance(unmatched[0].patterns[1].patterns[1],
                      relay.PatternConstructor)
    assert unmatched[0].patterns[1].patterns[1].constructor == nil
def test_match_order():
    mod = tvm.IRModule()
    box, box_ctor = init_box_adt(mod)
    v = relay.Var("v")
    w = relay.Var("w")
    # wildcard pattern goes first
    match = relay.Let(
        v,
        box_ctor(box_ctor(relay.const(2))),
        relay.Match(
            v,
            [
                relay.Clause(relay.PatternWildcard(), relay.const(1)),
                relay.Clause(
                    relay.PatternConstructor(box_ctor, [
                        relay.PatternConstructor(box_ctor,
                                                 [relay.PatternVar(w)])
                    ]),
                    w,
                ),
            ],
        ),
    )
    match_val = run_as_python(match, mod)
    assert_tensor_value(match_val, 1)
Example #9
0
def test_nested_pattern_match():
    x = relay.Var("x", l(nat()))
    h1 = relay.Var("h1")
    h2 = relay.Var("h2")
    t = relay.Var("t")
    match = relay.Match(
        x,
        [
            relay.Clause(
                relay.PatternConstructor(
                    cons,
                    [
                        relay.PatternVar(h1),
                        relay.PatternConstructor(cons, [relay.PatternVar(h2), relay.PatternVar(t)]),
                    ],
                ),
                h2,
            ),
            relay.Clause(relay.PatternWildcard(), z()),
        ],
    )
    get_second = relay.Function([x], match)

    res = intrp.evaluate(get_second(cons(s(z()), cons(s(s(z())), nil()))))

    assert count(res) == 2
Example #10
0
def test_match():
    # pair each match keyword with whether it specifies a complete match or not
    match_keywords = [("match", True), ("match?", False)]
    for (match_keyword, is_complete) in match_keywords:
        mod = tvm.IRModule()

        list_var = relay.GlobalTypeVar("List")
        typ_var = relay.TypeVar("A")
        cons_constructor = relay.Constructor(
            "Cons", [typ_var, list_var(typ_var)], list_var)
        nil_constructor = relay.Constructor("Nil", [], list_var)
        list_def = relay.TypeData(list_var, [typ_var],
                                  [cons_constructor, nil_constructor])
        mod[list_var] = list_def

        length_var = relay.GlobalVar("length")
        typ_var = relay.TypeVar("A")
        input_type = list_var(typ_var)
        input_var = relay.Var("xs", input_type)
        rest_var = relay.Var("rest")
        cons_case = relay.Let(
            relay.var("", type_annotation=None),
            UNIT,
            relay.add(relay.const(1), relay.Call(length_var, [rest_var])),
        )
        body = relay.Match(
            input_var,
            [
                relay.Clause(
                    relay.PatternConstructor(
                        cons_constructor,
                        [relay.PatternWildcard(),
                         relay.PatternVar(rest_var)]),
                    cons_case,
                ),
                relay.Clause(relay.PatternConstructor(nil_constructor, []),
                             relay.const(0)),
            ],
            complete=is_complete,
        )
        length_func = relay.Function([input_var], body, int32, [typ_var])
        mod[length_var] = length_func

        assert_parse_module_as(
            """
            %s

            def @length[A](%%xs: List[A]) -> int32 {
              %s (%%xs) {
                Cons(_, %%rest : List[A]) => {
                  ();
                  1 + @length(%%rest)
                },
                Nil => 0,
              }
            }
            """ % (LIST_DEFN, match_keyword),
            mod,
        )
Example #11
0
def test_wildcard_match_solo():
    x = relay.Var('x', nat())
    copy = relay.Function([x],
                          relay.Match(x, [relay.Clause(relay.PatternWildcard(), x)]),
                          nat())

    res = intrp.evaluate(copy(s(s(s(z())))))
    assert count(res) == 3
def test_trivial_matches():
    # a match clause with a wildcard will match anything
    v = relay.Var("v")
    match = relay.Match(v, [relay.Clause(relay.PatternWildcard(), v)])
    assert len(unmatched_cases(match)) == 0

    # same with a pattern var
    w = relay.Var("w")
    match = relay.Match(v, [relay.Clause(relay.PatternVar(w), w)])
    assert len(unmatched_cases(match)) == 0
Example #13
0
def test_match_wildcard():
    mod = relay.Module()
    box, box_ctor = init_box_adt(mod)
    v = relay.Var('v')
    match = relay.Let(
        v, box_ctor(relay.Tuple([])),
        relay.Match(v,
                    [relay.Clause(relay.PatternWildcard(), relay.const(1))]))

    match_val = run_as_python(match, mod)
    assert_tensor_value(match_val, 1)
def test_multiple_constructor_clauses():
    mod = tvm.IRModule()
    p = Prelude(mod)

    v = relay.Var('v')
    match = relay.Match(v, [
        # list of length exactly 1
        relay.Clause(
            relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
                                              relay.PatternConstructor(p.nil, [])]), v),
        # list of length exactly 2
        relay.Clause(
            relay.PatternConstructor(
                p.cons, [relay.PatternWildcard(),
                         relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
                                                           relay.PatternConstructor(p.nil, [])
                         ])]), v),
        # empty list
        relay.Clause(
            relay.PatternConstructor(p.nil, []), v),
        # list of length 2 or more
        relay.Clause(
            relay.PatternConstructor(
                p.cons, [relay.PatternWildcard(),
                         relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
                                                           relay.PatternWildcard()])]), v)
    ])
    assert len(unmatched_cases(match, mod)) == 0
Example #15
0
def test_match_vars():
    mod = tvm.IRModule()
    p = relay.prelude.Prelude(mod)
    rlist, cons, nil = p.mod.get_type("List")

    x = relay.Var("x")
    y = relay.Var("y")
    z = relay.Var("z")

    match1 = relay.Match(
        nil(),
        [
            relay.Clause(relay.PatternConstructor(nil), z),
            relay.Clause(
                relay.PatternConstructor(
                    cons, [relay.PatternVar(x),
                           relay.PatternVar(y)]),
                cons(x, y),
            ),
        ],
    )

    match2 = relay.Match(
        nil(),
        [
            relay.Clause(
                relay.PatternConstructor(
                    cons, [relay.PatternWildcard(),
                           relay.PatternVar(x)]), y),
            relay.Clause(relay.PatternWildcard(), z),
        ],
    )

    assert_vars_match(bound_vars(match1), [x, y])
    assert_vars_match(free_vars(match1), [z])
    assert_vars_match(all_vars(match1), [z, x, y])

    assert_vars_match(bound_vars(match2), [x])
    assert_vars_match(free_vars(match2), [y, z])
    assert_vars_match(all_vars(match2), [x, y, z])
Example #16
0
def test_missing_in_the_middle():
    mod = relay.Module()
    p = Prelude(mod)

    v = relay.Var('v')
    match = relay.Match(
        v,
        [
            # list of length exactly 1
            relay.Clause(
                relay.PatternConstructor(p.cons, [
                    relay.PatternWildcard(),
                    relay.PatternConstructor(p.nil, [])
                ]), v),
            # empty list
            relay.Clause(relay.PatternConstructor(p.nil, []), v),
            # list of length 3 or more
            relay.Clause(
                relay.PatternConstructor(p.cons, [
                    relay.PatternWildcard(),
                    relay.PatternConstructor(p.cons, [
                        relay.PatternWildcard(),
                        relay.PatternConstructor(
                            p.cons,
                            [relay.PatternWildcard(),
                             relay.PatternWildcard()])
                    ])
                ]), v)
        ])

    # fails to match a list of length exactly two
    unmatched = unmatched_cases(match, mod)
    assert len(unmatched) == 1
    assert isinstance(unmatched[0], relay.PatternConstructor)
    assert unmatched[0].constructor == p.cons
    assert isinstance(unmatched[0].patterns[1], relay.PatternConstructor)
    assert unmatched[0].patterns[1].constructor == p.cons
    assert isinstance(unmatched[0].patterns[1].patterns[1],
                      relay.PatternConstructor)
    assert unmatched[0].patterns[1].patterns[1].constructor == p.nil
Example #17
0
def test_adt_match():
    mod = relay.Module()
    box, constructor = initialize_box_adt(mod)

    v = relay.Var('v', relay.TensorType((), 'float32'))
    match = relay.Match(constructor(relay.const(0, 'float32')),
                        [relay.Clause(
                            relay.PatternConstructor(constructor,
                                                     [relay.PatternVar(v)]),
                            relay.Tuple([])),
                         # redundant but shouldn't matter to typechecking
                         relay.Clause(relay.PatternWildcard(),
                                      relay.Tuple([]))])

    mt = relay.ir_pass.infer_type(match, mod)
    assert mt.checked_type == relay.TupleType([])
Example #18
0
def test_wildcard_match_order():
    x = relay.Var('x', l(nat()))
    y = relay.Var('y')
    a = relay.Var('a')
    return_zero = relay.Function(
        [x],
        relay.Match(x, [
            relay.Clause(relay.PatternWildcard(), z()),
            relay.Clause(
                relay.PatternConstructor(
                    cons, [relay.PatternVar(y),
                           relay.PatternVar(a)]), y),
            relay.Clause(relay.PatternConstructor(nil), s(z()))
        ]), nat())

    res = intrp.evaluate(return_zero(cons(s(z()), nil())))
    # wildcard pattern is evaluated first
    assert count(res) == 0
Example #19
0
def test_multiple_constructor_clauses():
    mod = tvm.IRModule()
    p = Prelude(mod)
    _, cons, nil = mod.get_type("List")

    v = relay.Var("v")
    match = relay.Match(
        v,
        [
            # list of length exactly 1
            relay.Clause(
                relay.PatternConstructor(cons, [
                    relay.PatternWildcard(),
                    relay.PatternConstructor(nil, [])
                ]),
                v,
            ),
            # list of length exactly 2
            relay.Clause(
                relay.PatternConstructor(
                    cons,
                    [
                        relay.PatternWildcard(),
                        relay.PatternConstructor(cons, [
                            relay.PatternWildcard(),
                            relay.PatternConstructor(nil, [])
                        ]),
                    ],
                ),
                v,
            ),
            # empty list
            relay.Clause(relay.PatternConstructor(nil, []), v),
            # list of length 2 or more
            relay.Clause(
                relay.PatternConstructor(
                    cons,
                    [
                        relay.PatternWildcard(),
                        relay.PatternConstructor(
                            cons,
                            [relay.PatternWildcard(),
                             relay.PatternWildcard()]),
                    ],
                ),
                v,
            ),
        ],
    )
    assert len(unmatched_cases(match, mod)) == 0
Example #20
0
def test_adt_match():
    mod = tvm.IRModule()
    box, constructor = initialize_box_adt(mod)

    v = relay.Var("v", relay.TensorType((), "float32"))
    match = relay.Match(
        constructor(relay.const(0, "float32")),
        [
            relay.Clause(
                relay.PatternConstructor(constructor, [relay.PatternVar(v)]),
                relay.Tuple([])),
            # redundant but shouldn't matter to typechecking
            relay.Clause(relay.PatternWildcard(), relay.Tuple([])),
        ],
    )

    func = relay.Function([], match)
    mod["main"] = func
    mod = infer_mod(mod)
    actual = mod["main"].checked_type.ret_type
    assert actual == relay.TupleType([])
Example #21
0
def test_wildcard_match_order():
    x = relay.Var("x", rlist(nat()))
    y = relay.Var("y")
    a = relay.Var("a")
    return_zero = relay.Function(
        [x],
        relay.Match(
            x,
            [
                relay.Clause(relay.PatternWildcard(), z()),
                relay.Clause(
                    relay.PatternConstructor(
                        cons, [relay.PatternVar(y),
                               relay.PatternVar(a)]), y),
                relay.Clause(relay.PatternConstructor(nil), s(z())),
            ],
        ),
        nat(),
    )

    res = eval(return_zero(cons(s(z()), nil())))
    # wildcard pattern is evaluated first
    assert count(res) == 0
def test_mixed_adt_constructors():
    mod = tvm.IRModule()
    box = relay.GlobalTypeVar("box")
    a = relay.TypeVar("a")
    box_ctor = relay.Constructor("box", [a], box)
    box_data = relay.TypeData(box, [a], [box_ctor])
    mod[box] = box_data

    p = Prelude(mod)

    v = relay.Var("v")
    box_of_lists_inc = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(
                    box_ctor,
                    [
                        relay.PatternConstructor(
                            p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
                        )
                    ],
                ),
                v,
            )
        ],
    )

    # will fail to match a box containing an empty list
    unmatched = unmatched_cases(box_of_lists_inc, mod)
    assert len(unmatched) == 1
    assert isinstance(unmatched[0], relay.PatternConstructor)
    assert unmatched[0].constructor == box_ctor
    assert len(unmatched[0].patterns) == 1 and unmatched[0].patterns[0].constructor == p.nil

    box_of_lists_comp = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(box_ctor, [relay.PatternConstructor(p.nil, [])]), v
            ),
            relay.Clause(
                relay.PatternConstructor(
                    box_ctor,
                    [
                        relay.PatternConstructor(
                            p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
                        )
                    ],
                ),
                v,
            ),
        ],
    )
    assert len(unmatched_cases(box_of_lists_comp, mod)) == 0

    list_of_boxes_inc = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                        relay.PatternWildcard(),
                    ],
                ),
                v,
            )
        ],
    )

    # fails to match empty list of boxes
    unmatched = unmatched_cases(list_of_boxes_inc, mod)
    assert len(unmatched) == 1
    assert isinstance(unmatched[0], relay.PatternConstructor)
    assert unmatched[0].constructor == p.nil

    list_of_boxes_comp = relay.Match(
        v,
        [
            # exactly one box
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                        relay.PatternConstructor(p.nil, []),
                    ],
                ),
                v,
            ),
            # exactly two boxes
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                        relay.PatternConstructor(
                            p.cons,
                            [
                                relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                                relay.PatternConstructor(p.nil, []),
                            ],
                        ),
                    ],
                ),
                v,
            ),
            # exactly three boxes
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                        relay.PatternConstructor(
                            p.cons,
                            [
                                relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                                relay.PatternConstructor(
                                    p.cons,
                                    [
                                        relay.PatternConstructor(
                                            box_ctor, [relay.PatternWildcard()]
                                        ),
                                        relay.PatternConstructor(p.nil, []),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
                v,
            ),
            # one or more boxes
            relay.Clause(
                relay.PatternConstructor(
                    p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
                ),
                v,
            ),
            # no boxes
            relay.Clause(relay.PatternConstructor(p.nil, []), v),
        ],
    )
    assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0
            nil_found = True
        if case.constructor == p.cons:
            assert isinstance(case.patterns[1], relay.PatternConstructor)
            assert case.patterns[1].constructor == p.nil
            single_length_found = True
    assert nil_found and single_length_found

    # if we add a wildcard, this should work
    new_match = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternWildcard(),
                        relay.PatternConstructor(
                            p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
                        ),
                    ],
                ),
                v,
            ),
            relay.Clause(relay.PatternWildcard(), v),
        ],
    )
    assert len(unmatched_cases(new_match, mod)) == 0


def test_multiple_constructor_clauses():
    mod = tvm.IRModule()
Example #24
0
            nil_found = True
        if case.constructor == cons:
            assert isinstance(case.patterns[1], relay.PatternConstructor)
            assert case.patterns[1].constructor == nil
            single_length_found = True
    assert nil_found and single_length_found

    # if we add a wildcard, this should work
    new_match = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(
                    cons,
                    [
                        relay.PatternWildcard(),
                        relay.PatternConstructor(
                            cons,
                            [relay.PatternWildcard(),
                             relay.PatternWildcard()]),
                    ],
                ),
                v,
            ),
            relay.Clause(relay.PatternWildcard(), v),
        ],
    )
    assert len(unmatched_cases(new_match, mod)) == 0


def test_multiple_constructor_clauses():
def test_match_sequal():
    mod = tvm.IRModule()
    p = relay.prelude.Prelude(mod)

    x = relay.Var("x")
    y = relay.Var("y")
    nil_case = relay.Clause(relay.PatternConstructor(p.nil), p.nil())
    cons_case = relay.Clause(
        relay.PatternConstructor(p.cons, [relay.PatternVar(x), relay.PatternVar(y)]), p.cons(x, y)
    )

    z = relay.Var("z")
    a = relay.Var("a")
    equivalent_cons = relay.Clause(
        relay.PatternConstructor(p.cons, [relay.PatternVar(z), relay.PatternVar(a)]), p.cons(z, a)
    )

    data = p.cons(relay.const(1), p.cons(relay.const(2), p.nil()))

    match = relay.Match(data, [nil_case, cons_case])
    equivalent = relay.Match(data, [nil_case, equivalent_cons])
    empty = relay.Match(data, [])
    no_cons = relay.Match(data, [nil_case])
    no_nil = relay.Match(data, [cons_case])
    different_data = relay.Match(p.nil(), [nil_case, cons_case])
    different_order = relay.Match(data, [cons_case, nil_case])
    different_nil = relay.Match(
        data, [relay.Clause(relay.PatternConstructor(p.nil), p.cons(p.nil(), p.nil())), cons_case]
    )
    different_cons = relay.Match(
        data,
        [
            nil_case,
            relay.Clause(
                relay.PatternConstructor(
                    p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
                ),
                p.nil(),
            ),
        ],
    )
    another_case = relay.Match(
        data, [nil_case, cons_case, relay.Clause(relay.PatternWildcard(), p.nil())]
    )
    wrong_constructors = relay.Match(
        data,
        [
            relay.Clause(relay.PatternConstructor(p.none), p.nil()),
            relay.Clause(
                relay.PatternConstructor(p.some, [relay.PatternVar(x)]), p.cons(x, p.nil())
            ),
        ],
    )

    tvm.ir.assert_structural_equal(match, match)
    assert consistent_equal(match, match)
    assert consistent_equal(match, equivalent)
    assert not consistent_equal(match, no_cons)
    assert not consistent_equal(match, no_nil)
    assert not consistent_equal(match, empty)
    assert not consistent_equal(match, different_data)
    assert not consistent_equal(match, different_order)
    assert not consistent_equal(match, different_nil)
    assert not consistent_equal(match, different_cons)
    assert not consistent_equal(match, another_case)
    assert not consistent_equal(match, wrong_constructors)
    assert len(unmatched) == 2
    for case in unmatched:
        assert isinstance(case, relay.PatternConstructor)
        if case.constructor == p.nil:
            nil_found = True
        if case.constructor == p.cons:
            assert isinstance(case.patterns[1], relay.PatternConstructor)
            assert case.patterns[1].constructor == p.nil
            single_length_found = True
    assert nil_found and single_length_found

    # if we add a wildcard, this should work
    new_match = relay.Match(v, [
        relay.Clause(
            relay.PatternConstructor(
                p.cons, [relay.PatternWildcard(),
                         relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
                                                           relay.PatternWildcard()])]), v),
        relay.Clause(relay.PatternWildcard(), v)
    ])
    assert len(unmatched_cases(new_match, mod)) == 0


def test_multiple_constructor_clauses():
    mod = tvm.IRModule()
    p = Prelude(mod)

    v = relay.Var('v')
    match = relay.Match(v, [
        # list of length exactly 1
        relay.Clause(