Exemplo n.º 1
0
    def visit_Lambda(self, node):
        op = issimpleoperator(node)
        if op is not None:
            if mangle('operator') not in self.global_declarations:
                import_ = ast.Import(
                    [ast.alias('operator', mangle('operator'))])
                self.imports.append(import_)
                operator_module = MODULES['operator']
                self.global_declarations[mangle('operator')] = operator_module
            return ast.Attribute(
                ast.Name(mangle('operator'), ast.Load(), None, None), op,
                ast.Load())

        self.generic_visit(node)
        forged_name = "{0}_lambda{1}".format(self.prefix,
                                             len(self.lambda_functions))

        ii = self.gather(ImportedIds, node)
        ii.difference_update(self.lambda_functions)  # remove current lambdas

        binded_args = [
            ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii)
        ]
        node.args.args = (
            [ast.Name(iin, ast.Param(), None, None)
             for iin in sorted(ii)] + node.args.args)
        for patternname, pattern in self.patterns.items():
            if issamelambda(pattern, node):
                proxy_call = ast.Name(patternname, ast.Load(), None, None)
                break
        else:
            duc = ExtendedDefUseChains()
            nodepattern = deepcopy(node)
            duc.visit(ast.Module([ast.Expr(nodepattern)], []))
            self.patterns[forged_name] = nodepattern, duc

            forged_fdef = ast.FunctionDef(forged_name, copy(node.args),
                                          [ast.Return(node.body)], [], None,
                                          None)
            metadata.add(forged_fdef, metadata.Local())
            self.lambda_functions.append(forged_fdef)
            self.global_declarations[forged_name] = forged_fdef
            proxy_call = ast.Name(forged_name, ast.Load(), None, None)

        if binded_args:
            if MODULES['functools'] not in self.global_declarations.values():
                import_ = ast.Import(
                    [ast.alias('functools', mangle('functools'))])
                self.imports.append(import_)
                functools_module = MODULES['functools']
                self.global_declarations[mangle(
                    'functools')] = functools_module

            return ast.Call(
                ast.Attribute(
                    ast.Name(mangle('functools'), ast.Load(), None, None),
                    "partial", ast.Load()), [proxy_call] + binded_args, [])
        else:
            return proxy_call
Exemplo n.º 2
0
def insert_import_module_with_postion(node,
                                      mdl_name="paddle",
                                      pos=IMPROT_PADDLE_POS):
    """insert import module with pos number

    Args:
        node (gast)
        mdl_name (str, optional): Defaults to "paddle"
        pos (int, optional): Defaults to IMPROT_PADDLE_POS.

    Returns:
        node (gast)
    """
    assert isinstance(node, gast.gast.Module)
    stat = -1
    for b in node.body:
        if not isinstance(b, gast.gast.Import):
            continue
        for name in b.names:
            if "paddle" != name.name:
                continue
            else:
                stat = 0
                break
    if stat != 0:
        def_alias = gast.alias(name=mdl_name, asname=None)
        new_import_node = gast.gast.Import(names=[def_alias])
        print("insert paddle with position: %s" % (pos + 1))
        node.body.insert(pos + 1, new_import_node)
    return node
Exemplo n.º 3
0
    def visit_Lambda(self, node):
        if MODULES['functools'] not in self.global_declarations.values():
            import_ = ast.Import([ast.alias('functools', mangle('functools'))])
            self.imports.append(import_)
            functools_module = MODULES['functools']
            self.global_declarations[mangle('functools')] = functools_module

        self.generic_visit(node)
        forged_name = "{0}_lambda{1}".format(self.prefix,
                                             len(self.lambda_functions))

        ii = self.passmanager.gather(ImportedIds, node, self.ctx)
        ii.difference_update(self.lambda_functions)  # remove current lambdas

        binded_args = [ast.Name(iin, ast.Load(), None) for iin in sorted(ii)]
        node.args.args = (
            [ast.Name(iin, ast.Param(), None)
             for iin in sorted(ii)] + node.args.args)
        forged_fdef = ast.FunctionDef(forged_name, copy(node.args),
                                      [ast.Return(node.body)], [], None)
        self.lambda_functions.append(forged_fdef)
        self.global_declarations[forged_name] = forged_fdef
        proxy_call = ast.Name(forged_name, ast.Load(), None)
        if binded_args:
            return ast.Call(
                ast.Attribute(ast.Name(mangle('functools'), ast.Load(), None),
                              "partial", ast.Load()),
                [proxy_call] + binded_args, [])
        else:
            return proxy_call
 def visit_ImportFrom(self, node):
     for alias in node.names:
         if alias.name == "*":
             self.update = True
             node.names.pop()
             node.names.extend(ast.alias(fname, None) for fname in MODULES[node.module])
     return node
