Example #1
0
 def visit(self, node):
     if node.value:
         if anno.hasanno(node.value, self.out_label):
             before = hash(anno.getanno(node.value, self.out_label))
         else:
             before = None
         preds = [
             anno.getanno(pred.value, self.out_label) for pred in node.prev
             if anno.hasanno(pred.value, self.out_label)
         ]
         if preds:
             incoming = functools.reduce(self.op, preds[1:], preds[0])
         else:
             incoming = frozenset()
         anno.setanno(node.value, self.in_label, incoming, safe=False)
         gen, kill = self.gen(node, incoming)
         anno.setanno(node.value, self.gen_label, gen, safe=False)
         anno.setanno(node.value, self.kill_label, kill, safe=False)
         anno.setanno(node.value,
                      self.out_label, (incoming - kill) | gen,
                      safe=False)
         if hash(anno.getanno(node.value, self.out_label)) != before:
             for succ in node.next:
                 self.visit(succ)
     else:
         preds = [
             anno.getanno(pred.value, self.out_label) for pred in node.prev
         ]
         self.exit = functools.reduce(self.op, preds[1:], preds[0])
Example #2
0
 def visit(self, node):
     if anno.hasanno(node, 'definitions_gen'):
         self.definitions.update(anno.getanno(node, 'definitions_gen'))
         self.reaching_definitions = anno.getanno(node, 'definitions_in')
     if isinstance(node, gast.Name) and isinstance(node.ctx, gast.Load):
         self.used.update(def_ for def_ in self.reaching_definitions
                          if def_[0] == node.id)
     super(Unused, self).visit(node)
     if anno.hasanno(node, 'definitions_gen'):
         self.reaching_definitions = None
Example #3
0
 def visit(self, node):
     # Remove all AD-generated pushes of unused variables.
     if anno.hasanno(node, 'push_var') and anno.hasanno(
             node, 'pop') and anno.hasanno(node, 'gen_push'):
         defs = frozenset(
             id_ for id_, node in anno.getanno(node, 'definitions_in'))
         if ast_.get_name(anno.getanno(node, 'push_var')) not in defs:
             self.remove(node)
             self.remove(anno.getanno(node, 'pop'))
     return super(CleanStack, self).visit(node)
Example #4
0
 def prepend_uninitialized_grads(self, node):
     if anno.hasanno(node, 'defined_in'):
         uses = (succ for succ in gast.walk(node)
                 if isinstance(succ, gast.Name)
                 and isinstance(succ.ctx, gast.Load))
         for use in uses:
             if ((anno.hasanno(use, 'adjoint_var')
                  or anno.hasanno(use, 'temp_adjoint_var'))
                     and use.id not in anno.getanno(node, 'defined_in')
                     and use.id not in self.added):
                 self.added.add(use.id)
                 self.insert_top(self._init(use))
     return node
Example #5
0
 def visit(self, node, abort=astor.codegen.SourceGenerator.abort_visit):
     if anno.hasanno(node, 'comment'):
         comment = anno.getanno(node, 'comment')
         # Preprocess the comment to fit to maximum line width of 80 characters
         linewidth = 78
         if comment['location'] in ('above', 'below'):
             comment['text'] = comment['text'][:linewidth]
         n_newlines = 1 if self.new_indentation else 2
         if comment['location'] == 'above':
             self.result.append('\n' * n_newlines)
             self.result.append(self.indent_with * self.indentation)
             self.result.append('# %s' % comment['text'])
             super(SourceWithCommentGenerator, self).visit(node)
         elif comment['location'] == 'below':
             super(SourceWithCommentGenerator, self).visit(node)
             self.result.append('\n')
             self.result.append(self.indent_with * self.indentation)
             self.result.append('# %s' % comment['text'])
             self.result.append('\n' * (n_newlines - 1))
         elif comment['location'] == 'right':
             super(SourceWithCommentGenerator, self).visit(node)
             self.result.append(' # %s' % comment['text'])
         else:
             raise TangentParseError('Only valid comment locations are '
                                     'above, below, right')
     else:
         self.new_indentation = False
         super(SourceWithCommentGenerator, self).visit(node)
