Exemplo n.º 1
0
def test_multiple_constructor_clauses():
    mod = tvm.IRModule()
    p = Prelude(mod)

    v = relay.Var('v')
    match = relay.Match(v, [
        # list of length exactly 1
        relay.Clause(
            relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
                                              relay.PatternConstructor(p.nil, [])]), v),
        # list of length exactly 2
        relay.Clause(
            relay.PatternConstructor(
                p.cons, [relay.PatternWildcard(),
                         relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
                                                           relay.PatternConstructor(p.nil, [])
                         ])]), v),
        # empty list
        relay.Clause(
            relay.PatternConstructor(p.nil, []), v),
        # list of length 2 or more
        relay.Clause(
            relay.PatternConstructor(
                p.cons, [relay.PatternWildcard(),
                         relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
                                                           relay.PatternWildcard()])]), v)
    ])
    assert len(unmatched_cases(match, mod)) == 0
Exemplo n.º 2
0
def test_nested_matches():
    a = relay.TypeVar("a")
    # TODO(@jroesch): inference should be able to handle this one
    x = relay.Var("x", type_annotation=rlist(rlist(a)))
    y = relay.Var("y")
    w = relay.Var("w")
    h = relay.Var("h")
    t = relay.Var("t")
    flatten = relay.GlobalVar("flatten")

    # flatten could be written using a fold, but this way has nested matches
    inner_match = relay.Match(
        y,
        [
            relay.Clause(relay.PatternConstructor(nil), flatten(w)),
            relay.Clause(
                relay.PatternConstructor(
                    cons, [relay.PatternVar(h),
                           relay.PatternVar(t)]),
                cons(h, flatten(cons(t, w))),
            ),
        ],
    )

    prelude.mod[flatten] = relay.Function(
        [x],
        relay.Match(
            x,
            [
                relay.Clause(relay.PatternConstructor(nil), nil()),
                relay.Clause(
                    relay.PatternConstructor(
                        cons, [relay.PatternVar(y),
                               relay.PatternVar(w)]),
                    inner_match,
                ),
            ],
        ),
        rlist(a),
        [a],
    )

    first_list = cons(
        make_nat_expr(prelude, 1),
        cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3),
                                             nil())),
    )
    second_list = cons(
        make_nat_expr(prelude, 4),
        cons(make_nat_expr(prelude, 5), cons(make_nat_expr(prelude, 6),
                                             nil())),
    )
    final_list = cons(first_list, cons(second_list, nil()))

    res = intrp.evaluate(flatten(final_list))

    flat = to_list(res)
    assert len(flat) == 6
    for i in range(6):
        assert count(flat[i]) == i + 1
Exemplo n.º 3
0
def test_match_effect_exactly_once():
    mod = tvm.IRModule()
    p = Prelude(mod)
    _, cons, nil = p.mod.get_type("List")

    # the list should be of length 1!
    # Unless we mistakenly execute the data clause more than once
    r = relay.Var("r")
    data = seq(relay.RefWrite(r, cons(relay.Tuple([]), relay.RefRead(r))),
               relay.RefRead(r))
    match = relay.Let(
        r,
        relay.RefCreate(nil()),
        relay.Match(
            data,
            [
                relay.Clause(relay.PatternConstructor(nil, []),
                             relay.const(0)),
                relay.Clause(
                    relay.PatternConstructor(cons, [
                        relay.PatternWildcard(),
                        relay.PatternConstructor(nil, [])
                    ]),
                    relay.const(1),
                ),
                relay.Clause(relay.PatternWildcard(), relay.const(2)),
            ],
        ),
    )

    match_val = run_as_python(match, mod)
    assert_tensor_value(match_val, 1)
