Exemplo n.º 1
0
    def test_strict(self):
        ast = parser.parse_string(pytd_src("""
      import typing

      T = TypeVar('T')
      class list(typing.Generic[T], object):
        pass
      class A():
        pass
      class B(A):
        pass
      class `~unknown0`():
        pass
      a = ...  # type: A
      def left() -> `~unknown0`: ...
      def right() -> list[A]: ...
    """),
                                  options=self.options)
        ast = self.LinkAgainstSimpleBuiltins(ast)
        m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
        left, right = ast.Lookup("left"), ast.Lookup("right")
        unknown0 = escape.unknown(0)
        self.assertEqual(
            m.match(left, right, {}),
            booleq.And((booleq.Eq(unknown0,
                                  "list"), booleq.Eq(f"{unknown0}.list.T",
                                                     "A"))))
Exemplo n.º 2
0
    def testStrict(self):
        ast = parser.parse_string(
            textwrap.dedent("""

      T = TypeVar('T')
      class list(typing.Generic[T], object):
        pass
      class A():
        pass
      class B(A):
        pass
      class `~unknown0`():
        pass
      a = ...  # type: A
      def left() -> `~unknown0`
      def right() -> list[A]
    """))
        ast = visitors.LookupClasses(ast, self.mini_builtins)
        m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
        left, right = ast.Lookup("left"), ast.Lookup("right")
        self.assertEquals(
            m.match(left, right, {}),
            booleq.And((booleq.Eq("~unknown0",
                                  "list"), booleq.Eq("~unknown0.list.T",
                                                     "A"))))
Exemplo n.º 3
0
    def test_strict(self):
        ast = parser.parse_string(textwrap.dedent("""
      import typing

      T = TypeVar('T')
      class list(typing.Generic[T], object):
        pass
      class A():
        pass
      class B(A):
        pass
      class `~unknown0`():
        pass
      a = ...  # type: A
      def left() -> `~unknown0`
      def right() -> list[A]
    """),
                                  python_version=self.python_version)
        ast = self.LinkAgainstSimpleBuiltins(ast)
        m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
        left, right = ast.Lookup("left"), ast.Lookup("right")
        self.assertEqual(
            m.match(left, right, {}),
            booleq.And((booleq.Eq("~unknown0",
                                  "list"), booleq.Eq("~unknown0.list.T",
                                                     "A"))))
Exemplo n.º 4
0
 def testCallableWithArguments(self):
   ast = self.ParseWithBuiltins("""
     from typing import Callable
     v1 = ...  # type: Callable[[int], int]
     v2 = ...  # type: Callable[[bool], int]
     v3 = ...  # type: Callable[[int], bool]
     v4 = ...  # type: Callable[[int, str], int]
     v5 = ...  # type: Callable[[bool, str], int]
     v6 = ...  # type: Callable[[], int]
   """)
   v1 = ast.Lookup("v1").type
   v2 = ast.Lookup("v2").type
   v3 = ast.Lookup("v3").type
   v4 = ast.Lookup("v4").type
   v5 = ast.Lookup("v5").type
   v6 = ast.Lookup("v6").type
   m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
   # Argument types are contravariant.
   self.assertEqual(m.match_Generic_against_Generic(v1, v2, {}), booleq.TRUE)
   self.assertEqual(m.match_Generic_against_Generic(v2, v1, {}), booleq.FALSE)
   self.assertEqual(m.match_Generic_against_Generic(v1, v4, {}), booleq.FALSE)
   self.assertEqual(m.match_Generic_against_Generic(v4, v1, {}), booleq.FALSE)
   self.assertEqual(m.match_Generic_against_Generic(v4, v5, {}), booleq.TRUE)
   self.assertEqual(m.match_Generic_against_Generic(v5, v4, {}), booleq.FALSE)
   self.assertEqual(m.match_Generic_against_Generic(v1, v6, {}), booleq.FALSE)
   self.assertEqual(m.match_Generic_against_Generic(v6, v1, {}), booleq.FALSE)
   self.assertEqual(m.match_Generic_against_Generic(v6, v6, {}), booleq.TRUE)
   # Return type is covariant.
   self.assertEqual(m.match_Generic_against_Generic(v1, v3, {}), booleq.FALSE)
   self.assertEqual(m.match_Generic_against_Generic(v3, v1, {}), booleq.TRUE)
