Пример #1
0
def dead_code_elimination(node):
    """Perform a simple form of dead code elimination on a Python AST.

  This method performs reaching definitions analysis on all function
  definitions. It then looks for the definition of variables that are not used
  elsewhere and removes those definitions.

  This function takes into consideration push and pop statements; if a pop
  statement is removed, it will also try to remove the accompanying push
  statement. Note that this *requires dead code elimination to be performed on
  the primal and adjoint simultaneously*.

  Args:
    node: The AST to optimize.

  Returns:
    The optimized AST.
  """
    to_remove = set(def_[1] for def_ in annotate.unused(node)
                    if not isinstance(def_[1], (gast.arguments, gast.For)))
    for n in list(to_remove):
        for succ in gast.walk(n):
            if anno.getanno(succ, 'push', False):
                to_remove.add(anno.getanno(succ, 'push'))
    transformers.Remove(to_remove).visit(node)
    anno.clearanno(node)
    return node
Пример #2
0
 def visit(self, node):
     if node.value:
         if anno.hasanno(node.value, self.out_label):
             before = hash(anno.getanno(node.value, self.out_label))
         else:
             before = None
         preds = [
             anno.getanno(pred.value, self.out_label) for pred in node.prev
             if anno.hasanno(pred.value, self.out_label)
         ]
         if preds:
             incoming = functools.reduce(self.op, preds[1:], preds[0])
         else:
             incoming = frozenset()
         anno.setanno(node.value, self.in_label, incoming, safe=False)
         gen, kill = self.gen(node, incoming)
         anno.setanno(node.value, self.gen_label, gen, safe=False)
         anno.setanno(node.value, self.kill_label, kill, safe=False)
         anno.setanno(node.value,
                      self.out_label, (incoming - kill) | gen,
                      safe=False)
         if hash(anno.getanno(node.value, self.out_label)) != before:
             for succ in node.next:
                 self.visit(succ)
     else:
         preds = [
             anno.getanno(pred.value, self.out_label) for pred in node.prev
         ]
         self.exit = functools.reduce(self.op, preds[1:], preds[0])
Пример #3
0
def test_defined():
    node = tangent.quoting.parse_function(g)
    cfg.forward(node, cfg.Defined())
    body = node.body[0].body
    # only x is for sure defined at the end
    assert len(anno.getanno(body[1], 'defined_in')) == 1
    # at the end of the if body both x and y are defined
    if_body = body[0].body
    assert len(anno.getanno(if_body[0], 'defined_out')) == 2
Пример #4
0
 def visit(self, node):
     # Remove all AD-generated pushes of unused variables.
     if anno.hasanno(node, 'push_var') and anno.hasanno(
             node, 'pop') and anno.hasanno(node, 'gen_push'):
         defs = frozenset(
             id_ for id_, node in anno.getanno(node, 'definitions_in'))
         if ast_.get_name(anno.getanno(node, 'push_var')) not in defs:
             self.remove(node)
             self.remove(anno.getanno(node, 'pop'))
     return super(CleanStack, self).visit(node)
Пример #5
0
 def visit(self, node):
     if anno.hasanno(node, 'definitions_gen'):
         self.definitions.update(anno.getanno(node, 'definitions_gen'))
         self.reaching_definitions = anno.getanno(node, 'definitions_in')
     if isinstance(node, gast.Name) and isinstance(node.ctx, gast.Load):
         self.used.update(def_ for def_ in self.reaching_definitions
                          if def_[0] == node.id)
     super(Unused, self).visit(node)
     if anno.hasanno(node, 'definitions_gen'):
         self.reaching_definitions = None
Пример #6
0
 def _init(self, node):
     gradname = ast_.get_name(node)
     if anno.hasanno(node, 'adjoint_var'):
         var = anno.getanno(node, 'adjoint_var')
     else:
         var = anno.getanno(node, 'temp_adjoint_var')
     return gast.Assign(targets=[
         gast.Name(id=gradname, ctx=gast.Store(), annotation=None)
     ],
                        value=gast.Call(func=utils.INIT_GRAD,
                                        args=[var],
                                        keywords=[]))
