示例#1
0
 def visit_Assign(self, node):
     self.src = quoting.unquote(node)
     self.mark(node)
     self.trivializing = True
     self.namer.target = node.targets[0]
     if isinstance(node.targets[0], (gast.Subscript, gast.Attribute)):
         node.value = self.trivialize(node.value)
         node.targets[0] = self.visit(node.targets[0])
     elif isinstance(node.targets[0], gast.Tuple):
         node.value = self.visit(node.value)
         name = self.namer.name(node.targets[0])
         target = gast.Name(id=name, ctx=gast.Store(), annotation=None)
         for i, elt in enumerate(node.targets[0].elts):
             stmt = gast.Assign(targets=[elt],
                                value=gast.Subscript(
                                    value=gast.Name(id=name,
                                                    ctx=gast.Load(),
                                                    annotation=None),
                                    slice=gast.Index(value=gast.Num(n=i)),
                                    ctx=gast.Load()))
             self.mark(stmt)
             self.append(stmt)
         node.targets[0] = target
     elif not isinstance(node.targets[0], gast.Name):
         raise ValueError
     node = self.generic_visit(node)
     self.namer.target = None
     self.trivializing = False
     return node
示例#2
0
    def visit_Expr(self, node):
        # We need to special-case pushes, e.g. utils.push(_stack,x,op_id)
        adjoint = []
        if (isinstance(node.value, gast.Call)
                and anno.getanno(node.value, 'func') == utils.push):
            orig_src = quoting.unquote(node)
            stack, val, op_id = node.value.args
            push_template = grads.adjoints[utils.push]
            adjoint_rhs = template.replace(push_template,
                                           namer=self.namer,
                                           stack=stack,
                                           val=val,
                                           op_id=op_id)

            # Walk the RHS adjoint AST to find temporary adjoint variables to
            # sum
            accumulation = template.replace(
                'd[xi] = add_grad(d[xi], dxi_partial)',
                namer=self.namer,
                replace_grad=template.Replace.FULL,
                xi=val,
                dxi_partial=ast_.copy_node(adjoint_rhs[0].targets[0]),
                add_grad=utils.ADD_GRAD)
            adjoint = adjoint_rhs + [accumulation]
            for i, adj in enumerate(adjoint):
                adjoint[i] = comments.add_comment(adj,
                                                  'Grad of: %s' % orig_src)
        return node, adjoint
示例#3
0
  def visit_Assign(self, node):
    """Visit assignment statement.

    Notes
    -----
    This method sets the `self.target` attribute to the first assignment
    target for callees to use.
    """

    # Generate tangent name of the assignment
    self.target = node.targets[0]
    self.value = node.value

    # Get the tangent node
    tangent_node = self.visit(self.value)

    # If no forward-mode statement was created, then no extra work is needed.
    if tangent_node == self.value:
      self.target = None
      return node

    if self.value:
      new_node = template.replace(
          tangents.tangents[gast.Assign],
          replace_grad=template.Replace.TANGENT,
          namer=self.namer,
          temp=self.tmp_node,
          tangent=tangent_node,
          target=self.target,
          value=self.value)
      # Ensure that we use a unique tmp node per primal/tangent pair
      self.reset_tmp_node()
    else:
      # We're already in ANF form right now,
      # so we can assume LHS is a single Name "z"
      def template_(z, f):
        tmp = f
        d[z] = tmp[0]
        z = tmp[1]

      new_node = template.replace(
          template_,
          replace_grad=template.Replace.TANGENT,
          namer=self.namer,
          z=self.target,
          f=tangent_node[0])

    # Add it after the original statement
    self.target = None

    # Add in some cool comments
    for i in range(len(new_node)):
      new_node[i] = comments.add_comment(new_node[i], 'Primal and tangent of: '
                                         '%s' % quoting.unquote(node))
    return new_node
示例#4
0
def _get_stack_op_handle(node):
    assert isinstance(node, gast.Call), 'Only can get fn handles of Call nodes'
    fn_handle = anno.getanno(node, 'func', False)
    fn_map = defaultdict(lambda: False)
    fn_map['tangent.pop'] = utils.pop
    fn_map['tangent.push'] = utils.push
    fn_map['tangent.pop_stack'] = utils.pop_stack
    fn_map['tangent.push_stack'] = utils.push_stack
    if not fn_handle:
        fn_handle = fn_map[quoting.unquote(node.func)]
    return fn_handle
示例#5
0
文件: anf.py 项目: zouzias/tangent
 def visit_AugAssign(self, node):
     self.src = quoting.unquote(node)
     self.trivializing = True
     self.namer.target = node.target
     right = self.trivialize(node.value)
     target = self.trivialize(node.target)
     left = gast.Name(id=target.id, ctx=gast.Load(), annotation=None)
     node = gast.Assign(targets=[target],
                        value=gast.BinOp(left=left, op=node.op,
                                         right=right))
     self.mark(node)
     node = self.generic_visit(node)
     self.namer.target = None
     self.trivializing = False
     return node
