Esempio n. 1
0
def reverse_ad(node, wrt, preserve_result):
  """Perform reverse-mode AD on an AST.

  This function analyses the AST to determine which variables are active and
  proceeds by taking the naive derivative. Before returning the primal and
  adjoint it annotates push and pop statements as such.

  Args:
    node: A `FunctionDef` AST node.
    wrt: A tuple of argument indices with respect to which we take the
        derivative.
    preserve_result: A boolean indicating whether the generated
        derivative function should also return the original return value.

  Returns:
    mod: A `Module` node containing the naive primal and adjoint of the
        function which can be fed to the `split` and `joint` functions.
    required: A list of tuples of functions and argument indices. These
        functions were called by the function but did not have an adjoint.
  """
  if not isinstance(node, gast.FunctionDef):
    raise TypeError
  # Activity analysis
  cfg.forward(node, cfg.Active(wrt))

  ad = ReverseAD(wrt, preserve_result)
  pri, adj = ad.visit(node)
  mod = gast.Module(body=[pri, adj])
  mod = annotate.find_stacks(mod)
  return mod, ad.required, ad.stack
Esempio n. 2
0
def forward_ad(node, wrt, preserve_result=False, check_dims=True):
    """Perform forward-mode AD on an AST.

  This function analyses the AST to determine which variables are active and
  proceeds by taking the naive derivative. Before returning the primal and
  adjoint it annotates push and pop statements as such.

  Args:
    node: A `FunctionDef` AST node.
    wrt: A tuple of argument indices with respect to which we take the
        derivative.
    preserve_result: A boolean indicating whether the original
        non-differentiated function value should be returned
    check_dims: A boolean indicating whether the provided derivatives should
        have the same shape as their corresponding arguments.

  Returns:
    mod: A `Module` node containing the naive primal and adjoint of the
        function which can be fed to the `split` and `joint` functions.
    required: A list of tuples of functions and argument indices. These
        functions were called by the function but did not have an adjoint.
  """
    if not isinstance(node, gast.FunctionDef):
        raise TypeError

    # Activity analysis
    cfg_obj = cfg.CFG.build_cfg(node)
    cfg.Active(range(len(node.args.args))).visit(cfg_obj.entry)

    # Build forward mode function
    fad = ForwardAD(wrt, preserve_result, check_dims)
    node = fad.visit(node)

    # Annotate stacks
    node = annotate.find_stacks(node)

    # Clean up naive forward-mode fcode
    node = gast.Module([node])
    anno.clearanno(node)

    return node, fad.required