Beispiel #1
0
    def visit_Call(self, node):
        self.generic_visit(node)

        def resolve(node):
            if isinstance(node, gast.Attribute):
                return getattr(resolve(node.value), node.attr)
            if isinstance(node, gast.Name):
                if node.id in self.namespace:
                    return self.namespace[node.id]
                else:
                    # TODO: we should detect when tracing is a fallback.
                    return getattr(builtins, node.id)

        func = resolve(node.func)
        # If the user has used the @tangent.trace decorator,
        # then we'll switch to tracing the function.
        if hasattr(func, 'should_trace'):
            func = tracing.Traceable
        elif hasattr(func, 'fun'):
            # TODO: use a less dicey API to check if a function is autograd-wrapped
            # Autograd primitives keep around their original wrapped function.
            # We need that to be the func annotation, otherwise we'd have to
            # redefine derivatives for all autograd wrapped versions of NumPy.
            # Beyond that, autograd wrapped functions only have fn(*args,**kwargs)
            # for their signature. We need access tothe default values of functions
            # for proper code generation.
            func = func.fun
        anno.setanno(node, 'func', func)
Beispiel #2
0
def create_temp_grad(node, namer, tangent=False):
    """Create a variable to store partial gradients.

  Args:
    node: See `create_grad`.
    namer: See `create_grad`.
    tangent: See `create_grad`.

  Returns:
    node: See `create_grad`. Returns a node representing the partial gradient.
        Note that this is always a simple variable e.g. the temporary partial
        of `x[i]` can be something like `_dxi`.

        Nodes are given an annotation `temp_adjoint_var`.
  """
    if not isinstance(node, (gast.Subscript, gast.Name)):
        raise TypeError

    def _name_temp_grad(node):
        name = namer.temp_grad(node.id, tangent)
        temp_node = gast.Name(id=name, annotation=None, ctx=None)
        return temp_node

    if isinstance(node, gast.Subscript):
        temp_node = _name_temp_grad(node.value)
    else:
        temp_node = _name_temp_grad(node)
    anno.setanno(temp_node, 'temp_adjoint_var', node)
    return temp_node
Beispiel #3
0
 def _name_grad(node):
     if not isinstance(node, gast.Name):
         raise TypeError
     varname = node.id
     name = namer.grad(varname, tangent)
     grad_node = gast.Name(id=name, ctx=None, annotation=None)
     anno.setanno(grad_node, 'adjoint_var', node)
     return grad_node
Beispiel #4
0
 def visit(self, node):
     if node.value:
         if anno.hasanno(node.value, self.out_label):
             before = hash(anno.getanno(node.value, self.out_label))
         else:
             before = None
         preds = [
             anno.getanno(pred.value, self.out_label) for pred in node.prev
             if anno.hasanno(pred.value, self.out_label)
         ]
         if preds:
             incoming = functools.reduce(self.op, preds[1:], preds[0])
         else:
             incoming = frozenset()
         anno.setanno(node.value, self.in_label, incoming, safe=False)
         gen, kill = self.gen(node, incoming)
         anno.setanno(node.value, self.gen_label, gen, safe=False)
         anno.setanno(node.value, self.kill_label, kill, safe=False)
         anno.setanno(node.value,
                      self.out_label, (incoming - kill) | gen,
                      safe=False)
         if hash(anno.getanno(node.value, self.out_label)) != before:
             for succ in node.next:
                 self.visit(succ)
     else:
         preds = [
             anno.getanno(pred.value, self.out_label) for pred in node.prev
         ]
         self.exit = functools.reduce(self.op, preds[1:], preds[0])
Beispiel #5
0
  def visit_For(self, node):
    # If the iter is a Name that is active,
    # we need to rewrite the loop.
    # Iterators of the form `for a in x` rely on an implicit
    # indexing operation, which Tangent cannot reverse without
    # more information. So, we will create an explicit
    # indexing operation. Note that we will use
    # integer indexes, which will cause strange behavior if
    # the iterator's `next()` behavior deviates from
    # a plain incrementing index.
    # The right thing to do (eventually) is to write a multiple-dispatch
    # version of the `next` operator, and its adjoint, so that
    # we can handle e.g. dicts.

    if isinstance(node.iter, (gast.Name, gast.Subscript, gast.Attribute)):
      iter_name = ast.get_name(node.iter)
      if iter_name in anno.getanno(node, 'active_in'):
        # for a in x:
        #   f(a)
        # # becomes
        # for i in range(len(x)):
        #   a = x[i]
        #   f(a)

        # Get a unique iterator name
        old_target = copy.deepcopy(node.target)
        new_target = quoting.quote(self.namer.unique('_idx'))
        old_iter = copy.deepcopy(node.iter)

        item_access = template.replace(
          'old_target = x[i]',
          old_target=old_target,
          x=old_iter,
          i=new_target)

        node.target = gast.Name(id=new_target.id, ctx=gast.Store(), annotation=None)
        node.iter = quoting.quote('range(len(%s))' % iter_name)
        anno.setanno(node.iter, 'func', range)
        anno.setanno(node.iter.args[0], 'func', len)
        node.body = [item_access] + node.body

    return node
