Ejemplo n.º 1
0
def test_adt_defn():
    mod = tvm.IRModule()

    glob_typ_var = relay.GlobalTypeVar("Ayy")
    prog = relay.TypeData(glob_typ_var, [],
                          [relay.Constructor("Nil", [], glob_typ_var)])
    mod[glob_typ_var] = prog
    assert_parses_as("""
        type Ayy { Nil }
        """, mod)
Ejemplo n.º 2
0
def test_id_type():
    mod = relay.Module()
    id_type = relay.GlobalTypeVar("id")
    a = relay.TypeVar("a")
    mod[id_type] = relay.TypeData(id_type, [a], [])

    b = relay.TypeVar("b")
    make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
    t = relay.scalar_type("float32")
    b = relay.Var("b", t)
    assert relay.ir_pass.infer_type(make_id(b), mod).checked_type == id_type(t)
Ejemplo n.º 3
0
def test_multiple_cons_defn():
    mod = tvm.IRModule()

    list_var = relay.GlobalTypeVar("List")
    typ_var = relay.TypeVar("A")
    prog = relay.TypeData(list_var, [typ_var], [
        relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var),
        relay.Constructor("Nil", [], list_var),
    ])
    mod[list_var] = prog
    assert parses_as(LIST_DEFN, mod)
Ejemplo n.º 4
0
def test_extern_adt_defn():
    mod = tvm.IRModule()

    extern_var = relay.GlobalTypeVar("T")
    typ_var = relay.TypeVar("A")
    extern_def = relay.TypeData(extern_var, [typ_var], [])
    mod[extern_var] = extern_def

    assert_parse_module_as("""
        extern type T[A]
        """, mod)
Ejemplo n.º 5
0
def test_empty_adt_defn():
    mod = tvm.IRModule()

    glob_typ_var = relay.GlobalTypeVar("Ayy")
    prog = relay.TypeData(glob_typ_var, [], [])
    mod[glob_typ_var] = prog
    assert_parse_module_as(
        """
        type Ayy { }
        """,
        mod,
    )
Ejemplo n.º 6
0
def test_extern_adt_defn():
    # TODO(weberlo): update this test once extern is implemented
    mod = tvm.IRModule()

    extern_var = relay.GlobalTypeVar("T")
    typ_var = relay.TypeVar("A")
    extern_def = relay.TypeData(extern_var, [typ_var], [])
    mod[extern_var] = extern_def

    assert parses_as("""
        extern type T[A]
        """, mod)
Ejemplo n.º 7
0
def test_type_call_alpha_equal():
    h1 = relay.GlobalTypeVar("h1")
    h2 = relay.GlobalTypeVar("h2")
    t1 = relay.TensorType((1, 2), "float32")
    t2 = relay.TensorType((1, 2, 3), "float32")
    t3 = relay.TensorType((1, 2, 3, 4), "float32")
    t4 = relay.TensorType((), "float32")

    tc = relay.TypeCall(h1, [t1, t2, t3])
    same = relay.TypeCall(h1, [t1, t2, t3])

    different_func = relay.TypeCall(h2, [t1, t2, t3])
    different_arg = relay.TypeCall(h1, [t1, t2, t4])
    fewer_args = relay.TypeCall(h1, [t1, t2])
    more_args = relay.TypeCall(h1, [t1, t2, t3, t4])
    different_order_args = relay.TypeCall(h1, [t3, t2, t1])

    assert tc == same
    assert tc != different_func
    assert tc != fewer_args
    assert tc != more_args
    assert tc != different_order_args
Ejemplo n.º 8
0
def test_incompatible_typecall_args_unification():
    solver = make_solver()
    gtv = relay.GlobalTypeVar("gtv1")
    t1 = relay.IncompleteType()
    t2 = relay.IncompleteType()

    tensor1 = relay.TensorType((1, 2, 3), "float32")
    tensor2 = relay.TensorType((2, 3), "float32")
    tensor3 = relay.TensorType((3, ), "float32")

    tc1 = relay.TypeCall(gtv, [relay.TupleType([t1, t1]), t2])
    tc2 = relay.TypeCall(gtv, [relay.TupleType([tensor1, tensor2]), tensor3])
    solver.Unify(tc1, tc2)
