def __init__(self): """Initialize the PyStochCompiler. This initializes the SourceGenerator, and creates a new list for pystoch identifiers (in order to avoid hash collisions). """ super(PyStochCompiler, self).__init__(' ' * 4, False) self.idens = [] self.inloop = False self.inclass = False self.infunc = False self.line = IntegerStack()
class PyStochCompiler(codegen.SourceGenerator): """A visitor class to transform a python abstract syntax tree into pystoch. This class inherits from pystoch.codegen.SourceGenerator, which is a NodeVisitor that transforms a python abstract syntax tree (AST) into python code. The PyStochCompiler takes it one step further, overriding the appropriate functions from SourceGenerator in order to insert PyStoch necessary code and perform PyStoch transformations. See Also -------- pystoch.codegen pystoch.ast _ast ast """ def __init__(self): """Initialize the PyStochCompiler. This initializes the SourceGenerator, and creates a new list for pystoch identifiers (in order to avoid hash collisions). """ super(PyStochCompiler, self).__init__(' ' * 4, False) self.idens = [] self.inloop = False self.inclass = False self.infunc = False self.line = IntegerStack() def _gen_iden(self, node): """Generate a random unique PyStoch identifier. All PyStoch identifiers are prefixed by 'PYSTOCH_', followed by an eight-character hexadecimal string. The hexadecimal is the first eight characters of the md5 digest of the current date and time concatenated with the hash of `node`. All generated ids are stored, and if a collision is detected, the function will try again (with a different date and time) to generate a unique id. Parameters ---------- node : ast.AST The node to generate an id for Returns ------- out : string The identifier for `node` """ now = str(datetime.datetime.now()) nodeid = str(hash(node)) iden = hashlib.md5(now + nodeid).hexdigest()[:8] if iden in self.idens: iden = self._gen_iden(node) self.idens.append(iden) return "PYSTOCH_%s" % iden @property def source(self): """The source generated by the PyStochCompiler after `compile` has been called. """ for i in xrange(len(self.result)): if not isinstance(self.result[i], str): print "Something went wrong! Expected a string, but got %s." % type( self.result[i]) if i > 0: print "Previous item: %s" % repr(self.result[i - 1]) print "This item: %s" % repr(self.result[i]) if i < len(self.result) - 1: print "Next item: %s" % repr(self.result[i + 1]) return '' return ''.join(self.result) def insert(self, statements): """Insert non-node statements into the source. This is used for inserting non-node statements into the source compilation. These generally should only be PyStoch-specific statements; if they are a statement that needs to be evaluated by the PyStoch compiler then this function should NOT be used. You can pass in either a list/tuple of statements, or a single statement. Parameters ---------- statements : string or list or tuple The statement or statements to be inserted """ # turn it into a list if it's not already if not isinstance(statements, (list, tuple)): statements = [statements] for statement in statements: if not isinstance(statement, str): raise ValueError, "statement is not a string" # write each statement, separated by a newline for statement in statements: super(PyStochCompiler, self).newline() self.write(statement) def compile(self, src): """Compile python source to pystoch source. Parameters ---------- src : string (default=None) If source is a valid path, it will load the source from the path and compile that. If it is not, then it will be treated as the text of the source itself and be compiled. Returns ------- out : string The compiled source """ if not isinstance(src, str): raise ValueError("src must be a string") # read in the source from a file if os.path.exists(src): source = open(src, 'r').read() # ... or just treat src as the actual source else: source = src # parse the source into an AST node = ast.parse(source) # generate an identifier for the module/file, and push this # identifier onto the module stack. Also push a 0 onto the # line stack. iden = self._gen_iden(node) self.line.push(0) self.insert("PYSTOCHOBJ.line_stack.push(0)") # compile the rest of the module self.visit(node) # and finally, pop the line and module stacks self.write('\n') self.insert("PYSTOCHOBJ.line_stack.pop()") self.line.pop() def newline(self, node=None, extra=0): """Insert a newline. This inserts a newline in the same way as SourceGenerator, with the additional catch of incrementing the line stack if the node asking for the newline is non-null (if it's null, then incrementing the line stack is pointless because nothing will happen between now and the next time a newline occurs). Parameters ---------- node : ast.AST (default=None) The ast node to insert a newline for extra : integer (default=0) The number of extra newlines to insert Notes ----- This function doesn't actually immediately insert a newline, it increments the number of newlines to insert and then inserts them when write() is called. """ # call the parent newline method super(PyStochCompiler, self).newline(node=node, extra=extra) # return if the node is null if node is None: return # otherwise, incremet the line stack and then insert another # newline self.line.increment() self.write("PYSTOCHOBJ.line_stack.set(%s)" % self.line.peek()) super(PyStochCompiler, self).newline(node=node, extra=extra) def body(self, statements, write_before=None, write_after=None): """Write the body statements. This is the same as the SourceGenerator body function, with the additional parameters of write_before and write_after. These parameters allow you to insert extra stuff before and after the rest of the statements. Parameters ---------- statements : list of ast.AST nodes The statements to be written in the body write_before : list or string The statements to write before the body write_after : list or string The statements to write after the body """ # increment the level of indentation #self.new_line = True self.indentation += 1 # insert the write_before statementss if write_before is not None: self.insert(write_before) # write the actual body statements for stmt in statements: self.visit(stmt) # insert the write_after statements if write_after is not None: self.insert(write_after) # decrement the level of indentation self.indentation -= 1 def body_or_else(self, node, write_before=None, write_after=None): """Write a body as well as an else statement, if it exists. Parameters ---------- node : ast.AST node that has a body and optionally an orelse write_before : list or string The statements to write before the body write_after : list or string The statements to write after the body See Also -------- pystoch.compile.PyStochCompiler.body """ self.body(node.body, write_before=write_before, write_after=write_after) if node.orelse: self.newline() self.write('else:') self.body(node.orelse) def to_assign(self, value): """Takes a value, creates a random temporary identifier for it, and creates an Assign node, assigning the value to the identifier. Parameters ---------- value : ast.AST node that is to be the value of the Assign node Returns ------- out : tuple of string, _ast.Assign the string is the identifier of the node, and the _ast.Assign is the node that was created """ if not isinstance(value, _ast.AST): raise ValueError, "value is not an instance of _ast.AST" iden = self._gen_iden(value) node = _ast.Assign(targets=[ast.parse(iden).body[0].value], value=value) return iden, node def _count_calls(self, node): """Counts the number of Call nodes in a node. Note that this does NOT count Call nodes that are children of ListComp nodes. Parameters ---------- node : _ast.AST The node to count Call nodes in Returns ------- out : integer The number of Call nodes in `node` """ class CountCalls(NodeCounter): def visit_Call(self, node): total = 1 total += self.visit(node.func) for arg in node.args: total += self.visit(arg) for keyword in node.keywords: total += self.visit(keyword) if node.starargs is not None: total += self.visit(node.starargs) if node.kwargs is not None: total += self.visit(node.kwargs) return total def visit_ListComp(self, node): return 0 def visit_GeneratorExp(self, node): return 0 return CountCalls().visit(node) def _count_listcomps(self, node): """Counts the ListComp nodes in a node. Parameters ---------- node : _ast.AST The node to count ListComp nodes in Returns ------- out : integer 1 if there are any ListComps, 0 if there are none """ class CountListComps(NodeCounter): def visit_ListComp(self, node): return 1 def visit_GeneratorExp(self, node): return 1 def visit_Lambda(self, node): return 1 return CountListComps().visit(node) def should_rewrite(self, node, threshold=1, cthreshold=0): """Checks whether or not a node (_ast.AST) contains more than the minimum number of Call nodes and ListComps (though ListComps are only counted once, if they have more ListComps inside them they will not be counted). If so, this indicates that the node should probably be rewritten. Parameters ---------- node : _ast.AST The node to check for Call/ListComp nodes threshold : int (default=1) The default threshold for how many Call/ListComp nodes should be allowed before the node should be rewritten Returns ------- out : boolean Whether or not the node contains any Call/ListComp nodes """ ctotal = self._count_calls(node) lctotal = self._count_listcomps(node) return (ctotal > threshold) or (lctotal > cthreshold) def contains_return(self, node): """Checks whether or not a node (_ast.AST) contains any Return nodes. Parameters ---------- node : _ast.AST The node to check for Return nodes Returns ------- out : boolean Whether or not the node contains any Return nodes """ class ReturnChecker(NodeCounter): def visit_Return(self, node): return 1 return ReturnChecker().visit(node) def _extract(self, node): """Takes a node, checks to see if there are any Call nodes or ListComp nodes in it, and if so, extracts either the first ListComp or deepest Call child of the first Call node, places it inside a temporary variable, and then replaces the Call/ListComp node in `node` with the temporary variable. Returns the new, modified node. Parameters ---------- node : _ast.AST The node to extract any Call/ListComp nodes from Returns ------- out : _ast.AST The modified node, with a single Call/ListComp node removed """ class Replace(ast.NodeTransformer): # we need to pass in various parent functions def __init__(self, callback, should_rewrite, to_assign): self.callback = callback self.to_assign = to_assign self.should_rewrite = should_rewrite self.done = False def visit_Call(self, node): if self.done: return node # check function new = self.visit(node.func) if new != node.func: node.func = new self.done = True return node # check arguments for arg in xrange(len(node.args)): new = self.visit(node.args[arg]) if new != node.args[arg]: node.args[arg] = new self.done = True return node # check keyword arguments for arg in xrange(len(node.keywords)): new = self.visit(node.keywords[arg]) if new != node.keywords[arg]: node.keywords[arg] = new self.done = True return node # check *args if node.starargs is not None: new = self.visit(node.starargs) if new != node.starargs: node.starargs = new self.done = True return node # check **kwargs if node.kwargs is not None: new = self.visit(node.kwargs) if new != node.kwargs: node.kwargs = new self.done = True return node # rewrite the call node iden, assignment = self.to_assign(node) self.done = True self.callback(assignment) return ast.parse(iden).body[0].value def visit_ListComp(self, node): if self.done: return node # rewrite the listcomp iden, assignment = self.to_assign(node) self.done = True self.callback(assignment) return ast.parse(iden).body[0].value def visit_GeneratorExp(self, node): if self.done: return node iden, assignment = self.to_assign(node) self.done = True self.callback(assignment) return ast.parse(iden).body[0].value def visit_Lambda(self, node): if self.done: return node iden, assignment = self.to_assign(node) self.done = True self.callback(assignment) return ast.parse(iden).body[0].value newnode = Replace(self.visit, self.should_rewrite, self.to_assign).visit(node) return newnode def extract(self, node, threshold=1, cthreshold=0): """Extract all Call/ListComp nodes above a certain threshold from a node. This repeatedly calls self._extract(node) until self.should_rewrite(node, threshold=threshold) is False. Parameters ---------- node : _ast.AST The node to extract Call/ListComp nodes from. threshold : int (default=1) The minimum number of Call/ListComp nodes that are allowed before extraction will take place cthreshold : int (default=0) The minimum number of ListComp nodes that are allowed before extraction will take place. Returns ------- out : _ast.AST The modified `node`, minus Call/ListComp nodes. """ rw = self.should_rewrite(node, threshold=threshold, cthreshold=cthreshold) while rw: node = self._extract(node) rw = self.should_rewrite(node, threshold=threshold, cthreshold=cthreshold) return node ########### Visitor Functions ########### # 1) Statements def visit_FunctionDef(self, node): """Rewrite the FunctionDef visitor to push a new value onto the function and line stacks at the beginning of the function, and then pop those values at the end of the function. """ self.newline(extra=1) self.decorators(node) super(PyStochCompiler, self).newline() self.write('def %s(' % node.name) node.args.args.append(ast.parse('PYSTOCHOBJ').body[0].value) node.args.defaults.append(ast.parse('None').body[0].value) self.signature(node.args) self.write('):') if len(node.body) == 1 and isinstance(node.body[0], _ast.Pass): self.body(node.body) else: self.line.push(0) write_before = [ "PYSTOCHOBJ.func_stack.push('%s')" % self._gen_iden(node), "PYSTOCHOBJ.line_stack.push(0)" ] # only pop the line and function stacks if there is no return # statement if self.contains_return(node): write_after = None else: write_after = [ "PYSTOCHOBJ.line_stack.pop()", "PYSTOCHOBJ.func_stack.pop()" ] infunc = self.infunc self.infunc = True self.body(node.body, write_before=write_before, write_after=write_after) self.infunc = infunc self.line.pop() self.insert("%s.random = True" % node.name) def visit_ClassDef(self, node): """Rewrite the ClassDef visitor to push new values onto the class and line stacks at the beginning of the stack, and then to pop those values at the end of the class. """ have_args = [] def paren_or_comma(): if have_args: self.write(', ') else: have_args.append(True) self.write('(') self.newline(extra=2) self.decorators(node) self.newline(node) self.write('class %s' % node.name) for base in node.bases: paren_or_comma() self.visit(base) self.write(have_args and '):' or ':') write_before = ["random = True"] inclass = self.inclass self.inclass = True self.body(node.body, write_before=write_before) self.write('\n') self.inclass = inclass def visit_Return(self, node): """Rewrite the Return visitor function to first store the return value of the function, then pop the line and function stacks, then return the stored value. """ node = self.extract(node, threshold=0) # pop the line and function stacks if self.inloop: self.insert([ "PYSTOCHOBJ.line_stack.pop()", "PYSTOCHOBJ.loop_stack.pop()", "PYSTOCHOBJ.func_stack.pop()" ]) else: self.insert( ["PYSTOCHOBJ.line_stack.pop()", "PYSTOCHOBJ.func_stack.pop()"]) super(PyStochCompiler, self).newline() self.write("return ") if node.value is not None: self.visit(node.value) def visit_Delete(self, node): """Calls the superclass' visit_Delete method, while additionally checking to make sure that no Call nodes are present in the node being visited. See Also -------- codegen.SourceCompiler#visit_Delete """ if self.should_rewrite(node): raise UnexpectedCallException super(PyStochCompiler, self).visit_Delete(node) def visit_Assign(self, node): """Rewrite the Assign visitor function to deal with list comprehensions and function calls. """ # if the value is a list comprehension, the we need to handle # it specially if isinstance(node.value, (_ast.ListComp, _ast.GeneratorExp, _ast.Lambda)): node.value = self.extract(node.value, cthreshold=1) iden = self.visit(node.value) if isinstance(node.value, _ast.GeneratorExp): val = _ast.Call(func=_ast.Name(id=iden, ctx=_ast.Load()), args=[], keywords=[], starargs=None, kwargs=None) else: val = _ast.Name(id=iden, ctx=_ast.Load()) node = ast.Assign(value=val, targets=node.targets) elif isinstance(node.value, _ast.Dict) or \ isinstance(node.value, _ast.List) or \ isinstance(node.value, _ast.Tuple): node.value = self.extract(node.value, threshold=0) else: # do Call/ListComp extraction on the node's value node.value = self.extract(node.value, cthreshold=0) self.newline(node) for idx, target in enumerate(node.targets): if idx: self.write(' = ') self.visit(target) self.write(' = ') self.visit(node.value) def visit_AugAssign(self, node): """Rewrite the AugAssign visitor function to deal with list comprehensions and function calls. """ # do Call/ListComp extraction on the node's value # if the value is a list comprehension, the we need to handle # it specially if isinstance(node.value, (_ast.ListComp, _ast.GeneratorExp, _ast.Lambda)): node.value = self.extract(node.value, cthreshold=1) iden = self.visit_ListComp(node.value) val = _ast.Name(id=iden, ctx=_ast.Load()) node = ast.Assign(value=val, targets=node.targets) elif isinstance(node.value, _ast.Dict) or \ isinstance(node.value, _ast.List) or \ isinstance(node.value, _ast.Tuple): node.value = self.extract(node.value, threshold=0) else: node.value = self.extract(node.value, cthreshold=0) super(PyStochCompiler, self).visit_AugAssign(node) def visit_Print(self, node): """Rewrite the Print visitor function to deal with possible Call nodes. Any children with Call nodes are stored in temporary variables, and then the temporary variable is used in the actual print statement. See Also -------- codegen.SourceCompiler#visit_Print """ # if there is no call node, then we can just call super if self.should_rewrite(node): node = self.extract(node) super(PyStochCompiler, self).visit_Print(node) def visit_For(self, node): """Rewrite the For visitor function to first store the iterator of the for loop in a temporary variable, and then to loop over the contents of that variable. Additionally, push a new value onto the loop stack before entering the for loop, increment that value after each pass of the loop, and pop the value after the loop has terminated. """ node.iter = self.extract(node.iter) # push a new value onto the loop stack self.newline(node) super(PyStochCompiler, self).newline() self.insert(["PYSTOCHOBJ.loop_stack.push(0)"]) super(PyStochCompiler, self).newline() # iterate over the stored value for the for loop iterator self.write('for ') self.visit(node.target) self.write(' in ') self.visit(node.iter) self.write(':') # increment the loop stack at the end of the body inloop = self.inloop self.inloop = True self.body_or_else(node, write_before="PYSTOCHOBJ.loop_stack.increment()") self.inloop = inloop # and finally, pop the loop stack after the for loop is over self.insert("PYSTOCHOBJ.loop_stack.pop()") def visit_While(self, node): """Rewrite the While visitor function to first store the test of the while loop in a temporary variable, and then to loop over the contents of that variable. Additionally, push a new value onto the loop stack before entering the while loop, increment that value after each pass of the loop, and pop the value after the loop has terminated. """ node.test = self.extract(node.test) # push a new value onto the loop stack self.newline(node) super(PyStochCompiler, self).newline() self.insert(["PYSTOCHOBJ.loop_stack.push(0)"]) super(PyStochCompiler, self).newline() self.write('while ') self.visit(node.test) self.write(':') # increment the loop stack at the end of the body inloop = self.inloop self.inloop = True self.body_or_else(node, write_before="PYSTOCHOBJ.loop_stack.increment()") self.inloop = inloop # and finally, pop the loop stack at the end of the body self.insert("PYSTOCHOBJ.loop_stack.pop()") def visit_If(self, node): """Rewrite the If visitor function to assign the if and elif tests to temporary variables, and then check these variables in the actual if and elif statements. """ node.test = self.extract(node.test) #for i in xrange(len(node.orelse)): # node.orelse[i] = self.extract(node.orelse[i], threshold=0) self.newline(node) self.write('if ') self.visit(node.test) self.write(':') self.body(node.body) while True: else_ = node.orelse if len(else_) == 1 and isinstance(else_[0], _ast.If): node = else_[0] self.newline() self.write('elif ') self.visit(node.test) self.write(':') self.body(node.body) # handle the case were there is no else statement... elif len(else_) == 0: break else: self.newline() self.write('else:') self.body(else_) break def visit_With(self, node): """With statements are not supported at this time. """ raise NotImplementedError, "With statements are not supported at this time" def visit_Raise(self, node): """Rewrite the Raise visitor function to deal with potential Call nodes. Any children with Call nodes are stored in temporary variables, and then the variable is used in the actual raise statement. See Also -------- codegen.SourceCompiler#visit_Raise """ node.type = self.extract(node.type) if node.inst is not None: node.inst = self.extract(node.inst) if node.tback is not None: node.tback = self.extract(node.tback) super(PyStochCompiler, self).visit_Raise(node) def visit_TryExcept(self, node): """The superclass' visit_TryExcept method is called. See Also -------- codegen.SourceCompiler#visit_TryExcept """ super(PyStochCompiler, self).visit_TryExcept(node) def visit_TryFinally(self, node): """The superclass' visit_TryFinally method is called. See Also -------- codegen.SourceCompiler#visit_TryFinally """ super(PyStochCompiler, self).visit_TryFinally(node) def visit_Assert(self, node): """The Assert statement visitor function. This function is not implemented in codegen. It prints the assert statement as normal, additionally rewriting the test case and/or the message if they contain Call nodes. Parameters ---------- node : _ast.Assert The Assert node to transform into source code """ node.test = self.extract(node.test) if node.msg is not None: node.msg = self.extract(node.msg) # write the test case of the assert statement self.newline(node) self.write('assert ') self.visit(node.test) # if the message exists, then write it too if node.msg is not None: self.write(', ') self.visit(node.msg) def visit_Import(self, node): """Visit an import node (not calling superclass because it is wrong) """ for item in node.names: self.newline(node) self.write('import ') self.visit(item) def visit_ImportFrom(self, node): """The superclass' visit_ImportFrom method is called. See Also -------- codegen.SourceCompiler#visit_ImportFrom """ super(PyStochCompiler, self).visit_ImportFrom(node) def visit_Exec(self, node): """Exec statements are not supported at this time. """ raise NotImplementedError, "Exec statements are not supported at this time" def visit_Expr(self, node): """Rewrite the Expr visitor function to deal with Call nodes that are on a single line by themselves. These must be handled specially, because we want to make sure that the first Call node remains on a line by itself, e.g.: PYSTOCH_AAAAAAAA = bar() foo(PYSTOCH_AAAAAAAA) And not: PYSTOCH_AAAAAAAA = bar() PYSTOCH_BBBBBBBB = foo(PYSTOCH_AAAAAAAA) PYSTOCH_BBBBBBBB If it is not a call function, then it extracts the node's value as per usual and then calls the super visit_Expr on the new node. """ if isinstance(node.value, _ast.Call): self.visit_Call(node.value, True) else: node.value = self.extract(node.value) super(PyStochCompiler, self).visit_Expr(node) def visit_Pass(self, node): """Calls the superclass' visit_Pass method. See Also -------- codegen.SourceCompiler#visit_pass """ if not isinstance(node, _ast.AST): raise ValueError, "node must be an instance of _ast.AST" self.insert('pass') # 2) Expressions def visit_BoolOp(self, node): """Calls the superclass' visit_BoolOp method. See Also -------- codegen.SourceCompiler#visit_BoolOp """ super(PyStochCompiler, self).visit_BoolOp(node) def visit_BinOp(self, node): """Calls the superclass' visit_BinOp method. See Also -------- codegen.SourceCompiler#visit_BinOp """ self.write('(') super(PyStochCompiler, self).visit_BinOp(node) self.write(')') def visit_UnaryOp(self, node): """Calls the superclass' visit_UnaryOp method. See Also -------- codegen.SourceCompiler#visit_UnaryOp """ super(PyStochCompiler, self).visit_UnaryOp(node) def visit_Lambda(self, node): """Rewrite the Lambda visitor function to transform the lambda into a real iterator function. """ iden = self._gen_iden(node) funcnode = _ast.FunctionDef(name=iden, args=node.args, body=[_ast.Return(value=node.body)], decorator_list=[]) self.visit(funcnode) return iden def visit_IfExp(self, node): """IfExps are not supported at this time. """ raise NotImplementedError, "IfExp nodes are not supported at this time." def visit_Dict(self, node): """Rewrite the Dict visitor function to extract any list comprehensions or calls out of the keys or values. """ for k in xrange(len(node.keys)): node.keys[k] = self.extract(node.keys[k], threshold=0) for v in xrange(len(node.values)): node.values[v] = self.extract(node.values[v], threshold=0) super(PyStochCompiler, self).visit_Dict(node) def visit_Set(self, node): """Set nodes are not supported at this time. """ raise NotImplementedError, "Set nodes are not supported at this time" def visit_ListComp(self, node): """Rewrite the ListComp visitor function to turn the list comprehension into a real for loop. This is necessary to be able to correctly label any random functions that get called from within the list comprehension. Basically, this function creates a temporary variable for the list, and transforms the comprehension into a for loop that appends values onto this list. The list name is then returned, so that whatever element called the for loop can handle the assignment properly. """ # make an identifier for the list self.newline(node) iden = self._gen_iden(node) self.write("%s = []" % iden) elt = node.elt def parse_generator(nodes): """Transform the generator into a for loop. """ node = nodes[0] tempnode = ast.For() tempnode.target = node.target tempnode.iter = node.iter if len(nodes) == 1: append_node = ast.parse("%s.append(foo)" % iden).body[0] append_node.value.args[0] = elt body = [append_node] else: body = [parse_generator(nodes[1:])] if len(node.ifs) == 1: ifnode = _ast.If(test=node.ifs[0], body=body, orelse=[]) tempnode.body = [ifnode] elif len(node.ifs) > 1: ifnode = _ast.If(test=_ast.BoolOp(op=_ast.And(), values=node.ifs), body=body, orelse=[]) tempnode.body = [ifnode] else: tempnode.body = body tempnode.orelse = None return tempnode # visit the for loop self.visit(parse_generator(node.generators)) return iden def visit_SetComp(self, node): """Set comprehensions are not supported at this time. """ raise NotImplementedError, "Set comprehensions are not supported at this time" def visit_DictComp(self, node): """Dictionary comprehensions are not supported at this time. """ raise NotImplementedError, "Dictionary comprehensions are not supported at this time" def visit_GeneratorExp(self, node): """Rewrite the GeneratorExp visitor function to turn the generator expression into a iterator function. This is necessary to be able to correctly label any random functions that get called from within the generator expression. Basically, this function creates a function, and transforms the generator into a for loop that yields values from the. The function name is then returned, so that the parent node can handle the assignment properly. """ # make an identifier for the list self.newline(node) iden = self._gen_iden(node) argids = [] for gen in node.generators: argval = gen.iter argid = self._gen_iden(gen.iter) self.visit( _ast.Assign(targets=[_ast.Name(id=argid, ctx=_ast.Store())], value=argval)) argids.append(argid) elt = node.elt def parse_generator(nodes, ids): node = nodes[0] tempnode = _ast.For() tempnode.target = node.target tempnode.iter = _ast.Name(id=ids[0], ctx=_ast.Load()) if len(nodes) == 1: yield_node = _ast.Expr(value=_ast.Yield(value=elt)) body = [yield_node] else: body = [parse_generator(nodes[1:], ids[1:])] if len(node.ifs) == 1: ifnode = _ast.If(test=node.ifs[0], body=body, orelse=[]) tempnode.body = [ifnode] elif len(node.ifs) > 1: ifnode = _ast.If(test=_ast.BoolOp(op=_ast.And(), values=node.ifs), body=body, orelse=[]) tempnode.body = [ifnode] else: tempnode.body = body tempnode.orelse = None return tempnode node = _ast.FunctionDef( name=iden, args=_ast.arguments(args=[], vararg=None, kwarg=None, defaults=[]), body=[parse_generator(node.generators, argids)], decorator_list=[]) self.visit(node) return iden def visit_Yield(self, node): """Rewrite the Yield visitor function to extract calls/list comprehensions, and additionally pop the line and function stacks before yielding, and then pushing them with the same values after returning from the yield. """ node = self.extract(node) lineiden = self._gen_iden( ast.parse("PYSTOCHOBJ.line_stack.pop()").body[0]) funciden = self._gen_iden( ast.parse("PYSTOCHOBJ.func_stack.pop()").body[0]) if self.inloop: loopiden = self._gen_iden( ast.parse("PYSTOCHOBJ.loop_stack.pop()").body[0]) # pop the line and function stacks if self.inloop: self.insert([ "%s = PYSTOCHOBJ.line_stack.pop()" % lineiden, "%s = PYSTOCHOBJ.loop_stack.pop()" % loopiden, "%s = PYSTOCHOBJ.func_stack.pop()" % funciden ]) else: self.insert([ "%s = PYSTOCHOBJ.line_stack.pop()" % lineiden, "%s = PYSTOCHOBJ.func_stack.pop()" % funciden ]) super(PyStochCompiler, self).newline() self.write("yield ") if node.value is not None: self.visit(node.value) if self.inloop: self.insert([ "PYSTOCHOBJ.func_stack.push(%s)" % funciden, "PYSTOCHOBJ.loop_stack.push(%s)" % loopiden, "PYSTOCHOBJ.line_stack.push(%s)" % lineiden ]) else: self.insert([ "PYSTOCHOBJ.func_stack.push(%s)" % funciden, "PYSTOCHOBJ.line_stack.push(%s)" % lineiden ]) def visit_Compare(self, node): """Rewrite the Compare visitor function to deal with extraction. """ node.left = self.extract(node.left) for right in xrange(len(node.comparators)): node.comparators[right] = self.extract(node.comparators[right]) self.write('(') self.visit(node.left) for op, right in zip(node.ops, node.comparators): self.write(' %s ' % ast.CMPOP_SYMBOLS[type(op)]) self.visit(right) self.write(')') def visit_Call(self, node, newline=False): """Rewrite the Call visitor function to extract any child Call nodes. Parameters ---------- node : _ast.Call The Call node to rewrite. newline : boolean (default=False) Whether or not to insert a newline before transforming the Call node to source. This should be set to True when, for example, the Call node is by itself on a line. """ node.args.insert(0, node.func) node.func = ast.parse('PYSTOCHOBJ.call').body[0].value # extract each of the children of the Call node with threshold # zero, that is, we already know that we have one call node # (because we're visiting it), so we don't want any of its # children to also be call nodes threshold = 0 for arg in xrange(len(node.args)): node.args[arg] = self.extract(node.args[arg], threshold=threshold) for arg in xrange(len(node.keywords)): node.keywords[arg].value = self.extract(node.keywords[arg].value, threshold=threshold) if node.starargs is not None: node.starargs = self.extract(node.starargs, threshold=threshold) if node.kwargs is not None: node.kwargs = self.extract(node.kwargs, threshold=threshold) if newline: self.newline(node) super(PyStochCompiler, self).visit_Call(node) def visit_Attribute(self, node): """Rewrite the Attribute visitor function to deal with extraction. """ node.value = self.extract(node.value) super(PyStochCompiler, self).visit_Attribute(node) def visit_Subscript(self, node): """Rewrite the Subscript visitor function to deal with extraction. """ node.value = self.extract(node.value) node.slice = self.extract(node.slice) super(PyStochCompiler, self).visit_Subscript(node) def visit_List(self, node): """Rewrite the List visitor function to extract any list comprehensions or calls out of the keys or values. """ for elt in xrange(len(node.elts)): node.elts[elt] = self.extract(node.elts[elt], threshold=0) super(PyStochCompiler, self).visit_List(node) def visit_Tuple(self, node): """Rewrite the Tuple visitor function to extract any list comprehensions or calls out of the keys or values. """ for elt in xrange(len(node.elts)): node.elts[elt] = self.extract(node.elts[elt], threshold=0) super(PyStochCompiler, self).visit_Tuple(node) # 3) Misc def visit_Slice(self, node): """Rewrite the Slice visitor function to deal with extraction. """ if node.lower is not None: node.lower = self.extract(node.lower) if node.upper is not None: node.upper = self.extract(node.upper) if node.step is not None: node.step = self.extract(node.step) super(PyStochCompiler, self).visit_Slice(node) def visit_Index(self, node): """Rewrite the Index visitor function to deal with extraction. """ node.value = self.extract(node.value) self.visit(node.value) def visit_ExceptHandler(self, node): """The superclass' visit_ExceptHandler method is called. See Also -------- codegen.SourceCompiler#visit_ExceptHandler """ super(PyStochCompiler, self).visit_ExceptHandler(node)
class PyStochCompiler(codegen.SourceGenerator): """A visitor class to transform a python abstract syntax tree into pystoch. This class inherits from pystoch.codegen.SourceGenerator, which is a NodeVisitor that transforms a python abstract syntax tree (AST) into python code. The PyStochCompiler takes it one step further, overriding the appropriate functions from SourceGenerator in order to insert PyStoch necessary code and perform PyStoch transformations. See Also -------- pystoch.codegen pystoch.ast _ast ast """ def __init__(self): """Initialize the PyStochCompiler. This initializes the SourceGenerator, and creates a new list for pystoch identifiers (in order to avoid hash collisions). """ super(PyStochCompiler, self).__init__(' ' * 4, False) self.idens = [] self.inloop = False self.inclass = False self.infunc = False self.line = IntegerStack() def _gen_iden(self, node): """Generate a random unique PyStoch identifier. All PyStoch identifiers are prefixed by 'PYSTOCH_', followed by an eight-character hexadecimal string. The hexadecimal is the first eight characters of the md5 digest of the current date and time concatenated with the hash of `node`. All generated ids are stored, and if a collision is detected, the function will try again (with a different date and time) to generate a unique id. Parameters ---------- node : ast.AST The node to generate an id for Returns ------- out : string The identifier for `node` """ now = str(datetime.datetime.now()) nodeid = str(hash(node)) iden = hashlib.md5(now + nodeid).hexdigest()[:8] if iden in self.idens: iden = self._gen_iden(node) self.idens.append(iden) return "PYSTOCH_%s" % iden @property def source(self): """The source generated by the PyStochCompiler after `compile` has been called. """ for i in xrange(len(self.result)): if not isinstance(self.result[i], str): print "Something went wrong! Expected a string, but got %s." % type(self.result[i]) if i > 0: print "Previous item: %s" % repr(self.result[i-1]) print "This item: %s" % repr(self.result[i]) if i < len(self.result)-1: print "Next item: %s" % repr(self.result[i+1]) return '' return ''.join(self.result) def insert(self, statements): """Insert non-node statements into the source. This is used for inserting non-node statements into the source compilation. These generally should only be PyStoch-specific statements; if they are a statement that needs to be evaluated by the PyStoch compiler then this function should NOT be used. You can pass in either a list/tuple of statements, or a single statement. Parameters ---------- statements : string or list or tuple The statement or statements to be inserted """ # turn it into a list if it's not already if not isinstance(statements, (list, tuple)): statements = [statements] for statement in statements: if not isinstance(statement, str): raise ValueError, "statement is not a string" # write each statement, separated by a newline for statement in statements: super(PyStochCompiler, self).newline() self.write(statement) def compile(self, src): """Compile python source to pystoch source. Parameters ---------- src : string (default=None) If source is a valid path, it will load the source from the path and compile that. If it is not, then it will be treated as the text of the source itself and be compiled. Returns ------- out : string The compiled source """ if not isinstance(src, str): raise ValueError("src must be a string") # read in the source from a file if os.path.exists(src): source = open(src, 'r').read() # ... or just treat src as the actual source else: source = src # parse the source into an AST node = ast.parse(source) # generate an identifier for the module/file, and push this # identifier onto the module stack. Also push a 0 onto the # line stack. iden = self._gen_iden(node) self.line.push(0) self.insert("PYSTOCHOBJ.line_stack.push(0)") # compile the rest of the module self.visit(node) # and finally, pop the line and module stacks self.write('\n') self.insert("PYSTOCHOBJ.line_stack.pop()") self.line.pop() def newline(self, node=None, extra=0): """Insert a newline. This inserts a newline in the same way as SourceGenerator, with the additional catch of incrementing the line stack if the node asking for the newline is non-null (if it's null, then incrementing the line stack is pointless because nothing will happen between now and the next time a newline occurs). Parameters ---------- node : ast.AST (default=None) The ast node to insert a newline for extra : integer (default=0) The number of extra newlines to insert Notes ----- This function doesn't actually immediately insert a newline, it increments the number of newlines to insert and then inserts them when write() is called. """ # call the parent newline method super(PyStochCompiler, self).newline(node=node, extra=extra) # return if the node is null if node is None: return # otherwise, incremet the line stack and then insert another # newline self.line.increment() self.write("PYSTOCHOBJ.line_stack.set(%s)" % self.line.peek()) super(PyStochCompiler, self).newline(node=node, extra=extra) def body(self, statements, write_before=None, write_after=None): """Write the body statements. This is the same as the SourceGenerator body function, with the additional parameters of write_before and write_after. These parameters allow you to insert extra stuff before and after the rest of the statements. Parameters ---------- statements : list of ast.AST nodes The statements to be written in the body write_before : list or string The statements to write before the body write_after : list or string The statements to write after the body """ # increment the level of indentation #self.new_line = True self.indentation += 1 # insert the write_before statementss if write_before is not None: self.insert(write_before) # write the actual body statements for stmt in statements: self.visit(stmt) # insert the write_after statements if write_after is not None: self.insert(write_after) # decrement the level of indentation self.indentation -= 1 def body_or_else(self, node, write_before=None, write_after=None): """Write a body as well as an else statement, if it exists. Parameters ---------- node : ast.AST node that has a body and optionally an orelse write_before : list or string The statements to write before the body write_after : list or string The statements to write after the body See Also -------- pystoch.compile.PyStochCompiler.body """ self.body(node.body, write_before=write_before, write_after=write_after) if node.orelse: self.newline() self.write('else:') self.body(node.orelse) def to_assign(self, value): """Takes a value, creates a random temporary identifier for it, and creates an Assign node, assigning the value to the identifier. Parameters ---------- value : ast.AST node that is to be the value of the Assign node Returns ------- out : tuple of string, _ast.Assign the string is the identifier of the node, and the _ast.Assign is the node that was created """ if not isinstance(value, _ast.AST): raise ValueError, "value is not an instance of _ast.AST" iden = self._gen_iden(value) node = _ast.Assign( targets=[ast.parse(iden).body[0].value], value=value) return iden, node def _count_calls(self, node): """Counts the number of Call nodes in a node. Note that this does NOT count Call nodes that are children of ListComp nodes. Parameters ---------- node : _ast.AST The node to count Call nodes in Returns ------- out : integer The number of Call nodes in `node` """ class CountCalls(NodeCounter): def visit_Call(self, node): total = 1 total += self.visit(node.func) for arg in node.args: total += self.visit(arg) for keyword in node.keywords: total += self.visit(keyword) if node.starargs is not None: total += self.visit(node.starargs) if node.kwargs is not None: total += self.visit(node.kwargs) return total def visit_ListComp(self, node): return 0 def visit_GeneratorExp(self, node): return 0 return CountCalls().visit(node) def _count_listcomps(self, node): """Counts the ListComp nodes in a node. Parameters ---------- node : _ast.AST The node to count ListComp nodes in Returns ------- out : integer 1 if there are any ListComps, 0 if there are none """ class CountListComps(NodeCounter): def visit_ListComp(self, node): return 1 def visit_GeneratorExp(self, node): return 1 def visit_Lambda(self, node): return 1 return CountListComps().visit(node) def should_rewrite(self, node, threshold=1, cthreshold=0): """Checks whether or not a node (_ast.AST) contains more than the minimum number of Call nodes and ListComps (though ListComps are only counted once, if they have more ListComps inside them they will not be counted). If so, this indicates that the node should probably be rewritten. Parameters ---------- node : _ast.AST The node to check for Call/ListComp nodes threshold : int (default=1) The default threshold for how many Call/ListComp nodes should be allowed before the node should be rewritten Returns ------- out : boolean Whether or not the node contains any Call/ListComp nodes """ ctotal = self._count_calls(node) lctotal = self._count_listcomps(node) return (ctotal > threshold) or (lctotal > cthreshold) def contains_return(self, node): """Checks whether or not a node (_ast.AST) contains any Return nodes. Parameters ---------- node : _ast.AST The node to check for Return nodes Returns ------- out : boolean Whether or not the node contains any Return nodes """ class ReturnChecker(NodeCounter): def visit_Return(self, node): return 1 return ReturnChecker().visit(node) def _extract(self, node): """Takes a node, checks to see if there are any Call nodes or ListComp nodes in it, and if so, extracts either the first ListComp or deepest Call child of the first Call node, places it inside a temporary variable, and then replaces the Call/ListComp node in `node` with the temporary variable. Returns the new, modified node. Parameters ---------- node : _ast.AST The node to extract any Call/ListComp nodes from Returns ------- out : _ast.AST The modified node, with a single Call/ListComp node removed """ class Replace(ast.NodeTransformer): # we need to pass in various parent functions def __init__(self, callback, should_rewrite, to_assign): self.callback = callback self.to_assign = to_assign self.should_rewrite = should_rewrite self.done = False def visit_Call(self, node): if self.done: return node # check function new = self.visit(node.func) if new != node.func: node.func = new self.done = True return node # check arguments for arg in xrange(len(node.args)): new = self.visit(node.args[arg]) if new != node.args[arg]: node.args[arg] = new self.done = True return node # check keyword arguments for arg in xrange(len(node.keywords)): new = self.visit(node.keywords[arg]) if new != node.keywords[arg]: node.keywords[arg] = new self.done = True return node # check *args if node.starargs is not None: new = self.visit(node.starargs) if new != node.starargs: node.starargs = new self.done = True return node # check **kwargs if node.kwargs is not None: new = self.visit(node.kwargs) if new != node.kwargs: node.kwargs = new self.done = True return node # rewrite the call node iden, assignment = self.to_assign(node) self.done = True self.callback(assignment) return ast.parse(iden).body[0].value def visit_ListComp(self, node): if self.done: return node # rewrite the listcomp iden, assignment = self.to_assign(node) self.done = True self.callback(assignment) return ast.parse(iden).body[0].value def visit_GeneratorExp(self, node): if self.done: return node iden, assignment = self.to_assign(node) self.done = True self.callback(assignment) return ast.parse(iden).body[0].value def visit_Lambda(self, node): if self.done: return node iden, assignment = self.to_assign(node) self.done = True self.callback(assignment) return ast.parse(iden).body[0].value newnode = Replace(self.visit, self.should_rewrite, self.to_assign).visit(node) return newnode def extract(self, node, threshold=1, cthreshold=0): """Extract all Call/ListComp nodes above a certain threshold from a node. This repeatedly calls self._extract(node) until self.should_rewrite(node, threshold=threshold) is False. Parameters ---------- node : _ast.AST The node to extract Call/ListComp nodes from. threshold : int (default=1) The minimum number of Call/ListComp nodes that are allowed before extraction will take place cthreshold : int (default=0) The minimum number of ListComp nodes that are allowed before extraction will take place. Returns ------- out : _ast.AST The modified `node`, minus Call/ListComp nodes. """ rw = self.should_rewrite(node, threshold=threshold, cthreshold=cthreshold) while rw: node = self._extract(node) rw = self.should_rewrite(node, threshold=threshold, cthreshold=cthreshold) return node ########### Visitor Functions ########### # 1) Statements def visit_FunctionDef(self, node): """Rewrite the FunctionDef visitor to push a new value onto the function and line stacks at the beginning of the function, and then pop those values at the end of the function. """ self.newline(extra=1) self.decorators(node) super(PyStochCompiler, self).newline() self.write('def %s(' % node.name) node.args.args.append(ast.parse('PYSTOCHOBJ').body[0].value) node.args.defaults.append(ast.parse('None').body[0].value) self.signature(node.args) self.write('):') if len(node.body) == 1 and isinstance(node.body[0], _ast.Pass): self.body(node.body) else: self.line.push(0) write_before = [ "PYSTOCHOBJ.func_stack.push('%s')" % self._gen_iden(node), "PYSTOCHOBJ.line_stack.push(0)" ] # only pop the line and function stacks if there is no return # statement if self.contains_return(node): write_after = None else: write_after = [ "PYSTOCHOBJ.line_stack.pop()", "PYSTOCHOBJ.func_stack.pop()" ] infunc = self.infunc self.infunc = True self.body(node.body, write_before=write_before, write_after=write_after) self.infunc = infunc self.line.pop() self.insert("%s.random = True" % node.name) def visit_ClassDef(self, node): """Rewrite the ClassDef visitor to push new values onto the class and line stacks at the beginning of the stack, and then to pop those values at the end of the class. """ have_args = [] def paren_or_comma(): if have_args: self.write(', ') else: have_args.append(True) self.write('(') self.newline(extra=2) self.decorators(node) self.newline(node) self.write('class %s' % node.name) for base in node.bases: paren_or_comma() self.visit(base) self.write(have_args and '):' or ':') write_before = ["random = True"] inclass = self.inclass self.inclass = True self.body(node.body, write_before=write_before) self.write('\n') self.inclass = inclass def visit_Return(self, node): """Rewrite the Return visitor function to first store the return value of the function, then pop the line and function stacks, then return the stored value. """ node = self.extract(node, threshold=0) # pop the line and function stacks if self.inloop: self.insert([ "PYSTOCHOBJ.line_stack.pop()", "PYSTOCHOBJ.loop_stack.pop()", "PYSTOCHOBJ.func_stack.pop()" ]) else: self.insert([ "PYSTOCHOBJ.line_stack.pop()", "PYSTOCHOBJ.func_stack.pop()" ]) super(PyStochCompiler, self).newline() self.write("return ") if node.value is not None: self.visit(node.value) def visit_Delete(self, node): """Calls the superclass' visit_Delete method, while additionally checking to make sure that no Call nodes are present in the node being visited. See Also -------- codegen.SourceCompiler#visit_Delete """ if self.should_rewrite(node): raise UnexpectedCallException super(PyStochCompiler, self).visit_Delete(node) def visit_Assign(self, node): """Rewrite the Assign visitor function to deal with list comprehensions and function calls. """ # if the value is a list comprehension, the we need to handle # it specially if isinstance( node.value, (_ast.ListComp, _ast.GeneratorExp, _ast.Lambda)): node.value = self.extract(node.value, cthreshold=1) iden = self.visit(node.value) if isinstance(node.value, _ast.GeneratorExp): val = _ast.Call( func=_ast.Name(id=iden, ctx=_ast.Load()), args=[], keywords=[], starargs=None, kwargs=None) else: val = _ast.Name(id=iden, ctx=_ast.Load()) node = ast.Assign( value=val, targets=node.targets) elif isinstance(node.value, _ast.Dict) or \ isinstance(node.value, _ast.List) or \ isinstance(node.value, _ast.Tuple): node.value = self.extract(node.value, threshold=0) else: # do Call/ListComp extraction on the node's value node.value = self.extract(node.value, cthreshold=0) self.newline(node) for idx, target in enumerate(node.targets): if idx: self.write(' = ') self.visit(target) self.write(' = ') self.visit(node.value) def visit_AugAssign(self, node): """Rewrite the AugAssign visitor function to deal with list comprehensions and function calls. """ # do Call/ListComp extraction on the node's value # if the value is a list comprehension, the we need to handle # it specially if isinstance( node.value, (_ast.ListComp, _ast.GeneratorExp, _ast.Lambda)): node.value = self.extract(node.value, cthreshold=1) iden = self.visit_ListComp(node.value) val = _ast.Name(id=iden, ctx=_ast.Load()) node = ast.Assign( value = val, targets = node.targets) elif isinstance(node.value, _ast.Dict) or \ isinstance(node.value, _ast.List) or \ isinstance(node.value, _ast.Tuple): node.value = self.extract(node.value, threshold=0) else: node.value = self.extract(node.value, cthreshold=0) super(PyStochCompiler, self).visit_AugAssign(node) def visit_Print(self, node): """Rewrite the Print visitor function to deal with possible Call nodes. Any children with Call nodes are stored in temporary variables, and then the temporary variable is used in the actual print statement. See Also -------- codegen.SourceCompiler#visit_Print """ # if there is no call node, then we can just call super if self.should_rewrite(node): node = self.extract(node) super(PyStochCompiler, self).visit_Print(node) def visit_For(self, node): """Rewrite the For visitor function to first store the iterator of the for loop in a temporary variable, and then to loop over the contents of that variable. Additionally, push a new value onto the loop stack before entering the for loop, increment that value after each pass of the loop, and pop the value after the loop has terminated. """ node.iter = self.extract(node.iter) # push a new value onto the loop stack self.newline(node) super(PyStochCompiler, self).newline() self.insert(["PYSTOCHOBJ.loop_stack.push(0)"]) super(PyStochCompiler, self).newline() # iterate over the stored value for the for loop iterator self.write('for ') self.visit(node.target) self.write(' in ') self.visit(node.iter) self.write(':') # increment the loop stack at the end of the body inloop = self.inloop self.inloop = True self.body_or_else(node, write_before="PYSTOCHOBJ.loop_stack.increment()") self.inloop = inloop # and finally, pop the loop stack after the for loop is over self.insert("PYSTOCHOBJ.loop_stack.pop()") def visit_While(self, node): """Rewrite the While visitor function to first store the test of the while loop in a temporary variable, and then to loop over the contents of that variable. Additionally, push a new value onto the loop stack before entering the while loop, increment that value after each pass of the loop, and pop the value after the loop has terminated. """ node.test = self.extract(node.test) # push a new value onto the loop stack self.newline(node) super(PyStochCompiler, self).newline() self.insert(["PYSTOCHOBJ.loop_stack.push(0)"]) super(PyStochCompiler, self).newline() self.write('while ') self.visit(node.test) self.write(':') # increment the loop stack at the end of the body inloop = self.inloop self.inloop = True self.body_or_else(node, write_before="PYSTOCHOBJ.loop_stack.increment()") self.inloop = inloop # and finally, pop the loop stack at the end of the body self.insert("PYSTOCHOBJ.loop_stack.pop()") def visit_If(self, node): """Rewrite the If visitor function to assign the if and elif tests to temporary variables, and then check these variables in the actual if and elif statements. """ node.test = self.extract(node.test) #for i in xrange(len(node.orelse)): # node.orelse[i] = self.extract(node.orelse[i], threshold=0) self.newline(node) self.write('if ') self.visit(node.test) self.write(':') self.body(node.body) while True: else_ = node.orelse if len(else_) == 1 and isinstance(else_[0], _ast.If): node = else_[0] self.newline() self.write('elif ') self.visit(node.test) self.write(':') self.body(node.body) # handle the case were there is no else statement... elif len(else_) == 0: break else: self.newline() self.write('else:') self.body(else_) break def visit_With(self, node): """With statements are not supported at this time. """ raise NotImplementedError, "With statements are not supported at this time" def visit_Raise(self, node): """Rewrite the Raise visitor function to deal with potential Call nodes. Any children with Call nodes are stored in temporary variables, and then the variable is used in the actual raise statement. See Also -------- codegen.SourceCompiler#visit_Raise """ node.type = self.extract(node.type) if node.inst is not None: node.inst = self.extract(node.inst) if node.tback is not None: node.tback = self.extract(node.tback) super(PyStochCompiler, self).visit_Raise(node) def visit_TryExcept(self, node): """The superclass' visit_TryExcept method is called. See Also -------- codegen.SourceCompiler#visit_TryExcept """ super(PyStochCompiler, self).visit_TryExcept(node) def visit_TryFinally(self, node): """The superclass' visit_TryFinally method is called. See Also -------- codegen.SourceCompiler#visit_TryFinally """ super(PyStochCompiler, self).visit_TryFinally(node) def visit_Assert(self, node): """The Assert statement visitor function. This function is not implemented in codegen. It prints the assert statement as normal, additionally rewriting the test case and/or the message if they contain Call nodes. Parameters ---------- node : _ast.Assert The Assert node to transform into source code """ node.test = self.extract(node.test) if node.msg is not None: node.msg = self.extract(node.msg) # write the test case of the assert statement self.newline(node) self.write('assert ') self.visit(node.test) # if the message exists, then write it too if node.msg is not None: self.write(', ') self.visit(node.msg) def visit_Import(self, node): """Visit an import node (not calling superclass because it is wrong) """ for item in node.names: self.newline(node) self.write('import ') self.visit(item) def visit_ImportFrom(self, node): """The superclass' visit_ImportFrom method is called. See Also -------- codegen.SourceCompiler#visit_ImportFrom """ super(PyStochCompiler, self).visit_ImportFrom(node) def visit_Exec(self, node): """Exec statements are not supported at this time. """ raise NotImplementedError, "Exec statements are not supported at this time" def visit_Expr(self, node): """Rewrite the Expr visitor function to deal with Call nodes that are on a single line by themselves. These must be handled specially, because we want to make sure that the first Call node remains on a line by itself, e.g.: PYSTOCH_AAAAAAAA = bar() foo(PYSTOCH_AAAAAAAA) And not: PYSTOCH_AAAAAAAA = bar() PYSTOCH_BBBBBBBB = foo(PYSTOCH_AAAAAAAA) PYSTOCH_BBBBBBBB If it is not a call function, then it extracts the node's value as per usual and then calls the super visit_Expr on the new node. """ if isinstance(node.value, _ast.Call): self.visit_Call(node.value, True) else: node.value = self.extract(node.value) super(PyStochCompiler, self).visit_Expr(node) def visit_Pass(self, node): """Calls the superclass' visit_Pass method. See Also -------- codegen.SourceCompiler#visit_pass """ if not isinstance(node, _ast.AST): raise ValueError, "node must be an instance of _ast.AST" self.insert('pass') # 2) Expressions def visit_BoolOp(self, node): """Calls the superclass' visit_BoolOp method. See Also -------- codegen.SourceCompiler#visit_BoolOp """ super(PyStochCompiler, self).visit_BoolOp(node) def visit_BinOp(self, node): """Calls the superclass' visit_BinOp method. See Also -------- codegen.SourceCompiler#visit_BinOp """ self.write('(') super(PyStochCompiler, self).visit_BinOp(node) self.write(')') def visit_UnaryOp(self, node): """Calls the superclass' visit_UnaryOp method. See Also -------- codegen.SourceCompiler#visit_UnaryOp """ super(PyStochCompiler, self).visit_UnaryOp(node) def visit_Lambda(self, node): """Rewrite the Lambda visitor function to transform the lambda into a real iterator function. """ iden = self._gen_iden(node) funcnode = _ast.FunctionDef( name=iden, args=node.args, body=[_ast.Return(value=node.body)], decorator_list=[]) self.visit(funcnode) return iden def visit_IfExp(self, node): """IfExps are not supported at this time. """ raise NotImplementedError, "IfExp nodes are not supported at this time." def visit_Dict(self, node): """Rewrite the Dict visitor function to extract any list comprehensions or calls out of the keys or values. """ for k in xrange(len(node.keys)): node.keys[k] = self.extract(node.keys[k], threshold=0) for v in xrange(len(node.values)): node.values[v] = self.extract(node.values[v], threshold=0) super(PyStochCompiler, self).visit_Dict(node) def visit_Set(self, node): """Set nodes are not supported at this time. """ raise NotImplementedError, "Set nodes are not supported at this time" def visit_ListComp(self, node): """Rewrite the ListComp visitor function to turn the list comprehension into a real for loop. This is necessary to be able to correctly label any random functions that get called from within the list comprehension. Basically, this function creates a temporary variable for the list, and transforms the comprehension into a for loop that appends values onto this list. The list name is then returned, so that whatever element called the for loop can handle the assignment properly. """ # make an identifier for the list self.newline(node) iden = self._gen_iden(node) self.write("%s = []" % iden) elt = node.elt def parse_generator(nodes): """Transform the generator into a for loop. """ node = nodes[0] tempnode = ast.For() tempnode.target = node.target tempnode.iter = node.iter if len(nodes) == 1: append_node = ast.parse("%s.append(foo)" % iden).body[0] append_node.value.args[0] = elt body = [append_node] else: body = [parse_generator(nodes[1:])] if len(node.ifs) == 1: ifnode = _ast.If( test=node.ifs[0], body=body, orelse=[]) tempnode.body = [ifnode] elif len(node.ifs) > 1: ifnode = _ast.If( test=_ast.BoolOp( op=_ast.And(), values=node.ifs), body=body, orelse=[]) tempnode.body = [ifnode] else: tempnode.body = body tempnode.orelse = None return tempnode # visit the for loop self.visit(parse_generator(node.generators)) return iden def visit_SetComp(self, node): """Set comprehensions are not supported at this time. """ raise NotImplementedError, "Set comprehensions are not supported at this time" def visit_DictComp(self, node): """Dictionary comprehensions are not supported at this time. """ raise NotImplementedError, "Dictionary comprehensions are not supported at this time" def visit_GeneratorExp(self, node): """Rewrite the GeneratorExp visitor function to turn the generator expression into a iterator function. This is necessary to be able to correctly label any random functions that get called from within the generator expression. Basically, this function creates a function, and transforms the generator into a for loop that yields values from the. The function name is then returned, so that the parent node can handle the assignment properly. """ # make an identifier for the list self.newline(node) iden = self._gen_iden(node) argids = [] for gen in node.generators: argval = gen.iter argid = self._gen_iden(gen.iter) self.visit(_ast.Assign( targets=[_ast.Name(id=argid, ctx=_ast.Store())], value=argval)) argids.append(argid) elt = node.elt def parse_generator(nodes, ids): node = nodes[0] tempnode = _ast.For() tempnode.target = node.target tempnode.iter = _ast.Name(id=ids[0], ctx=_ast.Load()) if len(nodes) == 1: yield_node = _ast.Expr(value=_ast.Yield(value=elt)) body = [yield_node] else: body = [parse_generator(nodes[1:], ids[1:])] if len(node.ifs) == 1: ifnode = _ast.If( test=node.ifs[0], body=body, orelse=[]) tempnode.body = [ifnode] elif len(node.ifs) > 1: ifnode = _ast.If( test=_ast.BoolOp( op=_ast.And(), values=node.ifs), body=body, orelse=[]) tempnode.body = [ifnode] else: tempnode.body = body tempnode.orelse = None return tempnode node = _ast.FunctionDef( name=iden, args=_ast.arguments(args=[], vararg=None, kwarg=None, defaults=[]), body=[parse_generator(node.generators, argids)], decorator_list=[]) self.visit(node) return iden def visit_Yield(self, node): """Rewrite the Yield visitor function to extract calls/list comprehensions, and additionally pop the line and function stacks before yielding, and then pushing them with the same values after returning from the yield. """ node = self.extract(node) lineiden = self._gen_iden(ast.parse("PYSTOCHOBJ.line_stack.pop()").body[0]) funciden = self._gen_iden(ast.parse("PYSTOCHOBJ.func_stack.pop()").body[0]) if self.inloop: loopiden = self._gen_iden(ast.parse("PYSTOCHOBJ.loop_stack.pop()").body[0]) # pop the line and function stacks if self.inloop: self.insert([ "%s = PYSTOCHOBJ.line_stack.pop()" % lineiden, "%s = PYSTOCHOBJ.loop_stack.pop()" % loopiden, "%s = PYSTOCHOBJ.func_stack.pop()" % funciden ]) else: self.insert([ "%s = PYSTOCHOBJ.line_stack.pop()" % lineiden, "%s = PYSTOCHOBJ.func_stack.pop()" % funciden ]) super(PyStochCompiler, self).newline() self.write("yield ") if node.value is not None: self.visit(node.value) if self.inloop: self.insert([ "PYSTOCHOBJ.func_stack.push(%s)" % funciden, "PYSTOCHOBJ.loop_stack.push(%s)" % loopiden, "PYSTOCHOBJ.line_stack.push(%s)" % lineiden ]) else: self.insert([ "PYSTOCHOBJ.func_stack.push(%s)" % funciden, "PYSTOCHOBJ.line_stack.push(%s)" % lineiden ]) def visit_Compare(self, node): """Rewrite the Compare visitor function to deal with extraction. """ node.left = self.extract(node.left) for right in xrange(len(node.comparators)): node.comparators[right] = self.extract(node.comparators[right]) self.write('(') self.visit(node.left) for op, right in zip(node.ops, node.comparators): self.write(' %s ' % ast.CMPOP_SYMBOLS[type(op)]) self.visit(right) self.write(')') def visit_Call(self, node, newline=False): """Rewrite the Call visitor function to extract any child Call nodes. Parameters ---------- node : _ast.Call The Call node to rewrite. newline : boolean (default=False) Whether or not to insert a newline before transforming the Call node to source. This should be set to True when, for example, the Call node is by itself on a line. """ node.args.insert(0, node.func) node.func = ast.parse('PYSTOCHOBJ.call').body[0].value # extract each of the children of the Call node with threshold # zero, that is, we already know that we have one call node # (because we're visiting it), so we don't want any of its # children to also be call nodes threshold = 0 for arg in xrange(len(node.args)): node.args[arg] = self.extract(node.args[arg], threshold=threshold) for arg in xrange(len(node.keywords)): node.keywords[arg].value = self.extract(node.keywords[arg].value, threshold=threshold) if node.starargs is not None: node.starargs = self.extract(node.starargs, threshold=threshold) if node.kwargs is not None: node.kwargs = self.extract(node.kwargs, threshold=threshold) if newline: self.newline(node) super(PyStochCompiler, self).visit_Call(node) def visit_Attribute(self, node): """Rewrite the Attribute visitor function to deal with extraction. """ node.value = self.extract(node.value) super(PyStochCompiler, self).visit_Attribute(node) def visit_Subscript(self, node): """Rewrite the Subscript visitor function to deal with extraction. """ node.value = self.extract(node.value) node.slice = self.extract(node.slice) super(PyStochCompiler, self).visit_Subscript(node) def visit_List(self, node): """Rewrite the List visitor function to extract any list comprehensions or calls out of the keys or values. """ for elt in xrange(len(node.elts)): node.elts[elt] = self.extract(node.elts[elt], threshold=0) super(PyStochCompiler, self).visit_List(node) def visit_Tuple(self, node): """Rewrite the Tuple visitor function to extract any list comprehensions or calls out of the keys or values. """ for elt in xrange(len(node.elts)): node.elts[elt] = self.extract(node.elts[elt], threshold=0) super(PyStochCompiler, self).visit_Tuple(node) # 3) Misc def visit_Slice(self, node): """Rewrite the Slice visitor function to deal with extraction. """ if node.lower is not None: node.lower = self.extract(node.lower) if node.upper is not None: node.upper = self.extract(node.upper) if node.step is not None: node.step = self.extract(node.step) super(PyStochCompiler, self).visit_Slice(node) def visit_Index(self, node): """Rewrite the Index visitor function to deal with extraction. """ node.value = self.extract(node.value) self.visit(node.value) def visit_ExceptHandler(self, node): """The superclass' visit_ExceptHandler method is called. See Also -------- codegen.SourceCompiler#visit_ExceptHandler """ super(PyStochCompiler, self).visit_ExceptHandler(node)