Beispiel #1
0
 def visit_Assign(self, node):
     # If the target is y, then prepend this statement
     # NOTE Without this test, we'd have an infinite loop
     if node.targets[0].id == 'z':
         statement = quoting.quote("x = 2 * x")
         self.insert_top(statement)
     return node
Beispiel #2
0
  def primal_and_adjoint_for_tracing(self, node):
    """Build the primal and adjoint of a traceable function.

    Args:
      node: ast.Call node of a function we wish to trace, instead of transform

    Returns:
      primal: new ast.Assign node to replace the original primal call
      adjoint: new ast.Assign node using the VJP generated in primal to
        calculate the adjoint.
    """
    primal_template = grads.primals[tracing.Traceable]
    adjoint_template = grads.adjoints[tracing.Traceable]

    # Prep
    to_pack = node.args
    target = ast_.copy_node(self.orig_target)
    vjp = quoting.quote(self.namer.unique('%s_grad' % node.func.id))
    tmp = create.create_temp(quoting.quote('tmp'), self.namer)
    assert len(node.keywords) == 0

    # Full replacement of primal
    # TODO: do we need to set 'pri_call' on this?
    primal = template.replace(
        primal_template,
        namer=self.namer,
        result=target,
        fn=node.func,
        tmp=tmp,
        vjp=vjp,
        args=gast.Tuple(elts=to_pack, ctx=gast.Load()))

    # Building adjoint using the vjp generated with the primal
    dto_pack = gast.Tuple(
        elts=[create.create_temp_grad(arg, self.namer) for arg in to_pack],
        ctx=gast.Store())

    adjoint = template.replace(
        adjoint_template,
        namer=self.namer,
        result=target,
        vjp=vjp,
        dargs=dto_pack)

    return primal, adjoint
Beispiel #3
0
def _create_joint(fwdbwd, func, wrt, input_derivative):
    """Create a user-friendly gradient function.

  By default, gradient functions expect the stack to be passed to them
  explicitly. This function modifies the function so that the stack doesn't
  need to be passed and gets initialized in the function body instead.

  For consistency, gradient functions always return a tuple, even if the
  gradient of only one input was required. We unpack the tuple if it is of
  length one.

  Args:
    fwdbwd: An AST. The function definition of the joint primal and adjoint.
    func: A function handle. The original function that was differentiated.
    wrt: A tuple of integers. The arguments with respect to which we differentiated.

  Returns:
    The function definition of the new function.
  """
    # Correct return to be a non-tuple if there's only one element
    retval = fwdbwd.body[-1]
    if len(retval.value.elts) == 1:
        retval.value = retval.value.elts[0]

    # Make a stack init statement
    init_stack = quoting.quote('%s = tangent.Stack()' % fwdbwd.args.args[0].id)
    init_stack = comments.add_comment(init_stack, 'Initialize the tape')

    # Prepend the stack init to the top of the function
    fwdbwd.body = [init_stack] + fwdbwd.body

    # Replace the function arguments with the original ones
    grad_name = fwdbwd.args.args[1].id
    fwdbwd.args = quoting.parse_function(func).body[0].args

    # Give the function a nice name
    fwdbwd.name = naming.joint_name(func, wrt)

    # Allow the initial gradient to be passed as a keyword argument
    fwdbwd = ast_.append_args(fwdbwd, [grad_name])
    if input_derivative == INPUT_DERIVATIVE.DefaultOne:
        fwdbwd.args.defaults.append(quoting.quote('1.0'))
    return fwdbwd
