Exemplo n.º 1
0
def outline(name, formal_parameters, out_parameters, stmts, has_return):
    args = ast.arguments(
        [ast.Name(fp, ast.Param(), None) for fp in formal_parameters], None,
        [], [], None, [])

    if isinstance(stmts, ast.expr):
        assert not out_parameters, "no out parameters with expr"
        fdef = ast.FunctionDef(name, args, [ast.Return(stmts)], [], None)
    else:
        fdef = ast.FunctionDef(name, args, stmts, [], None)

        # this is part of a huge trick that plays with delayed type inference
        # it basically computes the return type based on out parameters, and
        # the return statement is unconditionally added so if we have other
        # returns, there will be a computation of the output type based on the
        # __combined of the regular return types and this one The original
        # returns have been patched above to have a different type that
        # cunningly combines with this output tuple
        #
        # This is the only trick I found to let pythran compute both the output
        # variable type and the early return type. But hey, a dirty one :-/

        stmts.append(
            ast.Return(
                ast.Tuple(
                    [ast.Name(fp, ast.Load(), None) for fp in out_parameters],
                    ast.Load())))
        if has_return:
            pr = PatchReturn(stmts[-1])
            pr.visit(fdef)

    return fdef
Exemplo n.º 2
0
    def visit_Lambda(self, node):
        if MODULES['functools'] not in self.global_declarations.values():
            import_ = ast.Import([ast.alias('functools', mangle('functools'))])
            self.imports.append(import_)
            functools_module = MODULES['functools']
            self.global_declarations[mangle('functools')] = functools_module

        self.generic_visit(node)
        forged_name = "{0}_lambda{1}".format(self.prefix,
                                             len(self.lambda_functions))

        ii = self.passmanager.gather(ImportedIds, node, self.ctx)
        ii.difference_update(self.lambda_functions)  # remove current lambdas

        binded_args = [ast.Name(iin, ast.Load(), None) for iin in sorted(ii)]
        node.args.args = (
            [ast.Name(iin, ast.Param(), None)
             for iin in sorted(ii)] + node.args.args)
        forged_fdef = ast.FunctionDef(forged_name, copy(node.args),
                                      [ast.Return(node.body)], [], None)
        self.lambda_functions.append(forged_fdef)
        self.global_declarations[forged_name] = forged_fdef
        proxy_call = ast.Name(forged_name, ast.Load(), None)
        if binded_args:
            return ast.Call(
                ast.Attribute(ast.Name(mangle('functools'), ast.Load(), None),
                              "partial", ast.Load()),
                [proxy_call] + binded_args, [])
        else:
            return proxy_call
  def test_ast_to_object(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[gast.Name('a', gast.Param(), None)],
            vararg=None,
            kwonlyargs=[],
            kwarg=None,
            defaults=[],
            kw_defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name('a', gast.Load(), None),
                    right=gast.Num(1)))
        ],
        decorator_list=[],
        returns=None)

    module, source, _ = compiler.ast_to_object(node)

    expected_source = """
      # coding=utf-8
      def f(a):
        return a + 1
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
Exemplo n.º 4
0
    def visit_Lambda(self, node):
        op = issimpleoperator(node)
        if op is not None:
            if mangle('operator') not in self.global_declarations:
                import_ = ast.Import(
                    [ast.alias('operator', mangle('operator'))])
                self.imports.append(import_)
                operator_module = MODULES['operator']
                self.global_declarations[mangle('operator')] = operator_module
            return ast.Attribute(
                ast.Name(mangle('operator'), ast.Load(), None, None), op,
                ast.Load())

        self.generic_visit(node)
        forged_name = "{0}_lambda{1}".format(self.prefix,
                                             len(self.lambda_functions))

        ii = self.gather(ImportedIds, node)
        ii.difference_update(self.lambda_functions)  # remove current lambdas

        binded_args = [
            ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii)
        ]
        node.args.args = (
            [ast.Name(iin, ast.Param(), None, None)
             for iin in sorted(ii)] + node.args.args)
        for patternname, pattern in self.patterns.items():
            if issamelambda(pattern, node):
                proxy_call = ast.Name(patternname, ast.Load(), None, None)
                break
        else:
            duc = ExtendedDefUseChains()
            nodepattern = deepcopy(node)
            duc.visit(ast.Module([ast.Expr(nodepattern)], []))
            self.patterns[forged_name] = nodepattern, duc

            forged_fdef = ast.FunctionDef(forged_name, copy(node.args),
                                          [ast.Return(node.body)], [], None,
                                          None)
            metadata.add(forged_fdef, metadata.Local())
            self.lambda_functions.append(forged_fdef)
            self.global_declarations[forged_name] = forged_fdef
            proxy_call = ast.Name(forged_name, ast.Load(), None, None)

        if binded_args:
            if MODULES['functools'] not in self.global_declarations.values():
                import_ = ast.Import(
                    [ast.alias('functools', mangle('functools'))])
                self.imports.append(import_)
                functools_module = MODULES['functools']
                self.global_declarations[mangle(
                    'functools')] = functools_module

            return ast.Call(
                ast.Attribute(
                    ast.Name(mangle('functools'), ast.Load(), None, None),
                    "partial", ast.Load()), [proxy_call] + binded_args, [])
        else:
            return proxy_call
Exemplo n.º 5
0
 def visit_FunctionDef(self, node):
     new_node = gast.FunctionDef(
         self._visit(node.name),
         self._visit(node.args),
         self._visit(node.body),
         self._visit(node.decorator_list),
         None,  # returns
     )
     ast.copy_location(new_node, node)
     return new_node
Exemplo n.º 6
0
    def visit_Compare(self, node):
        node = self.generic_visit(node)
        if len(node.ops) > 1:
            # in case we have more than one compare operator
            # we generate an auxiliary function
            # that lazily evaluates the needed parameters
            imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx)
            imported_ids = sorted(imported_ids)
            binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids]

            # name of the new function
            forged_name = "{0}_compare{1}".format(self.prefix,
                                                  len(self.compare_functions))

            # call site
            call = ast.Call(ast.Name(forged_name, ast.Load(), None),
                            binded_args, [])

            # new function
            arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids]
            args = ast.arguments(arg_names, None, [], [], None, [])

            body = []  # iteratively fill the body (yeah, feel your body!)

            if is_trivially_copied(node.left):
                prev_holder = node.left
            else:
                body.append(
                    ast.Assign([ast.Name('$0', ast.Store(), None)], node.left))
                prev_holder = ast.Name('$0', ast.Load(), None)

            for i, exp in enumerate(node.comparators):
                if is_trivially_copied(exp):
                    holder = exp
                else:
                    body.append(
                        ast.Assign(
                            [ast.Name('${}'.format(i + 1), ast.Store(), None)],
                            exp))
                    holder = ast.Name('${}'.format(i + 1), ast.Load(), None)
                cond = ast.Compare(prev_holder, [node.ops[i]], [holder])
                body.append(
                    ast.If(
                        cond, [ast.Pass()],
                        [ast.Return(path_to_attr(('__builtin__', 'False')))]))
                prev_holder = holder

            body.append(ast.Return(path_to_attr(('__builtin__', 'True'))))

            forged_fdef = ast.FunctionDef(forged_name, args, body, [], None)
            self.compare_functions.append(forged_fdef)

            return call
        else:
            return node
Exemplo n.º 7
0
 def visit_FunctionDef(self, node):
     new_node = gast.FunctionDef(
         self._visit(node.name),
         self._visit(node.args),
         self._visit(node.body),
         self._visit(node.decorator_list),
         None,  # returns
         None,  # type_comment
     )
     gast.copy_location(new_node, node)
     new_node.end_lineno = new_node.end_col_offset = None
     return new_node
Exemplo n.º 8
0
 def root_build(body):
     """Given a list of statements, puts them into a function in a module."""
     return gast.Module(body=[
         gast.FunctionDef(name="random_function",
                          args=_make_arguments(
                              python_numbers_control_flow.make_name("a"),
                              python_numbers_control_flow.make_name("b")),
                          body=body,
                          decorator_list=[],
                          returns=None,
                          type_comment=None)
     ],
                        type_ignores=[])
Exemplo n.º 9
0
    def visit_AnyComp(self, node, comp_type, *path):
        self.update = True
        node.elt = self.visit(node.elt)
        name = "{0}_comprehension{1}".format(comp_type, self.count)
        self.count += 1
        args = self.gather(ImportedIds, node)
        self.count_iter = 0

        starget = "__target"
        body = reduce(self.nest_reducer,
                      reversed(node.generators),
                      ast.Expr(
                          ast.Call(
                              reduce(lambda x, y: ast.Attribute(x, y,
                                                                ast.Load()),
                                     path[1:],
                                     ast.Name(path[0], ast.Load(),
                                              None, None)),
                              [ast.Name(starget, ast.Load(), None, None),
                               node.elt],
                              [],
                              )
                          )
                      )
        # add extra metadata to this node
        metadata.add(body, metadata.Comprehension(starget))
        init = ast.Assign(
            [ast.Name(starget, ast.Store(), None, None)],
            ast.Call(
                ast.Attribute(
                    ast.Name('builtins', ast.Load(), None, None),
                    comp_type,
                    ast.Load()
                    ),
                [], [],)
            )
        result = ast.Return(ast.Name(starget, ast.Load(), None, None))
        sargs = [ast.Name(arg, ast.Param(), None, None) for arg in args]
        fd = ast.FunctionDef(name,
                             ast.arguments(sargs, [], None, [], [], None, []),
                             [init, body, result],
                             [], None, None)
        metadata.add(fd, metadata.Local())
        self.ctx.module.body.append(fd)
        return ast.Call(
            ast.Name(name, ast.Load(), None, None),
            [ast.Name(arg.id, ast.Load(), None, None) for arg in sargs],
            [],
            )  # no sharing !
Exemplo n.º 10
0
  def generate_FunctionDef(self):
    """Generate a FunctionDef node."""

    # Generate the arguments, register them as available
    arg_vars = self.sample_node_list(
        low=2, high=10, generator=lambda: self.generate_Name(gast.Param()))
    args = gast.arguments(arg_vars, None, [], [], None, [])

    # Generate the function body
    body = self.sample_node_list(
        low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement)
    body.append(self.generate_Return())
    fn_name = self.generate_Name().id
    node = gast.FunctionDef(fn_name, args, body, (), None)
    return node
Exemplo n.º 11
0
def create_funcDef_node(nodes, name, input_args, return_name_ids):
    """
    Wrapper all statements of nodes into one ast.FunctionDef, which can be
    called by ast.Call.
    """
    nodes = copy.copy(nodes)
    # add return statement
    if return_name_ids:
        nodes.append(gast.Return(value=generate_name_node(return_name_ids)))
    else:
        nodes.append(gast.Return(value=None))
    func_def_node = gast.FunctionDef(name=name,
                                     args=input_args,
                                     body=nodes,
                                     decorator_list=[],
                                     returns=None,
                                     type_comment=None)
    return func_def_node
Exemplo n.º 12
0
  def test_load_ast(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[
                gast.Name(
                    'a', ctx=gast.Param(), annotation=None, type_comment=None)
            ],
            posonlyargs=[],
            vararg=None,
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name(
                        'a',
                        ctx=gast.Load(),
                        annotation=None,
                        type_comment=None),
                    right=gast.Constant(1, kind=None)))
        ],
        decorator_list=[],
        returns=None,
        type_comment=None)

    module, source, _ = loader.load_ast(node)

    expected_source = """
      # coding=utf-8
      def f(a):
          return (a + 1)
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
Exemplo n.º 13
0
def joint(node):
  """Merge the bodies of primal and adjoint into a single function.

  Args:
    node: A module with the primal and adjoint function definitions as returned
        by `reverse_ad`.

  Returns:
    func: A `Module` node with a single function definition containing the
        combined primal and adjoint.
  """
  node, _, _ = _fix(node)
  body = node.body[0].body[:-1] + node.body[1].body
  func = gast.Module(body=[gast.FunctionDef(
      name=node.body[0].name, args=node.body[1].args, body=body,
      decorator_list=[], returns=None)])
  # Clean up
  anno.clearanno(func)
  return func
