def test_state_dict_with_ignored_modules(self): # Initialize an FSDP-wrapped model with an ignored module model = Model(wrap_fsdp=True).cuda() ignored_modules = [model.outer] ignored_param_to_param_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd = fsdp_model.state_dict() with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters are not cloned for param, param_name in ignored_param_to_param_name.items(): self.assertTrue(param_name in sd) self.assertEqual(param.data_ptr(), sd[param_name].data_ptr()) # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() nonwrapped_model.load_state_dict(sd) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
def test_state_dict_with_ignored_modules(self): # Initialize an FSDP-wrapped model with an ignored module that includes # both parameters and a buffer model = Model(wrap_fsdp=True, register_buffers=True).cuda() ignored_modules = [model.outer] ignored_tensor_to_tensor_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", model.outer.buffer: "outer.buffer", } buffer_to_buffer_name = { model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd1 = fsdp_model.state_dict() with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters and all buffers are not cloned for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): self.assertTrue(tensor_name in sd1) self.assertEqual(tensor.data_ptr(), sd1[tensor_name].data_ptr()) # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() nonwrapped_model.load_state_dict(sd1) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param) # Check that if we save a state dict again, the ignored parameters and # buffer still have the same data pointer with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd2 = fsdp_model.state_dict() for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): self.assertTrue(tensor_name in sd1) # check again just in case self.assertTrue(tensor_name in sd2) self.assertEqual(tensor.data_ptr(), sd2[tensor_name].data_ptr()) self.assertEqual(sd1[tensor_name].data_ptr(), sd2[tensor_name].data_ptr())
def test_state_dict_with_ignored_modules(self, prefix, ignore_inner): # Initialize an FSDP-wrapped model with an ignored module that includes # both parameters and a buffer model = Model(wrap_fsdp=True, register_buffers=True, ignore_inner=ignore_inner).cuda() ignored_modules = [model.outer] ignored_tensor_to_tensor_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", } if ignore_inner: ignored_tensor_to_tensor_name = { **ignored_tensor_to_tensor_name, model.inner.bias: "inner.bias", model.inner.weight: "inner.weight", } # Note that when model.inner is not ignored this test also ensures # non-ignored buffers are not cloned. buffer_to_buffer_name = { model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) prefix_str = "foo." if prefix else "" with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd1 = fsdp_model.state_dict(prefix=prefix_str) with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters and all buffers are not cloned for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): prefixed_tensor_name = f"{prefix_str}{tensor_name}" self.assertTrue(prefixed_tensor_name in sd1) self.assertEqual(tensor.data_ptr(), sd1[prefixed_tensor_name].data_ptr(), f"{prefixed_tensor_name}") # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() to_load = {k[len(prefix_str):]: v for k, v in sd1.items()} nonwrapped_model.load_state_dict(to_load, strict=True) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param) # Check that if we save a state dict again, the ignored parameters and # buffer still have the same data pointer with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd2 = fsdp_model.state_dict(prefix=prefix_str) for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): prefixed_tensor_name = f"{prefix_str}{tensor_name}" self.assertTrue(prefixed_tensor_name in sd2) self.assertEqual(tensor.data_ptr(), sd2[prefixed_tensor_name].data_ptr()) self.assertEqual(sd1[prefixed_tensor_name].data_ptr(), sd2[prefixed_tensor_name].data_ptr())