Exemplo n.º 4
0
def test_global_recursion():
    mod = relay.Module()
    p = Prelude(mod)
    copy = relay.GlobalVar('copy')
    # same as above: it copies the given list
    a = relay.TypeVar('a')
    v = relay.Var('v', p.l(a))
    h = relay.Var('h')
    t = relay.Var('t')
    copy_def = relay.Function(
        [v],
        relay.Match(v, [
            relay.Clause(
                relay.PatternConstructor(
                    p.cons, [relay.PatternVar(h),
                             relay.PatternVar(t)]), p.cons(h, copy(t))),
            relay.Clause(relay.PatternConstructor(p.nil, []), p.nil())
        ]), p.l(a), [a])
    mod[copy] = copy_def

    call1 = copy_def(p.cons(relay.const(1), p.cons(relay.const(2), p.nil())))
    val1 = run_as_python(call1, mod)
    assert_constructor_value(val1, p.cons, 2)
    assert_tensor_value(val1.fields[0], 1)
    assert_constructor_value(val1.fields[1], p.cons, 2)
    assert_tensor_value(val1.fields[1].fields[0], 2)
    assert_constructor_value(val1.fields[1].fields[1], p.nil, 0)

    call2 = copy_def(p.cons(relay.Tuple([]), p.nil()))
    val2 = run_as_python(call2, mod)
    assert_constructor_value(val2, p.cons, 2)
    assert_adt_len(val2.fields[0], 0)
    assert_constructor_value(val2.fields[1], p.nil, 0)
Exemplo n.º 5
0
def test_optional_matching():
    x = relay.Var("x")
    y = relay.Var("y")
    v = relay.Var("v")
    condense = relay.Function(
        [x, y],
        relay.Match(
            x,
            [
                relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(v)]), cons(v, y)),
                relay.Clause(relay.PatternConstructor(none), y),
            ],
        ),
    )

    res = intrp.evaluate(
        foldr(
            condense,
            nil(),
            cons(some(make_nat_expr(3)), cons(none(), cons(some(make_nat_expr(1)), nil()))),
        )
    )

    reduced = to_list(res)
    assert len(reduced) == 2
    assert count(reduced[0]) == 3
    assert count(reduced[1]) == 1
Exemplo n.º 6
0
def test_match_order():
    mod = tvm.IRModule()
    box, box_ctor = init_box_adt(mod)
    v = relay.Var("v")
    w = relay.Var("w")
    # wildcard pattern goes first
    match = relay.Let(
        v,
        box_ctor(box_ctor(relay.const(2))),
        relay.Match(
            v,
            [
                relay.Clause(relay.PatternWildcard(), relay.const(1)),
                relay.Clause(
                    relay.PatternConstructor(box_ctor, [
                        relay.PatternConstructor(box_ctor,
                                                 [relay.PatternVar(w)])
                    ]),
                    w,
                ),
            ],
        ),
    )
    match_val = run_as_python(match, mod)
    assert_tensor_value(match_val, 1)
Exemplo n.º 7
0
def test_nested_pattern_match():
    x = relay.Var("x", l(nat()))
    h1 = relay.Var("h1")
    h2 = relay.Var("h2")
    t = relay.Var("t")
    match = relay.Match(
        x,
        [
            relay.Clause(
                relay.PatternConstructor(
                    cons,
                    [
                        relay.PatternVar(h1),
                        relay.PatternConstructor(cons, [relay.PatternVar(h2), relay.PatternVar(t)]),
                    ],
                ),
                h2,
            ),
            relay.Clause(relay.PatternWildcard(), z()),
        ],
    )
    get_second = relay.Function([x], match)

    res = intrp.evaluate(get_second(cons(s(z()), cons(s(s(z())), nil()))))

    assert count(res) == 2
Exemplo n.º 8
0
def test_single_constructor_adt():
    mod = tvm.IRModule()
    box = relay.GlobalTypeVar("box")
    a = relay.TypeVar("a")
    box_ctor = relay.Constructor("box", [a], box)
    box_data = relay.TypeData(box, [a], [box_ctor])
    mod[box] = box_data

    v = relay.Var("v")
    match = relay.Match(
        v, [relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), v)]
    )

    # with one constructor, having one pattern constructor case is exhaustive
    assert len(unmatched_cases(match, mod)) == 0

    # this will be so if we nest the constructors too
    nested_pattern = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(
                    box_ctor,
                    [
                        relay.PatternConstructor(
                            box_ctor,
                            [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()])],
                        )
                    ],
                ),
                v,
            )
        ],
    )
    assert len(unmatched_cases(nested_pattern, mod)) == 0
