def testFindUnknownVisitor(self): src = textwrap.dedent(""" class classobj: pass class `~unknown1`(): pass class `~unknown_foobar`(): pass class `~int`(): pass class A(): def foobar(self, x: `~unknown1`) -> ? class B(): def foobar(self, x: `~int`) -> ? class C(): x = ... # type: `~unknown_foobar` class D(`~unknown1`): pass """) tree = self.Parse(src) tree = visitors.LookupClasses(tree) find_on = lambda x: tree.Lookup(x).Visit(visitors.RaiseIfContainsUnknown()) self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "A") find_on("B") # shouldn't raise self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "C") self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "D")
def testSuperClasses(self): src = textwrap.dedent(""" class classobj: pass class A(): pass class B(): pass class C(A): pass class D(A,B): pass class E(C,D,A): pass """) ast = visitors.LookupClasses(self.Parse(src)) data = ast.Visit(visitors.ExtractSuperClasses()) self.assertItemsEqual(["classobj"], [t.name for t in data[ast.Lookup("A")]]) self.assertItemsEqual(["classobj"], [t.name for t in data[ast.Lookup("B")]]) self.assertItemsEqual(["A"], [t.name for t in data[ast.Lookup("C")]]) self.assertItemsEqual(["A", "B"], [t.name for t in data[ast.Lookup("D")]]) self.assertItemsEqual(["C", "D", "A"], [t.name for t in data[ast.Lookup("E")]])
def test_find_unknown_visitor(self): src = pytd_src(""" from typing import Any class object: pass class `~unknown1`(): pass class `~unknown_foobar`(): pass class `~int`(): pass class A(): def foobar(self, x: `~unknown1`) -> Any: ... class B(): def foobar(self, x: `~int`) -> Any: ... class C(): x = ... # type: `~unknown_foobar` class D(`~unknown1`): pass """) tree = self.Parse(src) tree = visitors.LookupClasses(tree) find_on = lambda x: tree.Lookup(x).Visit(visitors. RaiseIfContainsUnknown()) self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "A") find_on("B") # shouldn't raise self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "C") self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "D")
def test_superclasses(self): src = textwrap.dedent(""" class object: pass class A(): pass class B(): pass class C(A): pass class D(A,B): pass class E(C,D,A): pass """) ast = visitors.LookupClasses(self.Parse(src)) data = ast.Visit(pytd_visitors.ExtractSuperClasses()) six.assertCountEqual(self, ["object"], [t.name for t in data[ast.Lookup("A")]]) six.assertCountEqual(self, ["object"], [t.name for t in data[ast.Lookup("B")]]) six.assertCountEqual(self, ["A"], [t.name for t in data[ast.Lookup("C")]]) six.assertCountEqual(self, ["A", "B"], [t.name for t in data[ast.Lookup("D")]]) six.assertCountEqual(self, ["C", "D", "A"], [t.name for t in data[ast.Lookup("E")]])
def testRemoveInheritedMethodsWithOverride(self): src = textwrap.dedent(""" class A(object): def f(self, x) -> ? class B(A): def f(self) -> ? class C(B): def f(self) -> ? class D(B): def f(self, x) -> ? """) expected = textwrap.dedent(""" class A(object): def f(self, x) -> ? class B(A): def f(self) -> ? class C(B): pass class D(B): def f(self, x) -> ? """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, expected)
def Optimize(node, builtins=None, lossy=False, use_abcs=False, max_union=7, remove_mutable=False, can_do_lookup=True): """Optimize a PYTD tree. Tries to shrink a PYTD tree by applying various optimizations. Arguments: node: A pytd node to be optimized. It won't be modified - this function will return a new node. builtins: Definitions of all of the external types in node. lossy: Allow optimizations that change the meaning of the pytd. use_abcs: Use abstract base classes to represent unions like e.g. "float or int" as "Real". max_union: How many types we allow in a union before we simplify it to just "object". remove_mutable: Whether to simplify mutable parameters to normal parameters. can_do_lookup: True: We're either allowed to try to resolve NamedType instances in the AST, or the AST is already resolved. False: Skip any optimizations that would require NamedTypes to be resolved. Returns: An optimized node. """ node = node.Visit(RemoveDuplicates()) node = node.Visit(SimplifyUnions()) node = node.Visit(CombineReturnsAndExceptions()) node = node.Visit(Factorize()) node = node.Visit(ApplyOptionalArguments()) node = node.Visit(CombineContainers()) node = node.Visit(SimplifyContainers()) if builtins: superclasses = builtins.Visit(visitors.ExtractSuperClassesByName()) superclasses.update(node.Visit(visitors.ExtractSuperClassesByName())) if use_abcs: superclasses.update(abc_hierarchy.GetSuperClasses()) hierarchy = SuperClassHierarchy(superclasses) node = node.Visit(SimplifyUnionsWithSuperclasses(hierarchy)) if lossy: node = node.Visit(FindCommonSuperClasses(hierarchy)) if max_union: node = node.Visit(CollapseLongUnions(max_union)) node = node.Visit(AdjustReturnAndConstantGenericType()) if remove_mutable: node = node.Visit(AbsorbMutableParameters()) node = node.Visit(CombineContainers()) node = node.Visit(MergeTypeParameters()) node = node.Visit(visitors.AdjustSelf()) node = node.Visit(SimplifyContainers()) if builtins and can_do_lookup: node = visitors.LookupClasses(node, builtins, ignore_late_types=True) node = node.Visit(RemoveInheritedMethods()) node = node.Visit(RemoveRedundantSignatures(hierarchy)) return node
def solve(ast, builtins_pytd, protocols_pytd): """Solve the unknowns in a pytd AST using the standard Python builtins. Args: ast: A pytd.TypeDeclUnit, containing classes named ~unknownXX. builtins_pytd: A pytd for builtins. protocols_pytd: A pytd for protocols. Returns: A tuple of (1) a dictionary (str->str) mapping unknown class names to known class names and (2) a pytd.TypeDeclUnit of the complete classes in ast. """ builtins_pytd = transforms.RemoveMutableParameters(builtins_pytd) builtins_pytd = visitors.LookupClasses(builtins_pytd) protocols_pytd = visitors.LookupClasses(protocols_pytd) ast = visitors.LookupClasses(ast, builtins_pytd) return TypeSolver(ast, builtins_pytd, protocols_pytd).solve(), extract_local(ast)
def testRemoveInheritedMethodsWithCircle(self): src = textwrap.dedent(""" class A(B): def f(self) -> ? class B(A): def f(self) -> ? """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, src)
def testDontRemoveAbstractMethodImplementation(self): src = textwrap.dedent(""" class A(object): @abstractmethod def foo(self): ... class B(A): def foo(self): ... """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, src)
def testClassMatch(self): ast = parser.parse_string(textwrap.dedent(""" class Left(): def method(self) -> ? class Right(): def method(self) -> ? def method2(self) -> ? """), python_version=self.PYTHON_VERSION) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch() left, right = ast.Lookup("Left"), ast.Lookup("Right") self.assertEqual(m.match(left, right, {}), booleq.TRUE) self.assertNotEqual(m.match(right, left, {}), booleq.TRUE)
def test_simplify_unions_with_superclasses_generic(self): src = pytd_src(""" x = ... # type: Union[frozenset[int], AbstractSet[int]] """) expected = pytd_src(""" x = ... # type: AbstractSet[int] """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def test_class_match(self): ast = parser.parse_string(textwrap.dedent(""" from typing import Any class Left(): def method(self) -> Any: ... class Right(): def method(self) -> Any: ... def method2(self) -> Any: ... """), python_version=self.python_version) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch() left, right = ast.Lookup("Left"), ast.Lookup("Right") self.assertEqual(m.match(left, right, {}), booleq.TRUE) self.assertNotEqual(m.match(right, left, {}), booleq.TRUE)
def testSimplifyUnionsWithSuperclassesGeneric(self): src = textwrap.dedent(""" x = ... # type: frozenset[int] or AbstractSet[int] """) expected = textwrap.dedent(""" x = ... # type: AbstractSet[int] """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def testAdjustInheritedMethodSelf(self): src = textwrap.dedent(""" class A(): def f(self: object) -> float class B(A): pass """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(optimize.AddInheritedMethods()) self.assertMultiLineEqual(pytd.Print(ast.Lookup("B")), textwrap.dedent("""\ class B(A): def f(self) -> float: ... """))
def test_remove_inherited_methods_with_diamond(self): src = textwrap.dedent(""" class A(object): def f(self, x) -> ? class B(A): pass class C(A): def f(self, x, y) -> ? class D(B, C): def f(self, x) -> ? """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, src)
def testSubclasses(self): ast = parser.parse_string(textwrap.dedent(""" class A(): pass class B(A): pass a = ... # type: A def left(a: B) -> B def right(a: A) -> A """), python_version=self.PYTHON_VERSION) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) left, right = ast.Lookup("left"), ast.Lookup("right") self.assertEqual(m.match(left, right, {}), booleq.TRUE) self.assertNotEqual(m.match(right, left, {}), booleq.TRUE)
def test_adjust_inherited_method_self(self): src = pytd_src(""" class A(): def f(self: object) -> float: ... class B(A): pass """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(optimize.AddInheritedMethods()) self.assertMultiLineEqual( pytd_utils.Print(ast.Lookup("B")), pytd_src(""" class B(A): def f(self) -> float: ... """).lstrip())
def testAddInheritedMethods(self): src = textwrap.dedent(""" class A(): foo = ... # type: bool def f(self, x: int) -> float def h(self) -> complex class B(A): bar = ... # type: int def g(self, y: int) -> bool def h(self, z: float) -> ? """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) self.assertItemsEqual(("g", "h"), [m.name for m in ast.Lookup("B").methods]) ast = ast.Visit(optimize.AddInheritedMethods()) self.assertItemsEqual(("f", "g", "h"), [m.name for m in ast.Lookup("B").methods])
def test_simplify_unions_with_superclasses(self): src = pytd_src(""" x = ... # type: Union[int, bool] y = ... # type: Union[int, bool, float] z = ... # type: Union[list[int], int] """) expected = pytd_src(""" x = ... # type: int y = ... # type: Union[int, float] z = ... # type: Union[list[int], int] """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def testSimplifyUnionsWithSuperclasses(self): src = textwrap.dedent(""" x = ... # type: int or bool y = ... # type: int or bool or float z = ... # type: list[int] or int """) expected = textwrap.dedent(""" x = ... # type: int y = ... # type: int or float z = ... # type: list[int] or int """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def testLookupClasses(self): src = textwrap.dedent(""" class object(object): pass class A(object): def a(self, a: A, b: B) -> A or B: raise A() raise B() class B(object): def b(self, a: A, b: B) -> A or B: raise A() raise B() """) tree = self.Parse(src) new_tree = visitors.LookupClasses(tree) self.AssertSourceEquals(new_tree, src) new_tree.Visit(visitors.VerifyLookup())
def test_lookup_classes(self): src = textwrap.dedent(""" from typing import Union class object: pass class A: def a(self, a: A, b: B) -> Union[A, B]: raise A() raise B() class B: def b(self, a: A, b: B) -> Union[A, B]: raise A() raise B() """) tree = self.Parse(src) new_tree = visitors.LookupClasses(tree) self.AssertSourceEquals(new_tree, src) new_tree.Visit(visitors.VerifyLookup())
def test_add_inherited_methods(self): src = pytd_src(""" from typing import Any class A(): foo = ... # type: bool def f(self, x: int) -> float: ... def h(self) -> complex: ... class B(A): bar = ... # type: int def g(self, y: int) -> bool: ... def h(self, z: float) -> Any: ... """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) self.assertCountEqual(("g", "h"), [m.name for m in ast.Lookup("B").methods]) ast = ast.Visit(optimize.AddInheritedMethods()) self.assertCountEqual(("f", "g", "h"), [m.name for m in ast.Lookup("B").methods])
def testRemoveInheritedMethodsWithoutSelf(self): src = textwrap.dedent(""" class Bar(object): def baz(self) -> int class Foo(Bar): def baz(self) -> int def bar() -> float """) expected = textwrap.dedent(""" class Bar(object): def baz(self) -> int class Foo(Bar): def bar() -> float """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, expected)
def testRemoveInheritedMethods(self): src = textwrap.dedent(""" class A(): def f(self, y: int) -> bool def g(self) -> float class B(A): def b(self) -> B def f(self, y: int) -> bool class C(A): def c(self) -> C def f(self, y: int) -> bool class D(B): def g(self) -> float def d(self) -> D """) expected = textwrap.dedent(""" class A(): def f(self, y: int) -> bool def g(self) -> float class B(A): def b(self) -> B class C(A): def c(self) -> C class D(B): def d(self) -> D """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, expected)
def testLookupTypingClass(self): node = visitors.LookupClasses(pytd.NamedType("typing.Sequence"), self.loader.concat_all()) assert node.cls
def ParseWithLookup(self, src): tree = self.Parse(src) return visitors.LookupClasses(tree, builtins.GetBuiltinsPyTD( self.PYTHON_VERSION))
def LinkAgainstSimpleBuiltins(self, ast): ast = ast.Visit(visitors.AdjustTypeParameters()) ast = visitors.LookupClasses(ast, self.mini_builtins) return ast