Example #6
0
def explicit_loop_indexes(node):
  node = ExplicitLoopIndexes().visit(node)
  for n in gast.walk(node):
    for key in ('active_in', 'active_out', 'active_gen', 'active_kill'):
      if anno.hasanno(n, key):
        anno.delanno(n, key)
  return node
Example #7
0
 def visit_Assign(self, node):
     if isinstance(node.value, gast.Call) and anno.hasanno(
             node.value.func, 'add_grad'):
         defs = frozenset(
             id_ for id_, node in anno.getanno(node, 'definitions_in'))
         if ast_.get_name(node.targets[0]) not in defs:
             node.value = node.value.args[1]
     return node
Example #8
0
  def visit(self, node):
    method = 'visit_' + node.__class__.__name__
    visitor = getattr(self, method, self.generic_visit)

    # Set certain attributes for child nodes
    if anno.hasanno(node, 'active_in'):
      self.active_variables = anno.getanno(node, 'active_in')

    return visitor(node)
Example #9
0
 def _init(self, node):
     gradname = ast_.get_name(node)
     if anno.hasanno(node, 'adjoint_var'):
         var = anno.getanno(node, 'adjoint_var')
     else:
         var = anno.getanno(node, 'temp_adjoint_var')
     return gast.Assign(targets=[
         gast.Name(id=gradname, ctx=gast.Store(), annotation=None)
     ],
                        value=gast.Call(func=utils.INIT_GRAD,
                                        args=[var],
                                        keywords=[]))
Example #10
0
 def visit(self, node):
     if node in self.to_remove:
         self.remove = True
     if anno.hasanno(node, 'pri_call') or anno.hasanno(node, 'adj_call'):
         # We don't remove function calls for now; removing them also
         # removes the push statements inside of them, but not the
         # corresponding pop statements
         self.is_call = True
     new_node = super(Remove, self).visit(node)
     if isinstance(node, grammar.STATEMENTS):
         if self.remove and not self.is_call:
             new_node = None
         self.remove = self.is_call = False
     if isinstance(node, gast.If) and not node.body:
         # If we optimized away an entire if block, we need to handle that
         if not node.orelse:
             return
         else:
             node.test = gast.UnaryOp(op=gast.Not(), operand=node.test)
             node.body, node.orelse = node.orelse, node.body
     elif isinstance(node, (gast.While, gast.For)) and not node.body:
         return node.orelse
     return new_node
Example #11
0
def create_grad(node, namer, tangent=False):
    """Given a variable, create a variable for the gradient.

  Args:
    node: A node to create a gradient for, can be a normal variable (`x`) or a
        subscript (`x[i]`).
    namer: The namer object which will determine the name to use for the
        gradient.
    tangent: Whether a tangent (instead of adjoint) is created.

  Returns:
    node: A node representing the gradient with the correct name e.g. the
        gradient of `x[i]` is `dx[i]`.

        Note that this returns an invalid node, with the `ctx` attribute
        missing. It is assumed that this attribute is filled in later.

        Node has an `adjoint_var` annotation referring to the node it is an
        adjoint of.
  """
    if not isinstance(node, (gast.Subscript, gast.Name, gast.Str)):
        raise TypeError

    if anno.hasanno(node, 'temp_var'):
        return create_grad(anno.getanno(node, 'temp_var'), namer, tangent)

    def _name_grad(node):
        if not isinstance(node, gast.Name):
            raise TypeError
        varname = node.id
        name = namer.grad(varname, tangent)
        grad_node = gast.Name(id=name, ctx=None, annotation=None)
        anno.setanno(grad_node, 'adjoint_var', node)
        return grad_node

    if isinstance(node, gast.Subscript):
        grad_node = create_grad(node.value, namer, tangent=tangent)
        grad_node.ctx = gast.Load()
        return gast.Subscript(value=grad_node, slice=node.slice, ctx=None)
    elif isinstance(node, gast.Str):
        grad_node = create_grad(gast.Name(id=node.s, ctx=None,
                                          annotation=None),
                                namer,
                                tangent=tangent)
        return gast.Str(grad_node.id)
    else:
        return _name_grad(node)
