Esempio n. 1
0
    def make_data_provider(self, **kwargs):

        splitter_source = split_tokens_decoder.SplitTokensDecoder(
            tokens_feature_name="source_tokens",
            length_feature_name="source_len",
            append_token="SEQUENCE_END",
            delimiter=self.params["source_delimiter"])

        splitter_target = split_tokens_decoder.SplitTokensDecoder(
            tokens_feature_name="target_tokens",
            length_feature_name="target_len",
            prepend_token="SEQUENCE_START",
            append_token="SEQUENCE_END",
            delimiter=self.params["target_delimiter"])

        keys_to_features = {
            self.params["source_field"]:
            tf.FixedLenFeature((), tf.string),
            self.params["target_field"]:
            tf.FixedLenFeature((), tf.string, default_value="")
        }

        items_to_handlers = {}
        items_to_handlers[
            "source_tokens"] = tfexample_decoder.ItemHandlerCallback(
                keys=[self.params["source_field"]],
                func=lambda dict: splitter_source.decode(
                    dict[self.params["source_field"]], ["source_tokens"])[0])
        items_to_handlers[
            "source_len"] = tfexample_decoder.ItemHandlerCallback(
                keys=[self.params["source_field"]],
                func=lambda dict: splitter_source.decode(
                    dict[self.params["source_field"]], ["source_len"])[0])
        items_to_handlers[
            "target_tokens"] = tfexample_decoder.ItemHandlerCallback(
                keys=[self.params["target_field"]],
                func=lambda dict: splitter_target.decode(
                    dict[self.params["target_field"]], ["target_tokens"])[0])
        items_to_handlers[
            "target_len"] = tfexample_decoder.ItemHandlerCallback(
                keys=[self.params["target_field"]],
                func=lambda dict: splitter_target.decode(
                    dict[self.params["target_field"]], ["target_len"])[0])

        decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                     items_to_handlers)

        dataset = tf.contrib.slim.dataset.Dataset(
            data_sources=self.params["files"],
            reader=tf.TFRecordReader,
            decoder=decoder,
            num_samples=None,
            items_to_descriptions={})

        return tf.contrib.slim.dataset_data_provider.DatasetDataProvider(
            dataset=dataset,
            shuffle=self.params["shuffle"],
            num_epochs=self.params["num_epochs"],
            **kwargs)
Esempio n. 2
0
def make_parallel_data_provider(data_sources_source,
                                data_sources_target,
                                reader=tf.TextLineReader,
                                num_samples=None,
                                source_delimiter=" ",
                                target_delimiter=" ",
                                **kwargs):
    """Creates a DataProvider that reads parallel text data.

  Args:
    data_sources_source: A list of data sources for the source text files.
    data_sources_target: A list of data sources for the target text files.
      Can be None for inference mode.
    num_samples: Optional, number of records in the dataset
    delimiter: Split tokens in the data on this delimiter. Defaults to space.
    kwargs: Additional arguments (shuffle, num_epochs, etc) that are passed
      to the data provider

  Returns:
    A DataProvider instance
  """

    decoder_source = split_tokens_decoder.SplitTokensDecoder(
        tokens_feature_name="source_tokens",
        length_feature_name="source_len",
        append_token="SEQUENCE_END",
        #append_to_fixed_len=50,
        delimiter=source_delimiter)

    dataset_source = tf.contrib.slim.dataset.Dataset(
        data_sources=data_sources_source,
        reader=reader,
        decoder=decoder_source,
        num_samples=num_samples,
        items_to_descriptions={})

    dataset_target = None
    if data_sources_target is not None:
        decoder_target = split_tokens_decoder.SplitTokensDecoder(
            tokens_feature_name="target_tokens",
            length_feature_name="target_len",
            prepend_token="SEQUENCE_START",
            append_token="SEQUENCE_END",
            delimiter=target_delimiter)

        dataset_target = tf.contrib.slim.dataset.Dataset(
            data_sources=data_sources_target,
            reader=reader,
            decoder=decoder_target,
            num_samples=num_samples,
            items_to_descriptions={})

    return ParallelDataProvider(dataset1=dataset_source,
                                dataset2=dataset_target,
                                **kwargs)