Exemplo n.º 9
0
def test_filter():
    a = relay.TypeVar("a")
    expected_type = relay.FuncType(
        [relay.FuncType([a], relay.scalar_type("bool")),
         l(a)], l(a), [a])
    assert mod[filter].checked_type == expected_type

    x = relay.Var("x", nat())
    greater_than_one = relay.Function(
        [x],
        relay.Match(x, [
            relay.Clause(
                relay.PatternConstructor(
                    s,
                    [relay.PatternConstructor(s, [relay.PatternWildcard()])]),
                relay.const(True)),
            relay.Clause(relay.PatternWildcard(), relay.const(False))
        ]))
    res = intrp.evaluate(
        filter(
            greater_than_one,
            cons(
                make_nat_expr(1),
                cons(
                    make_nat_expr(1),
                    cons(
                        make_nat_expr(3),
                        cons(
                            make_nat_expr(1),
                            cons(make_nat_expr(5),
                                 cons(make_nat_expr(1), nil()))))))))
    filtered = to_list(res)
    assert len(filtered) == 2
    assert count(filtered[0]) == 3
    assert count(filtered[1]) == 5
Exemplo n.º 10
0
def test_local_recursion():
    mod = relay.Module()
    p = Prelude(mod)

    v = relay.Var('v')
    h = relay.Var('h')
    t = relay.Var('t')
    f = relay.Var('f')

    # just returns the same list
    let = relay.Let(
        f,
        relay.Function(
            [v],
            relay.Match(v, [
                relay.Clause(
                    relay.PatternConstructor(
                        p.cons, [relay.PatternVar(h),
                                 relay.PatternVar(t)]), p.cons(h, f(t))),
                relay.Clause(relay.PatternConstructor(p.nil, []), p.nil())
            ])),
        f(
            p.cons(relay.const(1),
                   p.cons(relay.const(2), p.cons(relay.const(3), p.nil())))))

    val = run_as_python(let, mod)
    assert_constructor_value(val, p.cons, 2)
    assert_tensor_value(val.fields[0], 1)
    assert_constructor_value(val.fields[1], p.cons, 2)
    assert_tensor_value(val.fields[1].fields[0], 2)
    assert_constructor_value(val.fields[1].fields[1], p.cons, 2)
    assert_tensor_value(val.fields[1].fields[1].fields[0], 3)
    assert_constructor_value(val.fields[1].fields[1].fields[1], p.nil, 0)
def test_match_vars():
    mod = tvm.IRModule()
    p = relay.prelude.Prelude(mod)

    x = relay.Var('x')
    y = relay.Var('y')
    z = relay.Var('z')

    match1 = relay.Match(p.nil(), [
        relay.Clause(relay.PatternConstructor(p.nil), z),
        relay.Clause(
            relay.PatternConstructor(
                p.cons,
                [relay.PatternVar(x), relay.PatternVar(y)]), p.cons(x, y))
    ])

    match2 = relay.Match(p.nil(), [
        relay.Clause(
            relay.PatternConstructor(
                p.cons, [relay.PatternWildcard(),
                         relay.PatternVar(x)]), y),
        relay.Clause(relay.PatternWildcard(), z)
    ])

    assert_vars_match(bound_vars(match1), [x, y])
    assert_vars_match(free_vars(match1), [z])
    assert_vars_match(all_vars(match1), [z, x, y])

    assert_vars_match(bound_vars(match2), [x])
    assert_vars_match(free_vars(match2), [y, z])
    assert_vars_match(all_vars(match2), [x, y, z])
Exemplo n.º 12
0
def test_unfoldl():
    a = relay.TypeVar("a")
    b = relay.TypeVar("b")
    expected_type = relay.FuncType(
        [relay.FuncType([a], optional(relay.TupleType([a, b]))), a], l(b), [a, b]
    )

    x = relay.Var("x", nat())
    n = relay.Var("n", nat())
    count_down = relay.Function(
        [x],
        relay.Match(
            x,
            [
                relay.Clause(
                    relay.PatternConstructor(s, [relay.PatternVar(n)]), some(relay.Tuple([n, x]))
                ),
                relay.Clause(relay.PatternConstructor(z, []), none()),
            ],
        ),
    )

    res = intrp.evaluate(unfoldl(count_down, make_nat_expr(3)))
    unfolded = to_list(res)

    assert len(unfolded) == 3
    assert count(unfolded[0]) == 1
    assert count(unfolded[1]) == 2
    assert count(unfolded[2]) == 3
