Exemplo n.º 1
0
    def visit_Expr(self, node):
        # We need to special-case pushes, e.g. utils.push(_stack,x,op_id)
        adjoint = []
        if (isinstance(node.value, gast.Call)
                and anno.getanno(node.value, 'func') == utils.push):
            orig_src = quoting.unquote(node)
            stack, val, op_id = node.value.args
            push_template = grads.adjoints[utils.push]
            adjoint_rhs = template.replace(push_template,
                                           namer=self.namer,
                                           stack=stack,
                                           val=val,
                                           op_id=op_id)

            # Walk the RHS adjoint AST to find temporary adjoint variables to
            # sum
            accumulation = template.replace(
                'd[xi] = add_grad(d[xi], dxi_partial)',
                namer=self.namer,
                replace_grad=template.Replace.FULL,
                xi=val,
                dxi_partial=ast_.copy_node(adjoint_rhs[0].targets[0]),
                add_grad=utils.ADD_GRAD)
            adjoint = adjoint_rhs + [accumulation]
            for i, adj in enumerate(adjoint):
                adjoint[i] = comments.add_comment(adj,
                                                  'Grad of: %s' % orig_src)
        return node, adjoint
Exemplo n.º 2
0
  def primal_and_adjoint_for_tracing(self, node):
    """Build the primal and adjoint of a traceable function.

    Args:
      node: ast.Call node of a function we wish to trace, instead of transform

    Returns:
      primal: new ast.Assign node to replace the original primal call
      adjoint: new ast.Assign node using the VJP generated in primal to
        calculate the adjoint.
    """
    primal_template = grads.primals[tracing.Traceable]
    adjoint_template = grads.adjoints[tracing.Traceable]

    # Prep
    to_pack = node.args
    target = ast_.copy_node(self.orig_target)
    vjp = quoting.quote(self.namer.unique('%s_grad' % node.func.id))
    tmp = create.create_temp(quoting.quote('tmp'), self.namer)
    assert len(node.keywords) == 0

    # Full replacement of primal
    # TODO: do we need to set 'pri_call' on this?
    primal = template.replace(
        primal_template,
        namer=self.namer,
        result=target,
        fn=node.func,
        tmp=tmp,
        vjp=vjp,
        args=gast.Tuple(elts=to_pack, ctx=gast.Load()))

    # Building adjoint using the vjp generated with the primal
    dto_pack = gast.Tuple(
        elts=[create.create_temp_grad(arg, self.namer) for arg in to_pack],
        ctx=gast.Store())

    adjoint = template.replace(
        adjoint_template,
        namer=self.namer,
        result=target,
        vjp=vjp,
        dargs=dto_pack)

    return primal, adjoint
Exemplo n.º 3
0
  def visit_For(self, node):
    if node.orelse:
      raise ValueError

    # Construct the primal and adjoint of the loop
    body, adjoint_body = self.visit_statements(node.body)

    # We create a loop counter which will be pushed on the stack
    push, pop, op_id = get_push_pop()
    counter = self.namer.counter()

    # In `for i in range ...` the variable `i` is the target, which we
    # temporarily set aside each iteration to push to the stack later
    push_target, pop_target, op_id_target = get_push_pop()
    tmp_target = create.create_temp(node.target, self.namer)

    primal_template = grads.primals[gast.For]
    primal = template.replace(
        primal_template,
        body=body,
        i=counter,
        push=push,
        target=node.target,
        iter_=node.iter,
        push_target=push_target,
        _target=tmp_target,
        _stack=self.stack,
        op_id_iter=op_id,
        op_id_target=op_id_target)

    adjoint_template = grads.adjoints[gast.For]
    adjoint = template.replace(
        adjoint_template,
        adjoint_body=adjoint_body,
        i=counter,
        pop=pop,
        pop_target=pop_target,
        target=ast_.copy_node(node.target),
        _stack=self.stack,
        op_id_iter=op_id,
        op_id_target=op_id_target)

    return primal, adjoint
