Exemplo n.º 1
0
def Optimize(node,
             lossy=False,
             use_abcs=False,
             max_union=7,
             remove_mutable=False,
             can_do_lookup=True):
    """Optimize a PYTD tree.

  Tries to shrink a PYTD tree by applying various optimizations.

  Arguments:
    node: A pytd node to be optimized. It won't be modified - this function
        will return a new node.
    lossy: Allow optimizations that change the meaning of the pytd.
    use_abcs: Use abstract base classes to represent unions like
        e.g. "float or int" as "Real"
    max_union: How many types we allow in a union before we simplify
        it to just "object".
    remove_mutable: Whether to simplify mutable parameters to normal
        parameters.
    can_do_lookup: True: We're either allowed to try to resolve NamedType
        instances in the AST, or the AST is already resolved. False: Skip any
        optimizations that would require NamedTypes to be resolved.

  Returns:
    An optimized node.
  """
    node = node.Visit(RemoveDuplicates())
    node = node.Visit(SimplifyUnions())
    node = node.Visit(CombineReturnsAndExceptions())
    node = node.Visit(Factorize())
    node = node.Visit(ApplyOptionalArguments())
    node = node.Visit(CombineContainers())
    node = node.Visit(SimplifyContainers())
    superclasses = builtins.GetBuiltinsPyTD().Visit(
        visitors.ExtractSuperClassesByName())
    superclasses.update(node.Visit(visitors.ExtractSuperClassesByName()))
    if use_abcs:
        superclasses.update(abc_hierarchy.GetSuperClasses())
    hierarchy = SuperClassHierarchy(superclasses)
    node = node.Visit(SimplifyUnionsWithSuperclasses(hierarchy))
    if lossy:
        node = node.Visit(FindCommonSuperClasses(hierarchy))
    if max_union:
        node = node.Visit(CollapseLongUnions(max_union))
    node = node.Visit(AdjustReturnAndConstantGenericType())
    if remove_mutable:
        node = node.Visit(AbsorbMutableParameters())
        node = node.Visit(CombineContainers())
        node = node.Visit(MergeTypeParameters())
        node = node.Visit(visitors.AdjustSelf(force=True))
    node = node.Visit(SimplifyContainers())
    if can_do_lookup:
        node = visitors.LookupClasses(node, builtins.GetBuiltinsPyTD())
        node = node.Visit(RemoveInheritedMethods())
        node = node.Visit(RemoveRedundantSignatures(hierarchy))
    return node
Exemplo n.º 2
0
 def testSimplifyUnionsWithSuperclassesGeneric(self):
     src = textwrap.dedent("""
     x = ...  # type: frozenset[int] or AbstractSet[int]
 """)
     expected = textwrap.dedent("""
     x = ...  # type: AbstractSet[int]
 """)
     hierarchy = builtins.GetBuiltinsPyTD().Visit(
         visitors.ExtractSuperClassesByName())
     visitor = optimize.SimplifyUnionsWithSuperclasses(
         optimize.SuperClassHierarchy(hierarchy))
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD())
     ast = ast.Visit(visitor)
     self.AssertSourceEquals(ast, expected)
Exemplo n.º 3
0
 def testRemoveInheritedMethodsWithOverride(self):
     src = textwrap.dedent("""
     class A(object):
         def f(self, x) -> Any
     class B(A):
         def f(self) -> Any
     class C(B):
         def f(self) -> Any
     class D(B):
         def f(self, x) -> Any
 """)
     expected = textwrap.dedent("""
     class A(object):
         def f(self, x) -> Any
     class B(A):
         def f(self) -> Any
     class C(B):
         pass
     class D(B):
         def f(self, x) -> Any
 """)
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD())
     ast = ast.Visit(optimize.RemoveInheritedMethods())
     self.AssertSourceEquals(ast, expected)
Exemplo n.º 4
0
    def testAddInheritedMethods(self):
        src = textwrap.dedent("""
        class A():
            foo = ...  # type: bool
            def f(self, x: int) -> float
            def h(self) -> complex

        class B(A):
            bar = ...  # type: int
            def g(self, y: int) -> bool
            def h(self, z: float) -> ?
    """)
        expected = textwrap.dedent("""
        class A():
            foo = ...  # type: bool
            def f(self, x: int) -> float
            def h(self) -> complex

        class B(A):
            bar = ...  # type: int
            foo = ...  # type: bool
            def g(self, y: int) -> bool
            def h(self, z: float) -> ?
            def f(self, x: int) -> float
    """)
        ast = self.Parse(src)
        ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD())
        ast = ast.Visit(optimize.AddInheritedMethods())
        self.AssertSourceEquals(ast, expected)
Exemplo n.º 5
0
 def testSimplifyUnionsWithSuperclasses(self):
     src = textwrap.dedent("""
     x = ...  # type: int or bool
     y = ...  # type: int or bool or float
     z = ...  # type: list[int] or int
 """)
     expected = textwrap.dedent("""
     x = ...  # type: int
     y = ...  # type: int or float
     z = ...  # type: list[int] or int
 """)
     hierarchy = builtins.GetBuiltinsPyTD().Visit(
         visitors.ExtractSuperClassesByName())
     visitor = optimize.SimplifyUnionsWithSuperclasses(
         optimize.SuperClassHierarchy(hierarchy))
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD())
     ast = ast.Visit(visitor)
     self.AssertSourceEquals(ast, expected)
