Exemplo n.º 1
0
    def test_optim_state_dict_nested(
        self,
        state_dict_type: StateDictType,
        use_multiple_param_groups: bool,
        rank0_only: bool,
        use_diff_optim_inputs: bool,
    ) -> None:
        """
        Tests :meth:`full_optim_state_dict` and `sharded_optim_state_dict`
        by comparing the returned dict for an FSDP-wrapped model with that of
        an equivalent non-wrapped model.

        The test checks the equivalence excluding the parameter keys since the
        FSDP and normal optimizer state dicts key by names and IDs,
        respectively. This means that the test can pass even if parameter keys
        are incorrectly mapped to values. Their correct mapping is tested in
        other tests that exercise the save/load workflow.
        """
        if rank0_only and state_dict_type == StateDictType.SHARDED_STATE_DICT:
            return  # not supported
        NUM_ITERS = 3
        model1, optim1, optim_input = self._init_nested_model(
            wrap=True, use_multiple_param_groups=use_multiple_param_groups,
            use_diff_optim_inputs=use_diff_optim_inputs,
        )
        losses1 = self._step_model(model1, optim1, num_iters=NUM_ITERS)
        if state_dict_type == StateDictType.FULL_STATE_DICT:
            fsdp_osd = FSDP.full_optim_state_dict(
                model1, optim1, optim_input, rank0_only=rank0_only,
            )
        else:
            fsdp_osd = FSDP.sharded_optim_state_dict(
                model1, optim1, optim_input
            )
        # Non-target ranks get an empty state dict
        if rank0_only and self.rank != 0:
            self.assertEqual(len(fsdp_osd), 0)
            return
        model2, optim2, _ = self._init_nested_model(
            wrap=False, use_multiple_param_groups=use_multiple_param_groups,
            use_diff_optim_inputs=use_diff_optim_inputs,
        )
        losses2 = self._step_model(model2, optim2, num_iters=NUM_ITERS)
        ref_osd = optim2.state_dict()
        # Check the losses to eliminate model drift as a source of error
        for i, (l1, l2) in enumerate(zip(losses1, losses2)):
            assert l1 == l2, f"Losses differ on iter {i}: {l1:.5f} {l2:.5f}"
        # Do not check the parameter keys since the full/sharded optimizer state
        # dict uses parameter names, while the non-wrapped equivalent uses
        # parameter IDs
        check_same_param_keys = False
        self._check_same_param_groups(
            fsdp_osd, ref_osd, check_same_param_keys=check_same_param_keys,
        )
        self._check_same_state(
            fsdp_osd, ref_osd, check_same_param_keys=check_same_param_keys,
        )
Exemplo n.º 2
0
 def test_rekey_optim_state_dict_to_ids(
     self,
     state_dict_type: StateDictType,
     use_multiple_param_groups: bool,
 ):
     """Tests :meth:`rekey_optim_state_dict` with the new keys being
     parameter IDs by checking that a wrapped model (i.e. with FSDP modules)
     can rekey its optimizer state dict to match that of an equivalent
     non-wrapped model (i.e. without FSDP modules)."""
     NUM_ITERS = 3
     # Run a wrapped model for a few iterations
     model1, optim1, optim_input1 = self._init_nested_model(
         wrap=True, use_multiple_param_groups=use_multiple_param_groups,
     )
     self._step_model(model1, optim1, num_iters=NUM_ITERS)
     if state_dict_type == StateDictType.FULL_STATE_DICT:
         fsdp_osd = FSDP.full_optim_state_dict(model1, optim1, optim_input1)
         # Broadcast instead of `torch.save()`/`torch.load()` so that all ranks
         # have the full state dict
         fsdp_osd = self._broadcast_full_osd(fsdp_osd)
     else:
         fsdp_osd = FSDP.sharded_optim_state_dict(model1, optim1, optim_input1)
     # Run a non-wrapped model for a few iterations
     model2, optim2, optim_input2 = self._init_nested_model(
         wrap=False, use_multiple_param_groups=use_multiple_param_groups,
     )
     self._step_model(model2, optim2, num_iters=NUM_ITERS)
     # Re-key the wrapped model's optimizer state dict using parameter IDs
     # according to the non-wrapped model
     rekeyed_osd = FSDP.rekey_optim_state_dict(
         fsdp_osd, OptimStateKeyType.PARAM_ID, model2, optim_input2,
     )
     # Check that the re-keyed dict and actual dict are the same
     osd = optim2.state_dict()
     check_same_param_keys = True
     self._check_same_param_groups(
         rekeyed_osd, osd, check_same_param_keys=check_same_param_keys,
     )
     self._check_same_state(
         rekeyed_osd, osd, check_same_param_keys=check_same_param_keys,
     )
     # As a sanity check, check that we can load and run a few iterations
     if state_dict_type != StateDictType.SHARDED_STATE_DICT:
         optim2.load_state_dict(rekeyed_osd)
         self._step_model(model2, optim2, num_iters=NUM_ITERS)
