예제 #1
0
def outline(name, formal_parameters, out_parameters, stmts, has_return):
    args = ast.arguments(
        [ast.Name(fp, ast.Param(), None) for fp in formal_parameters], None,
        [], [], None, [])

    if isinstance(stmts, ast.expr):
        assert not out_parameters, "no out parameters with expr"
        fdef = ast.FunctionDef(name, args, [ast.Return(stmts)], [], None)
    else:
        fdef = ast.FunctionDef(name, args, stmts, [], None)

        # this is part of a huge trick that plays with delayed type inference
        # it basically computes the return type based on out parameters, and
        # the return statement is unconditionally added so if we have other
        # returns, there will be a computation of the output type based on the
        # __combined of the regular return types and this one The original
        # returns have been patched above to have a different type that
        # cunningly combines with this output tuple
        #
        # This is the only trick I found to let pythran compute both the output
        # variable type and the early return type. But hey, a dirty one :-/

        stmts.append(
            ast.Return(
                ast.Tuple(
                    [ast.Name(fp, ast.Load(), None) for fp in out_parameters],
                    ast.Load())))
        if has_return:
            pr = PatchReturn(stmts[-1])
            pr.visit(fdef)

    return fdef
예제 #2
0
    def visit_Compare(self, node):
        node = self.generic_visit(node)
        if len(node.ops) > 1:
            # in case we have more than one compare operator
            # we generate an auxiliary function
            # that lazily evaluates the needed parameters
            imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx)
            imported_ids = sorted(imported_ids)
            binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids]

            # name of the new function
            forged_name = "{0}_compare{1}".format(self.prefix,
                                                  len(self.compare_functions))

            # call site
            call = ast.Call(ast.Name(forged_name, ast.Load(), None),
                            binded_args, [])

            # new function
            arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids]
            args = ast.arguments(arg_names, None, [], [], None, [])

            body = []  # iteratively fill the body (yeah, feel your body!)

            if is_trivially_copied(node.left):
                prev_holder = node.left
            else:
                body.append(
                    ast.Assign([ast.Name('$0', ast.Store(), None)], node.left))
                prev_holder = ast.Name('$0', ast.Load(), None)

            for i, exp in enumerate(node.comparators):
                if is_trivially_copied(exp):
                    holder = exp
                else:
                    body.append(
                        ast.Assign(
                            [ast.Name('${}'.format(i + 1), ast.Store(), None)],
                            exp))
                    holder = ast.Name('${}'.format(i + 1), ast.Load(), None)
                cond = ast.Compare(prev_holder, [node.ops[i]], [holder])
                body.append(
                    ast.If(
                        cond, [ast.Pass()],
                        [ast.Return(path_to_attr(('__builtin__', 'False')))]))
                prev_holder = holder

            body.append(ast.Return(path_to_attr(('__builtin__', 'True'))))

            forged_fdef = ast.FunctionDef(forged_name, args, body, [], None)
            self.compare_functions.append(forged_fdef)

            return call
        else:
            return node
