Пример #1
0
    def make_control_flow_handlers(self, cont_n, status_n, expected_return,
                                   has_cont, has_break):
        '''
        Create the statements in charge of gathering control flow information
        for the static_if result, and executes the expected control flow
        instruction
        '''
        if expected_return:
            assign = cont_ass = [
                ast.Assign([ast.Tuple(expected_return, ast.Store())],
                           ast.Name(cont_n, ast.Load(), None, None))
            ]
        else:
            assign = cont_ass = []

        if has_cont:
            cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None),
                               [ast.Eq()], [ast.Constant(LOOP_CONT, None)])
            cont_ass = [
                ast.If(cmpr,
                       deepcopy(assign) + [ast.Continue()], cont_ass)
            ]
        if has_break:
            cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None),
                               [ast.Eq()], [ast.Constant(LOOP_BREAK, None)])
            cont_ass = [
                ast.If(cmpr,
                       deepcopy(assign) + [ast.Break()], cont_ass)
            ]
        return cont_ass
Пример #2
0
def create_assign_node(name, node):
    """
    Creates a `gast.Assign` node by given name_id as target and node as value.
    """
    targets = generate_name_node(name, ctx=gast.Store())
    assign_node = gast.Assign(targets=[targets], value=node)
    return targets, assign_node
