def test_strict(self): ast = parser.parse_string(pytd_src(""" import typing T = TypeVar('T') class list(typing.Generic[T], object): pass class A(): pass class B(A): pass class `~unknown0`(): pass a = ... # type: A def left() -> `~unknown0`: ... def right() -> list[A]: ... """), options=self.options) ast = self.LinkAgainstSimpleBuiltins(ast) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) left, right = ast.Lookup("left"), ast.Lookup("right") unknown0 = escape.unknown(0) self.assertEqual( m.match(left, right, {}), booleq.And((booleq.Eq(unknown0, "list"), booleq.Eq(f"{unknown0}.list.T", "A"))))
def testStrict(self): ast = parser.parse_string( textwrap.dedent(""" T = TypeVar('T') class list(typing.Generic[T], object): pass class A(): pass class B(A): pass class `~unknown0`(): pass a = ... # type: A def left() -> `~unknown0` def right() -> list[A] """)) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) left, right = ast.Lookup("left"), ast.Lookup("right") self.assertEquals( m.match(left, right, {}), booleq.And((booleq.Eq("~unknown0", "list"), booleq.Eq("~unknown0.list.T", "A"))))
def test_strict(self): ast = parser.parse_string(textwrap.dedent(""" import typing T = TypeVar('T') class list(typing.Generic[T], object): pass class A(): pass class B(A): pass class `~unknown0`(): pass a = ... # type: A def left() -> `~unknown0` def right() -> list[A] """), python_version=self.python_version) ast = self.LinkAgainstSimpleBuiltins(ast) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) left, right = ast.Lookup("left"), ast.Lookup("right") self.assertEqual( m.match(left, right, {}), booleq.And((booleq.Eq("~unknown0", "list"), booleq.Eq("~unknown0.list.T", "A"))))
def testCallableWithArguments(self): ast = self.ParseWithBuiltins(""" from typing import Callable v1 = ... # type: Callable[[int], int] v2 = ... # type: Callable[[bool], int] v3 = ... # type: Callable[[int], bool] v4 = ... # type: Callable[[int, str], int] v5 = ... # type: Callable[[bool, str], int] v6 = ... # type: Callable[[], int] """) v1 = ast.Lookup("v1").type v2 = ast.Lookup("v2").type v3 = ast.Lookup("v3").type v4 = ast.Lookup("v4").type v5 = ast.Lookup("v5").type v6 = ast.Lookup("v6").type m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) # Argument types are contravariant. self.assertEqual(m.match_Generic_against_Generic(v1, v2, {}), booleq.TRUE) self.assertEqual(m.match_Generic_against_Generic(v2, v1, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(v1, v4, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(v4, v1, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(v4, v5, {}), booleq.TRUE) self.assertEqual(m.match_Generic_against_Generic(v5, v4, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(v1, v6, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(v6, v1, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(v6, v6, {}), booleq.TRUE) # Return type is covariant. self.assertEqual(m.match_Generic_against_Generic(v1, v3, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(v3, v1, {}), booleq.TRUE)
def test_heterogeneous_tuple(self): ast = self.ParseWithBuiltins(""" from typing import Tuple x1 = ... # type: Tuple[int] x2 = ... # type: Tuple[bool, str] x3 = ... # type: Tuple[int, str] """) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) x1 = ast.Lookup("x1").type x2 = ast.Lookup("x2").type x3 = ast.Lookup("x3").type self.assertEqual(m.match_Generic_against_Generic(x1, x1, {}), booleq.TRUE) self.assertEqual(m.match_Generic_against_Generic(x1, x2, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(x1, x3, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(x2, x1, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(x2, x2, {}), booleq.TRUE) self.assertEqual(m.match_Generic_against_Generic(x2, x3, {}), booleq.TRUE) self.assertEqual(m.match_Generic_against_Generic(x3, x1, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(x3, x2, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(x3, x3, {}), booleq.TRUE)
def testCallableAndType(self): ast = self.ParseWithBuiltins(""" from typing import Callable, Type v1 = ... # type: Callable[..., int] v2 = ... # type: Callable[..., bool] v3 = ... # type: Callable[[], int] v4 = ... # type: Callable[[], bool] v5 = ... # type: Type[int] v6 = ... # type: Type[bool] """) v1 = ast.Lookup("v1").type v2 = ast.Lookup("v2").type v3 = ast.Lookup("v3").type v4 = ast.Lookup("v4").type v5 = ast.Lookup("v5").type v6 = ast.Lookup("v6").type m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) self.assertEqual(m.match_Generic_against_Generic(v1, v6, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(v6, v1, {}), booleq.TRUE) self.assertEqual(m.match_Generic_against_Generic(v2, v5, {}), booleq.TRUE) self.assertEqual(m.match_Generic_against_Generic(v5, v2, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(v3, v6, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(v6, v3, {}), booleq.TRUE) self.assertEqual(m.match_Generic_against_Generic(v4, v5, {}), booleq.TRUE) self.assertEqual(m.match_Generic_against_Generic(v5, v4, {}), booleq.FALSE)
def testCallableNoArguments(self): ast = self.ParseWithBuiltins(""" from typing import Callable v1 = ... # type: Callable[..., int] v2 = ... # type: Callable[..., bool] """) v1 = ast.Lookup("v1").type v2 = ast.Lookup("v2").type m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) # Return type is covariant. self.assertEqual(m.match_Generic_against_Generic(v1, v2, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(v2, v1, {}), booleq.TRUE)
def testFunctionAgainstTupleSubclass(self): ast = self.ParseWithBuiltins(""" from typing import Tuple class A(Tuple[int, str]): ... def f(x): ... """) a = ast.Lookup("A") f = ast.Lookup("f") m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) # Smoke test for the TupleType logic in match_Function_against_Class self.assertEqual(m.match_Function_against_Class(f, a, {}, {}), booleq.FALSE)
def testExternal(self): ast = parser.parse_string(textwrap.dedent(""" class Base(): pass class Foo(Base): pass base = ... # type: Base """)) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) mod1_foo = pytd.ExternalType("Foo", module="mod1", cls=ast.Lookup("Foo")) eq = m.match_type_against_type(mod1_foo, ast.Lookup("base").type, {}) self.assertEquals(eq, booleq.TRUE)
def testBaseClass(self): ast = parser.parse_string(textwrap.dedent(""" class Base(): def f(self, x:Base) -> Base class Foo(Base): pass class Match(): def f(self, x:Base) -> Base """), python_version=self.PYTHON_VERSION) ast = self.LinkAgainstSimpleBuiltins(ast) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) eq = m.match_Class_against_Class(ast.Lookup("Match"), ast.Lookup("Foo"), {}) self.assertEqual(eq, booleq.TRUE)
def testBaseClass(self): ast = parser.parse_string(textwrap.dedent(""" class Base(): def f(self, x:Base) -> Base class Foo(Base): pass class Match(): def f(self, x:Base) -> Base """)) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) eq = m.match_Class_against_Class(ast.Lookup("Match"), ast.Lookup("Foo"), {}) self.assertEquals(eq, booleq.TRUE)
def solve(self): """Solve the equations generated from the pytd. Returns: A dictionary (str->str), mapping unknown class names to known class names. Raises: AssertionError: If we detect an internal error. """ hierarchy = type_match.get_all_subclasses([self.ast, self.builtins]) factory = type_match.TypeMatch(hierarchy) solver = factory.solver unknown_classes = set() partial_classes = set() complete_classes = set() for cls in self.ast.classes: if is_unknown(cls): solver.register_variable(cls.name) unknown_classes.add(cls) elif is_partial(cls): partial_classes.add(cls) else: complete_classes.add(cls) for complete in complete_classes.union(self.builtins.classes): for unknown in unknown_classes: self.match_unknown_against_complete(factory, solver, unknown, complete) for partial in partial_classes: if type_match.unpack_name_of_partial( partial.name) == complete.name: self.match_partial_against_complete( factory, solver, partial, complete) partial_functions = set() complete_functions = set() for f in self.ast.functions: if is_partial(f): partial_functions.add(f) else: complete_functions.add(f) for partial in partial_functions: for complete in complete_functions.union(self.builtins.functions): if type_match.unpack_name_of_partial( partial.name) == complete.name: self.match_call_record(factory, solver, partial, complete) log.info("=========== Equations to solve =============\n%s", solver) log.info("=========== Equations to solve (end) =======") return solver.solve()
def testSubclasses(self): ast = parser.parse_string(textwrap.dedent(""" class A(): pass class B(A): pass a = ... # type: A def left(a: B) -> B def right(a: A) -> A """)) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) left, right = ast.Lookup("left"), ast.Lookup("right") self.assertEquals(m.match(left, right, {}), booleq.TRUE) self.assertNotEquals(m.match(right, left, {}), booleq.TRUE)
def testSubclasses(self): ast = parser.parse_string(textwrap.dedent(""" class A(): pass class B(A): pass a = ... # type: A def left(a: B) -> B def right(a: A) -> A """), python_version=self.PYTHON_VERSION) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) left, right = ast.Lookup("left"), ast.Lookup("right") self.assertEqual(m.match(left, right, {}), booleq.TRUE) self.assertNotEqual(m.match(right, left, {}), booleq.TRUE)
def testUnknownAgainstTuple(self): ast = self.ParseWithBuiltins(""" from typing import Tuple class `~unknown0`(): pass x = ... # type: Tuple[int, str] """) unk = ast.Lookup("~unknown0") tup = ast.Lookup("x").type m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) match = m.match_Unknown_against_Generic(unk, tup, {}) self.assertListEqual(sorted(match.extract_equalities()), [("~unknown0", "__builtin__.tuple"), ("~unknown0.__builtin__.tuple._T", "int"), ("~unknown0.__builtin__.tuple._T", "str")])
def testExternal(self): ast = parser.parse_string( textwrap.dedent(""" class Base(): pass class Foo(Base): pass base = ... # type: Base """)) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) mod1_foo = pytd.ExternalType("Foo", module="mod1", cls=ast.Lookup("Foo")) eq = m.match_type_against_type(mod1_foo, ast.Lookup("base").type, {}) self.assertEquals(eq, booleq.TRUE)
def testBaseClass(self): ast = parser.parse_string( textwrap.dedent(""" class Base(): def f(self, x:Base) -> Base class Foo(Base): pass class Match(): def f(self, x:Base) -> Base """)) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) eq = m.match_Class_against_Class(ast.Lookup("Match"), ast.Lookup("Foo"), {}) self.assertEquals(eq, booleq.TRUE)
def test_base_class(self): ast = parser.parse_string(textwrap.dedent(""" class Base(): def f(self, x:Base) -> Base: ... class Foo(Base): pass class Match(): def f(self, x:Base) -> Base: ... """), options=self.options) ast = self.LinkAgainstSimpleBuiltins(ast) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) eq = m.match_Class_against_Class(ast.Lookup("Match"), ast.Lookup("Foo"), {}) self.assertEqual(eq, booleq.TRUE)
def test_unknown_against_tuple(self): ast = self.ParseWithBuiltins(pytd_src(""" from typing import Tuple class `~unknown0`(): pass x = ... # type: Tuple[int, str] """)) unknown0 = escape.unknown(0) unk = ast.Lookup(unknown0) tup = ast.Lookup("x").type m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) match = m.match_Unknown_against_Generic(unk, tup, {}) self.assertCountEqual(sorted(match.extract_equalities()), [(unknown0, "__builtin__.tuple"), (f"{unknown0}.__builtin__.tuple._T", "int"), (f"{unknown0}.__builtin__.tuple._T", "str")])
def solve(self): """Solve the equations generated from the pytd. Returns: A dictionary (str->str), mapping unknown class names to known class names. Raises: AssertionError: If we detect an internal error. """ hierarchy = type_match.get_all_subclasses([self.ast, self.builtins]) factory = type_match.TypeMatch(hierarchy) solver = factory.solver unknown_classes = set() partial_classes = set() complete_classes = set() for cls in self.ast.classes: if is_unknown(cls): solver.register_variable(cls.name) unknown_classes.add(cls) elif is_partial(cls): partial_classes.add(cls) else: complete_classes.add(cls) for complete in complete_classes.union(self.builtins.classes): for unknown in unknown_classes: self.match_unknown_against_complete(factory, solver, unknown, complete) for partial in partial_classes: if type_match.unpack_name_of_partial(partial.name) == complete.name: self.match_partial_against_complete( factory, solver, partial, complete) partial_functions = set() complete_functions = set() for f in self.ast.functions: if is_partial(f): partial_functions.add(f) else: complete_functions.add(f) for partial in partial_functions: for complete in complete_functions.union(self.builtins.functions): if type_match.unpack_name_of_partial(partial.name) == complete.name: self.match_call_record(factory, solver, partial, complete) log.info("=========== Equations to solve =============\n%s", solver) log.info("=========== Equations to solve (end) =======") return solver.solve()
def testTuple(self): ast = self.ParseWithBuiltins(""" from typing import Tuple x1 = ... # type: Tuple[bool, ...] x2 = ... # type: Tuple[int, ...] y1 = ... # type: Tuple[bool, int] """) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) x1 = ast.Lookup("x1").type x2 = ast.Lookup("x2").type y1 = ast.Lookup("y1").type # Tuple[T, ...] matches Tuple[U, V] when T matches both U and V. self.assertEqual(m.match_Generic_against_Generic(x1, y1, {}), booleq.TRUE) self.assertEqual(m.match_Generic_against_Generic(x2, y1, {}), booleq.FALSE) # Tuple[U, V] matches Tuple[T, ...] when Union[U, V] matches T. self.assertEqual(m.match_Generic_against_Generic(y1, x1, {}), booleq.FALSE) self.assertEqual(m.match_Generic_against_Generic(y1, x2, {}), booleq.TRUE)
def testStrict(self): ast = parser.parse_string(textwrap.dedent(""" class list<T>(nothing): pass class A(nothing): pass class B(A): pass class `~unknown0`(nothing): pass a : A def left() -> `~unknown0` def right() -> list<A> """)) ast = visitors.LookupClasses(ast) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) left, right = ast.Lookup("left"), ast.Lookup("right") self.assertEquals(m.match(left, right, {}), booleq.And((booleq.Eq("~unknown0", "list"), booleq.Eq("~unknown0.list.T", "A"))))
def testStrict(self): ast = parser.parse_string(textwrap.dedent(""" T = TypeVar('T') class list(typing.Generic[T], object): pass class A(): pass class B(A): pass class `~unknown0`(): pass a = ... # type: A def left() -> `~unknown0` def right() -> list[A] """)) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) left, right = ast.Lookup("left"), ast.Lookup("right") self.assertEquals(m.match(left, right, {}), booleq.And((booleq.Eq("~unknown0", "list"), booleq.Eq("~unknown0.list.T", "A"))))
def solve(self): """Solve the equations generated from the pytd. Returns: A dictionary (str->str), mapping unknown class names to known class names. Raises: AssertionError: If we detect an internal error. """ hierarchy = type_match.get_all_subclasses([self.ast, self.builtins]) factory_protocols = type_match.TypeMatch(hierarchy) factory_partial = type_match.TypeMatch(hierarchy) solver_protocols = factory_protocols.solver solver_partial = factory_partial.solver unknown_classes = set() partial_classes = set() complete_classes = set() for cls in self.ast.classes: if is_unknown(cls): solver_protocols.register_variable(cls.name) solver_partial.register_variable(cls.name) unknown_classes.add(cls) elif is_partial(cls): partial_classes.add(cls) else: complete_classes.add(cls) protocol_classes_and_aliases = set(self.protocols.classes) for alias in self.protocols.aliases: if (not isinstance(alias.type, pytd.AnythingType) and alias.name != "protocols.Protocol"): protocol_classes_and_aliases.add(alias.type.cls) # solve equations from protocols first for protocol in protocol_classes_and_aliases: for unknown in unknown_classes: self.match_unknown_against_protocol( factory_protocols, solver_protocols, unknown, protocol) # also solve partial equations for complete in complete_classes.union(self.builtins.classes): for partial in partial_classes: if type_match.unpack_name_of_partial(partial.name) == complete.name: self.match_partial_against_complete( factory_partial, solver_partial, partial, complete) partial_functions = set() complete_functions = set() for f in self.ast.functions: if is_partial(f): partial_functions.add(f) else: complete_functions.add(f) for partial in partial_functions: for complete in complete_functions.union(self.builtins.functions): if type_match.unpack_name_of_partial(partial.name) == complete.name: self.match_call_record( factory_partial, solver_partial, partial, complete) log.info("=========== Equations to solve =============\n%s", solver_protocols) log.info("=========== Equations to solve (end) =======") solved_protocols = solver_protocols.solve() log.info("=========== Call trace equations to solve =============\n%s", solver_partial) log.info("=========== Call trace equations to solve (end) =======") solved_partial = solver_partial.solve() merged_solution = {} for unknown in itertools.chain(solved_protocols, solved_partial): if unknown in solved_protocols and unknown in solved_partial: merged_solution[unknown] = solved_protocols[unknown].union( solved_partial[unknown]) # remove Any from set if present # if no restrictions are present, it will be labeled Any later # otherwise, Any will override other restrictions that were found merged_solution[unknown].discard("?") elif unknown in solved_protocols: merged_solution[unknown] = solved_protocols[unknown] else: merged_solution[unknown] = solved_partial[unknown] return merged_solution