예제 #3
0
    def visit_Lambda(self, node):
        if MODULES['functools'] not in self.global_declarations.values():
            import_ = ast.Import([ast.alias('functools', mangle('functools'))])
            self.imports.append(import_)
            functools_module = MODULES['functools']
            self.global_declarations[mangle('functools')] = functools_module

        self.generic_visit(node)
        forged_name = "{0}_lambda{1}".format(self.prefix,
                                             len(self.lambda_functions))

        ii = self.passmanager.gather(ImportedIds, node, self.ctx)
        ii.difference_update(self.lambda_functions)  # remove current lambdas

        binded_args = [ast.Name(iin, ast.Load(), None) for iin in sorted(ii)]
        node.args.args = (
            [ast.Name(iin, ast.Param(), None)
             for iin in sorted(ii)] + node.args.args)
        forged_fdef = ast.FunctionDef(forged_name, copy(node.args),
                                      [ast.Return(node.body)], [], None)
        self.lambda_functions.append(forged_fdef)
        self.global_declarations[forged_name] = forged_fdef
        proxy_call = ast.Name(forged_name, ast.Load(), None)
        if binded_args:
            return ast.Call(
                ast.Attribute(ast.Name(mangle('functools'), ast.Load(), None),
                              "partial", ast.Load()),
                [proxy_call] + binded_args, [])
        else:
            return proxy_call
 def visit_FunctionDef(self, node):
     modified_node = self.generic_visit(node)
     returned_id = len(self.func_returned_stack)
     returned_flags = self.func_returned_stack.pop()
     if returned_flags:
         node.body.insert(
             0,
             gast.Assign(targets=[
                 gast.Name(id=self.returned_flag + str(returned_id),
                           ctx=gast.Store(),
                           annotation=None,
                           type_comment=None)
             ],
                         value=gast.Constant(value=False, kind=None)))
     node.body.insert(
         0,
         gast.Assign(targets=[
             gast.Name(id=self.returned_value_key,
                       ctx=gast.Store(),
                       annotation=None,
                       type_comment=None)
         ],
                     value=gast.Constant(value=None, kind=None)))
     node.body.append(
         gast.Return(value=gast.Name(id=self.returned_value_key,
                                     ctx=gast.Load(),
                                     annotation=None,
                                     type_comment=None)))
     return modified_node
  def test_ast_to_object(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[gast.Name('a', gast.Param(), None)],
            vararg=None,
            kwonlyargs=[],
            kwarg=None,
            defaults=[],
            kw_defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name('a', gast.Load(), None),
                    right=gast.Num(1)))
        ],
        decorator_list=[],
        returns=None)

    module, source, _ = compiler.ast_to_object(node)

    expected_source = """
      # coding=utf-8
      def f(a):
        return a + 1
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
예제 #6
0
    def visit_Lambda(self, node):
        op = issimpleoperator(node)
        if op is not None:
            if mangle('operator') not in self.global_declarations:
                import_ = ast.Import(
                    [ast.alias('operator', mangle('operator'))])
                self.imports.append(import_)
                operator_module = MODULES['operator']
                self.global_declarations[mangle('operator')] = operator_module
            return ast.Attribute(
                ast.Name(mangle('operator'), ast.Load(), None, None), op,
                ast.Load())

        self.generic_visit(node)
        forged_name = "{0}_lambda{1}".format(self.prefix,
                                             len(self.lambda_functions))

        ii = self.gather(ImportedIds, node)
        ii.difference_update(self.lambda_functions)  # remove current lambdas

        binded_args = [
            ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii)
        ]
        node.args.args = (
            [ast.Name(iin, ast.Param(), None, None)
             for iin in sorted(ii)] + node.args.args)
        for patternname, pattern in self.patterns.items():
            if issamelambda(pattern, node):
                proxy_call = ast.Name(patternname, ast.Load(), None, None)
                break
        else:
            duc = ExtendedDefUseChains()
            nodepattern = deepcopy(node)
            duc.visit(ast.Module([ast.Expr(nodepattern)], []))
            self.patterns[forged_name] = nodepattern, duc

            forged_fdef = ast.FunctionDef(forged_name, copy(node.args),
                                          [ast.Return(node.body)], [], None,
                                          None)
            metadata.add(forged_fdef, metadata.Local())
            self.lambda_functions.append(forged_fdef)
            self.global_declarations[forged_name] = forged_fdef
            proxy_call = ast.Name(forged_name, ast.Load(), None, None)

        if binded_args:
            if MODULES['functools'] not in self.global_declarations.values():
                import_ = ast.Import(
                    [ast.alias('functools', mangle('functools'))])
                self.imports.append(import_)
                functools_module = MODULES['functools']
                self.global_declarations[mangle(
                    'functools')] = functools_module

            return ast.Call(
                ast.Attribute(
                    ast.Name(mangle('functools'), ast.Load(), None, None),
                    "partial", ast.Load()), [proxy_call] + binded_args, [])
        else:
            return proxy_call
예제 #7
0
def create_funcDef_node(nodes, name, input_args, return_name_ids):
    """
    Wrapper all statements of nodes into one ast.FunctionDef, which can be
    called by ast.Call.
    """
    nodes = copy.copy(nodes)
    # add return statement
    if return_name_ids:
        nodes.append(gast.Return(value=generate_name_node(return_name_ids)))
    else:
        nodes.append(gast.Return(value=None))
    func_def_node = gast.FunctionDef(name=name,
                                     args=input_args,
                                     body=nodes,
                                     decorator_list=[],
                                     returns=None,
                                     type_comment=None)
    return func_def_node
예제 #8
0
    def visit_FunctionDef(self, node):
        self.yield_points = self.gather(YieldPoints, node)
        for stmt in node.body:
            self.visit(stmt)
        # Look for nodes that have no successors; the predecessors of
        # the special NIL node are those AST nodes that end control flow
        # without a return statement.
        for n in self.cfg.predecessors(CFG.NIL):
            if not isinstance(n, (ast.Return, ast.Raise)):
                self.update = True
                if self.yield_points:
                    node.body.append(ast.Return(None))
                else:
                    none = ast.Attribute(
                        ast.Name("__builtin__", ast.Load(), None, None),
                        'None', ast.Load())
                    node.body.append(ast.Return(none))
                break

        return node
예제 #9
0
    def visit_Return(self, node):
        if node is self.guard:
            holder = "StaticIfNoReturn"
        else:
            holder = "StaticIfReturn"

        return ast.Return(
            ast.Call(
                ast.Attribute(
                    ast.Attribute(ast.Name("__builtin__", ast.Load(), None),
                                  "pythran", ast.Load()), holder, ast.Load()),
                [node.value], []))
예제 #10
0
    def __init__(self, astc, args, func_field):
        super().__init__()
        assert isinstance(astc.nast, (gast.FunctionDef, gast.Lambda))

        self.name = astc.gast.name if isinstance(astc.nast, gast.FunctionDef) else (lambda: None).__name__
        self.args = args
        self.func_field = func_field
        if isinstance(astc.nast, gast.Lambda):
            astc.nast.body = gast.Return(value=astc.nast.body) # Add return to the body
        self.ast = astc.nast
        self.filename = astc.filename
        self.lineno = astc.lineno
예제 #11
0
    def visit_Return(self, node):
        if node is self.guard:
            holder = "StaticIfNoReturn"
        else:
            holder = "StaticIfReturn"

        value = node.value

        return ast.Return(
            ast.Call(
                ast.Attribute(
                    ast.Attribute(ast.Name("builtins", ast.Load(), None, None),
                                  "pythran", ast.Load()), holder, ast.Load()),
                [value] if value else [ast.Constant(None, None)], []))
예제 #12
0
    def visit_AnyComp(self, node, comp_type, *path):
        self.update = True
        node.elt = self.visit(node.elt)
        name = "{0}_comprehension{1}".format(comp_type, self.count)
        self.count += 1
        args = self.gather(ImportedIds, node)
        self.count_iter = 0

        starget = "__target"
        body = reduce(self.nest_reducer,
                      reversed(node.generators),
                      ast.Expr(
                          ast.Call(
                              reduce(lambda x, y: ast.Attribute(x, y,
                                                                ast.Load()),
                                     path[1:],
                                     ast.Name(path[0], ast.Load(),
                                              None, None)),
                              [ast.Name(starget, ast.Load(), None, None),
                               node.elt],
                              [],
                              )
                          )
                      )
        # add extra metadata to this node
        metadata.add(body, metadata.Comprehension(starget))
        init = ast.Assign(
            [ast.Name(starget, ast.Store(), None, None)],
            ast.Call(
                ast.Attribute(
                    ast.Name('builtins', ast.Load(), None, None),
                    comp_type,
                    ast.Load()
                    ),
                [], [],)
            )
        result = ast.Return(ast.Name(starget, ast.Load(), None, None))
        sargs = [ast.Name(arg, ast.Param(), None, None) for arg in args]
        fd = ast.FunctionDef(name,
                             ast.arguments(sargs, [], None, [], [], None, []),
                             [init, body, result],
                             [], None, None)
        metadata.add(fd, metadata.Local())
        self.ctx.module.body.append(fd)
        return ast.Call(
            ast.Name(name, ast.Load(), None, None),
            [ast.Name(arg.id, ast.Load(), None, None) for arg in sargs],
            [],
            )  # no sharing !
예제 #13
0
  def test_load_ast(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[
                gast.Name(
                    'a', ctx=gast.Param(), annotation=None, type_comment=None)
            ],
            posonlyargs=[],
            vararg=None,
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name(
                        'a',
                        ctx=gast.Load(),
                        annotation=None,
                        type_comment=None),
                    right=gast.Constant(1, kind=None)))
        ],
        decorator_list=[],
        returns=None,
        type_comment=None)

    module, source, _ = loader.load_ast(node)

    expected_source = """
      # coding=utf-8
      def f(a):
          return (a + 1)
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
예제 #14
0
    def visit_Module(self, node):
        """Turn globals assignment to functionDef and visit function defs. """
        module_body = list()
        # Gather top level assigned variables.
        for stmt in node.body:
            if not isinstance(stmt, ast.Assign):
                continue
            for target in stmt.targets:
                if not isinstance(target, ast.Name):
                    raise PythranSyntaxError(
                        "Top-level assignment to an expression.", target)
                if target.id in self.to_expand:
                    raise PythranSyntaxError(
                        "Multiple top-level definition of %s." % target.id,
                        target)
                self.to_expand.add(target.id)

        for stmt in node.body:
            if isinstance(stmt, ast.Assign):
                self.local_decl = set()
                cst_value = self.visit(stmt.value)
                for target in stmt.targets:
                    assert isinstance(target, ast.Name)
                    module_body.append(
                        ast.FunctionDef(
                            target.id, ast.arguments([], None, [], [], None,
                                                     []),
                            [ast.Return(value=cst_value)], [], None))
                    metadata.add(module_body[-1].body[0],
                                 metadata.StaticReturn())
            else:
                self.local_decl = self.passmanager.gather(
                    LocalNameDeclarations, stmt, self.ctx)
                module_body.append(self.visit(stmt))

        node.body = module_body
        return node
 def fill(self, hole, rng):
     stmts_hole = Hole(ASTHoleType.STMTS, hole.metadata)
     number_hole = Hole(ASTHoleType.NUMBER, hole.metadata)
     return ASTWithHoles(1, [stmts_hole, number_hole],
                         lambda stmts, v: stmts + [gast.Return(value=v)])
