def outline(name, formal_parameters, out_parameters, stmts, has_return): args = ast.arguments( [ast.Name(fp, ast.Param(), None) for fp in formal_parameters], None, [], [], None, []) if isinstance(stmts, ast.expr): assert not out_parameters, "no out parameters with expr" fdef = ast.FunctionDef(name, args, [ast.Return(stmts)], [], None) else: fdef = ast.FunctionDef(name, args, stmts, [], None) # this is part of a huge trick that plays with delayed type inference # it basically computes the return type based on out parameters, and # the return statement is unconditionally added so if we have other # returns, there will be a computation of the output type based on the # __combined of the regular return types and this one The original # returns have been patched above to have a different type that # cunningly combines with this output tuple # # This is the only trick I found to let pythran compute both the output # variable type and the early return type. But hey, a dirty one :-/ stmts.append( ast.Return( ast.Tuple( [ast.Name(fp, ast.Load(), None) for fp in out_parameters], ast.Load()))) if has_return: pr = PatchReturn(stmts[-1]) pr.visit(fdef) return fdef
def visit_Compare(self, node): node = self.generic_visit(node) if len(node.ops) > 1: # in case we have more than one compare operator # we generate an auxiliary function # that lazily evaluates the needed parameters imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx) imported_ids = sorted(imported_ids) binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids] # name of the new function forged_name = "{0}_compare{1}".format(self.prefix, len(self.compare_functions)) # call site call = ast.Call(ast.Name(forged_name, ast.Load(), None), binded_args, []) # new function arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids] args = ast.arguments(arg_names, None, [], [], None, []) body = [] # iteratively fill the body (yeah, feel your body!) if is_trivially_copied(node.left): prev_holder = node.left else: body.append( ast.Assign([ast.Name('$0', ast.Store(), None)], node.left)) prev_holder = ast.Name('$0', ast.Load(), None) for i, exp in enumerate(node.comparators): if is_trivially_copied(exp): holder = exp else: body.append( ast.Assign( [ast.Name('${}'.format(i + 1), ast.Store(), None)], exp)) holder = ast.Name('${}'.format(i + 1), ast.Load(), None) cond = ast.Compare(prev_holder, [node.ops[i]], [holder]) body.append( ast.If( cond, [ast.Pass()], [ast.Return(path_to_attr(('__builtin__', 'False')))])) prev_holder = holder body.append(ast.Return(path_to_attr(('__builtin__', 'True')))) forged_fdef = ast.FunctionDef(forged_name, args, body, [], None) self.compare_functions.append(forged_fdef) return call else: return node
def visit_Lambda(self, node): if MODULES['functools'] not in self.global_declarations.values(): import_ = ast.Import([ast.alias('functools', mangle('functools'))]) self.imports.append(import_) functools_module = MODULES['functools'] self.global_declarations[mangle('functools')] = functools_module self.generic_visit(node) forged_name = "{0}_lambda{1}".format(self.prefix, len(self.lambda_functions)) ii = self.passmanager.gather(ImportedIds, node, self.ctx) ii.difference_update(self.lambda_functions) # remove current lambdas binded_args = [ast.Name(iin, ast.Load(), None) for iin in sorted(ii)] node.args.args = ( [ast.Name(iin, ast.Param(), None) for iin in sorted(ii)] + node.args.args) forged_fdef = ast.FunctionDef(forged_name, copy(node.args), [ast.Return(node.body)], [], None) self.lambda_functions.append(forged_fdef) self.global_declarations[forged_name] = forged_fdef proxy_call = ast.Name(forged_name, ast.Load(), None) if binded_args: return ast.Call( ast.Attribute(ast.Name(mangle('functools'), ast.Load(), None), "partial", ast.Load()), [proxy_call] + binded_args, []) else: return proxy_call
def visit_FunctionDef(self, node): modified_node = self.generic_visit(node) returned_id = len(self.func_returned_stack) returned_flags = self.func_returned_stack.pop() if returned_flags: node.body.insert( 0, gast.Assign(targets=[ gast.Name(id=self.returned_flag + str(returned_id), ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Constant(value=False, kind=None))) node.body.insert( 0, gast.Assign(targets=[ gast.Name(id=self.returned_value_key, ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Constant(value=None, kind=None))) node.body.append( gast.Return(value=gast.Name(id=self.returned_value_key, ctx=gast.Load(), annotation=None, type_comment=None))) return modified_node
def test_ast_to_object(self): node = gast.FunctionDef( name='f', args=gast.arguments( args=[gast.Name('a', gast.Param(), None)], vararg=None, kwonlyargs=[], kwarg=None, defaults=[], kw_defaults=[]), body=[ gast.Return( gast.BinOp( op=gast.Add(), left=gast.Name('a', gast.Load(), None), right=gast.Num(1))) ], decorator_list=[], returns=None) module, source, _ = compiler.ast_to_object(node) expected_source = """ # coding=utf-8 def f(a): return a + 1 """ self.assertEqual( textwrap.dedent(expected_source).strip(), source.strip()) self.assertEqual(2, module.f(1)) with open(module.__file__, 'r') as temp_output: self.assertEqual( textwrap.dedent(expected_source).strip(), temp_output.read().strip())
def visit_Lambda(self, node): op = issimpleoperator(node) if op is not None: if mangle('operator') not in self.global_declarations: import_ = ast.Import( [ast.alias('operator', mangle('operator'))]) self.imports.append(import_) operator_module = MODULES['operator'] self.global_declarations[mangle('operator')] = operator_module return ast.Attribute( ast.Name(mangle('operator'), ast.Load(), None, None), op, ast.Load()) self.generic_visit(node) forged_name = "{0}_lambda{1}".format(self.prefix, len(self.lambda_functions)) ii = self.gather(ImportedIds, node) ii.difference_update(self.lambda_functions) # remove current lambdas binded_args = [ ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii) ] node.args.args = ( [ast.Name(iin, ast.Param(), None, None) for iin in sorted(ii)] + node.args.args) for patternname, pattern in self.patterns.items(): if issamelambda(pattern, node): proxy_call = ast.Name(patternname, ast.Load(), None, None) break else: duc = ExtendedDefUseChains() nodepattern = deepcopy(node) duc.visit(ast.Module([ast.Expr(nodepattern)], [])) self.patterns[forged_name] = nodepattern, duc forged_fdef = ast.FunctionDef(forged_name, copy(node.args), [ast.Return(node.body)], [], None, None) metadata.add(forged_fdef, metadata.Local()) self.lambda_functions.append(forged_fdef) self.global_declarations[forged_name] = forged_fdef proxy_call = ast.Name(forged_name, ast.Load(), None, None) if binded_args: if MODULES['functools'] not in self.global_declarations.values(): import_ = ast.Import( [ast.alias('functools', mangle('functools'))]) self.imports.append(import_) functools_module = MODULES['functools'] self.global_declarations[mangle( 'functools')] = functools_module return ast.Call( ast.Attribute( ast.Name(mangle('functools'), ast.Load(), None, None), "partial", ast.Load()), [proxy_call] + binded_args, []) else: return proxy_call
def create_funcDef_node(nodes, name, input_args, return_name_ids): """ Wrapper all statements of nodes into one ast.FunctionDef, which can be called by ast.Call. """ nodes = copy.copy(nodes) # add return statement if return_name_ids: nodes.append(gast.Return(value=generate_name_node(return_name_ids))) else: nodes.append(gast.Return(value=None)) func_def_node = gast.FunctionDef(name=name, args=input_args, body=nodes, decorator_list=[], returns=None, type_comment=None) return func_def_node
def visit_FunctionDef(self, node): self.yield_points = self.gather(YieldPoints, node) for stmt in node.body: self.visit(stmt) # Look for nodes that have no successors; the predecessors of # the special NIL node are those AST nodes that end control flow # without a return statement. for n in self.cfg.predecessors(CFG.NIL): if not isinstance(n, (ast.Return, ast.Raise)): self.update = True if self.yield_points: node.body.append(ast.Return(None)) else: none = ast.Attribute( ast.Name("__builtin__", ast.Load(), None, None), 'None', ast.Load()) node.body.append(ast.Return(none)) break return node
def visit_Return(self, node): if node is self.guard: holder = "StaticIfNoReturn" else: holder = "StaticIfReturn" return ast.Return( ast.Call( ast.Attribute( ast.Attribute(ast.Name("__builtin__", ast.Load(), None), "pythran", ast.Load()), holder, ast.Load()), [node.value], []))
def __init__(self, astc, args, func_field): super().__init__() assert isinstance(astc.nast, (gast.FunctionDef, gast.Lambda)) self.name = astc.gast.name if isinstance(astc.nast, gast.FunctionDef) else (lambda: None).__name__ self.args = args self.func_field = func_field if isinstance(astc.nast, gast.Lambda): astc.nast.body = gast.Return(value=astc.nast.body) # Add return to the body self.ast = astc.nast self.filename = astc.filename self.lineno = astc.lineno
def visit_Return(self, node): if node is self.guard: holder = "StaticIfNoReturn" else: holder = "StaticIfReturn" value = node.value return ast.Return( ast.Call( ast.Attribute( ast.Attribute(ast.Name("builtins", ast.Load(), None, None), "pythran", ast.Load()), holder, ast.Load()), [value] if value else [ast.Constant(None, None)], []))
def visit_AnyComp(self, node, comp_type, *path): self.update = True node.elt = self.visit(node.elt) name = "{0}_comprehension{1}".format(comp_type, self.count) self.count += 1 args = self.gather(ImportedIds, node) self.count_iter = 0 starget = "__target" body = reduce(self.nest_reducer, reversed(node.generators), ast.Expr( ast.Call( reduce(lambda x, y: ast.Attribute(x, y, ast.Load()), path[1:], ast.Name(path[0], ast.Load(), None, None)), [ast.Name(starget, ast.Load(), None, None), node.elt], [], ) ) ) # add extra metadata to this node metadata.add(body, metadata.Comprehension(starget)) init = ast.Assign( [ast.Name(starget, ast.Store(), None, None)], ast.Call( ast.Attribute( ast.Name('builtins', ast.Load(), None, None), comp_type, ast.Load() ), [], [],) ) result = ast.Return(ast.Name(starget, ast.Load(), None, None)) sargs = [ast.Name(arg, ast.Param(), None, None) for arg in args] fd = ast.FunctionDef(name, ast.arguments(sargs, [], None, [], [], None, []), [init, body, result], [], None, None) metadata.add(fd, metadata.Local()) self.ctx.module.body.append(fd) return ast.Call( ast.Name(name, ast.Load(), None, None), [ast.Name(arg.id, ast.Load(), None, None) for arg in sargs], [], ) # no sharing !
def test_load_ast(self): node = gast.FunctionDef( name='f', args=gast.arguments( args=[ gast.Name( 'a', ctx=gast.Param(), annotation=None, type_comment=None) ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[ gast.Return( gast.BinOp( op=gast.Add(), left=gast.Name( 'a', ctx=gast.Load(), annotation=None, type_comment=None), right=gast.Constant(1, kind=None))) ], decorator_list=[], returns=None, type_comment=None) module, source, _ = loader.load_ast(node) expected_source = """ # coding=utf-8 def f(a): return (a + 1) """ self.assertEqual( textwrap.dedent(expected_source).strip(), source.strip()) self.assertEqual(2, module.f(1)) with open(module.__file__, 'r') as temp_output: self.assertEqual( textwrap.dedent(expected_source).strip(), temp_output.read().strip())
def visit_Module(self, node): """Turn globals assignment to functionDef and visit function defs. """ module_body = list() # Gather top level assigned variables. for stmt in node.body: if not isinstance(stmt, ast.Assign): continue for target in stmt.targets: if not isinstance(target, ast.Name): raise PythranSyntaxError( "Top-level assignment to an expression.", target) if target.id in self.to_expand: raise PythranSyntaxError( "Multiple top-level definition of %s." % target.id, target) self.to_expand.add(target.id) for stmt in node.body: if isinstance(stmt, ast.Assign): self.local_decl = set() cst_value = self.visit(stmt.value) for target in stmt.targets: assert isinstance(target, ast.Name) module_body.append( ast.FunctionDef( target.id, ast.arguments([], None, [], [], None, []), [ast.Return(value=cst_value)], [], None)) metadata.add(module_body[-1].body[0], metadata.StaticReturn()) else: self.local_decl = self.passmanager.gather( LocalNameDeclarations, stmt, self.ctx) module_body.append(self.visit(stmt)) node.body = module_body return node
def fill(self, hole, rng): stmts_hole = Hole(ASTHoleType.STMTS, hole.metadata) number_hole = Hole(ASTHoleType.NUMBER, hole.metadata) return ASTWithHoles(1, [stmts_hole, number_hole], lambda stmts, v: stmts + [gast.Return(value=v)])
class PyASTGraphsTest(parameterized.TestCase): def test_schema(self): """Check that the generated schema is reasonable.""" # Building should succeed schema = py_ast_graphs.SCHEMA # Should have one graph node for each AST node, plus sorted helpers. seq_helpers = [ "BoolOp_values-seq-helper", "Call_args-seq-helper", "For_body-seq-helper", "FunctionDef_body-seq-helper", "If_body-seq-helper", "If_orelse-seq-helper", "Module_body-seq-helper", "While_body-seq-helper", "arguments_args-seq-helper", ] expected_keys = [ *py_ast_graphs.PY_AST_SPECS.keys(), *seq_helpers, ] self.assertEqual(list(schema.keys()), expected_keys) # Check a few elements expected_fundef_in_edges = { "parent_in", "returns_in", "returns_missing", "body_in", "args_in", } expected_fundef_out_edges = { "parent_out", "returns_out", "body_out_all", "body_out_first", "body_out_last", "args_out", } self.assertEqual(set(schema["FunctionDef"].in_edges), expected_fundef_in_edges) self.assertEqual(set(schema["FunctionDef"].out_edges), expected_fundef_out_edges) expected_expr_in_edges = {"parent_in", "value_in"} expected_expr_out_edges = {"parent_out", "value_out"} self.assertEqual(set(schema["Expr"].in_edges), expected_expr_in_edges) self.assertEqual(set(schema["Expr"].out_edges), expected_expr_out_edges) expected_seq_helper_in_edges = { "parent_in", "item_in", "next_in", "next_missing", "prev_in", "prev_missing", } expected_seq_helper_out_edges = { "parent_out", "item_out", "next_out", "prev_out", } for seq_helper in seq_helpers: self.assertEqual(set(schema[seq_helper].in_edges), expected_seq_helper_in_edges) self.assertEqual(set(schema[seq_helper].out_edges), expected_seq_helper_out_edges) def test_ast_graph_conforms_to_schema(self): # Some example code using a few different syntactic constructs, to cover # a large set of nodes in the schema root = gast.parse( textwrap.dedent("""\ def foo(n): if n <= 1: return 1 else: return foo(n-1) + foo(n-2) def bar(m, n) -> int: x = n for i in range(m): if False: continue x = x + i while True: break return x x0 = 1 + 2 - 3 * 4 / 5 x1 = (1 == 2) and (3 < 4) and (5 > 6) x2 = (7 <= 8) and (9 >= 10) or (11 != 12) x2 = bar(13, 14 + 15) """)) graph, _ = py_ast_graphs.py_ast_to_graph(root) # Graph should match the schema schema_util.assert_conforms_to_schema(graph, py_ast_graphs.SCHEMA) def test_ast_graph_nodes(self): """Check node IDs, node types, and forward mapping.""" root = gast.parse( textwrap.dedent("""\ pass def foo(n): if n <= 1: return 1 """)) graph, forward_map = py_ast_graphs.py_ast_to_graph(root) # pytype: disable=attribute-error self.assertIn("root__Module", graph) self.assertEqual(graph["root__Module"].node_type, "Module") self.assertEqual(forward_map[id(root)], "root__Module") self.assertIn("root_body_1__Module_body-seq-helper", graph) self.assertEqual( graph["root_body_1__Module_body-seq-helper"].node_type, "Module_body-seq-helper") self.assertIn("root_body_1_item_body_0_item__If", graph) self.assertEqual(graph["root_body_1_item_body_0_item__If"].node_type, "If") self.assertEqual(forward_map[id(root.body[1].body[0])], "root_body_1_item_body_0_item__If") self.assertIn("root_body_1_item_body_0_item_test_left__Name", graph) self.assertEqual( graph["root_body_1_item_body_0_item_test_left__Name"].node_type, "Name") self.assertEqual(forward_map[id(root.body[1].body[0].test.left)], "root_body_1_item_body_0_item_test_left__Name") # pytype: enable=attribute-error def test_ast_graph_unique_field_edges(self): """Test that edges for unique fields are correct.""" root = gast.parse("print(1)") graph, _ = py_ast_graphs.py_ast_to_graph(root) self.assertEqual( graph["root_body_0_item__Expr"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item_value__Call"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual( graph["root_body_0_item_value__Call"].out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Expr"), in_edge=graph_types.InEdgeType("value_in")) ]) def test_ast_graph_optional_field_edges(self): """Test that edges for optional fields are correct.""" root = gast.parse("return 1\nreturn") graph, _ = py_ast_graphs.py_ast_to_graph(root) self.assertEqual( graph["root_body_0_item__Return"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0_item_value__Constant"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual( graph["root_body_0_item_value__Constant"].out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Return"), in_edge=graph_types.InEdgeType("value_in")) ]) self.assertEqual( graph["root_body_1_item__Return"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_1_item__Return"), in_edge=graph_types.InEdgeType("value_missing")) ]) def test_ast_graph_sequence_field_edges(self): """Test that edges for sequence fields are correct. Note that sequence fields produce connections between three nodes: the parent, the helper node, and the child. """ root = gast.parse( textwrap.dedent("""\ print(1) print(2) print(3) print(4) print(5) print(6) """)) graph, _ = py_ast_graphs.py_ast_to_graph(root) # Child edges from the parent node node = graph["root__Module"] self.assertLen(node.out_edges["body_out_all"], 6) self.assertEqual(node.out_edges["body_out_first"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual(node.out_edges["body_out_last"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_5__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("parent_in")) ]) # Edges from the sequence helper node = graph["root_body_0__Module_body-seq-helper"] self.assertEqual(node.out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root__Module"), in_edge=graph_types.InEdgeType("body_in")) ]) self.assertEqual(node.out_edges["item_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Expr"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual(node.out_edges["prev_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("prev_missing")) ]) self.assertEqual(node.out_edges["next_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_1__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("prev_in")) ]) # Parent edge of the item node = graph["root_body_0_item__Expr"] self.assertEqual(node.out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("item_in")) ]) @parameterized.named_parameters( { "testcase_name": "unexpected_type", "ast": gast.Subscript(value=None, slice=None, ctx=None), "expected_error": "Unknown AST node type 'Subscript'", }, { "testcase_name": "too_many_unique", "ast": gast.Assign(targets=[ gast.Name("foo", gast.Store(), None, None), gast.Name("bar", gast.Store(), None, None) ], value=gast.Constant(True, None)), "expected_error": "Expected 1 child for field 'targets' of node .*; got 2", }, { "testcase_name": "missing_unique", "ast": gast.Assign(targets=[], value=gast.Constant(True, None)), "expected_error": "Expected 1 child for field 'targets' of node .*; got 0", }, { "testcase_name": "too_many_optional", "ast": gast.Return(value=[ gast.Name("foo", gast.Load(), None, None), gast.Name("bar", gast.Load(), None, None) ]), "expected_error": "Expected at most 1 child for field 'value' of node .*; got 2", }) def test_invalid_graphs(self, ast, expected_error): with self.assertRaisesRegex(ValueError, expected_error): py_ast_graphs.py_ast_to_graph(ast)
def visit_FunctionDef(self, node): # Construct a namer to guarantee we create unique names that don't # override existing names self.namer = naming.Namer.build(node) # Check that this function has exactly one return statement at the end return_nodes = [n for n in gast.walk(node) if isinstance(n, gast.Return)] if ((len(return_nodes) > 1) or not isinstance(node.body[-1], gast.Return)): raise ValueError('function must have exactly one return statement') return_node = ast_.copy_node(return_nodes[0]) # Perform AD on the function body body, adjoint_body = self.visit_statements(node.body[:-1]) # Annotate the first statement of the primal and adjoint as such if body: body[0] = comments.add_comment(body[0], 'Beginning of forward pass') if adjoint_body: adjoint_body[0] = comments.add_comment( adjoint_body[0], 'Beginning of backward pass') # Before updating the primal arguments, extract the arguments we want # to differentiate with respect to dx = gast.Tuple([create.create_grad(node.args.args[i], self.namer) for i in self.wrt], ctx=gast.Load()) if self.preserve_result: # Append an extra Assign operation to the primal body # that saves the original output value stored_result_node = quoting.quote(self.namer.unique('result')) assign_stored_result = template.replace( 'result=orig_result', result=stored_result_node, orig_result=return_node.value) body.append(assign_stored_result) dx.elts.append(stored_result_node) for _dx in dx.elts: _dx.ctx = gast.Load() return_dx = gast.Return(value=dx) # We add the stack as first argument of the primal node.args.args = [self.stack] + node.args.args # Rename the function to its primal name func = anno.getanno(node, 'func') node.name = naming.primal_name(func, self.wrt) # The new body is the primal body plus the return statement node.body = body + node.body[-1:] # Find the cost; the first variable of potentially multiple return values # The adjoint will receive a value for the initial gradient of the cost y = node.body[-1].value if isinstance(y, gast.Tuple): y = y.elts[0] dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(), annotation=None) # Construct the adjoint adjoint_template = grads.adjoints[gast.FunctionDef] adjoint, = template.replace(adjoint_template, namer=self.namer, adjoint_body=adjoint_body, return_dx=return_dx) adjoint.args.args.extend([self.stack, dy]) adjoint.args.args.extend(node.args.args[1:]) adjoint.name = naming.adjoint_name(func, self.wrt) return node, adjoint
def visit_Module(self, node): """Turn globals assignment to functionDef and visit function defs. """ module_body = list() symbols = set() # Gather top level assigned variables. for stmt in node.body: if isinstance(stmt, (ast.Import, ast.ImportFrom)): for alias in stmt.names: name = alias.asname or alias.name symbols.add(name) # no warning here elif isinstance(stmt, ast.FunctionDef): if stmt.name in symbols: raise PythranSyntaxError( "Multiple top-level definition of %s." % stmt.name, stmt) else: symbols.add(stmt.name) if not isinstance(stmt, ast.Assign): continue for target in stmt.targets: if not isinstance(target, ast.Name): raise PythranSyntaxError( "Top-level assignment to an expression.", target) if target.id in self.to_expand: raise PythranSyntaxError( "Multiple top-level definition of %s." % target.id, target) if isinstance(stmt.value, ast.Name): if stmt.value.id in symbols: continue # create aliasing between top level symbols self.to_expand.add(target.id) for stmt in node.body: if isinstance(stmt, ast.Assign): # that's not a global var, but a module/function aliasing if all( isinstance(t, ast.Name) and t.id not in self.to_expand for t in stmt.targets): module_body.append(stmt) continue self.local_decl = set() cst_value = GlobalTransformer().visit(self.visit(stmt.value)) for target in stmt.targets: assert isinstance(target, ast.Name) module_body.append( ast.FunctionDef( target.id, ast.arguments([], [], None, [], [], None, []), [ast.Return(value=cst_value)], [], None, None)) metadata.add(module_body[-1].body[0], metadata.StaticReturn()) else: self.local_decl = self.gather(LocalNameDeclarations, stmt) module_body.append(self.visit(stmt)) self.update |= bool(self.to_expand) node.body = module_body return node
def visit_FunctionDef(self, node): self.function_def.append(node) self.return_value_name[node] = None self.return_name[node] = [] self.return_no_value_name[node] = [] self.pre_analysis = ReturnAnalysisVisitor(node) max_return_length = self.pre_analysis.get_func_max_return_length(node) while self.pre_analysis.get_func_return_count(node) > 1: self.generic_visit(node) self.pre_analysis = ReturnAnalysisVisitor(node) if max_return_length == 0: self.function_def.pop() return node # Prepend initialization of final return and append final return statement value_name = self.return_value_name[node] if value_name is not None: node.body.append( gast.Return(value=gast.Name( id=value_name, ctx=gast.Load(), annotation=None, type_comment=None))) init_names = [ unique_name.generate(RETURN_VALUE_INIT_NAME) for i in range(max_return_length) ] assign_zero_nodes = [ create_fill_constant_node(iname, 0.0) for iname in init_names ] if len(init_names) == 1: return_value_nodes = gast.Name( id=init_names[0], ctx=gast.Load(), annotation=None, type_comment=None) else: # We need to initialize return value as a tuple because control # flow requires some inputs or outputs have same structure return_value_nodes = gast.Tuple( elts=[ gast.Name( id=iname, ctx=gast.Load(), annotation=None, type_comment=None) for iname in init_names ], ctx=gast.Load()) assign_return_value_node = gast.Assign( targets=[ gast.Name( id=value_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=return_value_nodes) node.body.insert(0, assign_return_value_node) node.body[:0] = assign_zero_nodes # Prepend control flow boolean nodes such as '__return@1 = False' for name in self.return_name[node]: assign_false_node = create_fill_constant_node(name, False) node.body.insert(0, assign_false_node) # Prepend no value placeholders for name in self.return_no_value_name[node]: assign_no_value_node = create_fill_constant_node( name, RETURN_NO_VALUE_MAGIC_NUM) node.body.insert(0, assign_no_value_node) self.function_def.pop() return node
def get_for_stmt_nodes(self, node): # TODO: consider for - else in python # 1. get key statements for different cases # NOTE 1: three key statements: # 1). init_stmts: list[node], prepare nodes of for loop, may not only one # 2). cond_stmt: node, condition node to judge whether continue loop # 3). body_stmts: list[node], updated loop body, sometimes we should change # the original statement in body, not just append new statement # # NOTE 2: The following `for` statements will be transformed to `while` statements: # 1). for x in range(*) # 2). for x in iter_var # 3). for i, x in enumerate(*) current_for_node_parser = ForNodeVisitor(node) stmts_tuple = current_for_node_parser.parse() if stmts_tuple is None: return [node] init_stmts, cond_stmt, body_stmts = stmts_tuple # 2. get original loop vars loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases, # we need append new loop var & remove useless loop var # 1. for x in var -> x is no need # 2. for i, x in enumerate(var) -> x is no need if current_for_node_parser.is_for_iter( ) or current_for_node_parser.is_for_enumerate_iter(): iter_var_name = current_for_node_parser.iter_var_name iter_idx_name = current_for_node_parser.iter_idx_name loop_var_names.add(iter_idx_name) if iter_var_name not in create_var_names: loop_var_names.remove(iter_var_name) # 3. prepare result statement list new_stmts = [] # Python can create variable in loop and use it out of loop, E.g. # # for x in range(10): # y += x # print(x) # x = 10 # # We need to create static variable for those variables for name in create_var_names: if "." not in name: new_stmts.append(create_static_variable_gast_node(name)) # 4. append init statements new_stmts.extend(init_stmts) # 5. create & append condition function node condition_func_node = gast.FunctionDef( name=unique_name.generate(FOR_CONDITION_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=[gast.Return(value=cond_stmt)], decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(condition_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) # 6. create & append loop body function node # append return values for loop body body_stmts.append( gast.Return( value=generate_name_node(loop_var_names, ctx=gast.Load()))) body_func_node = gast.FunctionDef( name=unique_name.generate(FOR_BODY_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=body_stmts, decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(body_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) # 7. create & append while loop node while_loop_node = create_while_node(condition_func_node.name, body_func_node.name, loop_var_names) new_stmts.append(while_loop_node) return new_stmts
def visit_If(self, node): if node.test not in self.static_expressions: return self.generic_visit(node) imported_ids = self.gather(ImportedIds, node) assigned_ids_left = self.escaping_ids(node, node.body) assigned_ids_right = self.escaping_ids(node, node.orelse) assigned_ids_both = assigned_ids_left.union(assigned_ids_right) imported_ids.update(i for i in assigned_ids_left if i not in assigned_ids_right) imported_ids.update(i for i in assigned_ids_right if i not in assigned_ids_left) imported_ids = sorted(imported_ids) assigned_ids = sorted(assigned_ids_both) fbody = self.make_fake(node.body) true_has_return = self.gather(HasReturn, fbody) true_has_break = self.gather(HasBreak, fbody) true_has_cont = self.gather(HasContinue, fbody) felse = self.make_fake(node.orelse) false_has_return = self.gather(HasReturn, felse) false_has_break = self.gather(HasBreak, felse) false_has_cont = self.gather(HasContinue, felse) has_return = true_has_return or false_has_return has_break = true_has_break or false_has_break has_cont = true_has_cont or false_has_cont self.generic_visit(node) func_true = outline(self.true_name(), imported_ids, assigned_ids, node.body, has_return, has_break, has_cont) func_false = outline(self.false_name(), imported_ids, assigned_ids, node.orelse, has_return, has_break, has_cont) self.new_functions.extend((func_true, func_false)) actual_call = self.make_dispatcher(node.test, func_true, func_false, imported_ids) # variable modified within the static_if expected_return = [ast.Name(ii, ast.Store(), None, None) for ii in assigned_ids] self.update = True # name for various variables resulting from the static_if n = len(self.new_functions) status_n = "$status{}".format(n) return_n = "$return{}".format(n) cont_n = "$cont{}".format(n) if has_return: cfg = self.cfgs[-1] always_return = all(isinstance(x, (ast.Return, ast.Yield)) for x in cfg[node]) always_return &= true_has_return and false_has_return fast_return = [ast.Name(status_n, ast.Store(), None, None), ast.Name(return_n, ast.Store(), None, None), ast.Name(cont_n, ast.Store(), None, None)] if always_return: return [ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call, None), ast.Return(ast.Name(return_n, ast.Load(), None, None))] else: cont_ass = self.make_control_flow_handlers(cont_n, status_n, expected_return, has_cont, has_break) cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None), [ast.Eq()], [ast.Constant(EARLY_RET, None)]) return [ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call, None), ast.If(cmpr, [ast.Return(ast.Name(return_n, ast.Load(), None, None))], cont_ass)] elif has_break or has_cont: cont_ass = self.make_control_flow_handlers(cont_n, status_n, expected_return, has_cont, has_break) fast_return = [ast.Name(status_n, ast.Store(), None, None), ast.Name(cont_n, ast.Store(), None, None)] return [ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call, None)] + cont_ass elif expected_return: return ast.Assign([ast.Tuple(expected_return, ast.Store())], actual_call, None) else: return ast.Expr(actual_call)
def get_while_stmt_nodes(self, node): # TODO: consider while - else in python if not self.name_visitor.is_control_flow_loop(node): return [node] loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) new_stmts = [] # Python can create variable in loop and use it out of loop, E.g. # # while x < 10: # x += 1 # y = x # z = y # # We need to create static variable for those variables for name in create_var_names: if "." not in name: new_stmts.append(create_static_variable_gast_node(name)) # while x < 10 in dygraph should be convert into static tensor < 10 for name in loop_var_names: new_stmts.append(to_static_variable_gast_node(name)) logical_op_transformer = LogicalOpTransformer(node.test) cond_value_node = logical_op_transformer.transform() condition_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_CONDITION_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=[gast.Return(value=cond_value_node)], decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(condition_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) new_body = node.body new_body.append( gast.Return( value=generate_name_node(loop_var_names, ctx=gast.Load()))) body_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_BODY_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=new_body, decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(body_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) while_loop_node = create_while_node(condition_func_node.name, body_func_node.name, loop_var_names) new_stmts.append(while_loop_node) return new_stmts
def get_for_stmt_nodes(self, node): # TODO: consider for - else in python if not self.name_visitor.is_control_flow_loop(node): return [node] # TODO: support non-range case range_call_node = self.get_for_range_node(node) if range_call_node is None: return [node] if not isinstance(node.target, gast.Name): return [node] iter_var_name = node.target.id init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts( iter_var_name, range_call_node.args) loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) new_stmts = [] # Python can create variable in loop and use it out of loop, E.g. # # for x in range(10): # y += x # print(x) # x = 10 # # We need to create static variable for those variables for name in create_var_names: if "." not in name: new_stmts.append(create_static_variable_gast_node(name)) new_stmts.append(init_stmt) # for x in range(10) in dygraph should be convert into static tensor + 1 <= 10 for name in loop_var_names: new_stmts.append(to_static_variable_gast_node(name)) condition_func_node = gast.FunctionDef( name=unique_name.generate(FOR_CONDITION_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=[gast.Return(value=cond_stmt)], decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(condition_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) new_body = node.body new_body.append(change_stmt) new_body.append( gast.Return( value=generate_name_node(loop_var_names, ctx=gast.Load()))) body_func_node = gast.FunctionDef( name=unique_name.generate(FOR_BODY_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=new_body, decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(body_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) while_loop_node = create_while_node(condition_func_node.name, body_func_node.name, loop_var_names) new_stmts.append(while_loop_node) return new_stmts
def generate_Return(self): return gast.Return(self.generate_expression())
def visit_If(self, node): self.generic_visit(node) if node.test not in self.static_expressions: return node imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx) assigned_ids_left = set( self.passmanager.gather(IsAssigned, self.make_fake(node.body), self.ctx).keys()) assigned_ids_right = set( self.passmanager.gather(IsAssigned, self.make_fake(node.orelse), self.ctx).keys()) assigned_ids_both = assigned_ids_left.union(assigned_ids_right) imported_ids.update(i for i in assigned_ids_left if i not in assigned_ids_right) imported_ids.update(i for i in assigned_ids_right if i not in assigned_ids_left) imported_ids = sorted(imported_ids) assigned_ids = sorted(assigned_ids_both) true_has_return = self.passmanager.gather(HasReturn, self.make_fake(node.body), self.ctx) false_has_return = self.passmanager.gather(HasReturn, self.make_fake(node.orelse), self.ctx) has_return = true_has_return or false_has_return func_true = outline(self.true_name(), imported_ids, assigned_ids, node.body, has_return) func_false = outline(self.false_name(), imported_ids, assigned_ids, node.orelse, has_return) self.new_functions.extend((func_true, func_false)) actual_call = self.make_dispatcher(node.test, func_true, func_false, imported_ids) expected_return = [ ast.Name(ii, ast.Load(), None) for ii in assigned_ids ] if has_return: n = len(self.new_functions) fast_return = [ ast.Name("$status{}".format(n), ast.Load(), None), ast.Name("$return{}".format(n), ast.Load(), None), ast.Name("$cont{}".format(n), ast.Load(), None) ] if expected_return: cont_ass = [ ast.Assign([ast.Tuple(expected_return, ast.Store())], ast.Name("$cont{}".format(n), ast.Load(), None)) ] else: cont_ass = [] return [ ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call), ast.If(ast.Name("$status{}".format(n), ast.Load(), None), [ ast.Return( ast.Name("$return{}".format(n), ast.Load(), None)) ], cont_ass) ] elif expected_return: return ast.Assign([ast.Tuple(expected_return, ast.Store())], actual_call) else: return ast.Expr(actual_call)
def get_while_stmt_nodes(self, node): loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) new_stmts = [] # Python can create variable in loop and use it out of loop, E.g. # # while x < 10: # x += 1 # y = x # z = y # # We need to create static variable for those variables for name in create_var_names: if "." not in name: new_stmts.append(create_static_variable_gast_node(name)) condition_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_CONDITION_PREFIX), args=gast.arguments( args=[ gast.Name( id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=[gast.Return(value=node.test)], decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(condition_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) new_body = node.body new_body.append( gast.Return(value=generate_name_node( loop_var_names, ctx=gast.Load()))) body_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_BODY_PREFIX), args=gast.arguments( args=[ gast.Name( id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=new_body, decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(body_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) while_loop_nodes = create_while_nodes( condition_func_node.name, body_func_node.name, loop_var_names) new_stmts.extend(while_loop_nodes) return new_stmts
def fill(self, hole, rng): stmts_hole = Hole(ASTHoleType.STMTS, hole.metadata) return ASTWithHoles(1, [stmts_hole], lambda stmts: stmts + [gast.Return(value=None)])