def visit_Call(self, node): self.generic_visit(node) def resolve(node): if isinstance(node, gast.Attribute): return getattr(resolve(node.value), node.attr) if isinstance(node, gast.Name): if node.id in self.namespace: return self.namespace[node.id] else: # TODO: we should detect when tracing is a fallback. return getattr(builtins, node.id) func = resolve(node.func) # If the user has used the @tangent.trace decorator, # then we'll switch to tracing the function. if hasattr(func, 'should_trace'): func = tracing.Traceable elif hasattr(func, 'fun'): # TODO: use a less dicey API to check if a function is autograd-wrapped # Autograd primitives keep around their original wrapped function. # We need that to be the func annotation, otherwise we'd have to # redefine derivatives for all autograd wrapped versions of NumPy. # Beyond that, autograd wrapped functions only have fn(*args,**kwargs) # for their signature. We need access tothe default values of functions # for proper code generation. func = func.fun anno.setanno(node, 'func', func)
def create_temp_grad(node, namer, tangent=False): """Create a variable to store partial gradients. Args: node: See `create_grad`. namer: See `create_grad`. tangent: See `create_grad`. Returns: node: See `create_grad`. Returns a node representing the partial gradient. Note that this is always a simple variable e.g. the temporary partial of `x[i]` can be something like `_dxi`. Nodes are given an annotation `temp_adjoint_var`. """ if not isinstance(node, (gast.Subscript, gast.Name)): raise TypeError def _name_temp_grad(node): name = namer.temp_grad(node.id, tangent) temp_node = gast.Name(id=name, annotation=None, ctx=None) return temp_node if isinstance(node, gast.Subscript): temp_node = _name_temp_grad(node.value) else: temp_node = _name_temp_grad(node) anno.setanno(temp_node, 'temp_adjoint_var', node) return temp_node
def _name_grad(node): if not isinstance(node, gast.Name): raise TypeError varname = node.id name = namer.grad(varname, tangent) grad_node = gast.Name(id=name, ctx=None, annotation=None) anno.setanno(grad_node, 'adjoint_var', node) return grad_node
def visit(self, node): if node.value: if anno.hasanno(node.value, self.out_label): before = hash(anno.getanno(node.value, self.out_label)) else: before = None preds = [ anno.getanno(pred.value, self.out_label) for pred in node.prev if anno.hasanno(pred.value, self.out_label) ] if preds: incoming = functools.reduce(self.op, preds[1:], preds[0]) else: incoming = frozenset() anno.setanno(node.value, self.in_label, incoming, safe=False) gen, kill = self.gen(node, incoming) anno.setanno(node.value, self.gen_label, gen, safe=False) anno.setanno(node.value, self.kill_label, kill, safe=False) anno.setanno(node.value, self.out_label, (incoming - kill) | gen, safe=False) if hash(anno.getanno(node.value, self.out_label)) != before: for succ in node.next: self.visit(succ) else: preds = [ anno.getanno(pred.value, self.out_label) for pred in node.prev ] self.exit = functools.reduce(self.op, preds[1:], preds[0])
def visit_For(self, node): # If the iter is a Name that is active, # we need to rewrite the loop. # Iterators of the form `for a in x` rely on an implicit # indexing operation, which Tangent cannot reverse without # more information. So, we will create an explicit # indexing operation. Note that we will use # integer indexes, which will cause strange behavior if # the iterator's `next()` behavior deviates from # a plain incrementing index. # The right thing to do (eventually) is to write a multiple-dispatch # version of the `next` operator, and its adjoint, so that # we can handle e.g. dicts. if isinstance(node.iter, (gast.Name, gast.Subscript, gast.Attribute)): iter_name = ast.get_name(node.iter) if iter_name in anno.getanno(node, 'active_in'): # for a in x: # f(a) # # becomes # for i in range(len(x)): # a = x[i] # f(a) # Get a unique iterator name old_target = copy.deepcopy(node.target) new_target = quoting.quote(self.namer.unique('_idx')) old_iter = copy.deepcopy(node.iter) item_access = template.replace( 'old_target = x[i]', old_target=old_target, x=old_iter, i=new_target) node.target = gast.Name(id=new_target.id, ctx=gast.Store(), annotation=None) node.iter = quoting.quote('range(len(%s))' % iter_name) anno.setanno(node.iter, 'func', range) anno.setanno(node.iter.args[0], 'func', len) node.body = [item_access] + node.body return node
def create_temp(node, namer): """Create a temporary variable. Args: node: Create a temporary variable to store this variable in. namer: A naming object that guarantees the names are unique. Returns: node: See `create_grad`. Returns a temporary variable, which is always a simple variable annotated with `temp_var`. """ if isinstance(node, gast.Name): name = node.id elif isinstance(node, (gast.Attribute, gast.Subscript)): name = node.value.id else: raise TypeError temp_node = gast.Name(id=namer.temp(name), annotation=None, ctx=None) anno.setanno(temp_node, 'temp_var', node) return temp_node
def add_comment(node, text, location='above'): """Add a comment to the given node. If the `SourceWithCommentGenerator` class is used these comments will be output as part of the source code. Note that a node can only contain one comment. Subsequent calls to `add_comment` will ovverride the existing comments. Args: node: The AST node whose containing statement will be commented. text: A comment string. location: Where the comment should appear. Valid values are 'above', 'below' and 'right' Returns: The node with the comment stored as an annotation. """ anno.setanno(node, 'comment', dict(location=location, text=text), safe=False) return node
def visit_Name(self, node): if node.id in self.replacements: # NOTE In principle we don't want to copy, because it might break # references held in annotations, but we will copy if we have to to # avoid duplicate nodes if node.id in self.seen: new_nodes = ast_.copy_node(self.replacements[node.id]) else: self.seen.add(node.id) new_nodes = self.replacements[node.id] if isinstance(new_nodes, gast.AST): new_nodes = [new_nodes] for new_node in new_nodes: anno.setanno(new_node, 'replacement', node, safe=False) if 'ctx' in new_node._fields: new_node.ctx = node.ctx if len(new_nodes) == 1: new_nodes, = new_nodes return new_nodes else: return node
def visit_Expr(self, node): if isinstance(node.value, gast.Call): fn_handle = _get_stack_op_handle(node.value) if fn_handle and fn_handle in [utils.push, utils.push_stack]: op_id = node.value.args[-1].s anno.setanno(node, 'push_var', node.value.args[1]) try: matching_pop = self.push_pop_pairs[op_id][ self.fn_map[fn_handle]] except KeyError as e: if not self.strict: return else: raise e anno.setanno(node, 'pop', matching_pop, False) anno.setanno(node.value, 'pop', matching_pop, False)
def get_push_pop_stack(): """Create pop and push nodes for substacks that are linked. Returns: A push and pop node which have `push_func` and `pop_func` annotations respectively, identifying them as such. They also have a `pop` and `push` annotation respectively, which links the push node to the pop node and vice versa. """ push = copy.deepcopy(PUSH_STACK) pop = copy.deepcopy(POP_STACK) anno.setanno(push, 'pop', pop) anno.setanno(push, 'gen_push', True) anno.setanno(pop, 'push', push) op_id = _generate_op_id() return push, pop, op_id
def visit_Assign(self, node): if not isinstance(node.value, gast.Call): return fn_handle = _get_stack_op_handle(node.value) if fn_handle and fn_handle in [utils.pop, utils.pop_stack]: # Retrieve the op_id, e.g. val = tangent.pop(_stack,'abc') # ^^^ _, op_id_node = node.value.args op_id = op_id_node.s anno.setanno(node, 'pop_var', node.targets[0]) if op_id not in self.push_pop_pairs: raise ValueError('op_id %s not known' % op_id) push_pop_nodes = self.push_pop_pairs[op_id] keys = push_pop_nodes.keys() # Check that the op_id is associated with only two operations if self.strict and len(keys) != 2: raise ValueError('Instead of 2 push/pop fns, found %d' % len(keys)) # Make sure that those two operations are either `push` and `pop` # or `push_stack` and `pop_stack`. if (self.strict and set(keys) != set((utils.push, utils.pop)) and set(keys) != set((utils.push_stack, utils.pop_stack))): raise ValueError('Invalid push/pop function pair. Found %s' % keys) try: matching_push = self.push_pop_pairs[op_id][ self.fn_map[fn_handle]] except KeyError as e: if not self.strict: return else: raise e anno.setanno(node, 'push', matching_push, False) anno.setanno(node.value, 'push', matching_push, False)
def mark(self, node): if not anno.hasanno(node, 'pre_anf') and self.src: anno.setanno(node, 'pre_anf', self.src)
def visit_Call(self, node): if not self.target: return node func = anno.getanno(node, 'func') if func in tangents.UNIMPLEMENTED_TANGENTS: raise errors.ForwardNotImplementedError(func) if func == tracing.Traceable: raise NotImplementedError('Tracing of %s is not enabled in forward mode' % quoting.unquote(node)) if func not in tangents.tangents: try: quoting.parse_function(func) except: raise ValueError('No tangent found for %s, and could not get source.' % func.__name__) # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy) active_args = tuple(i for i, arg in enumerate(node.args) if isinstance(arg, gast.Name)) # TODO: Stack arguments are currently not considered # active, but for forward-mode applied to call trees, # they have to be. When we figure out how to update activity # analysis to do the right thing, we'll want to add the extra check: # `and arg.id in self.active_variables` # TODO: Duplicate of code in reverse_ad. already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) fn_name = naming.tangent_name(func, active_args) orig_args = quoting.parse_function(func).body[0].args tangent_keywords = [] for i in active_args: grad_node = create.create_grad(node.args[i], self.namer, tangent=True) arg_grad_node = create.create_grad( orig_args.args[i], self.namer, tangent=True) grad_node.ctx = gast.Load() tangent_keywords.append( gast.keyword(arg=arg_grad_node.id, value=grad_node)) # Update the original call rhs = gast.Call( func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None), args=node.args, keywords=tangent_keywords + node.keywords) # Set self.value to False to trigger whole primal replacement self.value = False return [rhs] template_ = tangents.tangents[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) bound_args.apply_defaults() # If any keyword arguments weren't passed, we fill them using the # defaults of the original function if grads.DEFAULT in bound_args.arguments.values(): # Build a mapping from names to defaults args = quoting.parse_function(func).body[0].args defaults = {} for arg, default in zip(*map(reversed, [args.args, args.defaults])): defaults[arg.id] = default for arg, default in zip(args.kwonlyargs, args.kw_defaults): if default is not None: defaults[arg.id] = default for name, value in bound_args.arguments.items(): if value is grads.DEFAULT: bound_args.arguments[name] = defaults[name] # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: self.tmp_node} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) # And we fill in the packed tuple into the template arg_replacements[six.get_function_code(template_).co_varnames[ -1]] = target tangent_node = template.replace( template_, replace_grad=template.Replace.TANGENT, namer=self.namer, **arg_replacements) # If the template uses the answer in the RHS of the tangent, # we need to make sure that the regular answer is replaced # with self.tmp_node, but that the gradient is not. We have # to be extra careful for statements like a = exp(a), because # both the target and RHS variables have the same name. tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True) tmp_grad_name = tmp_grad_node.id ans_grad_node = create.create_grad(self.target, self.namer, tangent=True) for _node in tangent_node: for succ in gast.walk(_node): if isinstance(succ, gast.Name) and succ.id == tmp_grad_name: succ.id = ans_grad_node.id if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [ create.create_temp_grad(arg, self.namer, True) for arg in to_pack ] value = create.create_grad(target, self.namer, tangent=True) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) # Stack pops have to be special-cased, we have # to set the 'push' attribute, so we know that if we # remove this pop, we have to remove the equivalent push. # NOTE: this only works if we're doing forward-over-reverse, # where reverse is applied in joint mode, with no call tree. # Otherwise, the pushes and pops won't be matched within a single # function call. if func == tangent.pop: if len(self.metastack): anno.setanno(tangent_node[0], 'push', self.metastack.pop()) else: anno.setanno(tangent_node[0], 'push', None) return tangent_node
from __future__ import division from copy import copy as native_copy import types import autograd import numpy import six from tangent import annotations as anno from tangent import non_differentiable from tangent import quoting INIT_GRAD = quoting.quote('tangent.init_grad') ADD_GRAD = quoting.quote('tangent.add_grad') anno.setanno(INIT_GRAD, 'init_grad', True) anno.setanno(ADD_GRAD, 'add_grad', True) def array_size(x, axis): """Calculate the size of `x` along `axis` dimensions only.""" axis_shape = x.shape if axis is None else tuple(x.shape[a] for a in axis) return max(numpy.prod(axis_shape), 1) class Stack(object): """A stack type that proxies list's `append` and `pop` methods. We don't use list directly so that we can test its type for the multiple- dispatch that occurs in `add_grad` and `init_grad`. """
def visit_Call(self, node): """Create adjoint for call. We don't allow unpacking of parameters, so we know that each argument gets passed in explicitly, allowing us to create partials for each. However, templates might perform parameter unpacking (for cases where the number of arguments is variable) and express their gradient as a tuple. In this case, we have to unpack this tuple of partials. """ # Find the function we are differentiating func = anno.getanno(node, 'func') if func in non_differentiable.NON_DIFFERENTIABLE: return node, [] if func == tracing.Traceable: return self.primal_and_adjoint_for_tracing(node) if func in grads.UNIMPLEMENTED_ADJOINTS: raise errors.ReverseNotImplementedError(func) # If we don't have an adjoint, we will have to step into the called # function and differentiate it if func not in grads.adjoints: active_args = tuple(i for i, arg in enumerate(node.args) if arg.id in self.active_variables) already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) pri_name = naming.primal_name(func, active_args) pri_call = gast.Call( func=gast.Name(id=pri_name, ctx=gast.Load(), annotation=None), args=[self.substack] + node.args, keywords=node.keywords) anno.setanno(pri_call, 'pri_call', True) dy = create.create_grad(self.target, self.namer) dy.ctx = gast.Load() dx = create.create_grad(node.args[0], self.namer) dx.ctx = gast.Store() adj_name = naming.adjoint_name(func, active_args) adj_call = gast.Call( func=gast.Name(id=adj_name, ctx=gast.Load(), annotation=None), args=[self.substack, dy] + node.args, keywords=node.keywords) anno.setanno(adj_call, 'adj_call', True) adjoint = [template.replace('dxs = dfx', namer=self.namer, dfx=adj_call)] for j, i in enumerate(active_args): adjoint.append(template.replace('d[x] = dxs[i]', namer=self.namer, x=node.args[i].id, i=gast.Num(n=j))) return pri_call, adjoint # We have a template for the gradient that we need to fill in template_ = grads.adjoints[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) # Fill in any missing kwargs with the defaults from the template args = quoting.parse_function(template_).body[0].args kwargs = dict(zip(*map(reversed, [args.args, args.defaults]))) kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults))) for arg, val in kwargs.items(): if arg.id not in bound_args.arguments: bound_args.arguments[arg.id] = val # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: ast_.copy_node(self.target)} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs packing = [] flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) packing = [gast.Assign(targets=[target], value=value)] # And we fill in the packed tuple into the template arg_replacements[six.get_function_code( template_).co_varnames[-1]] = target adjoint = template.replace(template_, namer=self.namer, **arg_replacements) unpacking = [] if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [create.create_temp_grad(arg, self.namer) for arg in to_pack] value = create.create_grad(target, self.namer) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) unpacking = [gast.Assign(targets=[target], value=value)] return node, packing + adjoint + unpacking
def visit_FunctionDef(self, node): self.generic_visit(node) anno.setanno(node, 'func', self.func)
from tangent import errors from tangent import fixes from tangent import funcsigs from tangent import grads from tangent import naming from tangent import non_differentiable from tangent import quoting from tangent import template from tangent import tracing from tangent import utils # Some AST nodes to fill in to templates that use stacks or reset gradients PUSH = quoting.quote('tangent.push') POP = quoting.quote('tangent.pop') anno.setanno(PUSH, 'push_func', True) anno.setanno(POP, 'pop_func', True) PUSH_STACK = quoting.quote('tangent.push_stack') POP_STACK = quoting.quote('tangent.pop_stack') anno.setanno(PUSH_STACK, 'push_func', True) anno.setanno(POP_STACK, 'pop_func', True) def _generate_op_id(): return quoting.quote("'_{}'".format(uuid4().hex[:8])) def get_push_pop(): """Create pop and push nodes that are linked. Returns: