Ejemplo n.º 1
0
 def testFindUnknownVisitor(self):
   src = textwrap.dedent("""
       class classobj:
         pass
       class `~unknown1`():
         pass
       class `~unknown_foobar`():
         pass
       class `~int`():
         pass
       class A():
         def foobar(self, x: `~unknown1`) -> ?
       class B():
         def foobar(self, x: `~int`) -> ?
       class C():
         x = ... # type: `~unknown_foobar`
       class D(`~unknown1`):
         pass
   """)
   tree = self.Parse(src)
   tree = visitors.LookupClasses(tree)
   find_on = lambda x: tree.Lookup(x).Visit(visitors.RaiseIfContainsUnknown())
   self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "A")
   find_on("B")  # shouldn't raise
   self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "C")
   self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "D")
Ejemplo n.º 2
0
 def testSuperClasses(self):
     src = textwrap.dedent("""
   class classobj:
       pass
   class A():
       pass
   class B():
       pass
   class C(A):
       pass
   class D(A,B):
       pass
   class E(C,D,A):
       pass
 """)
     ast = visitors.LookupClasses(self.Parse(src))
     data = ast.Visit(visitors.ExtractSuperClasses())
     self.assertItemsEqual(["classobj"],
                           [t.name for t in data[ast.Lookup("A")]])
     self.assertItemsEqual(["classobj"],
                           [t.name for t in data[ast.Lookup("B")]])
     self.assertItemsEqual(["A"], [t.name for t in data[ast.Lookup("C")]])
     self.assertItemsEqual(["A", "B"],
                           [t.name for t in data[ast.Lookup("D")]])
     self.assertItemsEqual(["C", "D", "A"],
                           [t.name for t in data[ast.Lookup("E")]])
Ejemplo n.º 3
0
 def test_find_unknown_visitor(self):
     src = pytd_src("""
     from typing import Any
     class object:
       pass
     class `~unknown1`():
       pass
     class `~unknown_foobar`():
       pass
     class `~int`():
       pass
     class A():
       def foobar(self, x: `~unknown1`) -> Any: ...
     class B():
       def foobar(self, x: `~int`) -> Any: ...
     class C():
       x = ... # type: `~unknown_foobar`
     class D(`~unknown1`):
       pass
 """)
     tree = self.Parse(src)
     tree = visitors.LookupClasses(tree)
     find_on = lambda x: tree.Lookup(x).Visit(visitors.
                                              RaiseIfContainsUnknown())
     self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on,
                       "A")
     find_on("B")  # shouldn't raise
     self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on,
                       "C")
     self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on,
                       "D")
Ejemplo n.º 4
0
 def test_superclasses(self):
   src = textwrap.dedent("""
     class object:
         pass
     class A():
         pass
     class B():
         pass
     class C(A):
         pass
     class D(A,B):
         pass
     class E(C,D,A):
         pass
   """)
   ast = visitors.LookupClasses(self.Parse(src))
   data = ast.Visit(pytd_visitors.ExtractSuperClasses())
   six.assertCountEqual(self,
                        ["object"], [t.name for t in data[ast.Lookup("A")]])
   six.assertCountEqual(self,
                        ["object"], [t.name for t in data[ast.Lookup("B")]])
   six.assertCountEqual(self, ["A"], [t.name for t in data[ast.Lookup("C")]])
   six.assertCountEqual(self,
                        ["A", "B"], [t.name for t in data[ast.Lookup("D")]])
   six.assertCountEqual(self, ["C", "D", "A"],
                        [t.name for t in data[ast.Lookup("E")]])
Ejemplo n.º 5
0
 def testRemoveInheritedMethodsWithOverride(self):
     src = textwrap.dedent("""
     class A(object):
         def f(self, x) -> ?
     class B(A):
         def f(self) -> ?
     class C(B):
         def f(self) -> ?
     class D(B):
         def f(self, x) -> ?
 """)
     expected = textwrap.dedent("""
     class A(object):
         def f(self, x) -> ?
     class B(A):
         def f(self) -> ?
     class C(B):
         pass
     class D(B):
         def f(self, x) -> ?
 """)
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, self.builtins)
     ast = ast.Visit(optimize.RemoveInheritedMethods())
     self.AssertSourceEquals(ast, expected)
