def make_parallel_data_provider(data_sources_source, data_sources_target, reader=tf.TextLineReader, num_samples=None, delimiter=" ", **kwargs): """Creates a DataProvider that reads parallel text data. Args: data_sources_source: A list of data sources for the source text files. data_sources_target: A list of data sources for the target text files. Can be None for inference mode. num_samples: Optional, number of records in the dataset delimiter: Split tokens in the data on this delimiter. Defaults to space. kwargs: Additional arguments (shuffle, num_epochs, etc) that are passed to the data provider Returns: A DataProvider instance """ decoder_source = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="source_tokens", length_feature_name="source_len", append_token="SEQUENCE_END", delimiter=delimiter) dataset_source = tf.contrib.slim.dataset.Dataset( data_sources=data_sources_source, reader=reader, decoder=decoder_source, num_samples=num_samples, items_to_descriptions={}) dataset_target = None if data_sources_target is not None: decoder_target = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="target_tokens", length_feature_name="target_len", prepend_token="SEQUENCE_START", append_token="SEQUENCE_END", delimiter=delimiter) dataset_target = tf.contrib.slim.dataset.Dataset( data_sources=data_sources_target, reader=reader, decoder=decoder_target, num_samples=num_samples, items_to_descriptions={}) return parallel_data_provider.ParallelDataProvider( dataset1=dataset_source, dataset2=dataset_target, **kwargs)
def make_data_provider(self, **kwargs): decoder_source = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="source_tokens", length_feature_name="source_len", append_token="SEQUENCE_END", delimiter=self.params["source_delimiter"]) dataset_source = tf.contrib.slim.dataset.Dataset( data_sources=self.params["source_files"], reader=tf.TextLineReader, decoder=decoder_source, num_samples=None, items_to_descriptions={}) dataset_target = None if len(self.params["target_files"]) > 0: decoder_target = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="target_tokens", length_feature_name="target_len", prepend_token="SEQUENCE_START", append_token="SEQUENCE_END", delimiter=self.params["target_delimiter"]) dataset_target = tf.contrib.slim.dataset.Dataset( data_sources=self.params["target_files"], reader=tf.TextLineReader, decoder=decoder_target, num_samples=None, items_to_descriptions={}) return parallel_data_provider.ParallelDataProvider( dataset1=dataset_source, dataset2=dataset_target, shuffle=self.params["shuffle"], num_epochs=self.params["num_epochs"], **kwargs)