示例#1
0
    def visit_arguments(self, node):
        # missing locations for vararg and kwarg set at function level
        if node.vararg:
            vararg = ast.Name(node.vararg, ast.Param())
        else:
            vararg = None

        if node.kwarg:
            kwarg = ast.Name(node.kwarg, ast.Param())
        else:
            kwarg = None

        if node.vararg:
            vararg = ast.Name(node.vararg, ast.Param())
        else:
            vararg = None

        new_node = gast.arguments(
            self._visit(node.args),
            [],  # posonlyargs
            self._visit(vararg),
            [],  # kwonlyargs
            [],  # kw_defaults
            self._visit(kwarg),
            self._visit(node.defaults),
        )
        return new_node
示例#2
0
def outline(name, formal_parameters, out_parameters, stmts, has_return):
    args = ast.arguments(
        [ast.Name(fp, ast.Param(), None) for fp in formal_parameters], None,
        [], [], None, [])

    if isinstance(stmts, ast.expr):
        assert not out_parameters, "no out parameters with expr"
        fdef = ast.FunctionDef(name, args, [ast.Return(stmts)], [], None)
    else:
        fdef = ast.FunctionDef(name, args, stmts, [], None)

        # this is part of a huge trick that plays with delayed type inference
        # it basically computes the return type based on out parameters, and
        # the return statement is unconditionally added so if we have other
        # returns, there will be a computation of the output type based on the
        # __combined of the regular return types and this one The original
        # returns have been patched above to have a different type that
        # cunningly combines with this output tuple
        #
        # This is the only trick I found to let pythran compute both the output
        # variable type and the early return type. But hey, a dirty one :-/

        stmts.append(
            ast.Return(
                ast.Tuple(
                    [ast.Name(fp, ast.Load(), None) for fp in out_parameters],
                    ast.Load())))
        if has_return:
            pr = PatchReturn(stmts[-1])
            pr.visit(fdef)

    return fdef
  def test_ast_to_object(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[gast.Name('a', gast.Param(), None)],
            vararg=None,
            kwonlyargs=[],
            kwarg=None,
            defaults=[],
            kw_defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name('a', gast.Load(), None),
                    right=gast.Num(1)))
        ],
        decorator_list=[],
        returns=None)

    module, source, _ = compiler.ast_to_object(node)

    expected_source = """
      # coding=utf-8
      def f(a):
        return a + 1
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
示例#4
0
def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load):
    """
    Find out the ast.Name.id list of input by analyzing node's AST information.
    """

    name_ids = [
        var_id for var_id, var_ctx in var_ids_dict.items()
        if isinstance(var_ctx[0], ctx)
    ]
    if return_ids:
        new_args = set(return_ids) - set(name_ids)
        name_ids.extend(list(new_args))
    name_ids.sort()
    args = [
        gast.Name(id=name_id,
                  ctx=gast.Load(),
                  annotation=None,
                  type_comment=None) for name_id in name_ids
    ]
    arguments = gast.arguments(args=args,
                               posonlyargs=[],
                               vararg=None,
                               kwonlyargs=[],
                               kw_defaults=None,
                               kwarg=None,
                               defaults=[])
    return arguments
示例#5
0
  def test_ast_to_object(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[gast.Name('a', gast.Param(), None)],
            vararg=None,
            kwonlyargs=[],
            kwarg=None,
            defaults=[],
            kw_defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name('a', gast.Load(), None),
                    right=gast.Num(1)))
        ],
        decorator_list=[],
        returns=None)

    module, source = compiler.ast_to_object(node)

    expected_source = """
      def f(a):
        return a + 1
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
示例#6
0
    def visit_ListComp(self, node):

        if node in self.optimizable_comprehension:
            self.update = True
            self.generic_visit(node)

            iterList = []
            varList = []

            for gen in node.generators:
                iterList.append(self.make_Iterator(gen))
                varList.append(ast.Name(gen.target.id, ast.Param(), None))

            # If dim = 1, product is useless
            if len(iterList) == 1:
                iterAST = iterList[0]
                varAST = ast.arguments([varList[0]], None, [], [], None, [])
            else:
                self.use_itertools = True
                prodName = ast.Attribute(value=ast.Name(id=mangle('itertools'),
                                                        ctx=ast.Load(),
                                                        annotation=None),
                                         attr='product',
                                         ctx=ast.Load())

                varid = varList[0].id  # retarget this id, it's free
                renamings = {v.id: (i, ) for i, v in enumerate(varList)}
                node.elt = ConvertToTuple(varid, renamings).visit(node.elt)
                iterAST = ast.Call(prodName, iterList, [])
                varAST = ast.arguments([ast.Name(varid, ast.Param(), None)],
                                       None, [], [], None, [])

            mapName = ast.Attribute(value=ast.Name(id='__builtin__',
                                                   ctx=ast.Load(),
                                                   annotation=None),
                                    attr='map',
                                    ctx=ast.Load())

            ldBodymap = node.elt
            ldmap = ast.Lambda(varAST, ldBodymap)

            return ast.Call(mapName, [ldmap, iterAST], [])

        else:
            return self.generic_visit(node)
