예제 #1
0
  def testIsLoopExit(self):
    exit_op = control_flow_ops.exit(1).op
    self.assertTrue(control_flow_util.IsLoopExit(exit_op))

    ref_exit = control_flow_ops.exit(test_ops.ref_output()).op
    self.assertTrue(control_flow_util.IsLoopExit(ref_exit))

    self.assertFalse(control_flow_util.IsLoopExit(test_ops.int_output().op))
예제 #2
0
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
  """Update pending count for the inputs of op and enqueue ready ops."""
  for x in op.inputs:
    pending_count[x.op] -= 1
    ready = (pending_count[x.op] == 0)
    if loop_state and not ready:
      ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op)
    if ready:
      if control_flow_util.IsLoopExit(x.op):
        # if x is an exit without real gradient, defer processing them.
        grad_state = loop_state.GetGradState(x.op, before=False)
        grad_state.deferred_exits.append(x)
        grad_state.pending_exits_count -= 1
        if grad_state.pending_exits_count == 0:
          # We now have all the exits so process them.
          has_not_none_grad = False
          for y in grad_state.deferred_exits:
            if _HasAnyNotNoneGrads(grads, y.op):
              has_not_none_grad = True
              queue.append(y.op)
            else:
              grad_state.unused_exits.append(y)
          if has_not_none_grad:
            # For an unused exit, if it has trainable outputs, backprop
            # a zero gradient. Otherwise, just ignore it.
            for y in grad_state.unused_exits:
              if _IsTrainable(y):
                _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
              queue.append(y.op)
          else:
            # All exits are "unused" so use None as gradient.
            for y in grad_state.unused_exits:
              queue.append(y.op)
      else:
        queue.append(x.op)
예제 #3
0
 def GetGradState(self, op, before):
     """Return the grad state for this op if it's in a forward loop context."""
     if before and util.IsLoopExit(op):
         forward_ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
         forward_ctxt = forward_ctxt.outer_context
         if forward_ctxt:
             forward_ctxt = forward_ctxt.GetWhileContext()
     else:
         forward_ctxt = util.GetWhileContext(op)
     if forward_ctxt:
         return self._map.get(forward_ctxt)
     return None
예제 #4
0
def _get_op_control_flow_context(op):
    """Returns the control flow of the given op.

  Args:
    op: tf.Operation for which the control flow context is requested.
  Returns:
    op_control_flow_context: which the is control flow context of the given
    op. If the operation type is LoopExit, returns the outer control flow
    context.
  """
    # pylint: disable=protected-access
    op_control_flow_context = op._control_flow_context
    # pylint: enable=protected-access
    if control_flow_util.IsLoopExit(op):
        op_control_flow_context = op_control_flow_context.outer_context
    return op_control_flow_context
예제 #5
0
def while_loop_op(op):
    """Returns true if op is one of the special ops of in a while loop.

  Args:
     op: A tf.Operation.

  Returns:
     True if the given op is one of [Switch, Merge, Enter, Exit,
     NextIteration, LoopCond], which are all building blocks for TF while
     loops.
  """
    return (control_flow_util.IsLoopSwitch(op)
            or control_flow_util.IsLoopMerge(op)
            or control_flow_util.IsLoopEnter(op)
            or control_flow_util.IsLoopExit(op) or loop_cond_op(op)
            or op.type in ('RefNextIteration', 'NextIteration'))
def MaybeCreateControlFlowState(between_op_list, between_ops,
                                colocate_gradients_with_ops):
  """Create the state for all the while loops involved in one gradients().

  We create a _ControlFlowState when there are while loops involved in
  gradients(). In gradients(), control flow logic is only invoked when
  the _ControlFlowState is not None.

  Note that this method modifies `between_op_list` and `between_ops`.
  """
  loop_state = None
  for op in between_op_list:
    if util.IsLoopExit(op):
      if loop_state is None:
        loop_state = _ControlFlowState()
      if colocate_gradients_with_ops:
        with ops.colocate_with(op):
          loop_state.AddWhileContext(op, between_op_list, between_ops)
      else:
        loop_state.AddWhileContext(op, between_op_list, between_ops)
  return loop_state