예제 #1
0
 def test_full_optim_state_dict_keys(self):
     """Tests that the parameter keys returned by
     :meth:`full_optim_state_dict` match those of :meth:`state_dict` with
     full ``state_dict_type`` for a non-FSDP-root model with nested FSDP
     instances and ignored modules."""
     device = torch.device("cuda")
     model = NestedModel().to(device)
     wrapped_model = NestedModel.wrap(model, ignore_modules=True)
     # Add checkpointing to ensure optim_state_dict and state_dict strip out
     # checkpointing prefixes.
     apply_activation_checkpointing_wrapper(
         model,
         check_fn=lambda module: isinstance(module, torch.nn.Sequential))
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     self._step_model(model, optim, device)
     optim_state_dict = FSDP.full_optim_state_dict(wrapped_model,
                                                   optim,
                                                   rank0_only=False)
     with FSDP.state_dict_type(wrapped_model,
                               StateDictType.FULL_STATE_DICT):
         state_dict = wrapped_model.state_dict()
     self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys())
     # Check that checkpointing prefix was indeed stripped.
     for key in optim_state_dict["state"]:
         self.assertNotIn(_CHECKPOINT_PREFIX, key)
예제 #2
0
    def test_full_optim_state_dict_nested(
        self,
        use_multiple_param_groups: bool,
        rank0_only: bool,
        use_diff_optim_inputs: bool,
    ) -> None:
        """
        Tests :meth:`full_optim_state_dict` by comparing the returned dict for
        an FSDP-wrapped model with that of an equivalent non-wrapped model.

        The test checks the equivalence excluding the parameter keys since the
        FSDP and normal optimizer state dicts key by names and IDs,
        respectively. This means that the test can pass even if parameter keys
        are incorrectly mapped to values. Their correct mapping is tested in
        other tests that exercise the save/load workflow.
        """
        NUM_ITERS = 3
        model1, optim1, optim_input = self._init_nested_model(
            wrap=True,
            use_multiple_param_groups=use_multiple_param_groups,
            use_diff_optim_inputs=use_diff_optim_inputs,
        )
        losses1 = self._step_model(model1, optim1, num_iters=NUM_ITERS)
        full_osd = FSDP.full_optim_state_dict(
            model1,
            optim1,
            optim_input,
            rank0_only=rank0_only,
        )
        # Non-target ranks get an empty state dict
        if rank0_only and self.rank != 0:
            self.assertEqual(len(full_osd), 0)
            return
        model2, optim2, _ = self._init_nested_model(
            wrap=False,
            use_multiple_param_groups=use_multiple_param_groups,
            use_diff_optim_inputs=use_diff_optim_inputs,
        )
        losses2 = self._step_model(model2, optim2, num_iters=NUM_ITERS)
        ref_osd = optim2.state_dict()
        # Check the losses to eliminate model drift as a source of error
        for i, (l1, l2) in enumerate(zip(losses1, losses2)):
            assert l1 == l2, f"Losses differ on iter {i}: {l1:.5f} {l2:.5f}"
        # Do not check the parameter keys since the full optimizer state dict
        # uses parameter names, while the non-wrapped equivalent uses parameter
        # IDs
        check_same_param_keys = False
        self._check_same_param_groups(
            full_osd,
            ref_osd,
            check_same_param_keys=check_same_param_keys,
        )
        self._check_same_state(
            full_osd,
            ref_osd,
            check_same_param_keys=check_same_param_keys,
        )