Пример #3
0
    def test_replace_code_block(self):
        template = """
      def test_fn(a):
        block
        return a
    """

        class ShouldBeReplaced(object):
            pass

        node = templates.replace(
            template,
            block=[
                gast.Assign(
                    [
                        gast.Name('a',
                                  ctx=ShouldBeReplaced,
                                  annotation=None,
                                  type_comment=None)
                    ],
                    gast.BinOp(
                        gast.Name('a',
                                  ctx=ShouldBeReplaced,
                                  annotation=None,
                                  type_comment=None), gast.Add(),
                        gast.Constant(1, kind=None)),
                ),
            ] * 2)[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(3, result.test_fn(1))
Пример #4
0
 def visit_loop(self, node, update_mask=gast.NameConstant(value=None)):
     node = FuseAttributes().visit(node)
     loads, stores = defaultdict(list), set()
     for child in node.body:
         for n in gast.walk(child):
             if isinstance(n, gast.Name) and isinstance(n.ctx, gast.Load):
                 loads[n.id].append(n)
         if isinstance(child, gast.Assign):
             if len(child.targets) > 1:
                 raise NotImplementedError("cannot process LCD that is "
                                           "part of multiple assignment")
             name = child.targets[0].id
             if name in loads:
                 if name in stores:
                     raise NotImplementedError("cannot process LCD "
                                               "stored to twice")
                 # $var = $expr -> $var = $var._update($expr)
                 child.value = gast.Call(
                     gast.Attribute(gast.Name(name, gast.Load(), None),
                                    gast.Name('_update', gast.Load(), None),
                                    None), [child.value, update_mask], [])
                 stores.add(name)
     node = SplitAttributes().visit(node)
     synchronizes = []
     for name in stores:
         synchronize = gast.Assign(
             [gast.Name(name, gast.Store(), None)],
             gast.Call(
                 gast.Attribute(
                     gast.Name(name, gast.Load(), None),
                     gast.Name('_synchronize', gast.Load(), None), None),
                 [], []))
         synchronizes.append(synchronize)
     node.body.extend(synchronizes)
     return node
Пример #5
0
 def synchronize_lcds(self, node):
     node = FuseAttributes().visit(node)
     loads, lcds = defaultdict(list), set()
     for child in node.body:
         for n in gast.walk(child):
             if isinstance(n, gast.Name) and isinstance(n.ctx, gast.Load):
                 loads[n.id].append(n)
         if isinstance(child, gast.Assign):
             name = child.targets[0].id
             if name in loads:
                 if name in lcds:
                     raise NotImplementedError("cannot process LCD "
                                               "stored to twice")
                 lcds.add(name)
     node = SplitAttributes().visit(node)
     synchronizes = []
     for name in lcds:
         synchronize = gast.Assign(
             [gast.Name(name, gast.Store(), None)],
             gast.Call(
                 gast.Attribute(
                     gast.Name(name, gast.Load(), None),
                     gast.Name('_synchronize', gast.Load(), None), None),
                 [], []))
         synchronizes.append(synchronize)
     node.body.extend(synchronizes)
     return node
Пример #6
0
    def _parse_sequence_assign(self, node):
        """
        a, b = c, d
        ->
        a = c
        b = d
        """
        assert isinstance(node, gast.Assign)

        target_nodes = node.targets
        value_node = node.value
        if not isinstance(target_nodes[0], (gast.List, gast.Tuple)):
            return node
        if not isinstance(value_node, (gast.List, gast.Tuple)):
            return node

        targets = node.targets[0].elts
        values = node.value.elts
        if len(targets) != len(values):
            return node

        new_nodes = []
        for target, value in zip(targets, values):
            assign_node = gast.Assign(targets=[target], value=value)
            new_nodes.append(assign_node)

        return new_nodes
Пример #7
0
def convert_func_to_ast(f, program_ctx, do_rename=True):
    """Specialization of `convert_entity_to_ast` for callable functions."""

    future_features = inspect_utils.getfutureimports(f)
    node, source = parser.parse_entity(f, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    # Parsed AST should contain future imports and one function def node.

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve_entity(node, source, f)

    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
    elif do_rename:
        new_name = namer.function_name(f.__name__)
    else:
        new_name = f.__name__

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         future_features=future_features,
                                         namespace=namespace)
    context = converter.EntityContext(namer, entity_info, program_ctx,
                                      new_name)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        node = gast.Assign(targets=[
            gast.Name(new_name,
                      ctx=gast.Store(),
                      annotation=None,
                      type_comment=None)
        ],
                           value=node)
    elif do_rename:
        node.name = new_name
    else:
        assert node.name == new_name

    return (node, ), new_name, entity_info
Пример #8
0
    def visit_FunctionDef(self, node):
        self.update = True
        if MODULES['functools'] not in self.global_declarations.values():
            import_ = ast.Import([ast.alias('functools', mangle('functools'))])
            self.ctx.module.body.insert(0, import_)
            functools_module = MODULES['functools']
            self.global_declarations[mangle('functools')] = functools_module

        self.ctx.module.body.append(node)

        former_name = node.name
        seed = 0
        new_name = "pythran_{}{}"

        while new_name.format(former_name, seed) in self.identifiers:
            seed += 1

        new_name = new_name.format(former_name, seed)
        self.identifiers.add(new_name)

        ii = self.passmanager.gather(ImportedIds, node, self.ctx)
        binded_args = [ast.Name(iin, ast.Load(), None) for iin in sorted(ii)]
        node.args.args = ([ast.Name(iin, ast.Param(), None)
                           for iin in sorted(ii)] +
                          node.args.args)

        class Renamer(ast.NodeTransformer):
            def visit_Call(self, node):
                self.generic_visit(node)
                if (isinstance(node.func, ast.Name) and
                        node.func.id == former_name):
                    node.func.id = new_name
                    node.args = (
                        [ast.Name(iin, ast.Load(), None)
                         for iin in sorted(ii)] +
                        node.args
                        )
                return node
        Renamer().visit(node)

        node.name = new_name
        self.global_declarations[node.name] = node
        proxy_call = ast.Name(new_name, ast.Load(), None)

        new_node = ast.Assign(
            [ast.Name(former_name, ast.Store(), None)],
            ast.Call(
                ast.Attribute(
                    ast.Name(mangle('functools'), ast.Load(), None),
                    "partial",
                    ast.Load()
                    ),
                [proxy_call] + binded_args,
                [],
                )
            )

        self.generic_visit(node)
        return new_node
Пример #9
0
 def visit_Break(self, node):
     node = self.generic_visit(node)
     flag = self.breaked_flag + str(self.getflag())
     self.for_breaked_stack[-1].append(flag)
     replacement = gast.Assign(
         targets=[gast.Name(id=flag, ctx=gast.Store(), annotation=None)],
         value=gast.NameConstant(value=True))
     return gast.copy_location(replacement, node)
Пример #10
0
 def visit_AugAssign(self, node):
   self.trivializing = True
   left = self.trivialize(node.target)
   right = self.trivialize(node.value)
   self.trivializing = False
   node = gast.Assign(targets=[node.target],
                      value=gast.BinOp(left=left, op=node.op, right=right))
   return node
Пример #11
0
 def _create_assign_node(self, node):
     var_node = self._get_print_var_node(node)
     if var_node is None:
         return None
     if self._need_transform(var_node, node):
         return gast.Assign(
             targets=[var_node], value=self._build_print_call_node(var_node))
     return None
Пример #12
0
 def make_unrolled_loop(self, node, value_node_list):
     # Repeat the for loop for each element of list,
     # prepending with a statement "i = next element of iterable"
     unrolled_nodes = []
     for value_node in value_node_list:
         assign_node = gast.Assign(targets=[node.target], value=value_node)
         unrolled_nodes.append(assign_node)
         unrolled_nodes.extend(node.body)
     return unrolled_nodes
Пример #13
0
 def generate_Assign(self):
   """Generate an Assign node."""
   # Generate left-hand side
   target_node = self.generate_Name(gast.Store())
   # Generate right-hand side
   value_node = self.generate_expression()
   # Put it all together
   node = gast.Assign(targets=[target_node], value=value_node)
   return node
Пример #14
0
def function_to_graph(f,
                      program_ctx,
                      arg_values,
                      arg_types,
                      owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""

  node, source = parser.parse_entity(f)
  node = node.body[0]

  # TODO(mdan): Can we convert everything and scoop the lambda afterwards?
  if f.__name__ == '<lambda>':
    nodes = ast_util.find_matching_lambda_definitions(node, f)
    if len(nodes) != 1:
      raise ValueError(
          'Unable to identify source code of lambda function {}. It was'
          ' defined on this line: {}, which contains multiple lambdas with'
          ' identical argument names. To avoid ambiguity, define each lambda'
          ' in a separate expression.'.format(f, source))
    node, = nodes

  # TODO(znado): Place inside standard_analysis.
  origin_info.resolve(node, source, f)
  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, program_ctx.autograph_module)
  namer = program_ctx.new_namer(namespace)

  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      owner_type=owner_type)
  context = converter.EntityContext(namer, entity_info, program_ctx)
  node = node_to_graph(node, context)

  if isinstance(node, gast.Lambda):
    new_name = namer.new_symbol('tf__lambda', ())
    node = gast.Assign(
        targets=[gast.Name(new_name, gast.Store(), None)], value=node)

  else:
    # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
    new_name, did_rename = namer.compiled_function_name(f.__name__, f,
                                                        owner_type)
    if did_rename:
      node.name = new_name
    else:
      new_name = f.__name__
      assert node.name == new_name

  program_ctx.update_name_map(namer)
  # TODO(mdan): Use this at compilation.

  return [node], new_name, namespace