Ejemplo n.º 6
0
def Optimize(node,
             builtins=None,
             lossy=False,
             use_abcs=False,
             max_union=7,
             remove_mutable=False,
             can_do_lookup=True):
    """Optimize a PYTD tree.

  Tries to shrink a PYTD tree by applying various optimizations.

  Arguments:
    node: A pytd node to be optimized. It won't be modified - this function
        will return a new node.
    builtins: Definitions of all of the external types in node.
    lossy: Allow optimizations that change the meaning of the pytd.
    use_abcs: Use abstract base classes to represent unions like
        e.g. "float or int" as "Real".
    max_union: How many types we allow in a union before we simplify
        it to just "object".
    remove_mutable: Whether to simplify mutable parameters to normal
        parameters.
    can_do_lookup: True: We're either allowed to try to resolve NamedType
        instances in the AST, or the AST is already resolved. False: Skip any
        optimizations that would require NamedTypes to be resolved.

  Returns:
    An optimized node.
  """
    node = node.Visit(RemoveDuplicates())
    node = node.Visit(SimplifyUnions())
    node = node.Visit(CombineReturnsAndExceptions())
    node = node.Visit(Factorize())
    node = node.Visit(ApplyOptionalArguments())
    node = node.Visit(CombineContainers())
    node = node.Visit(SimplifyContainers())
    if builtins:
        superclasses = builtins.Visit(visitors.ExtractSuperClassesByName())
        superclasses.update(node.Visit(visitors.ExtractSuperClassesByName()))
        if use_abcs:
            superclasses.update(abc_hierarchy.GetSuperClasses())
        hierarchy = SuperClassHierarchy(superclasses)
        node = node.Visit(SimplifyUnionsWithSuperclasses(hierarchy))
        if lossy:
            node = node.Visit(FindCommonSuperClasses(hierarchy))
    if max_union:
        node = node.Visit(CollapseLongUnions(max_union))
    node = node.Visit(AdjustReturnAndConstantGenericType())
    if remove_mutable:
        node = node.Visit(AbsorbMutableParameters())
        node = node.Visit(CombineContainers())
        node = node.Visit(MergeTypeParameters())
        node = node.Visit(visitors.AdjustSelf())
    node = node.Visit(SimplifyContainers())
    if builtins and can_do_lookup:
        node = visitors.LookupClasses(node, builtins, ignore_late_types=True)
        node = node.Visit(RemoveInheritedMethods())
        node = node.Visit(RemoveRedundantSignatures(hierarchy))
    return node
Ejemplo n.º 7
0
def solve(ast, builtins_pytd, protocols_pytd):
    """Solve the unknowns in a pytd AST using the standard Python builtins.

  Args:
    ast: A pytd.TypeDeclUnit, containing classes named ~unknownXX.
    builtins_pytd: A pytd for builtins.
    protocols_pytd: A pytd for protocols.

  Returns:
    A tuple of (1) a dictionary (str->str) mapping unknown class names to known
    class names and (2) a pytd.TypeDeclUnit of the complete classes in ast.
  """
    builtins_pytd = transforms.RemoveMutableParameters(builtins_pytd)
    builtins_pytd = visitors.LookupClasses(builtins_pytd)
    protocols_pytd = visitors.LookupClasses(protocols_pytd)
    ast = visitors.LookupClasses(ast, builtins_pytd)
    return TypeSolver(ast, builtins_pytd,
                      protocols_pytd).solve(), extract_local(ast)
Ejemplo n.º 8
0
 def testRemoveInheritedMethodsWithCircle(self):
     src = textwrap.dedent("""
     class A(B):
         def f(self) -> ?
     class B(A):
         def f(self) -> ?
 """)
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, self.builtins)
     ast = ast.Visit(optimize.RemoveInheritedMethods())
     self.AssertSourceEquals(ast, src)
Ejemplo n.º 9
0
 def testDontRemoveAbstractMethodImplementation(self):
     src = textwrap.dedent("""
   class A(object):
     @abstractmethod
     def foo(self): ...
   class B(A):
     def foo(self): ...
 """)
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, self.builtins)
     ast = ast.Visit(optimize.RemoveInheritedMethods())
     self.AssertSourceEquals(ast, src)
Ejemplo n.º 10
0
 def testClassMatch(self):
   ast = parser.parse_string(textwrap.dedent("""
     class Left():
       def method(self) -> ?
     class Right():
       def method(self) -> ?
       def method2(self) -> ?
   """), python_version=self.PYTHON_VERSION)
   ast = visitors.LookupClasses(ast, self.mini_builtins)
   m = type_match.TypeMatch()
   left, right = ast.Lookup("Left"), ast.Lookup("Right")
   self.assertEqual(m.match(left, right, {}), booleq.TRUE)
   self.assertNotEqual(m.match(right, left, {}), booleq.TRUE)
Ejemplo n.º 11
0
 def test_simplify_unions_with_superclasses_generic(self):
     src = pytd_src("""
     x = ...  # type: Union[frozenset[int], AbstractSet[int]]
 """)
     expected = pytd_src("""
     x = ...  # type: AbstractSet[int]
 """)
     hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName())
     visitor = optimize.SimplifyUnionsWithSuperclasses(
         optimize.SuperClassHierarchy(hierarchy))
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, self.builtins)
     ast = ast.Visit(visitor)
     self.AssertSourceEquals(ast, expected)