예제 #3
0
    def test_full_optim_state_dict_nested(
        self,
        use_multiple_param_groups: bool,
        rank0_only: bool,
    ) -> None:
        """
        Tests :meth:`full_optim_state_dict` by comparing the returned dict for
        an FSDP-wrapped model with that of an equivalent non-wrapped model.

        The parameter groups in the "param_groups" part and the values in the
        "state" part should be the same, but the parameter keys may be
        different (e.g. the full optimizer state dict uses parameter names
        while the non-wrapped equivalent uses parameter IDs).
        """
        NUM_ITERS = 3
        model1, optim1, optim_input = self._init_nested_model(
            wrap=True,
            use_multiple_param_groups=use_multiple_param_groups,
        )
        losses1 = self._step_model(model1, optim1, num_iters=NUM_ITERS)
        full_osd = FSDP.full_optim_state_dict(
            model1,
            optim1,
            optim_input,
            rank0_only=rank0_only,
        )
        # Non-target ranks get an empty state dict
        if rank0_only and self.rank != 0:
            self.assertEqual(len(full_osd), 0)
            return
        model2, optim2, _ = self._init_nested_model(
            wrap=False,
            use_multiple_param_groups=use_multiple_param_groups,
        )
        losses2 = self._step_model(model2, optim2, num_iters=NUM_ITERS)
        ref_osd = optim2.state_dict()
        # Check the losses to eliminate model drift as a source of error
        for i, (l1, l2) in enumerate(zip(losses1, losses2)):
            assert l1 == l2, f"Losses differ on iter {i}: {l1:.5f} {l2:.5f}"
        # Do not check the parameter keys since the full optimizer state dict
        # uses parameter names, while the non-wrapped equivalent uses parameter
        # IDs
        check_same_param_keys = False
        self._check_same_param_groups(
            full_osd,
            ref_osd,
            check_same_param_keys=check_same_param_keys,
        )
        self._check_same_state(
            full_osd,
            ref_osd,
            check_same_param_keys=check_same_param_keys,
        )
예제 #4
0
 def test_full_optim_state_dict_nested_invalid(self):
     """Tests that :meth:`full_optim_state_dict` raises an error when
     nonzero ranks are missing the optimizer state for parameters on rank
     0."""
     device = torch.device("cuda")
     model = NestedModel.wrap(NestedModel().to(device), None)
     optim_input = list(model.parameters())
     if self.rank != 0:
         # Exclude a parameter so that nonzero ranks are missing state
         optim_input = optim_input[:-1]
     optim = torch.optim.Adam(optim_input, lr=1e-3)
     self._step_model(model, optim, num_iters=3)
     error_regex = (
         "FSDP currently requires each rank to have at least the "
         "optimizer states needed by rank 0's optimizer but some ranks "
         "are missing some of those states"
     )
     with self.assertRaisesRegex(RuntimeError, error_regex):
         FSDP.full_optim_state_dict(
             model, optim, optim_input,
         )
예제 #5
0
 def test_rekey_optim_state_dict_to_ids(
     self,
     use_multiple_param_groups: bool,
 ):
     """Tests :meth:`rekey_optim_state_dict` with the new keys being
     parameter IDs by checking that a wrapped model (i.e. with FSDP modules)
     can rekey its optimizer state dict to match that of an equivalent
     non-wrapped model (i.e. without FSDP modules)."""
     NUM_ITERS = 3
     # Run a wrapped model for a few iterations
     model1, optim1, optim_input1 = self._init_nested_model(
         wrap=True,
         use_multiple_param_groups=use_multiple_param_groups,
     )
     self._step_model(model1, optim1, num_iters=NUM_ITERS)
     full_osd = FSDP.full_optim_state_dict(model1, optim1, optim_input1)
     # Broadcast instead of `torch.save()`/`torch.load()` so that all ranks
     # have the full state dict
     full_osd = self._broadcast_full_osd(full_osd)
     # Run a non-wrapped model for a few iterations
     model2, optim2, optim_input2 = self._init_nested_model(
         wrap=False,
         use_multiple_param_groups=use_multiple_param_groups,
     )
     self._step_model(model2, optim2, num_iters=NUM_ITERS)
     # Re-key the wrapped model's optimizer state dict using parameter IDs
     # according to the non-wrapped model
     rekeyed_osd = FSDP.rekey_optim_state_dict(
         full_osd,
         OptimStateKeyType.PARAM_ID,
         model2,
         optim_input2,
     )
     # Check that the re-keyed dict and actual dict are the same
     osd = optim2.state_dict()
     check_same_param_keys = True
     self._check_same_param_groups(
         rekeyed_osd,
         osd,
         check_same_param_keys=check_same_param_keys,
     )
     self._check_same_state(
         rekeyed_osd,
         osd,
         check_same_param_keys=check_same_param_keys,
     )
     # As a sanity check, check that we can load and run a few iterations
     optim2.load_state_dict(rekeyed_osd)
     self._step_model(model2, optim2, num_iters=NUM_ITERS)
