Beispiel #1
0
def check_or_generate_pyi(options, loader=None):
    """Returns generated errors and result pyi or None if it's only check.

  Args:
    options: config.Options object.
    loader: load_pytd.Loader object.

  Returns:
    A tuple, (errors.ErrorLog, PYI Ast as string or None, AST or None).
  """

    errorlog = errors.ErrorLog()
    result = pytd_builtins.DEFAULT_SRC
    ast = pytd_builtins.GetDefaultAst(options.python_version)
    try:
        src = read_source_file(options.input, options.open_function)
        if options.check:
            return check_py(src=src, options=options,
                            loader=loader), None, None
        else:
            errorlog, result, ast = generate_pyi(src=src,
                                                 options=options,
                                                 loader=loader)
    except utils.UsageError as e:
        raise
    except pyc.CompileError as e:
        errorlog.python_compiler_error(options.input, e.lineno, e.error)
    except IndentationError as e:
        errorlog.python_compiler_error(options.input, e.lineno, e.msg)
    except tokenize.TokenError as e:
        msg, (lineno, unused_column) = e.args  # pylint: disable=unbalanced-tuple-unpacking
        errorlog.python_compiler_error(options.input, lineno, msg)
    except directors.SkipFileError:
        result += "# skip-file found, file not analyzed"
    except Exception as e:  # pylint: disable=broad-except
        if options.nofail:
            log.warning("***Caught exception: %s", str(e), exc_info=True)
            if not options.check:
                result += (  # pytype: disable=name-error
                    "# Caught error in pytype: " +
                    str(e).replace("\n", "\n#") + "\n# " +
                    "\n# ".join(traceback.format_exc().splitlines()))
        else:
            e.args = (str(utils.message(e)) +
                      "\nFile: %s" % options.input, ) + e.args[1:]
            raise

    return (errorlog, None, None) if options.check else (errorlog, result, ast)
Beispiel #2
0
def write_pickle(ast, options, loader=None):
    """Dump a pickle of the ast to a file."""
    loader = loader or load_pytd.create_loader(options)
    try:
        ast = serialize_ast.PrepareForExport(options.module_name, ast, loader)
    except parser.ParseError as e:
        if options.nofail:
            ast = serialize_ast.PrepareForExport(
                options.module_name,
                pytd_builtins.GetDefaultAst(options.python_version), loader)
            log.warning("***Caught exception: %s", str(e), exc_info=True)
        else:
            raise
    if options.verify_pickle:
        ast1 = ast.Visit(visitors.LateTypeToClassType())
        ast1 = ast1.Visit(visitors.ClearClassPointers())
        ast2 = loader.load_file(options.module_name, options.verify_pickle)
        ast2 = ast2.Visit(visitors.ClearClassPointers())
        if not pytd_utils.ASTeq(ast1, ast2):
            raise AssertionError()
    serialize_ast.StoreAst(ast, options.output, options.open_function)
Beispiel #3
0
def infer_types(src, errorlog, options, loader,
                filename=None, deep=True, init_maximum_depth=INIT_MAXIMUM_DEPTH,
                show_library_calls=False, maximum_depth=None, tracer_vm=None,
                **kwargs):
  """Given Python source return its types.

  Args:
    src: A string containing Python source code.
    errorlog: Where error messages go. Instance of errors.ErrorLog.
    options: config.Options object
    loader: A load_pytd.Loader instance to load PYI information.
    filename: Filename of the program we're parsing.
    deep: If True, analyze all functions, even the ones not called by the main
      execution flow.
    init_maximum_depth: Depth of analysis during module loading.
    show_library_calls: If True, call traces are kept in the output.
    maximum_depth: Depth of the analysis. Default: unlimited.
    tracer_vm: An instance of CallTracer, in case the caller wants to
      instantiate and retain the vm used for type inference.
    **kwargs: Additional parameters to pass to vm.VirtualMachine
  Returns:
    A tuple of (ast: TypeDeclUnit, builtins: TypeDeclUnit)
  Raises:
    AssertionError: In case of a bad parameter combination.
  """
  # If the caller has passed in a vm, use that.
  if tracer_vm:
    assert isinstance(tracer_vm, CallTracer)
    tracer = tracer_vm
  else:
    tracer = CallTracer(errorlog=errorlog, options=options,
                        generate_unknowns=options.protocols,
                        store_all_calls=not deep, loader=loader, **kwargs)
  loc, defs = tracer.run_program(src, filename, init_maximum_depth)
  log.info("===Done running definitions and module-level code===")
  snapshotter = metrics.get_metric("memory", metrics.Snapshot)
  snapshotter.take_snapshot("analyze:infer_types:tracer")
  if deep:
    if maximum_depth is None:
      if not options.quick:
        maximum_depth = MAXIMUM_DEPTH
      elif options.analyze_annotated:
        # Since there's no point in analyzing annotated functions for inference,
        # the presence of this option means that the user wants checking, too.
        maximum_depth = QUICK_CHECK_MAXIMUM_DEPTH
      else:
        maximum_depth = QUICK_INFER_MAXIMUM_DEPTH
    tracer.exitpoint = tracer.analyze(loc, defs, maximum_depth)
  else:
    tracer.exitpoint = loc
  snapshotter.take_snapshot("analyze:infer_types:post")
  ast = tracer.compute_types(defs)
  ast = tracer.loader.resolve_ast(ast)
  if tracer.has_unknown_wildcard_imports or any(
      a in defs for a in abstract_utils.DYNAMIC_ATTRIBUTE_MARKERS):
    if "__getattr__" not in ast:
      ast = pytd_utils.Concat(
          ast, builtins.GetDefaultAst(options.python_version))
  # If merged with other if statement, triggers a ValueError: Unresolved class
  # when attempts to load from the protocols file
  if options.protocols:
    protocols_pytd = tracer.loader.import_name("protocols")
  else:
    protocols_pytd = None
  builtins_pytd = tracer.loader.concat_all()
  # Insert type parameters, where appropriate
  ast = ast.Visit(visitors.CreateTypeParametersForSignatures())
  if options.protocols:
    log.info("=========== PyTD to solve =============\n%s",
             pytd_utils.Print(ast))
    ast = convert_structural.convert_pytd(ast, builtins_pytd, protocols_pytd)
  elif not show_library_calls:
    log.info("Solving is turned off. Discarding call traces.")
    # Rename remaining "~unknown" to "?"
    ast = ast.Visit(visitors.RemoveUnknownClasses())
    # Remove "~list" etc.:
    ast = convert_structural.extract_local(ast)
  _maybe_output_debug(options, tracer.program)
  return ast, builtins_pytd