def visit_Expr(self, node): # We need to special-case pushes, e.g. utils.push(_stack,x,op_id) adjoint = [] if (isinstance(node.value, gast.Call) and anno.getanno(node.value, 'func') == utils.push): orig_src = quoting.unquote(node) stack, val, op_id = node.value.args push_template = grads.adjoints[utils.push] adjoint_rhs = template.replace(push_template, namer=self.namer, stack=stack, val=val, op_id=op_id) # Walk the RHS adjoint AST to find temporary adjoint variables to # sum accumulation = template.replace( 'd[xi] = add_grad(d[xi], dxi_partial)', namer=self.namer, replace_grad=template.Replace.FULL, xi=val, dxi_partial=ast_.copy_node(adjoint_rhs[0].targets[0]), add_grad=utils.ADD_GRAD) adjoint = adjoint_rhs + [accumulation] for i, adj in enumerate(adjoint): adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src) return node, adjoint
def visit_Assign(self, node): """Visit assignment statement. Notes ----- This method sets the `self.target` attribute to the first assignment target for callees to use. """ # Generate tangent name of the assignment self.target = node.targets[0] self.value = node.value # Get the tangent node tangent_node = self.visit(self.value) # If no forward-mode statement was created, then no extra work is needed. if tangent_node == self.value: self.target = None return node if self.value: new_node = template.replace( tangents.tangents[gast.Assign], replace_grad=template.Replace.TANGENT, namer=self.namer, temp=self.tmp_node, tangent=tangent_node, target=self.target, value=self.value) # Ensure that we use a unique tmp node per primal/tangent pair self.reset_tmp_node() else: # We're already in ANF form right now, # so we can assume LHS is a single Name "z" def template_(z, f): tmp = f d[z] = tmp[0] z = tmp[1] new_node = template.replace( template_, replace_grad=template.Replace.TANGENT, namer=self.namer, z=self.target, f=tangent_node[0]) # Add it after the original statement self.target = None # Add in some cool comments for i in range(len(new_node)): new_node[i] = comments.add_comment(new_node[i], 'Primal and tangent of: ' '%s' % quoting.unquote(node)) return new_node
def _create_joint(fwdbwd, func, wrt, input_derivative): """Create a user-friendly gradient function. By default, gradient functions expect the stack to be passed to them explicitly. This function modifies the function so that the stack doesn't need to be passed and gets initialized in the function body instead. For consistency, gradient functions always return a tuple, even if the gradient of only one input was required. We unpack the tuple if it is of length one. Args: fwdbwd: An AST. The function definition of the joint primal and adjoint. func: A function handle. The original function that was differentiated. wrt: A tuple of integers. The arguments with respect to which we differentiated. Returns: The function definition of the new function. """ # Correct return to be a non-tuple if there's only one element retval = fwdbwd.body[-1] if len(retval.value.elts) == 1: retval.value = retval.value.elts[0] # Make a stack init statement init_stack = quoting.quote('%s = tangent.Stack()' % fwdbwd.args.args[0].id) init_stack = comments.add_comment(init_stack, 'Initialize the tape') # Prepend the stack init to the top of the function fwdbwd.body = [init_stack] + fwdbwd.body # Replace the function arguments with the original ones grad_name = fwdbwd.args.args[1].id fwdbwd.args = quoting.parse_function(func).body[0].args # Give the function a nice name fwdbwd.name = naming.joint_name(func, wrt) # Allow the initial gradient to be passed as a keyword argument fwdbwd = ast_.append_args(fwdbwd, [grad_name]) if input_derivative == INPUT_DERIVATIVE.DefaultOne: fwdbwd.args.defaults.append(quoting.quote('1.0')) return fwdbwd
def visit_With(self, node): """Deal with the special with insert_grad_of(x) statement.""" if ast_.is_insert_grad_of_statement(node): primal = [] adjoint = node.body if isinstance(adjoint[0], gast.With): _, adjoint = self.visit(adjoint[0]) node.body[0] = comments.add_comment(node.body[0], 'Inserted code') # Rename the gradients replacements = {} for item in node.items: if (not isinstance(item.context_expr.args[0], gast.Name) or not isinstance(item.optional_vars, gast.Name)): raise ValueError replacements[item.optional_vars.id] = create.create_grad( item.context_expr.args[0], self.namer) template.ReplaceTransformer(replacements).visit(node) return primal, adjoint else: return node, []
def test_comment(): node = quoting.parse_function(f).body[0] comments.add_comment(node.body[0], 'foo', 'above') source = quoting.to_source(node) lines = source.split('\n') assert lines[1].strip() == '# foo' comments.add_comment(node.body[0], 'foo', 'right') source = quoting.to_source(node) lines = source.split('\n') assert lines[1].strip() == 'y = x # foo' comments.add_comment(node.body[0], 'foo', 'below') source = quoting.to_source(node) lines = source.split('\n') assert lines[2].strip() == '# foo'
def visit_Assign(self, node): """Visit assignment statement.""" if len(node.targets) != 1: raise ValueError('no support for chained assignment') # Before the node gets modified, get a source code representation # to add as a comment later on if anno.hasanno(node, 'pre_anf'): orig_src = anno.getanno(node, 'pre_anf') else: orig_src = quoting.unquote(node) # Set target for the RHS visitor to access self.orig_target = ast_.copy_node(node.targets[0]) # If we know we're going to be putting another stack on the stack, # we should try to make that explicit if isinstance(node.value, gast.Call) and \ anno.hasanno(node.value, 'func') and \ anno.getanno(node.value, 'func') in (utils.Stack, utils.pop_stack): push, pop, op_id = get_push_pop_stack() else: push, pop, op_id = get_push_pop() push_stack, pop_stack, op_id_stack = get_push_pop_stack() # Every assignment statement requires us to store the pre-value, and in the # adjoint restore the value, and reset the gradient store = template.replace( 'push(_stack, y, op_id)', push=push, y=self.orig_target, _stack=self.stack, op_id=op_id) create_substack = template.replace( 'substack = tangent.Stack()', substack=self.substack) store_substack = template.replace( 'push(stack, substack, op_id)', push=push_stack, stack=self.stack, substack=self.substack, op_id=op_id_stack) restore = template.replace( 'y = pop(_stack, op_id)', _stack=self.stack, pop=pop, y=ast_.copy_node(self.orig_target), op_id=op_id) restore_substack = template.replace( 'substack = pop(stack, op_id)', pop=pop_stack, stack=self.stack, substack=self.substack, op_id=op_id_stack) reset = template.replace( 'd[y] = init_grad(y, allow_lazy_initializer=True)', y=self.orig_target, init_grad=utils.INIT_GRAD, namer=self.namer, replace_grad=template.Replace.FULL) # If there are no active nodes, we don't need to find an adjoint # We simply store and restore the state, and reset the gradient if not self.is_active(node): return [store, node], [restore, reset] # We create a temporary variable for the target that the RHS can use self.target = create.create_temp(self.orig_target, self.namer) create_tmp = template.replace( 'tmp = y', tmp=self.target, y=self.orig_target) # Get the primal and adjoint of the RHS expression try: fx, adjoint_rhs = self.visit(node.value) except ValueError as e: context = [t.id if hasattr(t, 'id') else t for t in node.targets] raise ValueError( 'Failed to process assignment to: %s. Error: %s' % (context, e)) if not isinstance(adjoint_rhs, list): adjoint_rhs = [adjoint_rhs] # Walk the RHS adjoint AST to find temporary adjoint variables to sum accumulations = [] for n in adjoint_rhs: for succ in gast.walk(n): if anno.hasanno(succ, 'temp_adjoint_var'): xi = anno.getanno(succ, 'temp_adjoint_var') dxi_partial = ast_.copy_node(succ) accumulations.append(template.replace( 'd[xi] = add_grad(d[xi], dxi_partial)', namer=self.namer, replace_grad=template.Replace.FULL, xi=xi, dxi_partial=dxi_partial, add_grad=utils.ADD_GRAD)) # The primal consists of storing the state and then performing the # assignment with the new primal. # The primal `fx` may be optionally (but rarely) redefined when the # adjoint is generated, in `fx, adjoint_rhs = self.visit(node.value)`. # If we see that the primal value is an Assign node, or a list of nodes # (with at least one being an Assign) node, we allow the primal to change. # Otherwise, we'll build our own Assign node. if isinstance(fx, gast.Assign): assign = [fx] elif (isinstance(fx, list) and any([isinstance(ifx, gast.Assign) for ifx in fx])): assign = fx else: assign = template.replace( 'y = fx', y=ast_.copy_node(self.orig_target), fx=fx) assign = [assign] primal = [store, create_substack, store_substack] + assign # The adjoint involves creating the temporary, restoring the store, # calculating the adjoint, resetting the gradient, and finally accumulating # the partials adjoint = [create_tmp, restore_substack, restore ] + adjoint_rhs + [reset] + accumulations # If the LHS is a subscript assignment with variable index, we need to # store and restore that as well if (isinstance(self.orig_target, gast.Subscript) and isinstance(self.orig_target.slice.value, gast.Name)): push, pop, op_id = get_push_pop() i = self.orig_target.slice.value push_index = template.replace( 'push(_stack, i, op_id)', push=push, i=i, _stack=self.stack, op_id=op_id) pop_index = template.replace( 'i = pop(_stack, op_id)', pop=pop, i=i, _stack_=self.stack, op_id=op_id) primal.insert(len(primal), push_index) adjoint.insert(0, pop_index) # Add a comment in the backwards pass, indicating which # lines in the forward pass generated the adjoint for i, adj in enumerate(adjoint): adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src) return primal, adjoint
def visit_FunctionDef(self, node): # Construct a namer to guarantee we create unique names that don't # override existing names self.namer = naming.Namer.build(node) # Check that this function has exactly one return statement at the end return_nodes = [n for n in gast.walk(node) if isinstance(n, gast.Return)] if ((len(return_nodes) > 1) or not isinstance(node.body[-1], gast.Return)): raise ValueError('function must have exactly one return statement') return_node = ast_.copy_node(return_nodes[0]) # Perform AD on the function body body, adjoint_body = self.visit_statements(node.body[:-1]) # Annotate the first statement of the primal and adjoint as such if body: body[0] = comments.add_comment(body[0], 'Beginning of forward pass') if adjoint_body: adjoint_body[0] = comments.add_comment( adjoint_body[0], 'Beginning of backward pass') # Before updating the primal arguments, extract the arguments we want # to differentiate with respect to dx = gast.Tuple([create.create_grad(node.args.args[i], self.namer) for i in self.wrt], ctx=gast.Load()) if self.preserve_result: # Append an extra Assign operation to the primal body # that saves the original output value stored_result_node = quoting.quote(self.namer.unique('result')) assign_stored_result = template.replace( 'result=orig_result', result=stored_result_node, orig_result=return_node.value) body.append(assign_stored_result) dx.elts.append(stored_result_node) for _dx in dx.elts: _dx.ctx = gast.Load() return_dx = gast.Return(value=dx) # We add the stack as first argument of the primal node.args.args = [self.stack] + node.args.args # Rename the function to its primal name func = anno.getanno(node, 'func') node.name = naming.primal_name(func, self.wrt) # The new body is the primal body plus the return statement node.body = body + node.body[-1:] # Find the cost; the first variable of potentially multiple return values # The adjoint will receive a value for the initial gradient of the cost y = node.body[-1].value if isinstance(y, gast.Tuple): y = y.elts[0] dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(), annotation=None) # Construct the adjoint adjoint_template = grads.adjoints[gast.FunctionDef] adjoint, = template.replace(adjoint_template, namer=self.namer, adjoint_body=adjoint_body, return_dx=return_dx) adjoint.args.args.extend([self.stack, dy]) adjoint.args.args.extend(node.args.args[1:]) adjoint.name = naming.adjoint_name(func, self.wrt) return node, adjoint