Ejemplo n.º 9
0
def test_id_type():
    mod = relay.Module()
    id_type = relay.GlobalTypeVar("id")
    a = relay.TypeVar("a")
    mod[id_type] = relay.TypeData(id_type, [a], [])

    b = relay.TypeVar("b")
    make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
    t = relay.scalar_type("float32")
    b = relay.Var("b", t)
    mod[mod.entry_func] = relay.Function([], make_id(b))
    mod = transform.InferType()(mod)
    assert mod[mod.entry_func].body.checked_type == id_type(t)
Ejemplo n.º 10
0
def test_unify_typecall():
    solver = make_solver()
    gtv = relay.GlobalTypeVar("gtv")

    # yeah, typecalls are shaped like tuples so the same
    # tests work out
    t1 = relay.ty.IncompleteType()
    t2 = relay.ty.IncompleteType()
    t3 = relay.ty.TensorType((10, 20), "float32")

    tc1 = relay.ty.TypeCall(gtv, [t1, t2])
    tc2 = relay.ty.TypeCall(gtv, [t3, t3])
    unified = solver.Unify(tc1, tc2)
    assert unified == tc2
Ejemplo n.º 11
0
def test_match():
    # pair each match keyword with whether it specifies a complete match or not
    match_keywords = [("match", True), ("match?", False)]
    for (match_keyword, is_complete) in match_keywords:
        mod = tvm.IRModule()

        list_var = relay.GlobalTypeVar("List")
        typ_var = relay.TypeVar("A")
        cons_constructor = relay.Constructor(
            "Cons", [typ_var, list_var(typ_var)], list_var)
        nil_constructor = relay.Constructor("Nil", [], list_var)
        list_def = relay.TypeData(list_var, [typ_var],
                                  [cons_constructor, nil_constructor])
        mod[list_var] = list_def

        length_var = relay.GlobalVar("length")
        typ_var = relay.TypeVar("A")
        input_type = list_var(typ_var)
        input_var = relay.Var("xs", input_type)
        rest_var = relay.Var("rest")
        cons_case = relay.Let(
            _, UNIT,
            relay.add(relay.const(1), relay.Call(length_var, [rest_var])))
        body = relay.Match(input_var, [
            relay.Clause(
                relay.PatternConstructor(
                    cons_constructor,
                    [relay.PatternWildcard(),
                     relay.PatternVar(rest_var)]), cons_case),
            relay.Clause(relay.PatternConstructor(nil_constructor, []),
                         relay.const(0))
        ],
                           complete=is_complete)
        length_func = relay.Function([input_var], body, int32, [typ_var])
        mod[length_var] = length_func

        assert parses_as(
            """
            %s

            def @length[A](%%xs: List[A]) -> int32 {
              %s (%%xs) {
                Cons(_, %%rest) => {
                  ();;
                  1 + @length(%%rest)
                },
                Nil => 0,
              }
            }
            """ % (LIST_DEFN, match_keyword), mod)
Ejemplo n.º 12
0
def test_typecall_kind():
    gtv = relay.GlobalTypeVar("gtv")

    mod = tvm.IRModule()
    data = relay.TypeData(gtv, [], [])
    mod[gtv] = data
    empty_call = relay.TypeCall(gtv, [])
    assert check_kind(empty_call, mod) == relay.TypeKind.Type

    new_mod = tvm.IRModule()
    tv = relay.TypeVar("tv")
    new_data = relay.TypeData(gtv, [tv], [])
    new_mod[gtv] = new_data
    call = relay.TypeCall(gtv, [relay.TupleType([])])
    assert check_kind(call, new_mod) == relay.TypeKind.Type