示例#7
0
def _make_arguments(*args):
    """Returns a gast arguments node with these argument nodes."""
    return gast.arguments(args=list(args),
                          posonlyargs=[],
                          vararg=None,
                          kwonlyargs=[],
                          kw_defaults=[],
                          kwarg=None,
                          defaults=[])
    def visit_Compare(self, node):
        node = self.generic_visit(node)
        if len(node.ops) > 1:
            # in case we have more than one compare operator
            # we generate an auxiliary function
            # that lazily evaluates the needed parameters
            imported_ids = self.gather(ImportedIds, node)
            imported_ids = sorted(imported_ids)
            binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids]

            # name of the new function
            forged_name = "{0}_compare{1}".format(self.prefix,
                                                  len(self.compare_functions))

            # call site
            call = ast.Call(
                ast.Name(forged_name, ast.Load(), None),
                binded_args,
                [])

            # new function
            arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids]
            args = ast.arguments(arg_names, None, [], [], None, [])

            body = []  # iteratively fill the body (yeah, feel your body!)

            if is_trivially_copied(node.left):
                prev_holder = node.left
            else:
                body.append(ast.Assign([ast.Name('$0', ast.Store(), None)],
                                       node.left))
                prev_holder = ast.Name('$0', ast.Load(), None)

            for i, exp in enumerate(node.comparators):
                if is_trivially_copied(exp):
                    holder = exp
                else:
                    body.append(ast.Assign([ast.Name('${}'.format(i+1),
                                                     ast.Store(), None)],
                                           exp))
                    holder = ast.Name('${}'.format(i+1), ast.Load(), None)
                cond = ast.Compare(prev_holder,
                                   [node.ops[i]],
                                   [holder])
                body.append(ast.If(cond,
                                   [ast.Pass()],
                                   [ast.Return(path_to_attr(('__builtin__', 'False')))]))
                prev_holder = holder

            body.append(ast.Return(path_to_attr(('__builtin__', 'True'))))

            forged_fdef = ast.FunctionDef(forged_name, args, body, [], None)
            self.compare_functions.append(forged_fdef)

            return call
        else:
            return node
示例#9
0
def parse_cond_args(parent_ids_dict,
                    var_ids_dict,
                    modified_ids_dict=None,
                    ctx=gast.Load):
    """
    Find out the ast.Name.id list of input by analyzing node's AST information.
    """

    # 1. filter the var fit the ctx
    arg_name_ids = [
        var_id for var_id, var_ctx in six.iteritems(var_ids_dict)
        if isinstance(var_ctx[0], ctx)
    ]

    # 2. args should contain modified var ids in if-body or else-body
    #  case:
    #
    #   ```
    #   if b < 1:
    #     z = y
    #   else:
    #     z = x
    #   ```
    #
    #   In the above case, `z` should be in the args of cond()
    if modified_ids_dict:
        arg_name_ids = set(arg_name_ids) | set(modified_ids_dict)

    # 3. args should not contain the vars not in parent ids
    #  case :
    #
    #   ```
    #   x = 1
    #   if x > y:
    #     z = [v for v in range(i)]
    #   ```
    #
    #   In the above case, `v` should not be in the args of cond()
    arg_name_ids = list(set(arg_name_ids) & set(parent_ids_dict))

    arg_name_ids.sort()
    args = [
        gast.Name(id=name_id,
                  ctx=gast.Load(),
                  annotation=None,
                  type_comment=None) for name_id in arg_name_ids
    ]
    arguments = gast.arguments(args=args,
                               posonlyargs=[],
                               vararg=None,
                               kwonlyargs=[],
                               kw_defaults=None,
                               kwarg=None,
                               defaults=[])

    return arguments
示例#10
0
 def visit_arguments(self, node):
     new_node = gast.arguments(
         self._visit(node.args),
         self._visit(node.vararg),
         [],  # kwonlyargs
         [],  # kw_defaults
         self._visit(node.kwarg),
         self._visit(node.defaults),
     )
     return new_node
示例#11
0
 def visit_arguments(self, node):
     new_node = gast.arguments(
         self._visit(node.args),
         self._visit(node.vararg),
         [],  # kwonlyargs
         [],  # kw_defaults
         self._visit(node.kwarg),
         self._visit(node.defaults),
     )
     return new_node
示例#12
0
    def visit_Compare(self, node):
        node = self.generic_visit(node)
        if len(node.ops) > 1:
            # in case we have more than one compare operator
            # we generate an auxiliary function
            # that lazily evaluates the needed parameters
            imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx)
            imported_ids = sorted(imported_ids)
            binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids]

            # name of the new function
            forged_name = "{0}_compare{1}".format(self.prefix,
                                                  len(self.compare_functions))

            # call site
            call = ast.Call(ast.Name(forged_name, ast.Load(), None),
                            binded_args, [])

            # new function
            arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids]
            args = ast.arguments(arg_names, None, [], [], None, [])

            body = []  # iteratively fill the body (yeah, feel your body!)

            if is_trivially_copied(node.left):
                prev_holder = node.left
            else:
                body.append(
                    ast.Assign([ast.Name('$0', ast.Store(), None)], node.left))
                prev_holder = ast.Name('$0', ast.Load(), None)

            for i, exp in enumerate(node.comparators):
                if is_trivially_copied(exp):
                    holder = exp
                else:
                    body.append(
                        ast.Assign(
                            [ast.Name('${}'.format(i + 1), ast.Store(), None)],
                            exp))
                    holder = ast.Name('${}'.format(i + 1), ast.Load(), None)
                cond = ast.Compare(prev_holder, [node.ops[i]], [holder])
                body.append(
                    ast.If(
                        cond, [ast.Pass()],
                        [ast.Return(path_to_attr(('__builtin__', 'False')))]))
                prev_holder = holder

            body.append(ast.Return(path_to_attr(('__builtin__', 'True'))))

            forged_fdef = ast.FunctionDef(forged_name, args, body, [], None)
            self.compare_functions.append(forged_fdef)

            return call
        else:
            return node
