def test_no_drop_remainder(apply_transformations, dataset_generator, comm_size, micro_batch_size, num_batches, size_final_batch): batch_size = comm_size * micro_batch_size num_samples = num_batches * batch_size + size_final_batch (x_train, y_train) = dataset_generator(num_samples) reference_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) tnt_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) # Dataset should behve like the sequential dataset with `drop_ramainder=True` tnt_dataset = apply_transformations(tnt_dataset, batch_size=batch_size, drop_remainder=False) for rank in range(comm_size): # verify each rank separately # load local dataset for `rank` dist_dataset = ds.DistributedDataset(tnt_dataset, num_ranks=comm_size, rank=rank) local_dataset = dist_dataset.distribute_dataset_across_ranks() micro_batch_size = dist_dataset.get_microbatch_size(batch_size) # rebuild reference dataset each time to prevent # shuffling effects for repeated iterations ref_dataset = apply_transformations(reference_dataset, batch_size=batch_size, drop_remainder=True) validate_local_dataset(ref_dataset, local_dataset, micro_batch_size, rank)
def test_with_drop_remainder(apply_transformations, dataset_generator, comm_size, micro_batch_size, num_samples, nepochs): batch_size = comm_size * micro_batch_size (x_train, y_train) = dataset_generator(num_samples) reference_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) tnt_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) tnt_dataset = apply_transformations(tnt_dataset, batch_size=batch_size, drop_remainder=True) for rank in range(comm_size): # verify each rank separately # load local dataset for `rank` dist_dataset = ds.DistributedDataset(tnt_dataset, num_ranks=comm_size, rank=rank) local_dataset = dist_dataset.distribute_dataset_across_ranks() micro_batch_size = dist_dataset.get_microbatch_size(batch_size) # rebuild reference dataset each time to prevent # shuffling effects for repeated iterations ref_dataset = apply_transformations(reference_dataset, batch_size=batch_size, drop_remainder=True) for epoch in range(nepochs): validate_local_dataset(ref_dataset, local_dataset, micro_batch_size, rank)
def fit(self, x=None, y=None, callbacks=None, validation_data=None, tnt_micro_batch_size=None, tnt_validation_micro_batch_size=None, tnt_distribute_dataset=True, tnt_distribute_validation_dataset=True, **kwargs): self._setup_for_execution('fit', x, y, callbacks, kwargs) if tnt_distribute_dataset: distributed_x = ds.DistributedDataset( dataset=x, num_ranks=self.comm_size, rank=self.rank, shuffle_seed=self.default_shuffle_seed) x = distributed_x.distribute_dataset_across_ranks( user_micro_batch_size=tnt_micro_batch_size, is_training=True) else: logger.info( "Automatic dataset distribution is disabled." "Make sure the dataset is sharded manually across ranks.") # Always switch off shuffling kwargs["shuffle"] = False if validation_data: if tnt_distribute_validation_dataset: distributed_validation_data = ds.DistributedDataset( dataset=validation_data, num_ranks=self.comm_size, rank=self.rank, shuffle_seed=self.default_shuffle_seed) validation_data = distributed_validation_data.distribute_dataset_across_ranks( user_micro_batch_size=tnt_validation_micro_batch_size, is_training=False) else: logger.info( "Automatic distribution for the validation dataset is disabled." ) return self.model.fit(x, validation_data=validation_data, callbacks=callbacks, **kwargs)
def predict(self, x=None, callbacks=None, tnt_micro_batch_size=None, tnt_distribute_dataset=True, **kwargs): self._setup_for_execution('predict', x, None, callbacks, kwargs) if tnt_distribute_dataset: test_dataset = ds.DistributedDataset( dataset=x, num_ranks=self.comm_size, rank=self.rank, shuffle_seed=self.default_shuffle_seed) x = test_dataset.distribute_dataset_across_ranks( user_micro_batch_size=tnt_micro_batch_size, is_training=False) else: logger.info("Automatic dataset distribution is disabled.") return self.model.predict(x, callbacks=callbacks, **kwargs)
def test_batch_not_multiple_num_ranks(apply_transformations, dataset_generator, comm_size, micro_batch_size, size_batch_remainder): batch_size = comm_size * micro_batch_size + size_batch_remainder num_samples = 4 * batch_size (x_train, y_train) = dataset_generator(num_samples) tnt_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) tnt_dataset = apply_transformations(tnt_dataset, batch_size=batch_size, drop_remainder=True) for rank in range(comm_size): # verify each rank separately dist_dataset = ds.DistributedDataset(tnt_dataset, num_ranks=comm_size, rank=rank) # distributing the dataset should fail because the batch size is not a # multiple of the number of ranks with pytest.raises(ValueError): local_dataset = dist_dataset.distribute_dataset_across_ranks()