예제 #1
0
  def testIsLoopEnter(self):
    enter = gen_control_flow_ops.enter(1, frame_name="name").op
    self.assertTrue(control_flow_util.IsLoopEnter(enter))
    self.assertFalse(control_flow_util.IsLoopConstantEnter(enter))

    ref_enter = gen_control_flow_ops.ref_enter(test_ops.ref_output(),
                                               frame_name="name").op
    self.assertTrue(control_flow_util.IsLoopEnter(ref_enter))
    self.assertFalse(control_flow_util.IsLoopConstantEnter(ref_enter))

    const_enter = gen_control_flow_ops.enter(1, frame_name="name",
                                             is_constant=True).op
    self.assertTrue(control_flow_util.IsLoopEnter(const_enter))
    self.assertTrue(control_flow_util.IsLoopConstantEnter(const_enter))

    self.assertFalse(control_flow_util.IsLoopEnter(test_ops.int_output().op))
예제 #2
0
def while_loop_op(op):
    """Returns true if op is one of the special ops of in a while loop.

  Args:
     op: A tf.Operation.

  Returns:
     True if the given op is one of [Switch, Merge, Enter, Exit,
     NextIteration, LoopCond], which are all building blocks for TF while
     loops.
  """
    return (control_flow_util.IsLoopSwitch(op)
            or control_flow_util.IsLoopMerge(op)
            or control_flow_util.IsLoopEnter(op)
            or control_flow_util.IsLoopExit(op) or loop_cond_op(op)
            or op.type in ('RefNextIteration', 'NextIteration'))