Exemplo n.º 1
0
    def create_and_batch_tfds(self,
                              ds: Dataset,
                              mode,
                              args=None,
                              num_replicas_in_sync=1) -> tf.data.Dataset:
        """ Creates a dataset according to the `mode`.

        Args:
            args: A dict containing dataset arguments.
            ds: A neurst.data.datasets.Dataset object.
            mode: A ModeKeys indicating the running mode.
            num_replicas_in_sync: The number of GPUs or other workers. We will generate global
                batches, and each global batch is equally divisible by number of replicas.

        Returns:
            A tf.data.Dataset.
        """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args, local_overwrite=False)
        pad = tf.constant(self._data_pipeline.meta["pad_id"], dtype=tf.int64)
        dataset = ds.build(map_func=self.get_data_preprocess_fn(
            mode, ds.status, args),
                           map_output_dtypes=self.inputs_signature(mode)[0],
                           auto_shard=(mode == compat.ModeKeys.TRAIN),
                           shuffle=(mode == compat.ModeKeys.TRAIN))

        if mode == compat.ModeKeys.INFER:
            raise NotImplementedError
            # logging.info("Creating test dataset.")
            # return dataset.cache().padded_batch(
            #     dataset_utils.adjust_batch_size(args["batch_size"],
            #                                     num_replicas_in_sync=num_replicas_in_sync),
            #     padded_shapes={"tokens": [None]},
            #     padding_values={"tokens": pad},
            #     drop_remainder=False)
        elif mode == compat.ModeKeys.EVAL:
            logging.info("Creating evaluation dataset.")
            return dataset.cache().padded_batch(
                dataset_utils.adjust_batch_size(
                    args["batch_size"],
                    num_replicas_in_sync=num_replicas_in_sync),
                padded_shapes={"tokens": [None]},
                padding_values={"tokens": pad},
                drop_remainder=False)
        else:
            logging.info("Creating training dataset.")
            level = args.get("gpu_efficient_level", None)
            logging.info(
                f"Creating training dataset with GPU efficient level={level}.")
            dataset = ds.build(
                map_func=self.get_data_preprocess_fn(mode, ds.status, args),
                map_output_dtypes=self.inputs_signature(mode)[0],
                auto_shard=True,
                shuffle=True)
            dataset = dataset_utils.clean_dataset_by_length(
                dataset, {"tokens": args["max_len"]})
            if args["cache_dataset"]:
                dataset = dataset.cache()
            if args["shuffle_buffer"]:
                dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"])
            padding_values = {
                "tokens":
                tf.constant(self._data_pipeline.meta["pad_id"], dtype=tf.int64)
            }
            if args["max_len"] is None:
                raise RuntimeError("Must provide `max_len` for training.")
            max_len = minimal_multiple(args["max_len"],
                                       EFFICIENT_MULTIPLIER[level])
            batch_size = dataset_utils.adjust_batch_size(
                args["batch_size"],
                args["batch_size_per_gpu"],
                num_replicas_in_sync=num_replicas_in_sync,
                verbose=False)
            if level == GPU_EFFICIENT_LEVEL.LEVEL5:  # static batch
                _batch_size = batch_size
                if args["batch_by_tokens"]:
                    _batch_size = _batch_size // max_len
                logging.info(
                    f"Batching dataset with fixed shape: batch={_batch_size} x {max_len})."
                )
                return dataset.padded_batch(
                    _batch_size // num_replicas_in_sync * num_replicas_in_sync,
                    padded_shapes={"tokens": [max_len]},
                    padding_values=padding_values,
                    drop_remainder=True)
            else:
                bucket_boundaries = [
                    EFFICIENT_MULTIPLIER[level] * i
                    for i in range(1, max_len // EFFICIENT_MULTIPLIER[level] +
                                   1)
                ]
                if bucket_boundaries[-1] < max_len:
                    bucket_boundaries.append(
                        minimal_multiple(bucket_boundaries[-1] + 1,
                                         EFFICIENT_MULTIPLIER[level]))
                bucket_boundaries = {"tokens": bucket_boundaries}
                bucket_batch_sizes = dataset_utils.adjust_batch_size(
                    batch_size,
                    bucket_boundaries=bucket_boundaries
                    if args["batch_by_tokens"] else None,
                    boundaries_reduce_to_length_fn=lambda x: max(
                        tf.nest.flatten(x)),
                    num_replicas_in_sync=num_replicas_in_sync)
                if level != GPU_EFFICIENT_LEVEL.LEVEL0:
                    if isinstance(bucket_batch_sizes, list):
                        bucket_batch_sizes = [
                            int(
                                maximum_lower_multiple(
                                    x // num_replicas_in_sync,
                                    EFFICIENT_MULTIPLIER[level]) *
                                num_replicas_in_sync)
                            for x in bucket_batch_sizes
                        ]
                    else:
                        bucket_batch_sizes = int(
                            maximum_lower_multiple(
                                bucket_batch_sizes // num_replicas_in_sync,
                                EFFICIENT_MULTIPLIER[level]) *
                            num_replicas_in_sync)
                return dataset_utils.batch_examples_by_token(
                    dataset,
                    bucket_boundaries=bucket_boundaries,
                    bucket_batch_sizes=bucket_batch_sizes,
                    padding_values=padding_values,
                    example_length_func=lambda x:
                    {k: tf.size(v)
                     for k, v in x.items()})
