Example #1
0
  def primal_and_adjoint_for_tracing(self, node):
    """Build the primal and adjoint of a traceable function.

    Args:
      node: ast.Call node of a function we wish to trace, instead of transform

    Returns:
      primal: new ast.Assign node to replace the original primal call
      adjoint: new ast.Assign node using the VJP generated in primal to
        calculate the adjoint.
    """
    primal_template = grads.primals[tracing.Traceable]
    adjoint_template = grads.adjoints[tracing.Traceable]

    # Prep
    to_pack = node.args
    target = ast_.copy_node(self.orig_target)
    vjp = quoting.quote(self.namer.unique('%s_grad' % node.func.id))
    tmp = create.create_temp(quoting.quote('tmp'), self.namer)
    assert len(node.keywords) == 0

    # Full replacement of primal
    # TODO: do we need to set 'pri_call' on this?
    primal = template.replace(
        primal_template,
        namer=self.namer,
        result=target,
        fn=node.func,
        tmp=tmp,
        vjp=vjp,
        args=gast.Tuple(elts=to_pack, ctx=gast.Load()))

    # Building adjoint using the vjp generated with the primal
    dto_pack = gast.Tuple(
        elts=[create.create_temp_grad(arg, self.namer) for arg in to_pack],
        ctx=gast.Store())

    adjoint = template.replace(
        adjoint_template,
        namer=self.namer,
        result=target,
        vjp=vjp,
        dargs=dto_pack)

    return primal, adjoint
Example #2
0
  def visit_For(self, node):
    if node.orelse:
      raise ValueError

    # Construct the primal and adjoint of the loop
    body, adjoint_body = self.visit_statements(node.body)

    # We create a loop counter which will be pushed on the stack
    push, pop, op_id = get_push_pop()
    counter = self.namer.counter()

    # In `for i in range ...` the variable `i` is the target, which we
    # temporarily set aside each iteration to push to the stack later
    push_target, pop_target, op_id_target = get_push_pop()
    tmp_target = create.create_temp(node.target, self.namer)

    primal_template = grads.primals[gast.For]
    primal = template.replace(
        primal_template,
        body=body,
        i=counter,
        push=push,
        target=node.target,
        iter_=node.iter,
        push_target=push_target,
        _target=tmp_target,
        _stack=self.stack,
        op_id_iter=op_id,
        op_id_target=op_id_target)

    adjoint_template = grads.adjoints[gast.For]
    adjoint = template.replace(
        adjoint_template,
        adjoint_body=adjoint_body,
        i=counter,
        pop=pop,
        pop_target=pop_target,
        target=ast_.copy_node(node.target),
        _stack=self.stack,
        op_id_iter=op_id,
        op_id_target=op_id_target)

    return primal, adjoint