예제 #6
0
 def test_full_optim_state_dict_keys(self):
     """Tests that the parameter keys returned by
     :meth:`full_optim_state_dict` match those of :meth:`state_dict` with
     full ``state_dict_type`` for a non-FSDP-root model with nested FSDP
     instances and ignored modules."""
     device = torch.device("cuda")
     model = NestedModel().to(device)
     wrapped_model = NestedModel.wrap(model, ignore_modules=True)
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     self._step_model(model, optim, device)
     optim_state_dict = FSDP.full_optim_state_dict(wrapped_model,
                                                   optim,
                                                   rank0_only=False)
     with FSDP.state_dict_type(wrapped_model,
                               StateDictType.FULL_STATE_DICT):
         state_dict = wrapped_model.state_dict()
     self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys())
예제 #7
0
 def _test_shard_full_optim_state(
     self,
     model_class: str,
     use_multiple_param_groups: bool,
     halve_world_size: bool,
     **new_model_kwargs,
 ):
     """
     (1) Runs a model with full world size for K iterations to generate a
     full optimizer state dict;
     (2) initializes a model with halved world size and possibly different
     FSDP wrapping scheme (based on ``new_model_kwargs``);
     (3) shards the full optimizer state dict from (1) according to the
     halved-world-size model;
     (4) runs the halved-world-size model for K iterations; and
     (5) checks that the sharded optimizer state dict from (3) matches the
     halved-world-size model's local optimizer state dict, meaning that the
     former could have equivalently been loaded into the local optimizer.
     """
     NUM_ITERS = 3
     initializer = self._init_nested_model if model_class == "nested" \
         else self._init_transformer_model if model_class == "transformer" \
         else None
     assert initializer is not None, f"Unsupported model: {model_class}"
     # Run a wrapped model with full world size for a few iterations
     model1, optim1, optim_input1 = initializer(
         wrap=True,
         use_multiple_param_groups=use_multiple_param_groups,
     )
     self._step_model(model1, optim1, num_iters=NUM_ITERS)
     full_osd1 = FSDP.full_optim_state_dict(model1, optim1, optim_input1)
     # Broadcast instead of `torch.save()`/`torch.load()` so that all ranks
     # have the full state dict
     full_osd1 = self._broadcast_full_osd(full_osd1)
     if halve_world_size:
         # Create a new process group with halved world size
         new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0]
         new_group = dist.new_group(ranks=new_group_ranks)
         if self.rank not in new_group_ranks:
             return
     else:
         new_group = dist.distributed_c10d._get_default_group()
     # Run a wrapped model with halved world size (from scratch)
     model2, optim2, optim_input2 = initializer(
         wrap=True,
         group=new_group,
         use_multiple_param_groups=use_multiple_param_groups,
         **new_model_kwargs,  # specify `wrap_alt` to change wrapping
     )
     self._step_model(model2, optim2, num_iters=NUM_ITERS)
     full_osd2 = FSDP.full_optim_state_dict(model2, optim2, optim_input2)
     full_osd2 = self._broadcast_full_osd(full_osd2, group=new_group)
     # As a sanity check, check that sharding the halved-world-size model's
     # full optimizer state dict according to itself is equivalent to its
     # local optimizer's state dict
     local_osd2 = optim2.state_dict()
     sharded_osd2 = FSDP.shard_full_optim_state_dict(
         full_osd2,
         model2,
         optim_input2,
     )
     check_same_param_keys = True  # should all have matching parameter IDs
     self._check_same_param_groups(
         sharded_osd2,
         local_osd2,
         check_same_param_keys=check_same_param_keys,
     )
     self._check_same_state(
         sharded_osd2,
         local_osd2,
         check_same_param_keys=check_same_param_keys,
     )
     # Check that sharding the full-world-size model's full optimizer state
     # dict according to the halved-world-size model is equivalent to the
     # halved-world-size model's local optimizer state dict
     sharded_osd1 = FSDP.shard_full_optim_state_dict(
         full_osd1,
         model2,
         optim_input2,
     )
     self._check_same_param_groups(
         sharded_osd1,
         local_osd2,
         check_same_param_keys=check_same_param_keys,
     )
     self._check_same_state(
         sharded_osd1,
         local_osd2,
         check_same_param_keys=check_same_param_keys,
     )
     # As a sanity check, check that we can load and run a few iterations
     optim2.load_state_dict(sharded_osd1)
     self._step_model(model2, optim2, num_iters=NUM_ITERS)
