示例#1
0
    def visit_Expr(self, node):
        # We need to special-case pushes, e.g. utils.push(_stack,x,op_id)
        adjoint = []
        if (isinstance(node.value, gast.Call)
                and anno.getanno(node.value, 'func') == utils.push):
            orig_src = quoting.unquote(node)
            stack, val, op_id = node.value.args
            push_template = grads.adjoints[utils.push]
            adjoint_rhs = template.replace(push_template,
                                           namer=self.namer,
                                           stack=stack,
                                           val=val,
                                           op_id=op_id)

            # Walk the RHS adjoint AST to find temporary adjoint variables to
            # sum
            accumulation = template.replace(
                'd[xi] = add_grad(d[xi], dxi_partial)',
                namer=self.namer,
                replace_grad=template.Replace.FULL,
                xi=val,
                dxi_partial=ast_.copy_node(adjoint_rhs[0].targets[0]),
                add_grad=utils.ADD_GRAD)
            adjoint = adjoint_rhs + [accumulation]
            for i, adj in enumerate(adjoint):
                adjoint[i] = comments.add_comment(adj,
                                                  'Grad of: %s' % orig_src)
        return node, adjoint
示例#2
0
  def visit_Assign(self, node):
    """Visit assignment statement.

    Notes
    -----
    This method sets the `self.target` attribute to the first assignment
    target for callees to use.
    """

    # Generate tangent name of the assignment
    self.target = node.targets[0]
    self.value = node.value

    # Get the tangent node
    tangent_node = self.visit(self.value)

    # If no forward-mode statement was created, then no extra work is needed.
    if tangent_node == self.value:
      self.target = None
      return node

    if self.value:
      new_node = template.replace(
          tangents.tangents[gast.Assign],
          replace_grad=template.Replace.TANGENT,
          namer=self.namer,
          temp=self.tmp_node,
          tangent=tangent_node,
          target=self.target,
          value=self.value)
      # Ensure that we use a unique tmp node per primal/tangent pair
      self.reset_tmp_node()
    else:
      # We're already in ANF form right now,
      # so we can assume LHS is a single Name "z"
      def template_(z, f):
        tmp = f
        d[z] = tmp[0]
        z = tmp[1]

      new_node = template.replace(
          template_,
          replace_grad=template.Replace.TANGENT,
          namer=self.namer,
          z=self.target,
          f=tangent_node[0])

    # Add it after the original statement
    self.target = None

    # Add in some cool comments
    for i in range(len(new_node)):
      new_node[i] = comments.add_comment(new_node[i], 'Primal and tangent of: '
                                         '%s' % quoting.unquote(node))
    return new_node
示例#3
0
def _create_joint(fwdbwd, func, wrt, input_derivative):
    """Create a user-friendly gradient function.

  By default, gradient functions expect the stack to be passed to them
  explicitly. This function modifies the function so that the stack doesn't
  need to be passed and gets initialized in the function body instead.

  For consistency, gradient functions always return a tuple, even if the
  gradient of only one input was required. We unpack the tuple if it is of
  length one.

  Args:
    fwdbwd: An AST. The function definition of the joint primal and adjoint.
    func: A function handle. The original function that was differentiated.
    wrt: A tuple of integers. The arguments with respect to which we differentiated.

  Returns:
    The function definition of the new function.
  """
    # Correct return to be a non-tuple if there's only one element
    retval = fwdbwd.body[-1]
    if len(retval.value.elts) == 1:
        retval.value = retval.value.elts[0]

    # Make a stack init statement
    init_stack = quoting.quote('%s = tangent.Stack()' % fwdbwd.args.args[0].id)
    init_stack = comments.add_comment(init_stack, 'Initialize the tape')

    # Prepend the stack init to the top of the function
    fwdbwd.body = [init_stack] + fwdbwd.body

    # Replace the function arguments with the original ones
    grad_name = fwdbwd.args.args[1].id
    fwdbwd.args = quoting.parse_function(func).body[0].args

    # Give the function a nice name
    fwdbwd.name = naming.joint_name(func, wrt)

    # Allow the initial gradient to be passed as a keyword argument
    fwdbwd = ast_.append_args(fwdbwd, [grad_name])
    if input_derivative == INPUT_DERIVATIVE.DefaultOne:
        fwdbwd.args.defaults.append(quoting.quote('1.0'))
    return fwdbwd
