예제 #1
0
    def create_and_batch_tfds(self,
                              ds: Dataset,
                              mode,
                              args=None,
                              num_replicas_in_sync=1) -> tf.data.Dataset:
        """ Creates a dataset according to the `mode`.

        Args:
            args: A dict containing dataset arguments.
            ds: A neurst.data.datasets.Dataset object.
            mode: A ModeKeys indicating the running mode.
            num_replicas_in_sync: The number of GPUs or other workers. We will generate global
                batches, and each global batch is equally divisible by number of replicas.

        Returns:
            A tf.data.Dataset.
        """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args, local_overwrite=False)
        src_eos = tf.constant(self._src_data_pipeline.meta["eos_id"],
                              dtype=tf.int64)
        trg_eos = tf.constant(self._trg_data_pipeline.meta["eos_id"],
                              dtype=tf.int64)

        assert isinstance(ds, AbstractParallelDataset), (
            "The dataset for SeqToSeq task must inherit AbstractParallelDataset."
        )

        dataset = ds.build(map_func=self.get_data_preprocess_fn(
            mode, ds.status, args),
                           map_output_dtypes=self.inputs_signature(mode)[0],
                           auto_shard=(mode == compat.ModeKeys.TRAIN),
                           shuffle=(mode == compat.ModeKeys.TRAIN))

        if mode == compat.ModeKeys.INFER:
            logging.info("Creating test dataset.")
            return dataset.cache().padded_batch(
                dataset_utils.adjust_batch_size(
                    args["batch_size"],
                    num_replicas_in_sync=num_replicas_in_sync),
                padded_shapes={"feature": [None]},
                padding_values={"feature": src_eos},
                drop_remainder=False)
        elif mode == compat.ModeKeys.EVAL:
            logging.info("Creating evaluation dataset.")
            return dataset.cache().padded_batch(
                dataset_utils.adjust_batch_size(
                    args["batch_size"],
                    num_replicas_in_sync=num_replicas_in_sync),
                padded_shapes={
                    "feature": [None],
                    "label": [None]
                },
                padding_values={
                    "feature": src_eos,
                    "label": trg_eos
                },
                drop_remainder=False)
        else:
            logging.info("Creating training dataset.")
            dataset = dataset_utils.clean_dataset_by_length(
                dataset, {
                    "feature": args["max_src_len"],
                    "label": args["max_trg_len"]
                })
            if args["cache_dataset"]:
                dataset = dataset.cache()
            if args["shuffle_buffer"]:
                dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"])
            padding_values = {"feature": src_eos, "label": trg_eos}
            if args["max_src_len"] is None:
                raise RuntimeError("Must provide `max_src_len` for training.")
            if args["max_trg_len"] is None:
                raise RuntimeError("Must provide `max_trg_len` for training.")
            src_bucket_boundaries, trg_bucket_boundaries = dataset_utils.associated_bucket_boundaries(
                dataset_utils.create_batch_bucket_boundaries(
                    args["max_src_len"]),
                dataset_utils.create_batch_bucket_boundaries(
                    args["max_trg_len"]))

            bucket_boundaries = {
                "feature": src_bucket_boundaries,
                "label": trg_bucket_boundaries
            }
            bucket_batch_sizes = dataset_utils.adjust_batch_size(
                args["batch_size"],
                args["batch_size_per_gpu"],
                bucket_boundaries=bucket_boundaries
                if args["batch_by_tokens"] else None,
                boundaries_reduce_to_length_fn=lambda x: max(tf.nest.flatten(x)
                                                             ),
                num_replicas_in_sync=num_replicas_in_sync)
            return dataset_utils.batch_examples_by_token(
                dataset,
                bucket_boundaries=bucket_boundaries,
                bucket_batch_sizes=bucket_batch_sizes,
                padding_values=padding_values,
                example_length_func=lambda x:
                {k: tf.size(v)
                 for k, v in x.items()})
