def test_typecall_kind(): gtv = relay.GlobalTypeVar("gtv") mod = tvm.IRModule() data = relay.TypeData(gtv, [], []) mod[gtv] = data empty_call = relay.TypeCall(gtv, []) assert check_kind(empty_call, mod) == relay.TypeKind.Type new_mod = tvm.IRModule() tv = relay.TypeVar("tv") new_data = relay.TypeData(gtv, [tv], []) new_mod[gtv] = new_data call = relay.TypeCall(gtv, [relay.TupleType([])]) assert check_kind(call, new_mod) == relay.TypeKind.Type
def initialize_box_adt(mod): box = relay.GlobalTypeVar('box') tv = relay.TypeVar('tv') constructor = relay.Constructor('constructor', [tv], box) data = relay.TypeData(box, [tv], [constructor]) mod[box] = data return (box, constructor)
def test_typecall_invalid_num_args(): mod = relay.Module() gtv = relay.GlobalTypeVar('v1') tv = relay.TypeVar('tv') data = relay.TypeData(gtv, [tv], []) mod[gtv] = data check_kind(relay.TypeCall(gtv, []))
def test_typecall_invalid_num_args(): mod = tvm.IRModule() gtv = relay.GlobalTypeVar("v1") tv = relay.TypeVar("tv") data = relay.TypeData(gtv, [tv], []) mod[gtv] = data check_kind(relay.TypeCall(gtv, []))
def test_single_constructor_adt(): mod = tvm.IRModule() box = relay.GlobalTypeVar("box") a = relay.TypeVar("a") box_ctor = relay.Constructor("box", [a], box) box_data = relay.TypeData(box, [a], [box_ctor]) mod[box] = box_data v = relay.Var("v") match = relay.Match( v, [relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), v)] ) # with one constructor, having one pattern constructor case is exhaustive assert len(unmatched_cases(match, mod)) == 0 # this will be so if we nest the constructors too nested_pattern = relay.Match( v, [ relay.Clause( relay.PatternConstructor( box_ctor, [ relay.PatternConstructor( box_ctor, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()])], ) ], ), v, ) ], ) assert len(unmatched_cases(nested_pattern, mod)) == 0
def initialize_box_adt(mod): box = relay.GlobalTypeVar("box") tv = relay.TypeVar("tv") constructor = relay.Constructor("constructor", [tv], box) data = relay.TypeData(box, [tv], [constructor]) mod[box] = data return (box, constructor)
def test_adt_cons_expr(): mod = tvm.IRModule() list_var = relay.GlobalTypeVar("List") typ_var = relay.TypeVar("A") cons_constructor = relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var) nil_constructor = relay.Constructor("Nil", [], list_var) list_def = relay.TypeData(list_var, [typ_var], [cons_constructor, nil_constructor]) mod[list_var] = list_def make_singleton_var = relay.GlobalVar("make_singleton") input_var = relay.Var("x", int32) make_singleton_func = relay.Function([input_var], cons_constructor( input_var, nil_constructor()), list_var(int32)) mod[make_singleton_var] = make_singleton_func assert parses_as( """ %s def @make_singleton(%%x: int32) -> List[int32] { Cons(%%x, Nil) } """ % LIST_DEFN, mod)
def test_match(): # pair each match keyword with whether it specifies a complete match or not match_keywords = [("match", True), ("match?", False)] for (match_keyword, is_complete) in match_keywords: mod = tvm.IRModule() list_var = relay.GlobalTypeVar("List") typ_var = relay.TypeVar("A") cons_constructor = relay.Constructor( "Cons", [typ_var, list_var(typ_var)], list_var) nil_constructor = relay.Constructor("Nil", [], list_var) list_def = relay.TypeData(list_var, [typ_var], [cons_constructor, nil_constructor]) mod[list_var] = list_def length_var = relay.GlobalVar("length") typ_var = relay.TypeVar("A") input_type = list_var(typ_var) input_var = relay.Var("xs", input_type) rest_var = relay.Var("rest") cons_case = relay.Let( relay.var("", type_annotation=None), UNIT, relay.add(relay.const(1), relay.Call(length_var, [rest_var])), ) body = relay.Match( input_var, [ relay.Clause( relay.PatternConstructor( cons_constructor, [relay.PatternWildcard(), relay.PatternVar(rest_var)]), cons_case, ), relay.Clause(relay.PatternConstructor(nil_constructor, []), relay.const(0)), ], complete=is_complete, ) length_func = relay.Function([input_var], body, int32, [typ_var]) mod[length_var] = length_func assert_parse_module_as( """ %s def @length[A](%%xs: List[A]) -> int32 { %s (%%xs) { Cons(_, %%rest : List[A]) => { (); 1 + @length(%%rest) }, Nil => 0, } } """ % (LIST_DEFN, match_keyword), mod, )
def test_typecall_invalid_args(): # args must all be type kind mod = tvm.IRModule() gtv = relay.GlobalTypeVar("v1") data = relay.TypeData(gtv, [], []) mod[gtv] = data check_kind(relay.TypeCall(gtv, [data]))
def test_empty_adt_defn(): mod = tvm.IRModule() glob_typ_var = relay.GlobalTypeVar("Ayy") prog = relay.TypeData(glob_typ_var, [], []) mod[glob_typ_var] = prog assert parses_as(""" type Ayy { } """, mod)
def test_adt_defn(): mod = tvm.IRModule() glob_typ_var = relay.GlobalTypeVar("Ayy") prog = relay.TypeData(glob_typ_var, [], [relay.Constructor("Nil", [], glob_typ_var)]) mod[glob_typ_var] = prog assert parses_as(""" type Ayy { Nil } """, mod)
def test_multiple_cons_defn(): mod = tvm.IRModule() list_var = relay.GlobalTypeVar("List") typ_var = relay.TypeVar("A") prog = relay.TypeData(list_var, [typ_var], [ relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var), relay.Constructor("Nil", [], list_var), ]) mod[list_var] = prog assert parses_as(LIST_DEFN, mod)
def test_id_type(): mod = relay.Module() id_type = relay.GlobalTypeVar("id") a = relay.TypeVar("a") mod[id_type] = relay.TypeData(id_type, [a], []) b = relay.TypeVar("b") make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b])) t = relay.scalar_type("float32") b = relay.Var("b", t) assert relay.ir_pass.infer_type(make_id(b), mod).checked_type == id_type(t)
def test_extern_adt_defn(): mod = tvm.IRModule() extern_var = relay.GlobalTypeVar("T") typ_var = relay.TypeVar("A") extern_def = relay.TypeData(extern_var, [typ_var], []) mod[extern_var] = extern_def assert_parse_module_as(""" extern type T[A] """, mod)
def test_extern_adt_defn(): # TODO(weberlo): update this test once extern is implemented mod = tvm.IRModule() extern_var = relay.GlobalTypeVar("T") typ_var = relay.TypeVar("A") extern_def = relay.TypeData(extern_var, [typ_var], []) mod[extern_var] = extern_def assert parses_as(""" extern type T[A] """, mod)
def test_id_type(): mod = relay.Module() id_type = relay.GlobalTypeVar("id") a = relay.TypeVar("a") mod[id_type] = relay.TypeData(id_type, [a], []) b = relay.TypeVar("b") make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b])) t = relay.scalar_type("float32") b = relay.Var("b", t) mod[mod.entry_func] = relay.Function([], make_id(b)) mod = transform.InferType()(mod) assert mod[mod.entry_func].body.checked_type == id_type(t)
def test_multiple_type_param_defn(): glob_typ_var = relay.GlobalTypeVar("Either") typ_var_a = relay.TypeVar("A") typ_var_b = relay.TypeVar("B") prog = relay.TypeData(glob_typ_var, [typ_var_a, typ_var_b], [ relay.Constructor("Left", [typ_var_a], glob_typ_var), relay.Constructor("Right", [typ_var_b], glob_typ_var), ]) mod = tvm.IRModule() mod[glob_typ_var] = prog assert parses_as( """ type Either[A, B] { Left(A), Right(B), } """, mod)
def init_box_adt(mod): box = relay.GlobalTypeVar('box') a = relay.TypeVar('a') box_ctor = relay.Constructor('box', [a], box) mod[box] = relay.TypeData(box, [a], [box_ctor]) return (box, box_ctor)
# v0.7 Type Problem High-Level IR Transformation fixed import tvm from tvm import relay M7BIu = tvm.IRModule() YDc6p = relay.GlobalTypeVar('box') fzeFg = relay.TypeVar('') ErplN = relay.TypeData(YDc6p, [fzeFg], []) M7BIu[YDc6p] = ErplN M7BIu[YDc6p] = ErplN
def test_mixed_adt_constructors(): mod = tvm.IRModule() box = relay.GlobalTypeVar("box") a = relay.TypeVar("a") box_ctor = relay.Constructor("box", [a], box) box_data = relay.TypeData(box, [a], [box_ctor]) mod[box] = box_data p = Prelude(mod) v = relay.Var("v") box_of_lists_inc = relay.Match( v, [ relay.Clause( relay.PatternConstructor( box_ctor, [ relay.PatternConstructor( p.cons, [relay.PatternWildcard(), relay.PatternWildcard()] ) ], ), v, ) ], ) # will fail to match a box containing an empty list unmatched = unmatched_cases(box_of_lists_inc, mod) assert len(unmatched) == 1 assert isinstance(unmatched[0], relay.PatternConstructor) assert unmatched[0].constructor == box_ctor assert len(unmatched[0].patterns) == 1 and unmatched[0].patterns[0].constructor == p.nil box_of_lists_comp = relay.Match( v, [ relay.Clause( relay.PatternConstructor(box_ctor, [relay.PatternConstructor(p.nil, [])]), v ), relay.Clause( relay.PatternConstructor( box_ctor, [ relay.PatternConstructor( p.cons, [relay.PatternWildcard(), relay.PatternWildcard()] ) ], ), v, ), ], ) assert len(unmatched_cases(box_of_lists_comp, mod)) == 0 list_of_boxes_inc = relay.Match( v, [ relay.Clause( relay.PatternConstructor( p.cons, [ relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), relay.PatternWildcard(), ], ), v, ) ], ) # fails to match empty list of boxes unmatched = unmatched_cases(list_of_boxes_inc, mod) assert len(unmatched) == 1 assert isinstance(unmatched[0], relay.PatternConstructor) assert unmatched[0].constructor == p.nil list_of_boxes_comp = relay.Match( v, [ # exactly one box relay.Clause( relay.PatternConstructor( p.cons, [ relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), relay.PatternConstructor(p.nil, []), ], ), v, ), # exactly two boxes relay.Clause( relay.PatternConstructor( p.cons, [ relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), relay.PatternConstructor( p.cons, [ relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), relay.PatternConstructor(p.nil, []), ], ), ], ), v, ), # exactly three boxes relay.Clause( relay.PatternConstructor( p.cons, [ relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), relay.PatternConstructor( p.cons, [ relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), relay.PatternConstructor( p.cons, [ relay.PatternConstructor( box_ctor, [relay.PatternWildcard()] ), relay.PatternConstructor(p.nil, []), ], ), ], ), ], ), v, ), # one or more boxes relay.Clause( relay.PatternConstructor( p.cons, [relay.PatternWildcard(), relay.PatternWildcard()] ), v, ), # no boxes relay.Clause(relay.PatternConstructor(p.nil, []), v), ], ) assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0