Exemplo n.º 13
0
def test_match():
    # pair each match keyword with whether it specifies a complete match or not
    match_keywords = [("match", True), ("match?", False)]
    for (match_keyword, is_complete) in match_keywords:
        mod = tvm.IRModule()

        list_var = relay.GlobalTypeVar("List")
        typ_var = relay.TypeVar("A")
        cons_constructor = relay.Constructor(
            "Cons", [typ_var, list_var(typ_var)], list_var)
        nil_constructor = relay.Constructor("Nil", [], list_var)
        list_def = relay.TypeData(list_var, [typ_var],
                                  [cons_constructor, nil_constructor])
        mod[list_var] = list_def

        length_var = relay.GlobalVar("length")
        typ_var = relay.TypeVar("A")
        input_type = list_var(typ_var)
        input_var = relay.Var("xs", input_type)
        rest_var = relay.Var("rest")
        cons_case = relay.Let(
            relay.var("", type_annotation=None),
            UNIT,
            relay.add(relay.const(1), relay.Call(length_var, [rest_var])),
        )
        body = relay.Match(
            input_var,
            [
                relay.Clause(
                    relay.PatternConstructor(
                        cons_constructor,
                        [relay.PatternWildcard(),
                         relay.PatternVar(rest_var)]),
                    cons_case,
                ),
                relay.Clause(relay.PatternConstructor(nil_constructor, []),
                             relay.const(0)),
            ],
            complete=is_complete,
        )
        length_func = relay.Function([input_var], body, int32, [typ_var])
        mod[length_var] = length_func

        assert_parse_module_as(
            """
            %s

            def @length[A](%%xs: List[A]) -> int32 {
              %s (%%xs) {
                Cons(_, %%rest : List[A]) => {
                  ();
                  1 + @length(%%rest)
                },
                Nil => 0,
              }
            }
            """ % (LIST_DEFN, match_keyword),
            mod,
        )
Exemplo n.º 14
0
def test_trivial_matches():
    # a match clause with a wildcard will match anything
    v = relay.Var("v")
    match = relay.Match(v, [relay.Clause(relay.PatternWildcard(), v)])
    assert len(unmatched_cases(match)) == 0

    # same with a pattern var
    w = relay.Var("w")
    match = relay.Match(v, [relay.Clause(relay.PatternVar(w), w)])
    assert len(unmatched_cases(match)) == 0
def test_tuple_match():
    a = relay.Var("a")
    b = relay.Var("b")
    clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
    x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])

    a = relay.Var("a")
    b = relay.Var("b")
    clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
    y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
    assert consistent_equal(x, y)
Exemplo n.º 16
0
def test_adt():
    mod = relay.Module()
    p = Prelude(mod)
    x = relay.Var("x")
    s_case = relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]),
                          x)
    default_case = relay.Clause(relay.PatternVar(x), x)
    m0 = relay.Match(p.z(), [default_case])
    m1 = relay.Match(p.z(), [s_case, default_case])
    assert well_formed(m0)
    assert not well_formed(m1)
Exemplo n.º 17
0
def test_tuple_match():
    a = relay.Var("a")
    b = relay.Var("b")
    clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
    x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])

    a = relay.Var("a")
    b = relay.Var("b")
    clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
    y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
    assert analysis.alpha_equal(x, y)
    assert analysis.structural_hash(x) == analysis.structural_hash(y)
