Пример #1
0
 def test_rekey_optim_state_dict_to_names(
     self,
     use_multiple_param_groups: bool,
 ):
     """Tests :meth:`rekey_optim_state_dict` with the new keys being
     parameter names by checking that a non-wrapped model (i.e. without FSDP
     modules) can rekey its optimizer state dict to match the expected
     output of :meth:`full_optim_state_dict`, hence be sharded using
     :meth:`shard_full_optim_state_dict`, and finally match the per-rank
     optimizer state dict of a wrapped model (i.e. with FSDP modules)."""
     NUM_ITERS = 3
     # Run a wrapped model for a few iterations
     model1, optim1, optim_input1 = self._init_nested_model(
         wrap=True,
         use_multiple_param_groups=use_multiple_param_groups,
     )
     self._step_model(model1, optim1, num_iters=NUM_ITERS)
     # Run a non-wrapped model for a few iterations
     model2, optim2, optim_input2 = self._init_nested_model(
         wrap=False,
         use_multiple_param_groups=use_multiple_param_groups,
     )
     self._step_model(model2, optim2, num_iters=NUM_ITERS)
     # Re-key the non-wrapped model's optimizer state dict using parameter
     # names (still according to itself)
     osd2 = optim2.state_dict()
     rekeyed_osd = FSDP.rekey_optim_state_dict(
         osd2,
         OptimStateKeyType.PARAM_NAME,
         model2,
         optim_input2,
     )
     # Shard the non-wrapped model's re-keyed optimizer state dict, which
     # maps back to (flattened) parameter IDs
     sharded_osd = FSDP.shard_full_optim_state_dict(
         rekeyed_osd,
         model1,
         optim_input1,
     )
     # Check that this sharded optimizer state dict matches the wrapped
     # model's per-rank optimizer state dict
     osd1 = optim1.state_dict()
     check_same_param_keys = True
     self._check_same_param_groups(
         sharded_osd,
         osd1,
         check_same_param_keys=check_same_param_keys,
     )
     self._check_same_state(
         sharded_osd,
         osd1,
         check_same_param_keys=check_same_param_keys,
     )
     # As a sanity check, check that we can load and run a few iterations
     optim1.load_state_dict(sharded_osd)
     self._step_model(model1, optim1, num_iters=NUM_ITERS)
Пример #2
0
 def test_rekey_optim_state_dict_to_ids(
     self,
     use_multiple_param_groups: bool,
 ):
     """Tests :meth:`rekey_optim_state_dict` with the new keys being
     parameter IDs by checking that a wrapped model (i.e. with FSDP modules)
     can rekey its optimizer state dict to match that of an equivalent
     non-wrapped model (i.e. without FSDP modules)."""
     NUM_ITERS = 3
     # Run a wrapped model for a few iterations
     model1, optim1, optim_input1 = self._init_nested_model(
         wrap=True,
         use_multiple_param_groups=use_multiple_param_groups,
     )
     self._step_model(model1, optim1, num_iters=NUM_ITERS)
     full_osd = FSDP.full_optim_state_dict(model1, optim1, optim_input1)
     # Broadcast instead of `torch.save()`/`torch.load()` so that all ranks
     # have the full state dict
     full_osd = self._broadcast_full_osd(full_osd)
     # Run a non-wrapped model for a few iterations
     model2, optim2, optim_input2 = self._init_nested_model(
         wrap=False,
         use_multiple_param_groups=use_multiple_param_groups,
     )
     self._step_model(model2, optim2, num_iters=NUM_ITERS)
     # Re-key the wrapped model's optimizer state dict using parameter IDs
     # according to the non-wrapped model
     rekeyed_osd = FSDP.rekey_optim_state_dict(
         full_osd,
         OptimStateKeyType.PARAM_ID,
         model2,
         optim_input2,
     )
     # Check that the re-keyed dict and actual dict are the same
     osd = optim2.state_dict()
     check_same_param_keys = True
     self._check_same_param_groups(
         rekeyed_osd,
         osd,
         check_same_param_keys=check_same_param_keys,
     )
     self._check_same_state(
         rekeyed_osd,
         osd,
         check_same_param_keys=check_same_param_keys,
     )
     # As a sanity check, check that we can load and run a few iterations
     optim2.load_state_dict(rekeyed_osd)
     self._step_model(model2, optim2, num_iters=NUM_ITERS)