Ejemplo n.º 1
0
    def make_data_provider(self, **kwargs):
        """Creates DataProvider instance for this input pipeline. Additional keyword arguments
        are passed to the DataProvider.
        """
        context_keys_to_features = self._create_context_features()
        sequence_keys_to_features = self._create_sequence_features()
        items_to_handlers = self._create_items_to_handlers()

        decoder = TFSequenceExampleDecoder(
            context_keys_to_features=context_keys_to_features,
            sequence_keys_to_features=sequence_keys_to_features,
            items_to_handlers=items_to_handlers)

        dataset = Dataset(
            data_sources=self.data_files,
            reader=tf.TFRecordReader,
            decoder=decoder,
            num_samples=self.meta_data.get('num_samples', {}).get(self.mode),
            num_classes=self.meta_data.get('num_classes'),
            items_to_descriptions=self.meta_data.get('items_to_descriptions',
                                                     {}),
            meta_data=self.meta_data,
            labels_to_names=self.meta_data.get('labels_to_classes'))

        return DatasetDataProvider(dataset=dataset,
                                   shuffle=self.shuffle,
                                   num_epochs=self.num_epochs,
                                   **kwargs)
Ejemplo n.º 2
0
    def make_data_provider(self, **kwargs):
        """Creates DataProvider instance for this input pipeline. Additional keyword arguments
        are passed to the DataProvider.
        """
        splitter_source = SplitTokensDecoder(
            tokens_feature_name='source_token',
            length_feature_name='source_len',
            append_token='SEQUENCE_END',
            delimiter=self.source_delimiter)

        splitter_target = SplitTokensDecoder(
            tokens_feature_name='target_token',
            length_feature_name='target_len',
            prepend_token='SEQUENCE_START',
            append_token='SEQUENCE_END',
            delimiter=self.target_delimiter)

        keys_to_features = {
            self.source_field: tf.FixedLenFeature((), tf.string),
            self.target_field: tf.FixedLenFeature((),
                                                  tf.string,
                                                  default_value="")
        }

        items_to_handlers = {
            'source_token':
            tfslim.tfexample_decoder.ItemHandlerCallback(
                keys=[self.source_field],
                func=lambda dict: splitter_source.decode(
                    dict[self.source_field], ['source_token'])[0]),
            'source_len':
            tfslim.tfexample_decoder.ItemHandlerCallback(
                keys=[self.source_field],
                func=lambda dict: splitter_source.decode(
                    dict[self.source_field], ['source_len'])[0]),
            'target_token':
            tfslim.tfexample_decoder.ItemHandlerCallback(
                keys=[self.target_field],
                func=lambda dict: splitter_target.decode(
                    dict[self.target_field], ['target_token'])[0]),
            'target_len':
            tfslim.tfexample_decoder.ItemHandlerCallback(
                keys=[self.target_field],
                func=lambda dict: splitter_target.decode(
                    dict[self.target_field], ['target_len'])[0])
        }

        decoder = TFExampleDecoder(keys_to_features, items_to_handlers)

        dataset = Dataset(data_sources=self.files,
                          reader=tf.TFRecordReader,
                          decoder=decoder)

        return DatasetDataProvider(dataset=dataset,
                                   shuffle=self.shuffle,
                                   num_epochs=self.num_epochs,
                                   **kwargs)
Ejemplo n.º 3
0
    def test_TFRecordSeparateGetDataset(self):
        dataset_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
                                                           'tfrecord_separate_get'))

        height = 300
        width = 280

        with self.test_session():
            provider = DatasetDataProvider(_create_tfrecord_dataset(dataset_dir))

        [image] = provider.get(['image'])
        [label] = provider.get(['label'])
        image = _resize_image(image, height, width)

        with session.Session('') as sess:
            with queues.QueueRunners(sess):
                image, label = sess.run([image, label])
            self.assertListEqual([height, width, 3], list(image.shape))
            self.assertListEqual([1], list(label.shape))