Пример #7
0
def test_reaching():
    node = tangent.quoting.parse_function(f)
    cfg.forward(node, cfg.ReachingDefinitions())
    body = node.body[0].body
    # Only the argument reaches the expression
    assert len(anno.getanno(body[0], 'definitions_in')) == 1
    while_body = body[1].body
    # x can be either the argument here, or from the previous loop
    assert len(anno.getanno(while_body[0], 'definitions_in')) == 2
    # x can only be the previous line here
    assert len(anno.getanno(while_body[1], 'definitions_in')) == 1
    # x can be the argument here or the last definition from the while body
    assert len(anno.getanno(body[2], 'definitions_in')) == 2
Пример #8
0
    def visit_FunctionDef(self, node):
        self.namer = naming.Namer.build(node)

        # Get the tangent of the body
        new_body = []
        for n in node.body:
            new = self.visit(n)
            if isinstance(new, (list, tuple)):
                new_body.extend(new)
            else:
                new_body.append(new)
        node.body = new_body

        # Add in the initial gradient argument
        grad_args = [
            create.create_grad(arg, self.namer, tangent=True)
            for i, arg in enumerate(node.args.args) if i in self.wrt
        ]
        if len(self.wrt) != len(grad_args):
            raise ValueError(
                'Mismatch between requested and retrieved derivative arguments. '
                'Requested %d, found %d') % (len(self.wrt), len(grad_args))

        node.args.args += grad_args

        if self.check_dims:
            # Define the shape check code quote
            def shape_match_template(primal, tangent_):
                if not tangent.shapes_match(primal, tangent_):
                    raise ValueError(
                        'Shape mismatch between argument value (%s) and seed derivative '
                        '(%s)' \
                  % (numpy.shape(primal), numpy.shape(tangent_)))

            # Add a shape check for each seed derivative & primal pair.
            shape_check_nodes = []
            for iwrt, tangent_var in zip(self.wrt, grad_args):
                primal = node.args.args[iwrt]
                shape_check = template.replace(shape_match_template,
                                               primal=primal,
                                               tangent_=tangent_var)[0]
                shape_check_nodes.append(shape_check)
            node.body = shape_check_nodes + node.body

        # Add in gradient initialization statements for everything else
        grad_init_nodes = [
            template.replace('d[z] = init_grad(z)',
                             replace_grad=template.Replace.TANGENT,
                             namer=self.namer,
                             z=arg,
                             init_grad=utils.INIT_GRAD)
            for i, arg in enumerate(node.args.args) if i not in self.wrt
        ]
        node.body = grad_init_nodes + node.body

        # Rename the function
        func = anno.getanno(node, 'func')
        node.name = naming.tangent_name(func, self.wrt)

        return node
Пример #9
0
 def visit(self, node, abort=astor.codegen.SourceGenerator.abort_visit):
     if anno.hasanno(node, 'comment'):
         comment = anno.getanno(node, 'comment')
         # Preprocess the comment to fit to maximum line width of 80 characters
         linewidth = 78
         if comment['location'] in ('above', 'below'):
             comment['text'] = comment['text'][:linewidth]
         n_newlines = 1 if self.new_indentation else 2
         if comment['location'] == 'above':
             self.result.append('\n' * n_newlines)
             self.result.append(self.indent_with * self.indentation)
             self.result.append('# %s' % comment['text'])
             super(SourceWithCommentGenerator, self).visit(node)
         elif comment['location'] == 'below':
             super(SourceWithCommentGenerator, self).visit(node)
             self.result.append('\n')
             self.result.append(self.indent_with * self.indentation)
             self.result.append('# %s' % comment['text'])
             self.result.append('\n' * (n_newlines - 1))
         elif comment['location'] == 'right':
             super(SourceWithCommentGenerator, self).visit(node)
             self.result.append(' # %s' % comment['text'])
         else:
             raise TangentParseError('Only valid comment locations are '
                                     'above, below, right')
     else:
         self.new_indentation = False
         super(SourceWithCommentGenerator, self).visit(node)
