コード例 #1
0
  def testResetOfBlockingOperation(self):
    # We need each thread to keep its own device stack or the device scopes
    # won't be properly nested.
    ops.get_default_graph().switch_to_thread_local()
    with self.cached_session() as sess:
      q_empty = data_flow_ops.RandomShuffleQueue(5, 0, dtypes_lib.float32, (
          (),))
      dequeue_op = q_empty.dequeue()
      dequeue_many_op = q_empty.dequeue_many(1)
      dequeue_up_to_op = q_empty.dequeue_up_to(1)

      q_full = data_flow_ops.RandomShuffleQueue(5, 0, dtypes_lib.float32, ((),))
      sess.run(q_full.enqueue_many(([1.0, 2.0, 3.0, 4.0, 5.0],)))
      enqueue_op = q_full.enqueue((6.0,))
      enqueue_many_op = q_full.enqueue_many(([6.0],))

      threads = [
          self.checkedThread(
              self._blockingDequeue, args=(sess, dequeue_op)),
          self.checkedThread(
              self._blockingDequeueMany, args=(sess, dequeue_many_op)),
          self.checkedThread(
              self._blockingDequeueUpTo, args=(sess, dequeue_up_to_op)),
          self.checkedThread(
              self._blockingEnqueue, args=(sess, enqueue_op)),
          self.checkedThread(
              self._blockingEnqueueMany, args=(sess, enqueue_many_op))
      ]
      for t in threads:
        t.start()
      time.sleep(0.1)
      sess.close()  # Will cancel the blocked operations.
      for t in threads:
        t.join()
コード例 #2
0
  def testDequeueInDifferentOrders(self):
    with self.cached_session():
      # Specify seeds to make the test deterministic
      # (https://en.wikipedia.org/wiki/Taxicab_number).
      q1 = data_flow_ops.RandomShuffleQueue(
          10, 5, dtypes_lib.int32, ((),), seed=1729)
      q2 = data_flow_ops.RandomShuffleQueue(
          10, 5, dtypes_lib.int32, ((),), seed=87539319)
      enq1 = q1.enqueue_many(([1, 2, 3, 4, 5],))
      enq2 = q2.enqueue_many(([1, 2, 3, 4, 5],))
      deq1 = q1.dequeue()
      deq2 = q2.dequeue()

      enq1.run()
      enq1.run()
      enq2.run()
      enq2.run()

      results = [[], [], [], []]

      for _ in range(5):
        results[0].append(deq1.eval())
        results[1].append(deq2.eval())

      q1.close().run()
      q2.close().run()

      for _ in range(5):
        results[2].append(deq1.eval())
        results[3].append(deq2.eval())

      # No two should match
      for i in range(1, 4):
        for j in range(i):
          self.assertNotEqual(results[i], results[j])
コード例 #3
0
  def testResetOfBlockingOperation(self):
    with self.cached_session() as sess:
      q_empty = data_flow_ops.RandomShuffleQueue(5, 0, dtypes_lib.float32, (
          (),))
      dequeue_op = q_empty.dequeue()
      dequeue_many_op = q_empty.dequeue_many(1)
      dequeue_up_to_op = q_empty.dequeue_up_to(1)

      q_full = data_flow_ops.RandomShuffleQueue(5, 0, dtypes_lib.float32, ((),))
      sess.run(q_full.enqueue_many(([1.0, 2.0, 3.0, 4.0, 5.0],)))
      enqueue_op = q_full.enqueue((6.0,))
      enqueue_many_op = q_full.enqueue_many(([6.0],))

      threads = [
          self.checkedThread(
              self._blockingDequeue, args=(sess, dequeue_op)),
          self.checkedThread(
              self._blockingDequeueMany, args=(sess, dequeue_many_op)),
          self.checkedThread(
              self._blockingDequeueUpTo, args=(sess, dequeue_up_to_op)),
          self.checkedThread(
              self._blockingEnqueue, args=(sess, enqueue_op)),
          self.checkedThread(
              self._blockingEnqueueMany, args=(sess, enqueue_many_op))
      ]
      for t in threads:
        t.start()
      time.sleep(0.1)
      sess.close()  # Will cancel the blocked operations.
      for t in threads:
        t.join()
