コード例 #1
0
  def testThreads(self):
    with self.cached_session() as sess:
      # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
      zero64 = constant_op.constant(0, dtype=dtypes.int64)
      var = variables.VariableV1(zero64)
      count_up_to = var.count_up_to(3)
      queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
      variables.global_variables_initializer().run()
      qr = queue_runner_impl.QueueRunner(queue, [count_up_to,
                                                 _MockOp("bad_op")])
      threads = qr.create_threads(sess, start=True)
      self.assertEqual(sorted(t.name for t in threads),
                       ["QueueRunnerThread-fifo_queue-CountUpTo:0",
                        "QueueRunnerThread-fifo_queue-bad_op"])
      for t in threads:
        t.join()
      exceptions = qr.exceptions_raised
      self.assertEqual(1, len(exceptions))
      self.assertTrue("Operation not in the graph" in str(exceptions[0]))

      threads = qr.create_threads(sess, start=True)
      for t in threads:
        t.join()
      exceptions = qr.exceptions_raised
      self.assertEqual(1, len(exceptions))
      self.assertTrue("Operation not in the graph" in str(exceptions[0]))
コード例 #2
0
  def testQueueRunnerSerializationRoundTrip(self):
    graph = ops.Graph()
    with graph.as_default():
      queue = data_flow_ops.FIFOQueue(10, dtypes.float32, name="queue")
      enqueue_op = control_flow_ops.no_op(name="enqueue")
      close_op = control_flow_ops.no_op(name="close")
      cancel_op = control_flow_ops.no_op(name="cancel")
      qr0 = queue_runner_impl.QueueRunner(
          queue, [enqueue_op],
          close_op,
          cancel_op,
          queue_closed_exception_types=(errors_impl.OutOfRangeError,
                                        errors_impl.CancelledError))
      qr0_proto = queue_runner_impl.QueueRunner.to_proto(qr0)
      qr0_recon = queue_runner_impl.QueueRunner.from_proto(qr0_proto)
      self.assertEqual("queue", qr0_recon.queue.name)
      self.assertEqual(1, len(qr0_recon.enqueue_ops))
      self.assertEqual(enqueue_op, qr0_recon.enqueue_ops[0])
      self.assertEqual(close_op, qr0_recon.close_op)
      self.assertEqual(cancel_op, qr0_recon.cancel_op)
      self.assertEqual(
          (errors_impl.OutOfRangeError, errors_impl.CancelledError),
          qr0_recon.queue_closed_exception_types)

      # Assert we reconstruct an OutOfRangeError for QueueRunners
      # created before QueueRunnerDef had a queue_closed_exception_types field.
      del qr0_proto.queue_closed_exception_types[:]
      qr0_legacy_recon = queue_runner_impl.QueueRunner.from_proto(qr0_proto)
      self.assertEqual("queue", qr0_legacy_recon.queue.name)
      self.assertEqual(1, len(qr0_legacy_recon.enqueue_ops))
      self.assertEqual(enqueue_op, qr0_legacy_recon.enqueue_ops[0])
      self.assertEqual(close_op, qr0_legacy_recon.close_op)
      self.assertEqual(cancel_op, qr0_legacy_recon.cancel_op)
      self.assertEqual((errors_impl.OutOfRangeError,),
                       qr0_legacy_recon.queue_closed_exception_types)
コード例 #3
0
 def testRealDequeueEnqueue(self):
   with self.cached_session() as sess:
     q0 = data_flow_ops.FIFOQueue(3, dtypes.float32)
     enqueue0 = q0.enqueue((10.0,))
     close0 = q0.close()
     q1 = data_flow_ops.FIFOQueue(30, dtypes.float32)
     enqueue1 = q1.enqueue((q0.dequeue(),))
     dequeue1 = q1.dequeue()
     qr = queue_runner_impl.QueueRunner(q1, [enqueue1])
     threads = qr.create_threads(sess)
     for t in threads:
       t.start()
     # Enqueue 2 values, then close queue0.
     enqueue0.run()
     enqueue0.run()
     close0.run()
     # Wait for the queue runner to terminate.
     for t in threads:
       t.join()
     # It should have terminated cleanly.
     self.assertEqual(0, len(qr.exceptions_raised))
     # The 2 values should be in queue1.
     self.assertEqual(10.0, dequeue1.eval())
     self.assertEqual(10.0, dequeue1.eval())
     # And queue1 should now be closed.
     with self.assertRaisesRegexp(errors_impl.OutOfRangeError, "is closed"):
       dequeue1.eval()