Пример #10
0
    def visit_Expr(self, node):
        # We need to special-case pushes, e.g. utils.push(_stack,x,op_id)
        adjoint = []
        if (isinstance(node.value, gast.Call)
                and anno.getanno(node.value, 'func') == utils.push):
            orig_src = quoting.unquote(node)
            stack, val, op_id = node.value.args
            push_template = grads.adjoints[utils.push]
            adjoint_rhs = template.replace(push_template,
                                           namer=self.namer,
                                           stack=stack,
                                           val=val,
                                           op_id=op_id)

            # Walk the RHS adjoint AST to find temporary adjoint variables to
            # sum
            accumulation = template.replace(
                'd[xi] = add_grad(d[xi], dxi_partial)',
                namer=self.namer,
                replace_grad=template.Replace.FULL,
                xi=val,
                dxi_partial=ast_.copy_node(adjoint_rhs[0].targets[0]),
                add_grad=utils.ADD_GRAD)
            adjoint = adjoint_rhs + [accumulation]
            for i, adj in enumerate(adjoint):
                adjoint[i] = comments.add_comment(adj,
                                                  'Grad of: %s' % orig_src)
        return node, adjoint
Пример #11
0
 def is_active(self, node):
   active_variables = anno.getanno(node, 'active_in')
   for succ in gast.walk(node):
     if (isinstance(succ, gast.Name) and isinstance(succ.ctx, gast.Load) and
         succ.id in active_variables):
       return True
   return False
Пример #12
0
 def visit_Assign(self, node):
     if isinstance(node.value, gast.Call) and anno.hasanno(
             node.value.func, 'add_grad'):
         defs = frozenset(
             id_ for id_, node in anno.getanno(node, 'definitions_in'))
         if ast_.get_name(node.targets[0]) not in defs:
             node.value = node.value.args[1]
     return node
Пример #13
0
  def visit(self, node):
    method = 'visit_' + node.__class__.__name__
    visitor = getattr(self, method, self.generic_visit)

    # Set certain attributes for child nodes
    if anno.hasanno(node, 'active_in'):
      self.active_variables = anno.getanno(node, 'active_in')

    return visitor(node)
Пример #14
0
def _get_stack_op_handle(node):
    assert isinstance(node, gast.Call), 'Only can get fn handles of Call nodes'
    fn_handle = anno.getanno(node, 'func', False)
    fn_map = defaultdict(lambda: False)
    fn_map['tangent.pop'] = utils.pop
    fn_map['tangent.push'] = utils.push
    fn_map['tangent.pop_stack'] = utils.pop_stack
    fn_map['tangent.push_stack'] = utils.push_stack
    if not fn_handle:
        fn_handle = fn_map[quoting.unquote(node.func)]
    return fn_handle
Пример #15
0
 def prepend_uninitialized_grads(self, node):
     if anno.hasanno(node, 'defined_in'):
         uses = (succ for succ in gast.walk(node)
                 if isinstance(succ, gast.Name)
                 and isinstance(succ.ctx, gast.Load))
         for use in uses:
             if ((anno.hasanno(use, 'adjoint_var')
                  or anno.hasanno(use, 'temp_adjoint_var'))
                     and use.id not in anno.getanno(node, 'defined_in')
                     and use.id not in self.added):
                 self.added.add(use.id)
                 self.insert_top(self._init(use))
     return node
Пример #16
0
def test_resolve():
    def g(x):
        return 2 * x

    def f(x):
        return g(x)

    node = annotate.resolve_calls(f)
    assert anno.getanno(node.body[0].body[0].value, 'func') == g

    def f(x):
        return h(x)

    node = quoting.parse_function(f)
    with pytest.raises(AttributeError):
        annotate.resolve_calls(f)
