コード例 #1
0
ファイル: control_flow.py プロジェクト: y8tang/tensorflow
def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
                       symbol_names, opts):
    """Overload of for_stmt that iterates over a TF range (and elides it)."""
    start, limit, delta = iter_.op.inputs

    iterate = compat_util.BasicRef(start)

    def aug_get_state():
        return (iterate.value, ) + get_state()

    def aug_set_state(aug_loop_vars):
        # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax.
        iterate.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:]
        # The iteration index is not "output" by the for loop. If the iterate
        # is used outside the loop, it will appear in the loop vars separately.
        set_state(loop_vars)

    def aug_body():
        body(iterate.value)
        iterate.value += delta

    def aug_test():
        main_test = math_ops.logical_or(
            math_ops.logical_and(delta >= 0, iterate.value < limit),
            math_ops.logical_and(delta < 0, iterate.value > limit))
        if extra_test is not None:
            return control_flow_ops.cond(main_test, extra_test, lambda: False)
        return main_test

    opts['maximum_iterations'] = math_ops.cast(
        misc.get_range_len(start, limit, delta), dtypes.int32)

    _tf_while_stmt(aug_test, aug_body, aug_get_state, aug_set_state,
                   ('<internal iterate>', ) + symbol_names, opts)
コード例 #2
0
def _tf_range_for_stmt(
    iter_, extra_test, body, get_state, set_state, symbol_names, opts):
  """Overload of for_stmt that iterates over a TF range (and elides it)."""
  start, limit, delta = iter_.op.inputs

  iterate = compat_util.BasicRef(start)

  def _value_or(name, var, default):
    if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)):
      return default
    return var

  def aug_get_state():
    state_vars = get_state()
    state_vars = tuple(
        _value_or(name, var, iterate.value)
        for name, var in zip(symbol_names, state_vars))
    return (iterate.value,) + state_vars

  def aug_set_state(aug_loop_vars):
    # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax.
    iterate.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:]
    # The iteration index is not "output" by the for loop. If the iterate
    # is used outside the loop, it will appear in the loop vars separately.
    set_state(loop_vars)

  def aug_body():
    body(iterate.value)
    iterate.value += delta

  def aug_test():
    # TODO(b/159713842): Remove once constant folding works.
    const_delta = tensor_util.constant_value(delta)
    if const_delta is not None:
      if const_delta >= 0:
        main_test = iterate.value < limit
      else:
        main_test = iterate.value > limit
    else:
      main_test = math_ops.logical_or(
          math_ops.logical_and(delta >= 0, iterate.value < limit),
          math_ops.logical_and(delta < 0, iterate.value > limit))

    if extra_test is not None:
      main_test = control_flow_ops.cond(main_test, extra_test, lambda: False)
    return main_test

  # TODO(b/134181679): Remove.
  if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
    opts['maximum_iterations'] = math_ops.cast(
        misc.get_range_len(start, limit, delta), dtypes.int32)

  _tf_while_stmt(
      aug_test,
      aug_body,
      aug_get_state,
      aug_set_state,
      ('<internal iterate>',) + symbol_names,
      opts)