예제 #16
0
class PyASTGraphsTest(parameterized.TestCase):
    def test_schema(self):
        """Check that the generated schema is reasonable."""
        # Building should succeed
        schema = py_ast_graphs.SCHEMA

        # Should have one graph node for each AST node, plus sorted helpers.
        seq_helpers = [
            "BoolOp_values-seq-helper",
            "Call_args-seq-helper",
            "For_body-seq-helper",
            "FunctionDef_body-seq-helper",
            "If_body-seq-helper",
            "If_orelse-seq-helper",
            "Module_body-seq-helper",
            "While_body-seq-helper",
            "arguments_args-seq-helper",
        ]
        expected_keys = [
            *py_ast_graphs.PY_AST_SPECS.keys(),
            *seq_helpers,
        ]
        self.assertEqual(list(schema.keys()), expected_keys)

        # Check a few elements
        expected_fundef_in_edges = {
            "parent_in",
            "returns_in",
            "returns_missing",
            "body_in",
            "args_in",
        }
        expected_fundef_out_edges = {
            "parent_out",
            "returns_out",
            "body_out_all",
            "body_out_first",
            "body_out_last",
            "args_out",
        }
        self.assertEqual(set(schema["FunctionDef"].in_edges),
                         expected_fundef_in_edges)
        self.assertEqual(set(schema["FunctionDef"].out_edges),
                         expected_fundef_out_edges)

        expected_expr_in_edges = {"parent_in", "value_in"}
        expected_expr_out_edges = {"parent_out", "value_out"}
        self.assertEqual(set(schema["Expr"].in_edges), expected_expr_in_edges)
        self.assertEqual(set(schema["Expr"].out_edges),
                         expected_expr_out_edges)

        expected_seq_helper_in_edges = {
            "parent_in",
            "item_in",
            "next_in",
            "next_missing",
            "prev_in",
            "prev_missing",
        }
        expected_seq_helper_out_edges = {
            "parent_out",
            "item_out",
            "next_out",
            "prev_out",
        }
        for seq_helper in seq_helpers:
            self.assertEqual(set(schema[seq_helper].in_edges),
                             expected_seq_helper_in_edges)
            self.assertEqual(set(schema[seq_helper].out_edges),
                             expected_seq_helper_out_edges)

    def test_ast_graph_conforms_to_schema(self):
        # Some example code using a few different syntactic constructs, to cover
        # a large set of nodes in the schema
        root = gast.parse(
            textwrap.dedent("""\
        def foo(n):
          if n <= 1:
            return 1
          else:
            return foo(n-1) + foo(n-2)

        def bar(m, n) -> int:
          x = n
          for i in range(m):
            if False:
              continue
            x = x + i
          while True:
            break
          return x

        x0 = 1 + 2 - 3 * 4 / 5
        x1 = (1 == 2) and (3 < 4) and (5 > 6)
        x2 = (7 <= 8) and (9 >= 10) or (11 != 12)
        x2 = bar(13, 14 + 15)
        """))

        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        # Graph should match the schema
        schema_util.assert_conforms_to_schema(graph, py_ast_graphs.SCHEMA)

    def test_ast_graph_nodes(self):
        """Check node IDs, node types, and forward mapping."""
        root = gast.parse(
            textwrap.dedent("""\
        pass
        def foo(n):
            if n <= 1:
              return 1
        """))

        graph, forward_map = py_ast_graphs.py_ast_to_graph(root)

        # pytype: disable=attribute-error
        self.assertIn("root__Module", graph)
        self.assertEqual(graph["root__Module"].node_type, "Module")
        self.assertEqual(forward_map[id(root)], "root__Module")

        self.assertIn("root_body_1__Module_body-seq-helper", graph)
        self.assertEqual(
            graph["root_body_1__Module_body-seq-helper"].node_type,
            "Module_body-seq-helper")

        self.assertIn("root_body_1_item_body_0_item__If", graph)
        self.assertEqual(graph["root_body_1_item_body_0_item__If"].node_type,
                         "If")
        self.assertEqual(forward_map[id(root.body[1].body[0])],
                         "root_body_1_item_body_0_item__If")

        self.assertIn("root_body_1_item_body_0_item_test_left__Name", graph)
        self.assertEqual(
            graph["root_body_1_item_body_0_item_test_left__Name"].node_type,
            "Name")
        self.assertEqual(forward_map[id(root.body[1].body[0].test.left)],
                         "root_body_1_item_body_0_item_test_left__Name")
        # pytype: enable=attribute-error

    def test_ast_graph_unique_field_edges(self):
        """Test that edges for unique fields are correct."""
        root = gast.parse("print(1)")
        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        self.assertEqual(
            graph["root_body_0_item__Expr"].out_edges["value_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_0_item_value__Call"),
                    in_edge=graph_types.InEdgeType("parent_in"))
            ])

        self.assertEqual(
            graph["root_body_0_item_value__Call"].out_edges["parent_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_0_item__Expr"),
                    in_edge=graph_types.InEdgeType("value_in"))
            ])

    def test_ast_graph_optional_field_edges(self):
        """Test that edges for optional fields are correct."""
        root = gast.parse("return 1\nreturn")
        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        self.assertEqual(
            graph["root_body_0_item__Return"].out_edges["value_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId(
                        "root_body_0_item_value__Constant"),
                    in_edge=graph_types.InEdgeType("parent_in"))
            ])

        self.assertEqual(
            graph["root_body_0_item_value__Constant"].out_edges["parent_out"],
            [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_0_item__Return"),
                    in_edge=graph_types.InEdgeType("value_in"))
            ])

        self.assertEqual(
            graph["root_body_1_item__Return"].out_edges["value_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_1_item__Return"),
                    in_edge=graph_types.InEdgeType("value_missing"))
            ])

    def test_ast_graph_sequence_field_edges(self):
        """Test that edges for sequence fields are correct.

    Note that sequence fields produce connections between three nodes: the
    parent, the helper node, and the child.
    """
        root = gast.parse(
            textwrap.dedent("""\
        print(1)
        print(2)
        print(3)
        print(4)
        print(5)
        print(6)
        """))

        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        # Child edges from the parent node
        node = graph["root__Module"]
        self.assertLen(node.out_edges["body_out_all"], 6)
        self.assertEqual(node.out_edges["body_out_first"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_0__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("parent_in"))
        ])
        self.assertEqual(node.out_edges["body_out_last"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_5__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("parent_in"))
        ])

        # Edges from the sequence helper
        node = graph["root_body_0__Module_body-seq-helper"]
        self.assertEqual(node.out_edges["parent_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId("root__Module"),
                in_edge=graph_types.InEdgeType("body_in"))
        ])
        self.assertEqual(node.out_edges["item_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId("root_body_0_item__Expr"),
                in_edge=graph_types.InEdgeType("parent_in"))
        ])
        self.assertEqual(node.out_edges["prev_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_0__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("prev_missing"))
        ])
        self.assertEqual(node.out_edges["next_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_1__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("prev_in"))
        ])

        # Parent edge of the item
        node = graph["root_body_0_item__Expr"]
        self.assertEqual(node.out_edges["parent_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_0__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("item_in"))
        ])

    @parameterized.named_parameters(
        {
            "testcase_name": "unexpected_type",
            "ast": gast.Subscript(value=None, slice=None, ctx=None),
            "expected_error": "Unknown AST node type 'Subscript'",
        }, {
            "testcase_name":
            "too_many_unique",
            "ast":
            gast.Assign(targets=[
                gast.Name("foo", gast.Store(), None, None),
                gast.Name("bar", gast.Store(), None, None)
            ],
                        value=gast.Constant(True, None)),
            "expected_error":
            "Expected 1 child for field 'targets' of node .*; got 2",
        }, {
            "testcase_name":
            "missing_unique",
            "ast":
            gast.Assign(targets=[], value=gast.Constant(True, None)),
            "expected_error":
            "Expected 1 child for field 'targets' of node .*; got 0",
        }, {
            "testcase_name":
            "too_many_optional",
            "ast":
            gast.Return(value=[
                gast.Name("foo", gast.Load(), None, None),
                gast.Name("bar", gast.Load(), None, None)
            ]),
            "expected_error":
            "Expected at most 1 child for field 'value' of node .*; got 2",
        })
    def test_invalid_graphs(self, ast, expected_error):
        with self.assertRaisesRegex(ValueError, expected_error):
            py_ast_graphs.py_ast_to_graph(ast)
