def create_and_batch_tfds(self, ds: Dataset, mode, args=None, num_replicas_in_sync=1) -> tf.data.Dataset: """ Creates a dataset according to the `mode`. Args: args: A dict containing dataset arguments. ds: A neurst.data.datasets.Dataset object. mode: A ModeKeys indicating the running mode. num_replicas_in_sync: The number of GPUs or other workers. We will generate global batches, and each global batch is equally divisible by number of replicas. Returns: A tf.data.Dataset. """ if args is None: args = self._args else: args = deep_merge_dict(self._args, args, local_overwrite=False) src_eos = tf.constant(self._src_data_pipeline.meta["eos_id"], dtype=tf.int64) trg_eos = tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64) assert isinstance(ds, AbstractParallelDataset), ( "The dataset for SeqToSeq task must inherit AbstractParallelDataset." ) dataset = ds.build(map_func=self.get_data_preprocess_fn( mode, ds.status, args), map_output_dtypes=self.inputs_signature(mode)[0], auto_shard=(mode == compat.ModeKeys.TRAIN), shuffle=(mode == compat.ModeKeys.TRAIN)) if mode == compat.ModeKeys.INFER: logging.info("Creating test dataset.") return dataset.cache().padded_batch( dataset_utils.adjust_batch_size( args["batch_size"], num_replicas_in_sync=num_replicas_in_sync), padded_shapes={"feature": [None]}, padding_values={"feature": src_eos}, drop_remainder=False) elif mode == compat.ModeKeys.EVAL: logging.info("Creating evaluation dataset.") return dataset.cache().padded_batch( dataset_utils.adjust_batch_size( args["batch_size"], num_replicas_in_sync=num_replicas_in_sync), padded_shapes={ "feature": [None], "label": [None] }, padding_values={ "feature": src_eos, "label": trg_eos }, drop_remainder=False) else: logging.info("Creating training dataset.") dataset = dataset_utils.clean_dataset_by_length( dataset, { "feature": args["max_src_len"], "label": args["max_trg_len"] }) if args["cache_dataset"]: dataset = dataset.cache() if args["shuffle_buffer"]: dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"]) padding_values = {"feature": src_eos, "label": trg_eos} if args["max_src_len"] is None: raise RuntimeError("Must provide `max_src_len` for training.") if args["max_trg_len"] is None: raise RuntimeError("Must provide `max_trg_len` for training.") src_bucket_boundaries, trg_bucket_boundaries = dataset_utils.associated_bucket_boundaries( dataset_utils.create_batch_bucket_boundaries( args["max_src_len"]), dataset_utils.create_batch_bucket_boundaries( args["max_trg_len"])) bucket_boundaries = { "feature": src_bucket_boundaries, "label": trg_bucket_boundaries } bucket_batch_sizes = dataset_utils.adjust_batch_size( args["batch_size"], args["batch_size_per_gpu"], bucket_boundaries=bucket_boundaries if args["batch_by_tokens"] else None, boundaries_reduce_to_length_fn=lambda x: max(tf.nest.flatten(x) ), num_replicas_in_sync=num_replicas_in_sync) return dataset_utils.batch_examples_by_token( dataset, bucket_boundaries=bucket_boundaries, bucket_batch_sizes=bucket_batch_sizes, padding_values=padding_values, example_length_func=lambda x: {k: tf.size(v) for k, v in x.items()})
def create_and_batch_tfds(self, ds, mode, args=None, num_replicas_in_sync=1): """ With efficient level for training. """ if mode in [compat.ModeKeys.INFER, compat.ModeKeys.EVAL]: return super(Translation, self).create_and_batch_tfds( ds, mode, args, num_replicas_in_sync) if args is None: args = self._args else: args = deep_merge_dict(self._args, args, local_overwrite=False) level = args.get("gpu_efficient_level", None) auto_scale_batch = args.get("auto_scaling_batch_size", None) logging.info(f"Creating training dataset with GPU efficient level={level}.") dataset = ds.build(map_func=self.get_data_preprocess_fn(mode, ds.status, args), map_output_dtypes=self.inputs_signature(mode)[0], auto_shard=True, shuffle=True) dataset = dataset_utils.clean_dataset_by_length( dataset, {"feature": args["max_src_len"], "label": args["max_trg_len"]}) if args["cache_dataset"]: dataset = dataset.cache() if args["shuffle_buffer"]: dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"]) padding_values = {"feature": tf.constant(self._src_data_pipeline.meta["eos_id"], dtype=tf.int64), "label": tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64)} if args["max_src_len"] is None: raise RuntimeError("Must provide `max_src_len` for training.") if args["max_trg_len"] is None: raise RuntimeError("Must provide `max_trg_len` for training.") max_src_len = minimal_multiple(args["max_src_len"], EFFICIENT_MULTIPLIER[level]) max_trg_len = minimal_multiple(args["max_trg_len"], EFFICIENT_MULTIPLIER[level]) max_len = max(max_src_len, max_trg_len) batch_size = dataset_utils.adjust_batch_size(args["batch_size"], args["batch_size_per_gpu"], num_replicas_in_sync=num_replicas_in_sync, verbose=False) if auto_scale_batch: batch_size = _auto_scale_batch_size(batch_size, level) logging.info(f"Auto scaling batch size to {batch_size}.") if level == GPU_EFFICIENT_LEVEL.LEVEL5: # static batch _batch_size = batch_size if args["batch_by_tokens"]: _batch_size = _batch_size // max_len logging.info("Batching dataset with fixed shape: " f"batch={_batch_size} x (feature={max_src_len}, label={max_trg_len}).") return dataset.padded_batch( _batch_size // num_replicas_in_sync * num_replicas_in_sync, padded_shapes={"feature": [max_src_len], "label": [max_trg_len]}, drop_remainder=True, padding_values=padding_values) else: src_bucket_boundaries = [EFFICIENT_MULTIPLIER[level] * i for i in range(1, max_src_len // EFFICIENT_MULTIPLIER[level] + 1)] if src_bucket_boundaries[-1] < max_src_len: src_bucket_boundaries.append(minimal_multiple(src_bucket_boundaries[-1] + 1, EFFICIENT_MULTIPLIER[level])) trg_bucket_boundaries = [EFFICIENT_MULTIPLIER[level] * i for i in range(1, max_trg_len // EFFICIENT_MULTIPLIER[level] + 1)] if trg_bucket_boundaries[-1] < max_trg_len: trg_bucket_boundaries.append(minimal_multiple(trg_bucket_boundaries[-1] + 1, EFFICIENT_MULTIPLIER[level])) src_bucket_boundaries, trg_bucket_boundaries = dataset_utils.associated_bucket_boundaries( src_bucket_boundaries, trg_bucket_boundaries) bucket_boundaries = { "feature": src_bucket_boundaries, "label": trg_bucket_boundaries } bucket_batch_sizes = dataset_utils.adjust_batch_size( batch_size, bucket_boundaries=bucket_boundaries if args["batch_by_tokens"] else None, boundaries_reduce_to_length_fn=lambda x: max(tf.nest.flatten(x)), num_replicas_in_sync=num_replicas_in_sync) if level != GPU_EFFICIENT_LEVEL.LEVEL0: if isinstance(bucket_batch_sizes, list): bucket_batch_sizes = [ int(maximum_lower_multiple(x // num_replicas_in_sync, EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync) for x in bucket_batch_sizes] else: bucket_batch_sizes = int(maximum_lower_multiple( bucket_batch_sizes // num_replicas_in_sync, EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync) return dataset_utils.batch_examples_by_token( dataset, bucket_boundaries=bucket_boundaries, bucket_batch_sizes=bucket_batch_sizes, padding_values=padding_values, example_length_func=lambda x: {k: tf.size(v) for k, v in x.items()} )