コード例 #3
0
def _tf_iterator_for_stmt(
    iter_, extra_test, body, get_state, set_state, symbol_names, opts):
  """Overload of for_stmt that iterates over TF Iterators. See for_loop."""
  symbol_names = ('<internal has_next>',) + symbol_names
  has_next = compat_util.BasicRef(True)

  def aug_get_state():
    return (has_next.value,) + get_state()

  def aug_set_state(aug_loop_vars):
    # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax.
    has_next.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:]
    set_state(loop_vars)

  init_vars = aug_get_state()
  _verify_loop_init_vars(init_vars, symbol_names)

  def aug_body():
    """Main body passed to _tf_while_stmt."""
    opt_iterate = iterator_ops.get_next_as_optional(iter_)
    has_next.value = opt_iterate.has_value()
    loop_vars = aug_get_state()  # updated by set_state() in _tf_while_loop.

    def main_path():
      body(opt_iterate.get_value())
      new_loop_vars = aug_get_state()
      # Note: this verification duplicates the one performed in tf_while_stmt,
      # but needs to be done earlier to prevent the tf.cond from blowing up
      # first.
      _verify_tf_loop_vars(
          init_vars, loop_vars, new_loop_vars, symbol_names, opts)
      return new_loop_vars

    def noop_path():
      return loop_vars

    # TODO(mdan): If tf.while_loop supported Optional, this could be avoided.
    # Calling set_state so that get_state() _tf_while_loop sees the conditional
    # tensors.
    aug_set_state(
        control_flow_ops.cond(has_next.value, main_path, noop_path))

  def aug_test():
    # This value takes a complicated path to get here:
    #   prev_iteration_body -> get_state -> tf.while_loop (as loop var)
    #   -> current_iteration_body -> set_state -> has_next.value
    main_test = has_next.value
    if extra_test is not None:
      return control_flow_ops.cond(main_test, extra_test, lambda: False)
    return main_test

  _tf_while_stmt(
      aug_test,
      aug_body,
      aug_get_state,
      aug_set_state,
      symbol_names,
      opts)
コード例 #4
0
ファイル: control_flow.py プロジェクト: idodan1/thesis
def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
                           symbol_names, opts):
    """Overload of for_stmt that iterates over TF entities that admit a length."""
    n = py_builtins.len_(iter_)

    # TODO(b/117628877): Revisit performance once XLA has the necessary support.
    # Note: using a TensorArray creates an extra copy, but can calculate
    # gradients more efficiently than StridedSlice.
    ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
    iter_ = ta.unstack(iter_)

    iterate_index = compat_util.BasicRef(0)

    def aug_get_state():
        return (iterate_index.value, ) + get_state()

    def aug_set_state(aug_loop_vars):
        # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax.
        iterate_index.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:]
        # The iteration index is not "output" by the for loop. If the iterate
        # is used outside the loop, it will appear in the loop vars separately.
        set_state(loop_vars)

    def aug_body():
        body(iter_.read(iterate_index.value))
        iterate_index.value += 1

    def aug_test():
        main_test = iterate_index.value < n
        if extra_test is not None:
            return control_flow_ops.cond(main_test, extra_test, lambda: False)
        return main_test

    # TODO(b/159186914): Remove.
    if not control_flow_util.GraphOrParentsInXlaContext(
            ops.get_default_graph()):
        opts['maximum_iterations'] = n

    _tf_while_stmt(
        aug_test,
        aug_body,
        aug_get_state,
        aug_set_state,
        ('<internal iterate>', ) + symbol_names,
        opts,
    )
コード例 #5
0
def _tf_ragged_for_stmt(
    iter_, extra_test, body, get_state, set_state, symbol_names, opts):
  """Overload of for_stmt that iterates over TF ragged tensors."""
  init_vars = get_state()
  _verify_loop_init_vars(init_vars, symbol_names)

  # TODO(mdan): Move this into len()? Requires eager support.
  if iter_.shape and iter_.shape[0] is not None:
    n = iter_.shape[0]
  else:
    n = iter_.row_lengths()[0]

  iterate_index = compat_util.BasicRef(0)

  def aug_get_state():
    return (iterate_index.value,) + get_state()

  def aug_set_state(aug_loop_vars):
    # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax.
    iterate_index.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:]
    # The iteration index is not "output" by the for loop. If the iterate
    # is used outside the loop, it will appear in the loop vars separately.
    set_state(loop_vars)

  def aug_body():
    body(iter_[iterate_index.value])
    iterate_index.value += 1

  def aug_test():
    main_test = iterate_index.value < n
    if extra_test is not None:
      return control_flow_ops.cond(main_test, extra_test, lambda: False)
    return main_test

  # TODO(b/159186914): Remove.
  if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
    opts['maximum_iterations'] = n

  _tf_while_stmt(
      aug_test,
      aug_body,
      aug_get_state,
      aug_set_state,
      ('<internal iterate>',) + symbol_names,
      opts)