def primal_and_adjoint_for_tracing(self, node): """Build the primal and adjoint of a traceable function. Args: node: ast.Call node of a function we wish to trace, instead of transform Returns: primal: new ast.Assign node to replace the original primal call adjoint: new ast.Assign node using the VJP generated in primal to calculate the adjoint. """ primal_template = grads.primals[tracing.Traceable] adjoint_template = grads.adjoints[tracing.Traceable] # Prep to_pack = node.args target = ast_.copy_node(self.orig_target) vjp = quoting.quote(self.namer.unique('%s_grad' % node.func.id)) tmp = create.create_temp(quoting.quote('tmp'), self.namer) assert len(node.keywords) == 0 # Full replacement of primal # TODO: do we need to set 'pri_call' on this? primal = template.replace( primal_template, namer=self.namer, result=target, fn=node.func, tmp=tmp, vjp=vjp, args=gast.Tuple(elts=to_pack, ctx=gast.Load())) # Building adjoint using the vjp generated with the primal dto_pack = gast.Tuple( elts=[create.create_temp_grad(arg, self.namer) for arg in to_pack], ctx=gast.Store()) adjoint = template.replace( adjoint_template, namer=self.namer, result=target, vjp=vjp, dargs=dto_pack) return primal, adjoint
def visit_For(self, node): if node.orelse: raise ValueError # Construct the primal and adjoint of the loop body, adjoint_body = self.visit_statements(node.body) # We create a loop counter which will be pushed on the stack push, pop, op_id = get_push_pop() counter = self.namer.counter() # In `for i in range ...` the variable `i` is the target, which we # temporarily set aside each iteration to push to the stack later push_target, pop_target, op_id_target = get_push_pop() tmp_target = create.create_temp(node.target, self.namer) primal_template = grads.primals[gast.For] primal = template.replace( primal_template, body=body, i=counter, push=push, target=node.target, iter_=node.iter, push_target=push_target, _target=tmp_target, _stack=self.stack, op_id_iter=op_id, op_id_target=op_id_target) adjoint_template = grads.adjoints[gast.For] adjoint = template.replace( adjoint_template, adjoint_body=adjoint_body, i=counter, pop=pop, pop_target=pop_target, target=ast_.copy_node(node.target), _stack=self.stack, op_id_iter=op_id, op_id_target=op_id_target) return primal, adjoint
def visit_Assign(self, node): """Visit assignment statement.""" if len(node.targets) != 1: raise ValueError('no support for chained assignment') # Before the node gets modified, get a source code representation # to add as a comment later on if anno.hasanno(node, 'pre_anf'): orig_src = anno.getanno(node, 'pre_anf') else: orig_src = quoting.unquote(node) # Set target for the RHS visitor to access self.orig_target = ast_.copy_node(node.targets[0]) # If we know we're going to be putting another stack on the stack, # we should try to make that explicit if isinstance(node.value, gast.Call) and \ anno.hasanno(node.value, 'func') and \ anno.getanno(node.value, 'func') in (utils.Stack, utils.pop_stack): push, pop, op_id = get_push_pop_stack() else: push, pop, op_id = get_push_pop() push_stack, pop_stack, op_id_stack = get_push_pop_stack() # Every assignment statement requires us to store the pre-value, and in the # adjoint restore the value, and reset the gradient store = template.replace( 'push(_stack, y, op_id)', push=push, y=self.orig_target, _stack=self.stack, op_id=op_id) create_substack = template.replace( 'substack = tangent.Stack()', substack=self.substack) store_substack = template.replace( 'push(stack, substack, op_id)', push=push_stack, stack=self.stack, substack=self.substack, op_id=op_id_stack) restore = template.replace( 'y = pop(_stack, op_id)', _stack=self.stack, pop=pop, y=ast_.copy_node(self.orig_target), op_id=op_id) restore_substack = template.replace( 'substack = pop(stack, op_id)', pop=pop_stack, stack=self.stack, substack=self.substack, op_id=op_id_stack) reset = template.replace( 'd[y] = init_grad(y, allow_lazy_initializer=True)', y=self.orig_target, init_grad=utils.INIT_GRAD, namer=self.namer, replace_grad=template.Replace.FULL) # If there are no active nodes, we don't need to find an adjoint # We simply store and restore the state, and reset the gradient if not self.is_active(node): return [store, node], [restore, reset] # We create a temporary variable for the target that the RHS can use self.target = create.create_temp(self.orig_target, self.namer) create_tmp = template.replace( 'tmp = y', tmp=self.target, y=self.orig_target) # Get the primal and adjoint of the RHS expression try: fx, adjoint_rhs = self.visit(node.value) except ValueError as e: context = [t.id if hasattr(t, 'id') else t for t in node.targets] raise ValueError( 'Failed to process assignment to: %s. Error: %s' % (context, e)) if not isinstance(adjoint_rhs, list): adjoint_rhs = [adjoint_rhs] # Walk the RHS adjoint AST to find temporary adjoint variables to sum accumulations = [] for n in adjoint_rhs: for succ in gast.walk(n): if anno.hasanno(succ, 'temp_adjoint_var'): xi = anno.getanno(succ, 'temp_adjoint_var') dxi_partial = ast_.copy_node(succ) accumulations.append(template.replace( 'd[xi] = add_grad(d[xi], dxi_partial)', namer=self.namer, replace_grad=template.Replace.FULL, xi=xi, dxi_partial=dxi_partial, add_grad=utils.ADD_GRAD)) # The primal consists of storing the state and then performing the # assignment with the new primal. # The primal `fx` may be optionally (but rarely) redefined when the # adjoint is generated, in `fx, adjoint_rhs = self.visit(node.value)`. # If we see that the primal value is an Assign node, or a list of nodes # (with at least one being an Assign) node, we allow the primal to change. # Otherwise, we'll build our own Assign node. if isinstance(fx, gast.Assign): assign = [fx] elif (isinstance(fx, list) and any([isinstance(ifx, gast.Assign) for ifx in fx])): assign = fx else: assign = template.replace( 'y = fx', y=ast_.copy_node(self.orig_target), fx=fx) assign = [assign] primal = [store, create_substack, store_substack] + assign # The adjoint involves creating the temporary, restoring the store, # calculating the adjoint, resetting the gradient, and finally accumulating # the partials adjoint = [create_tmp, restore_substack, restore ] + adjoint_rhs + [reset] + accumulations # If the LHS is a subscript assignment with variable index, we need to # store and restore that as well if (isinstance(self.orig_target, gast.Subscript) and isinstance(self.orig_target.slice.value, gast.Name)): push, pop, op_id = get_push_pop() i = self.orig_target.slice.value push_index = template.replace( 'push(_stack, i, op_id)', push=push, i=i, _stack=self.stack, op_id=op_id) pop_index = template.replace( 'i = pop(_stack, op_id)', pop=pop, i=i, _stack_=self.stack, op_id=op_id) primal.insert(len(primal), push_index) adjoint.insert(0, pop_index) # Add a comment in the backwards pass, indicating which # lines in the forward pass generated the adjoint for i, adj in enumerate(adjoint): adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src) return primal, adjoint