def testOutOfRangeError(self): with self.cached_session(): [tfrecord_path ] = test_utils.create_tfrecord_files(self.get_temp_dir(), num_files=1) key, value = parallel_reader.single_pass_read( tfrecord_path, reader_class=io_ops.TFRecordReader) init_op = variables.local_variables_initializer() with self.cached_session() as sess: sess.run(init_op) with queues.QueueRunners(sess): num_reads = 11 with self.assertRaises(errors_impl.OutOfRangeError): for _ in range(num_reads): sess.run([key, value])
def testTFRecordSeparateGetDataset(self): dataset_dir = tempfile.mkdtemp('tfrecord_separate_get') height = 300 width = 280 with self.cached_session(): provider = dataset_data_provider.DatasetDataProvider( _create_tfrecord_dataset(dataset_dir)) [image] = provider.get(['image']) [label] = provider.get(['label']) image = _resize_image(image, height, width) with session.Session('') as sess: with queues.QueueRunners(sess): image, label = sess.run([image, label]) self.assertListEqual([height, width, 3], list(image.shape)) self.assertListEqual([1], list(label.shape))
def testTFRecordReader(self): with self.cached_session(): [tfrecord_path ] = test_utils.create_tfrecord_files(self.get_temp_dir(), num_files=1) key, value = parallel_reader.single_pass_read( tfrecord_path, reader_class=io_ops.TFRecordReader) init_op = variables.local_variables_initializer() with self.cached_session() as sess: sess.run(init_op) with queues.QueueRunners(sess): flowers = 0 num_reads = 9 for _ in range(num_reads): current_key, _ = sess.run([key, value]) if 'flowers' in str(current_key): flowers += 1 self.assertGreater(flowers, 0) self.assertEqual(flowers, num_reads)
def testTFRecordDataset(self): dataset_dir = tempfile.mkdtemp('tfrecord_dataset') height = 300 width = 280 with self.cached_session(): test_dataset = _create_tfrecord_dataset(dataset_dir) provider = dataset_data_provider.DatasetDataProvider(test_dataset) key, image, label = provider.get(['record_key', 'image', 'label']) image = _resize_image(image, height, width) with session.Session('') as sess: with queues.QueueRunners(sess): key, image, label = sess.run([key, image, label]) split_key = key.decode('utf-8').split(':') self.assertEqual(2, len(split_key)) self.assertEqual(test_dataset.data_sources[0], split_key[0]) self.assertTrue(split_key[1].isdigit()) self.assertListEqual([height, width, 3], list(image.shape)) self.assertListEqual([1], list(label.shape))