示例#13
0
    def visit_ListComp(self, node):

        if node in self.optimizable_comprehension:
            self.update = True
            self.generic_visit(node)

            iterList = []
            varList = []

            for gen in node.generators:
                iterList.append(self.make_Iterator(gen))
                varList.append(ast.Name(gen.target.id, ast.Param(), None))

            # If dim = 1, product is useless
            if len(iterList) == 1:
                iterAST = iterList[0]
                varAST = ast.arguments([varList[0]], None, [], [], None, [])
            else:
                self.use_itertools = True
                prodName = ast.Attribute(
                    value=ast.Name(id='itertools',
                                   ctx=ast.Load(),
                                   annotation=None),
                    attr='product', ctx=ast.Load())

                iterAST = ast.Call(prodName, iterList, [])
                varAST = ast.arguments([ast.Tuple(varList, ast.Store())],
                                       None, [], [], None, [])

            mapName = ast.Attribute(
                value=ast.Name(id='__builtin__',
                               ctx=ast.Load(),
                               annotation=None),
                attr='map', ctx=ast.Load())

            ldBodymap = node.elt
            ldmap = ast.Lambda(varAST, ldBodymap)

            return ast.Call(mapName, [ldmap, iterAST], [])

        else:
            return self.generic_visit(node)
示例#14
0
文件: ast3.py 项目: pombredanne/gast
 def visit_arguments(self, node):
     new_node = gast.arguments(
         self._visit(node.args),
         [],  # posonlyargs
         self._visit(node.vararg),
         self._visit(node.kwonlyargs),
         self._visit(node.kw_defaults),
         self._visit(node.kwarg),
         self._visit(node.defaults),
     )
     return gast.copy_location(new_node, node)
示例#15
0
    def visitComp(self, node, make_attr):

        if node in self.optimizable_comprehension:
            self.update = True
            self.generic_visit(node)

            iters = [self.make_Iterator(gen) for gen in node.generators]
            variables = [
                ast.Name(gen.target.id, ast.Param(), None, None)
                for gen in node.generators
            ]

            # If dim = 1, product is useless
            if len(iters) == 1:
                iterAST = iters[0]
                varAST = ast.arguments([variables[0]], [], None, [], [], None,
                                       [])
            else:
                self.use_itertools = True
                prodName = ast.Attribute(value=ast.Name(id=mangle('itertools'),
                                                        ctx=ast.Load(),
                                                        annotation=None,
                                                        type_comment=None),
                                         attr='product',
                                         ctx=ast.Load())

                varid = variables[0].id  # retarget this id, it's free
                renamings = {v.id: (i, ) for i, v in enumerate(variables)}
                node.elt = ConvertToTuple(varid, renamings).visit(node.elt)
                iterAST = ast.Call(prodName, iters, [])
                varAST = ast.arguments(
                    [ast.Name(varid, ast.Param(), None, None)], [], None, [],
                    [], None, [])

            ldBodymap = node.elt
            ldmap = ast.Lambda(varAST, ldBodymap)

            return make_attr(ldmap, iterAST)

        else:
            return self.generic_visit(node)
示例#16
0
 def visit_arguments(self, node):
     new_node = gast.arguments(
         [self._visit(n) for n in node.args],
         self._make_annotated_arg(node, node.vararg,
                                  self._visit(node.varargannotation)),
         [self._visit(n) for n in node.kwonlyargs],
         self._visit(node.kw_defaults),
         self._make_annotated_arg(node, node.kwarg,
                                  self._visit(node.kwargannotation)),
         self._visit(node.defaults),
     )
     return new_node
示例#17
0
    def visit_GeneratorExp(self, node):

        if node in self.optimizable_comprehension:
            self.update = True
            self.generic_visit(node)

            iters = [self.make_Iterator(gen) for gen in node.generators]
            variables = [
                ast.Name(gen.target.id, ast.Param(), None)
                for gen in node.generators
            ]

            # If dim = 1, product is useless
            if len(iters) == 1:
                iterAST = iters[0]
                varAST = ast.arguments([variables[0]], None, [], [], None, [])
            else:
                prodName = ast.Attribute(value=ast.Name(id='itertools',
                                                        ctx=ast.Load(),
                                                        annotation=None),
                                         attr='product',
                                         ctx=ast.Load())

                iterAST = ast.Call(prodName, iters, [])
                varAST = ast.arguments([ast.Tuple(variables, ast.Store())],
                                       None, [], [], None, [])

            imapName = ast.Attribute(value=ast.Name(id=MODULE,
                                                    ctx=ast.Load(),
                                                    annotation=None),
                                     attr=IMAP,
                                     ctx=ast.Load())

            ldBodyimap = node.elt
            ldimap = ast.Lambda(varAST, ldBodyimap)

            return ast.Call(imapName, [ldimap, iterAST], [])

        else:
            return self.generic_visit(node)
示例#18
0
    def visit_GeneratorExp(self, node):

        if node in self.optimizable_comprehension:
            self.update = True
            self.generic_visit(node)

            iters = [self.make_Iterator(gen) for gen in node.generators]
            variables = [ast.Name(gen.target.id, ast.Param(), None)
                         for gen in node.generators]

            # If dim = 1, product is useless
            if len(iters) == 1:
                iterAST = iters[0]
                varAST = ast.arguments([variables[0]], None, [], [], None, [])
            else:
                prodName = ast.Attribute(
                    value=ast.Name(id=mangle('itertools'),
                                   ctx=ast.Load(),
                                   annotation=None),
                    attr='product', ctx=ast.Load())

                iterAST = ast.Call(prodName, iters, [])
                varAST = ast.arguments([ast.Tuple(variables, ast.Store())],
                                       None, [], [], None, [])

            imapName = ast.Attribute(
                value=ast.Name(id=ASMODULE,
                               ctx=ast.Load(),
                               annotation=None),
                attr=IMAP, ctx=ast.Load())

            ldBodyimap = node.elt
            ldimap = ast.Lambda(varAST, ldBodyimap)

            return ast.Call(imapName, [ldimap, iterAST], [])

        else:
            return self.generic_visit(node)