Пример #17
0
def create_grad(node, namer, tangent=False):
    """Given a variable, create a variable for the gradient.

  Args:
    node: A node to create a gradient for, can be a normal variable (`x`) or a
        subscript (`x[i]`).
    namer: The namer object which will determine the name to use for the
        gradient.
    tangent: Whether a tangent (instead of adjoint) is created.

  Returns:
    node: A node representing the gradient with the correct name e.g. the
        gradient of `x[i]` is `dx[i]`.

        Note that this returns an invalid node, with the `ctx` attribute
        missing. It is assumed that this attribute is filled in later.

        Node has an `adjoint_var` annotation referring to the node it is an
        adjoint of.
  """
    if not isinstance(node, (gast.Subscript, gast.Name, gast.Str)):
        raise TypeError

    if anno.hasanno(node, 'temp_var'):
        return create_grad(anno.getanno(node, 'temp_var'), namer, tangent)

    def _name_grad(node):
        if not isinstance(node, gast.Name):
            raise TypeError
        varname = node.id
        name = namer.grad(varname, tangent)
        grad_node = gast.Name(id=name, ctx=None, annotation=None)
        anno.setanno(grad_node, 'adjoint_var', node)
        return grad_node

    if isinstance(node, gast.Subscript):
        grad_node = create_grad(node.value, namer, tangent=tangent)
        grad_node.ctx = gast.Load()
        return gast.Subscript(value=grad_node, slice=node.slice, ctx=None)
    elif isinstance(node, gast.Str):
        grad_node = create_grad(gast.Name(id=node.s, ctx=None,
                                          annotation=None),
                                namer,
                                tangent=tangent)
        return gast.Str(grad_node.id)
    else:
        return _name_grad(node)
Пример #18
0
def assignment_propagation(node):
    """Perform assignment propagation.

  Assignment propagation is not a compiler optimization as much as a
  readability optimization. If a variable name is used only once, it gets
  renamed when possible e.g. `y = x; z = y` will become `z = x`.

  Args:
    node: The AST to optimize.

  Returns:
    The optimized AST.
  """
    n_reads = read_counts(node)

    to_remove = []
    for succ in gast.walk(node):
        # We found an assignment of the form a = b
        # - Left-hand side is a Name, right-hand side is a Name.
        if (isinstance(succ, gast.Assign)
                and isinstance(succ.value, gast.Name)
                and len(succ.targets) == 1
                and isinstance(succ.targets[0], gast.Name)):
            rhs_name = succ.value.id
            # We now find all the places that b was defined
            rhs_defs = [
                def_[1] for def_ in anno.getanno(succ, 'definitions_in')
                if def_[0] == rhs_name
            ]
            # If b was defined in only one place (not an argument), and wasn't used
            # anywhere else but in a == b, and was defined as b = x, then we can fold
            # the statements
            if (len(rhs_defs) == 1 and isinstance(rhs_defs[0], gast.Assign)
                    and n_reads[rhs_defs[0]] == 1
                    and isinstance(rhs_defs[0].value, gast.Name)
                    and isinstance(rhs_defs[0].targets[0], gast.Name)):
                # Mark rhs_def for deletion
                to_remove.append(rhs_defs[0])
                # Propagate the definition
                succ.value = rhs_defs[0].value

    # Remove the definitions we folded
    transformers.Remove(to_remove).visit(node)
    anno.clearanno(node)
    return node
