Ejemplo n.º 1
0
def StoreAst(ast, filename=None, open_function=open):
    """Loads and stores an ast to disk.

  Args:
    ast: The pytd.TypeDeclUnit to save to disk.
    filename: The filename for the pickled output. If this is None, this
      function instead returns the pickled string.
    open_function: A custom file opening function.

  Returns:
    The pickled string, if no filename was given. (None otherwise.)
  """
    if ast.name.endswith(".__init__"):
        ast = ast.Visit(
            visitors.RenameModuleVisitor(ast.name,
                                         ast.name.rsplit(".__init__", 1)[0]))
    # Collect dependencies
    deps = visitors.CollectDependencies()
    ast.Visit(deps)
    dependencies = deps.dependencies
    late_dependencies = deps.late_dependencies

    # Clean external references
    ast.Visit(visitors.ClearClassPointers())
    indexer = FindClassTypesVisitor()
    ast.Visit(indexer)
    ast = ast.Visit(visitors.CanonicalOrderingVisitor())
    return pytd_utils.SavePickle(SerializableAst(
        ast, sorted(dependencies.items()), sorted(late_dependencies.items()),
        sorted(indexer.class_type_nodes)),
                                 filename,
                                 open_function=open_function)
Ejemplo n.º 2
0
 def testCollectDependenciesOnFunctionType(self):
   other = self.Parse("""
     def f(): ...
   """)
   ast = pytd.FunctionType("foo.bar", other.Lookup("f"))
   deps = visitors.CollectDependencies()
   ast.Visit(deps)
   six.assertCountEqual(self, {"foo"}, deps.dependencies)
Ejemplo n.º 3
0
 def testCollectDependencies(self):
   src = textwrap.dedent("""
     l = ... # type: list[int or baz.BigInt]
     def f1() -> bar.Bar
     def f2() -> foo.bar.Baz
   """)
   deps = visitors.CollectDependencies()
   self.Parse(src).Visit(deps)
   six.assertCountEqual(self, {"baz", "bar", "foo.bar"}, deps.dependencies)
Ejemplo n.º 4
0
 def _verifyDeps(self, module, immediate_deps, late_deps):
     if isinstance(module, bytes):
         data = cPickle.loads(module)
         six.assertCountEqual(self, dict(data.dependencies), immediate_deps)
         six.assertCountEqual(self, dict(data.late_dependencies), late_deps)
     else:
         c = visitors.CollectDependencies()
         module.Visit(c)
         six.assertCountEqual(self, c.dependencies, immediate_deps)
         six.assertCountEqual(self, c.late_dependencies, late_deps)
Ejemplo n.º 5
0
 def test_collect_dependencies(self):
   src = textwrap.dedent("""
     from typing import Union
     l = ... # type: list[Union[int, baz.BigInt]]
     def f1() -> bar.Bar: ...
     def f2() -> foo.bar.Baz: ...
   """)
   deps = visitors.CollectDependencies()
   self.Parse(src).Visit(deps)
   six.assertCountEqual(self, {"baz", "bar", "foo.bar"}, deps.dependencies)
Ejemplo n.º 6
0
 def _verifyDeps(self, module, immediate_deps, late_deps):
     if isinstance(module, str):
         data = cPickle.loads(module)
         self.assertItemsEqual(data.dependencies, immediate_deps)
         ast = data.ast
     else:
         c = visitors.CollectDependencies()
         module.Visit(c)
         self.assertItemsEqual(c.modules, immediate_deps)
         ast = module
     c = visitors.CollectLateDependencies()
     ast.Visit(c)
     self.assertItemsEqual(c.modules, late_deps)
Ejemplo n.º 7
0
 def _verifyDeps(self, module, immediate_deps, late_deps):
     if isinstance(module, bytes):
         data = cPickle.loads(module)
         six.assertCountEqual(self, data.dependencies, immediate_deps)
         ast = data.ast
     else:
         c = visitors.CollectDependencies()
         module.Visit(c)
         six.assertCountEqual(self, c.modules, immediate_deps)
         ast = module
     c = visitors.CollectLateDependencies()
     ast.Visit(c)
     six.assertCountEqual(self, c.modules, late_deps)
Ejemplo n.º 8
0
 def _collect_ast_dependencies(self, ast):
   """Goes over an ast and returns all references module names."""
   deps = visitors.CollectDependencies()
   ast.Visit(deps)
   return deps.dependencies