Beispiel #6
0
def create_temp(node, namer):
    """Create a temporary variable.

  Args:
    node: Create a temporary variable to store this variable in.
    namer: A naming object that guarantees the names are unique.

  Returns:
    node: See `create_grad`. Returns a temporary variable, which is always a
        simple variable annotated with `temp_var`.
  """
    if isinstance(node, gast.Name):
        name = node.id
    elif isinstance(node, (gast.Attribute, gast.Subscript)):
        name = node.value.id
    else:
        raise TypeError
    temp_node = gast.Name(id=namer.temp(name), annotation=None, ctx=None)
    anno.setanno(temp_node, 'temp_var', node)
    return temp_node
Beispiel #7
0
def add_comment(node, text, location='above'):
  """Add a comment to the given node.

  If the `SourceWithCommentGenerator` class is used these comments will be
  output as part of the source code.

  Note that a node can only contain one comment. Subsequent calls to
  `add_comment` will ovverride the existing comments.

  Args:
    node: The AST node whose containing statement will be commented.
    text: A comment string.
    location: Where the comment should appear. Valid values are 'above',
    'below' and 'right'

  Returns:
    The node with the comment stored as an annotation.
  """
  anno.setanno(node, 'comment', dict(location=location, text=text), safe=False)
  return node
Beispiel #8
0
 def visit_Name(self, node):
     if node.id in self.replacements:
         # NOTE In principle we don't want to copy, because it might break
         # references held in annotations, but we will copy if we have to to
         # avoid duplicate nodes
         if node.id in self.seen:
             new_nodes = ast_.copy_node(self.replacements[node.id])
         else:
             self.seen.add(node.id)
             new_nodes = self.replacements[node.id]
         if isinstance(new_nodes, gast.AST):
             new_nodes = [new_nodes]
         for new_node in new_nodes:
             anno.setanno(new_node, 'replacement', node, safe=False)
             if 'ctx' in new_node._fields:
                 new_node.ctx = node.ctx
         if len(new_nodes) == 1:
             new_nodes, = new_nodes
         return new_nodes
     else:
         return node
Beispiel #9
0
 def visit_Expr(self, node):
     if isinstance(node.value, gast.Call):
         fn_handle = _get_stack_op_handle(node.value)
         if fn_handle and fn_handle in [utils.push, utils.push_stack]:
             op_id = node.value.args[-1].s
             anno.setanno(node, 'push_var', node.value.args[1])
             try:
                 matching_pop = self.push_pop_pairs[op_id][
                     self.fn_map[fn_handle]]
             except KeyError as e:
                 if not self.strict:
                     return
                 else:
                     raise e
             anno.setanno(node, 'pop', matching_pop, False)
             anno.setanno(node.value, 'pop', matching_pop, False)
Beispiel #10
0
def get_push_pop_stack():
  """Create pop and push nodes for substacks that are linked.

  Returns:
    A push and pop node which have `push_func` and `pop_func` annotations
        respectively, identifying them as such. They also have a `pop` and
        `push` annotation respectively, which links the push node to the pop
        node and vice versa.
  """
  push = copy.deepcopy(PUSH_STACK)
  pop = copy.deepcopy(POP_STACK)
  anno.setanno(push, 'pop', pop)
  anno.setanno(push, 'gen_push', True)
  anno.setanno(pop, 'push', push)
  op_id = _generate_op_id()
  return push, pop, op_id
Beispiel #11
0
    def visit_Assign(self, node):
        if not isinstance(node.value, gast.Call):
            return
        fn_handle = _get_stack_op_handle(node.value)
        if fn_handle and fn_handle in [utils.pop, utils.pop_stack]:
            # Retrieve the op_id, e.g. val = tangent.pop(_stack,'abc')
            #                                                    ^^^
            _, op_id_node = node.value.args
            op_id = op_id_node.s
            anno.setanno(node, 'pop_var', node.targets[0])

            if op_id not in self.push_pop_pairs:
                raise ValueError('op_id %s not known' % op_id)
            push_pop_nodes = self.push_pop_pairs[op_id]
            keys = push_pop_nodes.keys()
            # Check that the op_id is associated with only two operations
            if self.strict and len(keys) != 2:
                raise ValueError('Instead of 2 push/pop fns, found %d' %
                                 len(keys))

            # Make sure that those two operations are either `push` and `pop`
            # or `push_stack` and `pop_stack`.
            if (self.strict and set(keys) != set((utils.push, utils.pop))
                    and set(keys) != set((utils.push_stack, utils.pop_stack))):
                raise ValueError('Invalid push/pop function pair. Found %s' %
                                 keys)

            try:
                matching_push = self.push_pop_pairs[op_id][
                    self.fn_map[fn_handle]]
            except KeyError as e:
                if not self.strict:
                    return
                else:
                    raise e
            anno.setanno(node, 'push', matching_push, False)
            anno.setanno(node.value, 'push', matching_push, False)
