def test_rekey_optim_state_dict_to_names( self, use_multiple_param_groups: bool, ): """Tests :meth:`rekey_optim_state_dict` with the new keys being parameter names by checking that a non-wrapped model (i.e. without FSDP modules) can rekey its optimizer state dict to match the expected output of :meth:`full_optim_state_dict`, hence be sharded using :meth:`shard_full_optim_state_dict`, and finally match the per-rank optimizer state dict of a wrapped model (i.e. with FSDP modules).""" NUM_ITERS = 3 # Run a wrapped model for a few iterations model1, optim1, optim_input1 = self._init_nested_model( wrap=True, use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model1, optim1, num_iters=NUM_ITERS) # Run a non-wrapped model for a few iterations model2, optim2, optim_input2 = self._init_nested_model( wrap=False, use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model2, optim2, num_iters=NUM_ITERS) # Re-key the non-wrapped model's optimizer state dict using parameter # names (still according to itself) osd2 = optim2.state_dict() rekeyed_osd = FSDP.rekey_optim_state_dict( osd2, OptimStateKeyType.PARAM_NAME, model2, optim_input2, ) # Shard the non-wrapped model's re-keyed optimizer state dict, which # maps back to (flattened) parameter IDs sharded_osd = FSDP.shard_full_optim_state_dict( rekeyed_osd, model1, optim_input1, ) # Check that this sharded optimizer state dict matches the wrapped # model's per-rank optimizer state dict osd1 = optim1.state_dict() check_same_param_keys = True self._check_same_param_groups( sharded_osd, osd1, check_same_param_keys=check_same_param_keys, ) self._check_same_state( sharded_osd, osd1, check_same_param_keys=check_same_param_keys, ) # As a sanity check, check that we can load and run a few iterations optim1.load_state_dict(sharded_osd) self._step_model(model1, optim1, num_iters=NUM_ITERS)
def test_rekey_optim_state_dict_to_ids( self, use_multiple_param_groups: bool, ): """Tests :meth:`rekey_optim_state_dict` with the new keys being parameter IDs by checking that a wrapped model (i.e. with FSDP modules) can rekey its optimizer state dict to match that of an equivalent non-wrapped model (i.e. without FSDP modules).""" NUM_ITERS = 3 # Run a wrapped model for a few iterations model1, optim1, optim_input1 = self._init_nested_model( wrap=True, use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model1, optim1, num_iters=NUM_ITERS) full_osd = FSDP.full_optim_state_dict(model1, optim1, optim_input1) # Broadcast instead of `torch.save()`/`torch.load()` so that all ranks # have the full state dict full_osd = self._broadcast_full_osd(full_osd) # Run a non-wrapped model for a few iterations model2, optim2, optim_input2 = self._init_nested_model( wrap=False, use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model2, optim2, num_iters=NUM_ITERS) # Re-key the wrapped model's optimizer state dict using parameter IDs # according to the non-wrapped model rekeyed_osd = FSDP.rekey_optim_state_dict( full_osd, OptimStateKeyType.PARAM_ID, model2, optim_input2, ) # Check that the re-keyed dict and actual dict are the same osd = optim2.state_dict() check_same_param_keys = True self._check_same_param_groups( rekeyed_osd, osd, check_same_param_keys=check_same_param_keys, ) self._check_same_state( rekeyed_osd, osd, check_same_param_keys=check_same_param_keys, ) # As a sanity check, check that we can load and run a few iterations optim2.load_state_dict(rekeyed_osd) self._step_model(model2, optim2, num_iters=NUM_ITERS)