示例#1
0
    def make_data_provider(self, **kwargs):
        """Creates DataProvider instance for this input pipeline. Additional keyword arguments
        are passed to the DataProvider.
        """
        decoder_source = SplitTokensDecoder(tokens_feature_name='source_token',
                                            length_feature_name='source_len',
                                            append_token='SEQUENCE_END',
                                            delimiter=self.source_delimiter)

        dataset_source = Dataset(data_sources=self.source_files,
                                 reader=tf.TextLineReader,
                                 decoder=decoder_source,
                                 num_samples=None,
                                 items_to_descriptions={})

        dataset_target = None
        if len(self.target_files) > 0:
            decoder_target = SplitTokensDecoder(
                tokens_feature_name='target_token',
                length_feature_name='target_len',
                prepend_token='SEQUENCE_START',
                append_token='SEQUENCE_END',
                delimiter=self.target_delimiter)

            dataset_target = Dataset(data_sources=self.target_files,
                                     reader=tf.TextLineReader,
                                     decoder=decoder_target,
                                     num_samples=None,
                                     items_to_descriptions={})

        return ParallelDatasetProvider(dataset_source=dataset_source,
                                       dataset_target=dataset_target,
                                       shuffle=self.shuffle,
                                       num_epochs=self.num_epochs,
                                       **kwargs)
示例#2
0
    def make_data_provider(self, **kwargs):
        """Creates DataProvider instance for this input pipeline. Additional keyword arguments
        are passed to the DataProvider.
        """
        context_keys_to_features = self._create_context_features()
        sequence_keys_to_features = self._create_sequence_features()
        items_to_handlers = self._create_items_to_handlers()

        decoder = TFSequenceExampleDecoder(
            context_keys_to_features=context_keys_to_features,
            sequence_keys_to_features=sequence_keys_to_features,
            items_to_handlers=items_to_handlers)

        dataset = Dataset(
            data_sources=self.data_files,
            reader=tf.TFRecordReader,
            decoder=decoder,
            num_samples=self.meta_data.get('num_samples', {}).get(self.mode),
            num_classes=self.meta_data.get('num_classes'),
            items_to_descriptions=self.meta_data.get('items_to_descriptions',
                                                     {}),
            meta_data=self.meta_data,
            labels_to_names=self.meta_data.get('labels_to_classes'))

        return DatasetDataProvider(dataset=dataset,
                                   shuffle=self.shuffle,
                                   num_epochs=self.num_epochs,
                                   **kwargs)
示例#3
0
    def make_data_provider(self, **kwargs):
        """Creates DataProvider instance for this input pipeline. Additional keyword arguments
        are passed to the DataProvider.
        """
        splitter_source = SplitTokensDecoder(
            tokens_feature_name='source_token',
            length_feature_name='source_len',
            append_token='SEQUENCE_END',
            delimiter=self.source_delimiter)

        splitter_target = SplitTokensDecoder(
            tokens_feature_name='target_token',
            length_feature_name='target_len',
            prepend_token='SEQUENCE_START',
            append_token='SEQUENCE_END',
            delimiter=self.target_delimiter)

        keys_to_features = {
            self.source_field: tf.FixedLenFeature((), tf.string),
            self.target_field: tf.FixedLenFeature((),
                                                  tf.string,
                                                  default_value="")
        }

        items_to_handlers = {
            'source_token':
            tfslim.tfexample_decoder.ItemHandlerCallback(
                keys=[self.source_field],
                func=lambda dict: splitter_source.decode(
                    dict[self.source_field], ['source_token'])[0]),
            'source_len':
            tfslim.tfexample_decoder.ItemHandlerCallback(
                keys=[self.source_field],
                func=lambda dict: splitter_source.decode(
                    dict[self.source_field], ['source_len'])[0]),
            'target_token':
            tfslim.tfexample_decoder.ItemHandlerCallback(
                keys=[self.target_field],
                func=lambda dict: splitter_target.decode(
                    dict[self.target_field], ['target_token'])[0]),
            'target_len':
            tfslim.tfexample_decoder.ItemHandlerCallback(
                keys=[self.target_field],
                func=lambda dict: splitter_target.decode(
                    dict[self.target_field], ['target_len'])[0])
        }

        decoder = TFExampleDecoder(keys_to_features, items_to_handlers)

        dataset = Dataset(data_sources=self.files,
                          reader=tf.TFRecordReader,
                          decoder=decoder)

        return DatasetDataProvider(dataset=dataset,
                                   shuffle=self.shuffle,
                                   num_epochs=self.num_epochs,
                                   **kwargs)