Ejemplo n.º 12
0
 def test_class_match(self):
   ast = parser.parse_string(textwrap.dedent("""
     from typing import Any
     class Left():
       def method(self) -> Any: ...
     class Right():
       def method(self) -> Any: ...
       def method2(self) -> Any: ...
   """), python_version=self.python_version)
   ast = visitors.LookupClasses(ast, self.mini_builtins)
   m = type_match.TypeMatch()
   left, right = ast.Lookup("Left"), ast.Lookup("Right")
   self.assertEqual(m.match(left, right, {}), booleq.TRUE)
   self.assertNotEqual(m.match(right, left, {}), booleq.TRUE)
Ejemplo n.º 13
0
 def testSimplifyUnionsWithSuperclassesGeneric(self):
     src = textwrap.dedent("""
     x = ...  # type: frozenset[int] or AbstractSet[int]
 """)
     expected = textwrap.dedent("""
     x = ...  # type: AbstractSet[int]
 """)
     hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName())
     visitor = optimize.SimplifyUnionsWithSuperclasses(
         optimize.SuperClassHierarchy(hierarchy))
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, self.builtins)
     ast = ast.Visit(visitor)
     self.AssertSourceEquals(ast, expected)
Ejemplo n.º 14
0
 def testAdjustInheritedMethodSelf(self):
   src = textwrap.dedent("""
     class A():
       def f(self: object) -> float
     class B(A):
       pass
   """)
   ast = self.Parse(src)
   ast = visitors.LookupClasses(ast, self.builtins)
   ast = ast.Visit(optimize.AddInheritedMethods())
   self.assertMultiLineEqual(pytd.Print(ast.Lookup("B")), textwrap.dedent("""\
       class B(A):
           def f(self) -> float: ...
   """))
Ejemplo n.º 15
0
 def test_remove_inherited_methods_with_diamond(self):
     src = textwrap.dedent("""
     class A(object):
         def f(self, x) -> ?
     class B(A):
         pass
     class C(A):
         def f(self, x, y) -> ?
     class D(B, C):
         def f(self, x) -> ?
 """)
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, self.builtins)
     ast = ast.Visit(optimize.RemoveInheritedMethods())
     self.AssertSourceEquals(ast, src)
Ejemplo n.º 16
0
 def testSubclasses(self):
   ast = parser.parse_string(textwrap.dedent("""
     class A():
       pass
     class B(A):
       pass
     a = ...  # type: A
     def left(a: B) -> B
     def right(a: A) -> A
   """), python_version=self.PYTHON_VERSION)
   ast = visitors.LookupClasses(ast, self.mini_builtins)
   m = type_match.TypeMatch(type_match.get_all_subclasses([ast]))
   left, right = ast.Lookup("left"), ast.Lookup("right")
   self.assertEqual(m.match(left, right, {}), booleq.TRUE)
   self.assertNotEqual(m.match(right, left, {}), booleq.TRUE)
Ejemplo n.º 17
0
 def test_adjust_inherited_method_self(self):
     src = pytd_src("""
   class A():
     def f(self: object) -> float: ...
   class B(A):
     pass
 """)
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, self.builtins)
     ast = ast.Visit(optimize.AddInheritedMethods())
     self.assertMultiLineEqual(
         pytd_utils.Print(ast.Lookup("B")),
         pytd_src("""
     class B(A):
         def f(self) -> float: ...
 """).lstrip())
Ejemplo n.º 18
0
  def testAddInheritedMethods(self):
    src = textwrap.dedent("""
        class A():
            foo = ...  # type: bool
            def f(self, x: int) -> float
            def h(self) -> complex

        class B(A):
            bar = ...  # type: int
            def g(self, y: int) -> bool
            def h(self, z: float) -> ?
    """)
    ast = self.Parse(src)
    ast = visitors.LookupClasses(ast, self.builtins)
    self.assertItemsEqual(("g", "h"), [m.name for m in ast.Lookup("B").methods])
    ast = ast.Visit(optimize.AddInheritedMethods())
    self.assertItemsEqual(("f", "g", "h"),
                          [m.name for m in ast.Lookup("B").methods])
Ejemplo n.º 19
0
 def test_simplify_unions_with_superclasses(self):
     src = pytd_src("""
     x = ...  # type: Union[int, bool]
     y = ...  # type: Union[int, bool, float]
     z = ...  # type: Union[list[int], int]
 """)
     expected = pytd_src("""
     x = ...  # type: int
     y = ...  # type: Union[int, float]
     z = ...  # type: Union[list[int], int]
 """)
     hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName())
     visitor = optimize.SimplifyUnionsWithSuperclasses(
         optimize.SuperClassHierarchy(hierarchy))
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, self.builtins)
     ast = ast.Visit(visitor)
     self.AssertSourceEquals(ast, expected)
