def if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names, composite_symbol_names): """Functional form of an if statement. Args: cond: Boolean. body: Callable with no arguments, and outputs of the positive (if) branch as return type. orelse: Callable with no arguments, and outputs of the negative (else) branch as return type. get_state: Function that returns a tuple containing the values of all composite symbols modified within the conditional. This allows access to state that branches may mutate through side effects. This function is not needed and should not be called when dispatching to code matching Python's default semantics. This is useful for checkpointing to avoid unintended side-effects when staging requires evaluating all code-paths. set_state: Function to set the values of all composite symbols modified within the conditional. This is the complement to get_state, used to restore checkpointed values. The single argument a tuple containing values for each composite symbol that may be modified in a branch of the conditional. The is usually the result of a call to get_state. basic_symbol_names: Tuple containing basic loop var names. composite_symbol_names: Tuple containing composite loop var names. Returns: Tuple containing the statement outputs. """ # Note: tf.cond doesn't support SparseTensor. if tensors.is_dense_tensor(cond): return tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names, composite_symbol_names) else: return _py_if_stmt(cond, body, orelse)
def while_stmt( test, body, get_state, set_state, init_vars, basic_symbol_names=None, composite_symbol_names=None, opts=None, ): """Functional form of a while statement. The loop operates on a so-called state, which includes all symbols that are variant across loop iterations. In what follows we refer to state as either a tuple of entities that represent an actual state, or a list of arguments of the corresponding types. Args: test: Callable with the state as arguments, and boolean return type. The loop condition. body: Callable with the state as arguments, and state as return type. The actual loop body. get_state: Additional callable which can capture additional state (such as the values of composite symbols). This is only useful when staging the loop. set_state: Additional callable which save values captured by get_state back into the Python environment. This is only useful when staging the loop. init_vars: Tuple containing the initial state. basic_symbol_names: Tuple containing basic loop var names. composite_symbol_names: Tuple containing composite loop var names. opts: Optional dict of extra loop parameters. Returns: Tuple containing the final state. """ # Evaluate the initial test once in order to do the dispatch. The evaluation # is isolated to minimize unwanted side effects. # TODO(mdan): Do a full iteration - some state types might lower to Tensor. with func_graph.FuncGraph('tmp').as_default(): init_test = test(*init_vars) # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine # with the re-evaluation of `test` that `_tf_while_stmt` will make. if tensors.is_dense_tensor(init_test): return _tf_while_stmt(test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts) # Normal Python: We already consumed one evaluation of `test`; consistently, # unroll one iteration before dispatching to a normal loop. # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt? if not init_test: return init_vars init_vars = body(*init_vars) return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts): """Functional form of an if statement. The conditional operates on a state, which includes all symbols whose values are a function of the branch taken. For example, given the code below that calculates the abs function: ``` x = 1 if x > 0: x = -x ``` The state is represented by the variable `x`. The `body, `orelse` and `set_state` functions must bind to the original `x` symbol, using `nonlocal`. The inputs and outputs of the callables representing the loop blocks are not explicit - instead, these functions must use nonlocal/global for side effects. The inputs and outputs are instead controlled by the set_state/get_state functions. Args: cond: Boolean. body: Callable representing the main block of the conditional. orelse: Callable representing the else block of the conditional. get_state: Function that returns a tuple containing the values of all composite symbols modified within the conditional. This allows access to state that branches may mutate through side effects. This function is not needed and should not be called when dispatching to code matching Python's default semantics. This is useful for checkpointing to avoid unintended side-effects when staging requires evaluating all code-paths. set_state: Function to set the values of all composite symbols modified within the conditional. This is the complement to get_state, used to restore checkpointed values. The single argument a tuple containing values for each composite symbol that may be modified in a branch of the conditional. The is usually the result of a call to get_state. symbol_names: Tuple containing basic loop var names. nouts: Number of variables output by the statement. Vars which are not outputs will not be passed through staged control flow such as tf.cond. This includes variables that are defined before the conditional, but are not used after it. """ # Note: tf.cond doesn't support SparseTensor. if tensors.is_dense_tensor(cond): _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts) else: _py_if_stmt(cond, body, orelse)
def if_exp(cond, if_true, if_false, expr_repr): if tensors.is_dense_tensor(cond): return _tf_if_exp(cond, if_true, if_false, expr_repr) else: return _py_if_exp(cond, if_true, if_false)