示例#4
0
 def visit_With(self, node):
   """Deal with the special with insert_grad_of(x) statement."""
   if ast_.is_insert_grad_of_statement(node):
     primal = []
     adjoint = node.body
     if isinstance(adjoint[0], gast.With):
       _, adjoint = self.visit(adjoint[0])
     node.body[0] = comments.add_comment(node.body[0], 'Inserted code')
     # Rename the gradients
     replacements = {}
     for item in node.items:
       if (not isinstance(item.context_expr.args[0], gast.Name) or
           not isinstance(item.optional_vars, gast.Name)):
         raise ValueError
       replacements[item.optional_vars.id] = create.create_grad(
           item.context_expr.args[0], self.namer)
     template.ReplaceTransformer(replacements).visit(node)
     return primal, adjoint
   else:
     return node, []
示例#5
0
def test_comment():
    node = quoting.parse_function(f).body[0]

    comments.add_comment(node.body[0], 'foo', 'above')
    source = quoting.to_source(node)
    lines = source.split('\n')
    assert lines[1].strip() == '# foo'

    comments.add_comment(node.body[0], 'foo', 'right')
    source = quoting.to_source(node)
    lines = source.split('\n')
    assert lines[1].strip() == 'y = x # foo'

    comments.add_comment(node.body[0], 'foo', 'below')
    source = quoting.to_source(node)
    lines = source.split('\n')
    assert lines[2].strip() == '# foo'
示例#6
0
  def visit_Assign(self, node):
    """Visit assignment statement."""
    if len(node.targets) != 1:
      raise ValueError('no support for chained assignment')

    # Before the node gets modified, get a source code representation
    # to add as a comment later on
    if anno.hasanno(node, 'pre_anf'):
      orig_src = anno.getanno(node, 'pre_anf')
    else:
      orig_src = quoting.unquote(node)

    # Set target for the RHS visitor to access
    self.orig_target = ast_.copy_node(node.targets[0])

    # If we know we're going to be putting another stack on the stack,
    # we should try to make that explicit
    if isinstance(node.value, gast.Call) and \
        anno.hasanno(node.value, 'func') and \
        anno.getanno(node.value, 'func') in (utils.Stack, utils.pop_stack):
      push, pop, op_id = get_push_pop_stack()
    else:
      push, pop, op_id = get_push_pop()
    push_stack, pop_stack, op_id_stack = get_push_pop_stack()

    # Every assignment statement requires us to store the pre-value, and in the
    # adjoint restore the value, and reset the gradient
    store = template.replace(
        'push(_stack, y, op_id)',
        push=push,
        y=self.orig_target,
        _stack=self.stack,
        op_id=op_id)
    create_substack = template.replace(
        'substack = tangent.Stack()', substack=self.substack)
    store_substack = template.replace(
        'push(stack, substack, op_id)',
        push=push_stack,
        stack=self.stack,
        substack=self.substack,
        op_id=op_id_stack)
    restore = template.replace(
        'y = pop(_stack, op_id)',
        _stack=self.stack,
        pop=pop,
        y=ast_.copy_node(self.orig_target),
        op_id=op_id)
    restore_substack = template.replace(
        'substack = pop(stack, op_id)',
        pop=pop_stack,
        stack=self.stack,
        substack=self.substack,
        op_id=op_id_stack)
    reset = template.replace(
        'd[y] = init_grad(y, allow_lazy_initializer=True)',
        y=self.orig_target,
        init_grad=utils.INIT_GRAD,
        namer=self.namer,
        replace_grad=template.Replace.FULL)

    # If there are no active nodes, we don't need to find an adjoint
    # We simply store and restore the state, and reset the gradient
    if not self.is_active(node):
      return [store, node], [restore, reset]

    # We create a temporary variable for the target that the RHS can use
    self.target = create.create_temp(self.orig_target, self.namer)
    create_tmp = template.replace(
        'tmp = y', tmp=self.target, y=self.orig_target)

    # Get the primal and adjoint of the RHS expression
    try:
      fx, adjoint_rhs = self.visit(node.value)
    except ValueError as e:
      context = [t.id if hasattr(t, 'id') else t for t in node.targets]
      raise ValueError(
          'Failed to process assignment to: %s. Error: %s' % (context, e))
    if not isinstance(adjoint_rhs, list):
      adjoint_rhs = [adjoint_rhs]

    # Walk the RHS adjoint AST to find temporary adjoint variables to sum
    accumulations = []
    for n in adjoint_rhs:
      for succ in gast.walk(n):
        if anno.hasanno(succ, 'temp_adjoint_var'):
          xi = anno.getanno(succ, 'temp_adjoint_var')
          dxi_partial = ast_.copy_node(succ)
          accumulations.append(template.replace(
              'd[xi] = add_grad(d[xi], dxi_partial)',
              namer=self.namer, replace_grad=template.Replace.FULL,
              xi=xi, dxi_partial=dxi_partial, add_grad=utils.ADD_GRAD))

    # The primal consists of storing the state and then performing the
    # assignment with the new primal.
    # The primal `fx` may be optionally (but rarely) redefined when the
    # adjoint is generated, in `fx, adjoint_rhs = self.visit(node.value)`.
    # If we see that the primal value is an Assign node, or a list of nodes
    # (with at least one being an Assign) node, we allow the primal to change.
    # Otherwise, we'll build our own Assign node.
    if isinstance(fx, gast.Assign):
      assign = [fx]
    elif (isinstance(fx, list) and
          any([isinstance(ifx, gast.Assign) for ifx in fx])):
      assign = fx
    else:
      assign = template.replace(
          'y = fx', y=ast_.copy_node(self.orig_target), fx=fx)
      assign = [assign]
    primal = [store, create_substack, store_substack] + assign

    # The adjoint involves creating the temporary, restoring the store,
    # calculating the adjoint, resetting the gradient, and finally accumulating
    # the partials
    adjoint = [create_tmp, restore_substack, restore
              ] + adjoint_rhs + [reset] + accumulations

    # If the LHS is a subscript assignment with variable index, we need to
    # store and restore that as well
    if (isinstance(self.orig_target, gast.Subscript) and
        isinstance(self.orig_target.slice.value, gast.Name)):
      push, pop, op_id = get_push_pop()
      i = self.orig_target.slice.value
      push_index = template.replace(
          'push(_stack, i, op_id)',
          push=push,
          i=i,
          _stack=self.stack,
          op_id=op_id)
      pop_index = template.replace(
          'i = pop(_stack, op_id)',
          pop=pop,
          i=i,
          _stack_=self.stack,
          op_id=op_id)

      primal.insert(len(primal), push_index)
      adjoint.insert(0, pop_index)

    # Add a comment in the backwards pass, indicating which
    # lines in the forward pass generated the adjoint
    for i, adj in enumerate(adjoint):
      adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src)

    return primal, adjoint
