def visit(self, node): if node.value: if anno.hasanno(node.value, self.out_label): before = hash(anno.getanno(node.value, self.out_label)) else: before = None preds = [ anno.getanno(pred.value, self.out_label) for pred in node.prev if anno.hasanno(pred.value, self.out_label) ] if preds: incoming = functools.reduce(self.op, preds[1:], preds[0]) else: incoming = frozenset() anno.setanno(node.value, self.in_label, incoming, safe=False) gen, kill = self.gen(node, incoming) anno.setanno(node.value, self.gen_label, gen, safe=False) anno.setanno(node.value, self.kill_label, kill, safe=False) anno.setanno(node.value, self.out_label, (incoming - kill) | gen, safe=False) if hash(anno.getanno(node.value, self.out_label)) != before: for succ in node.next: self.visit(succ) else: preds = [ anno.getanno(pred.value, self.out_label) for pred in node.prev ] self.exit = functools.reduce(self.op, preds[1:], preds[0])
def visit(self, node): if anno.hasanno(node, 'definitions_gen'): self.definitions.update(anno.getanno(node, 'definitions_gen')) self.reaching_definitions = anno.getanno(node, 'definitions_in') if isinstance(node, gast.Name) and isinstance(node.ctx, gast.Load): self.used.update(def_ for def_ in self.reaching_definitions if def_[0] == node.id) super(Unused, self).visit(node) if anno.hasanno(node, 'definitions_gen'): self.reaching_definitions = None
def visit(self, node): # Remove all AD-generated pushes of unused variables. if anno.hasanno(node, 'push_var') and anno.hasanno( node, 'pop') and anno.hasanno(node, 'gen_push'): defs = frozenset( id_ for id_, node in anno.getanno(node, 'definitions_in')) if ast_.get_name(anno.getanno(node, 'push_var')) not in defs: self.remove(node) self.remove(anno.getanno(node, 'pop')) return super(CleanStack, self).visit(node)
def prepend_uninitialized_grads(self, node): if anno.hasanno(node, 'defined_in'): uses = (succ for succ in gast.walk(node) if isinstance(succ, gast.Name) and isinstance(succ.ctx, gast.Load)) for use in uses: if ((anno.hasanno(use, 'adjoint_var') or anno.hasanno(use, 'temp_adjoint_var')) and use.id not in anno.getanno(node, 'defined_in') and use.id not in self.added): self.added.add(use.id) self.insert_top(self._init(use)) return node
def visit(self, node, abort=astor.codegen.SourceGenerator.abort_visit): if anno.hasanno(node, 'comment'): comment = anno.getanno(node, 'comment') # Preprocess the comment to fit to maximum line width of 80 characters linewidth = 78 if comment['location'] in ('above', 'below'): comment['text'] = comment['text'][:linewidth] n_newlines = 1 if self.new_indentation else 2 if comment['location'] == 'above': self.result.append('\n' * n_newlines) self.result.append(self.indent_with * self.indentation) self.result.append('# %s' % comment['text']) super(SourceWithCommentGenerator, self).visit(node) elif comment['location'] == 'below': super(SourceWithCommentGenerator, self).visit(node) self.result.append('\n') self.result.append(self.indent_with * self.indentation) self.result.append('# %s' % comment['text']) self.result.append('\n' * (n_newlines - 1)) elif comment['location'] == 'right': super(SourceWithCommentGenerator, self).visit(node) self.result.append(' # %s' % comment['text']) else: raise TangentParseError('Only valid comment locations are ' 'above, below, right') else: self.new_indentation = False super(SourceWithCommentGenerator, self).visit(node)
def explicit_loop_indexes(node): node = ExplicitLoopIndexes().visit(node) for n in gast.walk(node): for key in ('active_in', 'active_out', 'active_gen', 'active_kill'): if anno.hasanno(n, key): anno.delanno(n, key) return node
def visit_Assign(self, node): if isinstance(node.value, gast.Call) and anno.hasanno( node.value.func, 'add_grad'): defs = frozenset( id_ for id_, node in anno.getanno(node, 'definitions_in')) if ast_.get_name(node.targets[0]) not in defs: node.value = node.value.args[1] return node
def visit(self, node): method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) # Set certain attributes for child nodes if anno.hasanno(node, 'active_in'): self.active_variables = anno.getanno(node, 'active_in') return visitor(node)
def _init(self, node): gradname = ast_.get_name(node) if anno.hasanno(node, 'adjoint_var'): var = anno.getanno(node, 'adjoint_var') else: var = anno.getanno(node, 'temp_adjoint_var') return gast.Assign(targets=[ gast.Name(id=gradname, ctx=gast.Store(), annotation=None) ], value=gast.Call(func=utils.INIT_GRAD, args=[var], keywords=[]))
def visit(self, node): if node in self.to_remove: self.remove = True if anno.hasanno(node, 'pri_call') or anno.hasanno(node, 'adj_call'): # We don't remove function calls for now; removing them also # removes the push statements inside of them, but not the # corresponding pop statements self.is_call = True new_node = super(Remove, self).visit(node) if isinstance(node, grammar.STATEMENTS): if self.remove and not self.is_call: new_node = None self.remove = self.is_call = False if isinstance(node, gast.If) and not node.body: # If we optimized away an entire if block, we need to handle that if not node.orelse: return else: node.test = gast.UnaryOp(op=gast.Not(), operand=node.test) node.body, node.orelse = node.orelse, node.body elif isinstance(node, (gast.While, gast.For)) and not node.body: return node.orelse return new_node
def create_grad(node, namer, tangent=False): """Given a variable, create a variable for the gradient. Args: node: A node to create a gradient for, can be a normal variable (`x`) or a subscript (`x[i]`). namer: The namer object which will determine the name to use for the gradient. tangent: Whether a tangent (instead of adjoint) is created. Returns: node: A node representing the gradient with the correct name e.g. the gradient of `x[i]` is `dx[i]`. Note that this returns an invalid node, with the `ctx` attribute missing. It is assumed that this attribute is filled in later. Node has an `adjoint_var` annotation referring to the node it is an adjoint of. """ if not isinstance(node, (gast.Subscript, gast.Name, gast.Str)): raise TypeError if anno.hasanno(node, 'temp_var'): return create_grad(anno.getanno(node, 'temp_var'), namer, tangent) def _name_grad(node): if not isinstance(node, gast.Name): raise TypeError varname = node.id name = namer.grad(varname, tangent) grad_node = gast.Name(id=name, ctx=None, annotation=None) anno.setanno(grad_node, 'adjoint_var', node) return grad_node if isinstance(node, gast.Subscript): grad_node = create_grad(node.value, namer, tangent=tangent) grad_node.ctx = gast.Load() return gast.Subscript(value=grad_node, slice=node.slice, ctx=None) elif isinstance(node, gast.Str): grad_node = create_grad(gast.Name(id=node.s, ctx=None, annotation=None), namer, tangent=tangent) return gast.Str(grad_node.id) else: return _name_grad(node)
def remove_repeated_comments(node): """Remove comments that repeat themselves. Multiple statements might be annotated with the same comment. This way if one of the statements is deleted during optimization passes, the comment won't be lost. This pass removes sequences of identical comments, leaving only the first one. Args: node: An AST Returns: An AST where comments are not repeated in sequence. """ last_comment = {'text': None} for _node in gast.walk(node): if anno.hasanno(_node, 'comment'): comment = anno.getanno(_node, 'comment') if comment['text'] == last_comment['text']: anno.delanno(_node, 'comment') last_comment = comment return node
def visit(self, node): """Visit a node. This method is largely modelled after the ast.NodeTransformer class. Args: node: The node to visit. Returns: A tuple of the primal and adjoint, each of which is a node or a list of nodes. """ method = 'visit_' + node.__class__.__name__ if not hasattr(self, method): raise ValueError('Unknown node type: %s' % node.__class__.__name__) visitor = getattr(self, method) # If this node is a statement, inform all child nodes what the active # variables in this statement are if anno.hasanno(node, 'active_in'): self.active_variables = anno.getanno(node, 'active_in') pri, adj = visitor(node) # Annotate primal and adjoint statements if isinstance(pri, gast.AST): anno.setdefaultanno(pri, 'adj', adj) else: for node in pri: anno.setdefaultanno(node, 'adj', adj) if isinstance(adj, gast.AST): anno.setdefaultanno(adj, 'pri', pri) else: for node in adj: anno.setdefaultanno(node, 'pri', pri) return pri, adj
def visit(self, node): if anno.hasanno(node, 'push_var'): varname = ast_.get_name(anno.getanno(node, 'push_var')) if varname not in anno.getanno(node, 'defined_in'): self.insert_top(quoting.quote('{} = None'.format(varname))) return super(FixStack, self).visit(node)
def visit_Assign(self, node): """Visit assignment statement.""" if len(node.targets) != 1: raise ValueError('no support for chained assignment') # Before the node gets modified, get a source code representation # to add as a comment later on if anno.hasanno(node, 'pre_anf'): orig_src = anno.getanno(node, 'pre_anf') else: orig_src = quoting.unquote(node) # Set target for the RHS visitor to access self.orig_target = ast_.copy_node(node.targets[0]) # If we know we're going to be putting another stack on the stack, # we should try to make that explicit if isinstance(node.value, gast.Call) and \ anno.hasanno(node.value, 'func') and \ anno.getanno(node.value, 'func') in (utils.Stack, utils.pop_stack): push, pop, op_id = get_push_pop_stack() else: push, pop, op_id = get_push_pop() push_stack, pop_stack, op_id_stack = get_push_pop_stack() # Every assignment statement requires us to store the pre-value, and in the # adjoint restore the value, and reset the gradient store = template.replace( 'push(_stack, y, op_id)', push=push, y=self.orig_target, _stack=self.stack, op_id=op_id) create_substack = template.replace( 'substack = tangent.Stack()', substack=self.substack) store_substack = template.replace( 'push(stack, substack, op_id)', push=push_stack, stack=self.stack, substack=self.substack, op_id=op_id_stack) restore = template.replace( 'y = pop(_stack, op_id)', _stack=self.stack, pop=pop, y=ast_.copy_node(self.orig_target), op_id=op_id) restore_substack = template.replace( 'substack = pop(stack, op_id)', pop=pop_stack, stack=self.stack, substack=self.substack, op_id=op_id_stack) reset = template.replace( 'd[y] = init_grad(y, allow_lazy_initializer=True)', y=self.orig_target, init_grad=utils.INIT_GRAD, namer=self.namer, replace_grad=template.Replace.FULL) # If there are no active nodes, we don't need to find an adjoint # We simply store and restore the state, and reset the gradient if not self.is_active(node): return [store, node], [restore, reset] # We create a temporary variable for the target that the RHS can use self.target = create.create_temp(self.orig_target, self.namer) create_tmp = template.replace( 'tmp = y', tmp=self.target, y=self.orig_target) # Get the primal and adjoint of the RHS expression try: fx, adjoint_rhs = self.visit(node.value) except ValueError as e: context = [t.id if hasattr(t, 'id') else t for t in node.targets] raise ValueError( 'Failed to process assignment to: %s. Error: %s' % (context, e)) if not isinstance(adjoint_rhs, list): adjoint_rhs = [adjoint_rhs] # Walk the RHS adjoint AST to find temporary adjoint variables to sum accumulations = [] for n in adjoint_rhs: for succ in gast.walk(n): if anno.hasanno(succ, 'temp_adjoint_var'): xi = anno.getanno(succ, 'temp_adjoint_var') dxi_partial = ast_.copy_node(succ) accumulations.append(template.replace( 'd[xi] = add_grad(d[xi], dxi_partial)', namer=self.namer, replace_grad=template.Replace.FULL, xi=xi, dxi_partial=dxi_partial, add_grad=utils.ADD_GRAD)) # The primal consists of storing the state and then performing the # assignment with the new primal. # The primal `fx` may be optionally (but rarely) redefined when the # adjoint is generated, in `fx, adjoint_rhs = self.visit(node.value)`. # If we see that the primal value is an Assign node, or a list of nodes # (with at least one being an Assign) node, we allow the primal to change. # Otherwise, we'll build our own Assign node. if isinstance(fx, gast.Assign): assign = [fx] elif (isinstance(fx, list) and any([isinstance(ifx, gast.Assign) for ifx in fx])): assign = fx else: assign = template.replace( 'y = fx', y=ast_.copy_node(self.orig_target), fx=fx) assign = [assign] primal = [store, create_substack, store_substack] + assign # The adjoint involves creating the temporary, restoring the store, # calculating the adjoint, resetting the gradient, and finally accumulating # the partials adjoint = [create_tmp, restore_substack, restore ] + adjoint_rhs + [reset] + accumulations # If the LHS is a subscript assignment with variable index, we need to # store and restore that as well if (isinstance(self.orig_target, gast.Subscript) and isinstance(self.orig_target.slice.value, gast.Name)): push, pop, op_id = get_push_pop() i = self.orig_target.slice.value push_index = template.replace( 'push(_stack, i, op_id)', push=push, i=i, _stack=self.stack, op_id=op_id) pop_index = template.replace( 'i = pop(_stack, op_id)', pop=pop, i=i, _stack_=self.stack, op_id=op_id) primal.insert(len(primal), push_index) adjoint.insert(0, pop_index) # Add a comment in the backwards pass, indicating which # lines in the forward pass generated the adjoint for i, adj in enumerate(adjoint): adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src) return primal, adjoint
def mark(self, node): if not anno.hasanno(node, 'pre_anf') and self.src: anno.setanno(node, 'pre_anf', self.src)
def visit(self, node): if anno.hasanno(node, 'definitions_in'): self.reaching_definitions = anno.getanno(node, 'definitions_in') super(ReadCounts, self).visit(node) if anno.hasanno(node, 'definitions_in'): self.reaching_definitions = None