def test_regnet_fsdp_convergence_on_swav_with_larc(self): """ Run SWAV architecture with DDP or with FSDP with or without activation checkpointing and check that the results match """ with with_temp_files(count=2) as file_names: self.run_pretraining( with_fsdp=False, with_checkpointing=False, with_larc=True, output_file_name=file_names[0], ) self.run_pretraining( with_fsdp=True, with_checkpointing=False, with_larc=True, output_file_name=file_names[1], ) results = [] for file_name in file_names: with open(file_name, "rb") as f: result = pickle.load(f) # TODO (Quentin) - figure out why it diverges slightly after a while result[3] = round(result[3], 5) result[4] = round(result[4], 4) results.append(result) self.assertEqual(len(results[0]), len(results[1]), "DDP vs FSDP (LARC) Loss Lengths") for i, ddp_result in enumerate(results[0]): fsdp_result = results[1][i] self.assertAlmostEqual(ddp_result, fsdp_result, places=5)
def test_regnet_fsdp_convergence_on_swav(self): """ Run SWAV architecture with DDP or with FSDP with or without activation checkpointing and check that the results match """ with with_temp_files(count=3) as file_names: self.run_pretraining( with_fsdp=False, with_checkpointing=False, with_larc=False, output_file_name=file_names[0], ) self.run_pretraining( with_fsdp=True, with_checkpointing=False, with_larc=False, output_file_name=file_names[1], ) self.run_pretraining( with_fsdp=True, with_checkpointing=True, with_larc=False, output_file_name=file_names[2], ) results = [] for file_name in file_names: with open(file_name, "rb") as f: result = pickle.load(f) results.append(result) self.assertEqual(results[0], results[1], "DDP vs FSDP") self.assertEqual(results[1], results[2], "Activation checkpointing")
def test_memory_tracking_fsdp(self): with with_temp_files(count=1) as sync_file: world_size = 2 mp.spawn( self._layer_memory_tracking_worker, (sync_file, world_size), nprocs=world_size, )
def test_synch_bn_pytorch(self, group_size: int): world_size = 2 with with_temp_files(count=1) as sync_file: mp.spawn( self._test_synch_bn_pytorch_worker, (world_size, group_size, sync_file), nprocs=world_size, )
def test_backward_world_size_2(self): with with_temp_files(count=1) as sync_file: WORLD_SIZE = 2 BATCH_SIZE = 2 mp.spawn( self.worker_fn, args=(WORLD_SIZE, BATCH_SIZE, sync_file), nprocs=WORLD_SIZE, )
def test_gather_embeddings_word_size_2(self): with with_temp_files(count=1) as sync_file: WORLD_SIZE = 2 BATCH_SIZE = 2 mp.spawn( self.worker_fn, args=(WORLD_SIZE, BATCH_SIZE, sync_file), nprocs=WORLD_SIZE, )
def test_checkpoint_consolidation(self): with in_temporary_directory(): for with_heads in [True, False]: with with_temp_files(count=1) as sync_file: world_size = 2 mp.spawn( self._worker, (sync_file, world_size, with_heads), nprocs=world_size, )
def test_local_norm_computations(self): """ Sanity check: the sharded and non-sharded norms should be trivially the same when the number of GPU involved in 1 (no sharding) """ with with_temp_files(count=1) as sync_file: world_size = 1 mp.spawn( self._norm_computation_worker, (sync_file, world_size), nprocs=world_size, )
def test_norm_computations(self): """ Trying with 2 GPUs: the sharded computation should span across GPUs and lead to sensibly the same results as normal DDP with LARC """ with with_temp_files(count=1) as sync_file: world_size = 2 mp.spawn( self._norm_computation_worker, (sync_file, world_size), nprocs=world_size, )
def run_pretraining( with_fsdp: bool, with_checkpointing: bool, with_larc: bool, output_file_name: str, ): with with_temp_files(count=1) as sync_file: mp.spawn( TestRegnetFSDP._pretraining_worker, (with_fsdp, with_checkpointing, with_larc, sync_file, output_file_name), nprocs=2, )