示例#1
0
def test_typecall_kind():
    gtv = relay.GlobalTypeVar("gtv")

    mod = tvm.IRModule()
    data = relay.TypeData(gtv, [], [])
    mod[gtv] = data
    empty_call = relay.TypeCall(gtv, [])
    assert check_kind(empty_call, mod) == relay.TypeKind.Type

    new_mod = tvm.IRModule()
    tv = relay.TypeVar("tv")
    new_data = relay.TypeData(gtv, [tv], [])
    new_mod[gtv] = new_data
    call = relay.TypeCall(gtv, [relay.TupleType([])])
    assert check_kind(call, new_mod) == relay.TypeKind.Type
示例#2
0
def initialize_box_adt(mod):
    box = relay.GlobalTypeVar('box')
    tv = relay.TypeVar('tv')
    constructor = relay.Constructor('constructor', [tv], box)
    data = relay.TypeData(box, [tv], [constructor])
    mod[box] = data
    return (box, constructor)
示例#3
0
def test_typecall_invalid_num_args():
    mod = relay.Module()
    gtv = relay.GlobalTypeVar('v1')
    tv = relay.TypeVar('tv')
    data = relay.TypeData(gtv, [tv], [])
    mod[gtv] = data
    check_kind(relay.TypeCall(gtv, []))
示例#4
0
def test_typecall_invalid_num_args():
    mod = tvm.IRModule()
    gtv = relay.GlobalTypeVar("v1")
    tv = relay.TypeVar("tv")
    data = relay.TypeData(gtv, [tv], [])
    mod[gtv] = data
    check_kind(relay.TypeCall(gtv, []))
def test_single_constructor_adt():
    mod = tvm.IRModule()
    box = relay.GlobalTypeVar("box")
    a = relay.TypeVar("a")
    box_ctor = relay.Constructor("box", [a], box)
    box_data = relay.TypeData(box, [a], [box_ctor])
    mod[box] = box_data

    v = relay.Var("v")
    match = relay.Match(
        v, [relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), v)]
    )

    # with one constructor, having one pattern constructor case is exhaustive
    assert len(unmatched_cases(match, mod)) == 0

    # this will be so if we nest the constructors too
    nested_pattern = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(
                    box_ctor,
                    [
                        relay.PatternConstructor(
                            box_ctor,
                            [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()])],
                        )
                    ],
                ),
                v,
            )
        ],
    )
    assert len(unmatched_cases(nested_pattern, mod)) == 0
示例#6
0
def initialize_box_adt(mod):
    box = relay.GlobalTypeVar("box")
    tv = relay.TypeVar("tv")
    constructor = relay.Constructor("constructor", [tv], box)
    data = relay.TypeData(box, [tv], [constructor])
    mod[box] = data
    return (box, constructor)
示例#7
0
def test_adt_cons_expr():
    mod = tvm.IRModule()

    list_var = relay.GlobalTypeVar("List")
    typ_var = relay.TypeVar("A")
    cons_constructor = relay.Constructor("Cons",
                                         [typ_var, list_var(typ_var)],
                                         list_var)
    nil_constructor = relay.Constructor("Nil", [], list_var)
    list_def = relay.TypeData(list_var, [typ_var],
                              [cons_constructor, nil_constructor])
    mod[list_var] = list_def

    make_singleton_var = relay.GlobalVar("make_singleton")
    input_var = relay.Var("x", int32)
    make_singleton_func = relay.Function([input_var],
                                         cons_constructor(
                                             input_var, nil_constructor()),
                                         list_var(int32))
    mod[make_singleton_var] = make_singleton_func

    assert parses_as(
        """
        %s

        def @make_singleton(%%x: int32) -> List[int32] {
          Cons(%%x, Nil)
        }
        """ % LIST_DEFN, mod)