Esempio n. 3
0
    def make_data_provider(self, **kwargs):
        decoder_source_query = split_tokens_decoder.SplitTokensDecoder(
            tokens_feature_name="source_tokens",
            length_feature_name="source_len",
            append_token="SEQUENCE_END",
            delimiter=self.params["source_delimiter"])

        dataset_source_query = tf.contrib.slim.dataset.Dataset(
            data_sources=self.params["query_files"],
            reader=tf.TextLineReader,
            decoder=decoder_source_query,
            num_samples=None,
            items_to_descriptions={})

        decoder_source_candidate = split_tokens_decoder.SplitTokensDecoder(
            tokens_feature_name="source_candidate_tokens",
            length_feature_name="source_candidate_len",
            append_token="SEQUENCE_END",
            delimiter=self.params["source_delimiter"])

        dataset_source_candidate = tf.contrib.slim.dataset.Dataset(
            data_sources=self.params["candidate_files"],
            reader=tf.TextLineReader,
            decoder=decoder_source_candidate,
            num_samples=None,
            items_to_descriptions={})

        dataset_target = None
        if len(self.params["target_files"]) > 0:
            decoder_target = split_tokens_decoder.SplitTokensDecoder(
                tokens_feature_name="target_tokens",
                length_feature_name="target_len",
                prepend_token="SEQUENCE_START",
                append_token="SEQUENCE_END",
                delimiter=self.params["target_delimiter"])

            dataset_target = tf.contrib.slim.dataset.Dataset(
                data_sources=self.params["target_files"],
                reader=tf.TextLineReader,
                decoder=decoder_target,
                num_samples=None,
                items_to_descriptions={})

        return parallel_data_provider.TripleDataProvider(
            dataset1=dataset_source_query,
            dataset2=dataset_source_candidate,
            dataset3=dataset_target,
            shuffle=self.params["shuffle"],
            num_epochs=self.params["num_epochs"],
            **kwargs)
    def test_decode(self):
        decoder = split_tokens_decoder.SplitTokensDecoder(
            delimiter=" ",
            tokens_feature_name="source_tokens",
            length_feature_name="source_len")

        self.assertEqual(decoder.list_items(), ["source_tokens", "source_len"])

        data = tf.constant("Hello world ! 笑w")

        decoded_tokens = decoder.decode(data, ["source_tokens"])
        decoded_length = decoder.decode(data, ["source_len"])
        decoded_both = decoder.decode(data, decoder.list_items())

        with self.test_session() as sess:
            decoded_tokens_ = sess.run(decoded_tokens)[0]
            decoded_length_ = sess.run(decoded_length)[0]
            decoded_both_ = sess.run(decoded_both)

        self.assertEqual(decoded_length_, 4)
        np.testing.assert_array_equal(
            np.char.decode(decoded_tokens_.astype("S"), "utf-8"),
            ["Hello", "world", "!", "笑w"])

        self.assertEqual(decoded_both_[1], 4)
        np.testing.assert_array_equal(
            np.char.decode(decoded_both_[0].astype("S"), "utf-8"),
            ["Hello", "world", "!", "笑w"])
Esempio n. 5
0
def _make_copying_data_provider_base(data_sources_source,
                                     data_sources_schema,
                                     reader=tf.TextLineReader,
                                     num_samples=None,
                                     source_delimiter=" ",
                                     **kwargs):
    """
  Prepare the Datasets that will be used to make the copying data provider.

  Args:
    data_sources_source: A list of data sources for the source text files.
    data_sources_schema: A list of data sources for the schema location text files.
    reader: A reader that can handle the source and schema files.
    num_samples: Optional, number of records in the dataset
    delimiter: Split tokens in the data on this delimiter. Defaults to space.
    kwargs: Additional arguments (shuffle, num_epochs, etc) that are passed
      to the data provider

  Returns:
    The Datasets for source and schema.
"""
    decoder_source = split_tokens_decoder.SplitTokensDecoder(
        tokens_feature_name="source_tokens",
        length_feature_name="source_len",
        append_token="SEQUENCE_END",
        delimiter=source_delimiter)

    dataset_source = tf.contrib.slim.dataset.Dataset(
        data_sources=data_sources_source,
        reader=reader,
        decoder=decoder_source,
        num_samples=num_samples,
        items_to_descriptions={})

    dataset_schemas = None
    if data_sources_schema is not None:
        decoder_schemas = split_tokens_decoder.SplitTokensDecoder(
            tokens_feature_name="schema_loc", delimiter=" ")
        dataset_schemas = tf.contrib.slim.dataset.Dataset(
            data_sources=data_sources_schema,
            reader=reader,
            decoder=decoder_schemas,
            num_samples=num_samples,
            items_to_descriptions={})
    return dataset_source, dataset_schemas
Esempio n. 6
0
    def make_data_provider(self, **kwargs):
        data_files = []
        tf.logging.info(self.params["file_input_pattern"].split(","))
        for pattern in self.params["file_input_pattern"].split(","):
            data_files.extend(tf.gfile.Glob(pattern))
        if not data_files:
            tf.logging.fatal("Found no input files matching %s", self.params["file_input_pattern"])
        else:
            tf.logging.info("Prefetching values from %d files matching %s",
                            len(data_files), self.params["file_input_pattern"])

        splitter_target = split_tokens_decoder.SplitTokensDecoder(
            tokens_feature_name="target_tokens",
            length_feature_name="target_len",
            prepend_token="SEQUENCE_START",
            append_token="SEQUENCE_END",
            delimiter=self.params["target_delimiter"])

        context_keys_to_features = {
            self.params["caption_tokens_field"]: tf.FixedLenFeature(
                [], dtype=tf.string),
        }

        sequence_keys_to_features = {
            self.params["video_field"]: tf.FixedLenSequenceFeature(
                [], dtype=tf.string),
        }

        items_to_handlers = {
            "source_tokens": tfexample_decoder.Tensor(self.params["video_field"]),
            "source_len": tfexample_decoder.ItemHandlerCallback(
                keys=[self.params["video_field"]],
                func=lambda x: tf.size(x[self.params["video_field"]])),
            "target_tokens": tfexample_decoder.ItemHandlerCallback(
                keys=[self.params["caption_tokens_field"]],
                func=lambda dict: splitter_target.decode(
                    dict[self.params["caption_tokens_field"]], ["target_tokens"])[0]),
            "target_len": tfexample_decoder.ItemHandlerCallback(
                keys=[self.params["caption_tokens_field"]],
                func=lambda dict: splitter_target.decode(
                    dict[self.params["caption_tokens_field"]], ["target_len"])[0])
        }

        decoder = TFSEquenceExampleDecoder(context_keys_to_features, sequence_keys_to_features, items_to_handlers)

        dataset = tf.contrib.slim.dataset.Dataset(
            data_sources=data_files,
            reader=tf.TFRecordReader,
            decoder=decoder,
            num_samples=None,
            items_to_descriptions={})

        return tf.contrib.slim.dataset_data_provider.DatasetDataProvider(
            dataset=dataset,
            shuffle=self.params["shuffle"],
            num_epochs=self.params["num_epochs"],
            **kwargs)