def _testBatchTrainDataset(self, check_fn, batch_size, **kwargs):
    num_examples = 1000
    features = tf.random.normal([num_examples], mean=12, stddev=6, seed=42)
    labels_diff = tf.random.normal([num_examples], mean=0, stddev=3, seed=42)
    labels = features + labels_diff

    features = tf.maximum(tf.cast(1, tf.int32), tf.cast(features, tf.int32))
    labels = tf.maximum(tf.cast(1, tf.int32), tf.cast(labels, tf.int32))

    dataset = tf.data.Dataset.zip((
        tf.data.Dataset.from_tensor_slices(features),
        tf.data.Dataset.from_tensor_slices(labels)))
    dataset = dataset.apply(dataset_util.batch_sequence_dataset(
        batch_size,
        length_fn=[lambda x: x, lambda x: x],
        **kwargs))

    iterator = iter(dataset)
    check_fn(iterator)
예제 #2
0
    def make_training_dataset(
        self,
        features_file,
        labels_file,
        batch_size,
        batch_type="examples",
        batch_multiplier=1,
        batch_size_multiple=1,
        shuffle_buffer_size=None,
        length_bucket_width=None,
        maximum_features_length=None,
        maximum_labels_length=None,
        single_pass=False,
        num_shards=1,
        shard_index=0,
        num_threads=4,
        prefetch_buffer_size=None,
        cardinality_multiple=1,
        weights=None,
        batch_autotune_mode=False,
    ):
        """Builds a dataset to be used for training. It supports the full training
        pipeline, including:

        * sharding
        * shuffling
        * filtering
        * bucketing
        * prefetching

        Args:
          features_file: The source file or a list of training source files.
          labels_file: The target file or a list of training target files.
          batch_size: The batch size to use.
          batch_type: The training batching strategy to use: can be "examples" or
            "tokens".
          batch_multiplier: The batch size multiplier to prepare splitting accross
             replicated graph parts.
          batch_size_multiple: When :obj:`batch_type` is "tokens", ensure that the
            resulting batch size is a multiple of this value.
          shuffle_buffer_size: The number of elements from which to sample.
          length_bucket_width: The width of the length buckets to select batch
            candidates from (for efficiency). Set ``None`` to not constrain batch
            formation.
          maximum_features_length: The maximum length or list of maximum lengths of
            the features sequence(s). ``None`` to not constrain the length.
          maximum_labels_length: The maximum length of the labels sequence.
            ``None`` to not constrain the length.
          single_pass: If ``True``, makes a single pass over the training data.
          num_shards: The number of data shards (usually the number of workers in a
            distributed setting).
          shard_index: The shard index this data pipeline should read from.
          num_threads: The number of elements processed in parallel.
          prefetch_buffer_size: The number of batches to prefetch asynchronously. If
            ``None``, use an automatically tuned value.
          cardinality_multiple: Ensure that the dataset cardinality is a multiple of
            this value when :obj:`single_pass` is ``True``.
          weights: An optional list of weights to create a weighted dataset out of
            multiple training files.
          batch_autotune_mode: When enabled, all batches are padded to the maximum
            sequence length.

        Returns:
          A ``tf.data.Dataset``.

        See Also:
          :func:`opennmt.data.training_pipeline`
        """
        if labels_file is not None:
            data_files = [features_file, labels_file]
            maximum_length = [maximum_features_length, maximum_labels_length]
            features_length_fn = self.features_inputter.get_length
            labels_length_fn = self.labels_inputter.get_length
        else:
            data_files = features_file
            maximum_length = maximum_features_length
            features_length_fn = self.get_length
            labels_length_fn = None

        dataset = self.make_dataset(data_files, training=True)

        map_fn = lambda *arg: self.make_features(
            element=misc.item_or_tuple(arg), training=True)
        filter_fn = lambda *arg: (self.keep_for_training(
            misc.item_or_tuple(arg), maximum_length=maximum_length))
        transform_fns = [
            lambda dataset: dataset.map(map_fn,
                                        num_parallel_calls=num_threads or 4),
            lambda dataset: dataset.filter(filter_fn),
        ]

        if batch_autotune_mode:
            # In this mode we want to return batches where all sequences are padded
            # to the maximum possible length in order to maximize the memory usage.
            # Shuffling, sharding, prefetching, etc. are not applied since correctness and
            # performance are not important.

            if isinstance(dataset, list):  # Ignore weighted dataset.
                dataset = dataset[0]

            # We repeat the dataset now to ensure full batches are always returned.
            dataset = dataset.repeat()
            for transform_fn in transform_fns:
                dataset = dataset.apply(transform_fn)

            # length_fn returns the maximum length instead of the actual example length so
            # that batches are built as if each example has the maximum length.
            if labels_file is not None:
                constant_length_fn = [
                    lambda x: maximum_features_length,
                    lambda x: maximum_labels_length,
                ]
            else:
                constant_length_fn = lambda x: maximum_features_length

            # The length dimension is set to the maximum length in the padded shapes.
            padded_shapes = self.get_padded_shapes(
                dataset.element_spec, maximum_length=maximum_length)
            dataset = dataset.apply(
                dataset_util.batch_sequence_dataset(
                    batch_size,
                    batch_type=batch_type,
                    batch_multiplier=batch_multiplier,
                    length_bucket_width=1,
                    length_fn=constant_length_fn,
                    padded_shapes=padded_shapes,
                ))
            return dataset

        if weights is not None:
            dataset = (dataset, weights)
        dataset = dataset_util.training_pipeline(
            batch_size,
            batch_type=batch_type,
            batch_multiplier=batch_multiplier,
            batch_size_multiple=batch_size_multiple,
            transform_fns=transform_fns,
            length_bucket_width=length_bucket_width,
            features_length_fn=features_length_fn,
            labels_length_fn=labels_length_fn,
            single_pass=single_pass,
            num_shards=num_shards,
            shard_index=shard_index,
            num_threads=num_threads,
            dataset_size=self.get_dataset_size(data_files),
            shuffle_buffer_size=shuffle_buffer_size,
            prefetch_buffer_size=prefetch_buffer_size,
            cardinality_multiple=cardinality_multiple,
        )(dataset)
        return dataset