예제 #2
0
 def create_and_batch_tfds(self, ds, mode, args=None, num_replicas_in_sync=1):
     """ With efficient level for training. """
     if mode in [compat.ModeKeys.INFER, compat.ModeKeys.EVAL]:
         return super(Translation, self).create_and_batch_tfds(
             ds, mode, args, num_replicas_in_sync)
     if args is None:
         args = self._args
     else:
         args = deep_merge_dict(self._args, args, local_overwrite=False)
     level = args.get("gpu_efficient_level", None)
     auto_scale_batch = args.get("auto_scaling_batch_size", None)
     logging.info(f"Creating training dataset with GPU efficient level={level}.")
     dataset = ds.build(map_func=self.get_data_preprocess_fn(mode, ds.status, args),
                        map_output_dtypes=self.inputs_signature(mode)[0],
                        auto_shard=True, shuffle=True)
     dataset = dataset_utils.clean_dataset_by_length(
         dataset, {"feature": args["max_src_len"], "label": args["max_trg_len"]})
     if args["cache_dataset"]:
         dataset = dataset.cache()
     if args["shuffle_buffer"]:
         dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"])
     padding_values = {"feature": tf.constant(self._src_data_pipeline.meta["eos_id"], dtype=tf.int64),
                       "label": tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64)}
     if args["max_src_len"] is None:
         raise RuntimeError("Must provide `max_src_len` for training.")
     if args["max_trg_len"] is None:
         raise RuntimeError("Must provide `max_trg_len` for training.")
     max_src_len = minimal_multiple(args["max_src_len"], EFFICIENT_MULTIPLIER[level])
     max_trg_len = minimal_multiple(args["max_trg_len"], EFFICIENT_MULTIPLIER[level])
     max_len = max(max_src_len, max_trg_len)
     batch_size = dataset_utils.adjust_batch_size(args["batch_size"], args["batch_size_per_gpu"],
                                                  num_replicas_in_sync=num_replicas_in_sync,
                                                  verbose=False)
     if auto_scale_batch:
         batch_size = _auto_scale_batch_size(batch_size, level)
         logging.info(f"Auto scaling batch size to {batch_size}.")
     if level == GPU_EFFICIENT_LEVEL.LEVEL5:  # static batch
         _batch_size = batch_size
         if args["batch_by_tokens"]:
             _batch_size = _batch_size // max_len
         logging.info("Batching dataset with fixed shape: "
                      f"batch={_batch_size} x (feature={max_src_len}, label={max_trg_len}).")
         return dataset.padded_batch(
             _batch_size // num_replicas_in_sync * num_replicas_in_sync,
             padded_shapes={"feature": [max_src_len], "label": [max_trg_len]},
             drop_remainder=True, padding_values=padding_values)
     else:
         src_bucket_boundaries = [EFFICIENT_MULTIPLIER[level] * i for i in
                                  range(1, max_src_len // EFFICIENT_MULTIPLIER[level] + 1)]
         if src_bucket_boundaries[-1] < max_src_len:
             src_bucket_boundaries.append(minimal_multiple(src_bucket_boundaries[-1] + 1,
                                                           EFFICIENT_MULTIPLIER[level]))
         trg_bucket_boundaries = [EFFICIENT_MULTIPLIER[level] * i for i in
                                  range(1, max_trg_len // EFFICIENT_MULTIPLIER[level] + 1)]
         if trg_bucket_boundaries[-1] < max_trg_len:
             trg_bucket_boundaries.append(minimal_multiple(trg_bucket_boundaries[-1] + 1,
                                                           EFFICIENT_MULTIPLIER[level]))
         src_bucket_boundaries, trg_bucket_boundaries = dataset_utils.associated_bucket_boundaries(
             src_bucket_boundaries, trg_bucket_boundaries)
         bucket_boundaries = {
             "feature": src_bucket_boundaries,
             "label": trg_bucket_boundaries
         }
         bucket_batch_sizes = dataset_utils.adjust_batch_size(
             batch_size,
             bucket_boundaries=bucket_boundaries if args["batch_by_tokens"] else None,
             boundaries_reduce_to_length_fn=lambda x: max(tf.nest.flatten(x)),
             num_replicas_in_sync=num_replicas_in_sync)
         if level != GPU_EFFICIENT_LEVEL.LEVEL0:
             if isinstance(bucket_batch_sizes, list):
                 bucket_batch_sizes = [
                     int(maximum_lower_multiple(x // num_replicas_in_sync,
                                                EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync)
                     for x in bucket_batch_sizes]
             else:
                 bucket_batch_sizes = int(maximum_lower_multiple(
                     bucket_batch_sizes // num_replicas_in_sync,
                     EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync)
         return dataset_utils.batch_examples_by_token(
             dataset,
             bucket_boundaries=bucket_boundaries,
             bucket_batch_sizes=bucket_batch_sizes,
             padding_values=padding_values,
             example_length_func=lambda x: {k: tf.size(v) for k, v in x.items()}
         )