Exemplo n.º 4
0
 def visit_Return(self, node):
   orig_retval = ast_.copy_node(node.value)
   retval = node.value
   if isinstance(retval, (gast.Name, gast.Subscript)):
     retval = gast.Tuple(
         elts=[create.create_grad(retval, self.namer, tangent=True)],
         ctx=gast.Load())
   elif isinstance(retval, gast.Tuple):
     retval.elts = [
         create.create_grad(elt, self.namer, tangent=True)
         for elt in retval.elts
     ]
   else:
     raise ValueError
   for n in retval.elts:
     n.ctx = gast.Load()
   if self.preserve_result:
     retval.elts.append(orig_retval)
   node.value = retval
   return node
Exemplo n.º 5
0
 def visit_Name(self, node):
     if node.id in self.replacements:
         # NOTE In principle we don't want to copy, because it might break
         # references held in annotations, but we will copy if we have to to
         # avoid duplicate nodes
         if node.id in self.seen:
             new_nodes = ast_.copy_node(self.replacements[node.id])
         else:
             self.seen.add(node.id)
             new_nodes = self.replacements[node.id]
         if isinstance(new_nodes, gast.AST):
             new_nodes = [new_nodes]
         for new_node in new_nodes:
             anno.setanno(new_node, 'replacement', node, safe=False)
             if 'ctx' in new_node._fields:
                 new_node.ctx = node.ctx
         if len(new_nodes) == 1:
             new_nodes, = new_nodes
         return new_nodes
     else:
         return node
Exemplo n.º 6
0
  def visit_Call(self, node):
    """Create adjoint for call.

    We don't allow unpacking of parameters, so we know that each argument
    gets passed in explicitly, allowing us to create partials for each.
    However, templates might perform parameter unpacking (for cases where
    the number of arguments is variable) and express their gradient as a
    tuple. In this case, we have to unpack this tuple of partials.
    """
    # Find the function we are differentiating
    func = anno.getanno(node, 'func')

    if func in non_differentiable.NON_DIFFERENTIABLE:
      return node, []

    if func == tracing.Traceable:
      return self.primal_and_adjoint_for_tracing(node)

    if func in grads.UNIMPLEMENTED_ADJOINTS:
      raise errors.ReverseNotImplementedError(func)


    # If we don't have an adjoint, we will have to step into the called
    # function and differentiate it
    if func not in grads.adjoints:
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if arg.id in self.active_variables)

      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      pri_name = naming.primal_name(func, active_args)
      pri_call = gast.Call(
          func=gast.Name(id=pri_name, ctx=gast.Load(), annotation=None),
          args=[self.substack] + node.args,
          keywords=node.keywords)
      anno.setanno(pri_call, 'pri_call', True)

      dy = create.create_grad(self.target, self.namer)
      dy.ctx = gast.Load()
      dx = create.create_grad(node.args[0], self.namer)
      dx.ctx = gast.Store()
      adj_name = naming.adjoint_name(func, active_args)
      adj_call = gast.Call(
          func=gast.Name(id=adj_name, ctx=gast.Load(), annotation=None),
          args=[self.substack, dy] + node.args,
          keywords=node.keywords)
      anno.setanno(adj_call, 'adj_call', True)
      adjoint = [template.replace('dxs = dfx', namer=self.namer, dfx=adj_call)]
      for j, i in enumerate(active_args):
        adjoint.append(template.replace('d[x] = dxs[i]', namer=self.namer,
                                        x=node.args[i].id, i=gast.Num(n=j)))
      return pri_call, adjoint

    # We have a template for the gradient that we need to fill in
    template_ = grads.adjoints[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)

    # Fill in any missing kwargs with the defaults from the template
    args = quoting.parse_function(template_).body[0].args
    kwargs = dict(zip(*map(reversed, [args.args, args.defaults])))
    kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults)))
    for arg, val in kwargs.items():
      if arg.id not in bound_args.arguments:
        bound_args.arguments[arg.id] = val

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: ast_.copy_node(self.target)}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    packing = []
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())
      packing = [gast.Assign(targets=[target], value=value)]

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(
          template_).co_varnames[-1]] = target
    adjoint = template.replace(template_, namer=self.namer, **arg_replacements)
    unpacking = []
    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [create.create_temp_grad(arg, self.namer)
                  for arg in to_pack]
      value = create.create_grad(target, self.namer)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())
      unpacking = [gast.Assign(targets=[target], value=value)]

    return node, packing + adjoint + unpacking
