def visit_arguments(self, node): # missing locations for vararg and kwarg set at function level if node.vararg: vararg = ast.Name(node.vararg, ast.Param()) else: vararg = None if node.kwarg: kwarg = ast.Name(node.kwarg, ast.Param()) else: kwarg = None if node.vararg: vararg = ast.Name(node.vararg, ast.Param()) else: vararg = None new_node = gast.arguments( self._visit(node.args), [], # posonlyargs self._visit(vararg), [], # kwonlyargs [], # kw_defaults self._visit(kwarg), self._visit(node.defaults), ) return new_node
def outline(name, formal_parameters, out_parameters, stmts, has_return): args = ast.arguments( [ast.Name(fp, ast.Param(), None) for fp in formal_parameters], None, [], [], None, []) if isinstance(stmts, ast.expr): assert not out_parameters, "no out parameters with expr" fdef = ast.FunctionDef(name, args, [ast.Return(stmts)], [], None) else: fdef = ast.FunctionDef(name, args, stmts, [], None) # this is part of a huge trick that plays with delayed type inference # it basically computes the return type based on out parameters, and # the return statement is unconditionally added so if we have other # returns, there will be a computation of the output type based on the # __combined of the regular return types and this one The original # returns have been patched above to have a different type that # cunningly combines with this output tuple # # This is the only trick I found to let pythran compute both the output # variable type and the early return type. But hey, a dirty one :-/ stmts.append( ast.Return( ast.Tuple( [ast.Name(fp, ast.Load(), None) for fp in out_parameters], ast.Load()))) if has_return: pr = PatchReturn(stmts[-1]) pr.visit(fdef) return fdef
def test_ast_to_object(self): node = gast.FunctionDef( name='f', args=gast.arguments( args=[gast.Name('a', gast.Param(), None)], vararg=None, kwonlyargs=[], kwarg=None, defaults=[], kw_defaults=[]), body=[ gast.Return( gast.BinOp( op=gast.Add(), left=gast.Name('a', gast.Load(), None), right=gast.Num(1))) ], decorator_list=[], returns=None) module, source, _ = compiler.ast_to_object(node) expected_source = """ # coding=utf-8 def f(a): return a + 1 """ self.assertEqual( textwrap.dedent(expected_source).strip(), source.strip()) self.assertEqual(2, module.f(1)) with open(module.__file__, 'r') as temp_output: self.assertEqual( textwrap.dedent(expected_source).strip(), temp_output.read().strip())
def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load): """ Find out the ast.Name.id list of input by analyzing node's AST information. """ name_ids = [ var_id for var_id, var_ctx in var_ids_dict.items() if isinstance(var_ctx[0], ctx) ] if return_ids: new_args = set(return_ids) - set(name_ids) name_ids.extend(list(new_args)) name_ids.sort() args = [ gast.Name(id=name_id, ctx=gast.Load(), annotation=None, type_comment=None) for name_id in name_ids ] arguments = gast.arguments(args=args, posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]) return arguments
def test_ast_to_object(self): node = gast.FunctionDef( name='f', args=gast.arguments( args=[gast.Name('a', gast.Param(), None)], vararg=None, kwonlyargs=[], kwarg=None, defaults=[], kw_defaults=[]), body=[ gast.Return( gast.BinOp( op=gast.Add(), left=gast.Name('a', gast.Load(), None), right=gast.Num(1))) ], decorator_list=[], returns=None) module, source = compiler.ast_to_object(node) expected_source = """ def f(a): return a + 1 """ self.assertEqual( textwrap.dedent(expected_source).strip(), source.strip()) self.assertEqual(2, module.f(1)) with open(module.__file__, 'r') as temp_output: self.assertEqual( textwrap.dedent(expected_source).strip(), temp_output.read().strip())
def visit_ListComp(self, node): if node in self.optimizable_comprehension: self.update = True self.generic_visit(node) iterList = [] varList = [] for gen in node.generators: iterList.append(self.make_Iterator(gen)) varList.append(ast.Name(gen.target.id, ast.Param(), None)) # If dim = 1, product is useless if len(iterList) == 1: iterAST = iterList[0] varAST = ast.arguments([varList[0]], None, [], [], None, []) else: self.use_itertools = True prodName = ast.Attribute(value=ast.Name(id=mangle('itertools'), ctx=ast.Load(), annotation=None), attr='product', ctx=ast.Load()) varid = varList[0].id # retarget this id, it's free renamings = {v.id: (i, ) for i, v in enumerate(varList)} node.elt = ConvertToTuple(varid, renamings).visit(node.elt) iterAST = ast.Call(prodName, iterList, []) varAST = ast.arguments([ast.Name(varid, ast.Param(), None)], None, [], [], None, []) mapName = ast.Attribute(value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None), attr='map', ctx=ast.Load()) ldBodymap = node.elt ldmap = ast.Lambda(varAST, ldBodymap) return ast.Call(mapName, [ldmap, iterAST], []) else: return self.generic_visit(node)
def _make_arguments(*args): """Returns a gast arguments node with these argument nodes.""" return gast.arguments(args=list(args), posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[])
def visit_Compare(self, node): node = self.generic_visit(node) if len(node.ops) > 1: # in case we have more than one compare operator # we generate an auxiliary function # that lazily evaluates the needed parameters imported_ids = self.gather(ImportedIds, node) imported_ids = sorted(imported_ids) binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids] # name of the new function forged_name = "{0}_compare{1}".format(self.prefix, len(self.compare_functions)) # call site call = ast.Call( ast.Name(forged_name, ast.Load(), None), binded_args, []) # new function arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids] args = ast.arguments(arg_names, None, [], [], None, []) body = [] # iteratively fill the body (yeah, feel your body!) if is_trivially_copied(node.left): prev_holder = node.left else: body.append(ast.Assign([ast.Name('$0', ast.Store(), None)], node.left)) prev_holder = ast.Name('$0', ast.Load(), None) for i, exp in enumerate(node.comparators): if is_trivially_copied(exp): holder = exp else: body.append(ast.Assign([ast.Name('${}'.format(i+1), ast.Store(), None)], exp)) holder = ast.Name('${}'.format(i+1), ast.Load(), None) cond = ast.Compare(prev_holder, [node.ops[i]], [holder]) body.append(ast.If(cond, [ast.Pass()], [ast.Return(path_to_attr(('__builtin__', 'False')))])) prev_holder = holder body.append(ast.Return(path_to_attr(('__builtin__', 'True')))) forged_fdef = ast.FunctionDef(forged_name, args, body, [], None) self.compare_functions.append(forged_fdef) return call else: return node
def parse_cond_args(parent_ids_dict, var_ids_dict, modified_ids_dict=None, ctx=gast.Load): """ Find out the ast.Name.id list of input by analyzing node's AST information. """ # 1. filter the var fit the ctx arg_name_ids = [ var_id for var_id, var_ctx in six.iteritems(var_ids_dict) if isinstance(var_ctx[0], ctx) ] # 2. args should contain modified var ids in if-body or else-body # case: # # ``` # if b < 1: # z = y # else: # z = x # ``` # # In the above case, `z` should be in the args of cond() if modified_ids_dict: arg_name_ids = set(arg_name_ids) | set(modified_ids_dict) # 3. args should not contain the vars not in parent ids # case : # # ``` # x = 1 # if x > y: # z = [v for v in range(i)] # ``` # # In the above case, `v` should not be in the args of cond() arg_name_ids = list(set(arg_name_ids) & set(parent_ids_dict)) arg_name_ids.sort() args = [ gast.Name(id=name_id, ctx=gast.Load(), annotation=None, type_comment=None) for name_id in arg_name_ids ] arguments = gast.arguments(args=args, posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]) return arguments
def visit_arguments(self, node): new_node = gast.arguments( self._visit(node.args), self._visit(node.vararg), [], # kwonlyargs [], # kw_defaults self._visit(node.kwarg), self._visit(node.defaults), ) return new_node
def visit_Compare(self, node): node = self.generic_visit(node) if len(node.ops) > 1: # in case we have more than one compare operator # we generate an auxiliary function # that lazily evaluates the needed parameters imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx) imported_ids = sorted(imported_ids) binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids] # name of the new function forged_name = "{0}_compare{1}".format(self.prefix, len(self.compare_functions)) # call site call = ast.Call(ast.Name(forged_name, ast.Load(), None), binded_args, []) # new function arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids] args = ast.arguments(arg_names, None, [], [], None, []) body = [] # iteratively fill the body (yeah, feel your body!) if is_trivially_copied(node.left): prev_holder = node.left else: body.append( ast.Assign([ast.Name('$0', ast.Store(), None)], node.left)) prev_holder = ast.Name('$0', ast.Load(), None) for i, exp in enumerate(node.comparators): if is_trivially_copied(exp): holder = exp else: body.append( ast.Assign( [ast.Name('${}'.format(i + 1), ast.Store(), None)], exp)) holder = ast.Name('${}'.format(i + 1), ast.Load(), None) cond = ast.Compare(prev_holder, [node.ops[i]], [holder]) body.append( ast.If( cond, [ast.Pass()], [ast.Return(path_to_attr(('__builtin__', 'False')))])) prev_holder = holder body.append(ast.Return(path_to_attr(('__builtin__', 'True')))) forged_fdef = ast.FunctionDef(forged_name, args, body, [], None) self.compare_functions.append(forged_fdef) return call else: return node
def visit_ListComp(self, node): if node in self.optimizable_comprehension: self.update = True self.generic_visit(node) iterList = [] varList = [] for gen in node.generators: iterList.append(self.make_Iterator(gen)) varList.append(ast.Name(gen.target.id, ast.Param(), None)) # If dim = 1, product is useless if len(iterList) == 1: iterAST = iterList[0] varAST = ast.arguments([varList[0]], None, [], [], None, []) else: self.use_itertools = True prodName = ast.Attribute( value=ast.Name(id='itertools', ctx=ast.Load(), annotation=None), attr='product', ctx=ast.Load()) iterAST = ast.Call(prodName, iterList, []) varAST = ast.arguments([ast.Tuple(varList, ast.Store())], None, [], [], None, []) mapName = ast.Attribute( value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None), attr='map', ctx=ast.Load()) ldBodymap = node.elt ldmap = ast.Lambda(varAST, ldBodymap) return ast.Call(mapName, [ldmap, iterAST], []) else: return self.generic_visit(node)
def visit_arguments(self, node): new_node = gast.arguments( self._visit(node.args), [], # posonlyargs self._visit(node.vararg), self._visit(node.kwonlyargs), self._visit(node.kw_defaults), self._visit(node.kwarg), self._visit(node.defaults), ) return gast.copy_location(new_node, node)
def visitComp(self, node, make_attr): if node in self.optimizable_comprehension: self.update = True self.generic_visit(node) iters = [self.make_Iterator(gen) for gen in node.generators] variables = [ ast.Name(gen.target.id, ast.Param(), None, None) for gen in node.generators ] # If dim = 1, product is useless if len(iters) == 1: iterAST = iters[0] varAST = ast.arguments([variables[0]], [], None, [], [], None, []) else: self.use_itertools = True prodName = ast.Attribute(value=ast.Name(id=mangle('itertools'), ctx=ast.Load(), annotation=None, type_comment=None), attr='product', ctx=ast.Load()) varid = variables[0].id # retarget this id, it's free renamings = {v.id: (i, ) for i, v in enumerate(variables)} node.elt = ConvertToTuple(varid, renamings).visit(node.elt) iterAST = ast.Call(prodName, iters, []) varAST = ast.arguments( [ast.Name(varid, ast.Param(), None, None)], [], None, [], [], None, []) ldBodymap = node.elt ldmap = ast.Lambda(varAST, ldBodymap) return make_attr(ldmap, iterAST) else: return self.generic_visit(node)
def visit_arguments(self, node): new_node = gast.arguments( [self._visit(n) for n in node.args], self._make_annotated_arg(node, node.vararg, self._visit(node.varargannotation)), [self._visit(n) for n in node.kwonlyargs], self._visit(node.kw_defaults), self._make_annotated_arg(node, node.kwarg, self._visit(node.kwargannotation)), self._visit(node.defaults), ) return new_node
def visit_GeneratorExp(self, node): if node in self.optimizable_comprehension: self.update = True self.generic_visit(node) iters = [self.make_Iterator(gen) for gen in node.generators] variables = [ ast.Name(gen.target.id, ast.Param(), None) for gen in node.generators ] # If dim = 1, product is useless if len(iters) == 1: iterAST = iters[0] varAST = ast.arguments([variables[0]], None, [], [], None, []) else: prodName = ast.Attribute(value=ast.Name(id='itertools', ctx=ast.Load(), annotation=None), attr='product', ctx=ast.Load()) iterAST = ast.Call(prodName, iters, []) varAST = ast.arguments([ast.Tuple(variables, ast.Store())], None, [], [], None, []) imapName = ast.Attribute(value=ast.Name(id=MODULE, ctx=ast.Load(), annotation=None), attr=IMAP, ctx=ast.Load()) ldBodyimap = node.elt ldimap = ast.Lambda(varAST, ldBodyimap) return ast.Call(imapName, [ldimap, iterAST], []) else: return self.generic_visit(node)
def visit_GeneratorExp(self, node): if node in self.optimizable_comprehension: self.update = True self.generic_visit(node) iters = [self.make_Iterator(gen) for gen in node.generators] variables = [ast.Name(gen.target.id, ast.Param(), None) for gen in node.generators] # If dim = 1, product is useless if len(iters) == 1: iterAST = iters[0] varAST = ast.arguments([variables[0]], None, [], [], None, []) else: prodName = ast.Attribute( value=ast.Name(id=mangle('itertools'), ctx=ast.Load(), annotation=None), attr='product', ctx=ast.Load()) iterAST = ast.Call(prodName, iters, []) varAST = ast.arguments([ast.Tuple(variables, ast.Store())], None, [], [], None, []) imapName = ast.Attribute( value=ast.Name(id=ASMODULE, ctx=ast.Load(), annotation=None), attr=IMAP, ctx=ast.Load()) ldBodyimap = node.elt ldimap = ast.Lambda(varAST, ldBodyimap) return ast.Call(imapName, [ldimap, iterAST], []) else: return self.generic_visit(node)
def __init__(self, **kwargs): self.argument_effects = kwargs.get('argument_effects', (UpdateEffect(), ) * 11) self.global_effects = kwargs.get('global_effects', False) self.return_alias = kwargs.get('return_alias', lambda x: {UnboundValue}) self.args = ast.arguments( [ast.Name(n, ast.Param(), None) for n in kwargs.get('args', [])], None, [], [], None, [to_ast(d) for d in kwargs.get('defaults', [])]) self.return_range = kwargs.get("return_range", lambda call: UNKNOWN_RANGE) self.return_range_content = kwargs.get("return_range_content", lambda c: UNKNOWN_RANGE)
def __init__(self, **kwargs): self.argument_effects = kwargs.get('argument_effects', (UpdateEffect(),) * 11) self.global_effects = kwargs.get('global_effects', False) self.return_alias = kwargs.get('return_alias', lambda x: {UnboundValue}) self.args = ast.arguments( [ast.Name(n, ast.Param(), None) for n in kwargs.get('args', [])], None, [], [], None, [to_ast(d) for d in kwargs.get('defaults', [])]) self.return_range = kwargs.get("return_range", lambda call: UNKNOWN_RANGE) self.return_range_content = kwargs.get("return_range_content", lambda c: UNKNOWN_RANGE)
def visit_AnyComp(self, node, comp_type, *path): self.update = True node.elt = self.visit(node.elt) name = "{0}_comprehension{1}".format(comp_type, self.count) self.count += 1 args = self.gather(ImportedIds, node) self.count_iter = 0 starget = "__target" body = reduce(self.nest_reducer, reversed(node.generators), ast.Expr( ast.Call( reduce(lambda x, y: ast.Attribute(x, y, ast.Load()), path[1:], ast.Name(path[0], ast.Load(), None, None)), [ast.Name(starget, ast.Load(), None, None), node.elt], [], ) ) ) # add extra metadata to this node metadata.add(body, metadata.Comprehension(starget)) init = ast.Assign( [ast.Name(starget, ast.Store(), None, None)], ast.Call( ast.Attribute( ast.Name('builtins', ast.Load(), None, None), comp_type, ast.Load() ), [], [],) ) result = ast.Return(ast.Name(starget, ast.Load(), None, None)) sargs = [ast.Name(arg, ast.Param(), None, None) for arg in args] fd = ast.FunctionDef(name, ast.arguments(sargs, [], None, [], [], None, []), [init, body, result], [], None, None) metadata.add(fd, metadata.Local()) self.ctx.module.body.append(fd) return ast.Call( ast.Name(name, ast.Load(), None, None), [ast.Name(arg.id, ast.Load(), None, None) for arg in sargs], [], ) # no sharing !
def visitComp(self, node, make_attr): if node in self.optimizable_comprehension: self.update = True self.generic_visit(node) iters = [self.make_Iterator(gen) for gen in node.generators] variables = [ast.Name(gen.target.id, ast.Param(), None) for gen in node.generators] # If dim = 1, product is useless if len(iters) == 1: iterAST = iters[0] varAST = ast.arguments([variables[0]], None, [], [], None, []) else: self.use_itertools = True prodName = ast.Attribute( value=ast.Name(id=mangle('itertools'), ctx=ast.Load(), annotation=None), attr='product', ctx=ast.Load()) varid = variables[0].id # retarget this id, it's free renamings = {v.id: (i,) for i, v in enumerate(variables)} node.elt = ConvertToTuple(varid, renamings).visit(node.elt) iterAST = ast.Call(prodName, iters, []) varAST = ast.arguments([ast.Name(varid, ast.Param(), None)], None, [], [], None, []) ldBodymap = node.elt ldmap = ast.Lambda(varAST, ldBodymap) return make_attr(ldmap, iterAST) else: return self.generic_visit(node)
def make_Iterator(self, gen): if gen.ifs: ldFilter = ast.Lambda( ast.arguments([ast.Name(gen.target.id, ast.Param(), None)], None, [], [], None, []), ast.BoolOp(ast.And(), gen.ifs) if len(gen.ifs) > 1 else gen.ifs[0]) ifilterName = ast.Attribute( value=ast.Name(id=ASMODULE, ctx=ast.Load(), annotation=None), attr=IFILTER, ctx=ast.Load()) return ast.Call(ifilterName, [ldFilter, gen.iter], []) else: return gen.iter
def generate_FunctionDef(self): """Generate a FunctionDef node.""" # Generate the arguments, register them as available arg_vars = self.sample_node_list( low=2, high=10, generator=lambda: self.generate_Name(gast.Param())) args = gast.arguments(arg_vars, None, [], [], None, []) # Generate the function body body = self.sample_node_list( low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement) body.append(self.generate_Return()) fn_name = self.generate_Name().id node = gast.FunctionDef(fn_name, args, body, (), None) return node
def make_Iterator(self, gen): if gen.ifs: ldFilter = ast.Lambda( ast.arguments([ast.Name(gen.target.id, ast.Param(), None)], None, [], [], None, []), ast.BoolOp(ast.And(), gen.ifs) if len(gen.ifs) > 1 else gen.ifs[0]) ifilterName = ast.Attribute(value=ast.Name(id=MODULE, ctx=ast.Load(), annotation=None), attr=IFILTER, ctx=ast.Load()) return ast.Call(ifilterName, [ldFilter, gen.iter], []) else: return gen.iter
def wrap_func_ast( name: str, args: List[str], block: List[AST], returns: List[str] = [], return_tuple: bool = False, ) -> FunctionDef: """Wrap the given code block in a function as a FunctionDef AST node. Args: name: The name of the function wrapping the block of code. args: List of argument names which the wrapping function accepts block: List of AST nodes reprsenting the code block being wrapped by the wrapping function. The code block should not contain `return` statements returns: List of variable names to return from the wrapping functions. return_tuple: Whether to force the wrapping function to return to be a tuple, irregardless of whether multiple values are actually returned. Returns: The created function wrapping the given code block. """ # append return statement if actually returning variables if len(returns) > 0: # convert return names to return AST node return_ast = Return( value=[TupleAST(elts=[name_ast(r) for r in returns], ctx=Load())] if len(returns) > 1 or return_tuple else name_ast(returns[0]) ) block = block + [return_ast] return FunctionDef( name=name, args=arguments( args=[name_ast(a, Param()) for a in args], defaults=[], posonlyargs=[], kwonlyargs=[], kw_defaults=[], kwarg=None, vararg=None, ), body=block, decorator_list=[], returns="", type_comment="", )
def visit_Compare(self, node): node = self.generic_visit(node) if len(node.ops) > 1: # in case we have more than one compare operator # we generate an auxiliary function # that lazily evaluates the needed parameters imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx) imported_ids = sorted(imported_ids) binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids] # name of the new function forged_name = "{0}_compare{1}".format(self.prefix, len(self.compare_functions)) # call site call = ast.Call( ast.Name(forged_name, ast.Load(), None), binded_args, []) # new function arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids] args = ast.arguments(arg_names, None, [], [], None, []) body = [] # iteratively fill the body (yeah, feel your body!) body.append(ast.Assign([ast.Name('$0', ast.Store(), None)], node.left)) for i, exp in enumerate(node.comparators): body.append(ast.Assign([ast.Name('${}'.format(i+1), ast.Store(), None)], exp)) cond = ast.Compare(ast.Name('${}'.format(i), ast.Load(), None), [node.ops[i]], [ast.Name('${}'.format(i+1), ast.Load(), None)]) body.append(ast.If(cond, [ast.Pass()], [ast.Return(ast.Num(0))])) body.append(ast.Return(ast.Num(1))) forged_fdef = ast.FunctionDef(forged_name, args, body, [], None) self.compare_functions.append(forged_fdef) return call else: return node
def visit_AnyComp(self, node, comp_type, *path): self.update = True node.elt = self.visit(node.elt) name = "{0}_comprehension{1}".format(comp_type, self.count) self.count += 1 args = self.passmanager.gather(ImportedIds, node, self.ctx) self.count_iter = 0 starget = "__target" body = reduce(self.nest_reducer, reversed(node.generators), ast.Expr( ast.Call( reduce(lambda x, y: ast.Attribute(x, y, ast.Load()), path[1:], ast.Name(path[0], ast.Load(), None)), [ast.Name(starget, ast.Load(), None), node.elt], [], ) ) ) # add extra metadata to this node metadata.add(body, metadata.Comprehension(starget)) init = ast.Assign( [ast.Name(starget, ast.Store(), None)], ast.Call( ast.Attribute( ast.Name('__builtin__', ast.Load(), None), comp_type, ast.Load() ), [], [],) ) result = ast.Return(ast.Name(starget, ast.Load(), None)) sargs = sorted(ast.Name(arg, ast.Param(), None) for arg in args) fd = ast.FunctionDef(name, ast.arguments(sargs, None, [], [], None, []), [init, body, result], [], None) self.ctx.module.body.append(fd) return ast.Call( ast.Name(name, ast.Load(), None), [ast.Name(arg.id, ast.Load(), None) for arg in sargs], [], ) # no sharing !
def outline(name, formal_parameters, out_parameters, stmts, has_return, has_break, has_cont): args = ast.arguments( [ast.Name(fp, ast.Param(), None) for fp in formal_parameters], None, [], [], None, []) if isinstance(stmts, ast.expr): assert not out_parameters, "no out parameters with expr" fdef = ast.FunctionDef(name, args, [ast.Return(stmts)], [], None) else: fdef = ast.FunctionDef(name, args, stmts, [], None) # this is part of a huge trick that plays with delayed type inference # it basically computes the return type based on out parameters, and # the return statement is unconditionally added so if we have other # returns, there will be a computation of the output type based on the # __combined of the regular return types and this one The original # returns have been patched above to have a different type that # cunningly combines with this output tuple # # This is the only trick I found to let pythran compute both the output # variable type and the early return type. But hey, a dirty one :-/ stmts.append( ast.Return( ast.Tuple( [ast.Name(fp, ast.Load(), None) for fp in out_parameters], ast.Load() ) ) ) if has_return: pr = PatchReturn(stmts[-1], has_break or has_cont) pr.visit(fdef) if has_break or has_cont: if not has_return: stmts[-1].value = ast.Tuple([ast.Num(LOOP_NONE), stmts[-1].value], ast.Load()) pbc = PatchBreakContinue(stmts[-1]) pbc.visit(fdef) return fdef
def test_load_ast(self): node = gast.FunctionDef( name='f', args=gast.arguments( args=[ gast.Name( 'a', ctx=gast.Param(), annotation=None, type_comment=None) ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[ gast.Return( gast.BinOp( op=gast.Add(), left=gast.Name( 'a', ctx=gast.Load(), annotation=None, type_comment=None), right=gast.Constant(1, kind=None))) ], decorator_list=[], returns=None, type_comment=None) module, source, _ = loader.load_ast(node) expected_source = """ # coding=utf-8 def f(a): return (a + 1) """ self.assertEqual( textwrap.dedent(expected_source).strip(), source.strip()) self.assertEqual(2, module.f(1)) with open(module.__file__, 'r') as temp_output: self.assertEqual( textwrap.dedent(expected_source).strip(), temp_output.read().strip())
def create_lambda_node(func_or_expr_node, is_if_expr=False): body = func_or_expr_node if not is_if_expr: body = gast.Call(func=gast.Name(id=func_or_expr_node.name, ctx=gast.Load(), annotation=None, type_comment=None), args=[func_or_expr_node.args], keywords=[]) lambda_node = gast.Lambda(args=gast.arguments(args=[], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=body) return lambda_node
def visit_GeneratorExp(self, node): self.update = True node.elt = self.visit(node.elt) name = "generator_expression{0}".format(self.count) self.count += 1 args = self.passmanager.gather(ImportedIds, node, self.ctx) self.count_iter = 0 body = reduce(self.nest_reducer, reversed(node.generators), ast.Expr(ast.Yield(node.elt))) sargs = [ast.Name(arg, ast.Param(), None) for arg in args] fd = ast.FunctionDef(name, ast.arguments(sargs, None, [], [], None, []), [body], [], None) self.ctx.module.body.append(fd) return ast.Call( ast.Name(name, ast.Load(), None), [ast.Name(arg.id, ast.Load(), None) for arg in sargs], [], ) # no sharing !
def visit_arguments(self, node): if node.vararg: vararg = ast.Name(node.vararg, ast.Param()) ast.copy_location(vararg, node) else: vararg = None if node.kwarg: kwarg = ast.Name(node.kwarg, ast.Param()) ast.copy_location(kwarg, node) else: kwarg = None new_node = gast.arguments( self._visit(node.args), self._visit(vararg), [], # kwonlyargs [], # kw_defaults self._visit(kwarg), self._visit(node.defaults), ) return new_node
def visit_Module(self, node): """Turn globals assignment to functionDef and visit function defs. """ module_body = list() # Gather top level assigned variables. for stmt in node.body: if not isinstance(stmt, ast.Assign): continue for target in stmt.targets: if not isinstance(target, ast.Name): raise PythranSyntaxError( "Top-level assignment to an expression.", target) if target.id in self.to_expand: raise PythranSyntaxError( "Multiple top-level definition of %s." % target.id, target) self.to_expand.add(target.id) for stmt in node.body: if isinstance(stmt, ast.Assign): self.local_decl = set() cst_value = self.visit(stmt.value) for target in stmt.targets: assert isinstance(target, ast.Name) module_body.append( ast.FunctionDef(target.id, ast.arguments([], None, [], [], None, []), [ast.Return(value=cst_value)], [], None)) metadata.add(module_body[-1].body[0], metadata.StaticReturn()) else: self.local_decl = self.passmanager.gather( LocalNameDeclarations, stmt, self.ctx) module_body.append(self.visit(stmt)) node.body = module_body return node
def visit_GeneratorExp(self, node): self.update = True node.elt = self.visit(node.elt) name = "generator_expression{0}".format(self.count) self.count += 1 args = self.passmanager.gather(ImportedIds, node, self.ctx) self.count_iter = 0 body = reduce(self.nest_reducer, reversed(node.generators), ast.Expr(ast.Yield(node.elt)) ) sargs = [ast.Name(arg, ast.Param(), None) for arg in args] fd = ast.FunctionDef(name, ast.arguments(sargs, None, [], [], None, []), [body], [], None) self.ctx.module.body.append(fd) return ast.Call( ast.Name(name, ast.Load(), None), [ast.Name(arg.id, ast.Load(), None) for arg in sargs], [], ) # no sharing !
def visit_Module(self, node): """Turn globals assignment to functionDef and visit function defs. """ module_body = list() # Gather top level assigned variables. for stmt in node.body: if not isinstance(stmt, ast.Assign): continue for target in stmt.targets: if not isinstance(target, ast.Name): raise PythranSyntaxError( "Top-level assignment to an expression.", target) if target.id in self.to_expand: raise PythranSyntaxError( "Multiple top-level definition of %s." % target.id, target) self.to_expand.add(target.id) for stmt in node.body: if isinstance(stmt, ast.Assign): self.local_decl = set() cst_value = self.visit(stmt.value) for target in stmt.targets: assert isinstance(target, ast.Name) module_body.append( ast.FunctionDef( target.id, ast.arguments([], None, [], [], None, []), [ast.Return(value=cst_value)], [], None)) metadata.add(module_body[-1].body[0], metadata.StaticReturn()) else: self.local_decl = self.passmanager.gather( LocalNameDeclarations, stmt, self.ctx) module_body.append(self.visit(stmt)) node.body = module_body return node
def get_while_stmt_nodes(self, node): loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) new_stmts = [] # Python can create variable in loop and use it out of loop, E.g. # # while x < 10: # x += 1 # y = x # z = y # # We need to create static variable for those variables for name in create_var_names: if "." not in name: new_stmts.append(create_static_variable_gast_node(name)) condition_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_CONDITION_PREFIX), args=gast.arguments( args=[ gast.Name( id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=[gast.Return(value=node.test)], decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(condition_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) new_body = node.body new_body.append( gast.Return(value=generate_name_node( loop_var_names, ctx=gast.Load()))) body_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_BODY_PREFIX), args=gast.arguments( args=[ gast.Name( id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=new_body, decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(body_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) while_loop_nodes = create_while_nodes( condition_func_node.name, body_func_node.name, loop_var_names) new_stmts.extend(while_loop_nodes) return new_stmts
def visit_Module(self, node): """Turn globals assignment to functionDef and visit function defs. """ module_body = list() symbols = set() # Gather top level assigned variables. for stmt in node.body: if isinstance(stmt, (ast.Import, ast.ImportFrom)): for alias in stmt.names: name = alias.asname or alias.name symbols.add(name) # no warning here elif isinstance(stmt, ast.FunctionDef): if stmt.name in symbols: raise PythranSyntaxError( "Multiple top-level definition of %s." % stmt.name, stmt) else: symbols.add(stmt.name) if not isinstance(stmt, ast.Assign): continue for target in stmt.targets: if not isinstance(target, ast.Name): raise PythranSyntaxError( "Top-level assignment to an expression.", target) if target.id in self.to_expand: raise PythranSyntaxError( "Multiple top-level definition of %s." % target.id, target) if isinstance(stmt.value, ast.Name): if stmt.value.id in symbols: continue # create aliasing between top level symbols self.to_expand.add(target.id) for stmt in node.body: if isinstance(stmt, ast.Assign): # that's not a global var, but a module/function aliasing if all( isinstance(t, ast.Name) and t.id not in self.to_expand for t in stmt.targets): module_body.append(stmt) continue self.local_decl = set() cst_value = GlobalTransformer().visit(self.visit(stmt.value)) for target in stmt.targets: assert isinstance(target, ast.Name) module_body.append( ast.FunctionDef( target.id, ast.arguments([], [], None, [], [], None, []), [ast.Return(value=cst_value)], [], None, None)) metadata.add(module_body[-1].body[0], metadata.StaticReturn()) else: self.local_decl = self.gather(LocalNameDeclarations, stmt) module_body.append(self.visit(stmt)) self.update |= bool(self.to_expand) node.body = module_body return node
def get_for_stmt_nodes(self, node): # TODO: consider for - else in python # 1. get key statements for different cases # NOTE 1: three key statements: # 1). init_stmts: list[node], prepare nodes of for loop, may not only one # 2). cond_stmt: node, condition node to judge whether continue loop # 3). body_stmts: list[node], updated loop body, sometimes we should change # the original statement in body, not just append new statement # # NOTE 2: The following `for` statements will be transformed to `while` statements: # 1). for x in range(*) # 2). for x in iter_var # 3). for i, x in enumerate(*) current_for_node_parser = ForNodeVisitor(node) stmts_tuple = current_for_node_parser.parse() if stmts_tuple is None: return [node] init_stmts, cond_stmt, body_stmts = stmts_tuple # 2. get original loop vars loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases, # we need append new loop var & remove useless loop var # 1. for x in var -> x is no need # 2. for i, x in enumerate(var) -> x is no need if current_for_node_parser.is_for_iter( ) or current_for_node_parser.is_for_enumerate_iter(): iter_var_name = current_for_node_parser.iter_var_name iter_idx_name = current_for_node_parser.iter_idx_name loop_var_names.add(iter_idx_name) if iter_var_name not in create_var_names: loop_var_names.remove(iter_var_name) # 3. prepare result statement list new_stmts = [] # Python can create variable in loop and use it out of loop, E.g. # # for x in range(10): # y += x # print(x) # x = 10 # # We need to create static variable for those variables for name in create_var_names: if "." not in name: new_stmts.append(create_static_variable_gast_node(name)) # 4. append init statements new_stmts.extend(init_stmts) # 5. create & append condition function node condition_func_node = gast.FunctionDef( name=unique_name.generate(FOR_CONDITION_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=[gast.Return(value=cond_stmt)], decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(condition_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) # 6. create & append loop body function node # append return values for loop body body_stmts.append( gast.Return( value=generate_name_node(loop_var_names, ctx=gast.Load()))) body_func_node = gast.FunctionDef( name=unique_name.generate(FOR_BODY_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=body_stmts, decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(body_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) # 7. create & append while loop node while_loop_node = create_while_node(condition_func_node.name, body_func_node.name, loop_var_names) new_stmts.append(while_loop_node) return new_stmts
def get_while_stmt_nodes(self, node): # TODO: consider while - else in python if not self.name_visitor.is_control_flow_loop(node): return [node] loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) new_stmts = [] # Python can create variable in loop and use it out of loop, E.g. # # while x < 10: # x += 1 # y = x # z = y # # We need to create static variable for those variables for name in create_var_names: if "." not in name: new_stmts.append(create_static_variable_gast_node(name)) # while x < 10 in dygraph should be convert into static tensor < 10 for name in loop_var_names: new_stmts.append(to_static_variable_gast_node(name)) logical_op_transformer = LogicalOpTransformer(node.test) cond_value_node = logical_op_transformer.transform() condition_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_CONDITION_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=[gast.Return(value=cond_value_node)], decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(condition_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) new_body = node.body new_body.append( gast.Return( value=generate_name_node(loop_var_names, ctx=gast.Load()))) body_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_BODY_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=new_body, decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(body_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) while_loop_node = create_while_node(condition_func_node.name, body_func_node.name, loop_var_names) new_stmts.append(while_loop_node) return new_stmts
def get_for_stmt_nodes(self, node): # TODO: consider for - else in python if not self.name_visitor.is_control_flow_loop(node): return [node] # TODO: support non-range case range_call_node = self.get_for_range_node(node) if range_call_node is None: return [node] if not isinstance(node.target, gast.Name): return [node] iter_var_name = node.target.id init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts( iter_var_name, range_call_node.args) loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) new_stmts = [] # Python can create variable in loop and use it out of loop, E.g. # # for x in range(10): # y += x # print(x) # x = 10 # # We need to create static variable for those variables for name in create_var_names: if "." not in name: new_stmts.append(create_static_variable_gast_node(name)) new_stmts.append(init_stmt) # for x in range(10) in dygraph should be convert into static tensor + 1 <= 10 for name in loop_var_names: new_stmts.append(to_static_variable_gast_node(name)) condition_func_node = gast.FunctionDef( name=unique_name.generate(FOR_CONDITION_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=[gast.Return(value=cond_stmt)], decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(condition_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) new_body = node.body new_body.append(change_stmt) new_body.append( gast.Return( value=generate_name_node(loop_var_names, ctx=gast.Load()))) body_func_node = gast.FunctionDef( name=unique_name.generate(FOR_BODY_PREFIX), args=gast.arguments(args=[ gast.Name(id=name, ctx=gast.Param(), annotation=None, type_comment=None) for name in loop_var_names ], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), body=new_body, decorator_list=[], returns=None, type_comment=None) for name in loop_var_names: if "." in name: rename_transformer = RenameTransformer(body_func_node) rename_transformer.rename( name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) while_loop_node = create_while_node(condition_func_node.name, body_func_node.name, loop_var_names) new_stmts.append(while_loop_node) return new_stmts
def visit_Module(self, node): """Turn globals assignment to functionDef and visit function defs. """ module_body = list() symbols = set() # Gather top level assigned variables. for stmt in node.body: if isinstance(stmt, (ast.Import, ast.ImportFrom)): for alias in stmt.names: name = alias.asname or alias.name symbols.add(name) # no warning here elif isinstance(stmt, ast.FunctionDef): if stmt.name in symbols: raise PythranSyntaxError( "Multiple top-level definition of %s." % stmt.name, stmt) else: symbols.add(stmt.name) if not isinstance(stmt, ast.Assign): continue for target in stmt.targets: if not isinstance(target, ast.Name): raise PythranSyntaxError( "Top-level assignment to an expression.", target) if target.id in self.to_expand: raise PythranSyntaxError( "Multiple top-level definition of %s." % target.id, target) if isinstance(stmt.value, ast.Name): if stmt.value.id in symbols: continue # create aliasing between top level symbols self.to_expand.add(target.id) for stmt in node.body: if isinstance(stmt, ast.Assign): # that's not a global var, but a module/function aliasing if all(isinstance(t, ast.Name) and t.id not in self.to_expand for t in stmt.targets): module_body.append(stmt) continue self.local_decl = set() cst_value = self.visit(stmt.value) for target in stmt.targets: assert isinstance(target, ast.Name) module_body.append( ast.FunctionDef(target.id, ast.arguments([], None, [], [], None, []), [ast.Return(value=cst_value)], [], None)) metadata.add(module_body[-1].body[0], metadata.StaticReturn()) else: self.local_decl = self.passmanager.gather( LocalNameDeclarations, stmt, self.ctx) module_body.append(self.visit(stmt)) node.body = module_body return node