Beispiel #1
0
def if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names,
            composite_symbol_names):
    """Functional form of an if statement.

  Args:
    cond: Boolean.
    body: Callable with no arguments, and outputs of the positive (if) branch as
      return type.
    orelse: Callable with no arguments, and outputs of the negative (else)
      branch as return type.
    get_state: Function that returns a tuple containing the values of all
      composite symbols modified within the conditional. This allows access to
      state that branches may mutate through side effects. This function is not
      needed and should not be called when dispatching to code matching Python's
      default semantics. This is useful for checkpointing to avoid unintended
      side-effects when staging requires evaluating all code-paths.
    set_state: Function to set the values of all composite symbols modified
      within the conditional. This is the complement to get_state, used to
      restore checkpointed values. The single argument a tuple containing values
      for each composite symbol that may be modified in a branch of the
      conditional. The is usually the result of a call to get_state.
    basic_symbol_names: Tuple containing basic loop var names.
    composite_symbol_names: Tuple containing composite loop var names.

  Returns:
    Tuple containing the statement outputs.
  """
    # Note: tf.cond doesn't support SparseTensor.
    if tensors.is_dense_tensor(cond):
        return tf_if_stmt(cond, body, orelse, get_state, set_state,
                          basic_symbol_names, composite_symbol_names)
    else:
        return _py_if_stmt(cond, body, orelse)
Beispiel #2
0
def while_stmt(
    test,
    body,
    get_state,
    set_state,
    init_vars,
    basic_symbol_names=None,
    composite_symbol_names=None,
    opts=None,
):
    """Functional form of a while statement.

  The loop operates on a so-called state, which includes all symbols that are
  variant across loop iterations. In what follows we refer to state as either
  a tuple of entities that represent an actual state, or a list of arguments
  of the corresponding types.

  Args:
    test: Callable with the state as arguments, and boolean return type. The
      loop condition.
    body: Callable with the state as arguments, and state as return type. The
      actual loop body.
    get_state: Additional callable which can capture additional state (such as
      the values of composite symbols). This is only useful when staging the
      loop.
    set_state: Additional callable which save values captured by get_state back
      into the Python environment. This is only useful when staging the loop.
    init_vars: Tuple containing the initial state.
    basic_symbol_names: Tuple containing basic loop var names.
    composite_symbol_names: Tuple containing composite loop var names.
    opts: Optional dict of extra loop parameters.

  Returns:
    Tuple containing the final state.
  """

    # Evaluate the initial test once in order to do the dispatch. The evaluation
    # is isolated to minimize unwanted side effects.
    # TODO(mdan): Do a full iteration - some state types might lower to Tensor.
    with func_graph.FuncGraph('tmp').as_default():
        init_test = test(*init_vars)

    # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine
    # with the re-evaluation of `test` that `_tf_while_stmt` will make.
    if tensors.is_dense_tensor(init_test):
        return _tf_while_stmt(test, body, get_state, set_state, init_vars,
                              basic_symbol_names, composite_symbol_names, opts)

    # Normal Python: We already consumed one evaluation of `test`; consistently,
    # unroll one iteration before dispatching to a normal loop.
    # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt?
    if not init_test:
        return init_vars
    init_vars = body(*init_vars)

    return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
Beispiel #3
0
def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts):
    """Functional form of an if statement.

  The conditional operates on a state, which includes all symbols whose values
  are a function of the branch taken.

  For example, given the code below that calculates the abs function:

  ```
    x = 1
    if x > 0:
      x = -x
  ```

  The state is represented by the variable `x`. The `body, `orelse` and
  `set_state` functions must bind to the original `x` symbol, using `nonlocal`.

  The inputs and outputs of the callables representing the loop blocks are not
  explicit - instead, these functions must use nonlocal/global for side effects.
  The inputs and outputs are instead controlled by the set_state/get_state
  functions.

  Args:
    cond: Boolean.
    body: Callable representing the main block of the conditional.
    orelse: Callable representing the else block of the conditional.
    get_state: Function that returns a tuple containing the values of all
      composite symbols modified within the conditional. This allows access to
      state that branches may mutate through side effects. This function is not
      needed and should not be called when dispatching to code matching Python's
      default semantics. This is useful for checkpointing to avoid unintended
      side-effects when staging requires evaluating all code-paths.
    set_state: Function to set the values of all composite symbols modified
      within the conditional. This is the complement to get_state, used to
      restore checkpointed values. The single argument a tuple containing values
      for each composite symbol that may be modified in a branch of the
      conditional. The is usually the result of a call to get_state.
    symbol_names: Tuple containing basic loop var names.
    nouts: Number of variables output by the statement. Vars which are
      not outputs will not be passed through staged control flow such as
      tf.cond. This includes variables that are defined before the conditional,
      but are not used after it.
  """
    # Note: tf.cond doesn't support SparseTensor.
    if tensors.is_dense_tensor(cond):
        _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names,
                    nouts)
    else:
        _py_if_stmt(cond, body, orelse)
Beispiel #4
0
def if_exp(cond, if_true, if_false, expr_repr):
    if tensors.is_dense_tensor(cond):
        return _tf_if_exp(cond, if_true, if_false, expr_repr)
    else:
        return _py_if_exp(cond, if_true, if_false)