def Optimize(node, builtins=None, lossy=False, use_abcs=False, max_union=7, remove_mutable=False, can_do_lookup=True): """Optimize a PYTD tree. Tries to shrink a PYTD tree by applying various optimizations. Arguments: node: A pytd node to be optimized. It won't be modified - this function will return a new node. builtins: Definitions of all of the external types in node. lossy: Allow optimizations that change the meaning of the pytd. use_abcs: Use abstract base classes to represent unions like e.g. "float or int" as "Real". max_union: How many types we allow in a union before we simplify it to just "object". remove_mutable: Whether to simplify mutable parameters to normal parameters. can_do_lookup: True: We're either allowed to try to resolve NamedType instances in the AST, or the AST is already resolved. False: Skip any optimizations that would require NamedTypes to be resolved. Returns: An optimized node. """ node = node.Visit(RemoveDuplicates()) node = node.Visit(SimplifyUnions()) node = node.Visit(CombineReturnsAndExceptions()) node = node.Visit(Factorize()) node = node.Visit(ApplyOptionalArguments()) node = node.Visit(CombineContainers()) node = node.Visit(SimplifyContainers()) if builtins: superclasses = builtins.Visit(visitors.ExtractSuperClassesByName()) superclasses.update(node.Visit(visitors.ExtractSuperClassesByName())) if use_abcs: superclasses.update(abc_hierarchy.GetSuperClasses()) hierarchy = SuperClassHierarchy(superclasses) node = node.Visit(SimplifyUnionsWithSuperclasses(hierarchy)) if lossy: node = node.Visit(FindCommonSuperClasses(hierarchy)) if max_union: node = node.Visit(CollapseLongUnions(max_union)) node = node.Visit(AdjustReturnAndConstantGenericType()) if remove_mutable: node = node.Visit(AbsorbMutableParameters()) node = node.Visit(CombineContainers()) node = node.Visit(MergeTypeParameters()) node = node.Visit(visitors.AdjustSelf(force=True)) node = node.Visit(SimplifyContainers()) if builtins and can_do_lookup: node = visitors.LookupClasses(node, builtins) node = node.Visit(RemoveInheritedMethods()) node = node.Visit(RemoveRedundantSignatures(hierarchy)) return node
def __init__(self, superclasses=None, use_abcs=True): super(FindCommonSuperClasses, self).__init__() self._superclasses = builtins.GetBuiltinsPyTD().Visit( visitors.ExtractSuperClassesByName()) self._superclasses.update(superclasses or {}) if use_abcs: self._superclasses.update(abc_hierarchy.GetSuperClasses()) self._subclasses = abc_hierarchy.Invert(self._superclasses)
def testBuiltinSuperClasses(self): src = textwrap.dedent(""" def f(x: list or object, y: complex or memoryview) -> int or bool """) expected = textwrap.dedent(""" def f(x, y) -> int """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) hierarchy.update( self.typing.Visit(visitors.ExtractSuperClassesByName())) visitor = optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.ParseAndResolve(src) ast = ast.Visit(visitor) ast = ast.Visit(visitors.DropBuiltinPrefix()) ast = ast.Visit(visitors.CanonicalOrderingVisitor()) self.AssertSourceEquals(ast, expected)
def testSimplifyUnionsWithSuperclassesGeneric(self): src = textwrap.dedent(""" x = ... # type: frozenset[int] or AbstractSet[int] """) expected = textwrap.dedent(""" x = ... # type: AbstractSet[int] """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def testBuiltinSuperClasses(self): src = textwrap.dedent(""" def f(x: list or object, y: int or float) -> int or bool """) expected = textwrap.dedent(""" def f(x, y) -> int """) b = builtins.GetBuiltinsPyTD() hierarchy = b.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.ParseAndResolve(src) ast = ast.Visit(visitor) ast = ast.Visit(visitors.DropBuiltinPrefix()) self.AssertSourceEquals(ast, expected)
def testFindCommonSuperClasses(self): src = textwrap.dedent(""" x = ... # type: int or other.Bar """) expected = textwrap.dedent(""" x = ... # type: int or other.Bar """) ast = self.Parse(src) ast = ast.Visit( visitors.ReplaceTypes({"other.Bar": pytd.LateType("other.Bar")})) hierarchy = ast.Visit(visitors.ExtractSuperClassesByName()) ast = ast.Visit( optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy))) ast = ast.Visit(visitors.LateTypeToClassType()) self.AssertSourceEquals(ast, expected)
def testSimplifyUnionsWithSuperclasses(self): src = textwrap.dedent(""" x = ... # type: int or bool y = ... # type: int or bool or float z = ... # type: list[int] or int """) expected = textwrap.dedent(""" x = ... # type: int y = ... # type: int or float z = ... # type: list[int] or int """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def Optimize(node, lossy=False, use_abcs=False, max_union=7, remove_mutable=False): """Optimize a PYTD tree. Tries to shrink a PYTD tree by applying various optimizations. Arguments: node: A pytd node to be optimized. It won't be modified - this function will return a new node. lossy: Allow optimizations that change the meaning of the pytd. use_abcs: Use abstract base classes to represent unions like e.g. "float or int" as "Real" max_union: How many types we allow in a union before we simplify it to just "object". remove_mutable: Whether to simplify mutable parameters to normal parameters. Returns: An optimized node. """ node = node.Visit(RemoveDuplicates()) node = node.Visit(SimplifyUnions()) node = node.Visit(CombineReturnsAndExceptions()) node = node.Visit(Factorize()) node = node.Visit(ApplyOptionalArguments()) node = node.Visit(CombineContainers()) if lossy: hierarchy = node.Visit(visitors.ExtractSuperClassesByName()) node = node.Visit(FindCommonSuperClasses(hierarchy, use_abcs)) if max_union: node = node.Visit(CollapseLongParameterUnions(max_union)) node = node.Visit(CollapseLongReturnUnions(max_union)) node = node.Visit(CollapseLongConstantUnions(max_union)) if remove_mutable: node = node.Visit(AbsorbMutableParameters()) node = node.Visit(CombineContainers()) node = node.Visit(MergeTypeParameters()) node = node.Visit(visitors.AdjustSelf(force=True)) node = visitors.LookupClasses(node, builtins.GetBuiltinsPyTD()) node = node.Visit(RemoveInheritedMethods()) return node
def testUserSuperClassHierarchy(self): class_data = textwrap.dedent(""" class AB(object): pass class EFG(object): pass class A(AB, EFG): pass class B(AB): pass class E(EFG, AB): pass class F(EFG): pass class G(EFG): pass """) src = textwrap.dedent(""" def f(x: A or B, y: A, z: B) -> E or F or G def g(x: E or F or G or B) -> E or F def h(x) -> ? """) + class_data expected = textwrap.dedent(""" def f(x: AB, y: A, z: B) -> EFG def g(x) -> EFG def h(x) -> ? """) + class_data hierarchy = self.Parse(src).Visit( visitors.ExtractSuperClassesByName()) visitor = optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy)) new_src = self.ApplyVisitorToString(src, visitor) self.AssertSourceEquals(new_src, expected)
def testSuperClassesByName(self): src = textwrap.dedent(""" class A(): pass class B(): pass class C(A): pass class D(A,B): pass class E(C,D,A): pass """) tree = self.Parse(src) data = tree.Visit(visitors.ExtractSuperClassesByName()) self.assertItemsEqual(("classobj",), data["A"]) self.assertItemsEqual(("classobj",), data["B"]) self.assertItemsEqual(("A",), data["C"]) self.assertItemsEqual(("A", "B"), data["D"]) self.assertItemsEqual(("A", "C", "D"), data["E"])