def make_data_provider(self, **kwargs): splitter_source = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="source_tokens", length_feature_name="source_len", append_token="SEQUENCE_END", delimiter=self.params["source_delimiter"]) splitter_target = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="target_tokens", length_feature_name="target_len", prepend_token="SEQUENCE_START", append_token="SEQUENCE_END", delimiter=self.params["target_delimiter"]) keys_to_features = { self.params["source_field"]: tf.FixedLenFeature((), tf.string), self.params["target_field"]: tf.FixedLenFeature((), tf.string, default_value="") } items_to_handlers = {} items_to_handlers[ "source_tokens"] = tfexample_decoder.ItemHandlerCallback( keys=[self.params["source_field"]], func=lambda dict: splitter_source.decode( dict[self.params["source_field"]], ["source_tokens"])[0]) items_to_handlers[ "source_len"] = tfexample_decoder.ItemHandlerCallback( keys=[self.params["source_field"]], func=lambda dict: splitter_source.decode( dict[self.params["source_field"]], ["source_len"])[0]) items_to_handlers[ "target_tokens"] = tfexample_decoder.ItemHandlerCallback( keys=[self.params["target_field"]], func=lambda dict: splitter_target.decode( dict[self.params["target_field"]], ["target_tokens"])[0]) items_to_handlers[ "target_len"] = tfexample_decoder.ItemHandlerCallback( keys=[self.params["target_field"]], func=lambda dict: splitter_target.decode( dict[self.params["target_field"]], ["target_len"])[0]) decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) dataset = tf.contrib.slim.dataset.Dataset( data_sources=self.params["files"], reader=tf.TFRecordReader, decoder=decoder, num_samples=None, items_to_descriptions={}) return tf.contrib.slim.dataset_data_provider.DatasetDataProvider( dataset=dataset, shuffle=self.params["shuffle"], num_epochs=self.params["num_epochs"], **kwargs)
def make_parallel_data_provider(data_sources_source, data_sources_target, reader=tf.TextLineReader, num_samples=None, source_delimiter=" ", target_delimiter=" ", **kwargs): """Creates a DataProvider that reads parallel text data. Args: data_sources_source: A list of data sources for the source text files. data_sources_target: A list of data sources for the target text files. Can be None for inference mode. num_samples: Optional, number of records in the dataset delimiter: Split tokens in the data on this delimiter. Defaults to space. kwargs: Additional arguments (shuffle, num_epochs, etc) that are passed to the data provider Returns: A DataProvider instance """ decoder_source = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="source_tokens", length_feature_name="source_len", append_token="SEQUENCE_END", #append_to_fixed_len=50, delimiter=source_delimiter) dataset_source = tf.contrib.slim.dataset.Dataset( data_sources=data_sources_source, reader=reader, decoder=decoder_source, num_samples=num_samples, items_to_descriptions={}) dataset_target = None if data_sources_target is not None: decoder_target = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="target_tokens", length_feature_name="target_len", prepend_token="SEQUENCE_START", append_token="SEQUENCE_END", delimiter=target_delimiter) dataset_target = tf.contrib.slim.dataset.Dataset( data_sources=data_sources_target, reader=reader, decoder=decoder_target, num_samples=num_samples, items_to_descriptions={}) return ParallelDataProvider(dataset1=dataset_source, dataset2=dataset_target, **kwargs)
def make_data_provider(self, **kwargs): decoder_source_query = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="source_tokens", length_feature_name="source_len", append_token="SEQUENCE_END", delimiter=self.params["source_delimiter"]) dataset_source_query = tf.contrib.slim.dataset.Dataset( data_sources=self.params["query_files"], reader=tf.TextLineReader, decoder=decoder_source_query, num_samples=None, items_to_descriptions={}) decoder_source_candidate = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="source_candidate_tokens", length_feature_name="source_candidate_len", append_token="SEQUENCE_END", delimiter=self.params["source_delimiter"]) dataset_source_candidate = tf.contrib.slim.dataset.Dataset( data_sources=self.params["candidate_files"], reader=tf.TextLineReader, decoder=decoder_source_candidate, num_samples=None, items_to_descriptions={}) dataset_target = None if len(self.params["target_files"]) > 0: decoder_target = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="target_tokens", length_feature_name="target_len", prepend_token="SEQUENCE_START", append_token="SEQUENCE_END", delimiter=self.params["target_delimiter"]) dataset_target = tf.contrib.slim.dataset.Dataset( data_sources=self.params["target_files"], reader=tf.TextLineReader, decoder=decoder_target, num_samples=None, items_to_descriptions={}) return parallel_data_provider.TripleDataProvider( dataset1=dataset_source_query, dataset2=dataset_source_candidate, dataset3=dataset_target, shuffle=self.params["shuffle"], num_epochs=self.params["num_epochs"], **kwargs)
def test_decode(self): decoder = split_tokens_decoder.SplitTokensDecoder( delimiter=" ", tokens_feature_name="source_tokens", length_feature_name="source_len") self.assertEqual(decoder.list_items(), ["source_tokens", "source_len"]) data = tf.constant("Hello world ! 笑w") decoded_tokens = decoder.decode(data, ["source_tokens"]) decoded_length = decoder.decode(data, ["source_len"]) decoded_both = decoder.decode(data, decoder.list_items()) with self.test_session() as sess: decoded_tokens_ = sess.run(decoded_tokens)[0] decoded_length_ = sess.run(decoded_length)[0] decoded_both_ = sess.run(decoded_both) self.assertEqual(decoded_length_, 4) np.testing.assert_array_equal( np.char.decode(decoded_tokens_.astype("S"), "utf-8"), ["Hello", "world", "!", "笑w"]) self.assertEqual(decoded_both_[1], 4) np.testing.assert_array_equal( np.char.decode(decoded_both_[0].astype("S"), "utf-8"), ["Hello", "world", "!", "笑w"])
def _make_copying_data_provider_base(data_sources_source, data_sources_schema, reader=tf.TextLineReader, num_samples=None, source_delimiter=" ", **kwargs): """ Prepare the Datasets that will be used to make the copying data provider. Args: data_sources_source: A list of data sources for the source text files. data_sources_schema: A list of data sources for the schema location text files. reader: A reader that can handle the source and schema files. num_samples: Optional, number of records in the dataset delimiter: Split tokens in the data on this delimiter. Defaults to space. kwargs: Additional arguments (shuffle, num_epochs, etc) that are passed to the data provider Returns: The Datasets for source and schema. """ decoder_source = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="source_tokens", length_feature_name="source_len", append_token="SEQUENCE_END", delimiter=source_delimiter) dataset_source = tf.contrib.slim.dataset.Dataset( data_sources=data_sources_source, reader=reader, decoder=decoder_source, num_samples=num_samples, items_to_descriptions={}) dataset_schemas = None if data_sources_schema is not None: decoder_schemas = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="schema_loc", delimiter=" ") dataset_schemas = tf.contrib.slim.dataset.Dataset( data_sources=data_sources_schema, reader=reader, decoder=decoder_schemas, num_samples=num_samples, items_to_descriptions={}) return dataset_source, dataset_schemas
def make_data_provider(self, **kwargs): data_files = [] tf.logging.info(self.params["file_input_pattern"].split(",")) for pattern in self.params["file_input_pattern"].split(","): data_files.extend(tf.gfile.Glob(pattern)) if not data_files: tf.logging.fatal("Found no input files matching %s", self.params["file_input_pattern"]) else: tf.logging.info("Prefetching values from %d files matching %s", len(data_files), self.params["file_input_pattern"]) splitter_target = split_tokens_decoder.SplitTokensDecoder( tokens_feature_name="target_tokens", length_feature_name="target_len", prepend_token="SEQUENCE_START", append_token="SEQUENCE_END", delimiter=self.params["target_delimiter"]) context_keys_to_features = { self.params["caption_tokens_field"]: tf.FixedLenFeature( [], dtype=tf.string), } sequence_keys_to_features = { self.params["video_field"]: tf.FixedLenSequenceFeature( [], dtype=tf.string), } items_to_handlers = { "source_tokens": tfexample_decoder.Tensor(self.params["video_field"]), "source_len": tfexample_decoder.ItemHandlerCallback( keys=[self.params["video_field"]], func=lambda x: tf.size(x[self.params["video_field"]])), "target_tokens": tfexample_decoder.ItemHandlerCallback( keys=[self.params["caption_tokens_field"]], func=lambda dict: splitter_target.decode( dict[self.params["caption_tokens_field"]], ["target_tokens"])[0]), "target_len": tfexample_decoder.ItemHandlerCallback( keys=[self.params["caption_tokens_field"]], func=lambda dict: splitter_target.decode( dict[self.params["caption_tokens_field"]], ["target_len"])[0]) } decoder = TFSEquenceExampleDecoder(context_keys_to_features, sequence_keys_to_features, items_to_handlers) dataset = tf.contrib.slim.dataset.Dataset( data_sources=data_files, reader=tf.TFRecordReader, decoder=decoder, num_samples=None, items_to_descriptions={}) return tf.contrib.slim.dataset_data_provider.DatasetDataProvider( dataset=dataset, shuffle=self.params["shuffle"], num_epochs=self.params["num_epochs"], **kwargs)