예제 #8
0
 def test_shard_full_optim_state_dict_unmanaged_params(
     self,
     add_to_fsdp_module: bool,
 ):
     """
     Tests :meth:`shard_full_optim_state_dict` when there are unmanaged
     parameters.
       - If ``add_to_fsdp_module=True``, then the unmanaged parameters are
       added to a module to be wrapped with FSDP, in which case there should
       be an error since we require that all unflattened parameter
       comprising a flattened parameter have the same scalar state (e.g.
       Adam "step") but the added parameter is missing its entry.
       - If ``add_to_fsdp_module=False``, then the unmanaged parameters are
       added to a module not to be wrapped with FSDP, in which case there
       should be no error (emulating model parallel use cases where some
       parameters may be managed externally to FSDP).
     We do not separately test unmanaged parameters for
     :meth:`scatter_full_optim_state_dict` to save CI cost since it calls
     into the same subroutine :meth:`_flatten_full_optim_state_dict`.
     """
     NUM_ITERS = 1
     # Create a normal wrapped model
     model, optim, optim_input = self._init_nested_model(wrap=True)
     self._step_model(model, optim, num_iters=NUM_ITERS)
     full_osd = FSDP.full_optim_state_dict(
         model,
         optim,
         optim_input,
         rank0_only=False,
     )  # save on all ranks to avoid having to broadcast from rank 0
     # Create a new model with the same structure but additional unmanaged
     # parameters, representing the model for which we want to load
     device = torch.device("cuda")
     model = NestedModel().to(device)
     model, unmanaged_params = NestedModel.wrap_with_unmanaged_params(
         model,
         add_to_fsdp_module,
     )
     optim_input = list(model.parameters())
     if add_to_fsdp_module:
         # If we add the unmanaged parameters to a module wrapped with FSDP,
         # then the flattened parameter will be comprised of some
         # unflattened parameters with zero-dimensional tensor state (i.e.
         # Adam "step") and others without (i.e. the unmanaged parameters),
         # which triggers an error that we have to ensure correctness
         error_prefix = "^(All unflattened parameters comprising a " \
             "single flattened parameter must have scalar state with the " \
             "same value and dtype)"
         with self.assertRaisesRegex(ValueError, error_prefix):
             FSDP.shard_full_optim_state_dict(
                 full_osd,
                 model,
                 optim_input,
             )
     else:
         # If we add the unmanaged parameters to a module not wrapped with
         # FSDP, then we simply ignore them without erroring to enable
         # model parallelism use cases, where some parameters are managed
         # externally to FSDP
         sharded_osd = FSDP.shard_full_optim_state_dict(
             full_osd,
             model,
             optim_input,
         )
         # Add entries for the unmanaged parameters to be able to load
         for unmanaged_param in unmanaged_params:
             NestedModel.add_unmanaged_param_entry(
                 sharded_osd,
                 unmanaged_param,
                 NUM_ITERS,
             )
         # Check that we can load the optimizer state dict
         optim = torch.optim.Adam(optim_input, lr=1e-3)
         optim.load_state_dict(sharded_osd)