Beispiel #12
0
 def mark(self, node):
     if not anno.hasanno(node, 'pre_anf') and self.src:
         anno.setanno(node, 'pre_anf', self.src)
Beispiel #13
0
  def visit_Call(self, node):
    if not self.target:
      return node
    func = anno.getanno(node, 'func')

    if func in tangents.UNIMPLEMENTED_TANGENTS:
      raise errors.ForwardNotImplementedError(func)

    if func == tracing.Traceable:
      raise NotImplementedError('Tracing of %s is not enabled in forward mode' %
                                quoting.unquote(node))

    if func not in tangents.tangents:
      try:
        quoting.parse_function(func)
      except:
        raise ValueError('No tangent found for %s, and could not get source.' %
                         func.__name__)

      # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy)
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if isinstance(arg, gast.Name))
      # TODO: Stack arguments are currently not considered
      # active, but for forward-mode applied to call trees,
      # they have to be. When we figure out how to update activity
      # analysis to do the right thing, we'll want to add the extra check:
      # `and arg.id in self.active_variables`

      # TODO: Duplicate of code in reverse_ad.
      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      fn_name = naming.tangent_name(func, active_args)
      orig_args = quoting.parse_function(func).body[0].args
      tangent_keywords = []
      for i in active_args:
        grad_node = create.create_grad(node.args[i], self.namer, tangent=True)
        arg_grad_node = create.create_grad(
            orig_args.args[i], self.namer, tangent=True)
        grad_node.ctx = gast.Load()
        tangent_keywords.append(
            gast.keyword(arg=arg_grad_node.id, value=grad_node))
      # Update the original call
      rhs = gast.Call(
          func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None),
          args=node.args,
          keywords=tangent_keywords + node.keywords)
      # Set self.value to False to trigger whole primal replacement
      self.value = False
      return [rhs]

    template_ = tangents.tangents[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)
    bound_args.apply_defaults()

    # If any keyword arguments weren't passed, we fill them using the
    # defaults of the original function
    if grads.DEFAULT in bound_args.arguments.values():
      # Build a mapping from names to defaults
      args = quoting.parse_function(func).body[0].args
      defaults = {}
      for arg, default in zip(*map(reversed, [args.args, args.defaults])):
        defaults[arg.id] = default
      for arg, default in zip(args.kwonlyargs, args.kw_defaults):
        if default is not None:
          defaults[arg.id] = default
      for name, value in bound_args.arguments.items():
        if value is grads.DEFAULT:
          bound_args.arguments[name] = defaults[name]

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: self.tmp_node}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(template_).co_varnames[
          -1]] = target
    tangent_node = template.replace(
        template_,
        replace_grad=template.Replace.TANGENT,
        namer=self.namer,
        **arg_replacements)

    # If the template uses the answer in the RHS of the tangent,
    # we need to make sure that the regular answer is replaced
    # with self.tmp_node, but that the gradient is not. We have
    # to be extra careful for statements like a = exp(a), because
    # both the target and RHS variables have the same name.
    tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True)
    tmp_grad_name = tmp_grad_node.id
    ans_grad_node = create.create_grad(self.target, self.namer, tangent=True)
    for _node in tangent_node:
      for succ in gast.walk(_node):
        if isinstance(succ, gast.Name) and succ.id == tmp_grad_name:
          succ.id = ans_grad_node.id

    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [
          create.create_temp_grad(arg, self.namer, True) for arg in to_pack
      ]
      value = create.create_grad(target, self.namer, tangent=True)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())

    # Stack pops have to be special-cased, we have
    # to set the 'push' attribute, so we know that if we
    # remove this pop, we have to remove the equivalent push.
    # NOTE: this only works if we're doing forward-over-reverse,
    # where reverse is applied in joint mode, with no call tree.
    # Otherwise, the pushes and pops won't be matched within a single
    # function call.
    if func == tangent.pop:
      if len(self.metastack):
        anno.setanno(tangent_node[0], 'push', self.metastack.pop())
      else:
        anno.setanno(tangent_node[0], 'push', None)
    return tangent_node
Beispiel #14
0
from __future__ import division

from copy import copy as native_copy
import types

import autograd
import numpy
import six
from tangent import annotations as anno
from tangent import non_differentiable
from tangent import quoting