Ejemplo n.º 4
0
    def make_data_provider(self, **kwargs):
        """Creates DataProvider instance for this input pipeline. Additional keyword arguments
        are passed to the DataProvider.
        """
        keys_to_features = {
            'image/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
            'image/format':
            tf.FixedLenFeature(
                (),
                tf.string,
                default_value=self.meta_data.get('image_format')),
            'image/class/label':
            tf.FixedLenFeature([1],
                               tf.int64,
                               default_value=tf.zeros([1], dtype=tf.int64)),
        }

        image_shape = [
            self.meta_data.get('height'),
            self.meta_data.get('width'),
            self.meta_data.get('channels')
        ]
        if not all(image_shape):
            # no reshaping should be done
            image_shape = None

        items_to_handlers = {
            'image':
            tfslim.tfexample_decoder.Image(
                shape=image_shape, channels=self.meta_data.get('channels')),
            'label':
            tfslim.tfexample_decoder.Tensor('image/class/label', shape=[]),
        }

        decoder = TFExampleDecoder(keys_to_features, items_to_handlers)

        dataset = Dataset(data_sources=self.data_files,
                          reader=tf.TFRecordReader,
                          decoder=decoder,
                          num_samples=self.meta_data.get('num_samples',
                                                         {}).get(self.mode),
                          num_classes=self.meta_data['num_classes'],
                          items_to_descriptions=self.meta_data.get(
                              'items_to_descriptions', {}),
                          meta_data=self.meta_data,
                          labels_to_names=self.meta_data['labels_to_classes'])

        return DatasetDataProvider(dataset=dataset,
                                   shuffle=self.shuffle,
                                   num_epochs=self.num_epochs,
                                   **kwargs)
Ejemplo n.º 5
0
    def test_TFRecordDataset(self):
        dataset_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), 'tfrecord_dataset'))

        height = 300
        width = 280

        with self.test_session():
            test_dataset = _create_tfrecord_dataset(dataset_dir)
            provider = DatasetDataProvider(test_dataset)
            # key, image, label = provider.get(['record_key', 'image', 'label'])
            image, label = provider.get(['image', 'label'])
            image = _resize_image(image, height, width)

            with session.Session('') as sess:
                with queues.QueueRunners(sess):
                    # key, image, label = sess.run([key, image, label])
                    image, label = sess.run([image, label])
            # split_key = key.decode('utf-8').split(':')
            # self.assertEqual(2, len(split_key))
            # self.assertEqual(test_dataset.data_sources[0], split_key[0])
            # self.assertTrue(split_key[1].isdigit())
            self.assertListEqual([height, width, 3], list(image.shape))
            self.assertListEqual([1], list(label.shape))
Ejemplo n.º 6
0
    def make_data_provider(self, **kwargs):
        """Creates DataProvider instance for this input pipeline. Additional keyword arguments
        are passed to the DataProvider.
        """
        context_keys_to_features = {
            self.image_field:
            tf.FixedLenFeature([], dtype=tf.string),
            "image/format":
            tf.FixedLenFeature([],
                               dtype=tf.string,
                               default_value=self.image_format),
        }

        sequence_keys_to_features = {
            self.caption_ids_field:
            tf.FixedLenSequenceFeature([], dtype=tf.int64),
            self.caption_tokens_field:
            tf.FixedLenSequenceFeature([], dtype=tf.string)
        }

        items_to_handlers = {
            'image':
            tfslim.tfexample_decoder.Image(image_key=self.image_field,
                                           format_key="image/format",
                                           channels=3),
            'target_ids':
            tfslim.tfexample_decoder.Tensor(self.caption_ids_field),
            'target_token':
            tfslim.tfexample_decoder.Tensor(self.caption_tokens_field),
            'target_len':
            tfslim.tfexample_decoder.ItemHandlerCallback(
                keys=[self.caption_tokens_field],
                func=lambda x: tf.size(x[self.caption_tokens_field]))
        }

        decoder = TFSequenceExampleDecoder(context_keys_to_features,
                                           sequence_keys_to_features,
                                           items_to_handlers)

        dataset = Dataset(data_sources=self.files,
                          reader=tf.TFRecordReader,
                          decoder=decoder,
                          num_samples=None,
                          items_to_descriptions={})

        return DatasetDataProvider(dataset=dataset,
                                   shuffle=self.shuffle,
                                   num_epochs=self.num_epochs,
                                   **kwargs)