Ejemplo n.º 1
0
def Optimize(node,
             builtins=None,
             lossy=False,
             use_abcs=False,
             max_union=7,
             remove_mutable=False,
             can_do_lookup=True):
    """Optimize a PYTD tree.

  Tries to shrink a PYTD tree by applying various optimizations.

  Arguments:
    node: A pytd node to be optimized. It won't be modified - this function
        will return a new node.
    builtins: Definitions of all of the external types in node.
    lossy: Allow optimizations that change the meaning of the pytd.
    use_abcs: Use abstract base classes to represent unions like
        e.g. "float or int" as "Real".
    max_union: How many types we allow in a union before we simplify
        it to just "object".
    remove_mutable: Whether to simplify mutable parameters to normal
        parameters.
    can_do_lookup: True: We're either allowed to try to resolve NamedType
        instances in the AST, or the AST is already resolved. False: Skip any
        optimizations that would require NamedTypes to be resolved.

  Returns:
    An optimized node.
  """
    node = node.Visit(RemoveDuplicates())
    node = node.Visit(SimplifyUnions())
    node = node.Visit(CombineReturnsAndExceptions())
    node = node.Visit(Factorize())
    node = node.Visit(ApplyOptionalArguments())
    node = node.Visit(CombineContainers())
    node = node.Visit(SimplifyContainers())
    if builtins:
        superclasses = builtins.Visit(visitors.ExtractSuperClassesByName())
        superclasses.update(node.Visit(visitors.ExtractSuperClassesByName()))
        if use_abcs:
            superclasses.update(abc_hierarchy.GetSuperClasses())
        hierarchy = SuperClassHierarchy(superclasses)
        node = node.Visit(SimplifyUnionsWithSuperclasses(hierarchy))
        if lossy:
            node = node.Visit(FindCommonSuperClasses(hierarchy))
    if max_union:
        node = node.Visit(CollapseLongUnions(max_union))
    node = node.Visit(AdjustReturnAndConstantGenericType())
    if remove_mutable:
        node = node.Visit(AbsorbMutableParameters())
        node = node.Visit(CombineContainers())
        node = node.Visit(MergeTypeParameters())
        node = node.Visit(visitors.AdjustSelf(force=True))
    node = node.Visit(SimplifyContainers())
    if builtins and can_do_lookup:
        node = visitors.LookupClasses(node, builtins)
        node = node.Visit(RemoveInheritedMethods())
        node = node.Visit(RemoveRedundantSignatures(hierarchy))
    return node
Ejemplo n.º 2
0
 def __init__(self, superclasses=None, use_abcs=True):
     super(FindCommonSuperClasses, self).__init__()
     self._superclasses = builtins.GetBuiltinsPyTD().Visit(
         visitors.ExtractSuperClassesByName())
     self._superclasses.update(superclasses or {})
     if use_abcs:
         self._superclasses.update(abc_hierarchy.GetSuperClasses())
     self._subclasses = abc_hierarchy.Invert(self._superclasses)
Ejemplo n.º 3
0
 def testBuiltinSuperClasses(self):
     src = textwrap.dedent("""
     def f(x: list or object, y: complex or memoryview) -> int or bool
 """)
     expected = textwrap.dedent("""
     def f(x, y) -> int
 """)
     hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName())
     hierarchy.update(
         self.typing.Visit(visitors.ExtractSuperClassesByName()))
     visitor = optimize.FindCommonSuperClasses(
         optimize.SuperClassHierarchy(hierarchy))
     ast = self.ParseAndResolve(src)
     ast = ast.Visit(visitor)
     ast = ast.Visit(visitors.DropBuiltinPrefix())
     ast = ast.Visit(visitors.CanonicalOrderingVisitor())
     self.AssertSourceEquals(ast, expected)
Ejemplo n.º 4
0
 def testSimplifyUnionsWithSuperclassesGeneric(self):
     src = textwrap.dedent("""
     x = ...  # type: frozenset[int] or AbstractSet[int]
 """)
     expected = textwrap.dedent("""
     x = ...  # type: AbstractSet[int]
 """)
     hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName())
     visitor = optimize.SimplifyUnionsWithSuperclasses(
         optimize.SuperClassHierarchy(hierarchy))
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, self.builtins)
     ast = ast.Visit(visitor)
     self.AssertSourceEquals(ast, expected)
Ejemplo n.º 5
0
 def testBuiltinSuperClasses(self):
     src = textwrap.dedent("""
     def f(x: list or object, y: int or float) -> int or bool
 """)
     expected = textwrap.dedent("""
     def f(x, y) -> int
 """)
     b = builtins.GetBuiltinsPyTD()
     hierarchy = b.Visit(visitors.ExtractSuperClassesByName())
     visitor = optimize.FindCommonSuperClasses(
         optimize.SuperClassHierarchy(hierarchy))
     ast = self.ParseAndResolve(src)
     ast = ast.Visit(visitor)
     ast = ast.Visit(visitors.DropBuiltinPrefix())
     self.AssertSourceEquals(ast, expected)
