Ejemplo n.º 1
0
def test_defined():
    node = tangent.quoting.parse_function(g)
    cfg.forward(node, cfg.Defined())
    body = node.body[0].body
    # only x is for sure defined at the end
    assert len(anno.getanno(body[1], 'defined_in')) == 1
    # at the end of the if body both x and y are defined
    if_body = body[0].body
    assert len(anno.getanno(if_body[0], 'defined_out')) == 2
Ejemplo n.º 2
0
def _fix(node):
  """Fix the naive construction of the adjont.

  See `fixes.py` for details.

  This function also returns the result of reaching definitions analysis so
  that `split` mode can use this to carry over the state from primal to
  adjoint.

  Args:
    node: A module with the primal and adjoint function definitions as returned
        by `reverse_ad`.

  Returns:
    node: A module with the primal and adjoint function with additional
        variable definitions and such added so that pushes onto the stack and
        gradient accumulations are all valid.
    defined: The variables defined at the end of the primal.
    reaching: The variable definitions that reach the end of the primal.
  """
  # Do reaching definitions analysis on primal and adjoint
  pri_cfg = cfg.CFG.build_cfg(node.body[0])
  defined = cfg.Defined()
  defined.visit(pri_cfg.entry)
  reaching = cfg.ReachingDefinitions()
  reaching.visit(pri_cfg.entry)

  cfg.forward(node.body[1], cfg.Defined())
  cfg.forward(node.body[1], cfg.ReachingDefinitions())

  # Remove pushes of variables that were never defined
  fixes.CleanStack().visit(node)
  fixes.FixStack().visit(node.body[0])

  # Change accumulation into definition if possible
  fixes.CleanGrad().visit(node.body[1])
  # Define gradients that might or might not be defined
  fixes.FixGrad().visit(node.body[1])
  return node, defined.exit, reaching.exit