def testNoReturn(self): foo = self.Infer(""" def fail(): raise ValueError() """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd.Print(foo)) self.Check(""" import foo def g(): x = "hello" if __random__ else None if x is None: foo.fail() return x.upper() """, pythonpath=[d.path])
def test_subclass(self): ty = self.Infer(""" import collections class X(collections.namedtuple("X", [])): def __new__(cls, _): return super(X, cls).__new__(cls) """, deep=True) name = collections_overlay.namedtuple_name("X", []) ast = collections_overlay.namedtuple_ast(name, []) expected = pytd.Print(ast) + textwrap.dedent("""\ collections = ... # type: module _TX = TypeVar("_TX", bound=X) class X({name}): def __new__(cls: Type[_TX], _) -> _TX: ...""").format(name=name) self.assertTypesMatchPytd(ty, expected)
def testPrintImports(self): src = textwrap.dedent(""" from typing import List, Tuple, Union def f(x: Union[int, slice]) -> List[?]: ... def g(x: foo.C.C2) -> None: ... """) expected = textwrap.dedent("""\ import foo.C from typing import Any, List, Union def f(x: Union[int, slice]) -> List[Any]: ... def g(x: foo.C.C2) -> None: ...""") tree = self.Parse(src) res = pytd.Print(tree) self.AssertSourceEquals(res, src) self.assertMultiLineEqual(res, expected)
def testDefaultArgumentType(self): foo = self.Infer(""" from __future__ import google_type_annotations from typing import Any, Callable, TypeVar T = TypeVar("T") def f(x): return True def g(x: Callable[[T], Any]) -> T: ... """) with utils.Tempdir() as d: d.create_file("foo.pyi", pytd.Print(foo)) self.Check(""" import foo foo.g(foo.f).upper() """, pythonpath=[d.path])
def testLookup(self): ty = self.Infer(""" class Cloneable(object): def __init__(self): pass def clone(self): return type(self)() Cloneable().clone() """, deep=False, show_library_calls=True) cls = ty.Lookup("Cloneable") method = cls.Lookup("clone") self.assertEqual(pytd.Print(method), "def clone(self: _TCloneable) -> _TCloneable: ...")
def testResolveAlias(self): with utils.Tempdir() as d: d.create_file( "module1.pyi", """ from typing import List x = List[int] """) d.create_file("module2.pyi", """ def f() -> module1.x """) loader = load_pytd.Loader("base", self.PYTHON_VERSION, pythonpath=[d.path]) module2 = loader.import_name("module2") f, = module2.Lookup("module2.f").signatures self.assertEqual("List[int]", pytd.Print(f.return_type))
def testRedefineTypeVar(self): src = textwrap.dedent(""" def f(x: `~unknown1`) -> `~unknown1`: ... class `TypeVar`(object): ... """) ast = self.Parse(src).Visit(visitors.CreateTypeParametersForSignatures()) self.assertMultiLineEqual(pytd.Print(ast), textwrap.dedent("""\ import typing _T0 = TypeVar('_T0') class `TypeVar`(object): pass def f(x: _T0) -> _T0: ..."""))
def testPrintNoneUnion(self): src = textwrap.dedent(""" from typing import Union def f(x: Union[str, None]) -> None: ... def g(x: Union[str, int, None]) -> None: ... def h(x: Union[None]) -> None: ... """) expected = textwrap.dedent(""" from typing import Optional, Union def f(x: Optional[str]) -> None: ... def g(x: Optional[Union[str, int]]) -> None: ... def h(x: None) -> None: ... """) self.assertMultiLineEqual(expected.strip(), pytd.Print(self.ToAST(src)).strip())
def testLookupTwoStarAliasesWithDefaultPyi(self): src1 = "def __getattr__(name) -> ?" src2 = "def __getattr__(name) -> ?" src3 = textwrap.dedent(""" from foo import * from bar import * """) ast1 = self.Parse(src1).Replace(name="foo").Visit(visitors.AddNamePrefix()) ast2 = self.Parse(src2).Replace(name="bar").Visit(visitors.AddNamePrefix()) ast3 = self.Parse(src3).Replace(name="baz").Visit(visitors.AddNamePrefix()) ast3 = ast3.Visit(visitors.LookupExternalTypes( {"foo": ast1, "bar": ast2, "baz": ast3}, self_name="baz")) self.assertMultiLineEqual(pytd.Print(ast3), textwrap.dedent("""\ from typing import Any def baz.__getattr__(name) -> Any: ..."""))
def testNamedTupleSubclass(self): foo = self.Infer(""" import collections class X(collections.namedtuple("X", ["a"])): def __new__(cls, a, b): print b return super(X, cls).__new__(cls, a) """) with utils.Tempdir() as d: d.create_file("foo.pyi", pytd.Print(foo)) _, errors = self.InferWithErrors("""\ import foo foo.X("hello", "world") foo.X(42) # missing parameters """, pythonpath=[d.path]) self.assertErrorLogIs(errors, [(3, "missing-parameter", "b.*__new__")])
def testContainer(self): ty = self.Infer(""" class Container: def Add(self): pass class A(Container): pass """) with utils.Tempdir() as d: d.create_file("foo.pyi", pytd.Print(ty)) self.assertNoErrors(""" # u.py from foo import A A().Add() """, pythonpath=[d.path])
def testAdjustInheritedMethodSelf(self): src = textwrap.dedent(""" class A(): def f(self: object) -> float class B(A): pass """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(optimize.AddInheritedMethods()) self.assertMultiLineEqual( pytd.Print(ast.Lookup("B")), textwrap.dedent("""\ class B(A): def f(self) -> float: ... """))
def testStarImport(self): with utils.Tempdir() as d: d.create_file("foo.pyi", "class A(object): ...") d.create_file("bar.pyi", "from foo import *") foo = _Module(module_name="foo", file_name="foo.pyi") bar = _Module(module_name="bar", file_name="bar.pyi") loader, _ = self._LoadAst(d, module=bar) self._PickleModules(loader, d, foo, bar) loaded_ast = self._LoadPickledModule(d, bar) loaded_ast.Visit(visitors.VerifyLookup()) self.assertMultiLineEqual( pytd.Print(loaded_ast), textwrap.dedent("""\ import foo bar.A = foo.A"""))
def testTypeParameterBound(self): foo = self.Infer(""" from typing import TypeVar T = TypeVar("T", bound=float) def f(x: T) -> T: return x """, deep=False) with utils.Tempdir() as d: d.create_file("foo.pyi", pytd.Print(foo)) _, errors = self.InferWithErrors("""\ import foo foo.f("") """, pythonpath=[d.path]) self.assertErrorLogIs(errors, [(2, "wrong-arg-types", r"float.*str")])
def test_convert_with_type_params(self): ast = self.parse(""" from typing import Dict class A(object): def foo(self, x: `~unknown1`) -> bool class `~unknown1`(): def __setitem__(self, _1: str, _2: `~unknown2`) -> ? def update(self, _1: NoneType or Dict[nothing, nothing]) -> ? class `~unknown2`(): def append(self, _1:NoneType) -> NoneType """) ast = convert_structural.convert_pytd(ast, self.builtins_pytd) x = ast.Lookup("A").Lookup("foo").signatures[0].params[1].type self.assertIn("MutableSequence", pytd.Print(x))
def testNewChain(self): foo = self.Infer(""" class X(object): def __new__(cls, x): return super(X, cls).__new__(cls) """) with utils.Tempdir() as d: d.create_file("foo.pyi", pytd.Print(foo)) self.Check(""" import foo class Y(foo.X): def __new__(cls, x): return super(Y, cls).__new__(cls, x) def __init__(self, x): self.x = x Y("x").x """, pythonpath=[d.path])
def match_Function_against_Class(self, f1, cls2, subst, cache): cls2_methods = cache.get(id(cls2)) if cls2_methods is None: cls2_methods = cache[id(cls2)] = {f.name: f for f in cls2.methods} if f1.name not in cls2_methods: # The class itself doesn't have this method, but base classes might. # TODO(kramm): This should do MRO order, not depth-first. for base in cls2.parents: if isinstance(base, pytd.AnythingType): # AnythingType can contain any method. However, that would mean that # a class that inherits from AnythingType contains any method # imaginable, and hence is a match for anything. To prevent the bad # results caused by that, return FALSE here. return booleq.FALSE elif isinstance(base, (pytd.ClassType, pytd.GenericType)): if isinstance(base, pytd.ClassType): cls = base.cls values = tuple(pytd.AnythingType() for _ in cls.template) elif isinstance(base, pytd.TupleType): cls = base.base_type.cls values = (pytd.UnionType(type_list=base.parameters), ) else: cls = base.base_type.cls values = base.parameters if values: subst = subst.copy() for param, value in zip(cls.template, values): subst[param.type_param] = value implication = self.match_Function_against_Class( f1, cls, subst, cache) if implication is not booleq.FALSE: return implication else: # Funky types like UnionType are hard to match against (and shouldn't # appear as a base class) so we treat them as catch-all. log.warning("Assuming that %s has method %s", pytd.Print(base), f1.name) return booleq.TRUE return booleq.FALSE else: f2 = cls2_methods[f1.name] return self.match_Function_against_Function(f1, f2, subst, skip_self=True)
def testPrintImportsNamedType(self): # Can't get tree by parsing so build explicitly node = pytd.Constant("x", pytd.NamedType("typing.List")) tree = pytd.TypeDeclUnit(constants=(node, ), type_params=(), functions=(), classes=(), aliases=(), name=None) expected_src = textwrap.dedent(""" from typing import List x = ... # type: List """).strip() res = pytd.Print(tree) self.assertMultiLineEqual(res, expected_src)
def testAliasPrinting(self): a = pytd.Alias( "MyList", pytd.GenericType(pytd.NamedType("typing.List"), (pytd.AnythingType(), ))) ty = pytd.TypeDeclUnit(name="test", is_package=False, constants=(), type_params=(), classes=(), functions=(), aliases=(a, )) expected = textwrap.dedent(""" from typing import Any, List MyList = List[Any]""") self.assertMultiLineEqual(expected.strip(), pytd.Print(ty).strip())
def testAddNamePrefixOnNestedClassAlias(self): src = textwrap.dedent(""" class A: class B: class C: ... D = A.B.C """) expected = textwrap.dedent(""" from typing import Type class foo.A: class foo.A.B: class foo.A.B.C: ... D: Type[foo.A.B.C] """).strip() self.assertMultiLineEqual(expected, pytd.Print( self.Parse(src).Replace(name="foo").Visit(visitors.AddNamePrefix())))
def testPrintImports(self): no_import_src = textwrap.dedent(""" def f(x: Union[int, slice]) -> List[Any]: ... def g(x: foo.C.C2) -> None: ... """) imports = textwrap.dedent(""" import foo.C from typing import Any, List, Union """) expected_src = (imports + no_import_src).strip() # Extra newlines tree = self.Parse(no_import_src) res = pytd.Print(tree) # AssertSourceEquals strips imports self.AssertSourceEquals(res, no_import_src) self.assertMultiLineEqual(res, expected_src)
def test_convert(self): ast = self.parse(""" class A(object): def foo(self, x: `~unknown1`) -> ? def foobaz(self, x: int) -> int class `~unknown1`(object): def foobaz(self, x: int) -> int """) expected = textwrap.dedent(""" from typing import Any class A(object): def foo(self, x: A) -> Any: ... def foobaz(self, x: int) -> int: ... """).lstrip() ast = convert_structural.convert_pytd(ast, self.builtins_pytd) self.assertMultiLineEqual(pytd.Print(ast), expected)
def testContainer(self): ty = self.Infer(""" class Container: def Add(self): pass class A(Container): pass """, deep=False) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd.Print(ty)) self.Check(""" # u.py from foo import A A().Add() """, pythonpath=[d.path])
def testLookupStarAliasInUnnamedModule(self): src1 = textwrap.dedent(""" class A(object): ... """) src2 = "from foo import *" ast1 = self.Parse(src1).Replace(name="foo").Visit( visitors.AddNamePrefix()) ast2 = self.Parse(src2) name = ast2.name ast2 = ast2.Visit( visitors.LookupExternalTypes({"foo": ast1}, self_name=None)) self.assertEqual(name, ast2.name) self.assertMultiLineEqual( pytd.Print(ast2), textwrap.dedent("""\ import foo A = foo.A"""))
def testContextManagerSubclass(self): foo = self.Infer(""" class Foo(object): def __enter__(self): return self def __exit__(self, type, value, traceback): return None """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd.Print(foo)) self.Check(""" import foo class Bar(foo.Foo): x = None with Bar() as bar: bar.x """, pythonpath=[d.path])
def test_name_conflict(self): ty = self.Infer(""" import collections X = collections.namedtuple("_", []) Y = collections.namedtuple("_", []) Z = collections.namedtuple("_", "a") """) name_x = collections_overlay.namedtuple_name("_", []) name_z = collections_overlay.namedtuple_name("_", ["a"]) ast_x = collections_overlay.namedtuple_ast(name_x, []) ast_z = collections_overlay.namedtuple_ast(name_z, ["a"]) ast = pytd_utils.Concat(ast_x, ast_z) expected = pytd.Print(ast) + textwrap.dedent("""\ collections = ... # type: module X = {name_x} Y = {name_x} Z = {name_z}""").format(name_x=name_x, name_z=name_z) self.assertTypesMatchPytd(ty, expected)
def match_call_record(self, matcher, solver, call_record, complete): """Match the record of a method call against the formal signature.""" assert is_partial(call_record) assert is_complete(complete) formula = ( matcher.match_Function_against_Function(call_record, complete, {})) if formula is booleq.FALSE: cartesian = call_record.Visit(optimize.ExpandSignatures()) for signature in cartesian.signatures: formula = matcher.match_Signature_against_Function( signature, complete, {}) if formula is booleq.FALSE: faulty_signature = pytd.Print(signature) break else: faulty_signature = "" raise FlawedQuery("Bad call %s%s" % (call_record.name, faulty_signature)) solver.always_true(formula)
def testInheritedMutation(self): foo = self.Infer(""" class MyList(list): write = list.append """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd.Print(foo)) ty = self.Infer(""" import foo lst = foo.MyList() lst.write(42) """, pythonpath=[d.path]) # MyList is not parameterized because it inherits from List[Any]. self.assertTypesMatchPytd( ty, """ foo = ... # type: module lst = ... # type: foo.MyList """)
def _namedtuple_def(self, suffix="", **kws): """Generate the expected pyi for a simple namedtuple definition. Args: suffix: Optionally, extra text to append to the pyi. **kws: Must contain exactly one argument of the form alias=(name, [<fields>]). For example, to generate a definition for X = namedtuple("_X", "y z"), the method call should be _namedtuple_def(X=("_X", ["y", "z"])). Returns: The expected pyi for the namedtuple instance. """ (alias, (name, fields)), = kws.items() name = collections_overlay.namedtuple_name(name, fields) suffix += textwrap.dedent(""" collections = ... # type: module {alias} = {name}""").format(alias=alias, name=name) return pytd.Print(self._namedtuple_ast(name, fields)) + suffix
def testCreateTypeParametersForNew(self): src = textwrap.dedent(""" class Foo: def __new__(cls: Type[Foo]) -> Foo class Bar: def __new__(cls: Type[Bar], x, y, z) -> Bar """) ast = self.Parse(src).Visit(visitors.CreateTypeParametersForSignatures()) self.assertMultiLineEqual(pytd.Print(ast), textwrap.dedent(""" from typing import TypeVar _TBar = TypeVar('_TBar', bound=Bar) _TFoo = TypeVar('_TFoo', bound=Foo) class Foo: def __new__(cls: Type[_TFoo]) -> _TFoo: ... class Bar: def __new__(cls: Type[_TBar], x, y, z) -> _TBar: ... """).strip())