def _fn(): self._initialize(metadata) dataset = inputter.make_dataset(data_file, training=training) if training: batch_size_multiple = 1 if batch_type == "tokens" and self.dtype == tf.float16: batch_size_multiple = 8 dataset = data.training_pipeline( dataset, batch_size, batch_type=batch_type, batch_multiplier=batch_multiplier, bucket_width=bucket_width, single_pass=single_pass, process_fn=process_fn, num_threads=num_threads, shuffle_buffer_size=sample_buffer_size, prefetch_buffer_size=prefetch_buffer_size, dataset_size=self.features_inputter.get_dataset_size( features_file), maximum_features_length=maximum_features_length, maximum_labels_length=maximum_labels_length, features_length_fn=self.features_inputter.get_length, labels_length_fn=self.labels_inputter.get_length, batch_size_multiple=batch_size_multiple, num_shards=num_shards, shard_index=shard_index) else: dataset = data.inference_pipeline( dataset, batch_size, process_fn=process_fn, num_threads=num_threads, prefetch_buffer_size=prefetch_buffer_size, bucket_width=bucket_width, length_fn=self.features_inputter.get_length) iterator = dataset.make_initializable_iterator() # Add the initializer to a standard collection for it to be initialized. tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) return iterator.get_next()
def make_training_dataset(self, features_file, labels_file, batch_size, batch_type="examples", batch_multiplier=1, batch_size_multiple=1, shuffle_buffer_size=None, bucket_width=None, maximum_features_length=None, maximum_labels_length=None, single_pass=False, num_shards=1, shard_index=0, num_threads=4, prefetch_buffer_size=None): """See :meth:`opennmt.inputters.inputter.ExampleInputter.make_training_dataset`.""" _ = labels_file dataset = self.make_dataset(features_file, training=True) dataset = data.training_pipeline( dataset, batch_size, batch_type=batch_type, batch_multiplier=batch_multiplier, bucket_width=bucket_width, single_pass=single_pass, process_fn=lambda x: self._generate_example(x, training=True), num_threads=num_threads, shuffle_buffer_size=shuffle_buffer_size, prefetch_buffer_size=prefetch_buffer_size, maximum_features_length=maximum_features_length, maximum_labels_length=maximum_labels_length, features_length_fn=self.get_length, batch_size_multiple=batch_size_multiple, num_shards=num_shards, shard_index=shard_index) return dataset
def _input_fn_impl(self, mode, batch_size, metadata, features_file, labels_file=None, batch_type="examples", batch_multiplier=1, bucket_width=None, single_pass=False, num_threads=None, sample_buffer_size=None, prefetch_buffer_size=None, maximum_features_length=None, maximum_labels_length=None): """See ``input_fn``.""" self._initialize(metadata) feat_dataset, feat_process_fn = self._get_features_builder( features_file) if labels_file is None: dataset = feat_dataset # Parallel inputs must be catched in a single tuple and not considered as multiple arguments. process_fn = lambda *arg: feat_process_fn(item_or_tuple(arg)) else: labels_dataset, labels_process_fn = self._get_labels_builder( labels_file) dataset = tf.data.Dataset.zip((feat_dataset, labels_dataset)) process_fn = lambda features, labels: (feat_process_fn(features), labels_process_fn(labels)) dataset, process_fn = self._augment_parallel_dataset(dataset, process_fn, mode=mode) if mode == tf.estimator.ModeKeys.TRAIN: dataset = data.training_pipeline( dataset, batch_size, batch_type=batch_type, batch_multiplier=batch_multiplier, bucket_width=bucket_width, single_pass=single_pass, process_fn=process_fn, num_threads=num_threads, shuffle_buffer_size=sample_buffer_size, prefetch_buffer_size=prefetch_buffer_size, dataset_size=self._get_dataset_size(features_file), maximum_features_length=maximum_features_length, maximum_labels_length=maximum_labels_length, features_length_fn=self._get_features_length, labels_length_fn=self._get_labels_length) else: dataset = data.inference_pipeline( dataset, batch_size, process_fn=process_fn, num_threads=num_threads, prefetch_buffer_size=prefetch_buffer_size) iterator = dataset.make_initializable_iterator() # Add the initializer to a standard collection for it to be initialized. tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) return iterator.get_next()
def make_training_dataset(self, features_file, labels_file, batch_size, batch_type="examples", batch_multiplier=1, batch_size_multiple=1, shuffle_buffer_size=None, bucket_width=None, maximum_features_length=None, maximum_labels_length=None, single_pass=False, num_shards=1, shard_index=0, num_threads=4, prefetch_buffer_size=None): """Builds a dataset to be used for training. It supports the full training pipeline, including: * sharding * shuffling * filtering * bucketing * prefetching Args: features_file: The evaluation source file. labels_file: The evaluation target file. batch_size: The batch size to use. batch_type: The training batching stragety to use: can be "examples" or "tokens". batch_multiplier: The batch size multiplier to prepare splitting accross replicated graph parts. batch_size_multiple: When :obj:`batch_type` is "tokens", ensure that the result batch size is a multiple of this value. shuffle_buffer_size: The number of elements from which to sample. bucket_width: The width of the length buckets to select batch candidates from (for efficiency). Set ``None`` to not constrain batch formation. maximum_features_length: The maximum length or list of maximum lengths of the features sequence(s). ``None`` to not constrain the length. maximum_labels_length: The maximum length of the labels sequence. ``None`` to not constrain the length. single_pass: If ``True``, makes a single pass over the training data. num_shards: The number of data shards (usually the number of workers in a distributed setting). shard_index: The shard index this data pipeline should read from. num_threads: The number of elements processed in parallel. prefetch_buffer_size: The number of batches to prefetch asynchronously. If ``None``, use an automatically tuned value on TensorFlow 1.8+ and 1 on older versions. Returns: A ``tf.data.Dataset``. """ dataset_size = self.features_inputter.get_dataset_size(features_file) map_func = lambda *arg: self.make_features(arg, training=True) dataset = self.make_dataset([features_file, labels_file], training=True) dataset = training_pipeline( dataset, batch_size, batch_type=batch_type, batch_multiplier=batch_multiplier, bucket_width=bucket_width, single_pass=single_pass, process_fn=map_func, num_threads=num_threads, shuffle_buffer_size=shuffle_buffer_size, prefetch_buffer_size=prefetch_buffer_size, dataset_size=dataset_size, maximum_features_length=maximum_features_length, maximum_labels_length=maximum_labels_length, features_length_fn=self.features_inputter.get_length, labels_length_fn=self.labels_inputter.get_length, batch_size_multiple=batch_size_multiple, num_shards=num_shards, shard_index=shard_index) return dataset