Exemplo n.º 7
0
  def visit_Assign(self, node):
    """Visit assignment statement."""
    if len(node.targets) != 1:
      raise ValueError('no support for chained assignment')

    # Before the node gets modified, get a source code representation
    # to add as a comment later on
    if anno.hasanno(node, 'pre_anf'):
      orig_src = anno.getanno(node, 'pre_anf')
    else:
      orig_src = quoting.unquote(node)

    # Set target for the RHS visitor to access
    self.orig_target = ast_.copy_node(node.targets[0])

    # If we know we're going to be putting another stack on the stack,
    # we should try to make that explicit
    if isinstance(node.value, gast.Call) and \
        anno.hasanno(node.value, 'func') and \
        anno.getanno(node.value, 'func') in (utils.Stack, utils.pop_stack):
      push, pop, op_id = get_push_pop_stack()
    else:
      push, pop, op_id = get_push_pop()
    push_stack, pop_stack, op_id_stack = get_push_pop_stack()

    # Every assignment statement requires us to store the pre-value, and in the
    # adjoint restore the value, and reset the gradient
    store = template.replace(
        'push(_stack, y, op_id)',
        push=push,
        y=self.orig_target,
        _stack=self.stack,
        op_id=op_id)
    create_substack = template.replace(
        'substack = tangent.Stack()', substack=self.substack)
    store_substack = template.replace(
        'push(stack, substack, op_id)',
        push=push_stack,
        stack=self.stack,
        substack=self.substack,
        op_id=op_id_stack)
    restore = template.replace(
        'y = pop(_stack, op_id)',
        _stack=self.stack,
        pop=pop,
        y=ast_.copy_node(self.orig_target),
        op_id=op_id)
    restore_substack = template.replace(
        'substack = pop(stack, op_id)',
        pop=pop_stack,
        stack=self.stack,
        substack=self.substack,
        op_id=op_id_stack)
    reset = template.replace(
        'd[y] = init_grad(y, allow_lazy_initializer=True)',
        y=self.orig_target,
        init_grad=utils.INIT_GRAD,
        namer=self.namer,
        replace_grad=template.Replace.FULL)

    # If there are no active nodes, we don't need to find an adjoint
    # We simply store and restore the state, and reset the gradient
    if not self.is_active(node):
      return [store, node], [restore, reset]

    # We create a temporary variable for the target that the RHS can use
    self.target = create.create_temp(self.orig_target, self.namer)
    create_tmp = template.replace(
        'tmp = y', tmp=self.target, y=self.orig_target)

    # Get the primal and adjoint of the RHS expression
    try:
      fx, adjoint_rhs = self.visit(node.value)
    except ValueError as e:
      context = [t.id if hasattr(t, 'id') else t for t in node.targets]
      raise ValueError(
          'Failed to process assignment to: %s. Error: %s' % (context, e))
    if not isinstance(adjoint_rhs, list):
      adjoint_rhs = [adjoint_rhs]

    # Walk the RHS adjoint AST to find temporary adjoint variables to sum
    accumulations = []
    for n in adjoint_rhs:
      for succ in gast.walk(n):
        if anno.hasanno(succ, 'temp_adjoint_var'):
          xi = anno.getanno(succ, 'temp_adjoint_var')
          dxi_partial = ast_.copy_node(succ)
          accumulations.append(template.replace(
              'd[xi] = add_grad(d[xi], dxi_partial)',
              namer=self.namer, replace_grad=template.Replace.FULL,
              xi=xi, dxi_partial=dxi_partial, add_grad=utils.ADD_GRAD))

    # The primal consists of storing the state and then performing the
    # assignment with the new primal.
    # The primal `fx` may be optionally (but rarely) redefined when the
    # adjoint is generated, in `fx, adjoint_rhs = self.visit(node.value)`.
    # If we see that the primal value is an Assign node, or a list of nodes
    # (with at least one being an Assign) node, we allow the primal to change.
    # Otherwise, we'll build our own Assign node.
    if isinstance(fx, gast.Assign):
      assign = [fx]
    elif (isinstance(fx, list) and
          any([isinstance(ifx, gast.Assign) for ifx in fx])):
      assign = fx
    else:
      assign = template.replace(
          'y = fx', y=ast_.copy_node(self.orig_target), fx=fx)
      assign = [assign]
    primal = [store, create_substack, store_substack] + assign

    # The adjoint involves creating the temporary, restoring the store,
    # calculating the adjoint, resetting the gradient, and finally accumulating
    # the partials
    adjoint = [create_tmp, restore_substack, restore
              ] + adjoint_rhs + [reset] + accumulations

    # If the LHS is a subscript assignment with variable index, we need to
    # store and restore that as well
    if (isinstance(self.orig_target, gast.Subscript) and
        isinstance(self.orig_target.slice.value, gast.Name)):
      push, pop, op_id = get_push_pop()
      i = self.orig_target.slice.value
      push_index = template.replace(
          'push(_stack, i, op_id)',
          push=push,
          i=i,
          _stack=self.stack,
          op_id=op_id)
      pop_index = template.replace(
          'i = pop(_stack, op_id)',
          pop=pop,
          i=i,
          _stack_=self.stack,
          op_id=op_id)

      primal.insert(len(primal), push_index)
      adjoint.insert(0, pop_index)

    # Add a comment in the backwards pass, indicating which
    # lines in the forward pass generated the adjoint
    for i, adj in enumerate(adjoint):
      adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src)

    return primal, adjoint