示例#8
0
def test_match():
    # pair each match keyword with whether it specifies a complete match or not
    match_keywords = [("match", True), ("match?", False)]
    for (match_keyword, is_complete) in match_keywords:
        mod = tvm.IRModule()

        list_var = relay.GlobalTypeVar("List")
        typ_var = relay.TypeVar("A")
        cons_constructor = relay.Constructor(
            "Cons", [typ_var, list_var(typ_var)], list_var)
        nil_constructor = relay.Constructor("Nil", [], list_var)
        list_def = relay.TypeData(list_var, [typ_var],
                                  [cons_constructor, nil_constructor])
        mod[list_var] = list_def

        length_var = relay.GlobalVar("length")
        typ_var = relay.TypeVar("A")
        input_type = list_var(typ_var)
        input_var = relay.Var("xs", input_type)
        rest_var = relay.Var("rest")
        cons_case = relay.Let(
            relay.var("", type_annotation=None),
            UNIT,
            relay.add(relay.const(1), relay.Call(length_var, [rest_var])),
        )
        body = relay.Match(
            input_var,
            [
                relay.Clause(
                    relay.PatternConstructor(
                        cons_constructor,
                        [relay.PatternWildcard(),
                         relay.PatternVar(rest_var)]),
                    cons_case,
                ),
                relay.Clause(relay.PatternConstructor(nil_constructor, []),
                             relay.const(0)),
            ],
            complete=is_complete,
        )
        length_func = relay.Function([input_var], body, int32, [typ_var])
        mod[length_var] = length_func

        assert_parse_module_as(
            """
            %s

            def @length[A](%%xs: List[A]) -> int32 {
              %s (%%xs) {
                Cons(_, %%rest : List[A]) => {
                  ();
                  1 + @length(%%rest)
                },
                Nil => 0,
              }
            }
            """ % (LIST_DEFN, match_keyword),
            mod,
        )
示例#9
0
def test_typecall_invalid_args():
    # args must all be type kind
    mod = tvm.IRModule()
    gtv = relay.GlobalTypeVar("v1")
    data = relay.TypeData(gtv, [], [])
    mod[gtv] = data

    check_kind(relay.TypeCall(gtv, [data]))
示例#10
0
def test_empty_adt_defn():
    mod = tvm.IRModule()

    glob_typ_var = relay.GlobalTypeVar("Ayy")
    prog = relay.TypeData(glob_typ_var, [], [])
    mod[glob_typ_var] = prog
    assert parses_as("""
        type Ayy { }
        """, mod)
示例#11
0
def test_adt_defn():
    mod = tvm.IRModule()

    glob_typ_var = relay.GlobalTypeVar("Ayy")
    prog = relay.TypeData(glob_typ_var, [],
                          [relay.Constructor("Nil", [], glob_typ_var)])
    mod[glob_typ_var] = prog
    assert parses_as("""
        type Ayy { Nil }
        """, mod)
示例#12
0
def test_multiple_cons_defn():
    mod = tvm.IRModule()

    list_var = relay.GlobalTypeVar("List")
    typ_var = relay.TypeVar("A")
    prog = relay.TypeData(list_var, [typ_var], [
        relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var),
        relay.Constructor("Nil", [], list_var),
    ])
    mod[list_var] = prog
    assert parses_as(LIST_DEFN, mod)
示例#13
0
def test_id_type():
    mod = relay.Module()
    id_type = relay.GlobalTypeVar("id")
    a = relay.TypeVar("a")
    mod[id_type] = relay.TypeData(id_type, [a], [])

    b = relay.TypeVar("b")
    make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
    t = relay.scalar_type("float32")
    b = relay.Var("b", t)
    assert relay.ir_pass.infer_type(make_id(b), mod).checked_type == id_type(t)
示例#14
0
def test_extern_adt_defn():
    mod = tvm.IRModule()

    extern_var = relay.GlobalTypeVar("T")
    typ_var = relay.TypeVar("A")
    extern_def = relay.TypeData(extern_var, [typ_var], [])
    mod[extern_var] = extern_def

    assert_parse_module_as("""
        extern type T[A]
        """, mod)
示例#15
0
def test_extern_adt_defn():
    # TODO(weberlo): update this test once extern is implemented
    mod = tvm.IRModule()

    extern_var = relay.GlobalTypeVar("T")
    typ_var = relay.TypeVar("A")
    extern_def = relay.TypeData(extern_var, [typ_var], [])
    mod[extern_var] = extern_def

    assert parses_as("""
        extern type T[A]
        """, mod)
示例#16
0
def test_id_type():
    mod = relay.Module()
    id_type = relay.GlobalTypeVar("id")
    a = relay.TypeVar("a")
    mod[id_type] = relay.TypeData(id_type, [a], [])

    b = relay.TypeVar("b")
    make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
    t = relay.scalar_type("float32")
    b = relay.Var("b", t)
    mod[mod.entry_func] = relay.Function([], make_id(b))
    mod = transform.InferType()(mod)
    assert mod[mod.entry_func].body.checked_type == id_type(t)