Пример #15
0
    def visit_Compare(self, node):
        node = self.generic_visit(node)
        if len(node.ops) > 1:
            # in case we have more than one compare operator
            # we generate an auxiliary function
            # that lazily evaluates the needed parameters
            imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx)
            imported_ids = sorted(imported_ids)
            binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids]

            # name of the new function
            forged_name = "{0}_compare{1}".format(self.prefix,
                                                  len(self.compare_functions))

            # call site
            call = ast.Call(ast.Name(forged_name, ast.Load(), None),
                            binded_args, [])

            # new function
            arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids]
            args = ast.arguments(arg_names, None, [], [], None, [])

            body = []  # iteratively fill the body (yeah, feel your body!)
            body.append(
                ast.Assign([ast.Name('$0', ast.Store(), None)], node.left))
            for i, exp in enumerate(node.comparators):
                body.append(
                    ast.Assign(
                        [ast.Name('${}'.format(i + 1), ast.Store(), None)],
                        exp))
                cond = ast.Compare(
                    ast.Name('${}'.format(i), ast.Load(), None), [node.ops[i]],
                    [ast.Name('${}'.format(i + 1), ast.Load(), None)])
                body.append(
                    ast.If(cond, [ast.Pass()], [ast.Return(ast.Num(0))]))
            body.append(ast.Return(ast.Num(1)))

            forged_fdef = ast.FunctionDef(forged_name, args, body, [], None)
            self.compare_functions.append(forged_fdef)

            return call
        else:
            return node
