def testConflictingRecordKeyItem(self): dataset_dir = tempfile.mkdtemp('tfrecord_dataset') with self.cached_session(): with self.assertRaises(ValueError): dataset_data_provider.DatasetDataProvider( _create_tfrecord_dataset(dataset_dir), record_key='image')
def testTFRecordSeparateGetDataset(self): dataset_dir = tempfile.mkdtemp('tfrecord_separate_get') height = 300 width = 280 with self.cached_session(): provider = dataset_data_provider.DatasetDataProvider( _create_tfrecord_dataset(dataset_dir)) [image] = provider.get(['image']) [label] = provider.get(['label']) image = _resize_image(image, height, width) with session.Session('') as sess: with queues.QueueRunners(sess): image, label = sess.run([image, label]) self.assertListEqual([height, width, 3], list(image.shape)) self.assertListEqual([1], list(label.shape))
def testTFRecordDataset(self): dataset_dir = tempfile.mkdtemp('tfrecord_dataset') height = 300 width = 280 with self.cached_session(): test_dataset = _create_tfrecord_dataset(dataset_dir) provider = dataset_data_provider.DatasetDataProvider(test_dataset) key, image, label = provider.get(['record_key', 'image', 'label']) image = _resize_image(image, height, width) with session.Session('') as sess: with queues.QueueRunners(sess): key, image, label = sess.run([key, image, label]) split_key = key.decode('utf-8').split(':') self.assertEqual(2, len(split_key)) self.assertEqual(test_dataset.data_sources[0], split_key[0]) self.assertTrue(split_key[1].isdigit()) self.assertListEqual([height, width, 3], list(image.shape)) self.assertListEqual([1], list(label.shape))