예제 #1
0
 def testRemoveUnknownClasses(self):
   src = textwrap.dedent("""
       class `~unknown1`():
           pass
       class `~unknown2`():
           pass
       class A(object):
           def foobar(x: `~unknown1`, y: `~unknown2`) -> `~unknown1` or int
   """)
   expected = textwrap.dedent("""
       class A(object):
           def foobar(x, y) -> ? or int
   """)
   tree = self.Parse(src)
   tree = tree.Visit(visitors.RemoveUnknownClasses())
   tree = tree.Visit(visitors.DropBuiltinPrefix())
   self.AssertSourceEquals(tree, expected)
예제 #2
0
 def test_remove_unknown_classes(self):
   src = pytd_src("""
       from typing import Union
       class `~unknown1`():
           pass
       class `~unknown2`():
           pass
       class A:
           def foobar(x: `~unknown1`, y: `~unknown2`) -> Union[`~unknown1`, int]: ...
   """)
   expected = textwrap.dedent("""
       from typing import Any, Union
       class A:
           def foobar(x, y) -> Union[Any, int]: ...
   """)
   tree = self.Parse(src)
   tree = tree.Visit(visitors.RemoveUnknownClasses())
   self.AssertSourceEquals(tree, expected)
예제 #3
0
파일: errors.py 프로젝트: ghostdart/pytype
 def _pytd_print(self, pytd_type):
   """Print the name of the pytd type."""
   name = pytd_utils.Print(pytd_utils.CanonicalOrdering(optimize.Optimize(
       pytd_type.Visit(visitors.RemoveUnknownClasses()))))
   # Clean up autogenerated namedtuple names, e.g. "namedtuple-X-a-_0-c"
   # becomes just "X", by extracting out just the type name.
   if "namedtuple" in name:
     return escape.unpack_namedtuple(name)
   nested_class_match = re.search(r"_(?:\w+)_DOT_", name)
   if nested_class_match:
     # Pytype doesn't have true support for nested classes. Instead, for
     #   class Foo:
     #     class Bar: ...
     # it outputs:
     #   class _Foo_DOT_Bar: ...
     #   class Foo:
     #     Bar = ...  # type: Type[_Foo_DOT_Bar]
     # Replace _Foo_DOT_Bar with Foo.Bar in error messages for readability.
     # TODO(b/35138984): Get rid of this hack.
     start = nested_class_match.start()
     return name[:start] + name[start+1:].replace("_DOT_", ".")
   return name
예제 #4
0
def _to_pytd(datum, loader, ast):
  if not datum:
    return pytd.AnythingType()
  t = pytd_utils.JoinTypes(v.to_type() for v in datum).Visit(
      visitors.RemoveUnknownClasses())
  return loader.resolve_type(t, ast)
예제 #5
0
def infer_types(src,
                errorlog,
                options,
                loader,
                filename=None,
                deep=True,
                init_maximum_depth=INIT_MAXIMUM_DEPTH,
                show_library_calls=False,
                maximum_depth=None,
                tracer_vm=None,
                **kwargs):
    """Given Python source return its types.

  Args:
    src: A string containing Python source code.
    errorlog: Where error messages go. Instance of errors.ErrorLog.
    options: config.Options object
    loader: A load_pytd.Loader instance to load PYI information.
    filename: Filename of the program we're parsing.
    deep: If True, analyze all functions, even the ones not called by the main
      execution flow.
    init_maximum_depth: Depth of analysis during module loading.
    show_library_calls: If True, call traces are kept in the output.
    maximum_depth: Depth of the analysis. Default: unlimited.
    tracer_vm: An instance of CallTracer, in case the caller wants to
      instantiate and retain the vm used for type inference.
    **kwargs: Additional parameters to pass to vm.VirtualMachine
  Returns:
    A tuple of (ast: TypeDeclUnit, builtins: TypeDeclUnit)
  Raises:
    AssertionError: In case of a bad parameter combination.
  """
    # If the caller has passed in a vm, use that.
    if tracer_vm:
        assert isinstance(tracer_vm, CallTracer)
        tracer = tracer_vm
    else:
        tracer = CallTracer(errorlog=errorlog,
                            options=options,
                            generate_unknowns=options.protocols,
                            store_all_calls=not deep,
                            loader=loader,
                            **kwargs)
    loc, defs = tracer.run_program(src, filename, init_maximum_depth)
    log.info("===Done running definitions and module-level code===")
    snapshotter = metrics.get_metric("memory", metrics.Snapshot)
    snapshotter.take_snapshot("analyze:infer_types:tracer")
    if deep:
        if maximum_depth is None:
            if not options.quick:
                maximum_depth = MAXIMUM_DEPTH
            elif options.analyze_annotated:
                # Since there's no point in analyzing annotated functions for inference,
                # the presence of this option means that the user wants checking, too.
                maximum_depth = QUICK_CHECK_MAXIMUM_DEPTH
            else:
                maximum_depth = QUICK_INFER_MAXIMUM_DEPTH
        tracer.exitpoint = tracer.analyze(loc, defs, maximum_depth)
    else:
        tracer.exitpoint = loc
    snapshotter.take_snapshot("analyze:infer_types:post")
    ast = tracer.compute_types(defs)
    ast = tracer.loader.resolve_ast(ast)
    if tracer.has_unknown_wildcard_imports or any(
            a in defs for a in abstract_utils.DYNAMIC_ATTRIBUTE_MARKERS):
        try:
            ast.Lookup("__getattr__")
        except KeyError:
            ast = pytd_utils.Concat(
                ast, builtins.GetDefaultAst(options.python_version))
    # If merged with other if statement, triggers a ValueError: Unresolved class
    # when attempts to load from the protocols file
    if options.protocols:
        protocols_pytd = tracer.loader.import_name("protocols")
    else:
        protocols_pytd = None
    builtins_pytd = tracer.loader.concat_all()
    # Insert type parameters, where appropriate
    ast = ast.Visit(visitors.CreateTypeParametersForSignatures())
    if options.protocols:
        log.info("=========== PyTD to solve =============\n%s",
                 pytd_utils.Print(ast))
        ast = convert_structural.convert_pytd(ast, builtins_pytd,
                                              protocols_pytd)
    elif not show_library_calls:
        log.info("Solving is turned off. Discarding call traces.")
        # Rename remaining "~unknown" to "?"
        ast = ast.Visit(visitors.RemoveUnknownClasses())
        # Remove "~list" etc.:
        ast = convert_structural.extract_local(ast)
    _maybe_output_debug(options, tracer.program)
    return ast, builtins_pytd
예제 #6
0
 def get_pytd(self, datum):
     if not datum:
         return pytd.AnythingType()
     t = pytd_utils.JoinTypes(v.to_type() for v in datum).Visit(
         visitors.RemoveUnknownClasses())
     return self.loader.resolve_type(t, self.pytd_module)
예제 #7
0
def _join_types(vals):
  return pytd_utils.JoinTypes(v.to_type() for v in vals if v).Visit(
      visitors.RemoveUnknownClasses())