예제 #1
0
def reverse_ad(node, wrt, preserve_result):
  """Perform reverse-mode AD on an AST.

  This function analyses the AST to determine which variables are active and
  proceeds by taking the naive derivative. Before returning the primal and
  adjoint it annotates push and pop statements as such.

  Args:
    node: A `FunctionDef` AST node.
    wrt: A tuple of argument indices with respect to which we take the
        derivative.
    preserve_result: A boolean indicating whether the generated
        derivative function should also return the original return value.

  Returns:
    mod: A `Module` node containing the naive primal and adjoint of the
        function which can be fed to the `split` and `joint` functions.
    required: A list of tuples of functions and argument indices. These
        functions were called by the function but did not have an adjoint.
  """
  if not isinstance(node, gast.FunctionDef):
    raise TypeError
  # Activity analysis
  cfg.forward(node, cfg.Active(wrt))

  ad = ReverseAD(wrt, preserve_result)
  pri, adj = ad.visit(node)
  mod = gast.Module(body=[pri, adj])
  mod = annotate.find_stacks(mod)
  return mod, ad.required, ad.stack
예제 #2
0
파일: test_cfg.py 프로젝트: zouzias/tangent
def test_defined():
    node = tangent.quoting.parse_function(g)
    cfg.forward(node, cfg.Defined())
    body = node.body[0].body
    # only x is for sure defined at the end
    assert len(anno.getanno(body[1], 'defined_in')) == 1
    # at the end of the if body both x and y are defined
    if_body = body[0].body
    assert len(anno.getanno(if_body[0], 'defined_out')) == 2
예제 #3
0
파일: test_cfg.py 프로젝트: zouzias/tangent
def test_reaching():
    node = tangent.quoting.parse_function(f)
    cfg.forward(node, cfg.ReachingDefinitions())
    body = node.body[0].body
    # Only the argument reaches the expression
    assert len(anno.getanno(body[0], 'definitions_in')) == 1
    while_body = body[1].body
    # x can be either the argument here, or from the previous loop
    assert len(anno.getanno(while_body[0], 'definitions_in')) == 2
    # x can only be the previous line here
    assert len(anno.getanno(while_body[1], 'definitions_in')) == 1
    # x can be the argument here or the last definition from the while body
    assert len(anno.getanno(body[2], 'definitions_in')) == 2
예제 #4
0
def read_counts(node):
    """Check how many times a variable definition was used.

  Args:
    node: An AST to analyze.

  Returns:
    A dictionary from assignment nodes to the number of times the assigned to
        variable was used.
  """
    cfg.forward(node, cfg.ReachingDefinitions())

    rc = ReadCounts()
    rc.visit(node)
    return rc.n_read
예제 #5
0
def unused(node):
    """Find unused definitions that can be remove.

  This runs reaching definitions analysis followed by a walk over the AST to
  find all variable definitions that are not used later on.

  Args:
    node: The AST of e.g. a function body to find unused variable definitions.

  Returns:
    unused: After visiting all the nodes, this attribute contanis a set of
        definitions in the form of `(variable_name, node)` pairs which are
        unused in this AST.
  """
    cfg.forward(node, cfg.ReachingDefinitions())
    unused_obj = Unused()
    unused_obj.visit(node)
    return unused_obj.unused
예제 #6
0
def _fix(node):
  """Fix the naive construction of the adjont.

  See `fixes.py` for details.

  This function also returns the result of reaching definitions analysis so
  that `split` mode can use this to carry over the state from primal to
  adjoint.

  Args:
    node: A module with the primal and adjoint function definitions as returned
        by `reverse_ad`.

  Returns:
    node: A module with the primal and adjoint function with additional
        variable definitions and such added so that pushes onto the stack and
        gradient accumulations are all valid.
    defined: The variables defined at the end of the primal.
    reaching: The variable definitions that reach the end of the primal.
  """
  # Do reaching definitions analysis on primal and adjoint
  pri_cfg = cfg.CFG.build_cfg(node.body[0])
  defined = cfg.Defined()
  defined.visit(pri_cfg.entry)
  reaching = cfg.ReachingDefinitions()
  reaching.visit(pri_cfg.entry)

  cfg.forward(node.body[1], cfg.Defined())
  cfg.forward(node.body[1], cfg.ReachingDefinitions())

  # Remove pushes of variables that were never defined
  fixes.CleanStack().visit(node)
  fixes.FixStack().visit(node.body[0])

  # Change accumulation into definition if possible
  fixes.CleanGrad().visit(node.body[1])
  # Define gradients that might or might not be defined
  fixes.FixGrad().visit(node.body[1])
  return node, defined.exit, reaching.exit
예제 #7
0
파일: test_cfg.py 프로젝트: zouzias/tangent
def test_active2():
    node = tangent.quoting.parse_function(i)
    cfg.forward(node, cfg.Active(wrt=(1, )))
    body = node.body[0].body
    # through y both x and z are now active
    assert len(anno.getanno(body[-1], 'active_out')) == 3
예제 #8
0
파일: test_cfg.py 프로젝트: zouzias/tangent
def test_active():
    node = tangent.quoting.parse_function(h)
    cfg.forward(node, cfg.Active(wrt=(1, )))
    body = node.body[0].body
    # y has been overwritten here, so nothing is active anymore
    assert not anno.getanno(body[-1], 'active_out')
예제 #9
0
 def visit_FunctionDef(self, node):
   cfg.forward(node, cfg.Active(range(len(node.args.args))))
   self.namer = naming.Namer.build(node)
   node = self.generic_visit(node)
   return node