def testAddInheritedMethods(self): src = textwrap.dedent(""" class A(): foo = ... # type: bool def f(self, x: int) -> float def h(self) -> complex class B(A): bar = ... # type: int def g(self, y: int) -> bool def h(self, z: float) -> ? """) expected = textwrap.dedent(""" class A(): foo = ... # type: bool def f(self, x: int) -> float def h(self) -> complex class B(A): bar = ... # type: int foo = ... # type: bool def g(self, y: int) -> bool def h(self, z: float) -> ? def f(self, x: int) -> float """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) ast = ast.Visit(optimize.AddInheritedMethods()) self.AssertSourceEquals(ast, expected)
def testRemoveInheritedMethodsWithOverride(self): src = textwrap.dedent(""" class A(object): def f(self, x) -> Any class B(A): def f(self) -> Any class C(B): def f(self) -> Any class D(B): def f(self, x) -> Any """) expected = textwrap.dedent(""" class A(object): def f(self, x) -> Any class B(A): def f(self) -> Any class C(B): pass class D(B): def f(self, x) -> Any """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, expected)
def ParsePyTD(src=None, filename=None, python_version=None, module=None): """Parse pytd sourcecode and do name lookup for builtins. This loads a pytd and also makes sure that all names are resolved (i.e., that all primitive types in the AST are ClassType, and not NameType). Args: src: PyTD source code. filename: The filename the source code is from. python_version: The Python version to parse the pytd for. module: The name of the module we're parsing. Returns: A pytd.TypeDeclUnit. """ assert python_version if src is None: with open(filename, "rb") as fi: src = fi.read() ast = parser.parse_string(src, filename=filename, name=module, python_version=python_version) if module is not None: # Allow "" as module name ast = ast.Visit(visitors.AddNamePrefix(ast.name + ".")) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) return ast
def ParsePyTD(src=None, filename=None, python_version=None, module=None, lookup_classes=False): """Parse pytd sourcecode and do name lookup for builtins. This loads a pytd and also makes sure that all names are resolved (i.e., that all primitive types in the AST are ClassType, and not NameType). Args: src: PyTD source code. filename: The filename the source code is from. python_version: The Python version to parse the pytd for. module: The name of the module we're parsing. lookup_classes: If we should also lookup the class of every ClassType. Returns: A pytd.TypeDeclUnit. """ assert python_version if src is None: with open(filename, "rb") as fi: src = fi.read() ast = parser.parse_string(src, filename=filename, name=module, python_version=python_version) if lookup_classes: ast = visitors.LookupClasses(ast, GetBuiltinsPyTD(python_version)) return ast
def testStrict(self): ast = parser.parse_string( textwrap.dedent(""" T = TypeVar('T') class list(typing.Generic[T], object): pass class A(): pass class B(A): pass class `~unknown0`(): pass a = ... # type: A def left() -> `~unknown0` def right() -> list[A] """)) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) left, right = ast.Lookup("left"), ast.Lookup("right") self.assertEquals( m.match(left, right, {}), booleq.And((booleq.Eq("~unknown0", "list"), booleq.Eq("~unknown0.list.T", "A"))))
def testFindUnknownVisitor(self): src = textwrap.dedent(""" class classobj: pass class `~unknown1`(): pass class `~unknown_foobar`(): pass class `~int`(): pass class A(): def foobar(self, x: `~unknown1`) -> ? class B(): def foobar(self, x: `~int`) -> ? class C(): x = ... # type: `~unknown_foobar` class D(`~unknown1`): pass """) tree = self.Parse(src) tree = visitors.LookupClasses(tree) find_on = lambda x: tree.Lookup(x).Visit(visitors.RaiseIfContainsUnknown()) self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "A") find_on("B") # shouldn't raise self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "C") self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "D")
def solve(ast, builtins_pytd): """Solve the unknowns in a pytd AST using the standard Python builtins. Args: ast: A pytd.TypeDeclUnit, containing classes named ~unknownXX. builtins_pytd: A pytd for builtins. Returns: A tuple of (1) a dictionary (str->str) mapping unknown class names to known class names and (2) a pytd.TypeDeclUnit of the complete classes in ast. """ builtins_pytd = pytd_utils.RemoveMutableParameters(builtins_pytd) builtins_pytd = visitors.LookupClasses(builtins_pytd, overwrite=True) ast = visitors.LookupClasses(ast, builtins_pytd, overwrite=True) ast.Visit(visitors.InPlaceFillInExternalTypes(builtins_pytd)) return TypeSolver(ast, builtins_pytd).solve(), extract_local(ast)
def testSuperClasses(self): src = textwrap.dedent(""" class classobj: pass class A(): pass class B(): pass class C(A): pass class D(A,B): pass class E(C,D,A): pass """) ast = visitors.LookupClasses(self.Parse(src)) data = ast.Visit(visitors.ExtractSuperClasses()) self.assertItemsEqual(["classobj"], [t.name for t in data[ast.Lookup("A")]]) self.assertItemsEqual(["classobj"], [t.name for t in data[ast.Lookup("B")]]) self.assertItemsEqual(["A"], [t.name for t in data[ast.Lookup("C")]]) self.assertItemsEqual(["A", "B"], [t.name for t in data[ast.Lookup("D")]]) self.assertItemsEqual(["C", "D", "A"], [t.name for t in data[ast.Lookup("E")]])
def Optimize(node, builtins=None, lossy=False, use_abcs=False, max_union=7, remove_mutable=False, can_do_lookup=True): """Optimize a PYTD tree. Tries to shrink a PYTD tree by applying various optimizations. Arguments: node: A pytd node to be optimized. It won't be modified - this function will return a new node. builtins: Definitions of all of the external types in node. lossy: Allow optimizations that change the meaning of the pytd. use_abcs: Use abstract base classes to represent unions like e.g. "float or int" as "Real". max_union: How many types we allow in a union before we simplify it to just "object". remove_mutable: Whether to simplify mutable parameters to normal parameters. can_do_lookup: True: We're either allowed to try to resolve NamedType instances in the AST, or the AST is already resolved. False: Skip any optimizations that would require NamedTypes to be resolved. Returns: An optimized node. """ node = node.Visit(RemoveDuplicates()) node = node.Visit(SimplifyUnions()) node = node.Visit(CombineReturnsAndExceptions()) node = node.Visit(Factorize()) node = node.Visit(ApplyOptionalArguments()) node = node.Visit(CombineContainers()) node = node.Visit(SimplifyContainers()) if builtins: superclasses = builtins.Visit(visitors.ExtractSuperClassesByName()) superclasses.update(node.Visit(visitors.ExtractSuperClassesByName())) if use_abcs: superclasses.update(abc_hierarchy.GetSuperClasses()) hierarchy = SuperClassHierarchy(superclasses) node = node.Visit(SimplifyUnionsWithSuperclasses(hierarchy)) if lossy: node = node.Visit(FindCommonSuperClasses(hierarchy)) if max_union: node = node.Visit(CollapseLongUnions(max_union)) node = node.Visit(AdjustReturnAndConstantGenericType()) if remove_mutable: node = node.Visit(AbsorbMutableParameters()) node = node.Visit(CombineContainers()) node = node.Visit(MergeTypeParameters()) node = node.Visit(visitors.AdjustSelf(force=True)) node = node.Visit(SimplifyContainers()) if builtins and can_do_lookup: node = visitors.LookupClasses(node, builtins) node = node.Visit(RemoveInheritedMethods()) node = node.Visit(RemoveRedundantSignatures(hierarchy)) return node
def solve(ast, builtins_pytd, protocols_pytd): """Solve the unknowns in a pytd AST using the standard Python builtins. Args: ast: A pytd.TypeDeclUnit, containing classes named ~unknownXX. builtins_pytd: A pytd for builtins. protocols_pytd: A pytd for protocols. Returns: A tuple of (1) a dictionary (str->str) mapping unknown class names to known class names and (2) a pytd.TypeDeclUnit of the complete classes in ast. """ builtins_pytd = transforms.RemoveMutableParameters(builtins_pytd) builtins_pytd = visitors.LookupClasses(builtins_pytd) protocols_pytd = visitors.LookupClasses(protocols_pytd) ast = visitors.LookupClasses(ast, builtins_pytd) return TypeSolver( ast, builtins_pytd, protocols_pytd).solve(), extract_local(ast)
def testRemoveInheritedMethodsWithCircle(self): src = textwrap.dedent(""" class A(B): def f(self) -> Any class B(A): def f(self) -> Any """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, src)
def testDontRemoveAbstractMethodImplementation(self): src = textwrap.dedent(""" class A(object): @abstractmethod def foo(self): ... class B(A): def foo(self): ... """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, src)
def testClassMatch(self): ast = parser.parse_string(textwrap.dedent(""" class Left(): def method(self) -> ? class Right(): def method(self) -> ? def method2(self) -> ? """), python_version=self.PYTHON_VERSION) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch() left, right = ast.Lookup("Left"), ast.Lookup("Right") self.assertEqual(m.match(left, right, {}), booleq.TRUE) self.assertNotEqual(m.match(right, left, {}), booleq.TRUE)
def testSimplifyUnionsWithSuperclassesGeneric(self): src = textwrap.dedent(""" x = ... # type: frozenset[int] or AbstractSet[int] """) expected = textwrap.dedent(""" x = ... # type: AbstractSet[int] """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def testLookupClasses(self): src = textwrap.dedent(""" class object(object): pass class A(object): def a(self, a: A, b: B) -> A or B raises A, B class B(object): def b(self, a: A, b: B) -> A or B raises A, B """) tree = self.Parse(src) new_tree = visitors.LookupClasses(tree) self.AssertSourceEquals(new_tree, src) new_tree.Visit(visitors.VerifyLookup())
def testGeneric(self): ast = parser.parse_string( textwrap.dedent(""" T = TypeVar('T') class A(typing.Generic[T], object): pass left = ... # type: A[?] right = ... # type: A[?] """)) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch() self.assertEquals( m.match_type_against_type( ast.Lookup("left").type, ast.Lookup("right").type, {}), booleq.TRUE)
def testRemoveInheritedMethodsWithDiamond(self): src = textwrap.dedent(""" class A(object): def f(self, x) -> ? class B(A): pass class C(A): def f(self, x, y) -> ? class D(B, C): def f(self, x) -> ? """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, src)
def testSubclasses(self): ast = parser.parse_string(textwrap.dedent(""" class A(): pass class B(A): pass a = ... # type: A def left(a: B) -> B def right(a: A) -> A """), python_version=self.PYTHON_VERSION) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) left, right = ast.Lookup("left"), ast.Lookup("right") self.assertEqual(m.match(left, right, {}), booleq.TRUE) self.assertNotEqual(m.match(right, left, {}), booleq.TRUE)
def testBaseClass(self): ast = parser.parse_string( textwrap.dedent(""" class Base(): def f(self, x:Base) -> Base class Foo(Base): pass class Match(): def f(self, x:Base) -> Base """)) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) eq = m.match_Class_against_Class(ast.Lookup("Match"), ast.Lookup("Foo"), {}) self.assertEquals(eq, booleq.TRUE)
def testExternal(self): ast = parser.parse_string( textwrap.dedent(""" class Base(): pass class Foo(Base): pass base = ... # type: Base """)) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch(type_match.get_all_subclasses([ast])) mod1_foo = pytd.ExternalType("Foo", module="mod1", cls=ast.Lookup("Foo")) eq = m.match_type_against_type(mod1_foo, ast.Lookup("base").type, {}) self.assertEquals(eq, booleq.TRUE)
def GetBuiltinsPyTD(): """Get the "default" AST used to lookup built in types. Get an AST for all Python builtins as well as the most commonly used standard libraries. Returns: A pytd.TypeDeclUnit instance. It'll directly contain the builtin classes and functions, and submodules for each of the standard library modules. """ global _cached_builtins_pytd if not _cached_builtins_pytd: builtins_pytd = parser.TypeDeclParser().Parse( _FindBuiltinFile(_BUILTIN_NAME), name=_BUILTIN_NAME) _cached_builtins_pytd = visitors.LookupClasses(builtins_pytd) return _cached_builtins_pytd
def testAddInheritedMethods(self): src = textwrap.dedent(""" class A(): foo = ... # type: bool def f(self, x: int) -> float def h(self) -> complex class B(A): bar = ... # type: int def g(self, y: int) -> bool def h(self, z: float) -> ? """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) self.assertItemsEqual(("g", "h"), [m.name for m in ast.Lookup("B").methods]) ast = ast.Visit(optimize.AddInheritedMethods()) self.assertItemsEqual(("f", "g", "h"), [m.name for m in ast.Lookup("B").methods])
def testSimplifyUnionsWithSuperclasses(self): src = textwrap.dedent(""" x = ... # type: int or bool y = ... # type: int or bool or float z = ... # type: list[int] or int """) expected = textwrap.dedent(""" x = ... # type: int y = ... # type: int or float z = ... # type: list[int] or int """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def Optimize(node, lossy=False, use_abcs=False, max_union=7, remove_mutable=False): """Optimize a PYTD tree. Tries to shrink a PYTD tree by applying various optimizations. Arguments: node: A pytd node to be optimized. It won't be modified - this function will return a new node. lossy: Allow optimizations that change the meaning of the pytd. use_abcs: Use abstract base classes to represent unions like e.g. "float or int" as "Real" max_union: How many types we allow in a union before we simplify it to just "object". remove_mutable: Whether to simplify mutable parameters to normal parameters. Returns: An optimized node. """ node = node.Visit(RemoveDuplicates()) node = node.Visit(SimplifyUnions()) node = node.Visit(CombineReturnsAndExceptions()) node = node.Visit(Factorize()) node = node.Visit(ApplyOptionalArguments()) node = node.Visit(CombineContainers()) if lossy: hierarchy = node.Visit(visitors.ExtractSuperClassesByName()) node = node.Visit(FindCommonSuperClasses(hierarchy, use_abcs)) if max_union: node = node.Visit(CollapseLongParameterUnions(max_union)) node = node.Visit(CollapseLongReturnUnions(max_union)) node = node.Visit(CollapseLongConstantUnions(max_union)) if remove_mutable: node = node.Visit(AbsorbMutableParameters()) node = node.Visit(CombineContainers()) node = node.Visit(MergeTypeParameters()) node = node.Visit(visitors.AdjustSelf(force=True)) node = visitors.LookupClasses(node, builtins.GetBuiltinsPyTD()) node = node.Visit(RemoveInheritedMethods()) return node
def testRemoveInheritedMethodsWithoutSelf(self): src = textwrap.dedent(""" class Bar(object): def baz(self) -> int class Foo(Bar): def baz(self) -> int def bar() -> float """) expected = textwrap.dedent(""" class Bar(object): def baz(self) -> int class Foo(Bar): def bar() -> float """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltinsPyTD()) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, expected)
def _TestTypeParameters(self, reverse=False): ast = parser.parse_string( textwrap.dedent(""" class `~unknown0`(): def next(self) -> ? T = TypeVar('T') class A(typing.Generic[T], object): def next(self) -> ? class B(): pass def left(x: `~unknown0`) -> ? def right(x: A[B]) -> ? """)) ast = visitors.LookupClasses(ast, self.mini_builtins) m = type_match.TypeMatch() left, right = ast.Lookup("left"), ast.Lookup("right") match = m.match(right, left, {}) if reverse else m.match( left, right, {}) self.assertEquals( match, booleq.And((booleq.Eq("~unknown0", "A"), booleq.Eq("~unknown0.A.T", "B")))) self.assertIn("~unknown0.A.T", m.solver.variables)
def testRemoveInheritedMethods(self): src = textwrap.dedent(""" class A(): def f(self, y: int) -> bool def g(self) -> float class B(A): def b(self) -> B def f(self, y: int) -> bool class C(A): def c(self) -> C def f(self, y: int) -> bool class D(B): def g(self) -> float def d(self) -> D """) expected = textwrap.dedent(""" class A(): def f(self, y: int) -> bool def g(self) -> float class B(A): def b(self) -> B class C(A): def c(self) -> C class D(B): def d(self) -> D """) ast = self.Parse(src) ast = visitors.LookupClasses( ast, builtins.GetBuiltinsPyTD(self.PYTHON_VERSION)) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, expected)
def ParseWithLookup(self, src): tree = self.Parse(src) return visitors.LookupClasses( tree, builtins.GetBuiltinsPyTD(self.PYTHON_VERSION))
def LinkAgainstSimpleBuiltins(self, ast): ast = ast.Visit(visitors.AdjustTypeParameters()) ast = visitors.LookupClasses(ast, self.mini_builtins) return ast
def ParseWithLookup(self, src): tree = self.Parse(src) return visitors.LookupClasses( tree, visitors.LookupClasses(builtins.GetBuiltinsPyTD()))