Пример #16
0
    def visit_Assign(self, node):
        new_node = gast.Assign(
            self._visit(node.targets),
            self._visit(node.value),
            None,  # type_comment
        )

        gast.copy_location(new_node, node)
        new_node.end_lineno = new_node.end_col_offset = None
        return new_node
Пример #17
0
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    node = node.body[0]

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    try:
        node = node_to_graph(node, context)
    except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
        logging.error(1, 'Error converting %s', f, exc_info=True)
        raise errors.InternalError('conversion', e)
        # TODO(mdan): Catch and rethrow syntax errors.

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    elif do_rename:
        # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
        new_name = namer.function_name(f.__name__)
        node.name = new_name
    else:
        new_name = f.__name__
        assert node.name == new_name

    return [node], new_name, namespace
Пример #18
0
 def _build_var_shape_assign_node(self):
     # get variable shape as iter length
     return gast.Assign(
         targets=[
             gast.Name(
                 id=self.iter_var_shape_name,
                 ctx=gast.Load(),
                 annotation=None,
                 type_comment=None)
         ],
         value=create_api_shape_node(self.iter_node.func))
Пример #19
0
 def trivialize(self, node):
     if isinstance(node, (gast.Name, type(None)) + grammar.LITERALS):
         return node
     name = self.namer.name(node)
     stmt = gast.Assign(
         targets=[gast.Name(annotation=None, id=name, ctx=gast.Store())],
         value=None)
     self.mark(stmt)
     self.prepend(stmt)
     stmt.value = self.visit(node)
     return gast.Name(annotation=None, id=name, ctx=gast.Load())
Пример #20
0
    def test_ast_to_source(self):
        node = gast.If(
            test=gast.Num(1),
            body=[
                gast.Assign(targets=[gast.Name('a', gast.Store(), None)],
                            value=gast.Name('b', gast.Load(), None))
            ],
            orelse=[
                gast.Assign(targets=[gast.Name('a', gast.Store(), None)],
                            value=gast.Str('c'))
            ])

        source = compiler.ast_to_source(node, indentation='  ')
        self.assertEqual(
            textwrap.dedent("""
            if 1:
              a = b
            else:
              a = 'c'
        """).strip(), source.strip())
Пример #21
0
 def visit_Break(self, node):
     modified_node = self.generic_visit(node)
     self.for_breaked_stack[-1] = True
     breaked_id = len(self.for_breaked_stack)
     replacement = gast.Assign(targets=[
         gast.Name(id=self.breaked_flag + str(breaked_id),
                   ctx=gast.Store(),
                   annotation=None,
                   type_comment=None)
     ],
                               value=gast.Constant(value=True, kind=None))
     return gast.copy_location(replacement, node)
Пример #22
0
 def _init(self, node):
     gradname = ast_.get_name(node)
     if anno.hasanno(node, 'adjoint_var'):
         var = anno.getanno(node, 'adjoint_var')
     else:
         var = anno.getanno(node, 'temp_adjoint_var')
     return gast.Assign(targets=[
         gast.Name(id=gradname, ctx=gast.Store(), annotation=None)
     ],
                        value=gast.Call(func=utils.INIT_GRAD,
                                        args=[var],
                                        keywords=[]))