Ejemplo n.º 20
0
 def testSimplifyUnionsWithSuperclasses(self):
     src = textwrap.dedent("""
     x = ...  # type: int or bool
     y = ...  # type: int or bool or float
     z = ...  # type: list[int] or int
 """)
     expected = textwrap.dedent("""
     x = ...  # type: int
     y = ...  # type: int or float
     z = ...  # type: list[int] or int
 """)
     hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName())
     visitor = optimize.SimplifyUnionsWithSuperclasses(
         optimize.SuperClassHierarchy(hierarchy))
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, self.builtins)
     ast = ast.Visit(visitor)
     self.AssertSourceEquals(ast, expected)
Ejemplo n.º 21
0
  def testLookupClasses(self):
    src = textwrap.dedent("""
        class object(object):
            pass

        class A(object):
            def a(self, a: A, b: B) -> A or B:
                raise A()
                raise B()

        class B(object):
            def b(self, a: A, b: B) -> A or B:
                raise A()
                raise B()
    """)
    tree = self.Parse(src)
    new_tree = visitors.LookupClasses(tree)
    self.AssertSourceEquals(new_tree, src)
    new_tree.Visit(visitors.VerifyLookup())
Ejemplo n.º 22
0
    def test_lookup_classes(self):
        src = textwrap.dedent("""
        from typing import Union
        class object:
            pass

        class A:
            def a(self, a: A, b: B) -> Union[A, B]:
                raise A()
                raise B()

        class B:
            def b(self, a: A, b: B) -> Union[A, B]:
                raise A()
                raise B()
    """)
        tree = self.Parse(src)
        new_tree = visitors.LookupClasses(tree)
        self.AssertSourceEquals(new_tree, src)
        new_tree.Visit(visitors.VerifyLookup())
Ejemplo n.º 23
0
    def test_add_inherited_methods(self):
        src = pytd_src("""
        from typing import Any
        class A():
            foo = ...  # type: bool
            def f(self, x: int) -> float: ...
            def h(self) -> complex: ...

        class B(A):
            bar = ...  # type: int
            def g(self, y: int) -> bool: ...
            def h(self, z: float) -> Any: ...
    """)
        ast = self.Parse(src)
        ast = visitors.LookupClasses(ast, self.builtins)
        self.assertCountEqual(("g", "h"),
                              [m.name for m in ast.Lookup("B").methods])
        ast = ast.Visit(optimize.AddInheritedMethods())
        self.assertCountEqual(("f", "g", "h"),
                              [m.name for m in ast.Lookup("B").methods])
Ejemplo n.º 24
0
    def testRemoveInheritedMethodsWithoutSelf(self):
        src = textwrap.dedent("""
        class Bar(object):
          def baz(self) -> int

        class Foo(Bar):
          def baz(self) -> int
          def bar() -> float
    """)
        expected = textwrap.dedent("""
        class Bar(object):
          def baz(self) -> int

        class Foo(Bar):
          def bar() -> float
    """)
        ast = self.Parse(src)
        ast = visitors.LookupClasses(ast, self.builtins)
        ast = ast.Visit(optimize.RemoveInheritedMethods())
        self.AssertSourceEquals(ast, expected)
Ejemplo n.º 25
0
    def testRemoveInheritedMethods(self):
        src = textwrap.dedent("""
        class A():
            def f(self, y: int) -> bool
            def g(self) -> float

        class B(A):
            def b(self) -> B
            def f(self, y: int) -> bool

        class C(A):
            def c(self) -> C
            def f(self, y: int) -> bool

        class D(B):
            def g(self) -> float
            def d(self) -> D
    """)
        expected = textwrap.dedent("""
        class A():
            def f(self, y: int) -> bool
            def g(self) -> float

        class B(A):
            def b(self) -> B

        class C(A):
            def c(self) -> C

        class D(B):
            def d(self) -> D
    """)
        ast = self.Parse(src)
        ast = visitors.LookupClasses(ast, self.builtins)
        ast = ast.Visit(optimize.RemoveInheritedMethods())
        self.AssertSourceEquals(ast, expected)
Ejemplo n.º 26
0
 def testLookupTypingClass(self):
   node = visitors.LookupClasses(pytd.NamedType("typing.Sequence"),
                                 self.loader.concat_all())
   assert node.cls
Ejemplo n.º 27
0
 def ParseWithLookup(self, src):
   tree = self.Parse(src)
   return visitors.LookupClasses(tree, builtins.GetBuiltinsPyTD(
       self.PYTHON_VERSION))
Ejemplo n.º 28
0
 def LinkAgainstSimpleBuiltins(self, ast):
     ast = ast.Visit(visitors.AdjustTypeParameters())
     ast = visitors.LookupClasses(ast, self.mini_builtins)
     return ast