예제 #17
0
  def visit_FunctionDef(self, node):
    # Construct a namer to guarantee we create unique names that don't
    # override existing names
    self.namer = naming.Namer.build(node)

    # Check that this function has exactly one return statement at the end
    return_nodes = [n for n in gast.walk(node) if isinstance(n, gast.Return)]
    if ((len(return_nodes) > 1) or not isinstance(node.body[-1], gast.Return)):
      raise ValueError('function must have exactly one return statement')
    return_node = ast_.copy_node(return_nodes[0])

    # Perform AD on the function body
    body, adjoint_body = self.visit_statements(node.body[:-1])

    # Annotate the first statement of the primal and adjoint as such
    if body:
      body[0] = comments.add_comment(body[0], 'Beginning of forward pass')
    if adjoint_body:
      adjoint_body[0] = comments.add_comment(
          adjoint_body[0], 'Beginning of backward pass')

    # Before updating the primal arguments, extract the arguments we want
    # to differentiate with respect to
    dx = gast.Tuple([create.create_grad(node.args.args[i], self.namer)
                     for i in self.wrt], ctx=gast.Load())

    if self.preserve_result:
      # Append an extra Assign operation to the primal body
      # that saves the original output value
      stored_result_node = quoting.quote(self.namer.unique('result'))
      assign_stored_result = template.replace(
          'result=orig_result',
          result=stored_result_node,
          orig_result=return_node.value)
      body.append(assign_stored_result)
      dx.elts.append(stored_result_node)

    for _dx in dx.elts:
      _dx.ctx = gast.Load()
    return_dx = gast.Return(value=dx)

    # We add the stack as first argument of the primal
    node.args.args = [self.stack] + node.args.args

    # Rename the function to its primal name
    func = anno.getanno(node, 'func')
    node.name = naming.primal_name(func, self.wrt)

    # The new body is the primal body plus the return statement
    node.body = body + node.body[-1:]

    # Find the cost; the first variable of potentially multiple return values
    # The adjoint will receive a value for the initial gradient of the cost
    y = node.body[-1].value
    if isinstance(y, gast.Tuple):
      y = y.elts[0]
    dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(),
                   annotation=None)

    # Construct the adjoint
    adjoint_template = grads.adjoints[gast.FunctionDef]
    adjoint, = template.replace(adjoint_template, namer=self.namer,
                                adjoint_body=adjoint_body, return_dx=return_dx)
    adjoint.args.args.extend([self.stack, dy])
    adjoint.args.args.extend(node.args.args[1:])
    adjoint.name = naming.adjoint_name(func, self.wrt)

    return node, adjoint