Exemplo n.º 2
0
    def create_and_batch_tfds(self,
                              ds: Dataset,
                              mode,
                              args=None,
                              num_replicas_in_sync=1) -> tf.data.Dataset:
        """ Creates a dataset according to the `mode`.

        Args:
            args: A dict containing dataset arguments.
            ds: A neurst.data.datasets.Dataset object.
            mode: A ModeKeys indicating the running mode.
            num_replicas_in_sync: The number of GPUs or other workers. We will generate global
                batches, and each global batch is equally divisible by number of replicas.

        Returns:
            A tf.data.Dataset.
        """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args, local_overwrite=False)
        src_eos = tf.constant(self._src_data_pipeline.meta["eos_id"],
                              dtype=tf.int64)
        trg_eos = tf.constant(self._trg_data_pipeline.meta["eos_id"],
                              dtype=tf.int64)

        assert isinstance(ds, AbstractParallelDataset), (
            "The dataset for SeqToSeq task must inherit AbstractParallelDataset."
        )

        dataset = ds.build(map_func=self.get_data_preprocess_fn(
            mode, ds.status, args),
                           map_output_dtypes=self.inputs_signature(mode)[0],
                           auto_shard=(mode == compat.ModeKeys.TRAIN),
                           shuffle=(mode == compat.ModeKeys.TRAIN))

        if mode == compat.ModeKeys.INFER:
            logging.info("Creating test dataset.")
            return dataset.cache().padded_batch(
                dataset_utils.adjust_batch_size(
                    args["batch_size"],
                    num_replicas_in_sync=num_replicas_in_sync),
                padded_shapes={"feature": [None]},
                padding_values={"feature": src_eos},
                drop_remainder=False)
        elif mode == compat.ModeKeys.EVAL:
            logging.info("Creating evaluation dataset.")
            return dataset.cache().padded_batch(
                dataset_utils.adjust_batch_size(
                    args["batch_size"],
                    num_replicas_in_sync=num_replicas_in_sync),
                padded_shapes={
                    "feature": [None],
                    "label": [None]
                },
                padding_values={
                    "feature": src_eos,
                    "label": trg_eos
                },
                drop_remainder=False)
        else:
            logging.info("Creating training dataset.")
            dataset = dataset_utils.clean_dataset_by_length(
                dataset, {
                    "feature": args["max_src_len"],
                    "label": args["max_trg_len"]
                })
            if args["cache_dataset"]:
                dataset = dataset.cache()
            if args["shuffle_buffer"]:
                dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"])
            padding_values = {"feature": src_eos, "label": trg_eos}
            if args["max_src_len"] is None:
                raise RuntimeError("Must provide `max_src_len` for training.")
            if args["max_trg_len"] is None:
                raise RuntimeError("Must provide `max_trg_len` for training.")
            src_bucket_boundaries, trg_bucket_boundaries = dataset_utils.associated_bucket_boundaries(
                dataset_utils.create_batch_bucket_boundaries(
                    args["max_src_len"]),
                dataset_utils.create_batch_bucket_boundaries(
                    args["max_trg_len"]))

            bucket_boundaries = {
                "feature": src_bucket_boundaries,
                "label": trg_bucket_boundaries
            }
            bucket_batch_sizes = dataset_utils.adjust_batch_size(
                args["batch_size"],
                args["batch_size_per_gpu"],
                bucket_boundaries=bucket_boundaries
                if args["batch_by_tokens"] else None,
                boundaries_reduce_to_length_fn=lambda x: max(tf.nest.flatten(x)
                                                             ),
                num_replicas_in_sync=num_replicas_in_sync)
            return dataset_utils.batch_examples_by_token(
                dataset,
                bucket_boundaries=bucket_boundaries,
                bucket_batch_sizes=bucket_batch_sizes,
                padding_values=padding_values,
                example_length_func=lambda x:
                {k: tf.size(v)
                 for k, v in x.items()})
