Exemplo n.º 1
0
    def test_state_dict_with_ignored_modules(self):
        # Initialize an FSDP-wrapped model with an ignored module
        model = Model(wrap_fsdp=True).cuda()
        ignored_modules = [model.outer]
        ignored_param_to_param_name = {
            model.outer.bias: "outer.bias",
            model.outer.weight: "outer.weight",
        }
        fsdp_model = FSDP(model, ignored_modules=ignored_modules)
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd = fsdp_model.state_dict()

        with FSDP.summon_full_params(fsdp_model):
            fsdp_params = deepcopy(list(fsdp_model.parameters()))
        # Check that the ignored parameters are not cloned

        for param, param_name in ignored_param_to_param_name.items():
            self.assertTrue(param_name in sd)
            self.assertEqual(param.data_ptr(), sd[param_name].data_ptr())
        # Check that the state dict can be loaded into a non-wrapped version of
        # the model
        nonwrapped_model = Model(wrap_fsdp=False).cuda()
        for param in nonwrapped_model.parameters():
            with torch.no_grad():
                param.zero_()

        nonwrapped_model.load_state_dict(sd)
        local_params = list(nonwrapped_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
 def test_state_dict_with_ignored_modules(self):
     # Initialize an FSDP-wrapped model with an ignored module that includes
     # both parameters and a buffer
     model = Model(wrap_fsdp=True, register_buffers=True).cuda()
     ignored_modules = [model.outer]
     ignored_tensor_to_tensor_name = {
         model.outer.bias: "outer.bias",
         model.outer.weight: "outer.weight",
         model.outer.buffer: "outer.buffer",
     }
     buffer_to_buffer_name = {
         model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer",
     }
     fsdp_model = FSDP(model, ignored_modules=ignored_modules)
     with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
         sd1 = fsdp_model.state_dict()
     with FSDP.summon_full_params(fsdp_model):
         fsdp_params = deepcopy(list(fsdp_model.parameters()))
     # Check that the ignored parameters and all buffers are not cloned
     for tensor, tensor_name in {
         **ignored_tensor_to_tensor_name,
         **buffer_to_buffer_name,
     }.items():
         self.assertTrue(tensor_name in sd1)
         self.assertEqual(tensor.data_ptr(), sd1[tensor_name].data_ptr())
     # Check that the state dict can be loaded into a non-wrapped version of
     # the model
     nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda()
     for param in nonwrapped_model.parameters():
         with torch.no_grad():
             param.zero_()
     nonwrapped_model.load_state_dict(sd1)
     local_params = list(nonwrapped_model.parameters())
     for fsdp_param, local_param in zip(fsdp_params, local_params):
         self.assertEqual(fsdp_param, local_param)
     # Check that if we save a state dict again, the ignored parameters and
     # buffer still have the same data pointer
     with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
         sd2 = fsdp_model.state_dict()
     for tensor, tensor_name in {
         **ignored_tensor_to_tensor_name,
         **buffer_to_buffer_name,
     }.items():
         self.assertTrue(tensor_name in sd1)  # check again just in case
         self.assertTrue(tensor_name in sd2)
         self.assertEqual(tensor.data_ptr(), sd2[tensor_name].data_ptr())
         self.assertEqual(sd1[tensor_name].data_ptr(), sd2[tensor_name].data_ptr())
Exemplo n.º 3
0
    def test_state_dict_with_ignored_modules(self, prefix, ignore_inner):
        # Initialize an FSDP-wrapped model with an ignored module that includes
        # both parameters and a buffer
        model = Model(wrap_fsdp=True,
                      register_buffers=True,
                      ignore_inner=ignore_inner).cuda()
        ignored_modules = [model.outer]
        ignored_tensor_to_tensor_name = {
            model.outer.bias: "outer.bias",
            model.outer.weight: "outer.weight",
        }
        if ignore_inner:
            ignored_tensor_to_tensor_name = {
                **ignored_tensor_to_tensor_name,
                model.inner.bias: "inner.bias",
                model.inner.weight: "inner.weight",
            }
        # Note that when model.inner is not ignored this test also ensures
        # non-ignored buffers are not cloned.
        buffer_to_buffer_name = {
            model.inner.buffer: "inner.buffer",
            model.outer.buffer: "outer.buffer",
        }
        fsdp_model = FSDP(model, ignored_modules=ignored_modules)
        prefix_str = "foo." if prefix else ""
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd1 = fsdp_model.state_dict(prefix=prefix_str)
        with FSDP.summon_full_params(fsdp_model):
            fsdp_params = deepcopy(list(fsdp_model.parameters()))
        # Check that the ignored parameters and all buffers are not cloned
        for tensor, tensor_name in {
                **ignored_tensor_to_tensor_name,
                **buffer_to_buffer_name,
        }.items():
            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
            self.assertTrue(prefixed_tensor_name in sd1)
            self.assertEqual(tensor.data_ptr(),
                             sd1[prefixed_tensor_name].data_ptr(),
                             f"{prefixed_tensor_name}")
        # Check that the state dict can be loaded into a non-wrapped version of
        # the model
        nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda()
        for param in nonwrapped_model.parameters():
            with torch.no_grad():
                param.zero_()

        to_load = {k[len(prefix_str):]: v for k, v in sd1.items()}
        nonwrapped_model.load_state_dict(to_load, strict=True)
        local_params = list(nonwrapped_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
        # Check that if we save a state dict again, the ignored parameters and
        # buffer still have the same data pointer
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd2 = fsdp_model.state_dict(prefix=prefix_str)
        for tensor, tensor_name in {
                **ignored_tensor_to_tensor_name,
                **buffer_to_buffer_name,
        }.items():
            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
            self.assertTrue(prefixed_tensor_name in sd2)
            self.assertEqual(tensor.data_ptr(),
                             sd2[prefixed_tensor_name].data_ptr())
            self.assertEqual(sd1[prefixed_tensor_name].data_ptr(),
                             sd2[prefixed_tensor_name].data_ptr())