def visit_Expr(self, node): # We need to special-case pushes, e.g. utils.push(_stack,x,op_id) adjoint = [] if (isinstance(node.value, gast.Call) and anno.getanno(node.value, 'func') == utils.push): orig_src = quoting.unquote(node) stack, val, op_id = node.value.args push_template = grads.adjoints[utils.push] adjoint_rhs = template.replace(push_template, namer=self.namer, stack=stack, val=val, op_id=op_id) # Walk the RHS adjoint AST to find temporary adjoint variables to # sum accumulation = template.replace( 'd[xi] = add_grad(d[xi], dxi_partial)', namer=self.namer, replace_grad=template.Replace.FULL, xi=val, dxi_partial=ast_.copy_node(adjoint_rhs[0].targets[0]), add_grad=utils.ADD_GRAD) adjoint = adjoint_rhs + [accumulation] for i, adj in enumerate(adjoint): adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src) return node, adjoint
def primal_and_adjoint_for_tracing(self, node): """Build the primal and adjoint of a traceable function. Args: node: ast.Call node of a function we wish to trace, instead of transform Returns: primal: new ast.Assign node to replace the original primal call adjoint: new ast.Assign node using the VJP generated in primal to calculate the adjoint. """ primal_template = grads.primals[tracing.Traceable] adjoint_template = grads.adjoints[tracing.Traceable] # Prep to_pack = node.args target = ast_.copy_node(self.orig_target) vjp = quoting.quote(self.namer.unique('%s_grad' % node.func.id)) tmp = create.create_temp(quoting.quote('tmp'), self.namer) assert len(node.keywords) == 0 # Full replacement of primal # TODO: do we need to set 'pri_call' on this? primal = template.replace( primal_template, namer=self.namer, result=target, fn=node.func, tmp=tmp, vjp=vjp, args=gast.Tuple(elts=to_pack, ctx=gast.Load())) # Building adjoint using the vjp generated with the primal dto_pack = gast.Tuple( elts=[create.create_temp_grad(arg, self.namer) for arg in to_pack], ctx=gast.Store()) adjoint = template.replace( adjoint_template, namer=self.namer, result=target, vjp=vjp, dargs=dto_pack) return primal, adjoint
def visit_For(self, node): if node.orelse: raise ValueError # Construct the primal and adjoint of the loop body, adjoint_body = self.visit_statements(node.body) # We create a loop counter which will be pushed on the stack push, pop, op_id = get_push_pop() counter = self.namer.counter() # In `for i in range ...` the variable `i` is the target, which we # temporarily set aside each iteration to push to the stack later push_target, pop_target, op_id_target = get_push_pop() tmp_target = create.create_temp(node.target, self.namer) primal_template = grads.primals[gast.For] primal = template.replace( primal_template, body=body, i=counter, push=push, target=node.target, iter_=node.iter, push_target=push_target, _target=tmp_target, _stack=self.stack, op_id_iter=op_id, op_id_target=op_id_target) adjoint_template = grads.adjoints[gast.For] adjoint = template.replace( adjoint_template, adjoint_body=adjoint_body, i=counter, pop=pop, pop_target=pop_target, target=ast_.copy_node(node.target), _stack=self.stack, op_id_iter=op_id, op_id_target=op_id_target) return primal, adjoint
def visit_Return(self, node): orig_retval = ast_.copy_node(node.value) retval = node.value if isinstance(retval, (gast.Name, gast.Subscript)): retval = gast.Tuple( elts=[create.create_grad(retval, self.namer, tangent=True)], ctx=gast.Load()) elif isinstance(retval, gast.Tuple): retval.elts = [ create.create_grad(elt, self.namer, tangent=True) for elt in retval.elts ] else: raise ValueError for n in retval.elts: n.ctx = gast.Load() if self.preserve_result: retval.elts.append(orig_retval) node.value = retval return node
def visit_Name(self, node): if node.id in self.replacements: # NOTE In principle we don't want to copy, because it might break # references held in annotations, but we will copy if we have to to # avoid duplicate nodes if node.id in self.seen: new_nodes = ast_.copy_node(self.replacements[node.id]) else: self.seen.add(node.id) new_nodes = self.replacements[node.id] if isinstance(new_nodes, gast.AST): new_nodes = [new_nodes] for new_node in new_nodes: anno.setanno(new_node, 'replacement', node, safe=False) if 'ctx' in new_node._fields: new_node.ctx = node.ctx if len(new_nodes) == 1: new_nodes, = new_nodes return new_nodes else: return node
def visit_Call(self, node): """Create adjoint for call. We don't allow unpacking of parameters, so we know that each argument gets passed in explicitly, allowing us to create partials for each. However, templates might perform parameter unpacking (for cases where the number of arguments is variable) and express their gradient as a tuple. In this case, we have to unpack this tuple of partials. """ # Find the function we are differentiating func = anno.getanno(node, 'func') if func in non_differentiable.NON_DIFFERENTIABLE: return node, [] if func == tracing.Traceable: return self.primal_and_adjoint_for_tracing(node) if func in grads.UNIMPLEMENTED_ADJOINTS: raise errors.ReverseNotImplementedError(func) # If we don't have an adjoint, we will have to step into the called # function and differentiate it if func not in grads.adjoints: active_args = tuple(i for i, arg in enumerate(node.args) if arg.id in self.active_variables) already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) pri_name = naming.primal_name(func, active_args) pri_call = gast.Call( func=gast.Name(id=pri_name, ctx=gast.Load(), annotation=None), args=[self.substack] + node.args, keywords=node.keywords) anno.setanno(pri_call, 'pri_call', True) dy = create.create_grad(self.target, self.namer) dy.ctx = gast.Load() dx = create.create_grad(node.args[0], self.namer) dx.ctx = gast.Store() adj_name = naming.adjoint_name(func, active_args) adj_call = gast.Call( func=gast.Name(id=adj_name, ctx=gast.Load(), annotation=None), args=[self.substack, dy] + node.args, keywords=node.keywords) anno.setanno(adj_call, 'adj_call', True) adjoint = [template.replace('dxs = dfx', namer=self.namer, dfx=adj_call)] for j, i in enumerate(active_args): adjoint.append(template.replace('d[x] = dxs[i]', namer=self.namer, x=node.args[i].id, i=gast.Num(n=j))) return pri_call, adjoint # We have a template for the gradient that we need to fill in template_ = grads.adjoints[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) # Fill in any missing kwargs with the defaults from the template args = quoting.parse_function(template_).body[0].args kwargs = dict(zip(*map(reversed, [args.args, args.defaults]))) kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults))) for arg, val in kwargs.items(): if arg.id not in bound_args.arguments: bound_args.arguments[arg.id] = val # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: ast_.copy_node(self.target)} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs packing = [] flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) packing = [gast.Assign(targets=[target], value=value)] # And we fill in the packed tuple into the template arg_replacements[six.get_function_code( template_).co_varnames[-1]] = target adjoint = template.replace(template_, namer=self.namer, **arg_replacements) unpacking = [] if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [create.create_temp_grad(arg, self.namer) for arg in to_pack] value = create.create_grad(target, self.namer) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) unpacking = [gast.Assign(targets=[target], value=value)] return node, packing + adjoint + unpacking
def visit_Assign(self, node): """Visit assignment statement.""" if len(node.targets) != 1: raise ValueError('no support for chained assignment') # Before the node gets modified, get a source code representation # to add as a comment later on if anno.hasanno(node, 'pre_anf'): orig_src = anno.getanno(node, 'pre_anf') else: orig_src = quoting.unquote(node) # Set target for the RHS visitor to access self.orig_target = ast_.copy_node(node.targets[0]) # If we know we're going to be putting another stack on the stack, # we should try to make that explicit if isinstance(node.value, gast.Call) and \ anno.hasanno(node.value, 'func') and \ anno.getanno(node.value, 'func') in (utils.Stack, utils.pop_stack): push, pop, op_id = get_push_pop_stack() else: push, pop, op_id = get_push_pop() push_stack, pop_stack, op_id_stack = get_push_pop_stack() # Every assignment statement requires us to store the pre-value, and in the # adjoint restore the value, and reset the gradient store = template.replace( 'push(_stack, y, op_id)', push=push, y=self.orig_target, _stack=self.stack, op_id=op_id) create_substack = template.replace( 'substack = tangent.Stack()', substack=self.substack) store_substack = template.replace( 'push(stack, substack, op_id)', push=push_stack, stack=self.stack, substack=self.substack, op_id=op_id_stack) restore = template.replace( 'y = pop(_stack, op_id)', _stack=self.stack, pop=pop, y=ast_.copy_node(self.orig_target), op_id=op_id) restore_substack = template.replace( 'substack = pop(stack, op_id)', pop=pop_stack, stack=self.stack, substack=self.substack, op_id=op_id_stack) reset = template.replace( 'd[y] = init_grad(y, allow_lazy_initializer=True)', y=self.orig_target, init_grad=utils.INIT_GRAD, namer=self.namer, replace_grad=template.Replace.FULL) # If there are no active nodes, we don't need to find an adjoint # We simply store and restore the state, and reset the gradient if not self.is_active(node): return [store, node], [restore, reset] # We create a temporary variable for the target that the RHS can use self.target = create.create_temp(self.orig_target, self.namer) create_tmp = template.replace( 'tmp = y', tmp=self.target, y=self.orig_target) # Get the primal and adjoint of the RHS expression try: fx, adjoint_rhs = self.visit(node.value) except ValueError as e: context = [t.id if hasattr(t, 'id') else t for t in node.targets] raise ValueError( 'Failed to process assignment to: %s. Error: %s' % (context, e)) if not isinstance(adjoint_rhs, list): adjoint_rhs = [adjoint_rhs] # Walk the RHS adjoint AST to find temporary adjoint variables to sum accumulations = [] for n in adjoint_rhs: for succ in gast.walk(n): if anno.hasanno(succ, 'temp_adjoint_var'): xi = anno.getanno(succ, 'temp_adjoint_var') dxi_partial = ast_.copy_node(succ) accumulations.append(template.replace( 'd[xi] = add_grad(d[xi], dxi_partial)', namer=self.namer, replace_grad=template.Replace.FULL, xi=xi, dxi_partial=dxi_partial, add_grad=utils.ADD_GRAD)) # The primal consists of storing the state and then performing the # assignment with the new primal. # The primal `fx` may be optionally (but rarely) redefined when the # adjoint is generated, in `fx, adjoint_rhs = self.visit(node.value)`. # If we see that the primal value is an Assign node, or a list of nodes # (with at least one being an Assign) node, we allow the primal to change. # Otherwise, we'll build our own Assign node. if isinstance(fx, gast.Assign): assign = [fx] elif (isinstance(fx, list) and any([isinstance(ifx, gast.Assign) for ifx in fx])): assign = fx else: assign = template.replace( 'y = fx', y=ast_.copy_node(self.orig_target), fx=fx) assign = [assign] primal = [store, create_substack, store_substack] + assign # The adjoint involves creating the temporary, restoring the store, # calculating the adjoint, resetting the gradient, and finally accumulating # the partials adjoint = [create_tmp, restore_substack, restore ] + adjoint_rhs + [reset] + accumulations # If the LHS is a subscript assignment with variable index, we need to # store and restore that as well if (isinstance(self.orig_target, gast.Subscript) and isinstance(self.orig_target.slice.value, gast.Name)): push, pop, op_id = get_push_pop() i = self.orig_target.slice.value push_index = template.replace( 'push(_stack, i, op_id)', push=push, i=i, _stack=self.stack, op_id=op_id) pop_index = template.replace( 'i = pop(_stack, op_id)', pop=pop, i=i, _stack_=self.stack, op_id=op_id) primal.insert(len(primal), push_index) adjoint.insert(0, pop_index) # Add a comment in the backwards pass, indicating which # lines in the forward pass generated the adjoint for i, adj in enumerate(adjoint): adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src) return primal, adjoint
def visit_FunctionDef(self, node): # Construct a namer to guarantee we create unique names that don't # override existing names self.namer = naming.Namer.build(node) # Check that this function has exactly one return statement at the end return_nodes = [n for n in gast.walk(node) if isinstance(n, gast.Return)] if ((len(return_nodes) > 1) or not isinstance(node.body[-1], gast.Return)): raise ValueError('function must have exactly one return statement') return_node = ast_.copy_node(return_nodes[0]) # Perform AD on the function body body, adjoint_body = self.visit_statements(node.body[:-1]) # Annotate the first statement of the primal and adjoint as such if body: body[0] = comments.add_comment(body[0], 'Beginning of forward pass') if adjoint_body: adjoint_body[0] = comments.add_comment( adjoint_body[0], 'Beginning of backward pass') # Before updating the primal arguments, extract the arguments we want # to differentiate with respect to dx = gast.Tuple([create.create_grad(node.args.args[i], self.namer) for i in self.wrt], ctx=gast.Load()) if self.preserve_result: # Append an extra Assign operation to the primal body # that saves the original output value stored_result_node = quoting.quote(self.namer.unique('result')) assign_stored_result = template.replace( 'result=orig_result', result=stored_result_node, orig_result=return_node.value) body.append(assign_stored_result) dx.elts.append(stored_result_node) for _dx in dx.elts: _dx.ctx = gast.Load() return_dx = gast.Return(value=dx) # We add the stack as first argument of the primal node.args.args = [self.stack] + node.args.args # Rename the function to its primal name func = anno.getanno(node, 'func') node.name = naming.primal_name(func, self.wrt) # The new body is the primal body plus the return statement node.body = body + node.body[-1:] # Find the cost; the first variable of potentially multiple return values # The adjoint will receive a value for the initial gradient of the cost y = node.body[-1].value if isinstance(y, gast.Tuple): y = y.elts[0] dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(), annotation=None) # Construct the adjoint adjoint_template = grads.adjoints[gast.FunctionDef] adjoint, = template.replace(adjoint_template, namer=self.namer, adjoint_body=adjoint_body, return_dx=return_dx) adjoint.args.args.extend([self.stack, dy]) adjoint.args.args.extend(node.args.args[1:]) adjoint.name = naming.adjoint_name(func, self.wrt) return node, adjoint
def substack(self): if not hasattr(self, '_substack'): self._substack = quoting.quote(self.namer.unique(naming.SUBSTACK_NAME)) return ast_.copy_node(self._substack)