示例#6
0
def test_insert():
    def f(x):
        y = x
        return y

    node = quoting.parse_function(f)

    class Prepend(transformers.TreeTransformer):
        def visit_Assign(self, node):
            # If the target is y, then prepend this statement
            # NOTE Without this test, we'd have an infinite loop
            if node.targets[0].id == 'y':
                statement = quoting.quote("x = 2 * x")
                self.prepend(statement)
            return node

    Prepend().visit(node)
    assert quoting.unquote(node).split('\n')[1].strip() == "x = 2 * x"
示例#7
0
def test_remove():
    def f(x):
        while True:
            y = x
            z = y
        return y

    node = quoting.parse_function(f)

    class InsertTop(transformers.TreeTransformer):
        def visit_Assign(self, node):
            # If the target is y, then prepend this statement
            # NOTE Without this test, we'd have an infinite loop
            if node.targets[0].id == 'z':
                self.remove(node)
            return node

    InsertTop().visit(node)
    assert quoting.unquote(node).split('\n')[3].strip() == "return y"
示例#8
0
 def visit_Call(self, node):
     fn_handle = _get_stack_op_handle(node)
     if fn_handle and fn_handle in [
             utils.pop, utils.push, utils.push_stack, utils.pop_stack
     ]:
         # Retrieve the op_id, e.g. tangent.push(_stack,val,'abc')
         #                                                   ^^^
         if fn_handle in [utils.push, utils.push_stack]:
             _, _, op_id_node = node.args
         elif fn_handle in [utils.pop, utils.pop_stack]:
             _, op_id_node = node.args
         op_id = op_id_node.s
         if op_id not in self.push_pop_pairs:
             self.push_pop_pairs[op_id] = dict()
         assert fn_handle not in self.push_pop_pairs, (
             'Conflicting op_ids. '
             'Already have fn %s with '
             'op_id %s') % (quoting.unquote(node.func), op_id)
         self.push_pop_pairs[op_id][fn_handle] = node
示例#9
0
  def visit_Call(self, node):
    if not self.target:
      return node
    func = anno.getanno(node, 'func')

    if func in tangents.UNIMPLEMENTED_TANGENTS:
      raise errors.ForwardNotImplementedError(func)

    if func == tracing.Traceable:
      raise NotImplementedError('Tracing of %s is not enabled in forward mode' %
                                quoting.unquote(node))

    if func not in tangents.tangents:
      try:
        quoting.parse_function(func)
      except:
        raise ValueError('No tangent found for %s, and could not get source.' %
                         func.__name__)

      # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy)
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if isinstance(arg, gast.Name))
      # TODO: Stack arguments are currently not considered
      # active, but for forward-mode applied to call trees,
      # they have to be. When we figure out how to update activity
      # analysis to do the right thing, we'll want to add the extra check:
      # `and arg.id in self.active_variables`

      # TODO: Duplicate of code in reverse_ad.
      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      fn_name = naming.tangent_name(func, active_args)
      orig_args = quoting.parse_function(func).body[0].args
      tangent_keywords = []
      for i in active_args:
        grad_node = create.create_grad(node.args[i], self.namer, tangent=True)
        arg_grad_node = create.create_grad(
            orig_args.args[i], self.namer, tangent=True)
        grad_node.ctx = gast.Load()
        tangent_keywords.append(
            gast.keyword(arg=arg_grad_node.id, value=grad_node))
      # Update the original call
      rhs = gast.Call(
          func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None),
          args=node.args,
          keywords=tangent_keywords + node.keywords)
      # Set self.value to False to trigger whole primal replacement
      self.value = False
      return [rhs]

    template_ = tangents.tangents[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)
    bound_args.apply_defaults()

    # If any keyword arguments weren't passed, we fill them using the
    # defaults of the original function
    if grads.DEFAULT in bound_args.arguments.values():
      # Build a mapping from names to defaults
      args = quoting.parse_function(func).body[0].args
      defaults = {}
      for arg, default in zip(*map(reversed, [args.args, args.defaults])):
        defaults[arg.id] = default
      for arg, default in zip(args.kwonlyargs, args.kw_defaults):
        if default is not None:
          defaults[arg.id] = default
      for name, value in bound_args.arguments.items():
        if value is grads.DEFAULT:
          bound_args.arguments[name] = defaults[name]

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: self.tmp_node}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(template_).co_varnames[
          -1]] = target
    tangent_node = template.replace(
        template_,
        replace_grad=template.Replace.TANGENT,
        namer=self.namer,
        **arg_replacements)

    # If the template uses the answer in the RHS of the tangent,
    # we need to make sure that the regular answer is replaced
    # with self.tmp_node, but that the gradient is not. We have
    # to be extra careful for statements like a = exp(a), because
    # both the target and RHS variables have the same name.
    tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True)
    tmp_grad_name = tmp_grad_node.id
    ans_grad_node = create.create_grad(self.target, self.namer, tangent=True)
    for _node in tangent_node:
      for succ in gast.walk(_node):
        if isinstance(succ, gast.Name) and succ.id == tmp_grad_name:
          succ.id = ans_grad_node.id

    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [
          create.create_temp_grad(arg, self.namer, True) for arg in to_pack
      ]
      value = create.create_grad(target, self.namer, tangent=True)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())

    # Stack pops have to be special-cased, we have
    # to set the 'push' attribute, so we know that if we
    # remove this pop, we have to remove the equivalent push.
    # NOTE: this only works if we're doing forward-over-reverse,
    # where reverse is applied in joint mode, with no call tree.
    # Otherwise, the pushes and pops won't be matched within a single
    # function call.
    if func == tangent.pop:
      if len(self.metastack):
        anno.setanno(tangent_node[0], 'push', self.metastack.pop())
      else:
        anno.setanno(tangent_node[0], 'push', None)
    return tangent_node
