def testIsLoopExit(self): exit_op = control_flow_ops.exit(1).op self.assertTrue(control_flow_util.IsLoopExit(exit_op)) ref_exit = control_flow_ops.exit(test_ops.ref_output()).op self.assertTrue(control_flow_util.IsLoopExit(ref_exit)) self.assertFalse(control_flow_util.IsLoopExit(test_ops.int_output().op))
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state): """Update pending count for the inputs of op and enqueue ready ops.""" for x in op.inputs: pending_count[x.op] -= 1 ready = (pending_count[x.op] == 0) if loop_state and not ready: ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op) if ready: if control_flow_util.IsLoopExit(x.op): # if x is an exit without real gradient, defer processing them. grad_state = loop_state.GetGradState(x.op, before=False) grad_state.deferred_exits.append(x) grad_state.pending_exits_count -= 1 if grad_state.pending_exits_count == 0: # We now have all the exits so process them. has_not_none_grad = False for y in grad_state.deferred_exits: if _HasAnyNotNoneGrads(grads, y.op): has_not_none_grad = True queue.append(y.op) else: grad_state.unused_exits.append(y) if has_not_none_grad: # For an unused exit, if it has trainable outputs, backprop # a zero gradient. Otherwise, just ignore it. for y in grad_state.unused_exits: if _IsTrainable(y): _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) queue.append(y.op) else: # All exits are "unused" so use None as gradient. for y in grad_state.unused_exits: queue.append(y.op) else: queue.append(x.op)
def GetGradState(self, op, before): """Return the grad state for this op if it's in a forward loop context.""" if before and util.IsLoopExit(op): forward_ctxt = op._get_control_flow_context() # pylint: disable=protected-access forward_ctxt = forward_ctxt.outer_context if forward_ctxt: forward_ctxt = forward_ctxt.GetWhileContext() else: forward_ctxt = util.GetWhileContext(op) if forward_ctxt: return self._map.get(forward_ctxt) return None
def _get_op_control_flow_context(op): """Returns the control flow of the given op. Args: op: tf.Operation for which the control flow context is requested. Returns: op_control_flow_context: which the is control flow context of the given op. If the operation type is LoopExit, returns the outer control flow context. """ # pylint: disable=protected-access op_control_flow_context = op._control_flow_context # pylint: enable=protected-access if control_flow_util.IsLoopExit(op): op_control_flow_context = op_control_flow_context.outer_context return op_control_flow_context
def while_loop_op(op): """Returns true if op is one of the special ops of in a while loop. Args: op: A tf.Operation. Returns: True if the given op is one of [Switch, Merge, Enter, Exit, NextIteration, LoopCond], which are all building blocks for TF while loops. """ return (control_flow_util.IsLoopSwitch(op) or control_flow_util.IsLoopMerge(op) or control_flow_util.IsLoopEnter(op) or control_flow_util.IsLoopExit(op) or loop_cond_op(op) or op.type in ('RefNextIteration', 'NextIteration'))
def MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops): """Create the state for all the while loops involved in one gradients(). We create a _ControlFlowState when there are while loops involved in gradients(). In gradients(), control flow logic is only invoked when the _ControlFlowState is not None. Note that this method modifies `between_op_list` and `between_ops`. """ loop_state = None for op in between_op_list: if util.IsLoopExit(op): if loop_state is None: loop_state = _ControlFlowState() if colocate_gradients_with_ops: with ops.colocate_with(op): loop_state.AddWhileContext(op, between_op_list, between_ops) else: loop_state.AddWhileContext(op, between_op_list, between_ops) return loop_state