Beispiel #4
0
  def visit_For(self, node):
    # If the iter is a Name that is active,
    # we need to rewrite the loop.
    # Iterators of the form `for a in x` rely on an implicit
    # indexing operation, which Tangent cannot reverse without
    # more information. So, we will create an explicit
    # indexing operation. Note that we will use
    # integer indexes, which will cause strange behavior if
    # the iterator's `next()` behavior deviates from
    # a plain incrementing index.
    # The right thing to do (eventually) is to write a multiple-dispatch
    # version of the `next` operator, and its adjoint, so that
    # we can handle e.g. dicts.

    if isinstance(node.iter, (gast.Name, gast.Subscript, gast.Attribute)):
      iter_name = ast.get_name(node.iter)
      if iter_name in anno.getanno(node, 'active_in'):
        # for a in x:
        #   f(a)
        # # becomes
        # for i in range(len(x)):
        #   a = x[i]
        #   f(a)

        # Get a unique iterator name
        old_target = copy.deepcopy(node.target)
        new_target = quoting.quote(self.namer.unique('_idx'))
        old_iter = copy.deepcopy(node.iter)

        item_access = template.replace(
          'old_target = x[i]',
          old_target=old_target,
          x=old_iter,
          i=new_target)

        node.target = gast.Name(id=new_target.id, ctx=gast.Store(), annotation=None)
        node.iter = quoting.quote('range(len(%s))' % iter_name)
        anno.setanno(node.iter, 'func', range)
        anno.setanno(node.iter.args[0], 'func', len)
        node.body = [item_access] + node.body

    return node
Beispiel #5
0
 def tmp_node(self):
   if self._tmp_node is None:
     self._tmp_node = quoting.quote(self.namer.unique('tmp'))
   return self._tmp_node
Beispiel #6
0
def append_args(node, node_list):
    if not isinstance(node_list, list):
        raise TypeError('Please pass in a list')
    if all([isinstance(n, str) for n in node_list]):
        node_list = [quoting.quote(n) for n in node_list]
    return ArgAppend(node_list).visit(node)
Beispiel #7
0
"""
from __future__ import absolute_import
from __future__ import division

from copy import copy as native_copy
import types

import autograd
import numpy
import six
from tangent import annotations as anno
from tangent import non_differentiable
from tangent import quoting


INIT_GRAD = quoting.quote('tangent.init_grad')
ADD_GRAD = quoting.quote('tangent.add_grad')
anno.setanno(INIT_GRAD, 'init_grad', True)
anno.setanno(ADD_GRAD, 'add_grad', True)


def array_size(x, axis):
  """Calculate the size of `x` along `axis` dimensions only."""
  axis_shape = x.shape if axis is None else tuple(x.shape[a] for a in axis)
  return max(numpy.prod(axis_shape), 1)


class Stack(object):
  """A stack type that proxies list's `append` and `pop` methods.

  We don't use list directly so that we can test its type for the multiple-
Beispiel #8
0
def test_function_compile():
  with pytest.raises(TypeError):
    compile_.compile_function(quoting.quote('x = y'))
  with pytest.raises(ValueError):
    compile_.compile_function(gast.parse('x = y'))
Beispiel #9
0
def replace(template,
            replace_grad=Replace.PARTIAL,
            namer=None,
            **replacements):
    """Replace placeholders in a Python template (quote).

  Args:
    template: A function, AST node or string to be used as a template. Note
        that if a function is passed, any placeholder is expected to also be a
        function argument. If a string is passed, it must represent valid
        Python code, and any variable it references is a placeholder.
    replace_grad: If Replace.NONE, statements of the form `d[x]` are ignored.
        For the other possible values, see `ReplaceGradTransformer`.
    namer: See `ReplaceGradTransformer`.
    **replacements: A mapping from placeholder names to (lists of) AST nodes
        that these placeholders will be replaced by. If a string is passed,
        `quote` will be called on it to turn it into a node.

  Returns:
    body: An AST node or list of AST nodes with the replacements made. If the
        template was a function, a list will be returned. If the template was a
        node, the same node will be returned. If the template was a string, an
        AST node will be returned (a `Module` node in the case of a multi-line
        string, an `Expr` node otherwise).

  Raises:
    ValueError: If a function is used as a template and an incorrect set of
        replacements was passed.
  """
    # Handle the 3 different types of templates: funcs, nodes, and strings
    is_function = isinstance(template, types.FunctionType)
    if is_function:
        tree = quoting.parse_function(template).body[0]
        placeholders = set(arg.id for arg in tree.args.args)
        tree.args.args = []
        if tree.args.vararg:
            placeholders.add(tree.args.vararg)
            tree.args.vararg = None
        if set(replacements.keys()) != placeholders:
            raise ValueError('too many or few replacements')
    elif isinstance(template, gast.AST):
        tree = template
    else:
        tree = quoting.quote(template, return_expr=True)
    # If the replacements are strings, turn them into nodes
    for k, v in replacements.items():
        if isinstance(v, six.string_types):
            replacements[k] = quoting.quote(v)
    # Perform the replacement
    ReplaceTransformer(replacements).visit(tree)
    # Handle the d[x] operator
    if replace_grad is not Replace.NONE:
        rgt = ReplaceGradTransformer(replace_grad=replace_grad,
                                     namer=namer,
                                     tangent=replace_grad is Replace.TANGENT)
        rgt.visit(tree)
    # Return the AST node with replacements made
    if is_function:
        return tree.body
    else:
        return tree
