def setUp(self): builtins = parser.parse_string(textwrap.dedent(_BUILTINS), name="__builtin__", python_version=self.PYTHON_VERSION) typing = parser.parse_string("class Generic: ...", name="typing", python_version=self.PYTHON_VERSION) self.mini_builtins = pytd_utils.Concat(builtins, typing)
def test_concat(self): """Test for concatenating two pytd ASTs.""" ast1 = self.Parse(""" c1 = ... # type: int def f1() -> int: ... class Class1(object): pass """) ast2 = self.Parse(""" c2 = ... # type: int def f2() -> int: ... class Class2(object): pass """) expected = textwrap.dedent(""" c1 = ... # type: int c2 = ... # type: int def f1() -> int: ... def f2() -> int: ... class Class1(object): pass class Class2(object): pass """) combined = pytd_utils.Concat(ast1, ast2) self.AssertSourceEquals(combined, expected)
def concat_all(self): if not self._concatenated: self._concatenated = pytd_utils.Concat( *(module.ast for module in self._modules.values() if module.ast), name="<all>") return self._concatenated
def testConcatTypeParameters(self): """Test for concatenating ASTs with type parameters.""" ast1 = self.Parse("""T = TypeVar("T")""", name="__builtin__") ast2 = self.Parse("""T = TypeVar("T")""") combined = pytd_utils.Concat(ast1, ast2) self.assertEqual(combined.Lookup("__builtin__.T"), pytd.TypeParameter("T", scope="__builtin__")) self.assertEqual(combined.Lookup("T"), pytd.TypeParameter("T", scope=None))
def setUp(self): super().setUp() builtins = parser.parse_string(textwrap.dedent(_BUILTINS), name="builtins", options=self.options) typing = parser.parse_string("class Generic: ...", name="typing", options=self.options) self.mini_builtins = pytd_utils.Concat(builtins, typing)
def setUp(self): super(TestTypeMatch, self).setUp() builtins = parser.parse_string(textwrap.dedent(_BUILTINS), name="__builtin__", python_version=self.python_version) typing = parser.parse_string("class Generic: ...", name="typing", python_version=self.python_version) self.mini_builtins = pytd_utils.Concat(builtins, typing)
def convert_pytd(ast, builtins_pytd, protocols_pytd): """Convert pytd with unknowns (structural types) to one with nominal types.""" builtins_pytd = builtins_pytd.Visit(visitors.ClassTypeToNamedType()) mapping, result = solve(ast, builtins_pytd, protocols_pytd) log_info_mapping(mapping) lookup = pytd_utils.Concat(builtins_pytd, result) result = insert_solution(result, mapping, lookup) if log.isEnabledFor(logging.INFO): log.info("=========== solve result =============\n%s", pytd.Print(result)) log.info("=========== solve result (end) =============") return result
def test_concat3(self): """Test for concatenating three pytd ASTs.""" ast1 = self.Parse("""c1 = ... # type: int""") ast2 = self.Parse("""c2 = ... # type: float""") ast3 = self.Parse("""c3 = ... # type: bool""") combined = pytd_utils.Concat(ast1, ast2, ast3) expected = textwrap.dedent(""" c1 = ... # type: int c2 = ... # type: float c3 = ... # type: bool """) self.AssertSourceEquals(combined, expected)
def GetBuiltinsPyTD(python_version): # Deprecated. Use Loader.concat_all. """Get the "default" AST used to lookup built in types. Get an AST for all Python builtins as well as the most commonly used standard libraries. Args: python_version: The python version tuple. Returns: A pytd.TypeDeclUnit instance. It'll directly contain the builtin classes and functions, and submodules for each of the standard library modules. """ assert python_version return pytd_utils.Concat(*GetBuiltinsAndTyping(python_version))
def compute_types(self, defs): classes = (tuple(self.pytd_classes_for_unknowns()) + tuple(self.pytd_classes_for_call_traces()) + self.pytd_classes_for_namedtuple_instances()) functions = tuple(self.pytd_functions_for_call_traces()) aliases = tuple(self.pytd_aliases()) ty = pytd_utils.Concat( self.pytd_for_types(defs), pytd_utils.CreateModule("unknowns", classes=classes, functions=functions, aliases=aliases)) ty = ty.Visit(optimize.CombineReturnsAndExceptions()) ty = ty.Visit(optimize.PullInMethodClasses()) ty = ty.Visit(visitors.DefaceUnresolved( [ty, self.loader.concat_all()], "~unknown")) return ty.Visit(visitors.AdjustTypeParameters())
def test_name_conflict(self): ty = self.Infer(""" import collections X = collections.namedtuple("_", []) Y = collections.namedtuple("_", []) Z = collections.namedtuple("_", "a") """, deep=False) name_x = escape.pack_namedtuple("_", []) name_z = escape.pack_namedtuple("_", ["a"]) ast_x = self._namedtuple_ast(name_x, []) ast_z = self._namedtuple_ast(name_z, ["a"]) ast = pytd_utils.Concat(ast_x, ast_z) expected = pytd_utils.Print(ast) + textwrap.dedent(""" collections = ... # type: module X = {name_x} Y = {name_x} Z = {name_z}""").format(name_x=name_x, name_z=name_z) self.assertTypesMatchPytd(ty, expected)
def compute_types(self, defs): ty = pytd_utils.Concat( self.pytd_for_types(defs), pytd.TypeDeclUnit(name="unknowns", is_package=False, constants=tuple(), type_params=tuple(), classes=tuple(self.pytd_classes_for_unknowns()) + tuple(self.pytd_classes_for_call_traces()) + self.pytd_classes_for_namedtuple_instances(), functions=tuple( self.pytd_functions_for_call_traces()), aliases=tuple(self.pytd_aliases()))) ty = ty.Visit(optimize.CombineReturnsAndExceptions()) ty = ty.Visit(optimize.PullInMethodClasses()) ty = ty.Visit( visitors.DefaceUnresolved([ty, self.loader.concat_all()], "~unknown")) return ty.Visit(visitors.AdjustTypeParameters())
def main(): argument_parser = make_parser() opts = argument_parser.parse_args() if opts.python_version: python_version = utils.version_from_string(opts.python_version) else: python_version = sys.version_info[:2] try: utils.validate_version(python_version) except utils.UsageError as e: sys.stderr.write("Usage error: %s\n" % utils.message(e)) sys.exit(1) options = parser.PyiOptions(python_version=python_version) with open(opts.input) as fi: sourcecode = fi.read() try: parsed = parser.parse_string( sourcecode, filename=opts.input, options=options) except parser.ParseError as e: sys.stderr.write(str(e)) sys.exit(1) if opts.optimize: parsed = optimize.Optimize( parsed, pytd_utils.Concat(*builtin_stubs.GetBuiltinsAndTyping(options)), lossy=opts.lossy, use_abcs=opts.use_abcs, max_union=opts.max_union, remove_mutable=opts.remove_mutable, can_do_lookup=False) if opts.output is not None: out_text = pytd_utils.Print(parsed, opts.multiline_args) if opts.output == "-": sys.stdout.write(out_text) else: with open(opts.output, "w") as out: out.write(out_text)
def infer_types(src, errorlog, options, loader, filename=None, deep=True, init_maximum_depth=INIT_MAXIMUM_DEPTH, show_library_calls=False, maximum_depth=None, tracer_vm=None, **kwargs): """Given Python source return its types. Args: src: A string containing Python source code. errorlog: Where error messages go. Instance of errors.ErrorLog. options: config.Options object loader: A load_pytd.Loader instance to load PYI information. filename: Filename of the program we're parsing. deep: If True, analyze all functions, even the ones not called by the main execution flow. init_maximum_depth: Depth of analysis during module loading. show_library_calls: If True, call traces are kept in the output. maximum_depth: Depth of the analysis. Default: unlimited. tracer_vm: An instance of CallTracer, in case the caller wants to instantiate and retain the vm used for type inference. **kwargs: Additional parameters to pass to vm.VirtualMachine Returns: A tuple of (ast: TypeDeclUnit, builtins: TypeDeclUnit) Raises: AssertionError: In case of a bad parameter combination. """ # If the caller has passed in a vm, use that. if tracer_vm: assert isinstance(tracer_vm, CallTracer) tracer = tracer_vm else: tracer = CallTracer(errorlog=errorlog, options=options, generate_unknowns=options.protocols, store_all_calls=not deep, loader=loader, **kwargs) loc, defs = tracer.run_program(src, filename, init_maximum_depth) log.info("===Done running definitions and module-level code===") snapshotter = metrics.get_metric("memory", metrics.Snapshot) snapshotter.take_snapshot("analyze:infer_types:tracer") if deep: if maximum_depth is None: if not options.quick: maximum_depth = MAXIMUM_DEPTH elif options.analyze_annotated: # Since there's no point in analyzing annotated functions for inference, # the presence of this option means that the user wants checking, too. maximum_depth = QUICK_CHECK_MAXIMUM_DEPTH else: maximum_depth = QUICK_INFER_MAXIMUM_DEPTH tracer.exitpoint = tracer.analyze(loc, defs, maximum_depth) else: tracer.exitpoint = loc snapshotter.take_snapshot("analyze:infer_types:post") ast = tracer.compute_types(defs) ast = tracer.loader.resolve_ast(ast) if tracer.has_unknown_wildcard_imports or any( a in defs for a in abstract_utils.DYNAMIC_ATTRIBUTE_MARKERS): try: ast.Lookup("__getattr__") except KeyError: ast = pytd_utils.Concat( ast, builtins.GetDefaultAst(options.python_version)) # If merged with other if statement, triggers a ValueError: Unresolved class # when attempts to load from the protocols file if options.protocols: protocols_pytd = tracer.loader.import_name("protocols") else: protocols_pytd = None builtins_pytd = tracer.loader.concat_all() # Insert type parameters, where appropriate ast = ast.Visit(visitors.CreateTypeParametersForSignatures()) if options.protocols: log.info("=========== PyTD to solve =============\n%s", pytd_utils.Print(ast)) ast = convert_structural.convert_pytd(ast, builtins_pytd, protocols_pytd) elif not show_library_calls: log.info("Solving is turned off. Discarding call traces.") # Rename remaining "~unknown" to "?" ast = ast.Visit(visitors.RemoveUnknownClasses()) # Remove "~list" etc.: ast = convert_structural.extract_local(ast) _maybe_output_debug(options, tracer.program) return ast, builtins_pytd
def setUpClass(cls): super(UtilsTest, cls).setUpClass() cls.builtins = pytd_utils.Concat(*builtin_stubs.GetBuiltinsAndTyping( parser.PyiOptions(python_version=cls.python_version)))
def concat_all(self): if not self._concatenated: self._concatenated = pytd_utils.Concat(*self.defined_asts(), name="<all>") return self._concatenated