コード例 #4
0
 def testName(self):
   with ops.name_scope("scope"):
     queue = data_flow_ops.FIFOQueue(10, dtypes.float32, name="queue")
   qr = queue_runner_impl.QueueRunner(queue, [control_flow_ops.no_op()])
   self.assertEqual("scope/queue", qr.name)
   queue_runner_impl.add_queue_runner(qr)
   self.assertEqual(
       1, len(ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS, "scope")))
コード例 #5
0
def boston_input_fn_with_queue(num_epochs=None):
    features, labels = boston_input_fn(num_epochs=num_epochs)

    # Create a minimal queue runner.
    fake_queue = data_flow_ops.FIFOQueue(30, dtypes.int32)
    queue_runner = queue_runner_impl.QueueRunner(fake_queue,
                                                 [constant_op.constant(0)])
    queue_runner_impl.add_queue_runner(queue_runner)

    return features, labels
コード例 #6
0
 def testRequestStopOnException(self):
   with self.cached_session() as sess:
     queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
     qr = queue_runner_impl.QueueRunner(queue, [_MockOp("not an op")])
     coord = coordinator.Coordinator()
     threads = qr.create_threads(sess, coord)
     for t in threads:
       t.start()
     # The exception should be re-raised when joining.
     with self.assertRaisesRegexp(ValueError, "Operation not in the graph"):
       coord.join()
コード例 #7
0
 def testStartQueueRunnersRaisesIfNotASession(self):
   zero64 = constant_op.constant(0, dtype=dtypes.int64)
   var = variables.VariableV1(zero64)
   count_up_to = var.count_up_to(3)
   queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
   init_op = variables.global_variables_initializer()
   qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
   queue_runner_impl.add_queue_runner(qr)
   with self.cached_session():
     init_op.run()
     with self.assertRaisesRegexp(TypeError, "tf.Session"):
       queue_runner_impl.start_queue_runners("NotASession")
コード例 #8
0
 def testStartQueueRunnersIgnoresMonitoredSession(self):
   zero64 = constant_op.constant(0, dtype=dtypes.int64)
   var = variables.VariableV1(zero64)
   count_up_to = var.count_up_to(3)
   queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
   init_op = variables.global_variables_initializer()
   qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
   queue_runner_impl.add_queue_runner(qr)
   with self.cached_session():
     init_op.run()
     threads = queue_runner_impl.start_queue_runners(
         monitored_session.MonitoredSession())
     self.assertFalse(threads)
コード例 #9
0
 def testMultipleSessions(self):
   with self.cached_session() as sess:
     with session.Session() as other_sess:
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
       var = variables.VariableV1(zero64)
       count_up_to = var.count_up_to(3)
       queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
       variables.global_variables_initializer().run()
       coord = coordinator.Coordinator()
       qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
       # NOTE that this test does not actually start the threads.
       threads = qr.create_threads(sess, coord=coord)
       other_threads = qr.create_threads(other_sess, coord=coord)
       self.assertEqual(len(threads), len(other_threads))
コード例 #10
0
 def testExceptionsCaptured(self):
   with self.test_session() as sess:
     queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
     qr = queue_runner_impl.QueueRunner(queue, ["i fail", "so fail"])
     threads = qr.create_threads(sess)
     variables.global_variables_initializer().run()
     for t in threads:
       t.start()
     for t in threads:
       t.join()
     exceptions = qr.exceptions_raised
     self.assertEqual(2, len(exceptions))
     self.assertTrue("Operation not in the graph" in str(exceptions[0]))
     self.assertTrue("Operation not in the graph" in str(exceptions[1]))