Beispiel #10
0
 def visit(self, node):
     if anno.hasanno(node, 'push_var'):
         varname = ast_.get_name(anno.getanno(node, 'push_var'))
         if varname not in anno.getanno(node, 'defined_in'):
             self.insert_top(quoting.quote('{} = None'.format(varname)))
     return super(FixStack, self).visit(node)
Beispiel #11
0
def store_state(node, reaching, defined, stack):
  """Push the final state of the primal onto the stack for the adjoint.

  Python's scoping rules make it possible for variables to not be defined in
  certain blocks based on the control flow path taken at runtime. In order to
  make sure we don't try to push non-existing variables onto the stack, we
  defined these variables explicitly (by assigning `None` to them) at the
  beginning of the function.

  All the variables that reach the return statement are pushed onto the
  stack, and in the adjoint they are popped off in reverse order.

  Args:
    node: A module with the primal and adjoint function definitions as returned
        by `reverse_ad`.
    reaching: The variable definitions that reach the end of the primal.
    defined: The variables defined at the end of the primal.
    stack: The stack node to use for storing and restoring state.

  Returns:
    node: A node with the requisite pushes and pops added to make sure that
        state is transferred between primal and adjoint split motion calls.
  """
  defs = [def_ for def_ in reaching if not isinstance(def_[1], gast.arguments)]
  if not len(defs):
    return node
  reaching, original_defs = zip(*defs)

  # Explicitly define variables that might or might not be in scope at the end
  assignments = []
  for id_ in set(reaching) - defined:
    assignments.append(quoting.quote('{} = None'.format(id_)))

  # Store variables at the end of the function and restore them
  store = []
  load = []
  for id_, def_ in zip(reaching, original_defs):
    # If the original definition of a value that we need to store
    # was an initialization as a stack, then we should be using `push_stack`
    # to store its state, and `pop_stack` to restore it. This allows
    # us to avoid doing any `add_grad` calls on the stack, which result
    # in type errors in unoptimized mode (they are usually elided
    # after calling `dead_code_elimination`).
    if isinstance(
        def_, gast.Assign) and 'tangent.Stack()' in quoting.unquote(def_.value):
      push, pop, op_id = get_push_pop_stack()
    else:
      push, pop, op_id = get_push_pop()
    store.append(
        template.replace(
            'push(_stack, val, op_id)',
            push=push,
            val=id_,
            _stack=stack,
            op_id=op_id))
    load.append(
        template.replace(
            'val = pop(_stack, op_id)',
            pop=pop,
            val=id_,
            _stack=stack,
            op_id=op_id))

  body, return_ = node.body[0].body[:-1], node.body[0].body[-1]
  node.body[0].body = assignments + body + store + [return_]
  node.body[1].body = load[::-1] + node.body[1].body

  return node
Beispiel #12
0
def _generate_op_id():
  return quoting.quote("'_{}'".format(uuid4().hex[:8]))
