Example #1
0
def dead_code_elimination(node):
    """Perform a simple form of dead code elimination on a Python AST.

  This method performs reaching definitions analysis on all function
  definitions. It then looks for the definition of variables that are not used
  elsewhere and removes those definitions.

  This function takes into consideration push and pop statements; if a pop
  statement is removed, it will also try to remove the accompanying push
  statement. Note that this *requires dead code elimination to be performed on
  the primal and adjoint simultaneously*.

  Args:
    node: The AST to optimize.

  Returns:
    The optimized AST.
  """
    to_remove = set(def_[1] for def_ in annotate.unused(node)
                    if not isinstance(def_[1], (gast.arguments, gast.For)))
    for n in list(to_remove):
        for succ in gast.walk(n):
            if anno.getanno(succ, 'push', False):
                to_remove.add(anno.getanno(succ, 'push'))
    transformers.Remove(to_remove).visit(node)
    anno.clearanno(node)
    return node
Example #2
0
def assignment_propagation(node):
    """Perform assignment propagation.

  Assignment propagation is not a compiler optimization as much as a
  readability optimization. If a variable name is used only once, it gets
  renamed when possible e.g. `y = x; z = y` will become `z = x`.

  Args:
    node: The AST to optimize.

  Returns:
    The optimized AST.
  """
    n_reads = read_counts(node)

    to_remove = []
    for succ in gast.walk(node):
        # We found an assignment of the form a = b
        # - Left-hand side is a Name, right-hand side is a Name.
        if (isinstance(succ, gast.Assign)
                and isinstance(succ.value, gast.Name)
                and len(succ.targets) == 1
                and isinstance(succ.targets[0], gast.Name)):
            rhs_name = succ.value.id
            # We now find all the places that b was defined
            rhs_defs = [
                def_[1] for def_ in anno.getanno(succ, 'definitions_in')
                if def_[0] == rhs_name
            ]
            # If b was defined in only one place (not an argument), and wasn't used
            # anywhere else but in a == b, and was defined as b = x, then we can fold
            # the statements
            if (len(rhs_defs) == 1 and isinstance(rhs_defs[0], gast.Assign)
                    and n_reads[rhs_defs[0]] == 1
                    and isinstance(rhs_defs[0].value, gast.Name)
                    and isinstance(rhs_defs[0].targets[0], gast.Name)):
                # Mark rhs_def for deletion
                to_remove.append(rhs_defs[0])
                # Propagate the definition
                succ.value = rhs_defs[0].value

    # Remove the definitions we folded
    transformers.Remove(to_remove).visit(node)
    anno.clearanno(node)
    return node
Example #3
0
def joint(node):
  """Merge the bodies of primal and adjoint into a single function.

  Args:
    node: A module with the primal and adjoint function definitions as returned
        by `reverse_ad`.

  Returns:
    func: A `Module` node with a single function definition containing the
        combined primal and adjoint.
  """
  node, _, _ = _fix(node)
  body = node.body[0].body[:-1] + node.body[1].body
  func = gast.Module(body=[gast.FunctionDef(
      name=node.body[0].name, args=node.body[1].args, body=body,
      decorator_list=[], returns=None)])
  # Clean up
  anno.clearanno(func)
  return func
Example #4
0
def split(node, stack):
  """Carry over the state from the primal to the adjoint.

  Args:
    node: A module with the primal and adjoint function definitions as returned
        by `reverse_ad`.
    stack: The stack node to use for storing and restoring state.

  Returns:
    func: A `Module` node with two function definitions containing the primal
        and adjoint respectively.
  """
  node, defined, reaching = _fix(node)

  # Store and restore the state
  node = store_state(node, reaching, defined, stack)

  # Clean up
  anno.clearanno(node)
  return node
Example #5
0
def forward_ad(node, wrt, preserve_result=False, check_dims=True):
    """Perform forward-mode AD on an AST.

  This function analyses the AST to determine which variables are active and
  proceeds by taking the naive derivative. Before returning the primal and
  adjoint it annotates push and pop statements as such.

  Args:
    node: A `FunctionDef` AST node.
    wrt: A tuple of argument indices with respect to which we take the
        derivative.
    preserve_result: A boolean indicating whether the original
        non-differentiated function value should be returned
    check_dims: A boolean indicating whether the provided derivatives should
        have the same shape as their corresponding arguments.

  Returns:
    mod: A `Module` node containing the naive primal and adjoint of the
        function which can be fed to the `split` and `joint` functions.
    required: A list of tuples of functions and argument indices. These
        functions were called by the function but did not have an adjoint.
  """
    if not isinstance(node, gast.FunctionDef):
        raise TypeError

    # Activity analysis
    cfg_obj = cfg.CFG.build_cfg(node)
    cfg.Active(range(len(node.args.args))).visit(cfg_obj.entry)

    # Build forward mode function
    fad = ForwardAD(wrt, preserve_result, check_dims)
    node = fad.visit(node)

    # Annotate stacks
    node = annotate.find_stacks(node)

    # Clean up naive forward-mode fcode
    node = gast.Module([node])
    anno.clearanno(node)

    return node, fad.required