Exemplo n.º 3
0
 def create_and_batch_tfds(self, ds, mode, args=None, num_replicas_in_sync=1):
     """ With efficient level for training. """
     if mode in [compat.ModeKeys.INFER, compat.ModeKeys.EVAL]:
         return super(Translation, self).create_and_batch_tfds(
             ds, mode, args, num_replicas_in_sync)
     if args is None:
         args = self._args
     else:
         args = deep_merge_dict(self._args, args, local_overwrite=False)
     level = args.get("gpu_efficient_level", None)
     auto_scale_batch = args.get("auto_scaling_batch_size", None)
     logging.info(f"Creating training dataset with GPU efficient level={level}.")
     dataset = ds.build(map_func=self.get_data_preprocess_fn(mode, ds.status, args),
                        map_output_dtypes=self.inputs_signature(mode)[0],
                        auto_shard=True, shuffle=True)
     dataset = dataset_utils.clean_dataset_by_length(
         dataset, {"feature": args["max_src_len"], "label": args["max_trg_len"]})
     if args["cache_dataset"]:
         dataset = dataset.cache()
     if args["shuffle_buffer"]:
         dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"])
     padding_values = {"feature": tf.constant(self._src_data_pipeline.meta["eos_id"], dtype=tf.int64),
                       "label": tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64)}
     if args["max_src_len"] is None:
         raise RuntimeError("Must provide `max_src_len` for training.")
     if args["max_trg_len"] is None:
         raise RuntimeError("Must provide `max_trg_len` for training.")
     max_src_len = minimal_multiple(args["max_src_len"], EFFICIENT_MULTIPLIER[level])
     max_trg_len = minimal_multiple(args["max_trg_len"], EFFICIENT_MULTIPLIER[level])
     max_len = max(max_src_len, max_trg_len)
     batch_size = dataset_utils.adjust_batch_size(args["batch_size"], args["batch_size_per_gpu"],
                                                  num_replicas_in_sync=num_replicas_in_sync,
                                                  verbose=False)
     if auto_scale_batch:
         batch_size = _auto_scale_batch_size(batch_size, level)
         logging.info(f"Auto scaling batch size to {batch_size}.")
     if level == GPU_EFFICIENT_LEVEL.LEVEL5:  # static batch
         _batch_size = batch_size
         if args["batch_by_tokens"]:
             _batch_size = _batch_size // max_len
         logging.info("Batching dataset with fixed shape: "
                      f"batch={_batch_size} x (feature={max_src_len}, label={max_trg_len}).")
         return dataset.padded_batch(
             _batch_size // num_replicas_in_sync * num_replicas_in_sync,
             padded_shapes={"feature": [max_src_len], "label": [max_trg_len]},
             drop_remainder=True, padding_values=padding_values)
     else:
         src_bucket_boundaries = [EFFICIENT_MULTIPLIER[level] * i for i in
                                  range(1, max_src_len // EFFICIENT_MULTIPLIER[level] + 1)]
         if src_bucket_boundaries[-1] < max_src_len:
             src_bucket_boundaries.append(minimal_multiple(src_bucket_boundaries[-1] + 1,
                                                           EFFICIENT_MULTIPLIER[level]))
         trg_bucket_boundaries = [EFFICIENT_MULTIPLIER[level] * i for i in
                                  range(1, max_trg_len // EFFICIENT_MULTIPLIER[level] + 1)]
         if trg_bucket_boundaries[-1] < max_trg_len:
             trg_bucket_boundaries.append(minimal_multiple(trg_bucket_boundaries[-1] + 1,
                                                           EFFICIENT_MULTIPLIER[level]))
         src_bucket_boundaries, trg_bucket_boundaries = dataset_utils.associated_bucket_boundaries(
             src_bucket_boundaries, trg_bucket_boundaries)
         bucket_boundaries = {
             "feature": src_bucket_boundaries,
             "label": trg_bucket_boundaries
         }
         bucket_batch_sizes = dataset_utils.adjust_batch_size(
             batch_size,
             bucket_boundaries=bucket_boundaries if args["batch_by_tokens"] else None,
             boundaries_reduce_to_length_fn=lambda x: max(tf.nest.flatten(x)),
             num_replicas_in_sync=num_replicas_in_sync)
         if level != GPU_EFFICIENT_LEVEL.LEVEL0:
             if isinstance(bucket_batch_sizes, list):
                 bucket_batch_sizes = [
                     int(maximum_lower_multiple(x // num_replicas_in_sync,
                                                EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync)
                     for x in bucket_batch_sizes]
             else:
                 bucket_batch_sizes = int(maximum_lower_multiple(
                     bucket_batch_sizes // num_replicas_in_sync,
                     EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync)
         return dataset_utils.batch_examples_by_token(
             dataset,
             bucket_boundaries=bucket_boundaries,
             bucket_batch_sizes=bucket_batch_sizes,
             padding_values=padding_values,
             example_length_func=lambda x: {k: tf.size(v) for k, v in x.items()}
         )
Exemplo n.º 4
0
    def create_and_batch_tfds(self,
                              ds: Dataset,
                              mode,
                              args=None,
                              num_replicas_in_sync=1) -> tf.data.Dataset:
        """ Creates a dataset according to the `mode`.

        Args:
            args: A dict containing dataset arguments.
            ds: A neurst.data.datasets.Dataset object.
            mode: A ModeKeys indicating the running mode.
            num_replicas_in_sync: The number of GPUs or other workers. We will generate global
                batches, and each global batch is equally divisible by number of replicas.

        Returns:
            A tf.data.Dataset.
        """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args, local_overwrite=False)
        float_zero = tf.constant(0, dtype=tf.float32)
        int_zero = tf.constant(0, dtype=tf.int64)
        trg_eos = tf.constant(self._trg_data_pipeline.meta["eos_id"],
                              dtype=tf.int64)

        dataset = ds.build(map_func=self.get_data_preprocess_fn(
            mode, ds.status, args),
                           map_output_dtypes=self.inputs_signature(mode)[0],
                           auto_shard=(mode == compat.ModeKeys.TRAIN),
                           shuffle=(mode == compat.ModeKeys.TRAIN))

        if mode == compat.ModeKeys.INFER:
            logging.info("Creating test dataset.")
            return dataset.cache().padded_batch(
                dataset_utils.adjust_batch_size(
                    args["batch_size"],
                    num_replicas_in_sync=num_replicas_in_sync),
                padded_shapes={
                    "audio": [None],
                    "audio_length": []
                },
                padding_values={
                    "audio": float_zero,
                    "audio_length": int_zero
                },
                drop_remainder=False)

        elif mode == compat.ModeKeys.EVAL:
            logging.info("Creating evaluation dataset.")
            return dataset.cache().padded_batch(
                dataset_utils.adjust_batch_size(
                    args["batch_size"],
                    num_replicas_in_sync=num_replicas_in_sync),
                padded_shapes={
                    "audio": [None],
                    "audio_length": [],
                    "transcript": [None]
                },
                padding_values={
                    "audio": float_zero,
                    "audio_length": int_zero,
                    "transcript": trg_eos
                },
                drop_remainder=False)
        else:
            logging.info("Creating training dataset.")
            dataset = dataset_utils.clean_dataset_by_length(
                dataset, {
                    "audio": args["max_src_len"] * self._audio_feature_dim *
                    self._audio_feature_channels,
                    "audio_length": -1,
                    "transcript": args["max_trg_len"]
                })
            if args["cache_dataset"]:
                dataset = dataset.cache()
            if args["shuffle_buffer"]:
                dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"])
            padding_values = {
                "audio": float_zero,
                "audio_length": int_zero,
                "transcript": trg_eos
            }
            if args["max_src_len"] is None:
                raise RuntimeError(
                    "`max_src_len` for SpeechToText task must be provided.")
            if args["max_trg_len"] is None:
                raise RuntimeError(
                    "`max_trg_len` for SpeechToText task must be provided.")
            max_src_len = args["max_src_len"]
            max_trg_len = minimal_multiple(args["max_trg_len"], 8)
            audio_bucket_boundaries = create_audio_bucket_boundaries(
                max_src_len, args["min_src_bucket_boundary"])
            audio_bucket_boundaries[-1] = minimal_multiple(
                audio_bucket_boundaries[-1], 8)
            batch_size = dataset_utils.adjust_batch_size(
                args["batch_size"],
                args["batch_size_per_gpu"],
                num_replicas_in_sync=num_replicas_in_sync,
                verbose=False)
            batch_size_per_gpu = batch_size // num_replicas_in_sync
            assert batch_size_per_gpu > max_src_len, (
                f"batch size per gpu({batch_size_per_gpu} must be greater than "
                f"`max_src_len`={max_src_len}")
            if args["disable_batch_efficiency"]:
                bucket_batch_sizes = [
                    int(batch_size_per_gpu // bound * num_replicas_in_sync)
                    for bound in audio_bucket_boundaries
                ]
            else:
                bucket_batch_sizes = [
                    int(
                        minimal_multiple(batch_size_per_gpu // bound, 8) *
                        num_replicas_in_sync)
                    for bound in audio_bucket_boundaries
                ]
            frame_transcript_ratio = args[
                "experimental_frame_transcript_ratio"]
            if frame_transcript_ratio is None:
                logging.info(
                    "WARNING: we recommend one to pre-scan the dataset and estimate the ratio: "
                    "frame length / transcript length.")
            else:
                trans_bucket_boundaries = [
                    int(bound /
                        (frame_transcript_ratio + i *
                         (max_src_len / max_trg_len - frame_transcript_ratio) /
                         len(audio_bucket_boundaries)))
                    for i, bound in enumerate(audio_bucket_boundaries)
                ]
                trans_bucket_boundaries = [
                    minimal_multiple(min(i, max_trg_len), 8)
                    for i in trans_bucket_boundaries
                ]
                num_buckets = len(trans_bucket_boundaries)
                true_trans_bucket_boundaries = []
                num_input_shapes = 0
                for idx, (batc, bound, tbound) in enumerate(
                        zip(bucket_batch_sizes, audio_bucket_boundaries,
                            trans_bucket_boundaries)):
                    max_trans_len = [
                        tbound, trans_bucket_boundaries[min(
                            idx + 1,
                            len(bucket_batch_sizes) - 1)]
                    ]
                    num_input_shapes += len(set(max_trans_len))
                    true_trans_bucket_boundaries.append(max_trans_len)
                logging.info(
                    f"There are {num_input_shapes} input shapes to be compiled:"
                )
                for idx, (batc, bound, tbound) in enumerate(
                        zip(bucket_batch_sizes, audio_bucket_boundaries,
                            true_trans_bucket_boundaries)):
                    logging.info(f"   - batch={batc}, maximum-frames={bound}, "
                                 f"maximum-transcript-length={set(tbound)}")
                true_trans_bucket_boundaries = tf.constant(
                    true_trans_bucket_boundaries, dtype=tf.int32)
                true_audio_bucket_boundaries = tf.transpose(
                    tf.constant([audio_bucket_boundaries] * 2, dtype=tf.int32))

            bucket_batch_sizes = tf.constant(bucket_batch_sizes,
                                             dtype=tf.int64)
            audio_bucket_boundaries = tf.constant(audio_bucket_boundaries,
                                                  dtype=tf.int32)

            def example_to_bucket_id(examples):
                """Return int64 bucket id for this example, calculated based on length."""
                if frame_transcript_ratio is None:
                    conditions_c = tf.less_equal(
                        tf.cast(examples["audio_length"], tf.int32),
                        audio_bucket_boundaries)
                    return tf.reduce_min(tf.where(conditions_c))
                conditions_c = tf.logical_and(
                    tf.less_equal(tf.cast(examples["audio_length"], tf.int32),
                                  true_audio_bucket_boundaries),
                    tf.less_equal(tf.size(examples["transcript"]),
                                  true_trans_bucket_boundaries))
                minimum_match = tf.where(conditions_c)[0]
                return minimum_match[0] * num_buckets + minimum_match[1]

            def window_size_fn(bucket_id):
                """Return number of examples to be grouped when given a bucket id."""
                if frame_transcript_ratio is None:
                    return bucket_batch_sizes[bucket_id]
                return bucket_batch_sizes[bucket_id // num_buckets]

            def batching_fn(bucket_id, grouped_dataset):
                """Batch and add padding to a dataset of elements with similar lengths."""
                bucket_batch_size = window_size_fn(bucket_id)

                # Batch the dataset and add padding so that all input sequences in the
                # examples have the same length, and all target sequences have the same
                # lengths as well. Resulting lengths of inputs and targets can differ.
                return grouped_dataset.padded_batch(
                    bucket_batch_size,
                    padded_shapes={
                        "audio":
                        ([(audio_bucket_boundaries[bucket_id]
                           if frame_transcript_ratio is None else
                           audio_bucket_boundaries[bucket_id // num_buckets]) *
                          self._audio_feature_dim *
                          self._audio_feature_channels]),
                        "audio_length": [],
                        "transcript":
                        ([None] if frame_transcript_ratio is None else [
                            true_trans_bucket_boundaries[
                                bucket_id // num_buckets][bucket_id %
                                                          num_buckets]
                        ])
                    },
                    padding_values=padding_values,
                    drop_remainder=True)

            return dataset.apply(
                tf.data.experimental.group_by_window(
                    key_func=example_to_bucket_id,
                    reduce_func=batching_fn,
                    window_size=None,
                    window_size_func=window_size_fn))
Exemplo n.º 5
0
    def create_and_batch_tfds(self,
                              ds,
                              mode,
                              args=None,
                              num_replicas_in_sync=1):
        """ With efficient level for training. """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args)
        level = args.get("gpu_efficient_level", None)
        auto_scale_batch = args.get("auto_scaling_batch_size", None)
        if (mode in [compat.ModeKeys.INFER, compat.ModeKeys.EVAL]
                or level is None or level == GPU_EFFICIENT_LEVEL.LEVEL0):
            return super(Translation,
                         self).create_and_batch_tfds(ds, mode, args,
                                                     num_replicas_in_sync)
        padding_values = {
            "feature":
            tf.constant(self._src_data_pipeline.meta["eos_id"],
                        dtype=tf.int64),
            "label":
            tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64)
        }
        dataset = ds.build(auto_shard=True,
                           map_func=self.get_data_preprocess_fn(
                               mode, ds.status, args),
                           map_output_dtypes=self.inputs_signature(mode)[0])
        max_src_len = args["max_src_len"]
        max_trg_len = args["max_trg_len"]
        batch_by_tokens = args["batch_by_tokens"]
        assert max_src_len, "Must provide `max_src_len` when `gpu_efficient_level` > 0"
        assert max_trg_len, "Must provide `max_trg_len` when `gpu_efficient_level` > 0"
        logging.info(
            f"Creating training dataset with `gpu_efficient_level`={level}.")
        dataset = clean_dataset_by_length(dataset, {
            "feature": max_src_len,
            "label": max_trg_len
        })
        if args["cache_dataset"]:
            dataset = dataset.cache()
        if args["shuffle_buffer"]:
            dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"])
        batch_size_per_gpu = args["batch_size_per_gpu"]
        batch_size = args["batch_size"]
        if batch_size_per_gpu is None:
            batch_size_per_gpu = batch_size // num_replicas_in_sync
        if batch_by_tokens:
            assert batch_size_per_gpu > max(max_src_len, max_trg_len), (
                f"batch size per gpu({batch_size_per_gpu} must be greater than "
                f"both `max_src_len`{max_src_len} and `max_trg_len`{max_trg_len}"
            )
        if auto_scale_batch:
            new_batch_size_per_gpu = _auto_scale_batch_size(
                batch_size_per_gpu, level)
            logging.info(
                f"Auto scaling `batch_size_per_gpu` from {batch_size_per_gpu} "
                f"to {new_batch_size_per_gpu}")
            batch_size_per_gpu = new_batch_size_per_gpu
        max_src_len = minimal_multiple(max_src_len,
                                       EFFICIENT_MULTIPLIER[level])
        max_trg_len = minimal_multiple(max_trg_len,
                                       EFFICIENT_MULTIPLIER[level])
        max_len = max(max_src_len, max_trg_len)
        if level == GPU_EFFICIENT_LEVEL.LEVEL5:  # static batch
            if batch_by_tokens:
                batch_size_per_gpu = batch_size_per_gpu // max_len
            return dataset.padded_batch(int(
                minimal_multiple(batch_size_per_gpu,
                                 EFFICIENT_MULTIPLIER[level]) *
                num_replicas_in_sync),
                                        padded_shapes={
                                            "feature": [max_src_len],
                                            "label": [max_trg_len]
                                        },
                                        drop_remainder=True,
                                        padding_values=padding_values)
        else:
            bucket_boundaries = [
                EFFICIENT_MULTIPLIER[level] * i
                for i in range(1, max_len // EFFICIENT_MULTIPLIER[level] + 1)
            ]
            if bucket_boundaries[-1] < max_len:
                bucket_boundaries.append(
                    minimal_multiple(bucket_boundaries[-1] + 1,
                                     EFFICIENT_MULTIPLIER[level]))
            buckets_min = [0] + bucket_boundaries[:-1]
            if batch_by_tokens:
                bucket_batch_sizes = [
                    int(
                        minimal_multiple(batch_size_per_gpu // bound,
                                         EFFICIENT_MULTIPLIER[level]) *
                        num_replicas_in_sync) for bound in bucket_boundaries
                ]
            else:
                bucket_batch_sizes = [
                    int(
                        minimal_multiple(batch_size_per_gpu,
                                         EFFICIENT_MULTIPLIER[level]) *
                        num_replicas_in_sync)
                ] * len(bucket_boundaries)

            logging.info(
                f"There are {len(bucket_batch_sizes)} input shapes to be compiled:"
            )
            for batc, bound in zip(bucket_batch_sizes, bucket_boundaries):
                logging.info(f"   - batch={batc}, maximum-length={bound}")
            bucket_batch_sizes = tf.constant(bucket_batch_sizes,
                                             dtype=tf.int64)
            bucket_boundaries = tf.constant(bucket_boundaries, dtype=tf.int32)

            def example_to_bucket_id(examples):
                """Return int64 bucket id for this example, calculated based on length."""
                seq_length = tf.cast(
                    tf.maximum(tf.size(examples["feature"]),
                               tf.size(examples["label"])), tf.int32)

                conditions_c = tf.logical_and(
                    tf.less(buckets_min, seq_length),
                    tf.less_equal(seq_length, bucket_boundaries))
                bucket_id = tf.reduce_min(tf.where(conditions_c))
                return bucket_id

            def window_size_fn(bucket_id):
                """Return number of examples to be grouped when given a bucket id."""
                return bucket_batch_sizes[bucket_id]

            def batching_fn(bucket_id, grouped_dataset):
                """Batch and add padding to a dataset of elements with similar lengths."""
                bucket_batch_size = window_size_fn(bucket_id)

                # Batch the dataset and add padding so that all input sequences in the
                # examples have the same length, and all target sequences have the same
                # lengths as well. Resulting lengths of inputs and targets can differ.
                return grouped_dataset.padded_batch(
                    bucket_batch_size,
                    padded_shapes={
                        "feature": [bucket_boundaries[bucket_id]],
                        "label": [bucket_boundaries[bucket_id]]
                    },
                    padding_values=padding_values,
                    drop_remainder=True)

            return dataset.apply(
                tf.data.experimental.group_by_window(
                    key_func=example_to_bucket_id,
                    reduce_func=batching_fn,
                    window_size=None,
                    window_size_func=window_size_fn))