Beispiel #13
0
  def visit_FunctionDef(self, node):
    # Construct a namer to guarantee we create unique names that don't
    # override existing names
    self.namer = naming.Namer.build(node)

    # Check that this function has exactly one return statement at the end
    return_nodes = [n for n in gast.walk(node) if isinstance(n, gast.Return)]
    if ((len(return_nodes) > 1) or not isinstance(node.body[-1], gast.Return)):
      raise ValueError('function must have exactly one return statement')
    return_node = ast_.copy_node(return_nodes[0])

    # Perform AD on the function body
    body, adjoint_body = self.visit_statements(node.body[:-1])

    # Annotate the first statement of the primal and adjoint as such
    if body:
      body[0] = comments.add_comment(body[0], 'Beginning of forward pass')
    if adjoint_body:
      adjoint_body[0] = comments.add_comment(
          adjoint_body[0], 'Beginning of backward pass')

    # Before updating the primal arguments, extract the arguments we want
    # to differentiate with respect to
    dx = gast.Tuple([create.create_grad(node.args.args[i], self.namer)
                     for i in self.wrt], ctx=gast.Load())

    if self.preserve_result:
      # Append an extra Assign operation to the primal body
      # that saves the original output value
      stored_result_node = quoting.quote(self.namer.unique('result'))
      assign_stored_result = template.replace(
          'result=orig_result',
          result=stored_result_node,
          orig_result=return_node.value)
      body.append(assign_stored_result)
      dx.elts.append(stored_result_node)

    for _dx in dx.elts:
      _dx.ctx = gast.Load()
    return_dx = gast.Return(value=dx)

    # We add the stack as first argument of the primal
    node.args.args = [self.stack] + node.args.args

    # Rename the function to its primal name
    func = anno.getanno(node, 'func')
    node.name = naming.primal_name(func, self.wrt)

    # The new body is the primal body plus the return statement
    node.body = body + node.body[-1:]

    # Find the cost; the first variable of potentially multiple return values
    # The adjoint will receive a value for the initial gradient of the cost
    y = node.body[-1].value
    if isinstance(y, gast.Tuple):
      y = y.elts[0]
    dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(),
                   annotation=None)

    # Construct the adjoint
    adjoint_template = grads.adjoints[gast.FunctionDef]
    adjoint, = template.replace(adjoint_template, namer=self.namer,
                                adjoint_body=adjoint_body, return_dx=return_dx)
    adjoint.args.args.extend([self.stack, dy])
    adjoint.args.args.extend(node.args.args[1:])
    adjoint.name = naming.adjoint_name(func, self.wrt)

    return node, adjoint
Beispiel #14
0
 def substack(self):
   if not hasattr(self, '_substack'):
     self._substack = quoting.quote(self.namer.unique(naming.SUBSTACK_NAME))
   return ast_.copy_node(self._substack)
Beispiel #15
0
from tangent import comments
from tangent import create
from tangent import errors
from tangent import fixes
from tangent import funcsigs
from tangent import grads
from tangent import naming
from tangent import non_differentiable
from tangent import quoting
from tangent import template
from tangent import tracing
from tangent import utils


# Some AST nodes to fill in to templates that use stacks or reset gradients
PUSH = quoting.quote('tangent.push')
POP = quoting.quote('tangent.pop')
anno.setanno(PUSH, 'push_func', True)
anno.setanno(POP, 'pop_func', True)
PUSH_STACK = quoting.quote('tangent.push_stack')
POP_STACK = quoting.quote('tangent.pop_stack')
anno.setanno(PUSH_STACK, 'push_func', True)
anno.setanno(POP_STACK, 'pop_func', True)


def _generate_op_id():
  return quoting.quote("'_{}'".format(uuid4().hex[:8]))


def get_push_pop():
  """Create pop and push nodes that are linked.
Beispiel #16
0
def test_node_replace():
    node = template.replace(quoting.quote("a = b"), a="y", b="x * 2")
    assert quoting.unquote(node) == "y = x * 2"