Пример #1
0
    def test_TFRecordSeparateGetDataset(self):
        dataset_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
                                                           'tfrecord_separate_get'))

        height = 300
        width = 280

        with self.test_session():
            provider = DatasetDataProvider(_create_tfrecord_dataset(dataset_dir))

        [image] = provider.get(['image'])
        [label] = provider.get(['label'])
        image = _resize_image(image, height, width)

        with session.Session('') as sess:
            with queues.QueueRunners(sess):
                image, label = sess.run([image, label])
            self.assertListEqual([height, width, 3], list(image.shape))
            self.assertListEqual([1], list(label.shape))
Пример #2
0
    def test_TFRecordDataset(self):
        dataset_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), 'tfrecord_dataset'))

        height = 300
        width = 280

        with self.test_session():
            test_dataset = _create_tfrecord_dataset(dataset_dir)
            provider = DatasetDataProvider(test_dataset)
            # key, image, label = provider.get(['record_key', 'image', 'label'])
            image, label = provider.get(['image', 'label'])
            image = _resize_image(image, height, width)

            with session.Session('') as sess:
                with queues.QueueRunners(sess):
                    # key, image, label = sess.run([key, image, label])
                    image, label = sess.run([image, label])
            # split_key = key.decode('utf-8').split(':')
            # self.assertEqual(2, len(split_key))
            # self.assertEqual(test_dataset.data_sources[0], split_key[0])
            # self.assertTrue(split_key[1].isdigit())
            self.assertListEqual([height, width, 3], list(image.shape))
            self.assertListEqual([1], list(label.shape))