예제 #9
0
 def _test_shard_full_optim_state(
     self,
     model_class: str,
     use_multiple_param_groups: bool,
     halve_world_size: bool,
     osd_comm_method: _OSDCommMethod,
     **new_model_kwargs,
 ):
     """
     (1) Runs a model with full world size for K iterations to generate a
     full optimizer state dict;
     (2) initializes a model with halved world size and possibly different
     FSDP wrapping scheme (based on ``new_model_kwargs``);
     (3) shards the full optimizer state dict from (1) according to the
     halved-world-size model;
     (4) runs the halved-world-size model for K iterations; and
     (5) checks that the sharded optimizer state dict from (3) matches the
     halved-world-size model's local optimizer state dict, meaning that the
     former could have equivalently been loaded into the local optimizer.
     """
     NUM_ITERS = 3
     initializer = self._init_nested_model if model_class == "nested" \
         else self._init_transformer_model if model_class == "transformer" \
         else None
     assert initializer is not None, f"Unsupported model: {model_class}"
     # First, run a wrapped model with full world size for a few iterations
     model1, optim1, optim_input1 = initializer(
         wrap=True,
         use_multiple_param_groups=use_multiple_param_groups,
     )
     self._step_model(model1, optim1, num_iters=NUM_ITERS)
     full_osd1 = FSDP.full_optim_state_dict(model1, optim1, optim_input1)
     if halve_world_size:
         # Create a new process group with halved world size
         new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0]
         new_group = dist.new_group(ranks=new_group_ranks)
         if self.rank not in new_group_ranks:
             return
     else:
         # Continue using the same group and hence world size
         new_group = dist.distributed_c10d._get_default_group()
     # Second, run a wrapped model with (possibly) halved world size
     model2, optim2, optim_input2 = initializer(
         wrap=True,
         group=new_group,
         use_multiple_param_groups=use_multiple_param_groups,
         **new_model_kwargs,  # specify `wrap_alt` to change wrapping
     )
     self._step_model(model2, optim2, num_iters=NUM_ITERS)
     full_osd2 = FSDP.full_optim_state_dict(model2, optim2, optim_input2)
     # Compute two sharded optim state dicts: (1) for the first model
     # according to the second model and (2) for the second model according
     # to the second model
     if osd_comm_method == _OSDCommMethod.BROADCAST_OBJECT_LIST:
         full_osd1 = self._broadcast_full_osd(full_osd1, group=new_group)
         sharded_osd1 = FSDP.shard_full_optim_state_dict(
             full_osd1,
             model2,
             optim_input2,
         )
         full_osd2 = self._broadcast_full_osd(full_osd2, group=new_group)
         sharded_osd2 = FSDP.shard_full_optim_state_dict(
             full_osd2,
             model2,
             optim_input2,
         )
     elif osd_comm_method == _OSDCommMethod.SCATTER_FULL_OSD:
         sharded_osd1 = FSDP.scatter_full_optim_state_dict(
             full_osd1 if self.rank == 0 else None,
             model2,
             optim_input2,
             group=new_group,
         )
         sharded_osd2 = FSDP.scatter_full_optim_state_dict(
             full_osd2 if self.rank == 0 else None,
             model2,
             optim_input2,
             group=new_group,
         )
         self._check_state_device(sharded_osd1, on_gpu=True)
         self._check_state_device(sharded_osd2, on_gpu=True)
     # As a sanity check, check that sharding the second model's full
     # optimizer state dict according to itself is equivalent to its local
     # optimizer's state dict
     local_osd2 = optim2.state_dict()
     check_same_param_keys = True  # should all have matching parameter IDs
     self._check_same_param_groups(
         sharded_osd2,
         local_osd2,
         check_same_param_keys=check_same_param_keys,
     )
     self._check_same_state(
         sharded_osd2,
         local_osd2,
         check_same_param_keys=check_same_param_keys,
     )
     # Check that sharding the first model's full optimizer state dict
     # according to the second model is equivalent to the second model's
     # local optimizer state dict
     self._check_same_param_groups(
         sharded_osd1,
         local_osd2,
         check_same_param_keys=check_same_param_keys,
     )
     self._check_same_state(
         sharded_osd1,
         local_osd2,
         check_same_param_keys=check_same_param_keys,
     )
     # As a sanity check, check that we can load and run a few iterations
     optim2.load_state_dict(sharded_osd1)
     self._step_model(model2, optim2, num_iters=NUM_ITERS)