Ejemplo n.º 13
0
def test_multiple_type_param_defn():
    glob_typ_var = relay.GlobalTypeVar("Either")
    typ_var_a = relay.TypeVar("A")
    typ_var_b = relay.TypeVar("B")
    prog = relay.TypeData(glob_typ_var, [typ_var_a, typ_var_b], [
        relay.Constructor("Left", [typ_var_a], glob_typ_var),
        relay.Constructor("Right", [typ_var_b], glob_typ_var),
    ])
    mod = tvm.IRModule()
    mod[glob_typ_var] = prog
    assert parses_as(
        """
        type Either[A, B] {
          Left(A),
          Right(B),
        }
        """, mod)
Ejemplo n.º 14
0
def test_single_constructor_adt():
    mod = tvm.IRModule()
    box = relay.GlobalTypeVar("box")
    a = relay.TypeVar("a")
    box_ctor = relay.Constructor("box", [a], box)
    box_data = relay.TypeData(box, [a], [box_ctor])
    mod[box] = box_data

    v = relay.Var("v")
    match = relay.Match(v, [
        relay.Clause(
            relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), v)
    ])

    # with one constructor, having one pattern constructor case is exhaustive
    assert len(unmatched_cases(match, mod)) == 0

    # this will be so if we nest the constructors too
    nested_pattern = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(
                    box_ctor,
                    [
                        relay.PatternConstructor(
                            box_ctor,
                            [
                                relay.PatternConstructor(
                                    box_ctor, [relay.PatternWildcard()])
                            ],
                        )
                    ],
                ),
                v,
            )
        ],
    )
    assert len(unmatched_cases(nested_pattern, mod)) == 0
Ejemplo n.º 15
0
def init_box_adt(mod):
    box = relay.GlobalTypeVar('box')
    a = relay.TypeVar('a')
    box_ctor = relay.Constructor('box', [a], box)
    mod[box] = relay.TypeData(box, [a], [box_ctor])
    return (box, box_ctor)
