def test_full_optim_state_dict_keys(self): """Tests that the parameter keys returned by :meth:`full_optim_state_dict` match those of :meth:`state_dict` with full ``state_dict_type`` for a non-FSDP-root model with nested FSDP instances and ignored modules.""" device = torch.device("cuda") model = NestedModel().to(device) wrapped_model = NestedModel.wrap(model, ignore_modules=True) # Add checkpointing to ensure optim_state_dict and state_dict strip out # checkpointing prefixes. apply_activation_checkpointing_wrapper( model, check_fn=lambda module: isinstance(module, torch.nn.Sequential)) optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._step_model(model, optim, device) optim_state_dict = FSDP.full_optim_state_dict(wrapped_model, optim, rank0_only=False) with FSDP.state_dict_type(wrapped_model, StateDictType.FULL_STATE_DICT): state_dict = wrapped_model.state_dict() self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys()) # Check that checkpointing prefix was indeed stripped. for key in optim_state_dict["state"]: self.assertNotIn(_CHECKPOINT_PREFIX, key)
def _get_full_state_dict_mgr(self, model, state_dict_rank0_and_offload): return FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=state_dict_rank0_and_offload, offload_to_cpu=state_dict_rank0_and_offload, ))
def test_wrong_state_dict_config(self): model = FSDP(Model(wrap_fsdp=True).cuda()) with self.assertRaisesRegex(RuntimeError, "Expected state_dict_config of type"): with model.state_dict_type(model, StateDictType.FULL_STATE_DICT, LocalStateDictConfig()): pass
def test_state_dict_with_ignored_modules(self): # Initialize an FSDP-wrapped model with an ignored module model = Model(wrap_fsdp=True).cuda() ignored_modules = [model.outer] ignored_param_to_param_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd = fsdp_model.state_dict() with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters are not cloned for param, param_name in ignored_param_to_param_name.items(): self.assertTrue(param_name in sd) self.assertEqual(param.data_ptr(), sd[param_name].data_ptr()) # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() nonwrapped_model.load_state_dict(sd) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
def test_state_dict_with_ignored_modules(self): # Initialize an FSDP-wrapped model with an ignored module that includes # both parameters and a buffer model = Model(wrap_fsdp=True, register_buffers=True).cuda() ignored_modules = [model.outer] ignored_tensor_to_tensor_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", model.outer.buffer: "outer.buffer", } buffer_to_buffer_name = { model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd1 = fsdp_model.state_dict() with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters and all buffers are not cloned for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): self.assertTrue(tensor_name in sd1) self.assertEqual(tensor.data_ptr(), sd1[tensor_name].data_ptr()) # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() nonwrapped_model.load_state_dict(sd1) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param) # Check that if we save a state dict again, the ignored parameters and # buffer still have the same data pointer with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd2 = fsdp_model.state_dict() for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): self.assertTrue(tensor_name in sd1) # check again just in case self.assertTrue(tensor_name in sd2) self.assertEqual(tensor.data_ptr(), sd2[tensor_name].data_ptr()) self.assertEqual(sd1[tensor_name].data_ptr(), sd2[tensor_name].data_ptr())
def test_state_dict_type(self): module = SkipModel(double_nest=True) with enable_wrap(wrapper_cls=FSDP): fsdp = wrap(module) with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT): pass for module in FSDP.fsdp_modules(fsdp): self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT)
def _state_dict(model: Module, state_dict_type: str): try: enum_val = STATE_DICT_MAPPING[state_dict_type] except KeyError: raise ValueError(f"No state_dict type for {state_dict_type}") with FSDP.state_dict_type(model, enum_val): return model.state_dict()
def _load_state_dict(model: Module, state_dict_type: str, state_dict: Dict[str, Any]): try: enum_val = STATE_DICT_MAPPING[state_dict_type] except KeyError: raise ValueError(f"No state_dict for {state_dict_type}") with FSDP.state_dict_type(model, enum_val): return model.load_state_dict(state_dict, strict=True)
def _get_state_dict_mgr(self, model, state_dict_type, state_dict_rank0_and_offload): _state_dict_type = STATE_DICT_MAPPING[state_dict_type] if state_dict_type == "state_dict": config = FullStateDictConfig( rank0_only=state_dict_rank0_and_offload, offload_to_cpu=state_dict_rank0_and_offload, ) else: config = None return FSDP.state_dict_type(model, _state_dict_type, config)
def test_distributed_checkpoint(self, state_dict_type) -> None: with enable_wrap(wrapper_cls=FSDP): torch.manual_seed(100) model = wrap(SkipModel(double_nest=True)) torch.manual_seed(200) new_model = wrap(SkipModel(double_nest=True)) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertNotEqual(params, new_params) with tempfile.TemporaryDirectory() as path: paths = [path] dist.broadcast_object_list(paths) path = paths[0] writer = FileSystemWriter(path) reader = FileSystemReader(path) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = model.state_dict() save_state_dict(state_dict, writer) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = new_model.state_dict() load_state_dict(state_dict, reader) new_model.load_state_dict(state_dict) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertEqual(params, new_params)
def test_full_optim_state_dict_keys(self): """Tests that the parameter keys returned by :meth:`full_optim_state_dict` match those of :meth:`state_dict` with full ``state_dict_type`` for a non-FSDP-root model with nested FSDP instances and ignored modules.""" device = torch.device("cuda") model = NestedModel().to(device) wrapped_model = NestedModel.wrap(model, ignore_modules=True) optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._step_model(model, optim, device) optim_state_dict = FSDP.full_optim_state_dict(wrapped_model, optim, rank0_only=False) with FSDP.state_dict_type(wrapped_model, StateDictType.FULL_STATE_DICT): state_dict = wrapped_model.state_dict() self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys())
def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = "", with_context: bool = False): # TODO: Move this test to common_fsdp. model = self._initialize_model(wrap_fsdp) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: blank_model = FSDP(Model(True).cuda()) _zero_model(blank_model) if with_context: state_dict_type = { "state_dict": StateDictType.FULL_STATE_DICT, "local_state_dict": StateDictType.LOCAL_STATE_DICT, "sharded_state_dict": StateDictType.SHARDED_STATE_DICT, }[state_dict_type] with model.state_dict_type(state_dict_type): state_dict = model.state_dict() with blank_model.state_dict_type(state_dict_type): blank_model.load_state_dict(state_dict) else: state_dict = self._state_dict(model, state_dict_type) self._load_state_dict(blank_model, state_dict_type, state_dict) return get_full_params(blank_model) else: return list(model.parameters())
def test_state_dict_with_ignored_modules(self, prefix, ignore_inner): # Initialize an FSDP-wrapped model with an ignored module that includes # both parameters and a buffer model = Model(wrap_fsdp=True, register_buffers=True, ignore_inner=ignore_inner).cuda() ignored_modules = [model.outer] ignored_tensor_to_tensor_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", } if ignore_inner: ignored_tensor_to_tensor_name = { **ignored_tensor_to_tensor_name, model.inner.bias: "inner.bias", model.inner.weight: "inner.weight", } # Note that when model.inner is not ignored this test also ensures # non-ignored buffers are not cloned. buffer_to_buffer_name = { model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) prefix_str = "foo." if prefix else "" with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd1 = fsdp_model.state_dict(prefix=prefix_str) with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters and all buffers are not cloned for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): prefixed_tensor_name = f"{prefix_str}{tensor_name}" self.assertTrue(prefixed_tensor_name in sd1) self.assertEqual(tensor.data_ptr(), sd1[prefixed_tensor_name].data_ptr(), f"{prefixed_tensor_name}") # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() to_load = {k[len(prefix_str):]: v for k, v in sd1.items()} nonwrapped_model.load_state_dict(to_load, strict=True) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param) # Check that if we save a state dict again, the ignored parameters and # buffer still have the same data pointer with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd2 = fsdp_model.state_dict(prefix=prefix_str) for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): prefixed_tensor_name = f"{prefix_str}{tensor_name}" self.assertTrue(prefixed_tensor_name in sd2) self.assertEqual(tensor.data_ptr(), sd2[prefixed_tensor_name].data_ptr()) self.assertEqual(sd1[prefixed_tensor_name].data_ptr(), sd2[prefixed_tensor_name].data_ptr())
def test_basic_save_and_load_state_dict( self, state_dict_type, cpu_offload, fp16, state_dict_rank0_and_offload ): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return for model_call in [ partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() ctx = self._get_state_dict_mgr( model, state_dict_type, state_dict_rank0_and_offload ) with ctx: fsdp_state_dict = _get_state_dict( model, cpu_offload.offload_params, fp16 ) self._validate_state_dict_contents( fsdp_state_dict, state_dict_rank0_and_offload ) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertNotEqual(params, params_new) # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]): model_new.load_state_dict(fsdp_state_dict) with FullyShardedDataParallel.summon_full_params(model_new): with FullyShardedDataParallel.summon_full_params(model): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new) if fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16)
def test_state_dict_skip_module(self, state_dict_type, double_nest): torch.cuda.set_device(self.rank) def _create_module(wrap_fsdp=True): LINEAR_SKIP = "linear_skip" ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress() with ctx: module = SkipModel(double_nest=double_nest) # Full name of linear_skip param tensors in SkipModel, as would be # stored in checkpoint. linear_skip_tensor_names = [ k for k in dict(module.named_parameters()).keys() if LINEAR_SKIP in k ] # skip SkipModule linear_skip = getattr(module, LINEAR_SKIP) delattr(module, LINEAR_SKIP) # Wrap FSDP fsdp = wrap(module) # reattach setattr(module, LINEAR_SKIP, linear_skip) return fsdp, linear_skip_tensor_names fsdp, linear_skip_tensor_names = _create_module() # Run a forward pass inp = torch.randn((1, 10), device=torch.cuda.current_device()) loss = fsdp(inp) loss.sum().backward() with FSDP.state_dict_type(fsdp, STATE_DICT_MAPPING[state_dict_type]): state_dict = fsdp.state_dict() if self.rank == 0 and state_dict_type != "local_state_dict": sd_keys = list(state_dict.keys()) expected = list(SkipModel(double_nest=False).state_dict().keys()) self.assertEqual(sorted(sd_keys), sorted(expected)) # TODO: parameters in linear_skip_tensor_names should not be handled # by FSDP.state_dict(). Have a check once this is implemented in # FSDP.state_dict(). # Check that it can be loaded into FSDP. new_fsdp, _ = _create_module() _zero_model(new_fsdp) for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()): self.assertNotEqual(p1, p2) with FSDP.state_dict_type(new_fsdp, STATE_DICT_MAPPING[state_dict_type]): if state_dict_type != "local_state_dict": # FlatParameter has not supported deepcopy yet. state_dict = deepcopy(state_dict) new_fsdp.load_state_dict(state_dict, strict=True) for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()): self.assertEqual(p1, p2) # Test that the checkpoint can be loaded into a local model. local, _ = _create_module(wrap_fsdp=False) for param in local.parameters(): with torch.no_grad(): param.zero_() with fsdp.summon_full_params(fsdp): for (p1, p2) in zip(fsdp.parameters(), local.parameters()): self.assertNotEqual(p1, p2) if state_dict_type == "local_state_dict": return state_dict = _gather_state_dict(state_dict) with fsdp.summon_full_params(fsdp): if self.rank == 0: local.load_state_dict(state_dict, strict=True) for (p1, p2) in zip(fsdp.parameters(), local.parameters()): self.assertEqual(p1, p2)
def test_basic_save_and_load_state_dict(self, state_dict_type, cpu_offload, fp16, state_dict_rank0_and_offload): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return for model_call in [ partial(self._get_non_fsdp_root_module, cpu_offload=cpu_offload), partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() ctx = self._get_state_dict_mgr(model, state_dict_type, state_dict_rank0_and_offload) with ctx: fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) ignore_keys = [ k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k ] self._validate_state_dict_contents( model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=ignore_keys, ) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) self._compare_models(model, model_new, self.assertNotEqual) # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. if not isinstance(model, FSDP): # Move everything to CPU to avoid running into # https://github.com/pytorch/pytorch/issues/77113, some params # will still be on GPU for non FSDP root modules. for k in fsdp_state_dict.keys(): fsdp_state_dict[k] = fsdp_state_dict[k].cpu() fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]): model_new.load_state_dict(fsdp_state_dict, strict=True) self._compare_models(model, model_new, self.assertEqual, check_fp16=fp16)
def test_save_and_load_after_forward_state_dict( self, state_dict_type, mixed_precision, state_dict_rank0_and_offload): """ Test that saving after some training results in params being updated as expected. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return torch.cuda.set_device(self.rank) mixed_precision = (MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) if mixed_precision else None) model = self._get_simple_nested_model(mixed_precision=mixed_precision) optim = torch.optim.SGD(model.parameters(), lr=0.1) initial_params = get_full_params(model) for _ in range(6): inp = torch.randn(1, 10, device=torch.cuda.current_device()) output = model(*inp) loss = output.sum() expected_dtype = torch.float32 if mixed_precision is None else torch.float16 self.assertEqual(expected_dtype, loss.dtype) loss.backward() optim.step() trained_params = get_full_params(model) # Ensure some training occured self.assertNotEqual(initial_params, trained_params) # Save a copy of the state_dict fsd_mgr = self._get_state_dict_mgr(model, state_dict_type, state_dict_rank0_and_offload) with fsd_mgr: state_dict = model.state_dict() if state_dict_type == "state_dict": state_dict = {k: v.clone() for k, v in state_dict.items()} else: for sharded_tensor in state_dict.values(): shard = sharded_tensor._local_shards[0] shard.tensor = shard.tensor.clone().detach_() self._validate_state_dict_contents(model, state_dict, state_dict_rank0_and_offload) _zero_model(model) # Ensure checkpointed params have the full param dtype for tensor in state_dict.values(): self.assertEqual(tensor.dtype, torch.float32) # Load state_dict into zeroed model if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. state_dict = self._broadcast_state_dict(state_dict) for key in state_dict.keys(): state_dict[key] = state_dict[key].cuda() with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]): model.load_state_dict(state_dict, strict=True) loaded_params = get_full_params(model) self.assertEqual(loaded_params, trained_params)