def _testBatchTrainDataset(self, check_fn, batch_size, **kwargs): num_examples = 1000 features = tf.random.normal([num_examples], mean=12, stddev=6, seed=42) labels_diff = tf.random.normal([num_examples], mean=0, stddev=3, seed=42) labels = features + labels_diff features = tf.maximum(tf.cast(1, tf.int32), tf.cast(features, tf.int32)) labels = tf.maximum(tf.cast(1, tf.int32), tf.cast(labels, tf.int32)) dataset = tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices(features), tf.data.Dataset.from_tensor_slices(labels))) dataset = dataset.apply(dataset_util.batch_sequence_dataset( batch_size, length_fn=[lambda x: x, lambda x: x], **kwargs)) iterator = iter(dataset) check_fn(iterator)
def make_training_dataset( self, features_file, labels_file, batch_size, batch_type="examples", batch_multiplier=1, batch_size_multiple=1, shuffle_buffer_size=None, length_bucket_width=None, maximum_features_length=None, maximum_labels_length=None, single_pass=False, num_shards=1, shard_index=0, num_threads=4, prefetch_buffer_size=None, cardinality_multiple=1, weights=None, batch_autotune_mode=False, ): """Builds a dataset to be used for training. It supports the full training pipeline, including: * sharding * shuffling * filtering * bucketing * prefetching Args: features_file: The source file or a list of training source files. labels_file: The target file or a list of training target files. batch_size: The batch size to use. batch_type: The training batching strategy to use: can be "examples" or "tokens". batch_multiplier: The batch size multiplier to prepare splitting accross replicated graph parts. batch_size_multiple: When :obj:`batch_type` is "tokens", ensure that the resulting batch size is a multiple of this value. shuffle_buffer_size: The number of elements from which to sample. length_bucket_width: The width of the length buckets to select batch candidates from (for efficiency). Set ``None`` to not constrain batch formation. maximum_features_length: The maximum length or list of maximum lengths of the features sequence(s). ``None`` to not constrain the length. maximum_labels_length: The maximum length of the labels sequence. ``None`` to not constrain the length. single_pass: If ``True``, makes a single pass over the training data. num_shards: The number of data shards (usually the number of workers in a distributed setting). shard_index: The shard index this data pipeline should read from. num_threads: The number of elements processed in parallel. prefetch_buffer_size: The number of batches to prefetch asynchronously. If ``None``, use an automatically tuned value. cardinality_multiple: Ensure that the dataset cardinality is a multiple of this value when :obj:`single_pass` is ``True``. weights: An optional list of weights to create a weighted dataset out of multiple training files. batch_autotune_mode: When enabled, all batches are padded to the maximum sequence length. Returns: A ``tf.data.Dataset``. See Also: :func:`opennmt.data.training_pipeline` """ if labels_file is not None: data_files = [features_file, labels_file] maximum_length = [maximum_features_length, maximum_labels_length] features_length_fn = self.features_inputter.get_length labels_length_fn = self.labels_inputter.get_length else: data_files = features_file maximum_length = maximum_features_length features_length_fn = self.get_length labels_length_fn = None dataset = self.make_dataset(data_files, training=True) map_fn = lambda *arg: self.make_features( element=misc.item_or_tuple(arg), training=True) filter_fn = lambda *arg: (self.keep_for_training( misc.item_or_tuple(arg), maximum_length=maximum_length)) transform_fns = [ lambda dataset: dataset.map(map_fn, num_parallel_calls=num_threads or 4), lambda dataset: dataset.filter(filter_fn), ] if batch_autotune_mode: # In this mode we want to return batches where all sequences are padded # to the maximum possible length in order to maximize the memory usage. # Shuffling, sharding, prefetching, etc. are not applied since correctness and # performance are not important. if isinstance(dataset, list): # Ignore weighted dataset. dataset = dataset[0] # We repeat the dataset now to ensure full batches are always returned. dataset = dataset.repeat() for transform_fn in transform_fns: dataset = dataset.apply(transform_fn) # length_fn returns the maximum length instead of the actual example length so # that batches are built as if each example has the maximum length. if labels_file is not None: constant_length_fn = [ lambda x: maximum_features_length, lambda x: maximum_labels_length, ] else: constant_length_fn = lambda x: maximum_features_length # The length dimension is set to the maximum length in the padded shapes. padded_shapes = self.get_padded_shapes( dataset.element_spec, maximum_length=maximum_length) dataset = dataset.apply( dataset_util.batch_sequence_dataset( batch_size, batch_type=batch_type, batch_multiplier=batch_multiplier, length_bucket_width=1, length_fn=constant_length_fn, padded_shapes=padded_shapes, )) return dataset if weights is not None: dataset = (dataset, weights) dataset = dataset_util.training_pipeline( batch_size, batch_type=batch_type, batch_multiplier=batch_multiplier, batch_size_multiple=batch_size_multiple, transform_fns=transform_fns, length_bucket_width=length_bucket_width, features_length_fn=features_length_fn, labels_length_fn=labels_length_fn, single_pass=single_pass, num_shards=num_shards, shard_index=shard_index, num_threads=num_threads, dataset_size=self.get_dataset_size(data_files), shuffle_buffer_size=shuffle_buffer_size, prefetch_buffer_size=prefetch_buffer_size, cardinality_multiple=cardinality_multiple, )(dataset) return dataset