Ejemplo n.º 16
0
def test_mixed_adt_constructors():
    mod = tvm.IRModule()
    box = relay.GlobalTypeVar("box")
    a = relay.TypeVar("a")
    box_ctor = relay.Constructor("box", [a], box)
    box_data = relay.TypeData(box, [a], [box_ctor])
    mod[box] = box_data

    p = Prelude(mod)

    v = relay.Var("v")
    box_of_lists_inc = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(
                    box_ctor,
                    [
                        relay.PatternConstructor(
                            p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
                        )
                    ],
                ),
                v,
            )
        ],
    )

    # will fail to match a box containing an empty list
    unmatched = unmatched_cases(box_of_lists_inc, mod)
    assert len(unmatched) == 1
    assert isinstance(unmatched[0], relay.PatternConstructor)
    assert unmatched[0].constructor == box_ctor
    assert len(unmatched[0].patterns) == 1 and unmatched[0].patterns[0].constructor == p.nil

    box_of_lists_comp = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(box_ctor, [relay.PatternConstructor(p.nil, [])]), v
            ),
            relay.Clause(
                relay.PatternConstructor(
                    box_ctor,
                    [
                        relay.PatternConstructor(
                            p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
                        )
                    ],
                ),
                v,
            ),
        ],
    )
    assert len(unmatched_cases(box_of_lists_comp, mod)) == 0

    list_of_boxes_inc = relay.Match(
        v,
        [
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                        relay.PatternWildcard(),
                    ],
                ),
                v,
            )
        ],
    )

    # fails to match empty list of boxes
    unmatched = unmatched_cases(list_of_boxes_inc, mod)
    assert len(unmatched) == 1
    assert isinstance(unmatched[0], relay.PatternConstructor)
    assert unmatched[0].constructor == p.nil

    list_of_boxes_comp = relay.Match(
        v,
        [
            # exactly one box
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                        relay.PatternConstructor(p.nil, []),
                    ],
                ),
                v,
            ),
            # exactly two boxes
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                        relay.PatternConstructor(
                            p.cons,
                            [
                                relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                                relay.PatternConstructor(p.nil, []),
                            ],
                        ),
                    ],
                ),
                v,
            ),
            # exactly three boxes
            relay.Clause(
                relay.PatternConstructor(
                    p.cons,
                    [
                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                        relay.PatternConstructor(
                            p.cons,
                            [
                                relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
                                relay.PatternConstructor(
                                    p.cons,
                                    [
                                        relay.PatternConstructor(
                                            box_ctor, [relay.PatternWildcard()]
                                        ),
                                        relay.PatternConstructor(p.nil, []),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
                v,
            ),
            # one or more boxes
            relay.Clause(
                relay.PatternConstructor(
                    p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
                ),
                v,
            ),
            # no boxes
            relay.Clause(relay.PatternConstructor(p.nil, []), v),
        ],
    )
    assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0
Ejemplo n.º 17
0
from ...abstract import (
    AbstractArray,
    AbstractError,
    AbstractFunction,
    AbstractScalar,
    AbstractTaggedUnion,
    AbstractTuple,
    TypedPrimitive,
    VirtualFunction,
    broaden,
)
from ...utils import overload
from ...xtype import Bool, EnvType, Nil, type_to_np_dtype

union_type = relay.GlobalTypeVar('$_union_adt')
empty_union = adt.Constructor("empty", [], union_type)
tag_map = {}
rev_tag_map = {}


def get_union_ctr(tag, t):
    """Get the relay constructor for a union tag."""
    if tag not in tag_map:
        assert t is not None
        rt = to_relay_type(t)
        ctr = adt.Constructor(f"c{tag}", [rt], union_type)
        tag_map[tag] = ctr
        rev_tag_map[ctr] = tag
    return tag_map[tag]
Ejemplo n.º 18
0
# v0.7 Type Problem High-Level IR Transformation fixed
import tvm
from tvm import relay

M7BIu = tvm.IRModule()
YDc6p = relay.GlobalTypeVar('box')
fzeFg = relay.TypeVar('')
ErplN = relay.TypeData(YDc6p, [fzeFg], [])
M7BIu[YDc6p] = ErplN
M7BIu[YDc6p] = ErplN
Ejemplo n.º 19
0
def test_unify_invalid_global_typevars():
    solver = make_solver()
    gtv1 = relay.GlobalTypeVar("gtv1")
    gtv2 = relay.GlobalTypeVar("gtv2")
    solver.Unify(gtv1, gtv2)
Ejemplo n.º 20
0
def test_unify_global_type_var():
    # should only be able to unify if they're the same
    solver = make_solver()
    gtv = relay.GlobalTypeVar("gtv")
    unified = solver.Unify(gtv, gtv)
    assert unified == gtv
Ejemplo n.º 21
0
    AbstractArray,
    AbstractError,
    AbstractFunctionUnique,
    AbstractHandle,
    AbstractRandomState,
    AbstractScalar,
    AbstractTaggedUnion,
    AbstractTuple,
    AbstractType,
    TypedPrimitive,
)
from myia.operations import primitives as P
from myia.utils import overload
from myia.xtype import Bool, EnvType, Nil, UniverseType, type_to_np_dtype

union_type = relay.GlobalTypeVar("$_union_adt")
empty_union = adt.Constructor("empty", [], union_type)
tag_map = {None: empty_union}
rev_tag_map = {}


def get_union_ctr(tag, t):
    """Get the relay constructor for a union tag."""
    if tag not in tag_map:
        assert t is not None
        rt = to_relay_type(t)
        ctr = adt.Constructor(f"c{tag}", [rt], union_type)
        tag_map[tag] = ctr
    return tag_map[tag]

Ejemplo n.º 22
0
def test_global_typevar_kind():
    v1 = relay.GlobalTypeVar("gtv1", relay.TypeKind.AdtHandle)
    v2 = relay.GlobalTypeVar("gtv2", relay.TypeKind.Type)

    assert check_kind(v1) == relay.TypeKind.AdtHandle
    assert check_kind(v2) == relay.TypeKind.Type
Ejemplo n.º 23
0
def test_typecall_invalid_callee():
    # global type var must be an ADT handle
    gtv = relay.GlobalTypeVar("v1", relay.TypeKind.Type)
    check_kind(relay.TypeCall(gtv, []))