def test_TFRecordSeparateGetDataset(self): dataset_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), 'tfrecord_separate_get')) height = 300 width = 280 with self.test_session(): provider = DatasetDataProvider(_create_tfrecord_dataset(dataset_dir)) [image] = provider.get(['image']) [label] = provider.get(['label']) image = _resize_image(image, height, width) with session.Session('') as sess: with queues.QueueRunners(sess): image, label = sess.run([image, label]) self.assertListEqual([height, width, 3], list(image.shape)) self.assertListEqual([1], list(label.shape))
def test_TFRecordDataset(self): dataset_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), 'tfrecord_dataset')) height = 300 width = 280 with self.test_session(): test_dataset = _create_tfrecord_dataset(dataset_dir) provider = DatasetDataProvider(test_dataset) # key, image, label = provider.get(['record_key', 'image', 'label']) image, label = provider.get(['image', 'label']) image = _resize_image(image, height, width) with session.Session('') as sess: with queues.QueueRunners(sess): # key, image, label = sess.run([key, image, label]) image, label = sess.run([image, label]) # split_key = key.decode('utf-8').split(':') # self.assertEqual(2, len(split_key)) # self.assertEqual(test_dataset.data_sources[0], split_key[0]) # self.assertTrue(split_key[1].isdigit()) self.assertListEqual([height, width, 3], list(image.shape)) self.assertListEqual([1], list(label.shape))