示例#17
0
def test_multiple_type_param_defn():
    glob_typ_var = relay.GlobalTypeVar("Either")
    typ_var_a = relay.TypeVar("A")
    typ_var_b = relay.TypeVar("B")
    prog = relay.TypeData(glob_typ_var, [typ_var_a, typ_var_b], [
        relay.Constructor("Left", [typ_var_a], glob_typ_var),
        relay.Constructor("Right", [typ_var_b], glob_typ_var),
    ])
    mod = tvm.IRModule()
    mod[glob_typ_var] = prog
    assert parses_as(
        """
        type Either[A, B] {
          Left(A),
          Right(B),
        }
        """, mod)
示例#18
0
def init_box_adt(mod):
    box = relay.GlobalTypeVar('box')
    a = relay.TypeVar('a')
    box_ctor = relay.Constructor('box', [a], box)
    mod[box] = relay.TypeData(box, [a], [box_ctor])
    return (box, box_ctor)
示例#19
0
# v0.7 Type Problem High-Level IR Transformation fixed
import tvm
from tvm import relay

M7BIu = tvm.IRModule()
YDc6p = relay.GlobalTypeVar('box')
fzeFg = relay.TypeVar('')
ErplN = relay.TypeData(YDc6p, [fzeFg], [])
M7BIu[YDc6p] = ErplN
M7BIu[YDc6p] = ErplN
def test_mixed_adt_constructors():
    mod = tvm.IRModule()
    box = relay.GlobalTypeVar("box")
    a = relay.TypeVar("a")
    box_ctor = relay.Constructor("box", [a], box)
    box_data = relay.TypeData(box, [a], [box_ctor])
    mod[box] = box_data

    p = Prelude(mod)

    v = relay.Var("v")
    box_of_lists_inc = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(
                    box_ctor,
                    [
                        relay.PatternConstructor(
                            p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
                        )
                    ],
                ),
                v,
            )
        ],
    )

    # will fail to match a box containing an empty list
    unmatched = unmatched_cases(box_of_lists_inc, mod)
    assert len(unmatched) == 1
    assert isinstance(unmatched[0], relay.PatternConstructor)
    assert unmatched[0].constructor == box_ctor
    assert len(unmatched[0].patterns) == 1 and unmatched[0].patterns[0].constructor == p.nil

    box_of_lists_comp = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(box_ctor, [relay.PatternConstructor(p.nil, [])]), v
            ),
            relay.Clause(
                relay.PatternConstructor(
                    box_ctor,
                    [
                        relay.PatternConstructor(
                            p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
                        )
                    ],
                ),
                v,
            ),
        ],
    )
    assert len(unmatched_cases(box_of_lists_comp, mod)) == 0

    list_of_boxes_inc = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                        relay.PatternWildcard(),
                    ],
                ),
                v,
            )
        ],
    )

    # fails to match empty list of boxes
    unmatched = unmatched_cases(list_of_boxes_inc, mod)
    assert len(unmatched) == 1
    assert isinstance(unmatched[0], relay.PatternConstructor)
    assert unmatched[0].constructor == p.nil

    list_of_boxes_comp = relay.Match(
        v,
        [
            # exactly one box
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                        relay.PatternConstructor(p.nil, []),
                    ],
                ),
                v,
            ),
            # exactly two boxes
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                        relay.PatternConstructor(
                            p.cons,
                            [
                                relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                                relay.PatternConstructor(p.nil, []),
                            ],
                        ),
                    ],
                ),
                v,
            ),
            # exactly three boxes
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                        relay.PatternConstructor(
                            p.cons,
                            [
                                relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                                relay.PatternConstructor(
                                    p.cons,
                                    [
                                        relay.PatternConstructor(
                                            box_ctor, [relay.PatternWildcard()]
                                        ),
                                        relay.PatternConstructor(p.nil, []),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
                v,
            ),
            # one or more boxes
            relay.Clause(
                relay.PatternConstructor(
                    p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
                ),
                v,
            ),
            # no boxes
            relay.Clause(relay.PatternConstructor(p.nil, []), v),
        ],
    )
    assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0