示例#19
0
文件: intrinsic.py 项目: yws/pythran
 def __init__(self, **kwargs):
     self.argument_effects = kwargs.get('argument_effects',
                                        (UpdateEffect(), ) * 11)
     self.global_effects = kwargs.get('global_effects', False)
     self.return_alias = kwargs.get('return_alias',
                                    lambda x: {UnboundValue})
     self.args = ast.arguments(
         [ast.Name(n, ast.Param(), None)
          for n in kwargs.get('args', [])], None, [], [], None,
         [to_ast(d) for d in kwargs.get('defaults', [])])
     self.return_range = kwargs.get("return_range",
                                    lambda call: UNKNOWN_RANGE)
     self.return_range_content = kwargs.get("return_range_content",
                                            lambda c: UNKNOWN_RANGE)
示例#20
0
 def __init__(self, **kwargs):
     self.argument_effects = kwargs.get('argument_effects',
                                        (UpdateEffect(),) * 11)
     self.global_effects = kwargs.get('global_effects', False)
     self.return_alias = kwargs.get('return_alias',
                                    lambda x: {UnboundValue})
     self.args = ast.arguments(
         [ast.Name(n, ast.Param(), None) for n in kwargs.get('args', [])],
         None, [], [], None,
         [to_ast(d) for d in kwargs.get('defaults', [])])
     self.return_range = kwargs.get("return_range",
                                    lambda call: UNKNOWN_RANGE)
     self.return_range_content = kwargs.get("return_range_content",
                                            lambda c: UNKNOWN_RANGE)
示例#21
0
    def visit_AnyComp(self, node, comp_type, *path):
        self.update = True
        node.elt = self.visit(node.elt)
        name = "{0}_comprehension{1}".format(comp_type, self.count)
        self.count += 1
        args = self.gather(ImportedIds, node)
        self.count_iter = 0

        starget = "__target"
        body = reduce(self.nest_reducer,
                      reversed(node.generators),
                      ast.Expr(
                          ast.Call(
                              reduce(lambda x, y: ast.Attribute(x, y,
                                                                ast.Load()),
                                     path[1:],
                                     ast.Name(path[0], ast.Load(),
                                              None, None)),
                              [ast.Name(starget, ast.Load(), None, None),
                               node.elt],
                              [],
                              )
                          )
                      )
        # add extra metadata to this node
        metadata.add(body, metadata.Comprehension(starget))
        init = ast.Assign(
            [ast.Name(starget, ast.Store(), None, None)],
            ast.Call(
                ast.Attribute(
                    ast.Name('builtins', ast.Load(), None, None),
                    comp_type,
                    ast.Load()
                    ),
                [], [],)
            )
        result = ast.Return(ast.Name(starget, ast.Load(), None, None))
        sargs = [ast.Name(arg, ast.Param(), None, None) for arg in args]
        fd = ast.FunctionDef(name,
                             ast.arguments(sargs, [], None, [], [], None, []),
                             [init, body, result],
                             [], None, None)
        metadata.add(fd, metadata.Local())
        self.ctx.module.body.append(fd)
        return ast.Call(
            ast.Name(name, ast.Load(), None, None),
            [ast.Name(arg.id, ast.Load(), None, None) for arg in sargs],
            [],
            )  # no sharing !
    def visitComp(self, node, make_attr):

        if node in self.optimizable_comprehension:
            self.update = True
            self.generic_visit(node)

            iters = [self.make_Iterator(gen) for gen in node.generators]
            variables = [ast.Name(gen.target.id, ast.Param(), None)
                         for gen in node.generators]

            # If dim = 1, product is useless
            if len(iters) == 1:
                iterAST = iters[0]
                varAST = ast.arguments([variables[0]], None, [], [], None, [])
            else:
                self.use_itertools = True
                prodName = ast.Attribute(
                    value=ast.Name(id=mangle('itertools'),
                                   ctx=ast.Load(),
                                   annotation=None),
                    attr='product', ctx=ast.Load())

                varid = variables[0].id  # retarget this id, it's free
                renamings = {v.id: (i,) for i, v in enumerate(variables)}
                node.elt = ConvertToTuple(varid, renamings).visit(node.elt)
                iterAST = ast.Call(prodName, iters, [])
                varAST = ast.arguments([ast.Name(varid, ast.Param(), None)],
                                       None, [], [], None, [])

            ldBodymap = node.elt
            ldmap = ast.Lambda(varAST, ldBodymap)

            return make_attr(ldmap, iterAST)

        else:
            return self.generic_visit(node)
示例#23
0
 def make_Iterator(self, gen):
     if gen.ifs:
         ldFilter = ast.Lambda(
             ast.arguments([ast.Name(gen.target.id, ast.Param(), None)],
                           None, [], [], None, []),
             ast.BoolOp(ast.And(), gen.ifs)
             if len(gen.ifs) > 1 else gen.ifs[0])
         ifilterName = ast.Attribute(
             value=ast.Name(id=ASMODULE,
                            ctx=ast.Load(),
                            annotation=None),
             attr=IFILTER, ctx=ast.Load())
         return ast.Call(ifilterName, [ldFilter, gen.iter], [])
     else:
         return gen.iter
