def shuffle_with_seed(self, dataset, ds_kwargs): if not 'seed' in ds_kwargs or ds_kwargs['seed'] is None: logger.warn("Shuffling with fixed shuffle seed {}.".format(self.shuffle_seed)) ds_kwargs['seed'] = self.shuffle_seed else: logger.debug("Shuffling with shuffle seed {}.".format(ds_kwargs['seed'])) return dataset.shuffle(**ds_kwargs)
def _create_tnt_model(cls, model: tf.keras.Model, parallel_strategy: tnt.ParallelStrategy = tnt.ParallelStrategy.ALL if TF_DEFAULT_PIPELINING_FLAG \ else tnt.ParallelStrategy.DATA, num_pipeline_stages: int = 1): replica_group = tnt.Group() if (tnt.ParallelStrategy.PIPELINING in parallel_strategy) and isinstance(model, tf.keras.Sequential): logger.warn( f"Cannot pipeline a `tf.keras.Sequential` model; disabling model parallelism." ) parallel_strategy = parallel_strategy ^ tnt.ParallelStrategy.PIPELINING logger.info(f"Creating parallel model using {parallel_strategy}.") if tnt.ParallelStrategy.PIPELINING in parallel_strategy: rank = tnt.get_rank() partition_generator = pgen.GraphPartitionGenerator(model) rank_mapper = rmapper.RankMapper( num_ranks=tnt.get_size(), pipeline_graph=partition_generator.get_pipeline_graph()) pipeline_group = rank_mapper.get_pipelining_group_for_rank(rank) logger.info( f"[Pipelining] Creating pipelined model with {pipeline_group.size} partitions." ) # get my partition model = pm.PartitionedModel( model=model, group=pipeline_group, partition_generator=partition_generator, rank_mapper=rank_mapper, num_pipeline_stages=num_pipeline_stages) if tnt.ParallelStrategy.DATA in parallel_strategy: replica_group = rank_mapper.get_replica_group_for_rank(rank) else: if pipeline_group.size != tnt.get_size(): raise ValueError( f"Provided model has only {pipeline_group.size} partitions; use {pipeline_group.size} ranks or a different parallel strategy." ) if tnt.ParallelStrategy.DATA in parallel_strategy: # replicate my partition across the data parallel group logger.info( f"[DataParallel] Replicating local model across ranks {replica_group.group}." ) model = dpm.DataParallelModel(model=model, group=replica_group) return model
def _configure_rebuild(self, dataset): self.built = False dist_dataset = tnt.data.Dataset(dataset = dataset, num_ranks = 1, rank = 0) dist_dataset.distribute_dataset_across_ranks(apply_batch = False) # model is already built with the same `nano_batch_size` if self.nano_batch_size == dist_dataset.micro_batch_size // self.num_pipeline_stages: self.built = True return micro_batch_size = dist_dataset.micro_batch_size self.nano_batch_size = micro_batch_size // self.num_pipeline_stages if self.nano_batch_size * self.num_pipeline_stages != micro_batch_size: logger.warn(f"[PartitionedModel] The micro-batch size {self.micro_batch_size} is not a multiple of " f" the number of pipeline stages ({self.num_pipeline_stages}); removing the remainder.")
def _pad_dataset_if_necessary(dataset, num_samples, batch_size, min_last_batch_size): last_batch_size = _get_last_incomplete_batch_size(num_samples, batch_size) if last_batch_size == 0: logger.debug(f"No padding required: number of samples {num_samples} is a multiple " \ f"of the batch size {batch_size}.") return dataset logger.info(f"Incomplete last batch in the dataset: number of samples is " \ f"{last_batch_size} ( != batch size {batch_size}).") if version_utils.tf_version_below_equal('2.1'): num_samples_multiple = num_samples - last_batch_size logger.warn(f"Number of samples ({num_samples}) is not a multiple of batch size. " \ f"This use case is not supported in TF v{version_utils.current_version()}. " \ f"Dropping the last incomplete batch from the dataset, "\ f"and proceeding with {num_samples_multiple} samples.") return dataset.take(num_samples_multiple) if last_batch_size < min_last_batch_size: logger.debug(f"Padding required for the last batch: number of samples is " \ f"{last_batch_size} ( < min_batch_size {min_last_batch_size}).") # Create helper dataset that contains one full batch and one incomplete batch helper_dataset = dataset.take(min_last_batch_size + last_batch_size) helper_dataset = helper_dataset.batch(min_last_batch_size, drop_remainder=False) # If `padded_shape` is unspecified, all dimensions of all components # are padded to the maximum size in the batch. # The second batch in `helper_dataset` will now contain `min_last_batch_size - last_batch_size` # default-initialized samples. helper_dataset = helper_dataset.padded_batch(2) # Switch back to a list of samples instead of batches helper_dataset = helper_dataset.unbatch().unbatch() # Remaining samples in the dataset are those generated through padding padding_samples = helper_dataset.skip(min_last_batch_size + last_batch_size) dataset = dataset.concatenate(padding_samples) logger.info(f"[Rank {tnt.get_rank()}] Dataset padded with " \ f"{min_last_batch_size - last_batch_size} samples.") return dataset
def distributed_batch(self, dataset, batch_size, micro_batch_size): if self.batching_info.drop_remainder == True: dataset = self.batching_info.apply(dataset, new_batch_size = batch_size) dataset = dataset.unbatch() else: # no drop remainder num_samples = ds_helpers.get_num_samples(dataset) if num_samples == tf.data.experimental.INFINITE_CARDINALITY: raise ValueError("[DistributedDataset] Infinite dataset provided") # Total number of samples is not multiple of the batch size if num_samples % batch_size != 0: logger.warn("Number of samples ({}) is not a multiple of batch size.\ Removing the last incomplete batch from the dataset.".format(num_samples)) num_samples_multiple = (num_samples // batch_size) * batch_size dataset = dataset.take(num_samples_multiple) dataset = self.batching_info.apply(dataset, new_batch_size = micro_batch_size) dataset = dataset.shard(num_shards=self.num_ranks, index = self.rank) logger.info("Using batch size = {}, micro batch size = {}.".format( batch_size, micro_batch_size)) return dataset
def _validate_micro_batch_size_for_batch_normalization(self, micro_batch_size): if micro_batch_size < 16: for layer in self.layers: if isinstance(layer, tf.keras.layers.BatchNormalization): logger.warn("Micro batch size should be at least 16 when using Batch Normalization.") return