def for_stmt(target, iter_, body, orelse, local_writes): """Functional form of a for statement.""" if tf.is_tensor(iter_): local_writes = [ var for var in local_writes if not py_defaults.is_undefined(var.val) ] n = _tf_len(iter_) def for_test(i, *_): return i < n def for_body(iterate_index, *state): # pylint: disable=missing-docstring for var, s in zip(local_writes, state): var.val = s target.val = iter_[iterate_index] mods, _ = staging.execute_isolated(body, local_writes) state = [iterate_index + 1] + mods return state result_values = _tf_while_stmt(for_test, for_body, [0] + [var.val for var in local_writes]) for var, val in zip(local_writes, result_values[1:]): var.val = val else: py_defaults.for_stmt(target, iter_, body, orelse, local_writes)
def while_stmt(cond, body, orelse, local_writes): """Functional form of a while statement.""" cond_result = cond() if tf.is_tensor(cond_result): local_writes = [ var for var in local_writes if not py_defaults.is_undefined(var.val) ] def while_test(*state): for var, s in zip(local_writes, state): var.val = s _, retvals = staging.execute_isolated(cond, local_writes) return retvals def while_body(*state): for var, s in zip(local_writes, state): var.val = s mods, _ = staging.execute_isolated(body, local_writes) return mods result_values = _tf_while_stmt(while_test, while_body, [var.val for var in local_writes]) for var, val in zip(local_writes, result_values): var.val = val else: staging.run_python_while(cond, body, orelse, cond_result)
def while_stmt(cond, body, _, local_writes): """Functional form of a while statement.""" local_writes = [ var for var in local_writes if not py_defaults.is_undefined(var.val) ] def while_test(state): for var, s in zip(local_writes, state): var.val = s _, result_values = staging.execute_isolated(cond, local_writes) return result_values def while_body(state): for var, s in zip(local_writes, state): var.val = s modified_vals, _ = staging.execute_isolated(body, local_writes) return modified_vals result_values = lax.while_loop(while_test, while_body, [var.val for var in local_writes]) for var, val in zip(local_writes, result_values): var.val = val return result_values
def protected_func(): """Calls function and raises an error if undefined symbols are returned.""" results = func() undefined_symbols = None if isinstance(results, tuple): undefined_symbols = _filter_undefined(results) elif py_defaults.is_undefined(results): # Single return value undefined_symbols = results.symbol_name if undefined_symbols: message = ( 'The following symbols must also be initialized in the %s ' 'branch: {}. Alternatively, you may initialize them before ' 'the if statement.') % branch_name message = message.format(undefined_symbols) raise ValueError(message) return results
def for_stmt(target, iter_, body, orelse, modified_vars): """Functional form of a for statement.""" del orelse modified_vars = [ var for var in modified_vars if not py_defaults.is_undefined(var.val) ] def for_body(idx, state): for var, s in zip(modified_vars, state): var.val = s target.val = iter_[idx] modified_vals, _ = staging.execute_isolated(body, modified_vars) return modified_vals results = lax.fori_loop(0, len(iter_), for_body, [var.val for var in modified_vars]) for var, val in zip(modified_vars, results): var.val = val
def _filter_undefined(all_symbols): """Returns the names of undefined symbols contained in all_symbols.""" undefined_symbols = [ s.name for s in all_symbols if py_defaults.is_undefined(s) ] return undefined_symbols