Exemplo n.º 5
0
 def test_heterogeneous_tuple(self):
     ast = self.ParseWithBuiltins("""
   from typing import Tuple
   x1 = ...  # type: Tuple[int]
   x2 = ...  # type: Tuple[bool, str]
   x3 = ...  # type: Tuple[int, str]
 """)
     m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
     x1 = ast.Lookup("x1").type
     x2 = ast.Lookup("x2").type
     x3 = ast.Lookup("x3").type
     self.assertEqual(m.match_Generic_against_Generic(x1, x1, {}),
                      booleq.TRUE)
     self.assertEqual(m.match_Generic_against_Generic(x1, x2, {}),
                      booleq.FALSE)
     self.assertEqual(m.match_Generic_against_Generic(x1, x3, {}),
                      booleq.FALSE)
     self.assertEqual(m.match_Generic_against_Generic(x2, x1, {}),
                      booleq.FALSE)
     self.assertEqual(m.match_Generic_against_Generic(x2, x2, {}),
                      booleq.TRUE)
     self.assertEqual(m.match_Generic_against_Generic(x2, x3, {}),
                      booleq.TRUE)
     self.assertEqual(m.match_Generic_against_Generic(x3, x1, {}),
                      booleq.FALSE)
     self.assertEqual(m.match_Generic_against_Generic(x3, x2, {}),
                      booleq.FALSE)
     self.assertEqual(m.match_Generic_against_Generic(x3, x3, {}),
                      booleq.TRUE)
Exemplo n.º 6
0
 def testCallableAndType(self):
   ast = self.ParseWithBuiltins("""
     from typing import Callable, Type
     v1 = ...  # type: Callable[..., int]
     v2 = ...  # type: Callable[..., bool]
     v3 = ...  # type: Callable[[], int]
     v4 = ...  # type: Callable[[], bool]
     v5 = ...  # type: Type[int]
     v6 = ...  # type: Type[bool]
   """)
   v1 = ast.Lookup("v1").type
   v2 = ast.Lookup("v2").type
   v3 = ast.Lookup("v3").type
   v4 = ast.Lookup("v4").type
   v5 = ast.Lookup("v5").type
   v6 = ast.Lookup("v6").type
   m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
   self.assertEqual(m.match_Generic_against_Generic(v1, v6, {}), booleq.FALSE)
   self.assertEqual(m.match_Generic_against_Generic(v6, v1, {}), booleq.TRUE)
   self.assertEqual(m.match_Generic_against_Generic(v2, v5, {}), booleq.TRUE)
   self.assertEqual(m.match_Generic_against_Generic(v5, v2, {}), booleq.FALSE)
   self.assertEqual(m.match_Generic_against_Generic(v3, v6, {}), booleq.FALSE)
   self.assertEqual(m.match_Generic_against_Generic(v6, v3, {}), booleq.TRUE)
   self.assertEqual(m.match_Generic_against_Generic(v4, v5, {}), booleq.TRUE)
   self.assertEqual(m.match_Generic_against_Generic(v5, v4, {}), booleq.FALSE)
Exemplo n.º 7
0
 def testCallableNoArguments(self):
   ast = self.ParseWithBuiltins("""
     from typing import Callable
     v1 = ...  # type: Callable[..., int]
     v2 = ...  # type: Callable[..., bool]
   """)
   v1 = ast.Lookup("v1").type
   v2 = ast.Lookup("v2").type
   m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
   # Return type is covariant.
   self.assertEqual(m.match_Generic_against_Generic(v1, v2, {}), booleq.FALSE)
   self.assertEqual(m.match_Generic_against_Generic(v2, v1, {}), booleq.TRUE)
Exemplo n.º 8
0
 def testFunctionAgainstTupleSubclass(self):
   ast = self.ParseWithBuiltins("""
     from typing import Tuple
     class A(Tuple[int, str]): ...
     def f(x): ...
   """)
   a = ast.Lookup("A")
   f = ast.Lookup("f")
   m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
   # Smoke test for the TupleType logic in match_Function_against_Class
   self.assertEqual(m.match_Function_against_Class(f, a, {}, {}),
                    booleq.FALSE)
Exemplo n.º 9
0
 def testExternal(self):
   ast = parser.parse_string(textwrap.dedent("""
     class Base():
       pass
     class Foo(Base):
       pass
     base = ...  # type: Base
   """))
   ast = visitors.LookupClasses(ast, self.mini_builtins)
   m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
   mod1_foo = pytd.ExternalType("Foo", module="mod1", cls=ast.Lookup("Foo"))
   eq = m.match_type_against_type(mod1_foo, ast.Lookup("base").type, {})
   self.assertEquals(eq, booleq.TRUE)
