def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over a TF range (and elides it).""" start, limit, delta = iter_.op.inputs iterate = compat_util.BasicRef(start) def aug_get_state(): return (iterate.value, ) + get_state() def aug_set_state(aug_loop_vars): # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax. iterate.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:] # The iteration index is not "output" by the for loop. If the iterate # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): body(iterate.value) iterate.value += delta def aug_test(): main_test = math_ops.logical_or( math_ops.logical_and(delta >= 0, iterate.value < limit), math_ops.logical_and(delta < 0, iterate.value > limit)) if extra_test is not None: return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test opts['maximum_iterations'] = math_ops.cast( misc.get_range_len(start, limit, delta), dtypes.int32) _tf_while_stmt(aug_test, aug_body, aug_get_state, aug_set_state, ('<internal iterate>', ) + symbol_names, opts)
def _tf_range_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over a TF range (and elides it).""" start, limit, delta = iter_.op.inputs iterate = start def _value_or(name, var, default): if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)): return default return var def aug_get_state(): state_vars = get_state() state_vars = tuple( _value_or(name, var, iterate) for name, var in zip(symbol_names, state_vars)) return (iterate,) + state_vars def aug_set_state(aug_loop_vars): nonlocal iterate # TODO(b/171479293): Drop the lint override. iterate, *loop_vars = aug_loop_vars # pylint:disable=unused-variable # The iteration index is not "output" by the for loop. If the iterate # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): nonlocal iterate body(iterate) iterate += delta def aug_test(): # TODO(b/159713842): Remove once constant folding works. const_delta = tensor_util.constant_value(delta) if const_delta is not None: if const_delta >= 0: main_test = iterate < limit else: main_test = iterate > limit else: main_test = math_ops.logical_or( math_ops.logical_and(delta >= 0, iterate < limit), math_ops.logical_and(delta < 0, iterate > limit)) if extra_test is not None: main_test = control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test _add_max_iterations_hint( opts, math_ops.cast(misc.get_range_len(start, limit, delta), dtypes.int32)) _tf_while_stmt( aug_test, aug_body, aug_get_state, aug_set_state, ('<internal iterate>',) + symbol_names, opts)
def _tf_range_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over a TF range (and elides it).""" start, limit, delta = iter_.op.inputs iterate = compat_util.BasicRef(start) def _value_or(name, var, default): if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)): return default return var def aug_get_state(): state_vars = get_state() state_vars = tuple( _value_or(name, var, iterate.value) for name, var in zip(symbol_names, state_vars)) return (iterate.value,) + state_vars def aug_set_state(aug_loop_vars): # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax. iterate.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:] # The iteration index is not "output" by the for loop. If the iterate # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): body(iterate.value) iterate.value += delta def aug_test(): # TODO(b/159713842): Remove once constant folding works. const_delta = tensor_util.constant_value(delta) if const_delta is not None: if const_delta >= 0: main_test = iterate.value < limit else: main_test = iterate.value > limit else: main_test = math_ops.logical_or( math_ops.logical_and(delta >= 0, iterate.value < limit), math_ops.logical_and(delta < 0, iterate.value > limit)) if extra_test is not None: main_test = control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test # TODO(b/134181679): Remove. if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): opts['maximum_iterations'] = math_ops.cast( misc.get_range_len(start, limit, delta), dtypes.int32) _tf_while_stmt( aug_test, aug_body, aug_get_state, aug_set_state, ('<internal iterate>',) + symbol_names, opts)
def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts): """Overload of for_stmt that iterates over a TF range (and elides it).""" _disallow_undefs_into_loop(*init_vars) start, limit, delta = iter_.op.inputs def while_body(iterate, *loop_vars): new_vars = body(iterate, *loop_vars) loop_vars = (iterate + delta, ) if new_vars: loop_vars += new_vars return loop_vars def while_cond(iterate, *loop_vars): """Cond function for `tf.while_loop`.""" main_test = math_ops.logical_or( math_ops.logical_and(delta >= 0, iterate < limit), math_ops.logical_and(delta < 0, iterate > limit)) if extra_test is not None: return control_flow_ops.cond( main_test, lambda: extra_test(*loop_vars), lambda: False, ) return main_test opts['maximum_iterations'] = math_ops.cast( misc.get_range_len(start, limit, delta), dtypes.int32) results = _tf_while_stmt( while_cond, while_body, get_state, set_state, (start, ) + init_vars, ('<internal iterate>', ) + basic_symbol_names, composite_symbol_names, opts, ) # Note: the iteration index is not returned by the while loop, however # if a symbol with the same name exists outside the loop, it will be captured # by the loop variables and ultimately updated correctly. if isinstance(results, (tuple, list)): assert len(results) >= 1 # Has at least the iterate. if len(results) > 1: results = results[1:] else: results = () return results
def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts): """Overload of for_stmt that iterates over a TF range (and elides it).""" _disallow_undefs_into_loop(*init_vars) start, limit, delta = iter_.op.inputs def while_body(iterate, *loop_vars): new_vars = body(iterate, *loop_vars) loop_vars = (iterate + delta, ) if new_vars: loop_vars += new_vars return loop_vars def while_cond(iterate, *loop_vars): """Cond function for `tf.while_loop`.""" def build_main_test(): """Main iteration condition.""" # TODO(b/138857806): The optimizer should handle this. # LogicalAnd is slow on GPU so we avoid adding it if `delta` is a # compile time constant. delta_const = tensor_util.constant_value(delta) if delta_const is not None: # Support single element arrays. delta_const = np.asscalar(delta_const) if delta_const >= 0: return iterate < limit else: return iterate > limit else: return math_ops.logical_or( math_ops.logical_and(delta >= 0, iterate < limit), math_ops.logical_and(delta < 0, iterate > limit)) main_test = build_main_test() if extra_test is not None: return control_flow_ops.cond( main_test, lambda: extra_test(*loop_vars), lambda: False, ) return main_test # TODO(b/134181679): The op should handle this optimizations. if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): # This specific dtype is required by while_loop. opts['maximum_iterations'] = math_ops.cast( misc.get_range_len(start, limit, delta), dtypes.int32) results = _tf_while_stmt( while_cond, while_body, get_state, set_state, (start, ) + init_vars, ('<internal iterate>', ) + basic_symbol_names, composite_symbol_names, opts, ) # Note: the iteration index is not returned by the while loop, however # if a symbol with the same name exists outside the loop, it will be captured # by the loop variables and ultimately updated correctly. if isinstance(results, (tuple, list)): assert len(results) >= 1 # Has at least the iterate. if len(results) > 1: results = results[1:] else: results = () return results