示例#10
0
def anf_lines(f):
    """Return the ANF transformed source code as lines."""
    return quoting.unquote(anf.anf(quoting.parse_function(f))).split('\n')
示例#11
0
def store_state(node, reaching, defined, stack):
  """Push the final state of the primal onto the stack for the adjoint.

  Python's scoping rules make it possible for variables to not be defined in
  certain blocks based on the control flow path taken at runtime. In order to
  make sure we don't try to push non-existing variables onto the stack, we
  defined these variables explicitly (by assigning `None` to them) at the
  beginning of the function.

  All the variables that reach the return statement are pushed onto the
  stack, and in the adjoint they are popped off in reverse order.

  Args:
    node: A module with the primal and adjoint function definitions as returned
        by `reverse_ad`.
    reaching: The variable definitions that reach the end of the primal.
    defined: The variables defined at the end of the primal.
    stack: The stack node to use for storing and restoring state.

  Returns:
    node: A node with the requisite pushes and pops added to make sure that
        state is transferred between primal and adjoint split motion calls.
  """
  defs = [def_ for def_ in reaching if not isinstance(def_[1], gast.arguments)]
  if not len(defs):
    return node
  reaching, original_defs = zip(*defs)

  # Explicitly define variables that might or might not be in scope at the end
  assignments = []
  for id_ in set(reaching) - defined:
    assignments.append(quoting.quote('{} = None'.format(id_)))

  # Store variables at the end of the function and restore them
  store = []
  load = []
  for id_, def_ in zip(reaching, original_defs):
    # If the original definition of a value that we need to store
    # was an initialization as a stack, then we should be using `push_stack`
    # to store its state, and `pop_stack` to restore it. This allows
    # us to avoid doing any `add_grad` calls on the stack, which result
    # in type errors in unoptimized mode (they are usually elided
    # after calling `dead_code_elimination`).
    if isinstance(
        def_, gast.Assign) and 'tangent.Stack()' in quoting.unquote(def_.value):
      push, pop, op_id = get_push_pop_stack()
    else:
      push, pop, op_id = get_push_pop()
    store.append(
        template.replace(
            'push(_stack, val, op_id)',
            push=push,
            val=id_,
            _stack=stack,
            op_id=op_id))
    load.append(
        template.replace(
            'val = pop(_stack, op_id)',
            pop=pop,
            val=id_,
            _stack=stack,
            op_id=op_id))

  body, return_ = node.body[0].body[:-1], node.body[0].body[-1]
  node.body[0].body = assignments + body + store + [return_]
  node.body[1].body = load[::-1] + node.body[1].body

  return node