Пример #19
0
 def active(node, incoming):
     gen = set()
     kill = set()
     if isinstance(node.value, gast.arguments):
         gen.update(node.value.args[i].id for i in wrt)
     if isinstance(node.value, gast.Assign):
         # Special-case e.g. x = tangent.pop(_stack)
         # such that all values popped off the stack are live.
         if anno.getanno(node.value.value, 'func', False) == utils.pop:
             gen.update(ast_.get_updated(node.value))
         else:
             for succ in gast.walk(node.value.value):
                 if isinstance(succ, gast.Name) and succ.id in incoming:
                     gen.update(ast_.get_updated(node.value))
                     break
             else:
                 kill.update(ast_.get_updated(node.value))
     return gen, kill
Пример #20
0
  def visit_For(self, node):
    # If the iter is a Name that is active,
    # we need to rewrite the loop.
    # Iterators of the form `for a in x` rely on an implicit
    # indexing operation, which Tangent cannot reverse without
    # more information. So, we will create an explicit
    # indexing operation. Note that we will use
    # integer indexes, which will cause strange behavior if
    # the iterator's `next()` behavior deviates from
    # a plain incrementing index.
    # The right thing to do (eventually) is to write a multiple-dispatch
    # version of the `next` operator, and its adjoint, so that
    # we can handle e.g. dicts.

    if isinstance(node.iter, (gast.Name, gast.Subscript, gast.Attribute)):
      iter_name = ast.get_name(node.iter)
      if iter_name in anno.getanno(node, 'active_in'):
        # for a in x:
        #   f(a)
        # # becomes
        # for i in range(len(x)):
        #   a = x[i]
        #   f(a)

        # Get a unique iterator name
        old_target = copy.deepcopy(node.target)
        new_target = quoting.quote(self.namer.unique('_idx'))
        old_iter = copy.deepcopy(node.iter)

        item_access = template.replace(
          'old_target = x[i]',
          old_target=old_target,
          x=old_iter,
          i=new_target)

        node.target = gast.Name(id=new_target.id, ctx=gast.Store(), annotation=None)
        node.iter = quoting.quote('range(len(%s))' % iter_name)
        anno.setanno(node.iter, 'func', range)
        anno.setanno(node.iter.args[0], 'func', len)
        node.body = [item_access] + node.body

    return node
Пример #21
0
def is_insert_grad_of_statement(node):
    """Check whether a context manager calls `insert_grad_of`.

  Args:
    node: The context manager node.

  Returns:
    Whether or not this node contains `insert_grad_of` calls.

  Raises:
    ValueError: If the `insert_grad_of` calls are mixed with other calls.
  """
    tangent_calls = [
        anno.getanno(item.context_expr, 'func', None) is utils.insert_grad_of
        for item in node.items
    ]
    if all(tangent_calls):
        return True
    elif any(tangent_calls):
        raise ValueError
    else:
        return False
Пример #22
0
  def is_active(self, node):
    """Checks whether a statement is active.

    An assignment is active when its right hand side contains active
    variables.

    Args:
      node: an instance of gast.Assign

    Returns:
      Whether the statement is active.
    """
    # Special case: If the right hand side is a pop statement, we want to
    # process it
    if (isinstance(node.value, gast.Call) and
        anno.getanno(node.value, 'func', False) == utils.pop):
      return True
    for succ in gast.walk(node.value):
      if (isinstance(succ, gast.Name) and isinstance(succ.ctx, gast.Load) and
          succ.id in self.active_variables):
        return True
    return False
Пример #23
0
  def visit_FunctionDef(self, node):
    self.namer = naming.Namer.build(node)

    # Get the tangent of the body
    new_body = []
    for n in node.body:
      new = self.visit(n)
      if isinstance(new, (list, tuple)):
        new_body.extend(new)
      else:
        new_body.append(new)
    node.body = new_body

    # Add in the initial gradient argument
    grad_args = [
        create.create_grad(arg, self.namer, tangent=True)
        for i, arg in enumerate(node.args.args) if i in self.wrt
    ]

    node.args.args += grad_args

    # Add in gradient initialization statements for everything else
    grad_init_nodes = [
        template.replace(
            'd[z] = init_grad(z)',
            replace_grad=template.Replace.TANGENT,
            namer=self.namer,
            z=arg,
            init_grad=utils.INIT_GRAD) for i, arg in enumerate(node.args.args)
        if i not in self.wrt
    ]
    node.body = grad_init_nodes + node.body

    # Rename the function
    func = anno.getanno(node, 'func')
    node.name = naming.tangent_name(func, self.wrt)

    return node