Exemplo n.º 5
0
 def visit_Module(self, node):
     self.need_import = False
     self.generic_visit(node)
     if self.need_import:
         importIt = ast.Import(names=[ast.alias(name='numpy', asname=None)])
         node.body.insert(0, importIt)
     return node
Exemplo n.º 6
0
    def visit_Module(self, node):
        """
            When we normalize call, we need to add correct import for method
            to function transformation.

            a.max()

            for numpy array will become:

            numpy.max(a)

            so we have to import numpy.
        """
        self.skip_functions = True
        self.generic_visit(node)
        self.skip_functions = False
        self.generic_visit(node)
        new_imports = self.to_import - self.globals
        imports = [
            ast.Import(names=[ast.alias(name=mod[17:], asname=mod)])
            for mod in new_imports
        ]
        node.body = imports + node.body
        self.update |= bool(imports)
        return node
Exemplo n.º 7
0
 def visit_Module(self, node):
     self.need_import = False
     self.generic_visit(node)
     if self.need_import:
         importIt = ast.Import(names=[ast.alias(name='numpy', asname=None)])
         node.body.insert(0, importIt)
     return node
Exemplo n.º 8
0
    def visit_FunctionDef(self, node):
        self.update = True
        if MODULES['functools'] not in self.global_declarations.values():
            import_ = ast.Import([ast.alias('functools', mangle('functools'))])
            self.ctx.module.body.insert(0, import_)
            functools_module = MODULES['functools']
            self.global_declarations[mangle('functools')] = functools_module

        self.ctx.module.body.append(node)

        former_name = node.name
        seed = 0
        new_name = "pythran_{}{}"

        while new_name.format(former_name, seed) in self.identifiers:
            seed += 1

        new_name = new_name.format(former_name, seed)
        self.identifiers.add(new_name)

        ii = self.gather(ImportedIds, node)
        binded_args = [
            ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii)
        ]
        node.args.args = (
            [ast.Name(iin, ast.Param(), None, None)
             for iin in sorted(ii)] + node.args.args)

        metadata.add(node, metadata.Local())

        class Renamer(ast.NodeTransformer):
            def visit_Call(self, node):
                self.generic_visit(node)
                if (isinstance(node.func, ast.Name)
                        and node.func.id == former_name):
                    node.func.id = new_name
                    node.args = ([
                        ast.Name(iin, ast.Load(), None, None)
                        for iin in sorted(ii)
                    ] + node.args)
                return node

        Renamer().visit(node)

        node.name = new_name
        self.global_declarations[node.name] = node
        proxy_call = ast.Name(new_name, ast.Load(), None, None)

        new_node = ast.Assign([ast.Name(former_name, ast.Store(), None, None)],
                              ast.Call(
                                  ast.Attribute(
                                      ast.Name(mangle('functools'), ast.Load(),
                                               None, None), "partial",
                                      ast.Load()),
                                  [proxy_call] + binded_args,
                                  [],
                              ))

        self.generic_visit(node)
        return new_node
Exemplo n.º 9
0
    def visit_Lambda(self, node):
        if MODULES["functools"] not in self.global_declarations.values():
            import_ = ast.Import([ast.alias("functools", None)])
            self.imports.append(import_)
            self.global_declarations["functools"] = MODULES["functools"]

        self.generic_visit(node)
        forged_name = "{0}_lambda{1}".format(self.prefix, len(self.lambda_functions))

        ii = self.passmanager.gather(ImportedIds, node, self.ctx)
        ii.difference_update(self.lambda_functions)  # remove current lambdas

        binded_args = [ast.Name(iin, ast.Load(), None) for iin in sorted(ii)]
        node.args.args = [ast.Name(iin, ast.Param(), None) for iin in sorted(ii)] + node.args.args
        forged_fdef = ast.FunctionDef(forged_name, copy(node.args), [ast.Return(node.body)], [], None)
        self.lambda_functions.append(forged_fdef)
        self.global_declarations[forged_name] = forged_fdef
        proxy_call = ast.Name(forged_name, ast.Load(), None)
        if binded_args:
            return ast.Call(
                ast.Attribute(ast.Name("functools", ast.Load(), None), "partial", ast.Load()),
                [proxy_call] + binded_args,
                [],
            )
        else:
            return proxy_call