Exemplo n.º 18
0
def test_missing_in_the_middle():
    mod = tvm.IRModule()
    p = Prelude(mod)
    _, cons, nil = mod.get_type("List")

    v = relay.Var("v")
    match = relay.Match(
        v,
        [
            # list of length exactly 1
            relay.Clause(
                relay.PatternConstructor(cons, [
                    relay.PatternWildcard(),
                    relay.PatternConstructor(nil, [])
                ]),
                v,
            ),
            # empty list
            relay.Clause(relay.PatternConstructor(nil, []), v),
            # list of length 3 or more
            relay.Clause(
                relay.PatternConstructor(
                    cons,
                    [
                        relay.PatternWildcard(),
                        relay.PatternConstructor(
                            cons,
                            [
                                relay.PatternWildcard(),
                                relay.PatternConstructor(
                                    cons, [
                                        relay.PatternWildcard(),
                                        relay.PatternWildcard()
                                    ]),
                            ],
                        ),
                    ],
                ),
                v,
            ),
        ],
    )

    # fails to match a list of length exactly two
    unmatched = unmatched_cases(match, mod)
    assert len(unmatched) == 1
    assert isinstance(unmatched[0], relay.PatternConstructor)
    assert unmatched[0].constructor == cons
    assert isinstance(unmatched[0].patterns[1], relay.PatternConstructor)
    assert unmatched[0].patterns[1].constructor == cons
    assert isinstance(unmatched[0].patterns[1].patterns[1],
                      relay.PatternConstructor)
    assert unmatched[0].patterns[1].patterns[1].constructor == nil
Exemplo n.º 19
0
def test_adt():
    mod = tvm.IRModule()
    p = Prelude(mod)
    _, none, some = p.mod.get_type("Option")
    x = relay.Var("x")
    some_case = relay.Clause(
        relay.PatternConstructor(some, [relay.PatternVar(x)]), x)
    default_case = relay.Clause(relay.PatternVar(x), x)
    m0 = relay.Match(none(), [default_case])
    m1 = relay.Match(none(), [some_case, default_case])
    assert well_formed(m0)
    assert not well_formed(m1)
Exemplo n.º 20
0
def test_multiple_constructor_clauses():
    mod = tvm.IRModule()
    p = Prelude(mod)
    _, cons, nil = mod.get_type("List")

    v = relay.Var("v")
    match = relay.Match(
        v,
        [
            # list of length exactly 1
            relay.Clause(
                relay.PatternConstructor(cons, [
                    relay.PatternWildcard(),
                    relay.PatternConstructor(nil, [])
                ]),
                v,
            ),
            # list of length exactly 2
            relay.Clause(
                relay.PatternConstructor(
                    cons,
                    [
                        relay.PatternWildcard(),
                        relay.PatternConstructor(cons, [
                            relay.PatternWildcard(),
                            relay.PatternConstructor(nil, [])
                        ]),
                    ],
                ),
                v,
            ),
            # empty list
            relay.Clause(relay.PatternConstructor(nil, []), v),
            # list of length 2 or more
            relay.Clause(
                relay.PatternConstructor(
                    cons,
                    [
                        relay.PatternWildcard(),
                        relay.PatternConstructor(
                            cons,
                            [relay.PatternWildcard(),
                             relay.PatternWildcard()]),
                    ],
                ),
                v,
            ),
        ],
    )
    assert len(unmatched_cases(match, mod)) == 0