示例#24
0
  def generate_FunctionDef(self):
    """Generate a FunctionDef node."""

    # Generate the arguments, register them as available
    arg_vars = self.sample_node_list(
        low=2, high=10, generator=lambda: self.generate_Name(gast.Param()))
    args = gast.arguments(arg_vars, None, [], [], None, [])

    # Generate the function body
    body = self.sample_node_list(
        low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement)
    body.append(self.generate_Return())
    fn_name = self.generate_Name().id
    node = gast.FunctionDef(fn_name, args, body, (), None)
    return node
示例#25
0
文件: codegen.py 项目: Harryi0/tinyML
  def generate_FunctionDef(self):
    """Generate a FunctionDef node."""

    # Generate the arguments, register them as available
    arg_vars = self.sample_node_list(
        low=2, high=10, generator=lambda: self.generate_Name(gast.Param()))
    args = gast.arguments(arg_vars, None, [], [], None, [])

    # Generate the function body
    body = self.sample_node_list(
        low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement)
    body.append(self.generate_Return())
    fn_name = self.generate_Name().id
    node = gast.FunctionDef(fn_name, args, body, (), None)
    return node
示例#26
0
 def make_Iterator(self, gen):
     if gen.ifs:
         ldFilter = ast.Lambda(
             ast.arguments([ast.Name(gen.target.id, ast.Param(), None)],
                           None, [], [], None, []),
             ast.BoolOp(ast.And(), gen.ifs)
             if len(gen.ifs) > 1 else gen.ifs[0])
         ifilterName = ast.Attribute(value=ast.Name(id=MODULE,
                                                    ctx=ast.Load(),
                                                    annotation=None),
                                     attr=IFILTER,
                                     ctx=ast.Load())
         return ast.Call(ifilterName, [ldFilter, gen.iter], [])
     else:
         return gen.iter
示例#27
0
def wrap_func_ast(
    name: str,
    args: List[str],
    block: List[AST],
    returns: List[str] = [],
    return_tuple: bool = False,
) -> FunctionDef:
    """Wrap the given code block in a function as a FunctionDef AST node.

    Args:
        name: The name of the function wrapping the block of code.
        args: List of argument names which the wrapping function accepts
        block: List of AST nodes reprsenting the code block being wrapped by the
            wrapping function. The code block should not contain `return` statements
        returns: List of variable names to return from the wrapping functions.
        return_tuple: Whether to force the wrapping function to return to be a tuple,
            irregardless of whether multiple values are actually returned.
    Returns:
        The created function wrapping the given code block.
    """
    # append return statement if actually returning variables
    if len(returns) > 0:
        # convert return names to return AST node
        return_ast = Return(
            value=[TupleAST(elts=[name_ast(r) for r in returns], ctx=Load())]
            if len(returns) > 1 or return_tuple
            else name_ast(returns[0])
        )
        block = block + [return_ast]

    return FunctionDef(
        name=name,
        args=arguments(
            args=[name_ast(a, Param()) for a in args],
            defaults=[],
            posonlyargs=[],
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            vararg=None,
        ),
        body=block,
        decorator_list=[],
        returns="",
        type_comment="",
    )
示例#28
0
    def visit_Compare(self, node):
        node = self.generic_visit(node)
        if len(node.ops) > 1:
            # in case we have more than one compare operator
            # we generate an auxiliary function
            # that lazily evaluates the needed parameters
            imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx)
            imported_ids = sorted(imported_ids)
            binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids]

            # name of the new function
            forged_name = "{0}_compare{1}".format(self.prefix,
                                                  len(self.compare_functions))

            # call site
            call = ast.Call(
                ast.Name(forged_name, ast.Load(), None),
                binded_args,
                [])

            # new function
            arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids]
            args = ast.arguments(arg_names, None, [], [], None, [])

            body = []  # iteratively fill the body (yeah, feel your body!)
            body.append(ast.Assign([ast.Name('$0', ast.Store(), None)],
                                   node.left))
            for i, exp in enumerate(node.comparators):
                body.append(ast.Assign([ast.Name('${}'.format(i+1),
                                                 ast.Store(), None)],
                                       exp))
                cond = ast.Compare(ast.Name('${}'.format(i), ast.Load(), None),
                                   [node.ops[i]],
                                   [ast.Name('${}'.format(i+1),
                                             ast.Load(), None)])
                body.append(ast.If(cond,
                                   [ast.Pass()],
                                   [ast.Return(ast.Num(0))]))
            body.append(ast.Return(ast.Num(1)))

            forged_fdef = ast.FunctionDef(forged_name, args, body, [], None)
            self.compare_functions.append(forged_fdef)

            return call
        else:
            return node
示例#29
0
    def visit_AnyComp(self, node, comp_type, *path):
        self.update = True
        node.elt = self.visit(node.elt)
        name = "{0}_comprehension{1}".format(comp_type, self.count)
        self.count += 1
        args = self.passmanager.gather(ImportedIds, node, self.ctx)
        self.count_iter = 0

        starget = "__target"
        body = reduce(self.nest_reducer,
                      reversed(node.generators),
                      ast.Expr(
                          ast.Call(
                              reduce(lambda x, y: ast.Attribute(x, y,
                                                                ast.Load()),
                                     path[1:],
                                     ast.Name(path[0], ast.Load(), None)),
                              [ast.Name(starget, ast.Load(), None), node.elt],
                              [],
                              )
                          )
                      )
        # add extra metadata to this node
        metadata.add(body, metadata.Comprehension(starget))
        init = ast.Assign(
            [ast.Name(starget, ast.Store(), None)],
            ast.Call(
                ast.Attribute(
                    ast.Name('__builtin__', ast.Load(), None),
                    comp_type,
                    ast.Load()
                    ),
                [], [],)
            )
        result = ast.Return(ast.Name(starget, ast.Load(), None))
        sargs = sorted(ast.Name(arg, ast.Param(), None) for arg in args)
        fd = ast.FunctionDef(name,
                             ast.arguments(sargs, None, [], [], None, []),
                             [init, body, result],
                             [], None)
        self.ctx.module.body.append(fd)
        return ast.Call(
            ast.Name(name, ast.Load(), None),
            [ast.Name(arg.id, ast.Load(), None) for arg in sargs],
            [],
            )  # no sharing !
