def testRemoveUnknownClasses(self): src = textwrap.dedent(""" class `~unknown1`(): pass class `~unknown2`(): pass class A(object): def foobar(x: `~unknown1`, y: `~unknown2`) -> `~unknown1` or int """) expected = textwrap.dedent(""" class A(object): def foobar(x, y) -> ? or int """) tree = self.Parse(src) tree = tree.Visit(visitors.RemoveUnknownClasses()) tree = tree.Visit(visitors.DropBuiltinPrefix()) self.AssertSourceEquals(tree, expected)
def test_remove_unknown_classes(self): src = pytd_src(""" from typing import Union class `~unknown1`(): pass class `~unknown2`(): pass class A: def foobar(x: `~unknown1`, y: `~unknown2`) -> Union[`~unknown1`, int]: ... """) expected = textwrap.dedent(""" from typing import Any, Union class A: def foobar(x, y) -> Union[Any, int]: ... """) tree = self.Parse(src) tree = tree.Visit(visitors.RemoveUnknownClasses()) self.AssertSourceEquals(tree, expected)
def _pytd_print(self, pytd_type): """Print the name of the pytd type.""" name = pytd_utils.Print(pytd_utils.CanonicalOrdering(optimize.Optimize( pytd_type.Visit(visitors.RemoveUnknownClasses())))) # Clean up autogenerated namedtuple names, e.g. "namedtuple-X-a-_0-c" # becomes just "X", by extracting out just the type name. if "namedtuple" in name: return escape.unpack_namedtuple(name) nested_class_match = re.search(r"_(?:\w+)_DOT_", name) if nested_class_match: # Pytype doesn't have true support for nested classes. Instead, for # class Foo: # class Bar: ... # it outputs: # class _Foo_DOT_Bar: ... # class Foo: # Bar = ... # type: Type[_Foo_DOT_Bar] # Replace _Foo_DOT_Bar with Foo.Bar in error messages for readability. # TODO(b/35138984): Get rid of this hack. start = nested_class_match.start() return name[:start] + name[start+1:].replace("_DOT_", ".") return name
def _to_pytd(datum, loader, ast): if not datum: return pytd.AnythingType() t = pytd_utils.JoinTypes(v.to_type() for v in datum).Visit( visitors.RemoveUnknownClasses()) return loader.resolve_type(t, ast)
def infer_types(src, errorlog, options, loader, filename=None, deep=True, init_maximum_depth=INIT_MAXIMUM_DEPTH, show_library_calls=False, maximum_depth=None, tracer_vm=None, **kwargs): """Given Python source return its types. Args: src: A string containing Python source code. errorlog: Where error messages go. Instance of errors.ErrorLog. options: config.Options object loader: A load_pytd.Loader instance to load PYI information. filename: Filename of the program we're parsing. deep: If True, analyze all functions, even the ones not called by the main execution flow. init_maximum_depth: Depth of analysis during module loading. show_library_calls: If True, call traces are kept in the output. maximum_depth: Depth of the analysis. Default: unlimited. tracer_vm: An instance of CallTracer, in case the caller wants to instantiate and retain the vm used for type inference. **kwargs: Additional parameters to pass to vm.VirtualMachine Returns: A tuple of (ast: TypeDeclUnit, builtins: TypeDeclUnit) Raises: AssertionError: In case of a bad parameter combination. """ # If the caller has passed in a vm, use that. if tracer_vm: assert isinstance(tracer_vm, CallTracer) tracer = tracer_vm else: tracer = CallTracer(errorlog=errorlog, options=options, generate_unknowns=options.protocols, store_all_calls=not deep, loader=loader, **kwargs) loc, defs = tracer.run_program(src, filename, init_maximum_depth) log.info("===Done running definitions and module-level code===") snapshotter = metrics.get_metric("memory", metrics.Snapshot) snapshotter.take_snapshot("analyze:infer_types:tracer") if deep: if maximum_depth is None: if not options.quick: maximum_depth = MAXIMUM_DEPTH elif options.analyze_annotated: # Since there's no point in analyzing annotated functions for inference, # the presence of this option means that the user wants checking, too. maximum_depth = QUICK_CHECK_MAXIMUM_DEPTH else: maximum_depth = QUICK_INFER_MAXIMUM_DEPTH tracer.exitpoint = tracer.analyze(loc, defs, maximum_depth) else: tracer.exitpoint = loc snapshotter.take_snapshot("analyze:infer_types:post") ast = tracer.compute_types(defs) ast = tracer.loader.resolve_ast(ast) if tracer.has_unknown_wildcard_imports or any( a in defs for a in abstract_utils.DYNAMIC_ATTRIBUTE_MARKERS): try: ast.Lookup("__getattr__") except KeyError: ast = pytd_utils.Concat( ast, builtins.GetDefaultAst(options.python_version)) # If merged with other if statement, triggers a ValueError: Unresolved class # when attempts to load from the protocols file if options.protocols: protocols_pytd = tracer.loader.import_name("protocols") else: protocols_pytd = None builtins_pytd = tracer.loader.concat_all() # Insert type parameters, where appropriate ast = ast.Visit(visitors.CreateTypeParametersForSignatures()) if options.protocols: log.info("=========== PyTD to solve =============\n%s", pytd_utils.Print(ast)) ast = convert_structural.convert_pytd(ast, builtins_pytd, protocols_pytd) elif not show_library_calls: log.info("Solving is turned off. Discarding call traces.") # Rename remaining "~unknown" to "?" ast = ast.Visit(visitors.RemoveUnknownClasses()) # Remove "~list" etc.: ast = convert_structural.extract_local(ast) _maybe_output_debug(options, tracer.program) return ast, builtins_pytd
def get_pytd(self, datum): if not datum: return pytd.AnythingType() t = pytd_utils.JoinTypes(v.to_type() for v in datum).Visit( visitors.RemoveUnknownClasses()) return self.loader.resolve_type(t, self.pytd_module)
def _join_types(vals): return pytd_utils.JoinTypes(v.to_type() for v in vals if v).Visit( visitors.RemoveUnknownClasses())