Пример #23
0
    def _transform_function(self, fn, user_context):
        """Performs source code transformation on a function."""
        future_features = inspect_utils.getfutureimports(fn)
        node, source = parser.parse_entity(fn, future_features=future_features)
        logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

        # In general, the output of inspect.getsource is inexact for lambdas
        # because it uses regex matching to adjust the exact location around
        # the line number that CPython records. Then, the entire containing line
        # is returned, which we may have trouble disambiguating.
        # For example:
        #   x, y = lambda: 1, lambda: 2
        is_lambda = fn.__name__ == '<lambda>'
        if is_lambda:
            nodes = ast_util.find_matching_definitions(node, fn)
            if len(nodes) != 1:
                raise ValueError(
                    'Unable to identify source code of lambda function {}.'
                    ' It was defined in this code:\n'
                    '{}\n'
                    'This code must contain a single distinguishable lambda.'
                    ' To avoid this problem, define each lambda in a separate'
                    ' expression.'.format(fn, source))
            node, = nodes

        origin_info.resolve_entity(node, source, fn)

        namespace = inspect_utils.getnamespace(fn)
        namer = naming.Namer(namespace)
        new_name = namer.new_symbol(self.get_transformed_name(node), ())
        entity_info = transformer.EntityInfo(name=new_name,
                                             source_code=source,
                                             source_file='<fragment>',
                                             future_features=future_features,
                                             namespace=namespace)
        context = transformer.Context(entity_info, namer, user_context)

        node = self._erase_arg_defaults(node)
        node = self.transform_ast(node, context)

        if is_lambda:
            node = gast.Assign(targets=[
                gast.Name(new_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                               value=node)
        else:
            node.name = new_name

        return node, context
Пример #24
0
 def visit_FunctionDef(self, node):
     self.argsid = {arg.id for arg in node.args.args}
     self.renaming = {}
     [self.visit(n) for n in node.body]
     # do it twice to make sure all renaming are done
     [self.visit(n) for n in node.body]
     for k, v in self.renaming.items():
         node.body.insert(
             0,
             ast.Assign([ast.Name(v, ast.Store(), None, None)],
                        ast.Name(k, ast.Load(), None, None), None))
     self.update |= bool(self.renaming)
     return node
Пример #25
0
    def get_for_args_stmts(self, iter_name, args_list):
        '''
        Returns 3 gast stmt nodes for argument.
        1. Initailize of iterate variable
        2. Condition for the loop
        3. Statement for changing of iterate variable during the loop
        NOTE(TODO): Python allows to access iteration variable after loop, such
           as "for i in range(10)" will create i = 9 after the loop. But using
           current conversion will make i = 10. We should find a way to change it
        '''
        len_range_args = len(args_list)
        assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments"
        if len_range_args == 1:
            init_stmt = get_constant_variable_node(iter_name, 0)
        else:
            init_stmt = gast.Assign(
                targets=[
                    gast.Name(
                        id=iter_name,
                        ctx=gast.Store(),
                        annotation=None,
                        type_comment=None)
                ],
                value=args_list[0])

        range_max_node = args_list[0] if len_range_args == 1 else args_list[1]
        step_node = args_list[2] if len_range_args == 3 else gast.Constant(
            value=1, kind=None)

        cond_stmt = gast.Compare(
            left=gast.BinOp(
                left=gast.Name(
                    id=iter_name,
                    ctx=gast.Load(),
                    annotation=None,
                    type_comment=None),
                op=gast.Add(),
                right=step_node),
            ops=[gast.LtE()],
            comparators=[range_max_node])

        change_stmt = gast.AugAssign(
            target=gast.Name(
                id=iter_name,
                ctx=gast.Store(),
                annotation=None,
                type_comment=None),
            op=gast.Add(),
            value=step_node)

        return init_stmt, cond_stmt, change_stmt
Пример #26
0
 def _build_enum_init_node(self):
     enum_init_node = get_constant_variable_node(
         name=self.enum_idx_name, value=0)
     if self.is_for_enumerate_iter() and self.args_length != 1:
         enum_init_node = gast.Assign(
             targets=[
                 gast.Name(
                     id=self.enum_idx_name,
                     ctx=gast.Store(),
                     annotation=None,
                     type_comment=None)
             ],
             value=self.iter_args[1])
     return enum_init_node
Пример #27
0
    def visit_For(self, node):
        target = node.target
        if isinstance(target, ast.Tuple) or isinstance(target, ast.List):
            renamings = OrderedDict()
            self.traverse_tuples(target, (), renamings)
            if renamings:
                gtarget = self.get_new_id()
                node.target = ast.Name(gtarget, node.target.ctx, None)
                for rename, state in renamings.items():
                    nnode = reduce(
                        lambda x, y: ast.Subscript(x, ast.Index(ast.Num(y)),
                                                   ast.Load()), state,
                        ast.Name(gtarget, ast.Load(), None))
                    if isinstance(rename, str):
                        node.body.insert(
                            0,
                            ast.Assign([ast.Name(rename, ast.Store(), None)],
                                       nnode))
                    else:
                        node.body.insert(0, ast.Assign([rename], nnode))

        self.generic_visit(node)
        return node
Пример #28
0
    def visit_For(self, node):
        """
        Handle iterator variable in for loops.

        Iterate variable may be the correct one at the end of the loop.
        """
        body = node.body
        if node.target.id in self.naming:
            body = [ast.Assign(targets=[node.target], value=node.iter)] + body
            self.visit_any_conditionnal(body, node.orelse)
        else:
            iter_dep = self.visit(node.iter)
            self.naming[node.target.id] = iter_dep
            self.visit_any_conditionnal(body, body + node.orelse)
Пример #29
0
    def visit_AnyComp(self, node, comp_type, *path):
        self.update = True
        node.elt = self.visit(node.elt)
        name = "{0}_comprehension{1}".format(comp_type, self.count)
        self.count += 1
        args = self.gather(ImportedIds, node)
        self.count_iter = 0

        starget = "__target"
        body = reduce(self.nest_reducer,
                      reversed(node.generators),
                      ast.Expr(
                          ast.Call(
                              reduce(lambda x, y: ast.Attribute(x, y,
                                                                ast.Load()),
                                     path[1:],
                                     ast.Name(path[0], ast.Load(),
                                              None, None)),
                              [ast.Name(starget, ast.Load(), None, None),
                               node.elt],
                              [],
                              )
                          )
                      )
        # add extra metadata to this node
        metadata.add(body, metadata.Comprehension(starget))
        init = ast.Assign(
            [ast.Name(starget, ast.Store(), None, None)],
            ast.Call(
                ast.Attribute(
                    ast.Name('builtins', ast.Load(), None, None),
                    comp_type,
                    ast.Load()
                    ),
                [], [],)
            )
        result = ast.Return(ast.Name(starget, ast.Load(), None, None))
        sargs = [ast.Name(arg, ast.Param(), None, None) for arg in args]
        fd = ast.FunctionDef(name,
                             ast.arguments(sargs, [], None, [], [], None, []),
                             [init, body, result],
                             [], None, None)
        metadata.add(fd, metadata.Local())
        self.ctx.module.body.append(fd)
        return ast.Call(
            ast.Name(name, ast.Load(), None, None),
            [ast.Name(arg.id, ast.Load(), None, None) for arg in sargs],
            [],
            )  # no sharing !
Пример #30
0
    def test_code_block(self):
        def template(block):  # pylint:disable=unused-argument
            def test_fn(a):  # pylint:disable=unused-variable
                block  # pylint:disable=pointless-statement
                return a

        node = templates.replace(
            template,
            block=[
                gast.Assign([gast.Name('a', gast.Store(), None)],
                            gast.BinOp(gast.Name('a', gast.Load(), None),
                                       gast.Add(), gast.Num(1))),
            ] * 2)[0]
        result = compiler.ast_to_object(node)
        self.assertEquals(3, result.test_fn(1))