Exemplo n.º 21
0
def test_match_alpha_equal():
    mod = relay.Module()
    p = relay.prelude.Prelude(mod)

    x = relay.Var('x')
    y = relay.Var('y')
    nil_case = relay.Clause(relay.PatternConstructor(p.nil), p.nil())
    cons_case = relay.Clause(
        relay.PatternConstructor(
            p.cons,
            [relay.PatternVar(x), relay.PatternVar(y)]), p.cons(x, y))

    z = relay.Var('z')
    a = relay.Var('a')
    equivalent_cons = relay.Clause(
        relay.PatternConstructor(
            p.cons,
            [relay.PatternVar(z), relay.PatternVar(a)]), p.cons(z, a))

    data = p.cons(relay.const(1), p.cons(relay.const(2), p.nil()))

    match = relay.Match(data, [nil_case, cons_case])
    equivalent = relay.Match(data, [nil_case, equivalent_cons])
    empty = relay.Match(data, [])
    no_cons = relay.Match(data, [nil_case])
    no_nil = relay.Match(data, [cons_case])
    different_data = relay.Match(p.nil(), [nil_case, cons_case])
    different_order = relay.Match(data, [cons_case, nil_case])
    different_nil = relay.Match(data, [
        relay.Clause(relay.PatternConstructor(p.nil), p.cons(
            p.nil(), p.nil())), cons_case
    ])
    different_cons = relay.Match(data, [
        nil_case,
        relay.Clause(
            relay.PatternConstructor(
                p.cons, [relay.PatternWildcard(),
                         relay.PatternWildcard()]), p.nil())
    ])
    another_case = relay.Match(
        data,
        [nil_case, cons_case,
         relay.Clause(relay.PatternWildcard(), p.nil())])
    wrong_constructors = relay.Match(data, [
        relay.Clause(relay.PatternConstructor(p.none), p.nil()),
        relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]),
                     p.cons(x, p.nil()))
    ])

    assert alpha_equal(match, match)
    assert alpha_equal(match, equivalent)
    assert not alpha_equal(match, no_cons)
    assert not alpha_equal(match, no_nil)
    assert not alpha_equal(match, empty)
    assert not alpha_equal(match, different_data)
    assert not alpha_equal(match, different_order)
    assert not alpha_equal(match, different_nil)
    assert not alpha_equal(match, different_cons)
    assert not alpha_equal(match, another_case)
    assert not alpha_equal(match, wrong_constructors)
Exemplo n.º 22
0
def test_adt_match():
    mod = relay.Module()
    box, constructor = initialize_box_adt(mod)

    v = relay.Var('v', relay.TensorType((), 'float32'))
    match = relay.Match(constructor(relay.const(0, 'float32')),
                        [relay.Clause(
                            relay.PatternConstructor(constructor,
                                                     [relay.PatternVar(v)]),
                            relay.Tuple([])),
                         # redundant but shouldn't matter to typechecking
                         relay.Clause(relay.PatternWildcard(),
                                      relay.Tuple([]))])

    mt = relay.ir_pass.infer_type(match, mod)
    assert mt.checked_type == relay.TupleType([])
def test_too_specific_match():
    mod = tvm.IRModule()
    p = Prelude(mod)

    v = relay.Var('v')
    match = relay.Match(v, [
        relay.Clause(
            relay.PatternConstructor(p.cons, [
                relay.PatternWildcard(),
                relay.PatternConstructor(
                    p.cons, [relay.PatternWildcard(),
                             relay.PatternWildcard()])
            ]), v)
    ])

    unmatched = unmatched_cases(match, mod)

    # will not match nil or a list of length 1
    nil_found = False
    single_length_found = False
    assert len(unmatched) == 2
    for case in unmatched:
        assert isinstance(case, relay.PatternConstructor)
        if case.constructor == p.nil:
            nil_found = True
        if case.constructor == p.cons:
            assert isinstance(case.patterns[1], relay.PatternConstructor)
            assert case.patterns[1].constructor == p.nil
            single_length_found = True
Exemplo n.º 24
0
def test_nested_matches():
    a = relay.TypeVar("a")
    x = relay.Var("x")
    y = relay.Var("y")
    w = relay.Var("w")
    h = relay.Var("h")
    t = relay.Var("t")
    flatten = relay.GlobalVar("flatten")

    # flatten could be written using a fold, but this way has nested matches
    inner_match = relay.Match(
        y,
        [
            relay.Clause(relay.PatternConstructor(nil), flatten(w)),
            relay.Clause(
                relay.PatternConstructor(cons, [relay.PatternVar(h), relay.PatternVar(t)]),
                cons(h, flatten(cons(t, w))),
            ),
        ],
    )

    mod[flatten] = relay.Function(
        [x],
        relay.Match(
            x,
            [
                relay.Clause(relay.PatternConstructor(nil), nil()),
                relay.Clause(
                    relay.PatternConstructor(cons, [relay.PatternVar(y), relay.PatternVar(w)]),
                    inner_match,
                ),
            ],
        ),
        l(a),
        [a],
    )

    first_list = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
    second_list = cons(make_nat_expr(4), cons(make_nat_expr(5), cons(make_nat_expr(6), nil())))
    final_list = cons(first_list, cons(second_list, nil()))

    res = intrp.evaluate(flatten(final_list))

    flat = to_list(res)
    assert len(flat) == 6
    for i in range(6):
        assert count(flat[i]) == i + 1