示例#4
0
    def make_data_provider(self, **kwargs):
        """Creates DataProvider instance for this input pipeline. Additional keyword arguments
        are passed to the DataProvider.
        """
        keys_to_features = {
            'image/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
            'image/format':
            tf.FixedLenFeature(
                (),
                tf.string,
                default_value=self.meta_data.get('image_format')),
            'image/class/label':
            tf.FixedLenFeature([1],
                               tf.int64,
                               default_value=tf.zeros([1], dtype=tf.int64)),
        }

        image_shape = [
            self.meta_data.get('height'),
            self.meta_data.get('width'),
            self.meta_data.get('channels')
        ]
        if not all(image_shape):
            # no reshaping should be done
            image_shape = None

        items_to_handlers = {
            'image':
            tfslim.tfexample_decoder.Image(
                shape=image_shape, channels=self.meta_data.get('channels')),
            'label':
            tfslim.tfexample_decoder.Tensor('image/class/label', shape=[]),
        }

        decoder = TFExampleDecoder(keys_to_features, items_to_handlers)

        dataset = Dataset(data_sources=self.data_files,
                          reader=tf.TFRecordReader,
                          decoder=decoder,
                          num_samples=self.meta_data.get('num_samples',
                                                         {}).get(self.mode),
                          num_classes=self.meta_data['num_classes'],
                          items_to_descriptions=self.meta_data.get(
                              'items_to_descriptions', {}),
                          meta_data=self.meta_data,
                          labels_to_names=self.meta_data['labels_to_classes'])

        return DatasetDataProvider(dataset=dataset,
                                   shuffle=self.shuffle,
                                   num_epochs=self.num_epochs,
                                   **kwargs)
示例#5
0
    def make_data_provider(self, **kwargs):
        """Creates DataProvider instance for this input pipeline. Additional keyword arguments
        are passed to the DataProvider.
        """
        context_keys_to_features = {
            self.image_field:
            tf.FixedLenFeature([], dtype=tf.string),
            "image/format":
            tf.FixedLenFeature([],
                               dtype=tf.string,
                               default_value=self.image_format),
        }

        sequence_keys_to_features = {
            self.caption_ids_field:
            tf.FixedLenSequenceFeature([], dtype=tf.int64),
            self.caption_tokens_field:
            tf.FixedLenSequenceFeature([], dtype=tf.string)
        }

        items_to_handlers = {
            'image':
            tfslim.tfexample_decoder.Image(image_key=self.image_field,
                                           format_key="image/format",
                                           channels=3),
            'target_ids':
            tfslim.tfexample_decoder.Tensor(self.caption_ids_field),
            'target_token':
            tfslim.tfexample_decoder.Tensor(self.caption_tokens_field),
            'target_len':
            tfslim.tfexample_decoder.ItemHandlerCallback(
                keys=[self.caption_tokens_field],
                func=lambda x: tf.size(x[self.caption_tokens_field]))
        }

        decoder = TFSequenceExampleDecoder(context_keys_to_features,
                                           sequence_keys_to_features,
                                           items_to_handlers)

        dataset = Dataset(data_sources=self.files,
                          reader=tf.TFRecordReader,
                          decoder=decoder,
                          num_samples=None,
                          items_to_descriptions={})

        return DatasetDataProvider(dataset=dataset,
                                   shuffle=self.shuffle,
                                   num_epochs=self.num_epochs,
                                   **kwargs)