def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu): offload = CPUOffload(offload_params=True) model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload) local_model = DeterministicModel(wrap_fsdp=False) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([ p.clone() for p in model.parameters() ] if rank0_only and self.rank != 0 else list(local_model.parameters())) with model.summon_full_params( model, recurse=True, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) # FSDP param deepcopy() of params has issues fsdp_params = [p.clone() for p in model.parameters()] self.assertEqual(fsdp_params, params_to_compare)
def test_fsdp_calc_grad_norm_error(self, norm_type): """Test the abnormal cases of grad norm cal API.""" model = DeterministicModel(False) input = torch.rand(12, 2, device=self.rank) out = model(input) out.sum().backward() error_msg = f"Order {norm_type} not supported for matrix norm" with self.assertRaisesRegex(RuntimeError, error_msg): total_norm = _calc_grad_norm(model.parameters(), norm_type)
def test_summon_full_params_equivalence(self): offload = CPUOffload(offload_params=True) model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload) local_model = DeterministicModel(wrap_fsdp=False) with model.summon_full_params(recurse=True): # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) fsdp_params = deepcopy(list(model.parameters())) self.assertEqual(fsdp_params, list(local_model.parameters()))
def _run_fsdp_one_iteration(self, norm_type, nested_fsdp, cpu_offload): """Test FSDP with clip grad norm.""" fsdp_model = DeterministicModel(nested_fsdp, cpu_offload=cpu_offload) local_model = DeterministicModel(False) input = torch.rand(14, 2, device=self.rank) fsdp_model = FSDP(fsdp_model, cpu_offload=cpu_offload) self.assertTrue(len(input) >= self.world_size) out = local_model(input[:self.world_size]) out.sum().backward() in_data = torch.tensor(input[self.rank], device=self.rank) out_fsdp = fsdp_model(in_data) out_fsdp.sum().backward() total_norms_fsdp = _collect_total_grad_norm_fsdp( fsdp_model, norm_type, self.rank) total_norms_local = _collect_total_grad_norm_local( local_model, norm_type) total_norms_local /= self.world_size norm_cap = total_norms_fsdp / 2.0 self.assertEqual(total_norms_local, total_norms_fsdp) fsdp_model.clip_grad_norm_(norm_cap, norm_type=norm_type) nn_utils.clip_grad_norm_(local_model.parameters(), norm_cap, norm_type=norm_type) total_norms_after_clip_fsdp = _collect_total_grad_norm_fsdp( fsdp_model, norm_type, self.rank) total_norms_after_clip_local = _collect_total_grad_norm_local( local_model, norm_type) self.assertTrue(total_norms_after_clip_fsdp <= norm_cap) self.assertEqual(total_norms_after_clip_local, total_norms_after_clip_fsdp)
def test_fsdp_calc_grad_norm(self, norm_type, nested_fsdp): """Test grad norm cal API.""" model = FSDP(DeterministicModel(nested_fsdp)) input = torch.rand(15, 2, device=self.rank) out = model(input) out.sum().backward() total_norm = _calc_grad_norm(model.params_with_grad, norm_type) total_norm_expected = _collect_total_grad_norm_local(model, norm_type) self.assertEqual(total_norm, total_norm_expected)
def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu): offload = CPUOffload(offload_params=True) model = FSDP( DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload ) local_model = DeterministicModel(wrap_fsdp=False) params_to_compare = ( [p.clone() for p in model.parameters()] if rank0_only and self.rank != 0 else list(local_model.parameters()) ) writeback = not rank0_only with model.summon_full_params( model, recurse=True, rank0_only=rank0_only, writeback=writeback, offload_to_cpu=offload_to_cpu, ): if writeback: with torch.no_grad(): for p in model.parameters(): p.add_(1) for p in params_to_compare: p.add_(1) # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) # FSDP param deepcopy() of params has issues fsdp_params = [p.clone() for p in model.parameters()] self.assertEqual(fsdp_params, params_to_compare) # CPU offload is enabled for main API, so we should point back to CPU for param in model.parameters(): self.assertEqual(param.device, torch.device("cpu"))
def test_summon_from_non_fsdp(self): class FSDPContainer(nn.Module): def __init__(self, fsdp_1, fsdp_2, fsdp_3): super().__init__() self.fsdp_1 = fsdp_1 self.fsdp_2 = fsdp_2 self.fsdp_3 = fsdp_3 model_fsdp = FSDPContainer( FSDP(DeterministicModel(wrap_fsdp=True)), FSDP(DeterministicModel(wrap_fsdp=True)), DeterministicModel(wrap_fsdp=False), ) model_no_fsdp = FSDPContainer( DeterministicModel(wrap_fsdp=False), DeterministicModel(wrap_fsdp=False), DeterministicModel(wrap_fsdp=False), ) params_to_compare = list(model_no_fsdp.parameters()) with FSDP.summon_full_params(model_fsdp): fsdp_params = [p.clone() for p in model_fsdp.parameters()] self.assertEqual(params_to_compare, fsdp_params)