Exemplo n.º 10
0
    def visit_FunctionDef(self, node):
        self.update = True
        if MODULES['functools'] not in self.global_declarations.values():
            import_ = ast.Import([ast.alias('functools', mangle('functools'))])
            self.ctx.module.body.insert(0, import_)
            functools_module = MODULES['functools']
            self.global_declarations[mangle('functools')] = functools_module

        self.ctx.module.body.append(node)

        former_name = node.name
        seed = 0
        new_name = "pythran_{}{}"

        while new_name.format(former_name, seed) in self.identifiers:
            seed += 1

        new_name = new_name.format(former_name, seed)
        self.identifiers.add(new_name)

        ii = self.passmanager.gather(ImportedIds, node, self.ctx)
        binded_args = [ast.Name(iin, ast.Load(), None) for iin in sorted(ii)]
        node.args.args = ([ast.Name(iin, ast.Param(), None)
                           for iin in sorted(ii)] +
                          node.args.args)

        class Renamer(ast.NodeTransformer):
            def visit_Call(self, node):
                self.generic_visit(node)
                if (isinstance(node.func, ast.Name) and
                        node.func.id == former_name):
                    node.func.id = new_name
                    node.args = (
                        [ast.Name(iin, ast.Load(), None)
                         for iin in sorted(ii)] +
                        node.args
                        )
                return node
        Renamer().visit(node)

        node.name = new_name
        self.global_declarations[node.name] = node
        proxy_call = ast.Name(new_name, ast.Load(), None)

        new_node = ast.Assign(
            [ast.Name(former_name, ast.Store(), None)],
            ast.Call(
                ast.Attribute(
                    ast.Name(mangle('functools'), ast.Load(), None),
                    "partial",
                    ast.Load()
                    ),
                [proxy_call] + binded_args,
                [],
                )
            )

        self.generic_visit(node)
        return new_node
Exemplo n.º 11
0
 def visit_Module(self, node):
     """Add itertools import for imap, izip or ifilter iterator."""
     self.generic_visit(node)
     import_alias = ast.alias(name='itertools', asname=mangle('itertools'))
     if self.use_itertools:
         importIt = ast.Import(names=[import_alias])
         node.body.insert(0, importIt)
     return node
Exemplo n.º 12
0
 def visit_alias(self, node):
     new_node = gast.alias(
         self._visit(node.name),
         self._visit(node.asname),
     )
     new_node.lineno = new_node.col_offset = None
     new_node.end_lineno = new_node.end_col_offset = None
     return new_node
Exemplo n.º 13
0
 def visit_ImportFrom(self, node):
     for alias in node.names:
         if alias.name == '*':
             self.update = True
             node.names.pop()
             node.names.extend(
                 ast.alias(fname, None) for fname in MODULES[node.module])
     return node
Exemplo n.º 14
0
 def visit_Module(self, node):
     self.use_itertools = False
     self.generic_visit(node)
     if self.use_itertools:
         importIt = ast.Import(
             names=[ast.alias(name='itertools', asname=None)])
         node.body.insert(0, importIt)
     return node
Exemplo n.º 15
0
 def visit_Module(self, node):
     self.use_itertools = False
     self.generic_visit(node)
     if self.use_itertools:
         importIt = ast.Import(
             names=[ast.alias(name='itertools', asname=None)])
         node.body.insert(0, importIt)
     return node
Exemplo n.º 16
0
 def call_function(self, _, func_name):
     # There was a direct call to a function from this builtin. It means it
     # was imported in the caller module in the form: from builtin import
     # foo. We need to add such node to be imported
     importFrom = ast.ImportFrom(
         module=self.name, names=[ast.alias(name=func_name, asname=None)], level=0
     )  # FIXME what is level?
     self.exported_functions[func_name] = importFrom
     return func_name
Exemplo n.º 17
0
 def call_function(self, _, func_name):
     # There was a direct call to a function from this builtin. It means it
     # was imported in the caller module in the form: from builtin import
     # foo. We need to add such node to be imported
     importFrom = ast.ImportFrom(
         module=self.name,
         names=[ast.alias(name=func_name, asname=None)],
         level=0)  # FIXME what is level?
     self.exported_functions[func_name] = importFrom
     return func_name
Exemplo n.º 18
0
 def import_fct(mod):
     if mod.is_main_module:
         # don't need to import anything from the main module
         return
     for alias, module_name in mod.dependent_modules.items():
         import_node = ast.Import(
             names=[ast.alias(name=module_name, asname=alias)])
         import_list.append(import_node)
     # Here we import the function itself (FunctionDef node)
     # In case of builtin module, it is an ImportFrom node.
     import_list.extend(list(mod.exported_functions.values()))