Exemplo n.º 14
0
    def visit_GeneratorExp(self, node):
        self.update = True
        node.elt = self.visit(node.elt)
        name = "generator_expression{0}".format(self.count)
        self.count += 1
        args = self.passmanager.gather(ImportedIds, node, self.ctx)
        self.count_iter = 0

        body = reduce(self.nest_reducer, reversed(node.generators),
                      ast.Expr(ast.Yield(node.elt)))

        sargs = [ast.Name(arg, ast.Param(), None) for arg in args]
        fd = ast.FunctionDef(name, ast.arguments(sargs, None, [], [], None,
                                                 []), [body], [], None)
        self.ctx.module.body.append(fd)
        return ast.Call(
            ast.Name(name, ast.Load(), None),
            [ast.Name(arg.id, ast.Load(), None) for arg in sargs],
            [],
        )  # no sharing !
Exemplo n.º 15
0
    def visit_Module(self, node):
        """Turn globals assignment to functionDef and visit function defs. """
        module_body = list()
        # Gather top level assigned variables.
        for stmt in node.body:
            if not isinstance(stmt, ast.Assign):
                continue
            for target in stmt.targets:
                if not isinstance(target, ast.Name):
                    raise PythranSyntaxError(
                        "Top-level assignment to an expression.", target)
                if target.id in self.to_expand:
                    raise PythranSyntaxError(
                        "Multiple top-level definition of %s." % target.id,
                        target)
                self.to_expand.add(target.id)

        for stmt in node.body:
            if isinstance(stmt, ast.Assign):
                self.local_decl = set()
                cst_value = self.visit(stmt.value)
                for target in stmt.targets:
                    assert isinstance(target, ast.Name)
                    module_body.append(
                        ast.FunctionDef(
                            target.id, ast.arguments([], None, [], [], None,
                                                     []),
                            [ast.Return(value=cst_value)], [], None))
                    metadata.add(module_body[-1].body[0],
                                 metadata.StaticReturn())
            else:
                self.local_decl = self.passmanager.gather(
                    LocalNameDeclarations, stmt, self.ctx)
                module_body.append(self.visit(stmt))

        node.body = module_body
        return node
