def _get_copying_decoder(self, tokens_feature_name, length_feature_name,
                          prepend_token, append_token, delimiter):
     return copying_decoder.WordCopyingDecoder(
         tokens_feature_name=tokens_feature_name,
         length_feature_name=length_feature_name,
         prepend_token=prepend_token,
         append_token=append_token,
         delimiter=delimiter)
Exemple #2
0
def make_word_copying_data_provider(data_sources_source,
                                    data_sources_target,
                                    data_sources_schema=None,
                                    reader=tf.TextLineReader,
                                    num_samples=None,
                                    source_delimiter=" ",
                                    target_delimiter=" ",
                                    **kwargs):
    """
  Builds a copying data provider for word-only copying.
  Args:
    data_sources_source: A list of data sources for the source text
      files.
    data_sources_target: A list of data sources for the targer text
      files.
    data_sources_schema: An optional list of data sources for the schema
      location text files.
    reader: A reader that can handle the source and schema files.
    num_samples: Optional, number of records in the dataset
    source_delimiter: Split tokens in the source data on this
      delimiter. Defaults to space.
    target_delimiter: Split tokens in the target data on this
      delimiter. Defaults to space.
    kwargs: Additional arguments (shuffle, num_epochs, etc) that are passed
      to the data provider

  Returns:
    A WordCopyingDataProvider.
  """

    dataset_source, dataset_schemas = _make_copying_data_provider_base(
        data_sources_source,
        data_sources_schema,
        reader=tf.TextLineReader,
        num_samples=num_samples,
        source_delimiter=" ",
        **kwargs)
    dataset_target = None
    if data_sources_target is not None:
        decoder_target = copying_decoder.WordCopyingDecoder(
            tokens_feature_name="target_tokens",
            length_feature_name="target_len",
            prepend_token="SEQUENCE_START",
            append_token="SEQUENCE_END",
            delimiter=target_delimiter)

        dataset_target = tf.contrib.slim.dataset.Dataset(
            data_sources=data_sources_target,
            reader=reader,
            decoder=decoder_target,
            num_samples=num_samples,
            items_to_descriptions={})

    return WordCopyingDataProvider(dataset1=dataset_source,
                                   dataset2=dataset_target,
                                   schemas=dataset_schemas,
                                   **kwargs)