Exemplo n.º 19
0
 def generate_ImportList(self):
     """List of imported functions to be added to the main module.  """
     import_list = []
     for mod in self.modules.values():
         if mod.is_main_module:
             # don't need to import anything from the main module
             continue
         for alias, module_name in mod.dependent_modules.items():
             import_node = ast.Import(names=[ast.alias(name=module_name, asname=alias)])
             import_list.append(import_node)
         # Here we import the function itself (FunctionDef node)
         # In case of builtin module, it is an ImportFrom node.
         import_list.extend(list(mod.exported_functions.values()))
     return import_list
Exemplo n.º 20
0
class CbrtPattern(Pattern):
    # X ** .33333 => numpy.cbrt(X)
    pattern = ast.BinOp(Placeholder(0), ast.Pow(), ast.Constant(1./3., None))

    @staticmethod
    def sub():
        return ast.Call(
            func=ast.Attribute(value=ast.Name(id=mangle('numpy'),
                                              ctx=ast.Load(),
                                              annotation=None,
                                              type_comment=None),
                               attr="cbrt", ctx=ast.Load()),
            args=[Placeholder(0)], keywords=[])

    extra_imports = [ast.Import([ast.alias('numpy', mangle('numpy'))])]
Exemplo n.º 21
0
 def generate_ImportList(self):
     """List of imported functions to be added to the main module.  """
     import_list = []
     for mod in self.modules.values():
         if mod.is_main_module:
             # don't need to import anything from the main module
             continue
         for alias, module_name in mod.dependent_modules.items():
             import_node = ast.Import(
                 names=[ast.alias(name=module_name, asname=alias)])
             import_list.append(import_node)
         # Here we import the function itself (FunctionDef node)
         # In case of builtin module, it is an ImportFrom node.
         import_list.extend(list(mod.exported_functions.values()))
     return import_list
Exemplo n.º 22
0
def import_from_ast(module: str, names: List[str]) -> ImportFrom:
    """Convenience function for creating 'from module import names, ...' as a AST node

    Args:
        module: The name of the module to import from.
        names: List of names of symbols to import from the module.
    Returns:
        The created ImportFrom AST node.
    """
    return ImportFrom(
        module=module,
        names=[alias(name=n, asname=None) for n in names],
        # 0 -> use absolute import
        level=0,
    )
Exemplo n.º 23
0
    def visit_Module(self, node):
        """
        Visit the whole module and add all import at the top level.

        >> import numpy.linalg

        Becomes

        >> import numpy

        """
        node.body = [k for k in (self.visit(n) for n in node.body) if k]
        imports = [ast.Import([ast.alias(i, None)]) for i in self.imports]
        node.body = imports + node.body
        ast.fix_missing_locations(node)
        return node
Exemplo n.º 24
0
    def visit_Module(self, node):
        """
        Visit the whole module and add all import at the top level.

        >> import numpy.linalg

        Becomes

        >> import numpy

        """
        node.body = [k for k in (self.visit(n) for n in node.body) if k]
        imports = [ast.Import([ast.alias(i, mangle(i))]) for i in self.imports]
        node.body = imports + node.body
        ast.fix_missing_locations(node)
        return node
Exemplo n.º 25
0
def insert_import_module(node, mdl_name="paddle"): 
    assert isinstance(node, gast.gast.Module)
    stat = -1
    for b in node.body: 
        if not isinstance(b, gast.gast.Import):
            continue
        for name in b.names: 
            if "paddle" != name.name: 
                continue
            else: 
                stat = 0
                break
    if stat != 0: 
        def_alias = gast.alias(name=mdl_name, asname=None)
        new_import_node = gast.gast.Import(names=[def_alias])
        node.body.insert(IMPROT_PADDLE_POS, new_import_node)
    return  node
Exemplo n.º 26
0
def insert_import_module_with_postion(node, mdl_name="paddle", pos=IMPROT_PADDLE_POS): 
    assert isinstance(node, gast.gast.Module)
    stat = -1
    for b in node.body: 
        if not isinstance(b, gast.gast.Import):
            continue
        for name in b.names: 
            if "paddle" != name.name: 
                continue
            else: 
                stat = 0
                break
    if stat != 0: 
        def_alias = gast.alias(name=mdl_name, asname=None)
        new_import_node = gast.gast.Import(names=[def_alias])
        print("insert paddle into %s"%(pos+1))
        node.body.insert(pos+1, new_import_node)
    return  node