Exemplo n.º 16
0
    def get_for_stmt_nodes(self, node):
        # TODO: consider for - else in python

        # 1. get key statements for different cases
        # NOTE 1: three key statements:
        #   1). init_stmts: list[node], prepare nodes of for loop, may not only one
        #   2). cond_stmt: node, condition node to judge whether continue loop
        #   3). body_stmts: list[node], updated loop body, sometimes we should change
        #       the original statement in body, not just append new statement
        #
        # NOTE 2: The following `for` statements will be transformed to `while` statements:
        #   1). for x in range(*)
        #   2). for x in iter_var
        #   3). for i, x in enumerate(*)

        current_for_node_parser = ForNodeVisitor(node)
        stmts_tuple = current_for_node_parser.parse()
        if stmts_tuple is None:
            return [node]
        init_stmts, cond_stmt, body_stmts = stmts_tuple

        # 2. get original loop vars
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
        # we need append new loop var & remove useless loop var
        #   1. for x in var -> x is no need
        #   2. for i, x in enumerate(var) -> x is no need
        if current_for_node_parser.is_for_iter(
        ) or current_for_node_parser.is_for_enumerate_iter():
            iter_var_name = current_for_node_parser.iter_var_name
            iter_idx_name = current_for_node_parser.iter_idx_name
            loop_var_names.add(iter_idx_name)
            if iter_var_name not in create_var_names:
                loop_var_names.remove(iter_var_name)

        # 3. prepare result statement list
        new_stmts = []
        # Python can create variable in loop and use it out of loop, E.g.
        #
        # for x in range(10):
        #     y += x
        # print(x) # x = 10
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        # 4. append init statements
        new_stmts.extend(init_stmts)

        # 5. create & append condition function node
        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_CONDITION_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=[gast.Return(value=cond_stmt)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        # 6. create & append loop body function node
        # append return values for loop body
        body_stmts.append(
            gast.Return(
                value=generate_name_node(loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_BODY_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=body_stmts,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        # 7. create & append while loop node
        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name,
                                            loop_var_names)
        new_stmts.append(while_loop_node)

        return new_stmts
Exemplo n.º 17
0
    def get_while_stmt_nodes(self, node):
        # TODO: consider while - else in python
        if not self.name_visitor.is_control_flow_loop(node):
            return [node]

        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []

        # Python can create variable in loop and use it out of loop, E.g.
        #
        # while x < 10:
        #     x += 1
        #     y = x
        # z = y
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        # while x < 10 in dygraph should be convert into static tensor < 10
        for name in loop_var_names:
            new_stmts.append(to_static_variable_gast_node(name))

        logical_op_transformer = LogicalOpTransformer(node.test)
        cond_value_node = logical_op_transformer.transform()

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_CONDITION_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=[gast.Return(value=cond_value_node)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(
            gast.Return(
                value=generate_name_node(loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_BODY_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name,
                                            loop_var_names)
        new_stmts.append(while_loop_node)
        return new_stmts
Exemplo n.º 18
0
    def get_for_stmt_nodes(self, node):
        # TODO: consider for - else in python
        if not self.name_visitor.is_control_flow_loop(node):
            return [node]

        # TODO: support non-range case
        range_call_node = self.get_for_range_node(node)
        if range_call_node is None:
            return [node]

        if not isinstance(node.target, gast.Name):
            return [node]
        iter_var_name = node.target.id

        init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts(
            iter_var_name, range_call_node.args)

        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []
        # Python can create variable in loop and use it out of loop, E.g.
        #
        # for x in range(10):
        #     y += x
        # print(x) # x = 10
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        new_stmts.append(init_stmt)

        # for x in range(10) in dygraph should be convert into static tensor + 1 <= 10
        for name in loop_var_names:
            new_stmts.append(to_static_variable_gast_node(name))

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_CONDITION_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=[gast.Return(value=cond_stmt)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(change_stmt)
        new_body.append(
            gast.Return(
                value=generate_name_node(loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_BODY_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name,
                                            loop_var_names)
        new_stmts.append(while_loop_node)

        return new_stmts
Exemplo n.º 19
0
    def visit_Module(self, node):
        """Turn globals assignment to functionDef and visit function defs. """
        module_body = list()
        symbols = set()
        # Gather top level assigned variables.
        for stmt in node.body:
            if isinstance(stmt, (ast.Import, ast.ImportFrom)):
                for alias in stmt.names:
                    name = alias.asname or alias.name
                    symbols.add(name)  # no warning here
            elif isinstance(stmt, ast.FunctionDef):
                if stmt.name in symbols:
                    raise PythranSyntaxError(
                        "Multiple top-level definition of %s." % stmt.name,
                        stmt)
                else:
                    symbols.add(stmt.name)

            if not isinstance(stmt, ast.Assign):
                continue

            for target in stmt.targets:
                if not isinstance(target, ast.Name):
                    raise PythranSyntaxError(
                        "Top-level assignment to an expression.", target)
                if target.id in self.to_expand:
                    raise PythranSyntaxError(
                        "Multiple top-level definition of %s." % target.id,
                        target)
                if isinstance(stmt.value, ast.Name):
                    if stmt.value.id in symbols:
                        continue  # create aliasing between top level symbols
                self.to_expand.add(target.id)

        for stmt in node.body:
            if isinstance(stmt, ast.Assign):
                # that's not a global var, but a module/function aliasing
                if all(
                        isinstance(t, ast.Name) and t.id not in self.to_expand
                        for t in stmt.targets):
                    module_body.append(stmt)
                    continue

                self.local_decl = set()
                cst_value = GlobalTransformer().visit(self.visit(stmt.value))
                for target in stmt.targets:
                    assert isinstance(target, ast.Name)
                    module_body.append(
                        ast.FunctionDef(
                            target.id,
                            ast.arguments([], [], None, [], [], None, []),
                            [ast.Return(value=cst_value)], [], None, None))
                    metadata.add(module_body[-1].body[0],
                                 metadata.StaticReturn())
            else:
                self.local_decl = self.gather(LocalNameDeclarations, stmt)
                module_body.append(self.visit(stmt))

        self.update |= bool(self.to_expand)

        node.body = module_body
        return node
Exemplo n.º 20
0
    def get_while_stmt_nodes(self, node):
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []

        # Python can create variable in loop and use it out of loop, E.g.
        #
        # while x < 10:
        #     x += 1
        #     y = x
        # z = y
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_CONDITION_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=[gast.Return(value=node.test)],
            decorator_list=[],
            returns=None,
            type_comment=None)

        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(
            gast.Return(value=generate_name_node(
                loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_BODY_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_nodes = create_while_nodes(
            condition_func_node.name, body_func_node.name, loop_var_names)
        new_stmts.extend(while_loop_nodes)
        return new_stmts