Exemplo n.º 10
0
  def testBaseClass(self):
    ast = parser.parse_string(textwrap.dedent("""
      class Base():
        def f(self, x:Base) -> Base
      class Foo(Base):
        pass

      class Match():
        def f(self, x:Base) -> Base
    """), python_version=self.PYTHON_VERSION)
    ast = self.LinkAgainstSimpleBuiltins(ast)
    m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
    eq = m.match_Class_against_Class(ast.Lookup("Match"), ast.Lookup("Foo"), {})
    self.assertEqual(eq, booleq.TRUE)
Exemplo n.º 11
0
  def testBaseClass(self):
    ast = parser.parse_string(textwrap.dedent("""
      class Base():
        def f(self, x:Base) -> Base
      class Foo(Base):
        pass

      class Match():
        def f(self, x:Base) -> Base
    """))
    ast = visitors.LookupClasses(ast, self.mini_builtins)
    m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
    eq = m.match_Class_against_Class(ast.Lookup("Match"), ast.Lookup("Foo"), {})
    self.assertEquals(eq, booleq.TRUE)
Exemplo n.º 12
0
    def solve(self):
        """Solve the equations generated from the pytd.

    Returns:
      A dictionary (str->str), mapping unknown class names to known class names.
    Raises:
      AssertionError: If we detect an internal error.
    """
        hierarchy = type_match.get_all_subclasses([self.ast, self.builtins])
        factory = type_match.TypeMatch(hierarchy)
        solver = factory.solver

        unknown_classes = set()
        partial_classes = set()
        complete_classes = set()
        for cls in self.ast.classes:
            if is_unknown(cls):
                solver.register_variable(cls.name)
                unknown_classes.add(cls)
            elif is_partial(cls):
                partial_classes.add(cls)
            else:
                complete_classes.add(cls)

        for complete in complete_classes.union(self.builtins.classes):
            for unknown in unknown_classes:
                self.match_unknown_against_complete(factory, solver, unknown,
                                                    complete)
            for partial in partial_classes:
                if type_match.unpack_name_of_partial(
                        partial.name) == complete.name:
                    self.match_partial_against_complete(
                        factory, solver, partial, complete)

        partial_functions = set()
        complete_functions = set()
        for f in self.ast.functions:
            if is_partial(f):
                partial_functions.add(f)
            else:
                complete_functions.add(f)
        for partial in partial_functions:
            for complete in complete_functions.union(self.builtins.functions):
                if type_match.unpack_name_of_partial(
                        partial.name) == complete.name:
                    self.match_call_record(factory, solver, partial, complete)

        log.info("=========== Equations to solve =============\n%s", solver)
        log.info("=========== Equations to solve (end) =======")
        return solver.solve()
Exemplo n.º 13
0
 def testSubclasses(self):
   ast = parser.parse_string(textwrap.dedent("""
     class A():
       pass
     class B(A):
       pass
     a = ...  # type: A
     def left(a: B) -> B
     def right(a: A) -> A
   """))
   ast = visitors.LookupClasses(ast, self.mini_builtins)
   m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
   left, right = ast.Lookup("left"), ast.Lookup("right")
   self.assertEquals(m.match(left, right, {}), booleq.TRUE)
   self.assertNotEquals(m.match(right, left, {}), booleq.TRUE)
Exemplo n.º 14
0
 def testSubclasses(self):
   ast = parser.parse_string(textwrap.dedent("""
     class A():
       pass
     class B(A):
       pass
     a = ...  # type: A
     def left(a: B) -> B
     def right(a: A) -> A
   """), python_version=self.PYTHON_VERSION)
   ast = visitors.LookupClasses(ast, self.mini_builtins)
   m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
   left, right = ast.Lookup("left"), ast.Lookup("right")
   self.assertEqual(m.match(left, right, {}), booleq.TRUE)
   self.assertNotEqual(m.match(right, left, {}), booleq.TRUE)
Exemplo n.º 15
0
 def testUnknownAgainstTuple(self):
   ast = self.ParseWithBuiltins("""
     from typing import Tuple
     class `~unknown0`():
       pass
     x = ...  # type: Tuple[int, str]
   """)
   unk = ast.Lookup("~unknown0")
   tup = ast.Lookup("x").type
   m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
   match = m.match_Unknown_against_Generic(unk, tup, {})
   self.assertListEqual(sorted(match.extract_equalities()),
                        [("~unknown0", "__builtin__.tuple"),
                         ("~unknown0.__builtin__.tuple._T", "int"),
                         ("~unknown0.__builtin__.tuple._T", "str")])