コード例 #11
0
 def testIgnoreMultiStarts(self):
   with self.cached_session() as sess:
     # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
     zero64 = constant_op.constant(0, dtype=dtypes.int64)
     var = variables.VariableV1(zero64)
     count_up_to = var.count_up_to(3)
     queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
     variables.global_variables_initializer().run()
     coord = coordinator.Coordinator()
     qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
     threads = []
     # NOTE that this test does not actually start the threads.
     threads.extend(qr.create_threads(sess, coord=coord))
     new_threads = qr.create_threads(sess, coord=coord)
     self.assertEqual([], new_threads)
コード例 #12
0
 def testGracePeriod(self):
   with self.cached_session() as sess:
     # The enqueue will quickly block.
     queue = data_flow_ops.FIFOQueue(2, dtypes.float32)
     enqueue = queue.enqueue((10.0,))
     dequeue = queue.dequeue()
     qr = queue_runner_impl.QueueRunner(queue, [enqueue])
     coord = coordinator.Coordinator()
     qr.create_threads(sess, coord, start=True)
     # Dequeue one element and then request stop.
     dequeue.op.run()
     time.sleep(0.02)
     coord.request_stop()
     # We should be able to join because the RequestStop() will cause
     # the queue to be closed and the enqueue to terminate.
     coord.join(stop_grace_period_secs=1.0)
コード例 #13
0
  def _testScopedExportWithQueue(self, test_dir, exported_filename):
    graph = ops.Graph()
    with graph.as_default():
      with ops.name_scope("queue1"):
        input_queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
        enqueue = input_queue.enqueue((9876), name="enqueue")
        close = input_queue.close(name="close")
        qr = queue_runner_impl.QueueRunner(input_queue, [enqueue], close)
        queue_runner_impl.add_queue_runner(qr)
        input_queue.dequeue(name="dequeue")

      orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, exported_filename),
          graph=ops.get_default_graph(),
          export_scope="queue1")

    return orig_meta_graph
コード例 #14
0
 def testStartQueueRunners(self):
   # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
   zero64 = constant_op.constant(0, dtype=dtypes.int64)
   var = variables.VariableV1(zero64)
   count_up_to = var.count_up_to(3)
   queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
   init_op = variables.global_variables_initializer()
   qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
   queue_runner_impl.add_queue_runner(qr)
   with self.cached_session() as sess:
     init_op.run()
     threads = queue_runner_impl.start_queue_runners(sess)
     for t in threads:
       t.join()
     self.assertEqual(0, len(qr.exceptions_raised))
     # The variable should be 3.
     self.assertEqual(3, var.eval())
コード例 #15
0
 def testTwoOps(self):
   with self.test_session() as sess:
     # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
     zero64 = constant_op.constant(0, dtype=dtypes.int64)
     var0 = variables.Variable(zero64)
     count_up_to_3 = var0.count_up_to(3)
     var1 = variables.Variable(zero64)
     count_up_to_30 = var1.count_up_to(30)
     queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
     qr = queue_runner_impl.QueueRunner(queue, [count_up_to_3, count_up_to_30])
     threads = qr.create_threads(sess)
     variables.global_variables_initializer().run()
     for t in threads:
       t.start()
     for t in threads:
       t.join()
     self.assertEqual(0, len(qr.exceptions_raised))
     self.assertEqual(3, var0.eval())
     self.assertEqual(30, var1.eval())
コード例 #16
0
 def testBasic(self):
   with self.cached_session() as sess:
     # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
     zero64 = constant_op.constant(0, dtype=dtypes.int64)
     var = variables.VariableV1(zero64)
     count_up_to = var.count_up_to(3)
     queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
     variables.global_variables_initializer().run()
     qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
     threads = qr.create_threads(sess)
     self.assertEqual(sorted(t.name for t in threads),
                      ["QueueRunnerThread-fifo_queue-CountUpTo:0"])
     for t in threads:
       t.start()
     for t in threads:
       t.join()
     self.assertEqual(0, len(qr.exceptions_raised))
     # The variable should be 3.
     self.assertEqual(3, var.eval())