Exemplo n.º 8
0
  def visit_FunctionDef(self, node):
    # Construct a namer to guarantee we create unique names that don't
    # override existing names
    self.namer = naming.Namer.build(node)

    # Check that this function has exactly one return statement at the end
    return_nodes = [n for n in gast.walk(node) if isinstance(n, gast.Return)]
    if ((len(return_nodes) > 1) or not isinstance(node.body[-1], gast.Return)):
      raise ValueError('function must have exactly one return statement')
    return_node = ast_.copy_node(return_nodes[0])

    # Perform AD on the function body
    body, adjoint_body = self.visit_statements(node.body[:-1])

    # Annotate the first statement of the primal and adjoint as such
    if body:
      body[0] = comments.add_comment(body[0], 'Beginning of forward pass')
    if adjoint_body:
      adjoint_body[0] = comments.add_comment(
          adjoint_body[0], 'Beginning of backward pass')

    # Before updating the primal arguments, extract the arguments we want
    # to differentiate with respect to
    dx = gast.Tuple([create.create_grad(node.args.args[i], self.namer)
                     for i in self.wrt], ctx=gast.Load())

    if self.preserve_result:
      # Append an extra Assign operation to the primal body
      # that saves the original output value
      stored_result_node = quoting.quote(self.namer.unique('result'))
      assign_stored_result = template.replace(
          'result=orig_result',
          result=stored_result_node,
          orig_result=return_node.value)
      body.append(assign_stored_result)
      dx.elts.append(stored_result_node)

    for _dx in dx.elts:
      _dx.ctx = gast.Load()
    return_dx = gast.Return(value=dx)

    # We add the stack as first argument of the primal
    node.args.args = [self.stack] + node.args.args

    # Rename the function to its primal name
    func = anno.getanno(node, 'func')
    node.name = naming.primal_name(func, self.wrt)

    # The new body is the primal body plus the return statement
    node.body = body + node.body[-1:]

    # Find the cost; the first variable of potentially multiple return values
    # The adjoint will receive a value for the initial gradient of the cost
    y = node.body[-1].value
    if isinstance(y, gast.Tuple):
      y = y.elts[0]
    dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(),
                   annotation=None)

    # Construct the adjoint
    adjoint_template = grads.adjoints[gast.FunctionDef]
    adjoint, = template.replace(adjoint_template, namer=self.namer,
                                adjoint_body=adjoint_body, return_dx=return_dx)
    adjoint.args.args.extend([self.stack, dy])
    adjoint.args.args.extend(node.args.args[1:])
    adjoint.name = naming.adjoint_name(func, self.wrt)

    return node, adjoint
Exemplo n.º 9
0
 def substack(self):
   if not hasattr(self, '_substack'):
     self._substack = quoting.quote(self.namer.unique(naming.SUBSTACK_NAME))
   return ast_.copy_node(self._substack)