Exemplo n.º 16
0
 def testExternal(self):
     ast = parser.parse_string(
         textwrap.dedent("""
   class Base():
     pass
   class Foo(Base):
     pass
   base = ...  # type: Base
 """))
     ast = visitors.LookupClasses(ast, self.mini_builtins)
     m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
     mod1_foo = pytd.ExternalType("Foo",
                                  module="mod1",
                                  cls=ast.Lookup("Foo"))
     eq = m.match_type_against_type(mod1_foo, ast.Lookup("base").type, {})
     self.assertEquals(eq, booleq.TRUE)
Exemplo n.º 17
0
    def testBaseClass(self):
        ast = parser.parse_string(
            textwrap.dedent("""
      class Base():
        def f(self, x:Base) -> Base
      class Foo(Base):
        pass

      class Match():
        def f(self, x:Base) -> Base
    """))
        ast = visitors.LookupClasses(ast, self.mini_builtins)
        m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
        eq = m.match_Class_against_Class(ast.Lookup("Match"),
                                         ast.Lookup("Foo"), {})
        self.assertEquals(eq, booleq.TRUE)
Exemplo n.º 18
0
    def test_base_class(self):
        ast = parser.parse_string(textwrap.dedent("""
      class Base():
        def f(self, x:Base) -> Base: ...
      class Foo(Base):
        pass

      class Match():
        def f(self, x:Base) -> Base: ...
    """),
                                  options=self.options)
        ast = self.LinkAgainstSimpleBuiltins(ast)
        m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
        eq = m.match_Class_against_Class(ast.Lookup("Match"),
                                         ast.Lookup("Foo"), {})
        self.assertEqual(eq, booleq.TRUE)
Exemplo n.º 19
0
 def test_unknown_against_tuple(self):
   ast = self.ParseWithBuiltins(pytd_src("""
     from typing import Tuple
     class `~unknown0`():
       pass
     x = ...  # type: Tuple[int, str]
   """))
   unknown0 = escape.unknown(0)
   unk = ast.Lookup(unknown0)
   tup = ast.Lookup("x").type
   m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
   match = m.match_Unknown_against_Generic(unk, tup, {})
   self.assertCountEqual(sorted(match.extract_equalities()),
                         [(unknown0, "__builtin__.tuple"),
                          (f"{unknown0}.__builtin__.tuple._T", "int"),
                          (f"{unknown0}.__builtin__.tuple._T", "str")])
Exemplo n.º 20
0
  def solve(self):
    """Solve the equations generated from the pytd.

    Returns:
      A dictionary (str->str), mapping unknown class names to known class names.
    Raises:
      AssertionError: If we detect an internal error.
    """
    hierarchy = type_match.get_all_subclasses([self.ast, self.builtins])
    factory = type_match.TypeMatch(hierarchy)
    solver = factory.solver

    unknown_classes = set()
    partial_classes = set()
    complete_classes = set()
    for cls in self.ast.classes:
      if is_unknown(cls):
        solver.register_variable(cls.name)
        unknown_classes.add(cls)
      elif is_partial(cls):
        partial_classes.add(cls)
      else:
        complete_classes.add(cls)

    for complete in complete_classes.union(self.builtins.classes):
      for unknown in unknown_classes:
        self.match_unknown_against_complete(factory, solver, unknown, complete)
      for partial in partial_classes:
        if type_match.unpack_name_of_partial(partial.name) == complete.name:
          self.match_partial_against_complete(
              factory, solver, partial, complete)

    partial_functions = set()
    complete_functions = set()
    for f in self.ast.functions:
      if is_partial(f):
        partial_functions.add(f)
      else:
        complete_functions.add(f)
    for partial in partial_functions:
      for complete in complete_functions.union(self.builtins.functions):
        if type_match.unpack_name_of_partial(partial.name) == complete.name:
          self.match_call_record(factory, solver, partial, complete)

    log.info("=========== Equations to solve =============\n%s", solver)
    log.info("=========== Equations to solve (end) =======")
    return solver.solve()
Exemplo n.º 21
0
 def testTuple(self):
   ast = self.ParseWithBuiltins("""
     from typing import Tuple
     x1 = ...  # type: Tuple[bool, ...]
     x2 = ...  # type: Tuple[int, ...]
     y1 = ...  # type: Tuple[bool, int]
   """)
   m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
   x1 = ast.Lookup("x1").type
   x2 = ast.Lookup("x2").type
   y1 = ast.Lookup("y1").type
   # Tuple[T, ...] matches Tuple[U, V] when T matches both U and V.
   self.assertEqual(m.match_Generic_against_Generic(x1, y1, {}), booleq.TRUE)
   self.assertEqual(m.match_Generic_against_Generic(x2, y1, {}), booleq.FALSE)
   # Tuple[U, V] matches Tuple[T, ...] when Union[U, V] matches T.
   self.assertEqual(m.match_Generic_against_Generic(y1, x1, {}), booleq.FALSE)
   self.assertEqual(m.match_Generic_against_Generic(y1, x2, {}), booleq.TRUE)