Exemplo n.º 27
0
    def visit_Module(self, node):
        """
            When we normalize call, we need to add correct import for method
            to function transformation.

            a.max()

            for numpy array will become:

            numpy.max(a)

            so we have to import numpy.
        """
        self.generic_visit(node)
        new_imports = self.to_import - self.globals
        imports = [ast.Import(names=[ast.alias(name=mod, asname=None)])
                   for mod in new_imports]
        node.body = imports + node.body
        self.update |= bool(imports)
        return node
Exemplo n.º 28
0
def _wrap_into_dynamic_factory(nodes, entity_name, factory_factory_name,
                               factory_name, closure_vars, future_features):
  """Wraps an AST into the body of a dynamic factory.

  This uses the dynamic factory (factory of factory) pattern to achieve the
  following:

   1. The inner factory, dynamically creates the entity represented by nodes.
   2. The entity is parametrized by `ag__`, the internal AutoGraph module.
   3. The outer factory creates the inner factory with a lexical scope
      in which `closure_vars` are bound local variables. This in turn allows the
      caller to control the exact closure (i.e. non-global free variables) for
      the inner factory.

  The AST is expected to define some symbol named by `entity_name`.

  Args:
    nodes: ast.AST
    entity_name: Union[Text, ast.AST]
    factory_factory_name: Text
    factory_name: Text
    closure_vars: Iterable[Text]
    future_features: Iterable[Text], see EntityInfo.future_features.

  Returns:
    ast.AST
  """
  if not isinstance(nodes, (list, tuple)):
    nodes = (nodes,)

  dummy_closure_defs = []
  for var_name in closure_vars:
    template = """
      var_name = None
    """
    dummy_closure_defs.extend(templates.replace(template, var_name=var_name))

  if future_features:
    future_imports = gast.ImportFrom(
        module='__future__',
        names=[gast.alias(name=name, asname=None) for name in future_features],
        level=0)
  else:
    future_imports = []

  # These dummy symbol declarations create local fariables in a function scope,
  # so that the Python parser correctly marks them as free non-global variables
  # upon load (that is, it creates cell slots for each symbol). Their values are
  # not used, as the cells are swapped with the original entity's cells after
  # the code has been loaded.
  template = """
    future_imports
    def factory_factory_name():
      dummy_closure_defs
      def factory_name(ag__, ag_source_map__, ag_module__):
        entity_defs
        entity_name.ag_source_map = ag_source_map__
        entity_name.ag_module = ag_module__
        entity_name.autograph_info__ = {}
        return entity_name
      return factory_name
  """
  return templates.replace(
      template,
      future_imports=future_imports,
      factory_factory_name=factory_factory_name,
      factory_name=factory_name,
      dummy_closure_defs=dummy_closure_defs,
      entity_defs=nodes,
      entity_name=entity_name)
Exemplo n.º 29
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        nodes, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            do_rename=False)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = nodes[0]
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace
Exemplo n.º 30
0
def _wrap_into_factory(nodes, entity_name, inner_factory_name,
                       outer_factory_name, closure_vars, factory_args,
                       future_features):
  """Wraps an AST into the body of a factory with consistent lexical context.

  The AST is expected to define some symbol with a name given by `entity_name`.

  This mechanism ensures that the resulting transformed entity has lexical
  scoping identical to that of the source entity, while allowing extra
  parametrization.

  Two nested factories achieve the following:

   1. The inner factory dynamically creates the entity represented by `nodes`.
   2. The inner factory is parametrized by a custom set of arguments.
   3. The inner factory has a closure identical to that of the transformed
       entity.
   4. The inner factory has local variables named like `args`, which `nodes` may
       use as additional parameters.
   5. The inner factory returns the variables given by `entity_name`.
   6. The outer factory is niladic.
   7. The outer factory has no closure.
   8. The outer factory creates the necessary lexical scope for the inner
       factory, so that the loaded code has the given configuration for
       closure/globals.
   9. The outer factory returns the inner factory.

  Roughly speaking, the following code is generated:

      from __future__ import future_feature_1
      from __future__ import future_feature_2
      ...

      def outer_factory():
        closure_var_1 = None
        closure_var_2 = None
        ...

        def inner_factory(arg_1, arg_2, ...):
          <<nodes>>
          return entity

        return inner_factory

  The lexical scoping is created using dummy symbol declarations which create
  local fariables in the body of the outer factory, so that the Python parser
  correctly marks them as free non-global variables upon load (that is, it
  creates cell slots for each symbol. Thes symbols are initialized with None,
  but their values are not expected to be used; instead, the caller is expected
  to replace them with the cells of the source entity. For more details, see:
  https://docs.python.org/3/reference/executionmodel.html#binding-of-names

  Args:
    nodes: Tuple[ast.AST], the source code to wrap.
    entity_name: Union[Text, ast.AST], the name of the principal entity that
      `nodes` define.
    inner_factory_name: Text, the name of the inner factory.
    outer_factory_name: Text, the name of the outer factory.
    closure_vars: Iterable[Text], names of the closure variables for the inner
      factory.
    factory_args: Iterable[Text], names of additional arguments for the
      inner factory. Useful to configure variables that the converted code can
      use. Typically, these are modules.
    future_features: Iterable[Text], names of future statements to associate the
      code with.

  Returns:
    ast.AST
  """
  dummy_closure_defs = []
  for var_name in closure_vars:
    template = """
      var_name = None
    """
    dummy_closure_defs.extend(templates.replace(template, var_name=var_name))

  if future_features:
    future_imports = gast.ImportFrom(
        module='__future__',
        names=[gast.alias(name=name, asname=None) for name in future_features],
        level=0)
  else:
    future_imports = []

  factory_args = [
      gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None)
      for name in factory_args
  ]

  template = """
    future_imports
    def outer_factory_name():
      dummy_closure_defs
      def inner_factory_name(factory_args):
        entity_defs
        return entity_name
      return inner_factory_name
  """
  return templates.replace(
      template,
      dummy_closure_defs=dummy_closure_defs,
      entity_defs=nodes,
      entity_name=entity_name,
      factory_args=factory_args,
      future_imports=future_imports,
      inner_factory_name=inner_factory_name,
      outer_factory_name=outer_factory_name)
