def test_rekey_optim_state_dict_to_names( self, use_multiple_param_groups: bool, ): """Tests :meth:`rekey_optim_state_dict` with the new keys being parameter names by checking that a non-wrapped model (i.e. without FSDP modules) can rekey its optimizer state dict to match the expected output of :meth:`full_optim_state_dict`, hence be sharded using :meth:`shard_full_optim_state_dict`, and finally match the per-rank optimizer state dict of a wrapped model (i.e. with FSDP modules).""" NUM_ITERS = 3 # Run a wrapped model for a few iterations model1, optim1, optim_input1 = self._init_nested_model( wrap=True, use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model1, optim1, num_iters=NUM_ITERS) # Run a non-wrapped model for a few iterations model2, optim2, optim_input2 = self._init_nested_model( wrap=False, use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model2, optim2, num_iters=NUM_ITERS) # Re-key the non-wrapped model's optimizer state dict using parameter # names (still according to itself) osd2 = optim2.state_dict() rekeyed_osd = FSDP.rekey_optim_state_dict( osd2, OptimStateKeyType.PARAM_NAME, model2, optim_input2, ) # Shard the non-wrapped model's re-keyed optimizer state dict, which # maps back to (flattened) parameter IDs sharded_osd = FSDP.shard_full_optim_state_dict( rekeyed_osd, model1, optim_input1, ) # Check that this sharded optimizer state dict matches the wrapped # model's per-rank optimizer state dict osd1 = optim1.state_dict() check_same_param_keys = True self._check_same_param_groups( sharded_osd, osd1, check_same_param_keys=check_same_param_keys, ) self._check_same_state( sharded_osd, osd1, check_same_param_keys=check_same_param_keys, ) # As a sanity check, check that we can load and run a few iterations optim1.load_state_dict(sharded_osd) self._step_model(model1, optim1, num_iters=NUM_ITERS)
def test_shard_full_optim_state_dict_unmanaged_params( self, add_to_fsdp_module: bool, ): """ Tests :meth:`shard_full_optim_state_dict` when there are unmanaged parameters. - If ``add_to_fsdp_module=True``, then the unmanaged parameters are added to a module to be wrapped with FSDP, in which case there should be an error since we require that all unflattened parameter comprising a flattened parameter have the same scalar state (e.g. Adam "step") but the added parameter is missing its entry. - If ``add_to_fsdp_module=False``, then the unmanaged parameters are added to a module not to be wrapped with FSDP, in which case there should be no error (emulating model parallel use cases where some parameters may be managed externally to FSDP). We do not separately test unmanaged parameters for :meth:`scatter_full_optim_state_dict` to save CI cost since it calls into the same subroutine :meth:`_flatten_full_optim_state_dict`. """ NUM_ITERS = 1 # Create a normal wrapped model model, optim, optim_input = self._init_nested_model(wrap=True) self._step_model(model, optim, num_iters=NUM_ITERS) full_osd = FSDP.full_optim_state_dict( model, optim, optim_input, rank0_only=False, ) # save on all ranks to avoid having to broadcast from rank 0 # Create a new model with the same structure but additional unmanaged # parameters, representing the model for which we want to load device = torch.device("cuda") model = NestedModel().to(device) model, unmanaged_params = NestedModel.wrap_with_unmanaged_params( model, add_to_fsdp_module, ) optim_input = list(model.parameters()) if add_to_fsdp_module: # If we add the unmanaged parameters to a module wrapped with FSDP, # then the flattened parameter will be comprised of some # unflattened parameters with zero-dimensional tensor state (i.e. # Adam "step") and others without (i.e. the unmanaged parameters), # which triggers an error that we have to ensure correctness error_prefix = "^(All unflattened parameters comprising a " \ "single flattened parameter must have scalar state with the " \ "same value and dtype)" with self.assertRaisesRegex(ValueError, error_prefix): FSDP.shard_full_optim_state_dict( full_osd, model, optim_input, ) else: # If we add the unmanaged parameters to a module not wrapped with # FSDP, then we simply ignore them without erroring to enable # model parallelism use cases, where some parameters are managed # externally to FSDP sharded_osd = FSDP.shard_full_optim_state_dict( full_osd, model, optim_input, ) # Add entries for the unmanaged parameters to be able to load for unmanaged_param in unmanaged_params: NestedModel.add_unmanaged_param_entry( sharded_osd, unmanaged_param, NUM_ITERS, ) # Check that we can load the optimizer state dict optim = torch.optim.Adam(optim_input, lr=1e-3) optim.load_state_dict(sharded_osd)
def _test_shard_full_optim_state( self, model_class: str, use_multiple_param_groups: bool, halve_world_size: bool, **new_model_kwargs, ): """ (1) Runs a model with full world size for K iterations to generate a full optimizer state dict; (2) initializes a model with halved world size and possibly different FSDP wrapping scheme (based on ``new_model_kwargs``); (3) shards the full optimizer state dict from (1) according to the halved-world-size model; (4) runs the halved-world-size model for K iterations; and (5) checks that the sharded optimizer state dict from (3) matches the halved-world-size model's local optimizer state dict, meaning that the former could have equivalently been loaded into the local optimizer. """ NUM_ITERS = 3 initializer = self._init_nested_model if model_class == "nested" \ else self._init_transformer_model if model_class == "transformer" \ else None assert initializer is not None, f"Unsupported model: {model_class}" # Run a wrapped model with full world size for a few iterations model1, optim1, optim_input1 = initializer( wrap=True, use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model1, optim1, num_iters=NUM_ITERS) full_osd1 = FSDP.full_optim_state_dict(model1, optim1, optim_input1) # Broadcast instead of `torch.save()`/`torch.load()` so that all ranks # have the full state dict full_osd1 = self._broadcast_full_osd(full_osd1) if halve_world_size: # Create a new process group with halved world size new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0] new_group = dist.new_group(ranks=new_group_ranks) if self.rank not in new_group_ranks: return else: new_group = dist.distributed_c10d._get_default_group() # Run a wrapped model with halved world size (from scratch) model2, optim2, optim_input2 = initializer( wrap=True, group=new_group, use_multiple_param_groups=use_multiple_param_groups, **new_model_kwargs, # specify `wrap_alt` to change wrapping ) self._step_model(model2, optim2, num_iters=NUM_ITERS) full_osd2 = FSDP.full_optim_state_dict(model2, optim2, optim_input2) full_osd2 = self._broadcast_full_osd(full_osd2, group=new_group) # As a sanity check, check that sharding the halved-world-size model's # full optimizer state dict according to itself is equivalent to its # local optimizer's state dict local_osd2 = optim2.state_dict() sharded_osd2 = FSDP.shard_full_optim_state_dict( full_osd2, model2, optim_input2, ) check_same_param_keys = True # should all have matching parameter IDs self._check_same_param_groups( sharded_osd2, local_osd2, check_same_param_keys=check_same_param_keys, ) self._check_same_state( sharded_osd2, local_osd2, check_same_param_keys=check_same_param_keys, ) # Check that sharding the full-world-size model's full optimizer state # dict according to the halved-world-size model is equivalent to the # halved-world-size model's local optimizer state dict sharded_osd1 = FSDP.shard_full_optim_state_dict( full_osd1, model2, optim_input2, ) self._check_same_param_groups( sharded_osd1, local_osd2, check_same_param_keys=check_same_param_keys, ) self._check_same_state( sharded_osd1, local_osd2, check_same_param_keys=check_same_param_keys, ) # As a sanity check, check that we can load and run a few iterations optim2.load_state_dict(sharded_osd1) self._step_model(model2, optim2, num_iters=NUM_ITERS)
def _test_shard_full_optim_state( self, model_class: str, use_multiple_param_groups: bool, halve_world_size: bool, osd_comm_method: _OSDCommMethod, **new_model_kwargs, ): """ (1) Runs a model with full world size for K iterations to generate a full optimizer state dict; (2) initializes a model with halved world size and possibly different FSDP wrapping scheme (based on ``new_model_kwargs``); (3) shards the full optimizer state dict from (1) according to the halved-world-size model; (4) runs the halved-world-size model for K iterations; and (5) checks that the sharded optimizer state dict from (3) matches the halved-world-size model's local optimizer state dict, meaning that the former could have equivalently been loaded into the local optimizer. """ NUM_ITERS = 3 initializer = self._init_nested_model if model_class == "nested" \ else self._init_transformer_model if model_class == "transformer" \ else None assert initializer is not None, f"Unsupported model: {model_class}" # First, run a wrapped model with full world size for a few iterations model1, optim1, optim_input1 = initializer( wrap=True, use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model1, optim1, num_iters=NUM_ITERS) full_osd1 = FSDP.full_optim_state_dict(model1, optim1, optim_input1) if halve_world_size: # Create a new process group with halved world size new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0] new_group = dist.new_group(ranks=new_group_ranks) if self.rank not in new_group_ranks: return else: # Continue using the same group and hence world size new_group = dist.distributed_c10d._get_default_group() # Second, run a wrapped model with (possibly) halved world size model2, optim2, optim_input2 = initializer( wrap=True, group=new_group, use_multiple_param_groups=use_multiple_param_groups, **new_model_kwargs, # specify `wrap_alt` to change wrapping ) self._step_model(model2, optim2, num_iters=NUM_ITERS) full_osd2 = FSDP.full_optim_state_dict(model2, optim2, optim_input2) # Compute two sharded optim state dicts: (1) for the first model # according to the second model and (2) for the second model according # to the second model if osd_comm_method == _OSDCommMethod.BROADCAST_OBJECT_LIST: full_osd1 = self._broadcast_full_osd(full_osd1, group=new_group) sharded_osd1 = FSDP.shard_full_optim_state_dict( full_osd1, model2, optim_input2, ) full_osd2 = self._broadcast_full_osd(full_osd2, group=new_group) sharded_osd2 = FSDP.shard_full_optim_state_dict( full_osd2, model2, optim_input2, ) elif osd_comm_method == _OSDCommMethod.SCATTER_FULL_OSD: sharded_osd1 = FSDP.scatter_full_optim_state_dict( full_osd1 if self.rank == 0 else None, model2, optim_input2, group=new_group, ) sharded_osd2 = FSDP.scatter_full_optim_state_dict( full_osd2 if self.rank == 0 else None, model2, optim_input2, group=new_group, ) self._check_state_device(sharded_osd1, on_gpu=True) self._check_state_device(sharded_osd2, on_gpu=True) # As a sanity check, check that sharding the second model's full # optimizer state dict according to itself is equivalent to its local # optimizer's state dict local_osd2 = optim2.state_dict() check_same_param_keys = True # should all have matching parameter IDs self._check_same_param_groups( sharded_osd2, local_osd2, check_same_param_keys=check_same_param_keys, ) self._check_same_state( sharded_osd2, local_osd2, check_same_param_keys=check_same_param_keys, ) # Check that sharding the first model's full optimizer state dict # according to the second model is equivalent to the second model's # local optimizer state dict self._check_same_param_groups( sharded_osd1, local_osd2, check_same_param_keys=check_same_param_keys, ) self._check_same_state( sharded_osd1, local_osd2, check_same_param_keys=check_same_param_keys, ) # As a sanity check, check that we can load and run a few iterations optim2.load_state_dict(sharded_osd1) self._step_model(model2, optim2, num_iters=NUM_ITERS)