def outline(name, formal_parameters, out_parameters, stmts,
            has_return, has_break, has_cont):

    args = ast.arguments(
        [ast.Name(fp, ast.Param(), None) for fp in formal_parameters],
        None, [], [], None, [])

    if isinstance(stmts, ast.expr):
        assert not out_parameters, "no out parameters with expr"
        fdef = ast.FunctionDef(name, args, [ast.Return(stmts)], [], None)
    else:
        fdef = ast.FunctionDef(name, args, stmts, [], None)

        # this is part of a huge trick that plays with delayed type inference
        # it basically computes the return type based on out parameters, and
        # the return statement is unconditionally added so if we have other
        # returns, there will be a computation of the output type based on the
        # __combined of the regular return types and this one The original
        # returns have been patched above to have a different type that
        # cunningly combines with this output tuple
        #
        # This is the only trick I found to let pythran compute both the output
        # variable type and the early return type. But hey, a dirty one :-/

        stmts.append(
            ast.Return(
                ast.Tuple(
                    [ast.Name(fp, ast.Load(), None) for fp in out_parameters],
                    ast.Load()
                )
            )
        )
        if has_return:
            pr = PatchReturn(stmts[-1], has_break or has_cont)
            pr.visit(fdef)

        if has_break or has_cont:
            if not has_return:
                stmts[-1].value = ast.Tuple([ast.Num(LOOP_NONE),
                                             stmts[-1].value],
                                            ast.Load())
            pbc = PatchBreakContinue(stmts[-1])
            pbc.visit(fdef)

    return fdef
  def test_load_ast(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[
                gast.Name(
                    'a', ctx=gast.Param(), annotation=None, type_comment=None)
            ],
            posonlyargs=[],
            vararg=None,
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name(
                        'a',
                        ctx=gast.Load(),
                        annotation=None,
                        type_comment=None),
                    right=gast.Constant(1, kind=None)))
        ],
        decorator_list=[],
        returns=None,
        type_comment=None)

    module, source, _ = loader.load_ast(node)

    expected_source = """
      # coding=utf-8
      def f(a):
          return (a + 1)
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
示例#32
0
    def create_lambda_node(func_or_expr_node, is_if_expr=False):
        body = func_or_expr_node
        if not is_if_expr:
            body = gast.Call(func=gast.Name(id=func_or_expr_node.name,
                                            ctx=gast.Load(),
                                            annotation=None,
                                            type_comment=None),
                             args=[func_or_expr_node.args],
                             keywords=[])

        lambda_node = gast.Lambda(args=gast.arguments(args=[],
                                                      posonlyargs=[],
                                                      vararg=None,
                                                      kwonlyargs=[],
                                                      kw_defaults=None,
                                                      kwarg=None,
                                                      defaults=[]),
                                  body=body)
        return lambda_node
示例#33
0
    def visit_GeneratorExp(self, node):
        self.update = True
        node.elt = self.visit(node.elt)
        name = "generator_expression{0}".format(self.count)
        self.count += 1
        args = self.passmanager.gather(ImportedIds, node, self.ctx)
        self.count_iter = 0

        body = reduce(self.nest_reducer, reversed(node.generators),
                      ast.Expr(ast.Yield(node.elt)))

        sargs = [ast.Name(arg, ast.Param(), None) for arg in args]
        fd = ast.FunctionDef(name, ast.arguments(sargs, None, [], [], None,
                                                 []), [body], [], None)
        self.ctx.module.body.append(fd)
        return ast.Call(
            ast.Name(name, ast.Load(), None),
            [ast.Name(arg.id, ast.Load(), None) for arg in sargs],
            [],
        )  # no sharing !
示例#34
0
    def visit_arguments(self, node):
        if node.vararg:
            vararg = ast.Name(node.vararg, ast.Param())
            ast.copy_location(vararg, node)
        else:
            vararg = None

        if node.kwarg:
            kwarg = ast.Name(node.kwarg, ast.Param())
            ast.copy_location(kwarg, node)
        else:
            kwarg = None
        new_node = gast.arguments(
            self._visit(node.args),
            self._visit(vararg),
            [],  # kwonlyargs
            [],  # kw_defaults
            self._visit(kwarg),
            self._visit(node.defaults),
        )
        return new_node
示例#35
0
    def visit_Module(self, node):
        """Turn globals assignment to functionDef and visit function defs. """
        module_body = list()
        # Gather top level assigned variables.
        for stmt in node.body:
            if not isinstance(stmt, ast.Assign):
                continue
            for target in stmt.targets:
                if not isinstance(target, ast.Name):
                    raise PythranSyntaxError(
                        "Top-level assignment to an expression.",
                        target)
                if target.id in self.to_expand:
                    raise PythranSyntaxError(
                        "Multiple top-level definition of %s." % target.id,
                        target)
                self.to_expand.add(target.id)

        for stmt in node.body:
            if isinstance(stmt, ast.Assign):
                self.local_decl = set()
                cst_value = self.visit(stmt.value)
                for target in stmt.targets:
                    assert isinstance(target, ast.Name)
                    module_body.append(
                        ast.FunctionDef(target.id,
                                        ast.arguments([], None,
                                                      [], [], None, []),
                                        [ast.Return(value=cst_value)],
                                        [], None))
                    metadata.add(module_body[-1].body[0],
                                 metadata.StaticReturn())
            else:
                self.local_decl = self.passmanager.gather(
                    LocalNameDeclarations, stmt,
                    self.ctx)
                module_body.append(self.visit(stmt))

        node.body = module_body
        return node
示例#36
0
    def visit_GeneratorExp(self, node):
        self.update = True
        node.elt = self.visit(node.elt)
        name = "generator_expression{0}".format(self.count)
        self.count += 1
        args = self.passmanager.gather(ImportedIds, node, self.ctx)
        self.count_iter = 0

        body = reduce(self.nest_reducer,
                      reversed(node.generators),
                      ast.Expr(ast.Yield(node.elt))
                      )

        sargs = [ast.Name(arg, ast.Param(), None) for arg in args]
        fd = ast.FunctionDef(name,
                             ast.arguments(sargs, None, [], [], None, []),
                             [body], [], None)
        self.ctx.module.body.append(fd)
        return ast.Call(
            ast.Name(name, ast.Load(), None),
            [ast.Name(arg.id, ast.Load(), None) for arg in sargs],
            [],
            )  # no sharing !
示例#37
0
    def visit_Module(self, node):
        """Turn globals assignment to functionDef and visit function defs. """
        module_body = list()
        # Gather top level assigned variables.
        for stmt in node.body:
            if not isinstance(stmt, ast.Assign):
                continue
            for target in stmt.targets:
                if not isinstance(target, ast.Name):
                    raise PythranSyntaxError(
                        "Top-level assignment to an expression.", target)
                if target.id in self.to_expand:
                    raise PythranSyntaxError(
                        "Multiple top-level definition of %s." % target.id,
                        target)
                self.to_expand.add(target.id)

        for stmt in node.body:
            if isinstance(stmt, ast.Assign):
                self.local_decl = set()
                cst_value = self.visit(stmt.value)
                for target in stmt.targets:
                    assert isinstance(target, ast.Name)
                    module_body.append(
                        ast.FunctionDef(
                            target.id, ast.arguments([], None, [], [], None,
                                                     []),
                            [ast.Return(value=cst_value)], [], None))
                    metadata.add(module_body[-1].body[0],
                                 metadata.StaticReturn())
            else:
                self.local_decl = self.passmanager.gather(
                    LocalNameDeclarations, stmt, self.ctx)
                module_body.append(self.visit(stmt))

        node.body = module_body
        return node
示例#38
0
    def get_while_stmt_nodes(self, node):
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []

        # Python can create variable in loop and use it out of loop, E.g.
        #
        # while x < 10:
        #     x += 1
        #     y = x
        # z = y
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_CONDITION_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=[gast.Return(value=node.test)],
            decorator_list=[],
            returns=None,
            type_comment=None)

        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(
            gast.Return(value=generate_name_node(
                loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_BODY_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_nodes = create_while_nodes(
            condition_func_node.name, body_func_node.name, loop_var_names)
        new_stmts.extend(while_loop_nodes)
        return new_stmts
示例#39
0
    def visit_Module(self, node):
        """Turn globals assignment to functionDef and visit function defs. """
        module_body = list()
        symbols = set()
        # Gather top level assigned variables.
        for stmt in node.body:
            if isinstance(stmt, (ast.Import, ast.ImportFrom)):
                for alias in stmt.names:
                    name = alias.asname or alias.name
                    symbols.add(name)  # no warning here
            elif isinstance(stmt, ast.FunctionDef):
                if stmt.name in symbols:
                    raise PythranSyntaxError(
                        "Multiple top-level definition of %s." % stmt.name,
                        stmt)
                else:
                    symbols.add(stmt.name)

            if not isinstance(stmt, ast.Assign):
                continue

            for target in stmt.targets:
                if not isinstance(target, ast.Name):
                    raise PythranSyntaxError(
                        "Top-level assignment to an expression.", target)
                if target.id in self.to_expand:
                    raise PythranSyntaxError(
                        "Multiple top-level definition of %s." % target.id,
                        target)
                if isinstance(stmt.value, ast.Name):
                    if stmt.value.id in symbols:
                        continue  # create aliasing between top level symbols
                self.to_expand.add(target.id)

        for stmt in node.body:
            if isinstance(stmt, ast.Assign):
                # that's not a global var, but a module/function aliasing
                if all(
                        isinstance(t, ast.Name) and t.id not in self.to_expand
                        for t in stmt.targets):
                    module_body.append(stmt)
                    continue

                self.local_decl = set()
                cst_value = GlobalTransformer().visit(self.visit(stmt.value))
                for target in stmt.targets:
                    assert isinstance(target, ast.Name)
                    module_body.append(
                        ast.FunctionDef(
                            target.id,
                            ast.arguments([], [], None, [], [], None, []),
                            [ast.Return(value=cst_value)], [], None, None))
                    metadata.add(module_body[-1].body[0],
                                 metadata.StaticReturn())
            else:
                self.local_decl = self.gather(LocalNameDeclarations, stmt)
                module_body.append(self.visit(stmt))

        self.update |= bool(self.to_expand)

        node.body = module_body
        return node
示例#40
0
    def get_for_stmt_nodes(self, node):
        # TODO: consider for - else in python

        # 1. get key statements for different cases
        # NOTE 1: three key statements:
        #   1). init_stmts: list[node], prepare nodes of for loop, may not only one
        #   2). cond_stmt: node, condition node to judge whether continue loop
        #   3). body_stmts: list[node], updated loop body, sometimes we should change
        #       the original statement in body, not just append new statement
        #
        # NOTE 2: The following `for` statements will be transformed to `while` statements:
        #   1). for x in range(*)
        #   2). for x in iter_var
        #   3). for i, x in enumerate(*)

        current_for_node_parser = ForNodeVisitor(node)
        stmts_tuple = current_for_node_parser.parse()
        if stmts_tuple is None:
            return [node]
        init_stmts, cond_stmt, body_stmts = stmts_tuple

        # 2. get original loop vars
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
        # we need append new loop var & remove useless loop var
        #   1. for x in var -> x is no need
        #   2. for i, x in enumerate(var) -> x is no need
        if current_for_node_parser.is_for_iter(
        ) or current_for_node_parser.is_for_enumerate_iter():
            iter_var_name = current_for_node_parser.iter_var_name
            iter_idx_name = current_for_node_parser.iter_idx_name
            loop_var_names.add(iter_idx_name)
            if iter_var_name not in create_var_names:
                loop_var_names.remove(iter_var_name)

        # 3. prepare result statement list
        new_stmts = []
        # Python can create variable in loop and use it out of loop, E.g.
        #
        # for x in range(10):
        #     y += x
        # print(x) # x = 10
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        # 4. append init statements
        new_stmts.extend(init_stmts)

        # 5. create & append condition function node
        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_CONDITION_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=[gast.Return(value=cond_stmt)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        # 6. create & append loop body function node
        # append return values for loop body
        body_stmts.append(
            gast.Return(
                value=generate_name_node(loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_BODY_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=body_stmts,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        # 7. create & append while loop node
        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name,
                                            loop_var_names)
        new_stmts.append(while_loop_node)

        return new_stmts
示例#41
0
    def get_while_stmt_nodes(self, node):
        # TODO: consider while - else in python
        if not self.name_visitor.is_control_flow_loop(node):
            return [node]

        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []

        # Python can create variable in loop and use it out of loop, E.g.
        #
        # while x < 10:
        #     x += 1
        #     y = x
        # z = y
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        # while x < 10 in dygraph should be convert into static tensor < 10
        for name in loop_var_names:
            new_stmts.append(to_static_variable_gast_node(name))

        logical_op_transformer = LogicalOpTransformer(node.test)
        cond_value_node = logical_op_transformer.transform()

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_CONDITION_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=[gast.Return(value=cond_value_node)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(
            gast.Return(
                value=generate_name_node(loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_BODY_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name,
                                            loop_var_names)
        new_stmts.append(while_loop_node)
        return new_stmts
示例#42
0
    def get_for_stmt_nodes(self, node):
        # TODO: consider for - else in python
        if not self.name_visitor.is_control_flow_loop(node):
            return [node]

        # TODO: support non-range case
        range_call_node = self.get_for_range_node(node)
        if range_call_node is None:
            return [node]

        if not isinstance(node.target, gast.Name):
            return [node]
        iter_var_name = node.target.id

        init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts(
            iter_var_name, range_call_node.args)

        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []
        # Python can create variable in loop and use it out of loop, E.g.
        #
        # for x in range(10):
        #     y += x
        # print(x) # x = 10
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        new_stmts.append(init_stmt)

        # for x in range(10) in dygraph should be convert into static tensor + 1 <= 10
        for name in loop_var_names:
            new_stmts.append(to_static_variable_gast_node(name))

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_CONDITION_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=[gast.Return(value=cond_stmt)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(change_stmt)
        new_body.append(
            gast.Return(
                value=generate_name_node(loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_BODY_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name,
                                            loop_var_names)
        new_stmts.append(while_loop_node)

        return new_stmts
示例#43
0
    def visit_Module(self, node):
        """Turn globals assignment to functionDef and visit function defs. """
        module_body = list()
        symbols = set()
        # Gather top level assigned variables.
        for stmt in node.body:
            if isinstance(stmt, (ast.Import, ast.ImportFrom)):
                for alias in stmt.names:
                    name = alias.asname or alias.name
                    symbols.add(name)  # no warning here
            elif isinstance(stmt, ast.FunctionDef):
                if stmt.name in symbols:
                    raise PythranSyntaxError(
                        "Multiple top-level definition of %s." % stmt.name,
                        stmt)
                else:
                    symbols.add(stmt.name)

            if not isinstance(stmt, ast.Assign):
                continue

            for target in stmt.targets:
                if not isinstance(target, ast.Name):
                    raise PythranSyntaxError(
                        "Top-level assignment to an expression.",
                        target)
                if target.id in self.to_expand:
                    raise PythranSyntaxError(
                        "Multiple top-level definition of %s." % target.id,
                        target)
                if isinstance(stmt.value, ast.Name):
                    if stmt.value.id in symbols:
                        continue  # create aliasing between top level symbols
                self.to_expand.add(target.id)

        for stmt in node.body:
            if isinstance(stmt, ast.Assign):
                # that's not a global var, but a module/function aliasing
                if all(isinstance(t, ast.Name) and t.id not in self.to_expand
                       for t in stmt.targets):
                    module_body.append(stmt)
                    continue

                self.local_decl = set()
                cst_value = self.visit(stmt.value)
                for target in stmt.targets:
                    assert isinstance(target, ast.Name)
                    module_body.append(
                        ast.FunctionDef(target.id,
                                        ast.arguments([], None,
                                                      [], [], None, []),
                                        [ast.Return(value=cst_value)],
                                        [], None))
                    metadata.add(module_body[-1].body[0],
                                 metadata.StaticReturn())
            else:
                self.local_decl = self.passmanager.gather(
                    LocalNameDeclarations, stmt,
                    self.ctx)
                module_body.append(self.visit(stmt))

        node.body = module_body
        return node