def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over a TF range (and elides it).""" start, limit, delta = iter_.op.inputs iterate = compat_util.BasicRef(start) def aug_get_state(): return (iterate.value, ) + get_state() def aug_set_state(aug_loop_vars): # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax. iterate.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:] # The iteration index is not "output" by the for loop. If the iterate # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): body(iterate.value) iterate.value += delta def aug_test(): main_test = math_ops.logical_or( math_ops.logical_and(delta >= 0, iterate.value < limit), math_ops.logical_and(delta < 0, iterate.value > limit)) if extra_test is not None: return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test opts['maximum_iterations'] = math_ops.cast( misc.get_range_len(start, limit, delta), dtypes.int32) _tf_while_stmt(aug_test, aug_body, aug_get_state, aug_set_state, ('<internal iterate>', ) + symbol_names, opts)
def _tf_range_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over a TF range (and elides it).""" start, limit, delta = iter_.op.inputs iterate = compat_util.BasicRef(start) def _value_or(name, var, default): if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)): return default return var def aug_get_state(): state_vars = get_state() state_vars = tuple( _value_or(name, var, iterate.value) for name, var in zip(symbol_names, state_vars)) return (iterate.value,) + state_vars def aug_set_state(aug_loop_vars): # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax. iterate.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:] # The iteration index is not "output" by the for loop. If the iterate # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): body(iterate.value) iterate.value += delta def aug_test(): # TODO(b/159713842): Remove once constant folding works. const_delta = tensor_util.constant_value(delta) if const_delta is not None: if const_delta >= 0: main_test = iterate.value < limit else: main_test = iterate.value > limit else: main_test = math_ops.logical_or( math_ops.logical_and(delta >= 0, iterate.value < limit), math_ops.logical_and(delta < 0, iterate.value > limit)) if extra_test is not None: main_test = control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test # TODO(b/134181679): Remove. if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): opts['maximum_iterations'] = math_ops.cast( misc.get_range_len(start, limit, delta), dtypes.int32) _tf_while_stmt( aug_test, aug_body, aug_get_state, aug_set_state, ('<internal iterate>',) + symbol_names, opts)
def _tf_iterator_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over TF Iterators. See for_loop.""" symbol_names = ('<internal has_next>',) + symbol_names has_next = compat_util.BasicRef(True) def aug_get_state(): return (has_next.value,) + get_state() def aug_set_state(aug_loop_vars): # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax. has_next.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:] set_state(loop_vars) init_vars = aug_get_state() _verify_loop_init_vars(init_vars, symbol_names) def aug_body(): """Main body passed to _tf_while_stmt.""" opt_iterate = iterator_ops.get_next_as_optional(iter_) has_next.value = opt_iterate.has_value() loop_vars = aug_get_state() # updated by set_state() in _tf_while_loop. def main_path(): body(opt_iterate.get_value()) new_loop_vars = aug_get_state() # Note: this verification duplicates the one performed in tf_while_stmt, # but needs to be done earlier to prevent the tf.cond from blowing up # first. _verify_tf_loop_vars( init_vars, loop_vars, new_loop_vars, symbol_names, opts) return new_loop_vars def noop_path(): return loop_vars # TODO(mdan): If tf.while_loop supported Optional, this could be avoided. # Calling set_state so that get_state() _tf_while_loop sees the conditional # tensors. aug_set_state( control_flow_ops.cond(has_next.value, main_path, noop_path)) def aug_test(): # This value takes a complicated path to get here: # prev_iteration_body -> get_state -> tf.while_loop (as loop var) # -> current_iteration_body -> set_state -> has_next.value main_test = has_next.value if extra_test is not None: return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test _tf_while_stmt( aug_test, aug_body, aug_get_state, aug_set_state, symbol_names, opts)
def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over TF entities that admit a length.""" n = py_builtins.len_(iter_) # TODO(b/117628877): Revisit performance once XLA has the necessary support. # Note: using a TensorArray creates an extra copy, but can calculate # gradients more efficiently than StridedSlice. ta = tensor_array_ops.TensorArray(iter_.dtype, size=n) iter_ = ta.unstack(iter_) iterate_index = compat_util.BasicRef(0) def aug_get_state(): return (iterate_index.value, ) + get_state() def aug_set_state(aug_loop_vars): # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax. iterate_index.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:] # The iteration index is not "output" by the for loop. If the iterate # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): body(iter_.read(iterate_index.value)) iterate_index.value += 1 def aug_test(): main_test = iterate_index.value < n if extra_test is not None: return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test # TODO(b/159186914): Remove. if not control_flow_util.GraphOrParentsInXlaContext( ops.get_default_graph()): opts['maximum_iterations'] = n _tf_while_stmt( aug_test, aug_body, aug_get_state, aug_set_state, ('<internal iterate>', ) + symbol_names, opts, )
def _tf_ragged_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over TF ragged tensors.""" init_vars = get_state() _verify_loop_init_vars(init_vars, symbol_names) # TODO(mdan): Move this into len()? Requires eager support. if iter_.shape and iter_.shape[0] is not None: n = iter_.shape[0] else: n = iter_.row_lengths()[0] iterate_index = compat_util.BasicRef(0) def aug_get_state(): return (iterate_index.value,) + get_state() def aug_set_state(aug_loop_vars): # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax. iterate_index.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:] # The iteration index is not "output" by the for loop. If the iterate # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): body(iter_[iterate_index.value]) iterate_index.value += 1 def aug_test(): main_test = iterate_index.value < n if extra_test is not None: return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test # TODO(b/159186914): Remove. if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): opts['maximum_iterations'] = n _tf_while_stmt( aug_test, aug_body, aug_get_state, aug_set_state, ('<internal iterate>',) + symbol_names, opts)