Exemplo n.º 31
0
 def visit_Module(self, node):
     self.generic_visit(node)
     if MODULE != '__builtin__':
         importIt = ast.Import(names=[ast.alias(name=MODULE, asname=None)])
         node.body.insert(0, importIt)
     return node
Exemplo n.º 32
0
 def visit_Module(self, node):
     self.generic_visit(node)
     if MODULE != '__builtin__':
         importIt = ast.Import(names=[ast.alias(name=MODULE, asname=None)])
         node.body.insert(0, importIt)
     return node
Exemplo n.º 33
0
def _wrap_into_dynamic_factory(nodes, entity_name, factory_factory_name,
                               factory_name, closure_vars, future_features):
    """Wraps an AST into the body of a dynamic factory.

  This uses the dynamic factory (factory of factory) pattern to achieve the
  following:

   1. The inner factory, dynamically creates the entity represented by nodes.
   2. The entity is parametrized by `ag__`, the internal AutoGraph module.
   3. The outer factory creates the inner factory with a lexical scope
      in which `closure_vars` are bound local variables. This in turn allows the
      caller to control the exact closure (i.e. non-global free variables) for
      the inner factory.

  The AST is expected to define some symbol named by `entity_name`.

  Args:
    nodes: ast.AST
    entity_name: Union[Text, ast.AST]
    factory_factory_name: Text
    factory_name: Text
    closure_vars: Iterable[Text]
    future_features: Iterable[Text], see EntityInfo.future_features.

  Returns:
    ast.AST
  """
    if not isinstance(nodes, (list, tuple)):
        nodes = (nodes, )

    dummy_closure_defs = []
    for var_name in closure_vars:
        template = """
      var_name = None
    """
        dummy_closure_defs.extend(
            templates.replace(template, var_name=var_name))

    if future_features:
        future_imports = gast.ImportFrom(module='__future__',
                                         names=[
                                             gast.alias(name=name, asname=None)
                                             for name in future_features
                                         ],
                                         level=0)
    else:
        future_imports = []

    # These dummy symbol declarations create local fariables in a function scope,
    # so that the Python parser correctly marks them as free non-global variables
    # upon load (that is, it creates cell slots for each symbol). Their values are
    # not used, as the cells are swapped with the original entity's cells after
    # the code has been loaded.
    template = """
    future_imports
    def factory_factory_name():
      dummy_closure_defs
      def factory_name(ag__, ag_source_map__, ag_module__):
        entity_defs
        entity_name.ag_source_map = ag_source_map__
        entity_name.ag_module = ag_module__
        entity_name.autograph_info__ = {}
        return entity_name
      return factory_name
  """
    return templates.replace(template,
                             future_imports=future_imports,
                             factory_factory_name=factory_factory_name,
                             factory_name=factory_name,
                             dummy_closure_defs=dummy_closure_defs,
                             entity_defs=nodes,
                             entity_name=entity_name)