Exemplo n.º 22
0
 def testStrict(self):
   ast = parser.parse_string(textwrap.dedent("""
     class list<T>(nothing):
       pass
     class A(nothing):
       pass
     class B(A):
       pass
     class `~unknown0`(nothing):
       pass
     a : A
     def left() -> `~unknown0`
     def right() -> list<A>
   """))
   ast = visitors.LookupClasses(ast)
   m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
   left, right = ast.Lookup("left"), ast.Lookup("right")
   self.assertEquals(m.match(left, right, {}),
                     booleq.And((booleq.Eq("~unknown0", "list"),
                                 booleq.Eq("~unknown0.list.T", "A"))))
Exemplo n.º 23
0
  def testStrict(self):
    ast = parser.parse_string(textwrap.dedent("""

      T = TypeVar('T')
      class list(typing.Generic[T], object):
        pass
      class A():
        pass
      class B(A):
        pass
      class `~unknown0`():
        pass
      a = ...  # type: A
      def left() -> `~unknown0`
      def right() -> list[A]
    """))
    ast = visitors.LookupClasses(ast, self.mini_builtins)
    m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
    left, right = ast.Lookup("left"), ast.Lookup("right")
    self.assertEquals(m.match(left, right, {}),
                      booleq.And((booleq.Eq("~unknown0", "list"),
                                  booleq.Eq("~unknown0.list.T", "A"))))
Exemplo n.º 24
0
  def solve(self):
    """Solve the equations generated from the pytd.

    Returns:
      A dictionary (str->str), mapping unknown class names to known class names.
    Raises:
      AssertionError: If we detect an internal error.
    """
    hierarchy = type_match.get_all_subclasses([self.ast, self.builtins])
    factory_protocols = type_match.TypeMatch(hierarchy)
    factory_partial = type_match.TypeMatch(hierarchy)
    solver_protocols = factory_protocols.solver
    solver_partial = factory_partial.solver

    unknown_classes = set()
    partial_classes = set()
    complete_classes = set()
    for cls in self.ast.classes:
      if is_unknown(cls):
        solver_protocols.register_variable(cls.name)
        solver_partial.register_variable(cls.name)
        unknown_classes.add(cls)
      elif is_partial(cls):
        partial_classes.add(cls)
      else:
        complete_classes.add(cls)

    protocol_classes_and_aliases = set(self.protocols.classes)
    for alias in self.protocols.aliases:
      if (not isinstance(alias.type, pytd.AnythingType)
          and alias.name != "protocols.Protocol"):
        protocol_classes_and_aliases.add(alias.type.cls)

    # solve equations from protocols first
    for protocol in protocol_classes_and_aliases:
      for unknown in unknown_classes:
        self.match_unknown_against_protocol(
            factory_protocols, solver_protocols, unknown, protocol)

    # also solve partial equations
    for complete in complete_classes.union(self.builtins.classes):
      for partial in partial_classes:
        if type_match.unpack_name_of_partial(partial.name) == complete.name:
          self.match_partial_against_complete(
              factory_partial, solver_partial, partial, complete)

    partial_functions = set()
    complete_functions = set()
    for f in self.ast.functions:
      if is_partial(f):
        partial_functions.add(f)
      else:
        complete_functions.add(f)
    for partial in partial_functions:
      for complete in complete_functions.union(self.builtins.functions):
        if type_match.unpack_name_of_partial(partial.name) == complete.name:
          self.match_call_record(
              factory_partial, solver_partial, partial, complete)

    log.info("=========== Equations to solve =============\n%s",
             solver_protocols)
    log.info("=========== Equations to solve (end) =======")
    solved_protocols = solver_protocols.solve()
    log.info("=========== Call trace equations to solve =============\n%s",
             solver_partial)
    log.info("=========== Call trace equations to solve (end) =======")
    solved_partial = solver_partial.solve()
    merged_solution = {}
    for unknown in itertools.chain(solved_protocols, solved_partial):
      if unknown in solved_protocols and unknown in solved_partial:
        merged_solution[unknown] = solved_protocols[unknown].union(
            solved_partial[unknown])
        # remove Any from set if present
        # if no restrictions are present, it will be labeled Any later
        # otherwise, Any will override other restrictions that were found
        merged_solution[unknown].discard("?")
      elif unknown in solved_protocols:
        merged_solution[unknown] = solved_protocols[unknown]
      else:
        merged_solution[unknown] = solved_partial[unknown]
    return merged_solution