Ejemplo n.º 6
0
 def testFindCommonSuperClasses(self):
     src = textwrap.dedent("""
     x = ...  # type: int or other.Bar
 """)
     expected = textwrap.dedent("""
     x = ...  # type: int or other.Bar
 """)
     ast = self.Parse(src)
     ast = ast.Visit(
         visitors.ReplaceTypes({"other.Bar": pytd.LateType("other.Bar")}))
     hierarchy = ast.Visit(visitors.ExtractSuperClassesByName())
     ast = ast.Visit(
         optimize.FindCommonSuperClasses(
             optimize.SuperClassHierarchy(hierarchy)))
     ast = ast.Visit(visitors.LateTypeToClassType())
     self.AssertSourceEquals(ast, expected)
Ejemplo n.º 7
0
 def testSimplifyUnionsWithSuperclasses(self):
     src = textwrap.dedent("""
     x = ...  # type: int or bool
     y = ...  # type: int or bool or float
     z = ...  # type: list[int] or int
 """)
     expected = textwrap.dedent("""
     x = ...  # type: int
     y = ...  # type: int or float
     z = ...  # type: list[int] or int
 """)
     hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName())
     visitor = optimize.SimplifyUnionsWithSuperclasses(
         optimize.SuperClassHierarchy(hierarchy))
     ast = self.Parse(src)
     ast = visitors.LookupClasses(ast, self.builtins)
     ast = ast.Visit(visitor)
     self.AssertSourceEquals(ast, expected)
Ejemplo n.º 8
0
def Optimize(node,
             lossy=False,
             use_abcs=False,
             max_union=7,
             remove_mutable=False):
    """Optimize a PYTD tree.

  Tries to shrink a PYTD tree by applying various optimizations.

  Arguments:
    node: A pytd node to be optimized. It won't be modified - this function
        will return a new node.
    lossy: Allow optimizations that change the meaning of the pytd.
    use_abcs: Use abstract base classes to represent unions like
        e.g. "float or int" as "Real"
    max_union: How many types we allow in a union before we simplify
        it to just "object".
    remove_mutable: Whether to simplify mutable parameters to normal
        parameters.

  Returns:
    An optimized node.
  """
    node = node.Visit(RemoveDuplicates())
    node = node.Visit(SimplifyUnions())
    node = node.Visit(CombineReturnsAndExceptions())
    node = node.Visit(Factorize())
    node = node.Visit(ApplyOptionalArguments())
    node = node.Visit(CombineContainers())
    if lossy:
        hierarchy = node.Visit(visitors.ExtractSuperClassesByName())
        node = node.Visit(FindCommonSuperClasses(hierarchy, use_abcs))
    if max_union:
        node = node.Visit(CollapseLongParameterUnions(max_union))
        node = node.Visit(CollapseLongReturnUnions(max_union))
        node = node.Visit(CollapseLongConstantUnions(max_union))
    if remove_mutable:
        node = node.Visit(AbsorbMutableParameters())
        node = node.Visit(CombineContainers())
        node = node.Visit(MergeTypeParameters())
        node = node.Visit(visitors.AdjustSelf(force=True))
    node = visitors.LookupClasses(node, builtins.GetBuiltinsPyTD())
    node = node.Visit(RemoveInheritedMethods())
    return node
Ejemplo n.º 9
0
  def testUserSuperClassHierarchy(self):
    class_data = textwrap.dedent("""
        class AB(object):
            pass

        class EFG(object):
            pass

        class A(AB, EFG):
            pass

        class B(AB):
            pass

        class E(EFG, AB):
            pass

        class F(EFG):
            pass

        class G(EFG):
            pass
    """)

    src = textwrap.dedent("""
        def f(x: A or B, y: A, z: B) -> E or F or G
        def g(x: E or F or G or B) -> E or F
        def h(x) -> ?
    """) + class_data

    expected = textwrap.dedent("""
        def f(x: AB, y: A, z: B) -> EFG
        def g(x) -> EFG
        def h(x) -> ?
    """) + class_data

    hierarchy = self.Parse(src).Visit(
        visitors.ExtractSuperClassesByName())
    visitor = optimize.FindCommonSuperClasses(
        optimize.SuperClassHierarchy(hierarchy))
    new_src = self.ApplyVisitorToString(src, visitor)
    self.AssertSourceEquals(new_src, expected)
Ejemplo n.º 10
0
 def testSuperClassesByName(self):
   src = textwrap.dedent("""
     class A():
         pass
     class B():
         pass
     class C(A):
         pass
     class D(A,B):
         pass
     class E(C,D,A):
         pass
   """)
   tree = self.Parse(src)
   data = tree.Visit(visitors.ExtractSuperClassesByName())
   self.assertItemsEqual(("classobj",), data["A"])
   self.assertItemsEqual(("classobj",), data["B"])
   self.assertItemsEqual(("A",), data["C"])
   self.assertItemsEqual(("A", "B"), data["D"])
   self.assertItemsEqual(("A", "C", "D"), data["E"])