コード例 #4
0
 def testSelectQueueOutOfRange(self):
   with self.cached_session():
     q1 = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
     q2 = data_flow_ops.RandomShuffleQueue(15, 0, dtypes_lib.float32)
     enq_q = data_flow_ops.RandomShuffleQueue.from_list(3, [q1, q2])
     with self.assertRaisesOpError("is not in"):
       enq_q.dequeue().eval()
コード例 #5
0
  def testDequeueUpToWithTensorParameter(self):
    with self.cached_session():
      # Define a first queue that contains integer counts.
      dequeue_counts = [random.randint(1, 10) for _ in range(100)]
      count_q = data_flow_ops.RandomShuffleQueue(100, 0, dtypes_lib.int32)
      enqueue_counts_op = count_q.enqueue_many((dequeue_counts,))
      total_count = sum(dequeue_counts)

      # Define a second queue that contains total_count elements.
      elems = [random.randint(0, 100) for _ in range(total_count)]
      q = data_flow_ops.RandomShuffleQueue(total_count, 0, dtypes_lib.int32, (
          (),))
      enqueue_elems_op = q.enqueue_many((elems,))

      # Define a subgraph that first dequeues a count, then DequeuesUpTo
      # that number of elements.
      dequeued_t = q.dequeue_up_to(count_q.dequeue())

      enqueue_counts_op.run()
      enqueue_elems_op.run()

      dequeued_elems = []
      for _ in dequeue_counts:
        dequeued_elems.extend(dequeued_t.eval())
      self.assertItemsEqual(elems, dequeued_elems)
コード例 #6
0
  def testEnqueueManyWithShape(self):
    with self.cached_session():
      q = data_flow_ops.RandomShuffleQueue(
          10, 5, [dtypes_lib.int32, dtypes_lib.int32], shapes=[(), (2,)])
      q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
      self.assertAllEqual(4, q.size().eval())

      q2 = data_flow_ops.RandomShuffleQueue(
          10, 5, dtypes_lib.int32, shapes=tensor_shape.TensorShape([3]))
      q2.enqueue(([1, 2, 3],))
      q2.enqueue_many(([[1, 2, 3]],))
コード例 #7
0
  def testEnqueueAndBlockingDequeue(self):
    with self.cached_session() as sess:
      q = data_flow_ops.RandomShuffleQueue(3, 0, dtypes_lib.float32)
      elems = [10.0, 20.0, 30.0]
      enqueue_ops = [q.enqueue((x,)) for x in elems]
      dequeued_t = q.dequeue()

      def enqueue():
        # The enqueue_ops should run after the dequeue op has blocked.
        # TODO(mrry): Figure out how to do this without sleeping.
        time.sleep(0.1)
        for enqueue_op in enqueue_ops:
          sess.run(enqueue_op)

      results = []

      def dequeue():
        for _ in xrange(len(elems)):
          results.append(sess.run(dequeued_t))

      enqueue_thread = self.checkedThread(target=enqueue)
      dequeue_thread = self.checkedThread(target=dequeue)
      enqueue_thread.start()
      dequeue_thread.start()
      enqueue_thread.join()
      dequeue_thread.join()

      self.assertItemsEqual(elems, results)
