def _known_len_for_stmt(iter_, extra_test, body, init_state): """Overload of for_stmt that iterates over objects that define a length.""" n = builtins.dynamic_len(iter_) def while_body(iterate_index, *state): iterate = iter_[iterate_index] new_state = body(iterate, *state) return (iterate_index + 1,) + new_state def while_cond(iterate_index, *state): return gen_math_ops.logical_and(iterate_index < n, extra_test(*state)) results = while_stmt( while_cond, while_body, init_state=(0,) + init_state, extra_deps=(iter_,), opts=dict(maximum_iterations=n)) # Dropping the iteration index because it's not syntactically visible. results = results[1:] # TODO(mdan): Remove this special case. if len(results) == 1: return results[0] return results
def _known_len_for_loop(iterated, extra_cond, loop_body, init_state): """Overload of for_loop that iterates over objects that define a length.""" n = builtins.dynamic_len(iterated) def while_body(iterate_index, *state): iterate = iterated[iterate_index] new_state = loop_body(iterate, *state) return (iterate_index + 1, ) + new_state def while_cond(iterate_index, *state): return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state)) results = while_loop(while_cond, while_body, init_state=(0, ) + init_state, extra_deps=(iterated, ), opts=dict(maximum_iterations=n)) # Dropping the iteration index because it's not syntactically visible. results = results[1:] # TODO (mdan): Remove this special case. id:976 # https://github.com/imdone/tensorflow/issues/977 if len(results) == 1: return results[0] return results