def make_parallel_data_provider(data_sources_source,
                                data_sources_target,
                                reader=tf.TextLineReader,
                                num_samples=None,
                                delimiter=" ",
                                **kwargs):
  """Creates a DataProvider that reads parallel text data.

  Args:
    data_sources_source: A list of data sources for the source text files.
    data_sources_target: A list of data sources for the target text files.
      Can be None for inference mode.
    num_samples: Optional, number of records in the dataset
    delimiter: Split tokens in the data on this delimiter. Defaults to space.
    kwargs: Additional arguments (shuffle, num_epochs, etc) that are passed
      to the data provider

  Returns:
    A DataProvider instance
  """

  decoder_source = split_tokens_decoder.SplitTokensDecoder(
      tokens_feature_name="source_tokens",
      length_feature_name="source_len",
      append_token="SEQUENCE_END",
      delimiter=delimiter)

  dataset_source = tf.contrib.slim.dataset.Dataset(
      data_sources=data_sources_source,
      reader=reader,
      decoder=decoder_source,
      num_samples=num_samples,
      items_to_descriptions={})

  dataset_target = None
  if data_sources_target is not None:
    decoder_target = split_tokens_decoder.SplitTokensDecoder(
        tokens_feature_name="target_tokens",
        length_feature_name="target_len",
        prepend_token="SEQUENCE_START",
        append_token="SEQUENCE_END",
        delimiter=delimiter)

    dataset_target = tf.contrib.slim.dataset.Dataset(
        data_sources=data_sources_target,
        reader=reader,
        decoder=decoder_target,
        num_samples=num_samples,
        items_to_descriptions={})

  return parallel_data_provider.ParallelDataProvider(
      dataset1=dataset_source, dataset2=dataset_target, **kwargs)
Esempio n. 2
0
    def make_data_provider(self, **kwargs):

        decoder_source = split_tokens_decoder.SplitTokensDecoder(
            tokens_feature_name="source_tokens",
            length_feature_name="source_len",
            append_token="SEQUENCE_END",
            delimiter=self.params["source_delimiter"])

        dataset_source = tf.contrib.slim.dataset.Dataset(
            data_sources=self.params["source_files"],
            reader=tf.TextLineReader,
            decoder=decoder_source,
            num_samples=None,
            items_to_descriptions={})

        dataset_target = None
        if len(self.params["target_files"]) > 0:
            decoder_target = split_tokens_decoder.SplitTokensDecoder(
                tokens_feature_name="target_tokens",
                length_feature_name="target_len",
                prepend_token="SEQUENCE_START",
                append_token="SEQUENCE_END",
                delimiter=self.params["target_delimiter"])

            dataset_target = tf.contrib.slim.dataset.Dataset(
                data_sources=self.params["target_files"],
                reader=tf.TextLineReader,
                decoder=decoder_target,
                num_samples=None,
                items_to_descriptions={})

        return parallel_data_provider.ParallelDataProvider(
            dataset1=dataset_source,
            dataset2=dataset_target,
            shuffle=self.params["shuffle"],
            num_epochs=self.params["num_epochs"],
            **kwargs)