def visit_Subscript(self, node): if isinstance(node.value, (gast.Name, gast.Num)) and node.value.id == 'd': if (not isinstance(node.slice, gast.Index) or not isinstance(node.slice.value, (gast.Subscript, gast.Name, gast.Str))): # This happens when the gradient of a constant is taken if self.replace_grad == Replace.TANGENT: new_node = gast.Num(0) else: new_node = gast.Name(id='_', ctx=None, annotation=None) self.remove(new_node) elif (self.replace_grad in (Replace.FULL, Replace.TANGENT) or isinstance(node.ctx, gast.Load)): new_node = create.create_grad(node.slice.value, self.namer, self.tangent) elif isinstance(node.ctx, gast.Store): new_node = create.create_temp_grad(node.slice.value, self.namer, self.tangent) else: raise ValueError new_node.ctx = node.ctx if isinstance(new_node, gast.Tuple): for elt in new_node.elts: elt.ctx = node.ctx node = new_node return node
def primal_and_adjoint_for_tracing(self, node): """Build the primal and adjoint of a traceable function. Args: node: ast.Call node of a function we wish to trace, instead of transform Returns: primal: new ast.Assign node to replace the original primal call adjoint: new ast.Assign node using the VJP generated in primal to calculate the adjoint. """ primal_template = grads.primals[tracing.Traceable] adjoint_template = grads.adjoints[tracing.Traceable] # Prep to_pack = node.args target = ast_.copy_node(self.orig_target) vjp = quoting.quote(self.namer.unique('%s_grad' % node.func.id)) tmp = create.create_temp(quoting.quote('tmp'), self.namer) assert len(node.keywords) == 0 # Full replacement of primal # TODO: do we need to set 'pri_call' on this? primal = template.replace( primal_template, namer=self.namer, result=target, fn=node.func, tmp=tmp, vjp=vjp, args=gast.Tuple(elts=to_pack, ctx=gast.Load())) # Building adjoint using the vjp generated with the primal dto_pack = gast.Tuple( elts=[create.create_temp_grad(arg, self.namer) for arg in to_pack], ctx=gast.Store()) adjoint = template.replace( adjoint_template, namer=self.namer, result=target, vjp=vjp, dargs=dto_pack) return primal, adjoint
def visit_Call(self, node): if not self.target: return node func = anno.getanno(node, 'func') if func in tangents.UNIMPLEMENTED_TANGENTS: raise errors.ForwardNotImplementedError(func) if func == tracing.Traceable: raise NotImplementedError('Tracing of %s is not enabled in forward mode' % quoting.unquote(node)) if func not in tangents.tangents: try: quoting.parse_function(func) except: raise ValueError('No tangent found for %s, and could not get source.' % func.__name__) # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy) active_args = tuple(i for i, arg in enumerate(node.args) if isinstance(arg, gast.Name)) # TODO: Stack arguments are currently not considered # active, but for forward-mode applied to call trees, # they have to be. When we figure out how to update activity # analysis to do the right thing, we'll want to add the extra check: # `and arg.id in self.active_variables` # TODO: Duplicate of code in reverse_ad. already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) fn_name = naming.tangent_name(func, active_args) orig_args = quoting.parse_function(func).body[0].args tangent_keywords = [] for i in active_args: grad_node = create.create_grad(node.args[i], self.namer, tangent=True) arg_grad_node = create.create_grad( orig_args.args[i], self.namer, tangent=True) grad_node.ctx = gast.Load() tangent_keywords.append( gast.keyword(arg=arg_grad_node.id, value=grad_node)) # Update the original call rhs = gast.Call( func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None), args=node.args, keywords=tangent_keywords + node.keywords) # Set self.value to False to trigger whole primal replacement self.value = False return [rhs] template_ = tangents.tangents[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) bound_args.apply_defaults() # If any keyword arguments weren't passed, we fill them using the # defaults of the original function if grads.DEFAULT in bound_args.arguments.values(): # Build a mapping from names to defaults args = quoting.parse_function(func).body[0].args defaults = {} for arg, default in zip(*map(reversed, [args.args, args.defaults])): defaults[arg.id] = default for arg, default in zip(args.kwonlyargs, args.kw_defaults): if default is not None: defaults[arg.id] = default for name, value in bound_args.arguments.items(): if value is grads.DEFAULT: bound_args.arguments[name] = defaults[name] # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: self.tmp_node} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) # And we fill in the packed tuple into the template arg_replacements[six.get_function_code(template_).co_varnames[ -1]] = target tangent_node = template.replace( template_, replace_grad=template.Replace.TANGENT, namer=self.namer, **arg_replacements) # If the template uses the answer in the RHS of the tangent, # we need to make sure that the regular answer is replaced # with self.tmp_node, but that the gradient is not. We have # to be extra careful for statements like a = exp(a), because # both the target and RHS variables have the same name. tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True) tmp_grad_name = tmp_grad_node.id ans_grad_node = create.create_grad(self.target, self.namer, tangent=True) for _node in tangent_node: for succ in gast.walk(_node): if isinstance(succ, gast.Name) and succ.id == tmp_grad_name: succ.id = ans_grad_node.id if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [ create.create_temp_grad(arg, self.namer, True) for arg in to_pack ] value = create.create_grad(target, self.namer, tangent=True) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) # Stack pops have to be special-cased, we have # to set the 'push' attribute, so we know that if we # remove this pop, we have to remove the equivalent push. # NOTE: this only works if we're doing forward-over-reverse, # where reverse is applied in joint mode, with no call tree. # Otherwise, the pushes and pops won't be matched within a single # function call. if func == tangent.pop: if len(self.metastack): anno.setanno(tangent_node[0], 'push', self.metastack.pop()) else: anno.setanno(tangent_node[0], 'push', None) return tangent_node
def visit_Call(self, node): """Create adjoint for call. We don't allow unpacking of parameters, so we know that each argument gets passed in explicitly, allowing us to create partials for each. However, templates might perform parameter unpacking (for cases where the number of arguments is variable) and express their gradient as a tuple. In this case, we have to unpack this tuple of partials. """ # Find the function we are differentiating func = anno.getanno(node, 'func') if func in non_differentiable.NON_DIFFERENTIABLE: return node, [] if func == tracing.Traceable: return self.primal_and_adjoint_for_tracing(node) if func in grads.UNIMPLEMENTED_ADJOINTS: raise errors.ReverseNotImplementedError(func) # If we don't have an adjoint, we will have to step into the called # function and differentiate it if func not in grads.adjoints: active_args = tuple(i for i, arg in enumerate(node.args) if arg.id in self.active_variables) already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) pri_name = naming.primal_name(func, active_args) pri_call = gast.Call( func=gast.Name(id=pri_name, ctx=gast.Load(), annotation=None), args=[self.substack] + node.args, keywords=node.keywords) anno.setanno(pri_call, 'pri_call', True) dy = create.create_grad(self.target, self.namer) dy.ctx = gast.Load() dx = create.create_grad(node.args[0], self.namer) dx.ctx = gast.Store() adj_name = naming.adjoint_name(func, active_args) adj_call = gast.Call( func=gast.Name(id=adj_name, ctx=gast.Load(), annotation=None), args=[self.substack, dy] + node.args, keywords=node.keywords) anno.setanno(adj_call, 'adj_call', True) adjoint = [template.replace('dxs = dfx', namer=self.namer, dfx=adj_call)] for j, i in enumerate(active_args): adjoint.append(template.replace('d[x] = dxs[i]', namer=self.namer, x=node.args[i].id, i=gast.Num(n=j))) return pri_call, adjoint # We have a template for the gradient that we need to fill in template_ = grads.adjoints[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) # Fill in any missing kwargs with the defaults from the template args = quoting.parse_function(template_).body[0].args kwargs = dict(zip(*map(reversed, [args.args, args.defaults]))) kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults))) for arg, val in kwargs.items(): if arg.id not in bound_args.arguments: bound_args.arguments[arg.id] = val # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: ast_.copy_node(self.target)} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs packing = [] flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) packing = [gast.Assign(targets=[target], value=value)] # And we fill in the packed tuple into the template arg_replacements[six.get_function_code( template_).co_varnames[-1]] = target adjoint = template.replace(template_, namer=self.namer, **arg_replacements) unpacking = [] if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [create.create_temp_grad(arg, self.namer) for arg in to_pack] value = create.create_grad(target, self.namer) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) unpacking = [gast.Assign(targets=[target], value=value)] return node, packing + adjoint + unpacking