コード例 #8
0
  def testBigDequeueMany(self):
    with self.cached_session() as sess:
      q = data_flow_ops.RandomShuffleQueue(2, 0, dtypes_lib.int32, ((),))
      elem = np.arange(4, dtype=np.int32)
      enq_list = [q.enqueue((e,)) for e in elem]
      deq = q.dequeue_many(4)

      results = []

      def blocking_dequeue():
        # Will only complete after 4 enqueues complete.
        results.extend(sess.run(deq))

      thread = self.checkedThread(target=blocking_dequeue)
      thread.start()
      # The dequeue should start and then block.
      for enq in enq_list:
        # TODO(mrry): Figure out how to do this without sleeping.
        time.sleep(0.1)
        self.assertEqual(len(results), 0)
        sess.run(enq)

      # Enough enqueued to unblock the dequeue
      thread.join()
      self.assertItemsEqual(elem, results)
コード例 #9
0
 def testReadUpToFromRandomShuffleQueue(self):
     shared_queue = data_flow_ops.RandomShuffleQueue(
         capacity=55,
         min_after_dequeue=28,
         dtypes=[dtypes_lib.string, dtypes_lib.string],
         shapes=[[], []])
     self._verify_read_up_to_out(shared_queue)
コード例 #10
0
  def testBlockingDequeueFromClosedQueue(self):
    with self.test_session() as sess:
      q = data_flow_ops.RandomShuffleQueue(10, 2, dtypes_lib.float32)
      elems = [10.0, 20.0, 30.0, 40.0]
      enqueue_op = q.enqueue_many((elems,))
      close_op = q.close()
      dequeued_t = q.dequeue()

      enqueue_op.run()

      results = []

      def dequeue():
        for _ in elems:
          results.append(sess.run(dequeued_t))
        self.assertItemsEqual(elems, results)
        # Expect the operation to fail due to the queue being closed.
        with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
                                     "is closed and has insufficient"):
          sess.run(dequeued_t)

      dequeue_thread = self.checkedThread(target=dequeue)
      dequeue_thread.start()
      # The close_op should run after the dequeue_thread has blocked.
      # TODO(mrry): Figure out how to do this without sleeping.
      time.sleep(0.1)
      # The dequeue thread blocked when it hit the min_size requirement.
      self.assertEqual(len(results), 2)
      close_op.run()
      dequeue_thread.join()
      # Once the queue is closed, the min_size requirement is lifted.
      self.assertEqual(len(results), 4)
コード例 #11
0
    def _apply_transform(self, transform_input):
        filename_queue = input_ops.string_input_producer(self._work_units,
                                                         shuffle=self.shuffle,
                                                         seed=self._seed)

        if self.shuffle:
            queue = data_flow_ops.RandomShuffleQueue(
                capacity=self.queue_capacity,
                min_after_dequeue=self.min_after_dequeue,
                dtypes=[dtypes.string, dtypes.string],
                shapes=[[], []],
                seed=self.seed)
        else:
            queue = data_flow_ops.FIFOQueue(
                capacity=self.queue_capacity,
                dtypes=[dtypes.string, dtypes.string],
                shapes=[[], []])

        enqueue_ops = []
        for _ in range(self.num_threads):
            reader = self._reader_cls(**self._reader_kwargs)
            enqueue_ops.append(queue.enqueue(reader.read(filename_queue)))

        runner = queue_runner.QueueRunner(queue, enqueue_ops)
        queue_runner.add_queue_runner(runner)
        dequeued = queue.dequeue_many(self.batch_size)

        # pylint: disable=not-callable
        return self.return_type(*dequeued)
コード例 #12
0
  def testMultiDequeueUpToNoBlocking(self):
    with self.cached_session() as sess:
      q = data_flow_ops.RandomShuffleQueue(
          10, 0, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (2,)))
      float_elems = [
          10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0
      ]
      int_elems = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14],
                   [15, 16], [17, 18], [19, 20]]
      enqueue_op = q.enqueue_many((float_elems, int_elems))
      dequeued_t = q.dequeue_up_to(4)
      dequeued_single_t = q.dequeue()

      enqueue_op.run()

      results = []
      float_val, int_val = sess.run(dequeued_t)
      # dequeue_up_to has undefined shape.
      self.assertEqual([None], dequeued_t[0].get_shape().as_list())
      self.assertEqual([None, 2], dequeued_t[1].get_shape().as_list())
      results.extend(zip(float_val, int_val.tolist()))

      float_val, int_val = sess.run(dequeued_t)
      results.extend(zip(float_val, int_val.tolist()))

      float_val, int_val = sess.run(dequeued_single_t)
      self.assertEqual(float_val.shape, dequeued_single_t[0].get_shape())
      self.assertEqual(int_val.shape, dequeued_single_t[1].get_shape())
      results.append((float_val, int_val.tolist()))

      float_val, int_val = sess.run(dequeued_single_t)
      results.append((float_val, int_val.tolist()))

      self.assertItemsEqual(zip(float_elems, int_elems), results)
コード例 #13
0
  def testParallelEnqueue(self):
    with self.cached_session() as sess:
      q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
      elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
      enqueue_ops = [q.enqueue((x,)) for x in elems]
      dequeued_t = q.dequeue()

      # Run one producer thread for each element in elems.
      def enqueue(enqueue_op):
        sess.run(enqueue_op)

      threads = [
          self.checkedThread(
              target=enqueue, args=(e,)) for e in enqueue_ops
      ]
      for thread in threads:
        thread.start()
      for thread in threads:
        thread.join()

      # Dequeue every element using a single thread.
      results = []
      for _ in xrange(len(elems)):
        results.append(dequeued_t.eval())
      self.assertItemsEqual(elems, results)
コード例 #14
0
  def testBlockingDequeueUpToFromClosedQueueReturnsRemainder(self):
    with self.cached_session() as sess:
      q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, ((),))
      elems = [10.0, 20.0, 30.0, 40.0]
      enqueue_op = q.enqueue_many((elems,))
      close_op = q.close()
      dequeued_t = q.dequeue_up_to(3)

      enqueue_op.run()

      results = []

      def dequeue():
        results.extend(sess.run(dequeued_t))
        self.assertEquals(3, len(results))
        results.extend(sess.run(dequeued_t))
        self.assertEquals(4, len(results))

      dequeue_thread = self.checkedThread(target=dequeue)
      dequeue_thread.start()
      # The close_op should run after the dequeue_thread has blocked.
      # TODO(mrry): Figure out how to do this without sleeping.
      time.sleep(0.1)
      close_op.run()
      dequeue_thread.join()
      self.assertItemsEqual(results, elems)
コード例 #15
0
 def testEnqueue(self):
   with self.cached_session():
     q = data_flow_ops.RandomShuffleQueue(10, 5, dtypes_lib.float32)
     enqueue_op = q.enqueue((10.0,))
     self.assertAllEqual(0, q.size().eval())
     enqueue_op.run()
     self.assertAllEqual(1, q.size().eval())
コード例 #16
0
  def testBlockingDequeueUpToSmallerThanMinAfterDequeue(self):
    with self.cached_session() as sess:
      q = data_flow_ops.RandomShuffleQueue(
          capacity=10,
          min_after_dequeue=2,
          dtypes=dtypes_lib.float32,
          shapes=((),))
      elems = [10.0, 20.0, 30.0, 40.0]
      enqueue_op = q.enqueue_many((elems,))
      close_op = q.close()
      dequeued_t = q.dequeue_up_to(3)

      enqueue_op.run()

      results = []

      def dequeue():
        results.extend(sess.run(dequeued_t))
        self.assertEquals(3, len(results))
        # min_after_dequeue is 2, we ask for 3 elements, and we end up only
        # getting the remaining 1.
        results.extend(sess.run(dequeued_t))
        self.assertEquals(4, len(results))

      dequeue_thread = self.checkedThread(target=dequeue)
      dequeue_thread.start()
      # The close_op should run after the dequeue_thread has blocked.
      # TODO(mrry): Figure out how to do this without sleeping.
      time.sleep(0.1)
      close_op.run()
      dequeue_thread.join()
      self.assertItemsEqual(results, elems)