示例#12
0
  def visit_Assign(self, node):
    """Visit assignment statement."""
    if len(node.targets) != 1:
      raise ValueError('no support for chained assignment')

    # Before the node gets modified, get a source code representation
    # to add as a comment later on
    if anno.hasanno(node, 'pre_anf'):
      orig_src = anno.getanno(node, 'pre_anf')
    else:
      orig_src = quoting.unquote(node)

    # Set target for the RHS visitor to access
    self.orig_target = ast_.copy_node(node.targets[0])

    # If we know we're going to be putting another stack on the stack,
    # we should try to make that explicit
    if isinstance(node.value, gast.Call) and \
        anno.hasanno(node.value, 'func') and \
        anno.getanno(node.value, 'func') in (utils.Stack, utils.pop_stack):
      push, pop, op_id = get_push_pop_stack()
    else:
      push, pop, op_id = get_push_pop()
    push_stack, pop_stack, op_id_stack = get_push_pop_stack()

    # Every assignment statement requires us to store the pre-value, and in the
    # adjoint restore the value, and reset the gradient
    store = template.replace(
        'push(_stack, y, op_id)',
        push=push,
        y=self.orig_target,
        _stack=self.stack,
        op_id=op_id)
    create_substack = template.replace(
        'substack = tangent.Stack()', substack=self.substack)
    store_substack = template.replace(
        'push(stack, substack, op_id)',
        push=push_stack,
        stack=self.stack,
        substack=self.substack,
        op_id=op_id_stack)
    restore = template.replace(
        'y = pop(_stack, op_id)',
        _stack=self.stack,
        pop=pop,
        y=ast_.copy_node(self.orig_target),
        op_id=op_id)
    restore_substack = template.replace(
        'substack = pop(stack, op_id)',
        pop=pop_stack,
        stack=self.stack,
        substack=self.substack,
        op_id=op_id_stack)
    reset = template.replace(
        'd[y] = init_grad(y, allow_lazy_initializer=True)',
        y=self.orig_target,
        init_grad=utils.INIT_GRAD,
        namer=self.namer,
        replace_grad=template.Replace.FULL)

    # If there are no active nodes, we don't need to find an adjoint
    # We simply store and restore the state, and reset the gradient
    if not self.is_active(node):
      return [store, node], [restore, reset]

    # We create a temporary variable for the target that the RHS can use
    self.target = create.create_temp(self.orig_target, self.namer)
    create_tmp = template.replace(
        'tmp = y', tmp=self.target, y=self.orig_target)

    # Get the primal and adjoint of the RHS expression
    try:
      fx, adjoint_rhs = self.visit(node.value)
    except ValueError as e:
      context = [t.id if hasattr(t, 'id') else t for t in node.targets]
      raise ValueError(
          'Failed to process assignment to: %s. Error: %s' % (context, e))
    if not isinstance(adjoint_rhs, list):
      adjoint_rhs = [adjoint_rhs]

    # Walk the RHS adjoint AST to find temporary adjoint variables to sum
    accumulations = []
    for n in adjoint_rhs:
      for succ in gast.walk(n):
        if anno.hasanno(succ, 'temp_adjoint_var'):
          xi = anno.getanno(succ, 'temp_adjoint_var')
          dxi_partial = ast_.copy_node(succ)
          accumulations.append(template.replace(
              'd[xi] = add_grad(d[xi], dxi_partial)',
              namer=self.namer, replace_grad=template.Replace.FULL,
              xi=xi, dxi_partial=dxi_partial, add_grad=utils.ADD_GRAD))

    # The primal consists of storing the state and then performing the
    # assignment with the new primal.
    # The primal `fx` may be optionally (but rarely) redefined when the
    # adjoint is generated, in `fx, adjoint_rhs = self.visit(node.value)`.
    # If we see that the primal value is an Assign node, or a list of nodes
    # (with at least one being an Assign) node, we allow the primal to change.
    # Otherwise, we'll build our own Assign node.
    if isinstance(fx, gast.Assign):
      assign = [fx]
    elif (isinstance(fx, list) and
          any([isinstance(ifx, gast.Assign) for ifx in fx])):
      assign = fx
    else:
      assign = template.replace(
          'y = fx', y=ast_.copy_node(self.orig_target), fx=fx)
      assign = [assign]
    primal = [store, create_substack, store_substack] + assign

    # The adjoint involves creating the temporary, restoring the store,
    # calculating the adjoint, resetting the gradient, and finally accumulating
    # the partials
    adjoint = [create_tmp, restore_substack, restore
              ] + adjoint_rhs + [reset] + accumulations

    # If the LHS is a subscript assignment with variable index, we need to
    # store and restore that as well
    if (isinstance(self.orig_target, gast.Subscript) and
        isinstance(self.orig_target.slice.value, gast.Name)):
      push, pop, op_id = get_push_pop()
      i = self.orig_target.slice.value
      push_index = template.replace(
          'push(_stack, i, op_id)',
          push=push,
          i=i,
          _stack=self.stack,
          op_id=op_id)
      pop_index = template.replace(
          'i = pop(_stack, op_id)',
          pop=pop,
          i=i,
          _stack_=self.stack,
          op_id=op_id)

      primal.insert(len(primal), push_index)
      adjoint.insert(0, pop_index)

    # Add a comment in the backwards pass, indicating which
    # lines in the forward pass generated the adjoint
    for i, adj in enumerate(adjoint):
      adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src)

    return primal, adjoint
示例#13
0
def test_string_replace():
    node = template.replace("a = b", a="y", b="x * 2")
    assert quoting.unquote(node) == "y = x * 2"
示例#14
0
def test_node_replace():
    node = template.replace(quoting.quote("a = b"), a="y", b="x * 2")
    assert quoting.unquote(node) == "y = x * 2"