示例#7
0
  def visit_FunctionDef(self, node):
    # Construct a namer to guarantee we create unique names that don't
    # override existing names
    self.namer = naming.Namer.build(node)

    # Check that this function has exactly one return statement at the end
    return_nodes = [n for n in gast.walk(node) if isinstance(n, gast.Return)]
    if ((len(return_nodes) > 1) or not isinstance(node.body[-1], gast.Return)):
      raise ValueError('function must have exactly one return statement')
    return_node = ast_.copy_node(return_nodes[0])

    # Perform AD on the function body
    body, adjoint_body = self.visit_statements(node.body[:-1])

    # Annotate the first statement of the primal and adjoint as such
    if body:
      body[0] = comments.add_comment(body[0], 'Beginning of forward pass')
    if adjoint_body:
      adjoint_body[0] = comments.add_comment(
          adjoint_body[0], 'Beginning of backward pass')

    # Before updating the primal arguments, extract the arguments we want
    # to differentiate with respect to
    dx = gast.Tuple([create.create_grad(node.args.args[i], self.namer)
                     for i in self.wrt], ctx=gast.Load())

    if self.preserve_result:
      # Append an extra Assign operation to the primal body
      # that saves the original output value
      stored_result_node = quoting.quote(self.namer.unique('result'))
      assign_stored_result = template.replace(
          'result=orig_result',
          result=stored_result_node,
          orig_result=return_node.value)
      body.append(assign_stored_result)
      dx.elts.append(stored_result_node)

    for _dx in dx.elts:
      _dx.ctx = gast.Load()
    return_dx = gast.Return(value=dx)

    # We add the stack as first argument of the primal
    node.args.args = [self.stack] + node.args.args

    # Rename the function to its primal name
    func = anno.getanno(node, 'func')
    node.name = naming.primal_name(func, self.wrt)

    # The new body is the primal body plus the return statement
    node.body = body + node.body[-1:]

    # Find the cost; the first variable of potentially multiple return values
    # The adjoint will receive a value for the initial gradient of the cost
    y = node.body[-1].value
    if isinstance(y, gast.Tuple):
      y = y.elts[0]
    dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(),
                   annotation=None)

    # Construct the adjoint
    adjoint_template = grads.adjoints[gast.FunctionDef]
    adjoint, = template.replace(adjoint_template, namer=self.namer,
                                adjoint_body=adjoint_body, return_dx=return_dx)
    adjoint.args.args.extend([self.stack, dy])
    adjoint.args.args.extend(node.args.args[1:])
    adjoint.name = naming.adjoint_name(func, self.wrt)

    return node, adjoint