def testInheritedMutation(self): foo = self.Infer(""" class MyList(list): write = list.append """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd_utils.Print(foo)) ty = self.Infer(""" import foo lst = foo.MyList() lst.write(42) """, pythonpath=[d.path]) # MyList is not parameterized because it inherits from List[Any]. self.assertTypesMatchPytd( ty, """ foo = ... # type: module lst = ... # type: foo.MyList """)
def test_lookup_star_alias_in_unnamed_module(self): src1 = textwrap.dedent(""" class A(object): ... """) src2 = "from foo import *" ast1 = self.Parse(src1).Replace(name="foo").Visit( visitors.AddNamePrefix()) ast2 = self.Parse(src2) name = ast2.name ast2 = ast2.Visit( visitors.LookupExternalTypes({"foo": ast1}, self_name=None)) self.assertEqual(name, ast2.name) self.assertMultiLineEqual( pytd_utils.Print(ast2), textwrap.dedent(""" import foo A = foo.A """).strip())
def testInstantiatePyiClass(self): foo = self.Infer(""" import abc class Foo(metaclass=abc.ABCMeta): @abc.abstractmethod def foo(self): pass class Bar(Foo): def foo(self): pass """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd_utils.Print(foo)) _, errors = self.InferWithErrors("""\ import foo foo.Foo() foo.Bar() """, pythonpath=[d.path]) self.assertErrorLogIs(errors, [(2, "not-instantiable", r"foo\.Foo.*foo")])
def test_name_conflict(self): ty = self.Infer(""" import collections X = collections.namedtuple("_", []) Y = collections.namedtuple("_", []) Z = collections.namedtuple("_", "a") """, deep=False) name_x = collections_overlay.namedtuple_name("_", []) name_z = collections_overlay.namedtuple_name("_", ["a"]) ast_x = self._namedtuple_ast(name_x, []) ast_z = self._namedtuple_ast(name_z, ["a"]) ast = pytd_utils.Concat(ast_x, ast_z) expected = pytd_utils.Print(ast) + textwrap.dedent(""" collections = ... # type: module X = {name_x} Y = {name_x} Z = {name_z}""").format(name_x=name_x, name_z=name_z) self.assertTypesMatchPytd(ty, expected)
def test_instantiate_pyi_class(self): foo = self.Infer(""" import abc class Foo(metaclass=abc.ABCMeta): @abc.abstractmethod def foo(self): pass class Bar(Foo): def foo(self): pass """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd_utils.Print(foo)) _, errors = self.InferWithErrors(""" import foo foo.Foo() # not-instantiable[e] foo.Bar() """, pythonpath=[d.path]) self.assertErrorRegexes(errors, {"e": r"foo\.Foo.*foo"})
def PrepareForExport(module_name, python_version, ast, loader): """Prepare an ast as if it was parsed and loaded. External dependencies will not be resolved, as the ast generated by this method is supposed to be exported. Args: module_name: The module_name as a string for the returned ast. python_version: A tuple of (major, minor) python version as string (see config.python_version). ast: pytd.TypeDeclUnit, is only used if src is None. loader: A load_pytd.Loader instance. Returns: A pytd.TypeDeclUnit representing the supplied AST as it would look after being written to a file and parsed. """ # This is a workaround for functionality which crept into places it doesn't # belong. Ideally this would call some transformation Visitors on ast to # transform it into the same ast we get after parsing and loading (compare # load_pytd.Loader.load_file). Unfortunately parsing has some special cases, # e.g. '__init__' return type and '__new__' being a 'staticmethod', which # need to be moved to visitors before we can do this. Printing an ast also # applies transformations, # e.g. visitors.PrintVisitor._FormatContainerContents, which need to move to # their own visitors so they can be applied without printing. src = pytd_utils.Print(ast) ast = parser.parse_string(src=src, name=module_name, python_version=python_version) ast = ast.Visit(visitors.LookupBuiltins(loader.builtins, full_names=False)) ast = ast.Visit( visitors.ExpandCompatibleBuiltins(loader.builtins, python_version)) ast = ast.Visit(visitors.LookupLocalTypes()) ast = ast.Visit(visitors.AdjustTypeParameters()) ast = ast.Visit(visitors.NamedTypeToClassType()) ast = ast.Visit(visitors.FillInLocalPointers({"": ast, module_name: ast})) ast = ast.Visit(visitors.CanonicalOrderingVisitor()) ast = ast.Visit( visitors.ClassTypeToLateType( ignore=[module_name + ".", "__builtin__.", "typing."])) return ast
def test_module(self): foo_ty = self.Infer(""" x: int def f() -> str: return 'hello world' """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd_utils.Print(foo_ty)) errors = self.CheckWithErrors(""" import foo from typing import Protocol class ShouldMatch(Protocol): x: int def f(self) -> str: ... class ExtraAttribute(Protocol): x: int y: str class ExtraMethod(Protocol): def f(self) -> str: ... def g(self) -> int: ... class WrongType(Protocol): x: str def should_match(x: ShouldMatch): pass def extra_attribute(x: ExtraAttribute): pass def extra_method(x: ExtraMethod): pass def wrong_type(x: WrongType): pass should_match(foo) extra_attribute(foo) # wrong-arg-types[e1] extra_method(foo) # wrong-arg-types[e2] wrong_type(foo) # wrong-arg-types[e3] """, pythonpath=[d.path]) self.assertErrorRegexes( errors, { "e1": r"not implemented on module: y", "e2": r"not implemented on module: g", "e3": r"x.*expected str, got int", })
def match_Function_against_Class(self, f1, cls2, subst, cache): cls2_methods = cache.get(id(cls2)) if cls2_methods is None: cls2_methods = cache[id(cls2)] = {f.name: f for f in cls2.methods} if f1.name not in cls2_methods: # The class itself doesn't have this method, but base classes might. # TODO(b/159058933): This should do MRO order, not depth-first. for base in cls2.bases: if isinstance(base, pytd.AnythingType): # AnythingType can contain any method. However, that would mean that # a class that inherits from AnythingType contains any method # imaginable, and hence is a match for anything. To prevent the bad # results caused by that, return FALSE here. return booleq.FALSE elif isinstance(base, (pytd.ClassType, pytd.GenericType)): if isinstance(base, pytd.ClassType): cls = base.cls values = tuple(pytd.AnythingType() for _ in cls.template) elif isinstance(base, pytd.TupleType): cls = base.base_type.cls values = (pytd_utils.JoinTypes(base.parameters),) else: cls = base.base_type.cls values = base.parameters if values: subst = subst.copy() for param, value in zip(cls.template, values): subst[param.type_param] = value implication = self.match_Function_against_Class(f1, cls, subst, cache) if implication is not booleq.FALSE: return implication else: # Funky types like UnionType are hard to match against (and shouldn't # appear as a base class) so we treat them as catch-all. log.warning("Assuming that %s has method %s", pytd_utils.Print(base), f1.name) return booleq.TRUE return booleq.FALSE else: f2 = cls2_methods[f1.name] return self.match_Function_against_Function( f1, f2, subst, skip_self=True)
def _convert_member(self, member, subst=None): """Convert a member as a variable. For lazy lookup.""" subst = subst or datatypes.AliasingDict() node = self.ctx.root_node if isinstance(member, pytd.Constant): return self.ctx.convert.constant_to_var( abstract_utils.AsInstance(member.type), subst, node) elif isinstance(member, pytd.Function): c = self.ctx.convert.constant_to_value(member, subst=subst, node=node) c.parent = self return c.to_variable(node) elif isinstance(member, pytd.Class): return self.ctx.convert.constant_to_var(member, subst=subst, node=node) else: raise AssertionError("Invalid class member %s" % pytd_utils.Print(member))
def test_instantiate_imported_generic(self): foo = self.Infer(""" from typing import Generic, TypeVar T = TypeVar('T') class Foo(Generic[T]): def __init__(self): pass """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd_utils.Print(foo)) ty = self.Infer(""" import foo x = foo.Foo[int]() """, pythonpath=[d.path]) self.assertTypesMatchPytd( ty, """ foo: module x: foo.Foo[int] """)
def test_reingest_generic(self): foo = self.Infer(""" from typing import Generic, TypeVar T = TypeVar('T') class Foo(Generic[T]): def __init__(self, x: T): self.x = x """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd_utils.Print(foo)) ty = self.Infer(""" import foo x1 = foo.Foo(0).x x2 = foo.Foo[str](__any_object__).x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ foo: module x1: int x2: str """)
def test_create_type_parameters_for_new(self): src = textwrap.dedent(""" class Foo: def __new__(cls: Type[Foo]) -> Foo class Bar: def __new__(cls: Type[Bar], x, y, z) -> Bar """) ast = self.Parse(src).Visit(visitors.CreateTypeParametersForSignatures()) self.assertMultiLineEqual(pytd_utils.Print(ast), textwrap.dedent(""" from typing import TypeVar _TBar = TypeVar('_TBar', bound=Bar) _TFoo = TypeVar('_TFoo', bound=Foo) class Foo: def __new__(cls: Type[_TFoo]) -> _TFoo: ... class Bar: def __new__(cls: Type[_TBar], x, y, z) -> _TBar: ... """).strip())
def test_maybe_identity_decorators(self): foo = self.Infer(""" def maybe_decorate(f): return f or (lambda *args: 42) """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd_utils.Print(foo)) ty = self.Infer(""" import foo @foo.maybe_decorate def f(): return 3 def g(): return f() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ foo = ... # type: module def f() -> int def g() -> int """)
def test_reingest_custom_protocol(self): ty = self.Infer(""" from typing_extensions import Protocol class Appendable(Protocol): def append(self) -> None: pass """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd_utils.Print(ty)) self.Check(""" import foo class MyAppendable(object): def append(self): pass def f(x: foo.Appendable): pass f([]) f(MyAppendable()) """, pythonpath=[d.path])
def test_attrib_wrapper(self): foo_ty = self.Infer(""" import attr def attrib_wrapper(*args, **kwargs): return attr.ib(*args, **kwargs) """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd_utils.Print(foo_ty)) self.CheckWithErrors(""" import attr import foo @attr.s() class Foo: x: int = foo.attrib_wrapper() y = foo.attrib_wrapper(type=int) a = Foo(10, 10) b = Foo(10, '10') # The wrapper returns attr.ib(Any) so y.type is lost c = Foo(10, 20, 30) # wrong-arg-count d = Foo('10', 20) # wrong-arg-types """, pythonpath=[d.path])
def testPrintMergeTypes(self): src = textwrap.dedent(""" from typing import Union def a(a: float) -> int: ... def b(a: Union[int, float]) -> int: ... def c(a: object) -> Union[float, int]: ... def d(a: float) -> int: ... def e(a: Union[bool, None]) -> Union[bool, None]: ... """) expected = textwrap.dedent(""" from typing import Optional, Union def a(a: float) -> int: ... def b(a: float) -> int: ... def c(a: object) -> Union[float, int]: ... def d(a: float) -> int: ... def e(a: bool) -> Optional[bool]: ... """) self.assertMultiLineEqual(expected.strip(), pytd_utils.Print(self.ToAST(src)).strip())
def testLookupStarAliasWithDifferentGetAttr(self): src1 = "def __getattr__(name) -> int" src2 = textwrap.dedent(""" from foo import * def __getattr__(name) -> str """) ast1 = self.Parse(src1).Replace(name="foo").Visit( visitors.AddNamePrefix()) ast2 = self.Parse(src2).Replace(name="bar").Visit( visitors.AddNamePrefix()) ast2 = ast2.Visit( visitors.LookupExternalTypes({ "foo": ast1, "bar": ast2 }, self_name="bar")) self.assertMultiLineEqual( pytd_utils.Print(ast2), textwrap.dedent("""\ def bar.__getattr__(name) -> str: ..."""))
def _namedtuple_def(self, suffix="", **kws): """Generate the expected pyi for a simple namedtuple definition. Args: suffix: Optionally, extra text to append to the pyi. **kws: Must contain exactly one argument of the form alias=(name, [<fields>]). For example, to generate a definition for X = namedtuple("_X", "y z"), the method call should be _namedtuple_def(X=("_X", ["y", "z"])). Returns: The expected pyi for the namedtuple instance. """ (alias, (name, fields)), = kws.items() # pylint: disable=unbalanced-tuple-unpacking name = collections_overlay.namedtuple_name(name, fields) suffix += textwrap.dedent(""" collections = ... # type: module {alias} = {name}""").format(alias=alias, name=name) return pytd_utils.Print(self._namedtuple_ast(name, fields)) + "\n" + suffix
def test_inherited_mutation_in_generic_class(self): foo = self.Infer(""" from typing import List, TypeVar T = TypeVar("T") class MyList(List[T]): write = list.append """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd_utils.Print(foo)) ty = self.Infer(""" import foo lst = foo.MyList() lst.write(42) """, pythonpath=[d.path]) self.assertTypesMatchPytd( ty, """ foo = ... # type: module lst = ... # type: foo.MyList[int] """)
def testAddNamePrefixOnNestedClassAlias(self): src = textwrap.dedent(""" class A: class B: class C: ... D = A.B.C """) expected = textwrap.dedent(""" from typing import Type class foo.A: class foo.A.B: class foo.A.B.C: ... D: Type[foo.A.B.C] """).strip() self.assertMultiLineEqual( expected, pytd_utils.Print( self.Parse(src).Replace(name="foo").Visit( visitors.AddNamePrefix())))
def main(): argument_parser = make_parser() opts = argument_parser.parse_args() if opts.python_version: python_version = utils.version_from_string(opts.python_version) else: python_version = sys.version_info[:2] try: utils.validate_version(python_version) except utils.UsageError as e: sys.stderr.write("Usage error: %s\n" % utils.message(e)) sys.exit(1) options = parser.PyiOptions(python_version=python_version) with open(opts.input) as fi: sourcecode = fi.read() try: parsed = parser.parse_string( sourcecode, filename=opts.input, options=options) except parser.ParseError as e: sys.stderr.write(str(e)) sys.exit(1) if opts.optimize: parsed = optimize.Optimize( parsed, pytd_utils.Concat(*builtin_stubs.GetBuiltinsAndTyping(options)), lossy=opts.lossy, use_abcs=opts.use_abcs, max_union=opts.max_union, remove_mutable=opts.remove_mutable, can_do_lookup=False) if opts.output is not None: out_text = pytd_utils.Print(parsed, opts.multiline_args) if opts.output == "-": sys.stdout.write(out_text) else: with open(opts.output, "w") as out: out.write(out_text)
def _parameterized_type(self, base_type, parameters): """Return a parameterized type.""" if self._matches_named_type(base_type, _LITERAL_TYPES): return types.pytd_literal(parameters) elif self._matches_named_type(base_type, _ANNOTATED_TYPES): return types.pytd_annotated(parameters) elif any(isinstance(p, types.Constant) for p in parameters): parameters = ", ".join( p.repr_str() if isinstance(p, types.Constant) else "_" for p in parameters) raise ParseError( "%s[%s] not supported" % (pytd_utils.Print(base_type), parameters)) elif pytdgen.is_any(base_type): return pytd.AnythingType() elif len(parameters) == 2 and parameters[-1] is self.ELLIPSIS and ( not self._matches_named_type(base_type, _CALLABLE_TYPES)): element_type = parameters[0] if element_type is self.ELLIPSIS: raise ParseError("[..., ...] not supported") return pytd.GenericType(base_type=base_type, parameters=(element_type,)) else: parameters = tuple(pytd.AnythingType() if p is self.ELLIPSIS else p for p in parameters) if self._matches_named_type(base_type, _TUPLE_TYPES): return pytdgen.heterogeneous_tuple(base_type, parameters) elif self._matches_named_type(base_type, _CALLABLE_TYPES): callable_parameters = [] for p in parameters: # We do not yet support PEP 612, Parameter Specification Variables. # To avoid blocking typeshed from adopting this PEP, we convert new # features to Any. if p in self.param_specs or ( isinstance(p, pytd.GenericType) and self._matches_full_name( p, ("typing.Concatenate", "typing_extensions.Concatenate"))): callable_parameters.append(pytd.AnythingType()) else: callable_parameters.append(p) return pytdgen.pytd_callable(base_type, tuple(callable_parameters)) else: assert parameters return pytd.GenericType(base_type=base_type, parameters=parameters)
def testDeterminism(self): # Regression test for code on which pytype used to be non-deterministic. canonical = None for _ in range(10): # increase the chance of finding non-determinism ty = self.Infer(""" class Foo(object): def __init__(self, filenames): self._dict = {} for filename in filenames: d = self._dict if __random__: d[__any_object__] = {} d = d[__any_object__] if __random__: d[__any_object__] = None """) out = pytd_utils.Print(ty) if canonical is None: canonical = out else: self.assertMultiLineEqual(canonical, out)
def test_reingest_custom_protocol_error(self): ty = self.Infer(""" from typing_extensions import Protocol class Appendable(Protocol): def append(self) -> None: pass """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd_utils.Print(ty)) errors = self.CheckWithErrors(""" import foo class NotAppendable: pass def f(x: foo.Appendable): pass f(42) # wrong-arg-types[e1] f(NotAppendable()) # wrong-arg-types[e2] """, pythonpath=[d.path]) self.assertErrorRegexes(errors, { "e1": r"Appendable.*int.*append", "e2": r"Appendable.*NotAppendable.*append"})
def test_literal(self): ast = self.Parse(""" from typing import Literal x1: Literal[""] x2: Literal[b""] x3: Literal[0] x4: Literal[True] x5: Literal[None] """) ast = ast.Visit(visitors.LookupBuiltins(self.loader.builtins)) self.assertMultiLineEqual( pytd_utils.Print(ast), textwrap.dedent(""" from typing import Literal x1: Literal[''] x2: Literal[b''] x3: Literal[0] x4: Literal[True] x5: None """).strip())
def substitute_formal_args(self, node, args, view, alias_map): """Substitute matching args into this signature. Used by PyTDFunction.""" formal_args, arg_dict = self._map_args(args, view) self._fill_in_missing_parameters(node, args, arg_dict) subst, bad_arg = self.vm.matcher.compute_subst( node, formal_args, arg_dict, view, alias_map) if subst is None: if self.signature.has_param(bad_arg.name): signature = self.signature else: signature = self.signature.insert_varargs_and_kwargs(arg_dict) raise WrongArgTypes(signature, args, self.vm, bad_param=bad_arg) if log.isEnabledFor(logging.DEBUG): log.debug("Matched arguments against sig%s", pytd_utils.Print(self.pytd_sig)) for nr, p in enumerate(self.pytd_sig.params): log.info("param %d) %s: %s <=> %s", nr, p.name, p.type, arg_dict[p.name]) for name, var in sorted(subst.items()): log.debug("Using %s=%r %r", name, var, var.data) return arg_dict, subst
def test_store_and_load_from_namedtuple(self): ty = self.Infer(""" import collections t = collections.namedtuple('t', ['x', 'y', 'z']) t.x = 3 t.y = "foo" t.z = 1j x = t.x y = t.y z = t.z """) name = escape.pack_namedtuple("t", ["x", "y", "z"]) ast = named_tuple.namedtuple_ast( name, ["x", "y", "z"], [False] * 3, self.options) expected = pytd_utils.Print(ast) + textwrap.dedent(""" import collections t = {name} x = ... # type: int y = ... # type: str z = ... # type: complex""").format(name=name) self.assertTypesMatchPytd(ty, expected)
def test_store_and_load_from_namedtuple(self): ty = self.Infer(""" import collections t = collections.namedtuple('t', ['x', 'y', 'z']) t.x = 3 t.y = "foo" t.z = 1j x = t.x y = t.y z = t.z """) name = collections_overlay.namedtuple_name("t", ["x", "y", "z"]) ast = collections_overlay.namedtuple_ast(name, ["x", "y", "z"], self.python_version) expected = pytd_utils.Print(ast) + textwrap.dedent(""" collections = ... # type: module t = {name} x = ... # type: int y = ... # type: str z = ... # type: complex""").format(name=name) self.assertTypesMatchPytd(ty, expected)
def test_reingest_and_subclass(self): with file_utils.Tempdir() as d: self._setup_linen_pyi(d) foo_ty = self.Infer(""" from flax import linen class Foo(linen.Module): pass """, pythonpath=[d.path]) d.create_file("foo.pyi", pytd_utils.Print(foo_ty)) ty = self.Infer(""" import foo class Bar(foo.Foo): pass class Baz(Bar): x: int """, pythonpath=[d.path]) self.assertTypesMatchPytd( ty, """ import dataclasses import foo from typing import Any, Dict, TypeVar _TBar = TypeVar('_TBar', bound=Bar) @dataclasses.dataclass class Bar(foo.Foo): __dataclass_fields__: Dict[str, dataclasses.Field] def __init__(self, name: str = ..., parent: Any = ...) -> None: ... def replace(self: _TBar, **kwargs) -> _TBar: ... _TBaz = TypeVar('_TBaz', bound=Baz) @dataclasses.dataclass class Baz(Bar): x: int __dataclass_fields__: Dict[str, dataclasses.Field] def __init__( self, x: int, name: str = ..., parent: Any = ...) -> None: ... def replace(self: _TBaz, **kwargs) -> _TBaz: ... """)
def _pytd_print(self, pytd_type): """Print the name of the pytd type.""" name = pytd_utils.Print(pytd_utils.CanonicalOrdering(optimize.Optimize( pytd_type.Visit(visitors.RemoveUnknownClasses())))) # Clean up autogenerated namedtuple names, e.g. "namedtuple-X-a-_0-c" # becomes just "X", by extracting out just the type name. if "namedtuple" in name: return escape.unpack_namedtuple(name) nested_class_match = re.search(r"_(?:\w+)_DOT_", name) if nested_class_match: # Pytype doesn't have true support for nested classes. Instead, for # class Foo: # class Bar: ... # it outputs: # class _Foo_DOT_Bar: ... # class Foo: # Bar = ... # type: Type[_Foo_DOT_Bar] # Replace _Foo_DOT_Bar with Foo.Bar in error messages for readability. # TODO(b/35138984): Get rid of this hack. start = nested_class_match.start() return name[:start] + name[start+1:].replace("_DOT_", ".") return name