Example #3
0
  def visit_Assign(self, node):
    """Visit assignment statement."""
    if len(node.targets) != 1:
      raise ValueError('no support for chained assignment')

    # Before the node gets modified, get a source code representation
    # to add as a comment later on
    if anno.hasanno(node, 'pre_anf'):
      orig_src = anno.getanno(node, 'pre_anf')
    else:
      orig_src = quoting.unquote(node)

    # Set target for the RHS visitor to access
    self.orig_target = ast_.copy_node(node.targets[0])

    # If we know we're going to be putting another stack on the stack,
    # we should try to make that explicit
    if isinstance(node.value, gast.Call) and \
        anno.hasanno(node.value, 'func') and \
        anno.getanno(node.value, 'func') in (utils.Stack, utils.pop_stack):
      push, pop, op_id = get_push_pop_stack()
    else:
      push, pop, op_id = get_push_pop()
    push_stack, pop_stack, op_id_stack = get_push_pop_stack()

    # Every assignment statement requires us to store the pre-value, and in the
    # adjoint restore the value, and reset the gradient
    store = template.replace(
        'push(_stack, y, op_id)',
        push=push,
        y=self.orig_target,
        _stack=self.stack,
        op_id=op_id)
    create_substack = template.replace(
        'substack = tangent.Stack()', substack=self.substack)
    store_substack = template.replace(
        'push(stack, substack, op_id)',
        push=push_stack,
        stack=self.stack,
        substack=self.substack,
        op_id=op_id_stack)
    restore = template.replace(
        'y = pop(_stack, op_id)',
        _stack=self.stack,
        pop=pop,
        y=ast_.copy_node(self.orig_target),
        op_id=op_id)
    restore_substack = template.replace(
        'substack = pop(stack, op_id)',
        pop=pop_stack,
        stack=self.stack,
        substack=self.substack,
        op_id=op_id_stack)
    reset = template.replace(
        'd[y] = init_grad(y, allow_lazy_initializer=True)',
        y=self.orig_target,
        init_grad=utils.INIT_GRAD,
        namer=self.namer,
        replace_grad=template.Replace.FULL)

    # If there are no active nodes, we don't need to find an adjoint
    # We simply store and restore the state, and reset the gradient
    if not self.is_active(node):
      return [store, node], [restore, reset]

    # We create a temporary variable for the target that the RHS can use
    self.target = create.create_temp(self.orig_target, self.namer)
    create_tmp = template.replace(
        'tmp = y', tmp=self.target, y=self.orig_target)

    # Get the primal and adjoint of the RHS expression
    try:
      fx, adjoint_rhs = self.visit(node.value)
    except ValueError as e:
      context = [t.id if hasattr(t, 'id') else t for t in node.targets]
      raise ValueError(
          'Failed to process assignment to: %s. Error: %s' % (context, e))
    if not isinstance(adjoint_rhs, list):
      adjoint_rhs = [adjoint_rhs]

    # Walk the RHS adjoint AST to find temporary adjoint variables to sum
    accumulations = []
    for n in adjoint_rhs:
      for succ in gast.walk(n):
        if anno.hasanno(succ, 'temp_adjoint_var'):
          xi = anno.getanno(succ, 'temp_adjoint_var')
          dxi_partial = ast_.copy_node(succ)
          accumulations.append(template.replace(
              'd[xi] = add_grad(d[xi], dxi_partial)',
              namer=self.namer, replace_grad=template.Replace.FULL,
              xi=xi, dxi_partial=dxi_partial, add_grad=utils.ADD_GRAD))

    # The primal consists of storing the state and then performing the
    # assignment with the new primal.
    # The primal `fx` may be optionally (but rarely) redefined when the
    # adjoint is generated, in `fx, adjoint_rhs = self.visit(node.value)`.
    # If we see that the primal value is an Assign node, or a list of nodes
    # (with at least one being an Assign) node, we allow the primal to change.
    # Otherwise, we'll build our own Assign node.
    if isinstance(fx, gast.Assign):
      assign = [fx]
    elif (isinstance(fx, list) and
          any([isinstance(ifx, gast.Assign) for ifx in fx])):
      assign = fx
    else:
      assign = template.replace(
          'y = fx', y=ast_.copy_node(self.orig_target), fx=fx)
      assign = [assign]
    primal = [store, create_substack, store_substack] + assign

    # The adjoint involves creating the temporary, restoring the store,
    # calculating the adjoint, resetting the gradient, and finally accumulating
    # the partials
    adjoint = [create_tmp, restore_substack, restore
              ] + adjoint_rhs + [reset] + accumulations

    # If the LHS is a subscript assignment with variable index, we need to
    # store and restore that as well
    if (isinstance(self.orig_target, gast.Subscript) and
        isinstance(self.orig_target.slice.value, gast.Name)):
      push, pop, op_id = get_push_pop()
      i = self.orig_target.slice.value
      push_index = template.replace(
          'push(_stack, i, op_id)',
          push=push,
          i=i,
          _stack=self.stack,
          op_id=op_id)
      pop_index = template.replace(
          'i = pop(_stack, op_id)',
          pop=pop,
          i=i,
          _stack_=self.stack,
          op_id=op_id)

      primal.insert(len(primal), push_index)
      adjoint.insert(0, pop_index)

    # Add a comment in the backwards pass, indicating which
    # lines in the forward pass generated the adjoint
    for i, adj in enumerate(adjoint):
      adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src)

    return primal, adjoint