INIT_GRAD = quoting.quote('tangent.init_grad')
ADD_GRAD = quoting.quote('tangent.add_grad')
anno.setanno(INIT_GRAD, 'init_grad', True)
anno.setanno(ADD_GRAD, 'add_grad', True)


def array_size(x, axis):
  """Calculate the size of `x` along `axis` dimensions only."""
  axis_shape = x.shape if axis is None else tuple(x.shape[a] for a in axis)
  return max(numpy.prod(axis_shape), 1)


class Stack(object):
  """A stack type that proxies list's `append` and `pop` methods.

  We don't use list directly so that we can test its type for the multiple-
  dispatch that occurs in `add_grad` and `init_grad`.
  """
Beispiel #15
0
  def visit_Call(self, node):
    """Create adjoint for call.

    We don't allow unpacking of parameters, so we know that each argument
    gets passed in explicitly, allowing us to create partials for each.
    However, templates might perform parameter unpacking (for cases where
    the number of arguments is variable) and express their gradient as a
    tuple. In this case, we have to unpack this tuple of partials.
    """
    # Find the function we are differentiating
    func = anno.getanno(node, 'func')

    if func in non_differentiable.NON_DIFFERENTIABLE:
      return node, []

    if func == tracing.Traceable:
      return self.primal_and_adjoint_for_tracing(node)

    if func in grads.UNIMPLEMENTED_ADJOINTS:
      raise errors.ReverseNotImplementedError(func)


    # If we don't have an adjoint, we will have to step into the called
    # function and differentiate it
    if func not in grads.adjoints:
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if arg.id in self.active_variables)

      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      pri_name = naming.primal_name(func, active_args)
      pri_call = gast.Call(
          func=gast.Name(id=pri_name, ctx=gast.Load(), annotation=None),
          args=[self.substack] + node.args,
          keywords=node.keywords)
      anno.setanno(pri_call, 'pri_call', True)

      dy = create.create_grad(self.target, self.namer)
      dy.ctx = gast.Load()
      dx = create.create_grad(node.args[0], self.namer)
      dx.ctx = gast.Store()
      adj_name = naming.adjoint_name(func, active_args)
      adj_call = gast.Call(
          func=gast.Name(id=adj_name, ctx=gast.Load(), annotation=None),
          args=[self.substack, dy] + node.args,
          keywords=node.keywords)
      anno.setanno(adj_call, 'adj_call', True)
      adjoint = [template.replace('dxs = dfx', namer=self.namer, dfx=adj_call)]
      for j, i in enumerate(active_args):
        adjoint.append(template.replace('d[x] = dxs[i]', namer=self.namer,
                                        x=node.args[i].id, i=gast.Num(n=j)))
      return pri_call, adjoint

    # We have a template for the gradient that we need to fill in
    template_ = grads.adjoints[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)

    # Fill in any missing kwargs with the defaults from the template
    args = quoting.parse_function(template_).body[0].args
    kwargs = dict(zip(*map(reversed, [args.args, args.defaults])))
    kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults)))
    for arg, val in kwargs.items():
      if arg.id not in bound_args.arguments:
        bound_args.arguments[arg.id] = val

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: ast_.copy_node(self.target)}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    packing = []
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())
      packing = [gast.Assign(targets=[target], value=value)]

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(
          template_).co_varnames[-1]] = target
    adjoint = template.replace(template_, namer=self.namer, **arg_replacements)
    unpacking = []
    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [create.create_temp_grad(arg, self.namer)
                  for arg in to_pack]
      value = create.create_grad(target, self.namer)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())
      unpacking = [gast.Assign(targets=[target], value=value)]

    return node, packing + adjoint + unpacking
Beispiel #16
0
 def visit_FunctionDef(self, node):
     self.generic_visit(node)
     anno.setanno(node, 'func', self.func)
Beispiel #17
0
from tangent import errors
from tangent import fixes
from tangent import funcsigs
from tangent import grads
from tangent import naming
from tangent import non_differentiable
from tangent import quoting
from tangent import template
from tangent import tracing
from tangent import utils


# Some AST nodes to fill in to templates that use stacks or reset gradients
PUSH = quoting.quote('tangent.push')
POP = quoting.quote('tangent.pop')
anno.setanno(PUSH, 'push_func', True)
anno.setanno(POP, 'pop_func', True)
PUSH_STACK = quoting.quote('tangent.push_stack')
POP_STACK = quoting.quote('tangent.pop_stack')
anno.setanno(PUSH_STACK, 'push_func', True)
anno.setanno(POP_STACK, 'pop_func', True)


def _generate_op_id():
  return quoting.quote("'_{}'".format(uuid4().hex[:8]))


def get_push_pop():
  """Create pop and push nodes that are linked.

  Returns: