def testOutOfRangeError(self):
    with self.test_session():
      [tfrecord_path] = test_utils.create_tfrecord_files(
          self.get_temp_dir(), num_files=1)

    key, value = parallel_reader.single_pass_read(
        tfrecord_path, reader_class=io_ops.TFRecordReader)
    init_op = variables.local_variables_initializer()

    with self.test_session() as sess:
      sess.run(init_op)
      with queues.QueueRunners(sess):
        num_reads = 11
        with self.assertRaises(errors_impl.OutOfRangeError):
          for _ in range(num_reads):
            sess.run([key, value])
예제 #2
0
    def testOutOfRangeError(self):
        with self.test_session():
            [tfrecord_path
             ] = test_utils.create_tfrecord_files(self.get_temp_dir(),
                                                  num_files=1)

        key, value = parallel_reader.single_pass_read(
            tfrecord_path, reader_class=io_ops.TFRecordReader)
        init_op = variables.local_variables_initializer()

        with self.test_session() as sess:
            sess.run(init_op)
            with queues.QueueRunners(sess):
                num_reads = 11
                with self.assertRaises(errors_impl.OutOfRangeError):
                    for _ in range(num_reads):
                        sess.run([key, value])
  def testTFRecordReader(self):
    with self.test_session():
      [tfrecord_path] = test_utils.create_tfrecord_files(
          self.get_temp_dir(), num_files=1)

    key, value = parallel_reader.single_pass_read(
        tfrecord_path, reader_class=io_ops.TFRecordReader)
    init_op = variables.local_variables_initializer()

    with self.test_session() as sess:
      sess.run(init_op)
      with queues.QueueRunners(sess):
        flowers = 0
        num_reads = 9
        for _ in range(num_reads):
          current_key, _ = sess.run([key, value])
          if 'flowers' in str(current_key):
            flowers += 1
        self.assertGreater(flowers, 0)
        self.assertEquals(flowers, num_reads)
예제 #4
0
    def testTFRecordReader(self):
        with self.test_session():
            [tfrecord_path
             ] = test_utils.create_tfrecord_files(self.get_temp_dir(),
                                                  num_files=1)

        key, value = parallel_reader.single_pass_read(
            tfrecord_path, reader_class=io_ops.TFRecordReader)
        init_op = variables.local_variables_initializer()

        with self.test_session() as sess:
            sess.run(init_op)
            with queues.QueueRunners(sess):
                flowers = 0
                num_reads = 9
                for _ in range(num_reads):
                    current_key, _ = sess.run([key, value])
                    if 'flowers' in str(current_key):
                        flowers += 1
                self.assertGreater(flowers, 0)
                self.assertEquals(flowers, num_reads)