Example #12
0
def remove_repeated_comments(node):
  """Remove comments that repeat themselves.

  Multiple statements might be annotated with the same comment. This way if one
  of the statements is deleted during optimization passes, the comment won't be
  lost. This pass removes sequences of identical comments, leaving only the
  first one.

  Args:
    node: An AST

  Returns:
    An AST where comments are not repeated in sequence.

  """
  last_comment = {'text': None}
  for _node in gast.walk(node):
    if anno.hasanno(_node, 'comment'):
      comment = anno.getanno(_node, 'comment')
      if comment['text'] == last_comment['text']:
        anno.delanno(_node, 'comment')
      last_comment = comment
  return node
Example #13
0
  def visit(self, node):
    """Visit a node.

    This method is largely modelled after the ast.NodeTransformer class.

    Args:
      node: The node to visit.

    Returns:
      A tuple of the primal and adjoint, each of which is a node or a list of
      nodes.
    """
    method = 'visit_' + node.__class__.__name__
    if not hasattr(self, method):
      raise ValueError('Unknown node type: %s' % node.__class__.__name__)
    visitor = getattr(self, method)

    # If this node is a statement, inform all child nodes what the active
    # variables in this statement are
    if anno.hasanno(node, 'active_in'):
      self.active_variables = anno.getanno(node, 'active_in')
    pri, adj = visitor(node)

    # Annotate primal and adjoint statements
    if isinstance(pri, gast.AST):
      anno.setdefaultanno(pri, 'adj', adj)
    else:
      for node in pri:
        anno.setdefaultanno(node, 'adj', adj)
    if isinstance(adj, gast.AST):
      anno.setdefaultanno(adj, 'pri', pri)
    else:
      for node in adj:
        anno.setdefaultanno(node, 'pri', pri)

    return pri, adj
Example #14
0
 def visit(self, node):
     if anno.hasanno(node, 'push_var'):
         varname = ast_.get_name(anno.getanno(node, 'push_var'))
         if varname not in anno.getanno(node, 'defined_in'):
             self.insert_top(quoting.quote('{} = None'.format(varname)))
     return super(FixStack, self).visit(node)