Exemplo n.º 6
0
 def testRemoveInheritedMethodsWithCircle(self):
     src = textwrap.dedent("""
     class A(B):
         def f(self) -> Any
     class B(A):
         def f(self) -> Any
 """)
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD())
     ast = ast.Visit(optimize.RemoveInheritedMethods())
     self.AssertSourceEquals(ast, src)
Exemplo n.º 7
0
 def testDontRemoveAbstractMethodImplementation(self):
     src = textwrap.dedent("""
   class A(object):
     @abstractmethod
     def foo(self): ...
   class B(A):
     def foo(self): ...
 """)
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD())
     ast = ast.Visit(optimize.RemoveInheritedMethods())
     self.AssertSourceEquals(ast, src)
Exemplo n.º 8
0
 def testBuiltinSuperClasses(self):
     src = textwrap.dedent("""
     def f(x: list or object, y: int or float) -> int or bool
 """)
     expected = textwrap.dedent("""
     def f(x, y) -> int
 """)
     b = builtins.GetBuiltinsPyTD()
     hierarchy = b.Visit(visitors.ExtractSuperClassesByName())
     visitor = optimize.FindCommonSuperClasses(
         optimize.SuperClassHierarchy(hierarchy))
     ast = self.ParseAndResolve(src)
     ast = ast.Visit(visitor)
     ast = ast.Visit(visitors.DropBuiltinPrefix())
     self.AssertSourceEquals(ast, expected)
Exemplo n.º 9
0
 def testRemoveInheritedMethodsWithDiamond(self):
     src = textwrap.dedent("""
     class A(object):
         def f(self, x) -> ?
     class B(A):
         pass
     class C(A):
         def f(self, x, y) -> ?
     class D(B, C):
         def f(self, x) -> ?
 """)
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD())
     ast = ast.Visit(optimize.RemoveInheritedMethods())
     self.AssertSourceEquals(ast, src)
Exemplo n.º 10
0
 def testBuiltinSuperClasses(self):
     src = textwrap.dedent("""
     def f(x: list or object, y: complex or memoryview) -> int or bool
 """)
     expected = textwrap.dedent("""
     def f(x, y) -> int
 """)
     b = builtins.GetBuiltinsPyTD(self.PYTHON_VERSION)
     hierarchy = b.Visit(visitors.ExtractSuperClassesByName())
     visitor = optimize.FindCommonSuperClasses(
         optimize.SuperClassHierarchy(hierarchy))
     ast = self.ParseAndResolve(src)
     ast = ast.Visit(visitor)
     ast = ast.Visit(visitors.DropBuiltinPrefix())
     ast = ast.Visit(visitors.CanonicalOrderingVisitor())
     self.AssertSourceEquals(ast, expected)
Exemplo n.º 11
0
  def testAddInheritedMethods(self):
    src = textwrap.dedent("""
        class A():
            foo = ...  # type: bool
            def f(self, x: int) -> float
            def h(self) -> complex

        class B(A):
            bar = ...  # type: int
            def g(self, y: int) -> bool
            def h(self, z: float) -> ?
    """)
    ast = self.Parse(src)
    ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD())
    self.assertItemsEqual(("g", "h"), [m.name for m in ast.Lookup("B").methods])
    ast = ast.Visit(optimize.AddInheritedMethods())
    self.assertItemsEqual(("f", "g", "h"),
                          [m.name for m in ast.Lookup("B").methods])
Exemplo n.º 12
0
    def testRemoveInheritedMethodsWithoutSelf(self):
        src = textwrap.dedent("""
        class Bar(object):
          def baz(self) -> int

        class Foo(Bar):
          def baz(self) -> int
          def bar() -> float
    """)
        expected = textwrap.dedent("""
        class Bar(object):
          def baz(self) -> int

        class Foo(Bar):
          def bar() -> float
    """)
        ast = self.Parse(src)
        ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD())
        ast = ast.Visit(optimize.RemoveInheritedMethods())
        self.AssertSourceEquals(ast, expected)
Exemplo n.º 13
0
    def testRemoveInheritedMethods(self):
        src = textwrap.dedent("""
        class A():
            def f(self, y: int) -> bool
            def g(self) -> float

        class B(A):
            def b(self) -> B
            def f(self, y: int) -> bool

        class C(A):
            def c(self) -> C
            def f(self, y: int) -> bool

        class D(B):
            def g(self) -> float
            def d(self) -> D
    """)
        expected = textwrap.dedent("""
        class A():
            def f(self, y: int) -> bool
            def g(self) -> float

        class B(A):
            def b(self) -> B

        class C(A):
            def c(self) -> C

        class D(B):
            def d(self) -> D
    """)
        ast = self.Parse(src)
        ast = visitors.LookupClasses(
            ast, builtins.GetBuiltinsPyTD(self.PYTHON_VERSION))
        ast = ast.Visit(optimize.RemoveInheritedMethods())
        self.AssertSourceEquals(ast, expected)
Exemplo n.º 14
0
 def setUpClass(cls):
     super(UtilsTest, cls).setUpClass()
     cls.builtins = builtins.GetBuiltinsPyTD(cls.python_version)
Exemplo n.º 15
0
 def Optimize(self, ast, **kwargs):
     return optimize.Optimize(ast, builtins.GetBuiltinsPyTD(), **kwargs)
Exemplo n.º 16
0
 def Optimize(self, ast, **kwargs):
     return optimize.Optimize(ast,
                              builtins.GetBuiltinsPyTD(self.PYTHON_VERSION),
                              **kwargs)