Пример #24
0
def remove_repeated_comments(node):
  """Remove comments that repeat themselves.

  Multiple statements might be annotated with the same comment. This way if one
  of the statements is deleted during optimization passes, the comment won't be
  lost. This pass removes sequences of identical comments, leaving only the
  first one.

  Args:
    node: An AST

  Returns:
    An AST where comments are not repeated in sequence.

  """
  last_comment = {'text': None}
  for _node in gast.walk(node):
    if anno.hasanno(_node, 'comment'):
      comment = anno.getanno(_node, 'comment')
      if comment['text'] == last_comment['text']:
        anno.delanno(_node, 'comment')
      last_comment = comment
  return node
Пример #25
0
  def visit_Expr(self, node):
    # Special-case the push() expression (have to reverse the usual
    # tangent/primal order)

    if isinstance(node.value, gast.Call):
      fn = anno.getanno(node.value, 'func')
      if fn in [tangent.push, tangent.push_stack]:
        # Save the pop associated with this push
        template_ = tangents.tangents[fn]
        stack_node, var_node, op_id = node.value.args
        tangent_node = template.replace(
            template_,
            replace_grad=template.Replace.TANGENT,
            namer=self.namer,
            x=var_node,
            stack=stack_node,
            op_id=op_id)
        # Push the original node and the tangent_node push
        # onto a "meta-stack", so we can track the relationship
        # between the pushes and pops in tangent mode
        self.metastack.append(tangent_node[0])
        return [node] + tangent_node
    return node
Пример #26
0
  def visit(self, node):
    """Visit a node.

    This method is largely modelled after the ast.NodeTransformer class.

    Args:
      node: The node to visit.

    Returns:
      A tuple of the primal and adjoint, each of which is a node or a list of
      nodes.
    """
    method = 'visit_' + node.__class__.__name__
    if not hasattr(self, method):
      raise ValueError('Unknown node type: %s' % node.__class__.__name__)
    visitor = getattr(self, method)

    # If this node is a statement, inform all child nodes what the active
    # variables in this statement are
    if anno.hasanno(node, 'active_in'):
      self.active_variables = anno.getanno(node, 'active_in')
    pri, adj = visitor(node)

    # Annotate primal and adjoint statements
    if isinstance(pri, gast.AST):
      anno.setdefaultanno(pri, 'adj', adj)
    else:
      for node in pri:
        anno.setdefaultanno(node, 'adj', adj)
    if isinstance(adj, gast.AST):
      anno.setdefaultanno(adj, 'pri', pri)
    else:
      for node in adj:
        anno.setdefaultanno(node, 'pri', pri)

    return pri, adj
Пример #27
0
 def visit(self, node):
     if anno.hasanno(node, 'push_var'):
         varname = ast_.get_name(anno.getanno(node, 'push_var'))
         if varname not in anno.getanno(node, 'defined_in'):
             self.insert_top(quoting.quote('{} = None'.format(varname)))
     return super(FixStack, self).visit(node)