Example #15
0
  def visit_Assign(self, node):
    """Visit assignment statement."""
    if len(node.targets) != 1:
      raise ValueError('no support for chained assignment')

    # Before the node gets modified, get a source code representation
    # to add as a comment later on
    if anno.hasanno(node, 'pre_anf'):
      orig_src = anno.getanno(node, 'pre_anf')
    else:
      orig_src = quoting.unquote(node)

    # Set target for the RHS visitor to access
    self.orig_target = ast_.copy_node(node.targets[0])

    # If we know we're going to be putting another stack on the stack,
    # we should try to make that explicit
    if isinstance(node.value, gast.Call) and \
        anno.hasanno(node.value, 'func') and \
        anno.getanno(node.value, 'func') in (utils.Stack, utils.pop_stack):
      push, pop, op_id = get_push_pop_stack()
    else:
      push, pop, op_id = get_push_pop()
    push_stack, pop_stack, op_id_stack = get_push_pop_stack()

    # Every assignment statement requires us to store the pre-value, and in the
    # adjoint restore the value, and reset the gradient
    store = template.replace(
        'push(_stack, y, op_id)',
        push=push,
        y=self.orig_target,
        _stack=self.stack,
        op_id=op_id)
    create_substack = template.replace(
        'substack = tangent.Stack()', substack=self.substack)
    store_substack = template.replace(
        'push(stack, substack, op_id)',
        push=push_stack,
        stack=self.stack,
        substack=self.substack,
        op_id=op_id_stack)
    restore = template.replace(
        'y = pop(_stack, op_id)',
        _stack=self.stack,
        pop=pop,
        y=ast_.copy_node(self.orig_target),
        op_id=op_id)
    restore_substack = template.replace(
        'substack = pop(stack, op_id)',
        pop=pop_stack,
        stack=self.stack,
        substack=self.substack,
        op_id=op_id_stack)
    reset = template.replace(
        'd[y] = init_grad(y, allow_lazy_initializer=True)',
        y=self.orig_target,
        init_grad=utils.INIT_GRAD,
        namer=self.namer,
        replace_grad=template.Replace.FULL)

    # If there are no active nodes, we don't need to find an adjoint
    # We simply store and restore the state, and reset the gradient
    if not self.is_active(node):
      return [store, node], [restore, reset]

    # We create a temporary variable for the target that the RHS can use
    self.target = create.create_temp(self.orig_target, self.namer)
    create_tmp = template.replace(
        'tmp = y', tmp=self.target, y=self.orig_target)

    # Get the primal and adjoint of the RHS expression
    try:
      fx, adjoint_rhs = self.visit(node.value)
    except ValueError as e:
      context = [t.id if hasattr(t, 'id') else t for t in node.targets]
      raise ValueError(
          'Failed to process assignment to: %s. Error: %s' % (context, e))
    if not isinstance(adjoint_rhs, list):
      adjoint_rhs = [adjoint_rhs]

    # Walk the RHS adjoint AST to find temporary adjoint variables to sum
    accumulations = []
    for n in adjoint_rhs:
      for succ in gast.walk(n):
        if anno.hasanno(succ, 'temp_adjoint_var'):
          xi = anno.getanno(succ, 'temp_adjoint_var')
          dxi_partial = ast_.copy_node(succ)
          accumulations.append(template.replace(
              'd[xi] = add_grad(d[xi], dxi_partial)',
              namer=self.namer, replace_grad=template.Replace.FULL,
              xi=xi, dxi_partial=dxi_partial, add_grad=utils.ADD_GRAD))

    # The primal consists of storing the state and then performing the
    # assignment with the new primal.
    # The primal `fx` may be optionally (but rarely) redefined when the
    # adjoint is generated, in `fx, adjoint_rhs = self.visit(node.value)`.
    # If we see that the primal value is an Assign node, or a list of nodes
    # (with at least one being an Assign) node, we allow the primal to change.
    # Otherwise, we'll build our own Assign node.
    if isinstance(fx, gast.Assign):
      assign = [fx]
    elif (isinstance(fx, list) and
          any([isinstance(ifx, gast.Assign) for ifx in fx])):
      assign = fx
    else:
      assign = template.replace(
          'y = fx', y=ast_.copy_node(self.orig_target), fx=fx)
      assign = [assign]
    primal = [store, create_substack, store_substack] + assign

    # The adjoint involves creating the temporary, restoring the store,
    # calculating the adjoint, resetting the gradient, and finally accumulating
    # the partials
    adjoint = [create_tmp, restore_substack, restore
              ] + adjoint_rhs + [reset] + accumulations

    # If the LHS is a subscript assignment with variable index, we need to
    # store and restore that as well
    if (isinstance(self.orig_target, gast.Subscript) and
        isinstance(self.orig_target.slice.value, gast.Name)):
      push, pop, op_id = get_push_pop()
      i = self.orig_target.slice.value
      push_index = template.replace(
          'push(_stack, i, op_id)',
          push=push,
          i=i,
          _stack=self.stack,
          op_id=op_id)
      pop_index = template.replace(
          'i = pop(_stack, op_id)',
          pop=pop,
          i=i,
          _stack_=self.stack,
          op_id=op_id)

      primal.insert(len(primal), push_index)
      adjoint.insert(0, pop_index)

    # Add a comment in the backwards pass, indicating which
    # lines in the forward pass generated the adjoint
    for i, adj in enumerate(adjoint):
      adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src)

    return primal, adjoint
Example #16
0
 def mark(self, node):
     if not anno.hasanno(node, 'pre_anf') and self.src:
         anno.setanno(node, 'pre_anf', self.src)
Example #17
0
 def visit(self, node):
     if anno.hasanno(node, 'definitions_in'):
         self.reaching_definitions = anno.getanno(node, 'definitions_in')
     super(ReadCounts, self).visit(node)
     if anno.hasanno(node, 'definitions_in'):
         self.reaching_definitions = None