def can_infer_type_of_function_with_explicit_signature_of_aliased_function_type(): args = nodes.arguments([]) node = nodes.func("f", args=args, body=[], type=nodes.ref("Action")) type_bindings = { "Action": types.meta_type(types.func([], types.none_type)) } assert_equal(types.func([], types.none_type), _infer_func_type(node, type_bindings))
def callee_can_be_overloaded_func_type_where_choice_is_unambiguous_given_args(): type_bindings = {"f": types.overloaded_func( types.func([types.str_type], types.int_type), types.func([types.int_type], types.str_type), )} node = nodes.call(nodes.ref("f"), [nodes.str_literal("")]) assert_equal(types.int_type, infer(node, type_bindings=type_bindings))
def rescursive_structural_types_do_not_cause_stack_overflow(self): recursive1 = types.structural_type("recursive1") recursive1.attrs.add("uh_oh", types.func([], recursive1)) recursive2 = types.structural_type("recursive2") recursive2.attrs.add("uh_oh", types.func([], recursive2)) assert not types.is_sub_type( recursive1, recursive2, )
def return_type_is_common_super_type_of_possible_return_types_of_overloaded_function(): type_bindings = {"f": types.overloaded_func( types.func([types.object_type], types.int_type), types.func([types.str_type], types.str_type), )} node = nodes.call(nodes.ref("f"), [nodes.str_literal("")]) assert_equal( types.common_super_type([types.int_type, types.str_type]), infer(node, type_bindings=type_bindings) )
def error_in_inferring_actual_argument_to_overloaded_function_is_not_failure_to_find_matching_overload(): type_bindings = {"f": types.overloaded_func( types.func([types.object_type], types.any_type), types.func([types.str_type], types.any_type), )} ref = nodes.ref("x") try: infer(nodes.call(nodes.ref("f"), [ref]), type_bindings=type_bindings) assert False, "Expected error" except errors.TypeCheckError as error: assert_equal(ref, error.node)
def instantiated_generic_structural_type_is_sub_type_of_other_instantiated_generic_structural_type_if_it_has_matching_attributes(self): iterator = types.generic_structural_type("iterator", [types.covariant("T")], lambda T: [ types.attr("__iter__", types.func([], iterator(T))), types.attr("__next__", types.func([], T)), ]) iterable = types.generic_structural_type("iterable", [types.covariant("T")], lambda T: [ types.attr("__iter__", types.func([], iterator(T))), ]) assert types.is_sub_type( iterable(types.int_type), iterator(types.int_type), )
def type_of_add_method_argument_allows_super_type(): cls = types.class_type("Addable", {}) cls.attrs.add("__add__", types.func([types.object_type], cls)) type_bindings = {"x": cls, "y": cls} addition = nodes.add(nodes.ref("x"), nodes.ref("y")) assert_equal(cls, infer(addition, type_bindings=type_bindings))
def return_type_of_add_can_differ_from_original_type(): cls = types.class_type("Addable", {}) cls.attrs.add("__add__", types.func([types.object_type], types.object_type)) type_bindings = {"x": cls, "y": cls} addition = nodes.add(nodes.ref("x"), nodes.ref("y")) assert_equal(types.object_type, infer(addition, type_bindings=type_bindings))
def can_infer_type_of_subscript_using_getitem(): cls = types.class_type("Blah", [ types.attr("__getitem__", types.func([types.int_type], types.str_type)), ]) type_bindings = {"x": cls} node = nodes.subscript(nodes.ref("x"), nodes.int_literal(4)) assert_equal(types.str_type, infer(node, type_bindings=type_bindings))
def if_positional_has_name_then_that_name_is_used_in_missing_argument_message(): node = _create_call([]) try: _infer_function_call(types.func([types.func_arg("message", types.str_type)], types.int_type), node) assert False, "Expected error" except errors.ArgumentsError as error: assert_is(node, error.node) assert_equal("missing argument 'message'", str(error))
def error_if_extra_positional_argument(): node = _create_call([nodes.str_literal("hello")]) try: _infer_function_call(types.func([], types.int_type), node) assert False, "Expected error" except errors.ArgumentsError as error: assert_is(node, error.node) assert_equal("function takes 0 positional arguments but 1 was given", str(error))
def error_if_extra_keyword_argument(): node = _create_call([], {"message": nodes.str_literal("hello")}) try: _infer_function_call(types.func([], types.int_type), node) assert False, "Expected error" except errors.ArgumentsError as error: assert_is(node, error.node) assert_equal("unexpected keyword argument 'message'", str(error))
def attributes_with_function_type_defined_in_class_definition_body_are_not_present_on_meta_type(): node = nodes.class_("User", [ nodes.assign([nodes.ref("is_person")], nodes.ref("true_func")), ]) meta_type = _infer_meta_type(node, ["is_person"], type_bindings={ "true_func": types.func([types.object_type], types.none_type) }) assert "is_person" not in meta_type.attrs
def error_if_positional_argument_is_missing(): node = _create_call([]) try: _infer_function_call(types.func([types.str_type], types.int_type), node) assert False, "Expected error" except errors.ArgumentsError as error: assert_is(node, error.node) assert_equal("missing 1st positional argument", str(error))
def object_can_be_called_if_it_has_call_magic_method_that_returns_callable(): second_cls = types.class_type("Second", [ types.attr("__call__", types.func([types.str_type], types.int_type)), ]) first_cls = types.class_type("First", [ types.attr("__call__", second_cls), ]) type_bindings = {"f": first_cls} assert_equal(types.int_type, infer(nodes.call(nodes.ref("f"), [nodes.str_literal("")]), type_bindings=type_bindings))
def recursive_instantiated_generic_structural_type_is_sub_type_of_same_instantiated_generic_structural_type_if_it_has_matching_attributes(self): recursive = types.generic_structural_type("recursive", [types.covariant("T")], lambda T: [ types.attr("__iter__", types.func([], recursive(T))), ]) assert types.is_sub_type( recursive(types.int_type), recursive(types.int_type), )
def can_infer_type_of_call_with_optional_argument_specified(): type_bindings = { "f": types.func( args=[types.func_arg(None, types.str_type, optional=True)], return_type=types.bool_type, ) } node = nodes.call(nodes.ref("f"), [nodes.str_literal("blah")]) assert_equal(types.bool_type, infer(node, type_bindings=type_bindings))
def function_adds_arguments_to_context(): signature = nodes.signature( args=[nodes.signature_arg(nodes.ref("int"))], returns=nodes.ref("int") ) args = nodes.arguments([nodes.argument("x")]) body = [nodes.ret(nodes.ref("x"))] node = nodes.func("f", args, body, type=signature) assert_equal(types.func([types.int_type], types.int_type), _infer_func_type(node))
def generic_type_arguments_are_covariant(): type_bindings = {"f": types.generic_func(["T"], lambda T: types.func([T, T], T), )} node = nodes.call(nodes.ref("f"), [nodes.str_literal(""), nodes.none()]) assert_equal( types.common_super_type([types.str_type, types.none_type]), infer(node, type_bindings=type_bindings) )
def can_infer_type_of_call_with_keyword_arguments(): type_bindings = { "f": types.func( args=[types.func_arg("name", types.str_type), types.func_arg("hats", types.int_type)], return_type=types.bool_type, ) } node = nodes.call(nodes.ref("f"), [], {"name": nodes.str_literal("Bob"), "hats": nodes.int_literal(42)}) assert_equal(types.bool_type, infer(node, type_bindings=type_bindings))
def init_must_be_function_definition(): func_node = nodes.assign([nodes.ref("__init__")], nodes.ref("f")) node = nodes.class_("User", [func_node]) try: _infer_class_type(node, ["__init__"], type_bindings={ "f": types.func([types.object_type], types.str_type) }) assert False, "Expected error" except errors.InitAttributeMustBeFunctionDefinitionError as error: assert_equal(func_node, error.node)
def test_transform_call_with_positional_arguments(self): func_node = nodes.ref("f") type_lookup = [ (func_node, types.func([types.str_type], types.none_type)) ] _assert_transform( nodes.call(func_node, [nodes.ref("x")]), cc.call(cc.ref("f"), [cc.ref("x")]), type_lookup=type_lookup, )
def can_infer_type_of_function_with_no_args_and_return_annotation(): node = nodes.func( "f", args=nodes.arguments([]), body=[ nodes.ret(nodes.int_literal(4)) ], type=nodes.signature(returns=nodes.ref("int")), ) assert_equal(types.func([], types.int_type), _infer_func_type(node))
def type_of_positional_arguments_must_match(): type_bindings = {"f": types.func([types.str_type], types.int_type)} arg_node = nodes.int_literal(4) node = nodes.call(nodes.ref("f"), [arg_node]) assert_type_mismatch( lambda: infer(node, type_bindings=type_bindings), expected=types.str_type, actual=types.int_type, node=arg_node, )
def init_method_is_used_as_call_method_on_meta_type(): node = _create_class_with_init( signature=nodes.signature( args=[nodes.signature_arg(nodes.ref("Self")), nodes.signature_arg(nodes.ref("str"))], returns=nodes.ref("none") ), args=nodes.args([nodes.arg("self"), nodes.arg("name")]), body=[], ) meta_type = _infer_meta_type(node, ["__init__"]) assert_equal(types.func([types.str_type], meta_type.type), meta_type.attrs.type_of("__call__"))
def invariant_type_parameter_can_be_unified_when_part_of_recursive_structural_type(self): invariant_type_param = types.invariant("T") recursive = types.generic_structural_type("recursive", [types.covariant("T")], lambda T: [ types.attr("__iter__", types.func([], recursive(T))), ]) assert types.is_sub_type( recursive(invariant_type_param), recursive(types.int_type), unify=[invariant_type_param] )
def error_if_generic_func_is_passed_wrong_arguments(): type_bindings = {"f": types.generic_func(["T"], lambda T: types.func([T, types.int_type], T), )} node = nodes.call(nodes.ref("f"), [nodes.str_literal(""), nodes.none()]) try: infer(node, type_bindings=type_bindings) assert False, "Expected error" except errors.ArgumentsError as error: assert_is(node, error.node) assert_equal("cannot call function of type: T => T, int -> T\nwith arguments: str, NoneType", str(error))
def method_signature_is_checked_when_defined_by_assignment(): func_node = nodes.assign([nodes.ref("is_person")], nodes.ref("f")) node = nodes.class_("User", [func_node]) try: _infer_class_type(node, ["is_person"], type_bindings={ "f": types.func([], types.bool_type) }) assert False, "Expected error" except errors.MethodHasNoArgumentsError as error: assert_equal(func_node, error.node) assert_equal("is_person", error.attr_name)
def add_method_should_only_accept_one_argument(): cls = types.class_type("NotAddable", {}) cls.attrs.add("__add__", types.func([types.object_type, types.object_type], cls)) type_bindings = {"x": cls, "y": cls} addition = nodes.add(nodes.ref("x"), nodes.ref("y")) try: infer(addition, type_bindings=type_bindings) assert False, "Expected error" except errors.TypeCheckError as error: assert_equal(addition, ephemeral.root_node(error.node))
def iter_method_must_take_no_arguments(): cls = types.class_type("Blah") cls.attrs.add("__iter__", types.func([types.str_type], types.iterable(types.str_type))) ref_node = nodes.ref("xs") node = nodes.for_(nodes.ref("x"), ref_node, []) try: update_context(node, type_bindings={"x": None, "xs": cls}) assert False, "Expected error" except errors.TypeCheckError as error: assert_equal(ref_node, ephemeral.root_node(error.node))