コード例 #17
0
  def testBlockingDequeueUpTo(self):
    with self.cached_session() as sess:
      q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, ((),))
      elems = [10.0, 20.0, 30.0, 40.0]
      enqueue_op = q.enqueue_many((elems,))
      dequeued_t = q.dequeue_up_to(4)

      dequeued_elems = []

      def enqueue():
        # The enqueue_op should run after the dequeue op has blocked.
        # TODO(mrry): Figure out how to do this without sleeping.
        time.sleep(0.1)
        sess.run(enqueue_op)

      def dequeue():
        dequeued_elems.extend(sess.run(dequeued_t).tolist())

      enqueue_thread = self.checkedThread(target=enqueue)
      dequeue_thread = self.checkedThread(target=dequeue)
      enqueue_thread.start()
      dequeue_thread.start()
      enqueue_thread.join()
      dequeue_thread.join()

      self.assertItemsEqual(elems, dequeued_elems)
コード例 #18
0
  def testParallelDequeueUpToRandomPartition(self):
    with self.cached_session() as sess:
      dequeue_sizes = [random.randint(50, 150) for _ in xrange(10)]
      total_elements = sum(dequeue_sizes)
      q = data_flow_ops.RandomShuffleQueue(
          total_elements, 0, dtypes_lib.float32, shapes=())

      elems = [10.0 * x for x in xrange(total_elements)]
      enqueue_op = q.enqueue_many((elems,))
      dequeue_ops = [q.dequeue_up_to(size) for size in dequeue_sizes]

      enqueue_op.run()

      # Dequeue random number of items in parallel on 10 threads.
      dequeued_elems = []

      def dequeue(dequeue_op):
        dequeued_elems.extend(sess.run(dequeue_op))

      threads = []
      for dequeue_op in dequeue_ops:
        threads.append(self.checkedThread(target=dequeue, args=(dequeue_op,)))
      for thread in threads:
        thread.start()
      for thread in threads:
        thread.join()
      self.assertItemsEqual(elems, dequeued_elems)
コード例 #19
0
  def testBlockingDequeueManyFromClosedQueueWithElementsRemaining(self):
    with self.cached_session() as sess:
      q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, ((),))
      elems = [10.0, 20.0, 30.0, 40.0]
      enqueue_op = q.enqueue_many((elems,))
      close_op = q.close()
      dequeued_t = q.dequeue_many(3)
      cleanup_dequeue_t = q.dequeue_many(q.size())

      enqueue_op.run()

      results = []

      def dequeue():
        results.extend(sess.run(dequeued_t))
        self.assertEqual(len(results), 3)
        # Expect the operation to fail due to the queue being closed.
        with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
                                     "is closed and has insufficient"):
          sess.run(dequeued_t)
        # While the last dequeue failed, we want to insure that it returns
        # any elements that it potentially reserved to dequeue. Thus the
        # next cleanup should return a single element.
        results.extend(sess.run(cleanup_dequeue_t))

      dequeue_thread = self.checkedThread(target=dequeue)
      dequeue_thread.start()
      # The close_op should run after the dequeue_thread has blocked.
      # TODO(mrry): Figure out how to do this without sleeping.
      time.sleep(0.1)
      close_op.run()
      dequeue_thread.join()
      self.assertEqual(len(results), 4)