Exemplo n.º 34
0
 def visit_Module(self, node):
     """ Add itertools import for imap, izip or ifilter iterator. """
     self.generic_visit(node)
     importIt = ast.Import(names=[ast.alias(name='itertools', asname=None)])
     return ast.Module(body=([importIt] + node.body))
Exemplo n.º 35
0
def class_to_graph(c, program_ctx):
  """Specialization of `entity_to_graph` for classes."""
  converted_members = {}
  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
  members = tf_inspect.getmembers(c, predicate=method_filter)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.' % c)

  class_namespace = {}
  for _, m in members:
    # Only convert the members that are directly defined by the class.
    if inspect_utils.getdefiningclass(m, c) is not c:
      continue
    node, _, namespace = function_to_graph(
        m,
        program_ctx=program_ctx,
        arg_values={},
        arg_types={'self': (c.__name__, c)},
        owner_type=c)
    if class_namespace is None:
      class_namespace = namespace
    else:
      class_namespace.update(namespace)
    converted_members[m] = node[0]
  namer = program_ctx.new_namer(class_namespace)
  class_name = namer.compiled_class_name(c.__name__, c)

  # TODO(mdan): This needs to be explained more thoroughly.
  # Process any base classes: if the superclass if of a whitelisted type, an
  # absolute import line is generated. Otherwise, it is marked for conversion
  # (as a side effect of the call to namer.compiled_class_name() followed by
  # program_ctx.update_name_map(namer)).
  output_nodes = []
  renames = {}
  base_names = []
  for base in c.__bases__:
    if isinstance(object, base):
      base_names.append('object')
      continue
    if is_whitelisted_for_graph(base):
      alias = namer.new_symbol(base.__name__, ())
      output_nodes.append(
          gast.ImportFrom(
              module=base.__module__,
              names=[gast.alias(name=base.__name__, asname=alias)],
              level=0))
    else:
      # This will trigger a conversion into a class with this name.
      alias = namer.compiled_class_name(base.__name__, base)
    base_names.append(alias)
    renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
  program_ctx.update_name_map(namer)

  # Generate the definition of the converted class.
  bases = [gast.Name(n, gast.Load(), None) for n in base_names]
  class_def = gast.ClassDef(
      class_name,
      bases=bases,
      keywords=[],
      body=list(converted_members.values()),
      decorator_list=[])
  # Make a final pass to replace references to the class or its base classes.
  # Most commonly, this occurs when making super().__init__() calls.
  # TODO(mdan): Making direct references to superclass' superclass will fail.
  class_def = qual_names.resolve(class_def)
  renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
  class_def = ast_util.rename_symbols(class_def, renames)

  output_nodes.append(class_def)

  return output_nodes, class_name, class_namespace
Exemplo n.º 36
0
def convert_class_to_ast(c, program_ctx):
  """Specialization of `convert_entity_to_ast` for classes."""
  # TODO(mdan): Revisit this altogether. Not sure we still need it.
  converted_members = {}
  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
  members = tf_inspect.getmembers(c, predicate=method_filter)
  if not members:
    raise ValueError('cannot convert %s: no member methods' % c)

  # TODO(mdan): Don't clobber namespaces for each method in one class namespace.
  # The assumption that one namespace suffices for all methods only holds if
  # all methods were defined in the same module.
  # If, instead, functions are imported from multiple modules and then spliced
  # into the class, then each function has its own globals and __future__
  # imports that need to stay separate.

  # For example, C's methods could both have `global x` statements referring to
  # mod1.x and mod2.x, but using one namespace for C would cause a conflict.
  # from mod1 import f1
  # from mod2 import f2
  # class C(object):
  #   method1 = f1
  #   method2 = f2

  class_namespace = {}
  future_features = None
  for _, m in members:
    # Only convert the members that are directly defined by the class.
    if inspect_utils.getdefiningclass(m, c) is not c:
      continue
    (node,), _, entity_info = convert_func_to_ast(
        m,
        program_ctx=program_ctx,
        do_rename=False)
    class_namespace.update(entity_info.namespace)
    converted_members[m] = node

    # TODO(mdan): Similarly check the globals.
    if future_features is None:
      future_features = entity_info.future_features
    elif frozenset(future_features) ^ frozenset(entity_info.future_features):
      # Note: we can support this case if ever needed.
      raise ValueError(
          'cannot convert {}: if has methods built with mismatched future'
          ' features: {} and {}'.format(c, future_features,
                                        entity_info.future_features))
  namer = naming.Namer(class_namespace)
  class_name = namer.class_name(c.__name__)

  # Process any base classes: if the superclass if of a whitelisted type, an
  # absolute import line is generated.
  output_nodes = []
  renames = {}
  base_names = []
  for base in c.__bases__:
    if isinstance(object, base):
      base_names.append('object')
      continue
    if is_whitelisted_for_graph(base):
      alias = namer.new_symbol(base.__name__, ())
      output_nodes.append(
          gast.ImportFrom(
              module=base.__module__,
              names=[gast.alias(name=base.__name__, asname=alias)],
              level=0))
    else:
      raise NotImplementedError(
          'Conversion of classes that do not directly extend classes from'
          ' whitelisted modules is temporarily suspended. If this breaks'
          ' existing code please notify the AutoGraph team immediately.')
    base_names.append(alias)
    renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

  # Generate the definition of the converted class.
  bases = [gast.Name(n, gast.Load(), None) for n in base_names]
  class_def = gast.ClassDef(
      class_name,
      bases=bases,
      keywords=[],
      body=list(converted_members.values()),
      decorator_list=[])
  # Make a final pass to replace references to the class or its base classes.
  # Most commonly, this occurs when making super().__init__() calls.
  # TODO(mdan): Making direct references to superclass' superclass will fail.
  class_def = qual_names.resolve(class_def)
  renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
  class_def = ast_util.rename_symbols(class_def, renames)

  output_nodes.append(class_def)

  # TODO(mdan): Find a way better than forging this object.
  entity_info = transformer.EntityInfo(
      source_code=None,
      source_file=None,
      future_features=future_features,
      namespace=class_namespace)

  return output_nodes, class_name, entity_info