예제 #18
0
    def visit_Module(self, node):
        """Turn globals assignment to functionDef and visit function defs. """
        module_body = list()
        symbols = set()
        # Gather top level assigned variables.
        for stmt in node.body:
            if isinstance(stmt, (ast.Import, ast.ImportFrom)):
                for alias in stmt.names:
                    name = alias.asname or alias.name
                    symbols.add(name)  # no warning here
            elif isinstance(stmt, ast.FunctionDef):
                if stmt.name in symbols:
                    raise PythranSyntaxError(
                        "Multiple top-level definition of %s." % stmt.name,
                        stmt)
                else:
                    symbols.add(stmt.name)

            if not isinstance(stmt, ast.Assign):
                continue

            for target in stmt.targets:
                if not isinstance(target, ast.Name):
                    raise PythranSyntaxError(
                        "Top-level assignment to an expression.", target)
                if target.id in self.to_expand:
                    raise PythranSyntaxError(
                        "Multiple top-level definition of %s." % target.id,
                        target)
                if isinstance(stmt.value, ast.Name):
                    if stmt.value.id in symbols:
                        continue  # create aliasing between top level symbols
                self.to_expand.add(target.id)

        for stmt in node.body:
            if isinstance(stmt, ast.Assign):
                # that's not a global var, but a module/function aliasing
                if all(
                        isinstance(t, ast.Name) and t.id not in self.to_expand
                        for t in stmt.targets):
                    module_body.append(stmt)
                    continue

                self.local_decl = set()
                cst_value = GlobalTransformer().visit(self.visit(stmt.value))
                for target in stmt.targets:
                    assert isinstance(target, ast.Name)
                    module_body.append(
                        ast.FunctionDef(
                            target.id,
                            ast.arguments([], [], None, [], [], None, []),
                            [ast.Return(value=cst_value)], [], None, None))
                    metadata.add(module_body[-1].body[0],
                                 metadata.StaticReturn())
            else:
                self.local_decl = self.gather(LocalNameDeclarations, stmt)
                module_body.append(self.visit(stmt))

        self.update |= bool(self.to_expand)

        node.body = module_body
        return node
