def visit_Subscript(self, node): if isinstance(node.value, (gast.Name, gast.Num)) and node.value.id == 'd': if (not isinstance(node.slice, gast.Index) or not isinstance(node.slice.value, (gast.Subscript, gast.Name, gast.Str))): # This happens when the gradient of a constant is taken if self.replace_grad == Replace.TANGENT: new_node = gast.Num(0) else: new_node = gast.Name(id='_', ctx=None, annotation=None) self.remove(new_node) elif (self.replace_grad in (Replace.FULL, Replace.TANGENT) or isinstance(node.ctx, gast.Load)): new_node = create.create_grad(node.slice.value, self.namer, self.tangent) elif isinstance(node.ctx, gast.Store): new_node = create.create_temp_grad(node.slice.value, self.namer, self.tangent) else: raise ValueError new_node.ctx = node.ctx if isinstance(new_node, gast.Tuple): for elt in new_node.elts: elt.ctx = node.ctx node = new_node return node
def visit_FunctionDef(self, node): self.namer = naming.Namer.build(node) # Get the tangent of the body new_body = [] for n in node.body: new = self.visit(n) if isinstance(new, (list, tuple)): new_body.extend(new) else: new_body.append(new) node.body = new_body # Add in the initial gradient argument grad_args = [ create.create_grad(arg, self.namer, tangent=True) for i, arg in enumerate(node.args.args) if i in self.wrt ] if len(self.wrt) != len(grad_args): raise ValueError( 'Mismatch between requested and retrieved derivative arguments. ' 'Requested %d, found %d') % (len(self.wrt), len(grad_args)) node.args.args += grad_args if self.check_dims: # Define the shape check code quote def shape_match_template(primal, tangent_): if not tangent.shapes_match(primal, tangent_): raise ValueError( 'Shape mismatch between argument value (%s) and seed derivative ' '(%s)' \ % (numpy.shape(primal), numpy.shape(tangent_))) # Add a shape check for each seed derivative & primal pair. shape_check_nodes = [] for iwrt, tangent_var in zip(self.wrt, grad_args): primal = node.args.args[iwrt] shape_check = template.replace(shape_match_template, primal=primal, tangent_=tangent_var)[0] shape_check_nodes.append(shape_check) node.body = shape_check_nodes + node.body # Add in gradient initialization statements for everything else grad_init_nodes = [ template.replace('d[z] = init_grad(z)', replace_grad=template.Replace.TANGENT, namer=self.namer, z=arg, init_grad=utils.INIT_GRAD) for i, arg in enumerate(node.args.args) if i not in self.wrt ] node.body = grad_init_nodes + node.body # Rename the function func = anno.getanno(node, 'func') node.name = naming.tangent_name(func, self.wrt) return node
def visit_Return(self, node): orig_retval = ast_.copy_node(node.value) retval = node.value if isinstance(retval, (gast.Name, gast.Subscript)): retval = gast.Tuple( elts=[create.create_grad(retval, self.namer, tangent=True)], ctx=gast.Load()) elif isinstance(retval, gast.Tuple): retval.elts = [ create.create_grad(elt, self.namer, tangent=True) for elt in retval.elts ] else: raise ValueError for n in retval.elts: n.ctx = gast.Load() if self.preserve_result: retval.elts.append(orig_retval) node.value = retval return node
def create_grad_list(self, node): assert isinstance(node, (gast.List, gast.Tuple)), 'Must be list or tuple' list_of_nodes = node.elts elts = [] for _node in list_of_nodes: if isinstance(_node, (gast.Name, gast.Subscript)): grad_node = create.create_grad(_node, self.namer, tangent=True) grad_node.ctx = node.ctx elts.append(grad_node) elif isinstance(_node, gast.Num): elts.append(gast.Num(0)) elif isinstance(_node, (gast.List, gast.Tuple)): elts.append(self.create_grad_list(_node.elts)) else: raise ValueError('Cannot handle node type %s' % type(_node)) return node.__class__(elts=elts, ctx=node.ctx)
def visit_With(self, node): """Deal with the special with insert_grad_of(x) statement.""" if ast_.is_insert_grad_of_statement(node): primal = [] adjoint = node.body if isinstance(adjoint[0], gast.With): _, adjoint = self.visit(adjoint[0]) node.body[0] = comments.add_comment(node.body[0], 'Inserted code') # Rename the gradients replacements = {} for item in node.items: if (not isinstance(item.context_expr.args[0], gast.Name) or not isinstance(item.optional_vars, gast.Name)): raise ValueError replacements[item.optional_vars.id] = create.create_grad( item.context_expr.args[0], self.namer) template.ReplaceTransformer(replacements).visit(node) return primal, adjoint else: return node, []
def visit_FunctionDef(self, node): self.namer = naming.Namer.build(node) # Get the tangent of the body new_body = [] for n in node.body: new = self.visit(n) if isinstance(new, (list, tuple)): new_body.extend(new) else: new_body.append(new) node.body = new_body # Add in the initial gradient argument grad_args = [ create.create_grad(arg, self.namer, tangent=True) for i, arg in enumerate(node.args.args) if i in self.wrt ] node.args.args += grad_args # Add in gradient initialization statements for everything else grad_init_nodes = [ template.replace( 'd[z] = init_grad(z)', replace_grad=template.Replace.TANGENT, namer=self.namer, z=arg, init_grad=utils.INIT_GRAD) for i, arg in enumerate(node.args.args) if i not in self.wrt ] node.body = grad_init_nodes + node.body # Rename the function func = anno.getanno(node, 'func') node.name = naming.tangent_name(func, self.wrt) return node
def visit_Call(self, node): if not self.target: return node func = anno.getanno(node, 'func') if func in tangents.UNIMPLEMENTED_TANGENTS: raise errors.ForwardNotImplementedError(func) if func == tracing.Traceable: raise NotImplementedError('Tracing of %s is not enabled in forward mode' % quoting.unquote(node)) if func not in tangents.tangents: try: quoting.parse_function(func) except: raise ValueError('No tangent found for %s, and could not get source.' % func.__name__) # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy) active_args = tuple(i for i, arg in enumerate(node.args) if isinstance(arg, gast.Name)) # TODO: Stack arguments are currently not considered # active, but for forward-mode applied to call trees, # they have to be. When we figure out how to update activity # analysis to do the right thing, we'll want to add the extra check: # `and arg.id in self.active_variables` # TODO: Duplicate of code in reverse_ad. already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) fn_name = naming.tangent_name(func, active_args) orig_args = quoting.parse_function(func).body[0].args tangent_keywords = [] for i in active_args: grad_node = create.create_grad(node.args[i], self.namer, tangent=True) arg_grad_node = create.create_grad( orig_args.args[i], self.namer, tangent=True) grad_node.ctx = gast.Load() tangent_keywords.append( gast.keyword(arg=arg_grad_node.id, value=grad_node)) # Update the original call rhs = gast.Call( func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None), args=node.args, keywords=tangent_keywords + node.keywords) # Set self.value to False to trigger whole primal replacement self.value = False return [rhs] template_ = tangents.tangents[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) bound_args.apply_defaults() # If any keyword arguments weren't passed, we fill them using the # defaults of the original function if grads.DEFAULT in bound_args.arguments.values(): # Build a mapping from names to defaults args = quoting.parse_function(func).body[0].args defaults = {} for arg, default in zip(*map(reversed, [args.args, args.defaults])): defaults[arg.id] = default for arg, default in zip(args.kwonlyargs, args.kw_defaults): if default is not None: defaults[arg.id] = default for name, value in bound_args.arguments.items(): if value is grads.DEFAULT: bound_args.arguments[name] = defaults[name] # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: self.tmp_node} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) # And we fill in the packed tuple into the template arg_replacements[six.get_function_code(template_).co_varnames[ -1]] = target tangent_node = template.replace( template_, replace_grad=template.Replace.TANGENT, namer=self.namer, **arg_replacements) # If the template uses the answer in the RHS of the tangent, # we need to make sure that the regular answer is replaced # with self.tmp_node, but that the gradient is not. We have # to be extra careful for statements like a = exp(a), because # both the target and RHS variables have the same name. tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True) tmp_grad_name = tmp_grad_node.id ans_grad_node = create.create_grad(self.target, self.namer, tangent=True) for _node in tangent_node: for succ in gast.walk(_node): if isinstance(succ, gast.Name) and succ.id == tmp_grad_name: succ.id = ans_grad_node.id if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [ create.create_temp_grad(arg, self.namer, True) for arg in to_pack ] value = create.create_grad(target, self.namer, tangent=True) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) # Stack pops have to be special-cased, we have # to set the 'push' attribute, so we know that if we # remove this pop, we have to remove the equivalent push. # NOTE: this only works if we're doing forward-over-reverse, # where reverse is applied in joint mode, with no call tree. # Otherwise, the pushes and pops won't be matched within a single # function call. if func == tangent.pop: if len(self.metastack): anno.setanno(tangent_node[0], 'push', self.metastack.pop()) else: anno.setanno(tangent_node[0], 'push', None) return tangent_node
def visit_Call(self, node): """Create adjoint for call. We don't allow unpacking of parameters, so we know that each argument gets passed in explicitly, allowing us to create partials for each. However, templates might perform parameter unpacking (for cases where the number of arguments is variable) and express their gradient as a tuple. In this case, we have to unpack this tuple of partials. """ # Find the function we are differentiating func = anno.getanno(node, 'func') if func in non_differentiable.NON_DIFFERENTIABLE: return node, [] if func == tracing.Traceable: return self.primal_and_adjoint_for_tracing(node) if func in grads.UNIMPLEMENTED_ADJOINTS: raise errors.ReverseNotImplementedError(func) # If we don't have an adjoint, we will have to step into the called # function and differentiate it if func not in grads.adjoints: active_args = tuple(i for i, arg in enumerate(node.args) if arg.id in self.active_variables) already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) pri_name = naming.primal_name(func, active_args) pri_call = gast.Call( func=gast.Name(id=pri_name, ctx=gast.Load(), annotation=None), args=[self.substack] + node.args, keywords=node.keywords) anno.setanno(pri_call, 'pri_call', True) dy = create.create_grad(self.target, self.namer) dy.ctx = gast.Load() dx = create.create_grad(node.args[0], self.namer) dx.ctx = gast.Store() adj_name = naming.adjoint_name(func, active_args) adj_call = gast.Call( func=gast.Name(id=adj_name, ctx=gast.Load(), annotation=None), args=[self.substack, dy] + node.args, keywords=node.keywords) anno.setanno(adj_call, 'adj_call', True) adjoint = [template.replace('dxs = dfx', namer=self.namer, dfx=adj_call)] for j, i in enumerate(active_args): adjoint.append(template.replace('d[x] = dxs[i]', namer=self.namer, x=node.args[i].id, i=gast.Num(n=j))) return pri_call, adjoint # We have a template for the gradient that we need to fill in template_ = grads.adjoints[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) # Fill in any missing kwargs with the defaults from the template args = quoting.parse_function(template_).body[0].args kwargs = dict(zip(*map(reversed, [args.args, args.defaults]))) kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults))) for arg, val in kwargs.items(): if arg.id not in bound_args.arguments: bound_args.arguments[arg.id] = val # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: ast_.copy_node(self.target)} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs packing = [] flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) packing = [gast.Assign(targets=[target], value=value)] # And we fill in the packed tuple into the template arg_replacements[six.get_function_code( template_).co_varnames[-1]] = target adjoint = template.replace(template_, namer=self.namer, **arg_replacements) unpacking = [] if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [create.create_temp_grad(arg, self.namer) for arg in to_pack] value = create.create_grad(target, self.namer) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) unpacking = [gast.Assign(targets=[target], value=value)] return node, packing + adjoint + unpacking
def visit_FunctionDef(self, node): # Construct a namer to guarantee we create unique names that don't # override existing names self.namer = naming.Namer.build(node) # Check that this function has exactly one return statement at the end return_nodes = [n for n in gast.walk(node) if isinstance(n, gast.Return)] if ((len(return_nodes) > 1) or not isinstance(node.body[-1], gast.Return)): raise ValueError('function must have exactly one return statement') return_node = ast_.copy_node(return_nodes[0]) # Perform AD on the function body body, adjoint_body = self.visit_statements(node.body[:-1]) # Annotate the first statement of the primal and adjoint as such if body: body[0] = comments.add_comment(body[0], 'Beginning of forward pass') if adjoint_body: adjoint_body[0] = comments.add_comment( adjoint_body[0], 'Beginning of backward pass') # Before updating the primal arguments, extract the arguments we want # to differentiate with respect to dx = gast.Tuple([create.create_grad(node.args.args[i], self.namer) for i in self.wrt], ctx=gast.Load()) if self.preserve_result: # Append an extra Assign operation to the primal body # that saves the original output value stored_result_node = quoting.quote(self.namer.unique('result')) assign_stored_result = template.replace( 'result=orig_result', result=stored_result_node, orig_result=return_node.value) body.append(assign_stored_result) dx.elts.append(stored_result_node) for _dx in dx.elts: _dx.ctx = gast.Load() return_dx = gast.Return(value=dx) # We add the stack as first argument of the primal node.args.args = [self.stack] + node.args.args # Rename the function to its primal name func = anno.getanno(node, 'func') node.name = naming.primal_name(func, self.wrt) # The new body is the primal body plus the return statement node.body = body + node.body[-1:] # Find the cost; the first variable of potentially multiple return values # The adjoint will receive a value for the initial gradient of the cost y = node.body[-1].value if isinstance(y, gast.Tuple): y = y.elts[0] dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(), annotation=None) # Construct the adjoint adjoint_template = grads.adjoints[gast.FunctionDef] adjoint, = template.replace(adjoint_template, namer=self.namer, adjoint_body=adjoint_body, return_dx=return_dx) adjoint.args.args.extend([self.stack, dy]) adjoint.args.args.extend(node.args.args[1:]) adjoint.name = naming.adjoint_name(func, self.wrt) return node, adjoint