Exemplo n.º 37
0
 def visit_Module(self, node):
     self.generic_visit(node)
     importIt = ast.Import(names=[ast.alias(name=MODULE, asname=ASMODULE)])
     node.body.insert(0, importIt)
     return node
Exemplo n.º 38
0
def convert_class_to_ast(c, program_ctx):
    """Specialization of `convert_entity_to_ast` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('cannot convert %s: no member methods' % c)

    # TODO(mdan): Don't clobber namespaces for each method in one class namespace.
    # The assumption that one namespace suffices for all methods only holds if
    # all methods were defined in the same module.
    # If, instead, functions are imported from multiple modules and then spliced
    # into the class, then each function has its own globals and __future__
    # imports that need to stay separate.

    # For example, C's methods could both have `global x` statements referring to
    # mod1.x and mod2.x, but using one namespace for C would cause a conflict.
    # from mod1 import f1
    # from mod2 import f2
    # class C(object):
    #   method1 = f1
    #   method2 = f2

    class_namespace = {}
    future_features = None
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        (node, ), _, entity_info = convert_func_to_ast(m,
                                                       program_ctx=program_ctx,
                                                       do_rename=False)
        class_namespace.update(entity_info.namespace)
        converted_members[m] = node

        # TODO(mdan): Similarly check the globals.
        if future_features is None:
            future_features = entity_info.future_features
        elif frozenset(future_features) ^ frozenset(
                entity_info.future_features):
            # Note: we can support this case if ever needed.
            raise ValueError(
                'cannot convert {}: if has methods built with mismatched future'
                ' features: {} and {}'.format(c, future_features,
                                              entity_info.future_features))
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    # TODO(mdan): Find a way better than forging this object.
    entity_info = transformer.EntityInfo(source_code=None,
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=class_namespace)

    return output_nodes, class_name, entity_info
Exemplo n.º 39
0
 def visit_Module(self, node):
     self.generic_visit(node)
     importIt = ast.Import(names=[ast.alias(name=MODULE, asname=ASMODULE)])
     node.body.insert(0, importIt)
     return node
Exemplo n.º 40
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        node, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            owner_type=c)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = node[0]
    namer = program_ctx.new_namer(class_namespace)
    class_name = namer.compiled_class_name(c.__name__, c)

    # TODO(mdan): This needs to be explained more thoroughly.
    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated. Otherwise, it is marked for conversion
    # (as a side effect of the call to namer.compiled_class_name() followed by
    # program_ctx.update_name_map(namer)).
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            # This will trigger a conversion into a class with this name.
            alias = namer.compiled_class_name(base.__name__, base)
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
    program_ctx.update_name_map(namer)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace
Exemplo n.º 41
0
 def visit_Module(self, node):
     """ Add itertools import for imap, izip or ifilter iterator. """
     self.generic_visit(node)
     importIt = ast.Import(names=[ast.alias(name='itertools', asname=None)])
     return ast.Module(body=([importIt] + node.body))