def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu):
        offload = CPUOffload(offload_params=True)
        model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload),
                     cpu_offload=offload)
        local_model = DeterministicModel(wrap_fsdp=False)

        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))

        params_to_compare = ([
            p.clone() for p in model.parameters()
        ] if rank0_only and self.rank != 0 else list(local_model.parameters()))

        with model.summon_full_params(
                model,
                recurse=True,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            # FSDP param deepcopy() of params has issues
            fsdp_params = [p.clone() for p in model.parameters()]

        self.assertEqual(fsdp_params, params_to_compare)
 def test_fsdp_calc_grad_norm_error(self, norm_type):
     """Test the abnormal cases of grad norm cal API."""
     model = DeterministicModel(False)
     input = torch.rand(12, 2, device=self.rank)
     out = model(input)
     out.sum().backward()
     error_msg = f"Order {norm_type} not supported for matrix norm"
     with self.assertRaisesRegex(RuntimeError, error_msg):
         total_norm = _calc_grad_norm(model.parameters(), norm_type)
    def test_summon_full_params_equivalence(self):
        offload = CPUOffload(offload_params=True)
        model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload),
                     cpu_offload=offload)
        local_model = DeterministicModel(wrap_fsdp=False)

        with model.summon_full_params(recurse=True):
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            fsdp_params = deepcopy(list(model.parameters()))

        self.assertEqual(fsdp_params, list(local_model.parameters()))
Exemple #4
0
 def _run_fsdp_one_iteration(self, norm_type, nested_fsdp, cpu_offload):
     """Test FSDP with clip grad norm."""
     fsdp_model = DeterministicModel(nested_fsdp, cpu_offload=cpu_offload)
     local_model = DeterministicModel(False)
     input = torch.rand(14, 2, device=self.rank)
     fsdp_model = FSDP(fsdp_model, cpu_offload=cpu_offload)
     self.assertTrue(len(input) >= self.world_size)
     out = local_model(input[:self.world_size])
     out.sum().backward()
     in_data = torch.tensor(input[self.rank], device=self.rank)
     out_fsdp = fsdp_model(in_data)
     out_fsdp.sum().backward()
     total_norms_fsdp = _collect_total_grad_norm_fsdp(
         fsdp_model, norm_type, self.rank)
     total_norms_local = _collect_total_grad_norm_local(
         local_model, norm_type)
     total_norms_local /= self.world_size
     norm_cap = total_norms_fsdp / 2.0
     self.assertEqual(total_norms_local, total_norms_fsdp)
     fsdp_model.clip_grad_norm_(norm_cap, norm_type=norm_type)
     nn_utils.clip_grad_norm_(local_model.parameters(),
                              norm_cap,
                              norm_type=norm_type)
     total_norms_after_clip_fsdp = _collect_total_grad_norm_fsdp(
         fsdp_model, norm_type, self.rank)
     total_norms_after_clip_local = _collect_total_grad_norm_local(
         local_model, norm_type)
     self.assertTrue(total_norms_after_clip_fsdp <= norm_cap)
     self.assertEqual(total_norms_after_clip_local,
                      total_norms_after_clip_fsdp)
Exemple #5
0
 def test_fsdp_calc_grad_norm(self, norm_type, nested_fsdp):
     """Test grad norm cal API."""
     model = FSDP(DeterministicModel(nested_fsdp))
     input = torch.rand(15, 2, device=self.rank)
     out = model(input)
     out.sum().backward()
     total_norm = _calc_grad_norm(model.params_with_grad, norm_type)
     total_norm_expected = _collect_total_grad_norm_local(model, norm_type)
     self.assertEqual(total_norm, total_norm_expected)
Exemple #6
0
    def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu):
        offload = CPUOffload(offload_params=True)
        model = FSDP(
            DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload
        )
        local_model = DeterministicModel(wrap_fsdp=False)

        params_to_compare = (
            [p.clone() for p in model.parameters()]
            if rank0_only and self.rank != 0
            else list(local_model.parameters())
        )

        writeback = not rank0_only

        with model.summon_full_params(
            model,
            recurse=True,
            rank0_only=rank0_only,
            writeback=writeback,
            offload_to_cpu=offload_to_cpu,
        ):
            if writeback:
                with torch.no_grad():
                    for p in model.parameters():
                        p.add_(1)
                    for p in params_to_compare:
                        p.add_(1)
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            # FSDP param deepcopy() of params has issues
            fsdp_params = [p.clone() for p in model.parameters()]

        self.assertEqual(fsdp_params, params_to_compare)

        # CPU offload is enabled for main API, so we should point back to CPU
        for param in model.parameters():
            self.assertEqual(param.device, torch.device("cpu"))
    def test_summon_from_non_fsdp(self):
        class FSDPContainer(nn.Module):
            def __init__(self, fsdp_1, fsdp_2, fsdp_3):
                super().__init__()
                self.fsdp_1 = fsdp_1
                self.fsdp_2 = fsdp_2
                self.fsdp_3 = fsdp_3

        model_fsdp = FSDPContainer(
            FSDP(DeterministicModel(wrap_fsdp=True)),
            FSDP(DeterministicModel(wrap_fsdp=True)),
            DeterministicModel(wrap_fsdp=False),
        )
        model_no_fsdp = FSDPContainer(
            DeterministicModel(wrap_fsdp=False),
            DeterministicModel(wrap_fsdp=False),
            DeterministicModel(wrap_fsdp=False),
        )

        params_to_compare = list(model_no_fsdp.parameters())
        with FSDP.summon_full_params(model_fsdp):
            fsdp_params = [p.clone() for p in model_fsdp.parameters()]

        self.assertEqual(params_to_compare, fsdp_params)