Пример #28
0
  def visit_Call(self, node):
    if not self.target:
      return node
    func = anno.getanno(node, 'func')

    if func in tangents.UNIMPLEMENTED_TANGENTS:
      raise errors.ForwardNotImplementedError(func)

    if func == tracing.Traceable:
      raise NotImplementedError('Tracing of %s is not enabled in forward mode' %
                                quoting.unquote(node))

    if func not in tangents.tangents:
      try:
        quoting.parse_function(func)
      except:
        raise ValueError('No tangent found for %s, and could not get source.' %
                         func.__name__)

      # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy)
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if isinstance(arg, gast.Name))
      # TODO: Stack arguments are currently not considered
      # active, but for forward-mode applied to call trees,
      # they have to be. When we figure out how to update activity
      # analysis to do the right thing, we'll want to add the extra check:
      # `and arg.id in self.active_variables`

      # TODO: Duplicate of code in reverse_ad.
      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      fn_name = naming.tangent_name(func, active_args)
      orig_args = quoting.parse_function(func).body[0].args
      tangent_keywords = []
      for i in active_args:
        grad_node = create.create_grad(node.args[i], self.namer, tangent=True)
        arg_grad_node = create.create_grad(
            orig_args.args[i], self.namer, tangent=True)
        grad_node.ctx = gast.Load()
        tangent_keywords.append(
            gast.keyword(arg=arg_grad_node.id, value=grad_node))
      # Update the original call
      rhs = gast.Call(
          func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None),
          args=node.args,
          keywords=tangent_keywords + node.keywords)
      # Set self.value to False to trigger whole primal replacement
      self.value = False
      return [rhs]

    template_ = tangents.tangents[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)
    bound_args.apply_defaults()

    # If any keyword arguments weren't passed, we fill them using the
    # defaults of the original function
    if grads.DEFAULT in bound_args.arguments.values():
      # Build a mapping from names to defaults
      args = quoting.parse_function(func).body[0].args
      defaults = {}
      for arg, default in zip(*map(reversed, [args.args, args.defaults])):
        defaults[arg.id] = default
      for arg, default in zip(args.kwonlyargs, args.kw_defaults):
        if default is not None:
          defaults[arg.id] = default
      for name, value in bound_args.arguments.items():
        if value is grads.DEFAULT:
          bound_args.arguments[name] = defaults[name]

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: self.tmp_node}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(template_).co_varnames[
          -1]] = target
    tangent_node = template.replace(
        template_,
        replace_grad=template.Replace.TANGENT,
        namer=self.namer,
        **arg_replacements)

    # If the template uses the answer in the RHS of the tangent,
    # we need to make sure that the regular answer is replaced
    # with self.tmp_node, but that the gradient is not. We have
    # to be extra careful for statements like a = exp(a), because
    # both the target and RHS variables have the same name.
    tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True)
    tmp_grad_name = tmp_grad_node.id
    ans_grad_node = create.create_grad(self.target, self.namer, tangent=True)
    for _node in tangent_node:
      for succ in gast.walk(_node):
        if isinstance(succ, gast.Name) and succ.id == tmp_grad_name:
          succ.id = ans_grad_node.id

    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [
          create.create_temp_grad(arg, self.namer, True) for arg in to_pack
      ]
      value = create.create_grad(target, self.namer, tangent=True)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())

    # Stack pops have to be special-cased, we have
    # to set the 'push' attribute, so we know that if we
    # remove this pop, we have to remove the equivalent push.
    # NOTE: this only works if we're doing forward-over-reverse,
    # where reverse is applied in joint mode, with no call tree.
    # Otherwise, the pushes and pops won't be matched within a single
    # function call.
    if func == tangent.pop:
      if len(self.metastack):
        anno.setanno(tangent_node[0], 'push', self.metastack.pop())
      else:
        anno.setanno(tangent_node[0], 'push', None)
    return tangent_node
Пример #29
0
def test_active2():
    node = tangent.quoting.parse_function(i)
    cfg.forward(node, cfg.Active(wrt=(1, )))
    body = node.body[0].body
    # through y both x and z are now active
    assert len(anno.getanno(body[-1], 'active_out')) == 3
Пример #30
0
def test_active():
    node = tangent.quoting.parse_function(h)
    cfg.forward(node, cfg.Active(wrt=(1, )))
    body = node.body[0].body
    # y has been overwritten here, so nothing is active anymore
    assert not anno.getanno(body[-1], 'active_out')