Esempio n. 1
0
def test_reaching():
    node = tangent.quoting.parse_function(f)
    cfg.forward(node, cfg.ReachingDefinitions())
    body = node.body[0].body
    # Only the argument reaches the expression
    assert len(anno.getanno(body[0], 'definitions_in')) == 1
    while_body = body[1].body
    # x can be either the argument here, or from the previous loop
    assert len(anno.getanno(while_body[0], 'definitions_in')) == 2
    # x can only be the previous line here
    assert len(anno.getanno(while_body[1], 'definitions_in')) == 1
    # x can be the argument here or the last definition from the while body
    assert len(anno.getanno(body[2], 'definitions_in')) == 2
Esempio n. 2
0
def _fix(node):
  """Fix the naive construction of the adjont.

  See `fixes.py` for details.

  This function also returns the result of reaching definitions analysis so
  that `split` mode can use this to carry over the state from primal to
  adjoint.

  Args:
    node: A module with the primal and adjoint function definitions as returned
        by `reverse_ad`.

  Returns:
    node: A module with the primal and adjoint function with additional
        variable definitions and such added so that pushes onto the stack and
        gradient accumulations are all valid.
    defined: The variables defined at the end of the primal.
    reaching: The variable definitions that reach the end of the primal.
  """
  # Do reaching definitions analysis on primal and adjoint
  pri_cfg = cfg.CFG.build_cfg(node.body[0])
  defined = cfg.Defined()
  defined.visit(pri_cfg.entry)
  reaching = cfg.ReachingDefinitions()
  reaching.visit(pri_cfg.entry)

  cfg.forward(node.body[1], cfg.Defined())
  cfg.forward(node.body[1], cfg.ReachingDefinitions())

  # Remove pushes of variables that were never defined
  fixes.CleanStack().visit(node)
  fixes.FixStack().visit(node.body[0])

  # Change accumulation into definition if possible
  fixes.CleanGrad().visit(node.body[1])
  # Define gradients that might or might not be defined
  fixes.FixGrad().visit(node.body[1])
  return node, defined.exit, reaching.exit
Esempio n. 3
0
def read_counts(node):
    """Check how many times a variable definition was used.

  Args:
    node: An AST to analyze.

  Returns:
    A dictionary from assignment nodes to the number of times the assigned to
        variable was used.
  """
    cfg.forward(node, cfg.ReachingDefinitions())

    rc = ReadCounts()
    rc.visit(node)
    return rc.n_read
Esempio n. 4
0
def unused(node):
    """Find unused definitions that can be remove.

  This runs reaching definitions analysis followed by a walk over the AST to
  find all variable definitions that are not used later on.

  Args:
    node: The AST of e.g. a function body to find unused variable definitions.

  Returns:
    unused: After visiting all the nodes, this attribute contanis a set of
        definitions in the form of `(variable_name, node)` pairs which are
        unused in this AST.
  """
    cfg.forward(node, cfg.ReachingDefinitions())
    unused_obj = Unused()
    unused_obj.visit(node)
    return unused_obj.unused