コード例 #17
0
 def testRespectCoordShouldStop(self):
   with self.test_session() as sess:
     # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
     zero64 = constant_op.constant(0, dtype=dtypes.int64)
     var = variables.Variable(zero64)
     count_up_to = var.count_up_to(3)
     queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
     variables.global_variables_initializer().run()
     qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
     # As the coordinator to stop.  The queue runner should
     # finish immediately.
     coord = coordinator.Coordinator()
     coord.request_stop()
     threads = qr.create_threads(sess, coord)
     for t in threads:
       t.start()
     coord.join()
     self.assertEqual(0, len(qr.exceptions_raised))
     # The variable should be 0.
     self.assertEqual(0, var.eval())
コード例 #18
0
 def testTwoOps(self):
   with self.cached_session() as sess:
     # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
     zero64 = constant_op.constant(0, dtype=dtypes.int64)
     var0 = variables.VariableV1(zero64)
     count_up_to_3 = var0.count_up_to(3)
     var1 = variables.VariableV1(zero64)
     count_up_to_30 = var1.count_up_to(30)
     queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
     qr = queue_runner_impl.QueueRunner(queue, [count_up_to_3, count_up_to_30])
     threads = qr.create_threads(sess)
     self.assertEqual(sorted(t.name for t in threads),
                      ["QueueRunnerThread-fifo_queue-CountUpTo:0",
                       "QueueRunnerThread-fifo_queue-CountUpTo_1:0"])
     self.evaluate(variables.global_variables_initializer())
     for t in threads:
       t.start()
     for t in threads:
       t.join()
     self.assertEqual(0, len(qr.exceptions_raised))
     self.assertEqual(3, self.evaluate(var0))
     self.assertEqual(30, self.evaluate(var1))
コード例 #19
0
    def testMultiThreadedEstimateDataDistribution(self):
        num_classes = 10

        # Set up graph.
        random_seed.set_random_seed(1234)
        label = math_ops.cast(
            math_ops.round(random_ops.random_uniform([1]) * num_classes),
            dtypes_lib.int32)

        prob_estimate = sampling_ops._estimate_data_distribution(  # pylint: disable=protected-access
            label, num_classes)
        # Check that prob_estimate is well-behaved in a multithreaded context.
        _, _, [prob_estimate] = sampling_ops._verify_input(  # pylint: disable=protected-access
            [], label, [prob_estimate])

        # Use queues to run multiple threads over the graph, each of which
        # fetches `prob_estimate`.
        queue = data_flow_ops.FIFOQueue(capacity=25,
                                        dtypes=[prob_estimate.dtype],
                                        shapes=[prob_estimate.get_shape()])
        enqueue_op = queue.enqueue([prob_estimate])
        queue_runner_impl.add_queue_runner(
            queue_runner_impl.QueueRunner(queue, [enqueue_op] * 25))
        out_tensor = queue.dequeue()

        # Run the multi-threaded session.
        with self.cached_session() as sess:
            # Need to initialize variables that keep running total of classes seen.
            variables.global_variables_initializer().run()

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(coord=coord)

            for _ in range(25):
                sess.run([out_tensor])

            coord.request_stop()
            coord.join(threads)
コード例 #20
0
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session
from tensorflow.python.training import queue_runner_impl


zero64 = constant_op.constant(0, dtype=dtypes.int64)
var0 = variables.VariableV1(zero64)
count_up_to_3 = var0.count_up_to(3)
var1 = variables.VariableV1(zero64)
count_up_to_30 = var1.count_up_to(30)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
qr = queue_runner_impl.QueueRunner(queue, [count_up_to_3, count_up_to_30])
threads = qr.create_threads(sess)






from tensorflow.python import summary
from tensorflow.python.compiler.xla import xla
from tensorflow.python.eager import def_function
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops