def visit_Assign(self, node): # If the target is y, then prepend this statement # NOTE Without this test, we'd have an infinite loop if node.targets[0].id == 'z': statement = quoting.quote("x = 2 * x") self.insert_top(statement) return node
def primal_and_adjoint_for_tracing(self, node): """Build the primal and adjoint of a traceable function. Args: node: ast.Call node of a function we wish to trace, instead of transform Returns: primal: new ast.Assign node to replace the original primal call adjoint: new ast.Assign node using the VJP generated in primal to calculate the adjoint. """ primal_template = grads.primals[tracing.Traceable] adjoint_template = grads.adjoints[tracing.Traceable] # Prep to_pack = node.args target = ast_.copy_node(self.orig_target) vjp = quoting.quote(self.namer.unique('%s_grad' % node.func.id)) tmp = create.create_temp(quoting.quote('tmp'), self.namer) assert len(node.keywords) == 0 # Full replacement of primal # TODO: do we need to set 'pri_call' on this? primal = template.replace( primal_template, namer=self.namer, result=target, fn=node.func, tmp=tmp, vjp=vjp, args=gast.Tuple(elts=to_pack, ctx=gast.Load())) # Building adjoint using the vjp generated with the primal dto_pack = gast.Tuple( elts=[create.create_temp_grad(arg, self.namer) for arg in to_pack], ctx=gast.Store()) adjoint = template.replace( adjoint_template, namer=self.namer, result=target, vjp=vjp, dargs=dto_pack) return primal, adjoint
def _create_joint(fwdbwd, func, wrt, input_derivative): """Create a user-friendly gradient function. By default, gradient functions expect the stack to be passed to them explicitly. This function modifies the function so that the stack doesn't need to be passed and gets initialized in the function body instead. For consistency, gradient functions always return a tuple, even if the gradient of only one input was required. We unpack the tuple if it is of length one. Args: fwdbwd: An AST. The function definition of the joint primal and adjoint. func: A function handle. The original function that was differentiated. wrt: A tuple of integers. The arguments with respect to which we differentiated. Returns: The function definition of the new function. """ # Correct return to be a non-tuple if there's only one element retval = fwdbwd.body[-1] if len(retval.value.elts) == 1: retval.value = retval.value.elts[0] # Make a stack init statement init_stack = quoting.quote('%s = tangent.Stack()' % fwdbwd.args.args[0].id) init_stack = comments.add_comment(init_stack, 'Initialize the tape') # Prepend the stack init to the top of the function fwdbwd.body = [init_stack] + fwdbwd.body # Replace the function arguments with the original ones grad_name = fwdbwd.args.args[1].id fwdbwd.args = quoting.parse_function(func).body[0].args # Give the function a nice name fwdbwd.name = naming.joint_name(func, wrt) # Allow the initial gradient to be passed as a keyword argument fwdbwd = ast_.append_args(fwdbwd, [grad_name]) if input_derivative == INPUT_DERIVATIVE.DefaultOne: fwdbwd.args.defaults.append(quoting.quote('1.0')) return fwdbwd
def visit_For(self, node): # If the iter is a Name that is active, # we need to rewrite the loop. # Iterators of the form `for a in x` rely on an implicit # indexing operation, which Tangent cannot reverse without # more information. So, we will create an explicit # indexing operation. Note that we will use # integer indexes, which will cause strange behavior if # the iterator's `next()` behavior deviates from # a plain incrementing index. # The right thing to do (eventually) is to write a multiple-dispatch # version of the `next` operator, and its adjoint, so that # we can handle e.g. dicts. if isinstance(node.iter, (gast.Name, gast.Subscript, gast.Attribute)): iter_name = ast.get_name(node.iter) if iter_name in anno.getanno(node, 'active_in'): # for a in x: # f(a) # # becomes # for i in range(len(x)): # a = x[i] # f(a) # Get a unique iterator name old_target = copy.deepcopy(node.target) new_target = quoting.quote(self.namer.unique('_idx')) old_iter = copy.deepcopy(node.iter) item_access = template.replace( 'old_target = x[i]', old_target=old_target, x=old_iter, i=new_target) node.target = gast.Name(id=new_target.id, ctx=gast.Store(), annotation=None) node.iter = quoting.quote('range(len(%s))' % iter_name) anno.setanno(node.iter, 'func', range) anno.setanno(node.iter.args[0], 'func', len) node.body = [item_access] + node.body return node
def tmp_node(self): if self._tmp_node is None: self._tmp_node = quoting.quote(self.namer.unique('tmp')) return self._tmp_node
def append_args(node, node_list): if not isinstance(node_list, list): raise TypeError('Please pass in a list') if all([isinstance(n, str) for n in node_list]): node_list = [quoting.quote(n) for n in node_list] return ArgAppend(node_list).visit(node)
""" from __future__ import absolute_import from __future__ import division from copy import copy as native_copy import types import autograd import numpy import six from tangent import annotations as anno from tangent import non_differentiable from tangent import quoting INIT_GRAD = quoting.quote('tangent.init_grad') ADD_GRAD = quoting.quote('tangent.add_grad') anno.setanno(INIT_GRAD, 'init_grad', True) anno.setanno(ADD_GRAD, 'add_grad', True) def array_size(x, axis): """Calculate the size of `x` along `axis` dimensions only.""" axis_shape = x.shape if axis is None else tuple(x.shape[a] for a in axis) return max(numpy.prod(axis_shape), 1) class Stack(object): """A stack type that proxies list's `append` and `pop` methods. We don't use list directly so that we can test its type for the multiple-
def test_function_compile(): with pytest.raises(TypeError): compile_.compile_function(quoting.quote('x = y')) with pytest.raises(ValueError): compile_.compile_function(gast.parse('x = y'))
def replace(template, replace_grad=Replace.PARTIAL, namer=None, **replacements): """Replace placeholders in a Python template (quote). Args: template: A function, AST node or string to be used as a template. Note that if a function is passed, any placeholder is expected to also be a function argument. If a string is passed, it must represent valid Python code, and any variable it references is a placeholder. replace_grad: If Replace.NONE, statements of the form `d[x]` are ignored. For the other possible values, see `ReplaceGradTransformer`. namer: See `ReplaceGradTransformer`. **replacements: A mapping from placeholder names to (lists of) AST nodes that these placeholders will be replaced by. If a string is passed, `quote` will be called on it to turn it into a node. Returns: body: An AST node or list of AST nodes with the replacements made. If the template was a function, a list will be returned. If the template was a node, the same node will be returned. If the template was a string, an AST node will be returned (a `Module` node in the case of a multi-line string, an `Expr` node otherwise). Raises: ValueError: If a function is used as a template and an incorrect set of replacements was passed. """ # Handle the 3 different types of templates: funcs, nodes, and strings is_function = isinstance(template, types.FunctionType) if is_function: tree = quoting.parse_function(template).body[0] placeholders = set(arg.id for arg in tree.args.args) tree.args.args = [] if tree.args.vararg: placeholders.add(tree.args.vararg) tree.args.vararg = None if set(replacements.keys()) != placeholders: raise ValueError('too many or few replacements') elif isinstance(template, gast.AST): tree = template else: tree = quoting.quote(template, return_expr=True) # If the replacements are strings, turn them into nodes for k, v in replacements.items(): if isinstance(v, six.string_types): replacements[k] = quoting.quote(v) # Perform the replacement ReplaceTransformer(replacements).visit(tree) # Handle the d[x] operator if replace_grad is not Replace.NONE: rgt = ReplaceGradTransformer(replace_grad=replace_grad, namer=namer, tangent=replace_grad is Replace.TANGENT) rgt.visit(tree) # Return the AST node with replacements made if is_function: return tree.body else: return tree
def visit(self, node): if anno.hasanno(node, 'push_var'): varname = ast_.get_name(anno.getanno(node, 'push_var')) if varname not in anno.getanno(node, 'defined_in'): self.insert_top(quoting.quote('{} = None'.format(varname))) return super(FixStack, self).visit(node)
def store_state(node, reaching, defined, stack): """Push the final state of the primal onto the stack for the adjoint. Python's scoping rules make it possible for variables to not be defined in certain blocks based on the control flow path taken at runtime. In order to make sure we don't try to push non-existing variables onto the stack, we defined these variables explicitly (by assigning `None` to them) at the beginning of the function. All the variables that reach the return statement are pushed onto the stack, and in the adjoint they are popped off in reverse order. Args: node: A module with the primal and adjoint function definitions as returned by `reverse_ad`. reaching: The variable definitions that reach the end of the primal. defined: The variables defined at the end of the primal. stack: The stack node to use for storing and restoring state. Returns: node: A node with the requisite pushes and pops added to make sure that state is transferred between primal and adjoint split motion calls. """ defs = [def_ for def_ in reaching if not isinstance(def_[1], gast.arguments)] if not len(defs): return node reaching, original_defs = zip(*defs) # Explicitly define variables that might or might not be in scope at the end assignments = [] for id_ in set(reaching) - defined: assignments.append(quoting.quote('{} = None'.format(id_))) # Store variables at the end of the function and restore them store = [] load = [] for id_, def_ in zip(reaching, original_defs): # If the original definition of a value that we need to store # was an initialization as a stack, then we should be using `push_stack` # to store its state, and `pop_stack` to restore it. This allows # us to avoid doing any `add_grad` calls on the stack, which result # in type errors in unoptimized mode (they are usually elided # after calling `dead_code_elimination`). if isinstance( def_, gast.Assign) and 'tangent.Stack()' in quoting.unquote(def_.value): push, pop, op_id = get_push_pop_stack() else: push, pop, op_id = get_push_pop() store.append( template.replace( 'push(_stack, val, op_id)', push=push, val=id_, _stack=stack, op_id=op_id)) load.append( template.replace( 'val = pop(_stack, op_id)', pop=pop, val=id_, _stack=stack, op_id=op_id)) body, return_ = node.body[0].body[:-1], node.body[0].body[-1] node.body[0].body = assignments + body + store + [return_] node.body[1].body = load[::-1] + node.body[1].body return node
def _generate_op_id(): return quoting.quote("'_{}'".format(uuid4().hex[:8]))
def visit_FunctionDef(self, node): # Construct a namer to guarantee we create unique names that don't # override existing names self.namer = naming.Namer.build(node) # Check that this function has exactly one return statement at the end return_nodes = [n for n in gast.walk(node) if isinstance(n, gast.Return)] if ((len(return_nodes) > 1) or not isinstance(node.body[-1], gast.Return)): raise ValueError('function must have exactly one return statement') return_node = ast_.copy_node(return_nodes[0]) # Perform AD on the function body body, adjoint_body = self.visit_statements(node.body[:-1]) # Annotate the first statement of the primal and adjoint as such if body: body[0] = comments.add_comment(body[0], 'Beginning of forward pass') if adjoint_body: adjoint_body[0] = comments.add_comment( adjoint_body[0], 'Beginning of backward pass') # Before updating the primal arguments, extract the arguments we want # to differentiate with respect to dx = gast.Tuple([create.create_grad(node.args.args[i], self.namer) for i in self.wrt], ctx=gast.Load()) if self.preserve_result: # Append an extra Assign operation to the primal body # that saves the original output value stored_result_node = quoting.quote(self.namer.unique('result')) assign_stored_result = template.replace( 'result=orig_result', result=stored_result_node, orig_result=return_node.value) body.append(assign_stored_result) dx.elts.append(stored_result_node) for _dx in dx.elts: _dx.ctx = gast.Load() return_dx = gast.Return(value=dx) # We add the stack as first argument of the primal node.args.args = [self.stack] + node.args.args # Rename the function to its primal name func = anno.getanno(node, 'func') node.name = naming.primal_name(func, self.wrt) # The new body is the primal body plus the return statement node.body = body + node.body[-1:] # Find the cost; the first variable of potentially multiple return values # The adjoint will receive a value for the initial gradient of the cost y = node.body[-1].value if isinstance(y, gast.Tuple): y = y.elts[0] dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(), annotation=None) # Construct the adjoint adjoint_template = grads.adjoints[gast.FunctionDef] adjoint, = template.replace(adjoint_template, namer=self.namer, adjoint_body=adjoint_body, return_dx=return_dx) adjoint.args.args.extend([self.stack, dy]) adjoint.args.args.extend(node.args.args[1:]) adjoint.name = naming.adjoint_name(func, self.wrt) return node, adjoint
def substack(self): if not hasattr(self, '_substack'): self._substack = quoting.quote(self.namer.unique(naming.SUBSTACK_NAME)) return ast_.copy_node(self._substack)
from tangent import comments from tangent import create from tangent import errors from tangent import fixes from tangent import funcsigs from tangent import grads from tangent import naming from tangent import non_differentiable from tangent import quoting from tangent import template from tangent import tracing from tangent import utils # Some AST nodes to fill in to templates that use stacks or reset gradients PUSH = quoting.quote('tangent.push') POP = quoting.quote('tangent.pop') anno.setanno(PUSH, 'push_func', True) anno.setanno(POP, 'pop_func', True) PUSH_STACK = quoting.quote('tangent.push_stack') POP_STACK = quoting.quote('tangent.pop_stack') anno.setanno(PUSH_STACK, 'push_func', True) anno.setanno(POP_STACK, 'pop_func', True) def _generate_op_id(): return quoting.quote("'_{}'".format(uuid4().hex[:8])) def get_push_pop(): """Create pop and push nodes that are linked.
def test_node_replace(): node = template.replace(quoting.quote("a = b"), a="y", b="x * 2") assert quoting.unquote(node) == "y = x * 2"