def Optimize(node, lossy=False, use_abcs=False, max_union=7, remove_mutable=False, can_do_lookup=True): """Optimize a PYTD tree. Tries to shrink a PYTD tree by applying various optimizations. Arguments: node: A pytd node to be optimized. It won't be modified - this function will return a new node. lossy: Allow optimizations that change the meaning of the pytd. use_abcs: Use abstract base classes to represent unions like e.g. "float or int" as "Real" max_union: How many types we allow in a union before we simplify it to just "object". remove_mutable: Whether to simplify mutable parameters to normal parameters. can_do_lookup: True: We're either allowed to try to resolve NamedType instances in the AST, or the AST is already resolved. False: Skip any optimizations that would require NamedTypes to be resolved. Returns: An optimized node. """ node = node.Visit(RemoveDuplicates()) node = node.Visit(SimplifyUnions()) node = node.Visit(CombineReturnsAndExceptions()) node = node.Visit(Factorize()) node = node.Visit(ApplyOptionalArguments()) node = node.Visit(CombineContainers()) node = node.Visit(SimplifyContainers()) superclasses = builtins.GetBuiltinsPyTD().Visit( visitors.ExtractSuperClassesByName()) superclasses.update(node.Visit(visitors.ExtractSuperClassesByName())) if use_abcs: superclasses.update(abc_hierarchy.GetSuperClasses()) hierarchy = SuperClassHierarchy(superclasses) node = node.Visit(SimplifyUnionsWithSuperclasses(hierarchy)) if lossy: node = node.Visit(FindCommonSuperClasses(hierarchy)) if max_union: node = node.Visit(CollapseLongUnions(max_union)) node = node.Visit(AdjustReturnAndConstantGenericType()) if remove_mutable: node = node.Visit(AbsorbMutableParameters()) node = node.Visit(CombineContainers()) node = node.Visit(MergeTypeParameters()) node = node.Visit(visitors.AdjustSelf(force=True)) node = node.Visit(SimplifyContainers()) if can_do_lookup: node = visitors.LookupClasses(node, builtins.GetBuiltinsPyTD()) node = node.Visit(RemoveInheritedMethods()) node = node.Visit(RemoveRedundantSignatures(hierarchy)) return node
def testSimplifyUnionsWithSuperclassesGeneric(self): src = textwrap.dedent(""" x = ... # type: frozenset[int] or AbstractSet[int] """) expected = textwrap.dedent(""" x = ... # type: AbstractSet[int] """) hierarchy = builtins.GetBuiltinsPyTD().Visit( visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def testRemoveInheritedMethodsWithOverride(self): src = textwrap.dedent(""" class A(object): def f(self, x) -> Any class B(A): def f(self) -> Any class C(B): def f(self) -> Any class D(B): def f(self, x) -> Any """) expected = textwrap.dedent(""" class A(object): def f(self, x) -> Any class B(A): def f(self) -> Any class C(B): pass class D(B): def f(self, x) -> Any """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, expected)
def testAddInheritedMethods(self): src = textwrap.dedent(""" class A(): foo = ... # type: bool def f(self, x: int) -> float def h(self) -> complex class B(A): bar = ... # type: int def g(self, y: int) -> bool def h(self, z: float) -> ? """) expected = textwrap.dedent(""" class A(): foo = ... # type: bool def f(self, x: int) -> float def h(self) -> complex class B(A): bar = ... # type: int foo = ... # type: bool def g(self, y: int) -> bool def h(self, z: float) -> ? def f(self, x: int) -> float """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) ast = ast.Visit(optimize.AddInheritedMethods()) self.AssertSourceEquals(ast, expected)
def testSimplifyUnionsWithSuperclasses(self): src = textwrap.dedent(""" x = ... # type: int or bool y = ... # type: int or bool or float z = ... # type: list[int] or int """) expected = textwrap.dedent(""" x = ... # type: int y = ... # type: int or float z = ... # type: list[int] or int """) hierarchy = builtins.GetBuiltinsPyTD().Visit( visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def testRemoveInheritedMethodsWithCircle(self): src = textwrap.dedent(""" class A(B): def f(self) -> Any class B(A): def f(self) -> Any """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, src)
def testDontRemoveAbstractMethodImplementation(self): src = textwrap.dedent(""" class A(object): @abstractmethod def foo(self): ... class B(A): def foo(self): ... """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, src)
def testBuiltinSuperClasses(self): src = textwrap.dedent(""" def f(x: list or object, y: int or float) -> int or bool """) expected = textwrap.dedent(""" def f(x, y) -> int """) b = builtins.GetBuiltinsPyTD() hierarchy = b.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.ParseAndResolve(src) ast = ast.Visit(visitor) ast = ast.Visit(visitors.DropBuiltinPrefix()) self.AssertSourceEquals(ast, expected)
def testRemoveInheritedMethodsWithDiamond(self): src = textwrap.dedent(""" class A(object): def f(self, x) -> ? class B(A): pass class C(A): def f(self, x, y) -> ? class D(B, C): def f(self, x) -> ? """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, src)
def testBuiltinSuperClasses(self): src = textwrap.dedent(""" def f(x: list or object, y: complex or memoryview) -> int or bool """) expected = textwrap.dedent(""" def f(x, y) -> int """) b = builtins.GetBuiltinsPyTD(self.PYTHON_VERSION) hierarchy = b.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.ParseAndResolve(src) ast = ast.Visit(visitor) ast = ast.Visit(visitors.DropBuiltinPrefix()) ast = ast.Visit(visitors.CanonicalOrderingVisitor()) self.AssertSourceEquals(ast, expected)
def testAddInheritedMethods(self): src = textwrap.dedent(""" class A(): foo = ... # type: bool def f(self, x: int) -> float def h(self) -> complex class B(A): bar = ... # type: int def g(self, y: int) -> bool def h(self, z: float) -> ? """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) self.assertItemsEqual(("g", "h"), [m.name for m in ast.Lookup("B").methods]) ast = ast.Visit(optimize.AddInheritedMethods()) self.assertItemsEqual(("f", "g", "h"), [m.name for m in ast.Lookup("B").methods])
def testRemoveInheritedMethodsWithoutSelf(self): src = textwrap.dedent(""" class Bar(object): def baz(self) -> int class Foo(Bar): def baz(self) -> int def bar() -> float """) expected = textwrap.dedent(""" class Bar(object): def baz(self) -> int class Foo(Bar): def bar() -> float """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, expected)
def testRemoveInheritedMethods(self): src = textwrap.dedent(""" class A(): def f(self, y: int) -> bool def g(self) -> float class B(A): def b(self) -> B def f(self, y: int) -> bool class C(A): def c(self) -> C def f(self, y: int) -> bool class D(B): def g(self) -> float def d(self) -> D """) expected = textwrap.dedent(""" class A(): def f(self, y: int) -> bool def g(self) -> float class B(A): def b(self) -> B class C(A): def c(self) -> C class D(B): def d(self) -> D """) ast = self.Parse(src) ast = visitors.LookupClasses( ast, builtins.GetBuiltinsPyTD(self.PYTHON_VERSION)) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, expected)
def setUpClass(cls): super(UtilsTest, cls).setUpClass() cls.builtins = builtins.GetBuiltinsPyTD(cls.python_version)
def Optimize(self, ast, **kwargs): return optimize.Optimize(ast, builtins.GetBuiltinsPyTD(), **kwargs)
def Optimize(self, ast, **kwargs): return optimize.Optimize(ast, builtins.GetBuiltinsPyTD(self.PYTHON_VERSION), **kwargs)