コード例 #20
0
  def testBlockingEnqueueManyToFullQueue(self):
    with self.cached_session() as sess:
      q = data_flow_ops.RandomShuffleQueue(4, 0, dtypes_lib.float32, ((),))
      elems = [10.0, 20.0, 30.0, 40.0]
      enqueue_op = q.enqueue_many((elems,))
      blocking_enqueue_op = q.enqueue_many(([50.0, 60.0],))
      dequeued_t = q.dequeue()

      enqueue_op.run()

      def blocking_enqueue():
        sess.run(blocking_enqueue_op)

      thread = self.checkedThread(target=blocking_enqueue)
      thread.start()
      # The dequeue ops should run after the blocking_enqueue_op has blocked.
      # TODO(mrry): Figure out how to do this without sleeping.
      time.sleep(0.1)

      results = []
      for _ in elems:
        time.sleep(0.01)
        results.append(dequeued_t.eval())
      results.append(dequeued_t.eval())
      results.append(dequeued_t.eval())
      self.assertItemsEqual(elems + [50.0, 60.0], results)
      # There wasn't room for 50.0 or 60.0 in the queue when the first
      # element was dequeued.
      self.assertNotEqual(50.0, results[0])
      self.assertNotEqual(60.0, results[0])
      # Similarly for 60.0 and the second element.
      self.assertNotEqual(60.0, results[1])
      thread.join()
コード例 #21
0
  def testBlockingDequeueManyFromClosedQueue(self):
    with self.cached_session() as sess:
      q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, ((),))
      elems = [10.0, 20.0, 30.0, 40.0]
      enqueue_op = q.enqueue_many((elems,))
      close_op = q.close()
      dequeued_t = q.dequeue_many(4)

      enqueue_op.run()

      progress = []  # Must be mutable

      def dequeue():
        self.assertItemsEqual(elems, sess.run(dequeued_t))
        progress.append(1)
        # Expect the operation to fail due to the queue being closed.
        with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
                                     "is closed and has insufficient"):
          sess.run(dequeued_t)
        progress.append(2)

      self.assertEqual(len(progress), 0)
      dequeue_thread = self.checkedThread(target=dequeue)
      dequeue_thread.start()
      # The close_op should run after the dequeue_thread has blocked.
      # TODO(mrry): Figure out how to do this without sleeping.
      for _ in range(100):
        time.sleep(0.01)
        if len(progress) == 1:
          break
      self.assertEqual(len(progress), 1)
      time.sleep(0.01)
      close_op.run()
      dequeue_thread.join()
      self.assertEqual(len(progress), 2)
コード例 #22
0
def parallel_read(data_sources,
                  reader_class,
                  num_epochs=None,
                  num_readers=4,
                  reader_kwargs=None,
                  shuffle=True,
                  dtypes=None,
                  capacity=256,
                  min_after_dequeue=128):
    """Reads multiple records in parallel from data_sources using n readers.

  It uses a ParallelReader to read from multiple files in  parallel using
  multiple readers created using `reader_class` with `reader_kwargs'.

  If shuffle is True the common_queue would be a RandomShuffleQueue otherwise
  it would be a FIFOQueue.

  Usage:
      data_sources = ['path_to/train*']
      key, value = parallel_read(data_sources, tf.CSVReader, num_readers=4)

  Args:
    data_sources: a list/tuple of files or the location of the data, i.e.
      /cns/../train@128, /cns/.../train* or /tmp/.../train*
    reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader
    num_epochs: The number of times each data source is read. If left as None,
        the data will be cycled through indefinitely.
    num_readers: a integer, number of Readers to create.
    reader_kwargs: an optional dict, of kwargs for the reader.
    shuffle: boolean, wether should shuffle the files and the records by using
      RandomShuffleQueue as common_queue.
    dtypes:  A list of types.  The length of dtypes must equal the number
        of elements in each record. If it is None it will default to
        [tf.string, tf.string] for (key, value).
    capacity: integer, capacity of the common_queue.
    min_after_dequeue: integer, minimum number of records in the common_queue
      after dequeue. Needed for a good shuffle.

  Returns:
    key, value: a tuple of keys and values from the data_source.
  """
    data_files = get_data_files(data_sources)
    with ops.name_scope('parallel_read'):
        filename_queue = tf_input.string_input_producer(data_files,
                                                        num_epochs=num_epochs,
                                                        shuffle=shuffle)
        dtypes = dtypes or [tf_dtypes.string, tf_dtypes.string]
        if shuffle:
            common_queue = data_flow_ops.RandomShuffleQueue(
                capacity=capacity,
                min_after_dequeue=min_after_dequeue,
                dtypes=dtypes)
        else:
            common_queue = data_flow_ops.FIFOQueue(capacity=capacity,
                                                   dtypes=dtypes)

        return ParallelReader(reader_class,
                              common_queue,
                              num_readers=num_readers,
                              reader_kwargs=reader_kwargs).read(filename_queue)
