Exemplo n.º 1
0
    def testConflictingRecordKeyItem(self):
        dataset_dir = tempfile.mkdtemp('tfrecord_dataset')

        with self.cached_session():
            with self.assertRaises(ValueError):
                dataset_data_provider.DatasetDataProvider(
                    _create_tfrecord_dataset(dataset_dir), record_key='image')
Exemplo n.º 2
0
    def testTFRecordSeparateGetDataset(self):
        dataset_dir = tempfile.mkdtemp('tfrecord_separate_get')

        height = 300
        width = 280

        with self.cached_session():
            provider = dataset_data_provider.DatasetDataProvider(
                _create_tfrecord_dataset(dataset_dir))
        [image] = provider.get(['image'])
        [label] = provider.get(['label'])
        image = _resize_image(image, height, width)

        with session.Session('') as sess:
            with queues.QueueRunners(sess):
                image, label = sess.run([image, label])
            self.assertListEqual([height, width, 3], list(image.shape))
            self.assertListEqual([1], list(label.shape))
Exemplo n.º 3
0
    def testTFRecordDataset(self):
        dataset_dir = tempfile.mkdtemp('tfrecord_dataset')

        height = 300
        width = 280

        with self.cached_session():
            test_dataset = _create_tfrecord_dataset(dataset_dir)
            provider = dataset_data_provider.DatasetDataProvider(test_dataset)
            key, image, label = provider.get(['record_key', 'image', 'label'])
            image = _resize_image(image, height, width)

            with session.Session('') as sess:
                with queues.QueueRunners(sess):
                    key, image, label = sess.run([key, image, label])
            split_key = key.decode('utf-8').split(':')
            self.assertEqual(2, len(split_key))
            self.assertEqual(test_dataset.data_sources[0], split_key[0])
            self.assertTrue(split_key[1].isdigit())
            self.assertListEqual([height, width, 3], list(image.shape))
            self.assertListEqual([1], list(label.shape))