Exemplo n.º 3
0
    def test_shard_full_optim_state_dict_unmanaged_params(
        self,
        state_dict_type: StateDictType,
        add_to_fsdp_module: bool,
    ):
        """
        Tests :meth:`shard_full_optim_state_dict` when there are unmanaged
        parameters.
          - If ``add_to_fsdp_module=True``, then the unmanaged parameters are
          added to a module to be wrapped with FSDP, in which case there should
          be an error since we require that all unflattened parameter
          comprising a flattened parameter have the same scalar state (e.g.
          Adam "step") but the added parameter is missing its entry.
          - If ``add_to_fsdp_module=False``, then the unmanaged parameters are
          added to a module not to be wrapped with FSDP, in which case there
          should be no error (emulating model parallel use cases where some
          parameters may be managed externally to FSDP).
        We do not separately test unmanaged parameters for
        :meth:`scatter_full_optim_state_dict` and `flatten_sharded_optim_state_dict`
        to save CI cost since it call into the same subroutine
        :meth:`_flatten_optim_state_dict`.
        """
        NUM_ITERS = 1
        # Create a normal wrapped model
        model, optim, optim_input = self._init_nested_model(wrap=True)
        self._step_model(model, optim, num_iters=NUM_ITERS)

        if state_dict_type == StateDictType.FULL_STATE_DICT:
            fsdp_osd = FSDP.full_optim_state_dict(
                model, optim, optim_input, rank0_only=False,
            )  # save on all ranks to avoid having to broadcast from rank 0
        else:
            fsdp_osd = FSDP.sharded_optim_state_dict(model, optim, optim_input)
        # Create a new model with the same structure but additional unmanaged
        # parameters, representing the model for which we want to load
        device = torch.device("cuda")
        model = NestedModel().to(device)
        model, unmanaged_params = NestedModel.wrap_with_unmanaged_params(
            model, add_to_fsdp_module,
        )
        optim_input = list(model.parameters())
        if add_to_fsdp_module:
            # If we add the unmanaged parameters to a module wrapped with FSDP,
            # then the flattened parameter will be comprised of some
            # unflattened parameters with zero-dimensional tensor state (i.e.
            # Adam "step") and others without (i.e. the unmanaged parameters),
            # which triggers an error that we have to ensure correctness
            error_prefix = "^(All unflattened parameters comprising a " \
                "single flattened parameter must have scalar state with the " \
                "same value and dtype)"
            with self.assertRaisesRegex(ValueError, error_prefix):
                if state_dict_type == StateDictType.FULL_STATE_DICT:
                    FSDP.shard_full_optim_state_dict(
                        fsdp_osd, model, optim_input,
                    )
                else:
                    FSDP.flatten_sharded_optim_state_dict(
                        fsdp_osd, model, optim_input,
                    )
        else:
            # If we add the unmanaged parameters to a module not wrapped with
            # FSDP, then we simply ignore them without erroring to enable
            # model parallelism use cases, where some parameters are managed
            # externally to FSDP
            if state_dict_type == StateDictType.FULL_STATE_DICT:
                flattened_osd = FSDP.shard_full_optim_state_dict(
                    fsdp_osd, model, optim_input,
                )
            else:
                flattened_osd = FSDP.flatten_sharded_optim_state_dict(
                    fsdp_osd, model, optim_input,
                )
            # Add entries for the unmanaged parameters to be able to load
            for unmanaged_param in unmanaged_params:
                NestedModel.add_unmanaged_param_entry(
                    flattened_osd, unmanaged_param, NUM_ITERS,
                )
            # Check that we can load the optimizer state dict
            optim = torch.optim.Adam(optim_input, lr=1e-3)
            optim.load_state_dict(flattened_osd)