예제 #1
0
def _known_len_for_stmt(iter_, extra_test, body, init_state):
  """Overload of for_stmt that iterates over objects that define a length."""
  n = builtins.dynamic_len(iter_)

  def while_body(iterate_index, *state):
    iterate = iter_[iterate_index]
    new_state = body(iterate, *state)
    return (iterate_index + 1,) + new_state

  def while_cond(iterate_index, *state):
    return gen_math_ops.logical_and(iterate_index < n, extra_test(*state))

  results = while_stmt(
      while_cond,
      while_body,
      init_state=(0,) + init_state,
      extra_deps=(iter_,),
      opts=dict(maximum_iterations=n))
  # Dropping the iteration index because it's not syntactically visible.
  results = results[1:]

  # TODO(mdan): Remove this special case.
  if len(results) == 1:
    return results[0]
  return results
예제 #2
0
def _known_len_for_loop(iterated, extra_cond, loop_body, init_state):
    """Overload of for_loop that iterates over objects that define a length."""
    n = builtins.dynamic_len(iterated)

    def while_body(iterate_index, *state):
        iterate = iterated[iterate_index]
        new_state = loop_body(iterate, *state)
        return (iterate_index + 1, ) + new_state

    def while_cond(iterate_index, *state):
        return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state))

    results = while_loop(while_cond,
                         while_body,
                         init_state=(0, ) + init_state,
                         extra_deps=(iterated, ),
                         opts=dict(maximum_iterations=n))
    # Dropping the iteration index because it's not syntactically visible.
    results = results[1:]

    # TODO (mdan): Remove this special case. id:976
    # https://github.com/imdone/tensorflow/issues/977
    if len(results) == 1:
        return results[0]
    return results