예제 #19
0
    def visit_FunctionDef(self, node):
        self.function_def.append(node)
        self.return_value_name[node] = None
        self.return_name[node] = []
        self.return_no_value_name[node] = []

        self.pre_analysis = ReturnAnalysisVisitor(node)
        max_return_length = self.pre_analysis.get_func_max_return_length(node)
        while self.pre_analysis.get_func_return_count(node) > 1:
            self.generic_visit(node)
            self.pre_analysis = ReturnAnalysisVisitor(node)

        if max_return_length == 0:
            self.function_def.pop()
            return node

        # Prepend initialization of final return and append final return statement
        value_name = self.return_value_name[node]
        if value_name is not None:
            node.body.append(
                gast.Return(value=gast.Name(
                    id=value_name,
                    ctx=gast.Load(),
                    annotation=None,
                    type_comment=None)))
            init_names = [
                unique_name.generate(RETURN_VALUE_INIT_NAME)
                for i in range(max_return_length)
            ]
            assign_zero_nodes = [
                create_fill_constant_node(iname, 0.0) for iname in init_names
            ]
            if len(init_names) == 1:
                return_value_nodes = gast.Name(
                    id=init_names[0],
                    ctx=gast.Load(),
                    annotation=None,
                    type_comment=None)
            else:
                # We need to initialize return value as a tuple because control
                # flow requires some inputs or outputs have same structure
                return_value_nodes = gast.Tuple(
                    elts=[
                        gast.Name(
                            id=iname,
                            ctx=gast.Load(),
                            annotation=None,
                            type_comment=None) for iname in init_names
                    ],
                    ctx=gast.Load())
            assign_return_value_node = gast.Assign(
                targets=[
                    gast.Name(
                        id=value_name,
                        ctx=gast.Store(),
                        annotation=None,
                        type_comment=None)
                ],
                value=return_value_nodes)
            node.body.insert(0, assign_return_value_node)
            node.body[:0] = assign_zero_nodes
        # Prepend control flow boolean nodes such as '__return@1 = False'
        for name in self.return_name[node]:
            assign_false_node = create_fill_constant_node(name, False)
            node.body.insert(0, assign_false_node)
        # Prepend no value placeholders
        for name in self.return_no_value_name[node]:
            assign_no_value_node = create_fill_constant_node(
                name, RETURN_NO_VALUE_MAGIC_NUM)
            node.body.insert(0, assign_no_value_node)

        self.function_def.pop()
        return node