コード例 #23
0
 def testEnqueueWithShape(self):
   with self.cached_session():
     q = data_flow_ops.RandomShuffleQueue(
         10, 5, dtypes_lib.float32, shapes=tensor_shape.TensorShape([3, 2]))
     enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
     enqueue_correct_op.run()
     self.assertAllEqual(1, q.size().eval())
     with self.assertRaises(ValueError):
       q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
コード例 #24
0
  def testEmptyDequeueUpTo(self):
    with self.cached_session():
      q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, shapes=())
      enqueue_op = q.enqueue((10.0,))
      dequeued_t = q.dequeue_up_to(0)

      self.assertEqual([], dequeued_t.eval().tolist())
      enqueue_op.run()
      self.assertEqual([], dequeued_t.eval().tolist())
コード例 #25
0
  def testHighDimension(self):
    with self.cached_session():
      q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.int32, (
          (4, 4, 4, 4)))
      elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
      enqueue_op = q.enqueue_many((elems,))
      dequeued_t = q.dequeue_many(10)

      enqueue_op.run()
      self.assertItemsEqual(dequeued_t.eval().tolist(), elems.tolist())
コード例 #26
0
  def testEmptyEnqueueMany(self):
    with self.cached_session():
      q = data_flow_ops.RandomShuffleQueue(10, 5, dtypes_lib.float32)
      empty_t = constant_op.constant(
          [], dtype=dtypes_lib.float32, shape=[0, 2, 3])
      enqueue_op = q.enqueue_many((empty_t,))
      size_t = q.size()

      self.assertEqual(0, size_t.eval())
      enqueue_op.run()
      self.assertEqual(0, size_t.eval())
コード例 #27
0
ファイル: read_data.py プロジェクト: wgwangang/mycodes
def string_int_pair_producer(strings_ints):
    capacity = len(strings_ints[0])

    q = data_flow_ops.RandomShuffleQueue(capacity=capacity,
                                         dtypes=[tf.string, tf.int64],
                                         min_after_dequeue=0,
                                         name="name_label_queue")
    enq = q.enqueue_many(strings_ints)
    queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq]))

    return q
コード例 #28
0
  def testEnqueueToClosedQueue(self):
    with self.cached_session():
      q = data_flow_ops.RandomShuffleQueue(10, 4, dtypes_lib.float32)
      enqueue_op = q.enqueue((10.0,))
      close_op = q.close()

      enqueue_op.run()
      close_op.run()

      # Expect the operation to fail due to the queue being closed.
      with self.assertRaisesRegexp(errors_impl.CancelledError, "is closed"):
        enqueue_op.run()
コード例 #29
0
  def testDequeueUpToNoBlocking(self):
    with self.cached_session():
      q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, ((),))
      elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
      enqueue_op = q.enqueue_many((elems,))
      dequeued_t = q.dequeue_up_to(5)

      enqueue_op.run()

      results = dequeued_t.eval().tolist()
      results.extend(dequeued_t.eval())
      self.assertItemsEqual(elems, results)
コード例 #30
0
  def testQueueSizeAfterEnqueueAndDequeue(self):
    with self.cached_session():
      q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
      enqueue_op = q.enqueue((10.0,))
      dequeued_t = q.dequeue()
      size = q.size()
      self.assertEqual([], size.get_shape())

      enqueue_op.run()
      self.assertEqual([1], size.eval())
      dequeued_t.op.run()
      self.assertEqual([0], size.eval())