Exemplo n.º 25
0
def test_tuple_match():
    a = relay.Var("a")
    b = relay.Var("b")
    clause = relay.Clause(
        relay.PatternTuple([relay.PatternVar(a),
                            relay.PatternVar(b)]), a + b)
    x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
    tvm.ir.assert_structural_equal(dcpe(x), const(2))
def test_tuple_match():
    a = relay.Var("a")
    b = relay.Var("b")
    clause = relay.Clause(
        relay.PatternTuple([relay.PatternVar(a),
                            relay.PatternVar(b)]), a + b)
    x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
    assert len(unmatched_cases(x)) == 0
Exemplo n.º 27
0
def test_wildcard_match_solo():
    x = relay.Var('x', nat())
    copy = relay.Function([x],
                          relay.Match(x, [relay.Clause(relay.PatternWildcard(), x)]),
                          nat())

    res = intrp.evaluate(copy(s(s(s(z())))))
    assert count(res) == 3
Exemplo n.º 28
0
def test_wildcard_match_order():
    x = relay.Var('x', l(nat()))
    y = relay.Var('y')
    a = relay.Var('a')
    return_zero = relay.Function(
        [x],
        relay.Match(x, [
            relay.Clause(relay.PatternWildcard(), z()),
            relay.Clause(
                relay.PatternConstructor(
                    cons, [relay.PatternVar(y),
                           relay.PatternVar(a)]), y),
            relay.Clause(relay.PatternConstructor(nil), s(z()))
        ]), nat())

    res = intrp.evaluate(return_zero(cons(s(z()), nil())))
    # wildcard pattern is evaluated first
    assert count(res) == 0
Exemplo n.º 29
0
def test_global_recursion():
    mod = tvm.IRModule()
    p = Prelude(mod)
    rlist, cons, nil = p.mod.get_type("List")

    copy = relay.GlobalVar("copy")
    # same as above: it copies the given list
    a = relay.TypeVar("a")
    v = relay.Var("v", rlist(a))
    h = relay.Var("h")
    t = relay.Var("t")
    copy_def = relay.Function(
        [v],
        relay.Match(
            v,
            [
                relay.Clause(
                    relay.PatternConstructor(
                        cons, [relay.PatternVar(h),
                               relay.PatternVar(t)]),
                    cons(h, copy(t)),
                ),
                relay.Clause(relay.PatternConstructor(nil, []), nil()),
            ],
        ),
        rlist(a),
        [a],
    )
    mod[copy] = copy_def

    call1 = copy_def(cons(relay.const(1), cons(relay.const(2), nil())))
    val1 = run_as_python(call1, mod)
    assert_constructor_value(val1, cons, 2)
    assert_tensor_value(val1.fields[0], 1)
    assert_constructor_value(val1.fields[1], cons, 2)
    assert_tensor_value(val1.fields[1].fields[0], 2)
    assert_constructor_value(val1.fields[1].fields[1], nil, 0)

    call2 = copy_def(cons(relay.Tuple([]), nil()))
    val2 = run_as_python(call2, mod)
    assert_constructor_value(val2, cons, 2)
    assert_adt_len(val2.fields[0], 0)
    assert_constructor_value(val2.fields[1], nil, 0)
Exemplo n.º 30
0
def test_optional_matching():
    x = relay.Var('x')
    y = relay.Var('y')
    v = relay.Var('v')
    condense = relay.Function(
        [x, y],
        relay.Match(x, [
            relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(v)]), cons(v, y)),
            relay.Clause(relay.PatternConstructor(none), y)
        ]))

    res = intrp.evaluate(foldr(condense, nil(), cons(
        some(build_nat(3)),
        cons(none(), cons(some(build_nat(1)), nil())))))

    reduced = to_list(res)
    assert len(reduced) == 2
    assert count(reduced[0]) == 3
    assert count(reduced[1]) == 1