예제 #20
0
    def get_for_stmt_nodes(self, node):
        # TODO: consider for - else in python

        # 1. get key statements for different cases
        # NOTE 1: three key statements:
        #   1). init_stmts: list[node], prepare nodes of for loop, may not only one
        #   2). cond_stmt: node, condition node to judge whether continue loop
        #   3). body_stmts: list[node], updated loop body, sometimes we should change
        #       the original statement in body, not just append new statement
        #
        # NOTE 2: The following `for` statements will be transformed to `while` statements:
        #   1). for x in range(*)
        #   2). for x in iter_var
        #   3). for i, x in enumerate(*)

        current_for_node_parser = ForNodeVisitor(node)
        stmts_tuple = current_for_node_parser.parse()
        if stmts_tuple is None:
            return [node]
        init_stmts, cond_stmt, body_stmts = stmts_tuple

        # 2. get original loop vars
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
        # we need append new loop var & remove useless loop var
        #   1. for x in var -> x is no need
        #   2. for i, x in enumerate(var) -> x is no need
        if current_for_node_parser.is_for_iter(
        ) or current_for_node_parser.is_for_enumerate_iter():
            iter_var_name = current_for_node_parser.iter_var_name
            iter_idx_name = current_for_node_parser.iter_idx_name
            loop_var_names.add(iter_idx_name)
            if iter_var_name not in create_var_names:
                loop_var_names.remove(iter_var_name)

        # 3. prepare result statement list
        new_stmts = []
        # Python can create variable in loop and use it out of loop, E.g.
        #
        # for x in range(10):
        #     y += x
        # print(x) # x = 10
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        # 4. append init statements
        new_stmts.extend(init_stmts)

        # 5. create & append condition function node
        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_CONDITION_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=[gast.Return(value=cond_stmt)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        # 6. create & append loop body function node
        # append return values for loop body
        body_stmts.append(
            gast.Return(
                value=generate_name_node(loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_BODY_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=body_stmts,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        # 7. create & append while loop node
        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name,
                                            loop_var_names)
        new_stmts.append(while_loop_node)

        return new_stmts
예제 #21
0
    def visit_If(self, node):
        if node.test not in self.static_expressions:
            return self.generic_visit(node)

        imported_ids = self.gather(ImportedIds, node)

        assigned_ids_left = self.escaping_ids(node, node.body)
        assigned_ids_right = self.escaping_ids(node, node.orelse)
        assigned_ids_both = assigned_ids_left.union(assigned_ids_right)

        imported_ids.update(i for i in assigned_ids_left
                            if i not in assigned_ids_right)
        imported_ids.update(i for i in assigned_ids_right
                            if i not in assigned_ids_left)
        imported_ids = sorted(imported_ids)

        assigned_ids = sorted(assigned_ids_both)

        fbody = self.make_fake(node.body)
        true_has_return = self.gather(HasReturn, fbody)
        true_has_break = self.gather(HasBreak, fbody)
        true_has_cont = self.gather(HasContinue, fbody)

        felse = self.make_fake(node.orelse)
        false_has_return = self.gather(HasReturn, felse)
        false_has_break = self.gather(HasBreak, felse)
        false_has_cont = self.gather(HasContinue, felse)

        has_return = true_has_return or false_has_return
        has_break = true_has_break or false_has_break
        has_cont = true_has_cont or false_has_cont

        self.generic_visit(node)

        func_true = outline(self.true_name(), imported_ids, assigned_ids,
                            node.body, has_return, has_break, has_cont)
        func_false = outline(self.false_name(), imported_ids, assigned_ids,
                             node.orelse, has_return, has_break, has_cont)
        self.new_functions.extend((func_true, func_false))

        actual_call = self.make_dispatcher(node.test,
                                           func_true, func_false, imported_ids)

        # variable modified within the static_if
        expected_return = [ast.Name(ii, ast.Store(), None, None)
                           for ii in assigned_ids]

        self.update = True

        # name for various variables resulting from the static_if
        n = len(self.new_functions)
        status_n = "$status{}".format(n)
        return_n = "$return{}".format(n)
        cont_n = "$cont{}".format(n)

        if has_return:
            cfg = self.cfgs[-1]
            always_return = all(isinstance(x, (ast.Return, ast.Yield))
                                for x in cfg[node])
            always_return &= true_has_return and false_has_return

            fast_return = [ast.Name(status_n, ast.Store(), None, None),
                           ast.Name(return_n, ast.Store(), None, None),
                           ast.Name(cont_n, ast.Store(), None, None)]

            if always_return:
                return [ast.Assign([ast.Tuple(fast_return, ast.Store())],
                                   actual_call, None),
                        ast.Return(ast.Name(return_n, ast.Load(), None, None))]
            else:
                cont_ass = self.make_control_flow_handlers(cont_n, status_n,
                                                           expected_return,
                                                           has_cont, has_break)

                cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None),
                                   [ast.Eq()], [ast.Constant(EARLY_RET, None)])
                return [ast.Assign([ast.Tuple(fast_return, ast.Store())],
                                   actual_call, None),
                        ast.If(cmpr,
                               [ast.Return(ast.Name(return_n, ast.Load(),
                                                    None, None))],
                               cont_ass)]
        elif has_break or has_cont:
            cont_ass = self.make_control_flow_handlers(cont_n, status_n,
                                                       expected_return,
                                                       has_cont, has_break)

            fast_return = [ast.Name(status_n, ast.Store(), None, None),
                           ast.Name(cont_n, ast.Store(), None, None)]
            return [ast.Assign([ast.Tuple(fast_return, ast.Store())],
                               actual_call, None)] + cont_ass
        elif expected_return:
            return ast.Assign([ast.Tuple(expected_return, ast.Store())],
                              actual_call, None)
        else:
            return ast.Expr(actual_call)
예제 #22
0
    def get_while_stmt_nodes(self, node):
        # TODO: consider while - else in python
        if not self.name_visitor.is_control_flow_loop(node):
            return [node]

        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []

        # Python can create variable in loop and use it out of loop, E.g.
        #
        # while x < 10:
        #     x += 1
        #     y = x
        # z = y
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        # while x < 10 in dygraph should be convert into static tensor < 10
        for name in loop_var_names:
            new_stmts.append(to_static_variable_gast_node(name))

        logical_op_transformer = LogicalOpTransformer(node.test)
        cond_value_node = logical_op_transformer.transform()

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_CONDITION_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=[gast.Return(value=cond_value_node)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(
            gast.Return(
                value=generate_name_node(loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_BODY_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name,
                                            loop_var_names)
        new_stmts.append(while_loop_node)
        return new_stmts
예제 #23
0
    def get_for_stmt_nodes(self, node):
        # TODO: consider for - else in python
        if not self.name_visitor.is_control_flow_loop(node):
            return [node]

        # TODO: support non-range case
        range_call_node = self.get_for_range_node(node)
        if range_call_node is None:
            return [node]

        if not isinstance(node.target, gast.Name):
            return [node]
        iter_var_name = node.target.id

        init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts(
            iter_var_name, range_call_node.args)

        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []
        # Python can create variable in loop and use it out of loop, E.g.
        #
        # for x in range(10):
        #     y += x
        # print(x) # x = 10
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        new_stmts.append(init_stmt)

        # for x in range(10) in dygraph should be convert into static tensor + 1 <= 10
        for name in loop_var_names:
            new_stmts.append(to_static_variable_gast_node(name))

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_CONDITION_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=[gast.Return(value=cond_stmt)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(change_stmt)
        new_body.append(
            gast.Return(
                value=generate_name_node(loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_BODY_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name,
                                            loop_var_names)
        new_stmts.append(while_loop_node)

        return new_stmts
예제 #24
0
파일: codegen.py 프로젝트: Harryi0/tinyML
 def generate_Return(self):
   return gast.Return(self.generate_expression())
예제 #25
0
    def visit_If(self, node):
        self.generic_visit(node)
        if node.test not in self.static_expressions:
            return node

        imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx)

        assigned_ids_left = set(
            self.passmanager.gather(IsAssigned, self.make_fake(node.body),
                                    self.ctx).keys())
        assigned_ids_right = set(
            self.passmanager.gather(IsAssigned, self.make_fake(node.orelse),
                                    self.ctx).keys())
        assigned_ids_both = assigned_ids_left.union(assigned_ids_right)

        imported_ids.update(i for i in assigned_ids_left
                            if i not in assigned_ids_right)
        imported_ids.update(i for i in assigned_ids_right
                            if i not in assigned_ids_left)
        imported_ids = sorted(imported_ids)

        assigned_ids = sorted(assigned_ids_both)

        true_has_return = self.passmanager.gather(HasReturn,
                                                  self.make_fake(node.body),
                                                  self.ctx)
        false_has_return = self.passmanager.gather(HasReturn,
                                                   self.make_fake(node.orelse),
                                                   self.ctx)

        has_return = true_has_return or false_has_return

        func_true = outline(self.true_name(), imported_ids, assigned_ids,
                            node.body, has_return)
        func_false = outline(self.false_name(), imported_ids, assigned_ids,
                             node.orelse, has_return)
        self.new_functions.extend((func_true, func_false))

        actual_call = self.make_dispatcher(node.test, func_true, func_false,
                                           imported_ids)

        expected_return = [
            ast.Name(ii, ast.Load(), None) for ii in assigned_ids
        ]

        if has_return:
            n = len(self.new_functions)
            fast_return = [
                ast.Name("$status{}".format(n), ast.Load(), None),
                ast.Name("$return{}".format(n), ast.Load(), None),
                ast.Name("$cont{}".format(n), ast.Load(), None)
            ]

            if expected_return:
                cont_ass = [
                    ast.Assign([ast.Tuple(expected_return, ast.Store())],
                               ast.Name("$cont{}".format(n), ast.Load(), None))
                ]
            else:
                cont_ass = []

            return [
                ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call),
                ast.If(ast.Name("$status{}".format(n), ast.Load(), None), [
                    ast.Return(
                        ast.Name("$return{}".format(n), ast.Load(), None))
                ], cont_ass)
            ]
        elif expected_return:
            return ast.Assign([ast.Tuple(expected_return, ast.Store())],
                              actual_call)
        else:
            return ast.Expr(actual_call)
예제 #26
0
    def get_while_stmt_nodes(self, node):
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []

        # Python can create variable in loop and use it out of loop, E.g.
        #
        # while x < 10:
        #     x += 1
        #     y = x
        # z = y
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_CONDITION_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=[gast.Return(value=node.test)],
            decorator_list=[],
            returns=None,
            type_comment=None)

        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(
            gast.Return(value=generate_name_node(
                loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_BODY_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_nodes = create_while_nodes(
            condition_func_node.name, body_func_node.name, loop_var_names)
        new_stmts.extend(while_loop_nodes)
        return new_stmts
 def fill(self, hole, rng):
     stmts_hole = Hole(ASTHoleType.STMTS, hole.metadata)
     return ASTWithHoles(1, [stmts_hole],
                         lambda stmts: stmts + [gast.Return(value=None)])