def test_shard_full_optim_state_dict_unmanaged_params(
        self,
        state_dict_type: StateDictType,
        add_to_fsdp_module: bool,
    ):
        """
        Tests :meth:`shard_full_optim_state_dict` when there are unmanaged
        parameters.
          - If ``add_to_fsdp_module=True``, then the unmanaged parameters are
          added to a module to be wrapped with FSDP, in which case there should
          be an error since we require that all unflattened parameter
          comprising a flattened parameter have the same scalar state (e.g.
          Adam "step") but the added parameter is missing its entry.
          - If ``add_to_fsdp_module=False``, then the unmanaged parameters are
          added to a module not to be wrapped with FSDP, in which case there
          should be no error (emulating model parallel use cases where some
          parameters may be managed externally to FSDP).
        We do not separately test unmanaged parameters for
        :meth:`scatter_full_optim_state_dict` and `flatten_sharded_optim_state_dict`
        to save CI cost since it call into the same subroutine
        :meth:`_flatten_optim_state_dict`.
        """
        NUM_ITERS = 1
        # Create a normal wrapped model
        model, optim, optim_input = self._init_nested_model(wrap=True)
        self._step_model(model, optim, num_iters=NUM_ITERS)

        if state_dict_type == StateDictType.FULL_STATE_DICT:
            fsdp_osd = FSDP.full_optim_state_dict(
                model, optim, optim_input, rank0_only=False,
            )  # save on all ranks to avoid having to broadcast from rank 0
        else:
            fsdp_osd = FSDP.sharded_optim_state_dict(model, optim, optim_input)
        # Create a new model with the same structure but additional unmanaged
        # parameters, representing the model for which we want to load
        device = torch.device("cuda")
        model = NestedModel().to(device)
        model, unmanaged_params = NestedModel.wrap_with_unmanaged_params(
            model, add_to_fsdp_module,
        )
        optim_input = list(model.parameters())
        if add_to_fsdp_module:
            # If we add the unmanaged parameters to a module wrapped with FSDP,
            # then the flattened parameter will be comprised of some
            # unflattened parameters with zero-dimensional tensor state (i.e.
            # Adam "step") and others without (i.e. the unmanaged parameters),
            # which triggers an error that we have to ensure correctness
            error_prefix = "^(All unflattened parameters comprising a " \
                "single flattened parameter must have scalar state with the " \
                "same value and dtype)"
            with self.assertRaisesRegex(ValueError, error_prefix):
                if state_dict_type == StateDictType.FULL_STATE_DICT:
                    FSDP.shard_full_optim_state_dict(
                        fsdp_osd, model, optim_input,
                    )
                else:
                    FSDP.flatten_sharded_optim_state_dict(
                        fsdp_osd, model, optim_input,
                    )
        else:
            # If we add the unmanaged parameters to a module not wrapped with
            # FSDP, then we simply ignore them without erroring to enable
            # model parallelism use cases, where some parameters are managed
            # externally to FSDP
            if state_dict_type == StateDictType.FULL_STATE_DICT:
                flattened_osd = FSDP.shard_full_optim_state_dict(
                    fsdp_osd, model, optim_input,
                )
            else:
                flattened_osd = FSDP.flatten_sharded_optim_state_dict(
                    fsdp_osd, model, optim_input,
                )
            # Add entries for the unmanaged parameters to be able to load
            for unmanaged_param in unmanaged_params:
                NestedModel.add_unmanaged_param_entry(
                    flattened_osd, unmanaged_param, NUM_ITERS,
                )
            # Check that we can load the optimizer state dict
            optim = torch.optim.Adam(optim_input, lr=1e-3)
            optim.load_state_dict(flattened_osd)
    def _test_load_optim_state(
        self,
        model_class: _ModelClass,
        use_multiple_param_groups: bool,
        halve_world_size: bool,
        osd_comm_method: _OSDCommMethod,
        use_diff_optim_inputs: bool,
        **new_model_kwargs,
    ):
        """
        (1) Runs a model with full world size for K iterations to generate a
        full/sharded optimizer state dict;
        (2) initializes a model with halved world size and possibly different
        FSDP wrapping scheme (based on ``new_model_kwargs``);
        (3) loads the full/sharded optimizer state dict from (1) according to the
        halved-world-size model;
        (4) runs the halved-world-size model for K iterations; and
        (5) checks that the sharded optimizer state dict from (3) matches the
        halved-world-size model's local optimizer state dict, meaning that the
        former could have equivalently been loaded into the local optimizer.
        """
        NUM_ITERS = 3
        initializer = self._model_class[model_class]
        osd_method = (
            FSDP.sharded_optim_state_dict
            if osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD
            else FSDP.full_optim_state_dict
        )

        # First, run a wrapped model with full world size for a few iterations
        model1, optim1, optim_input1 = initializer(
            wrap=True, use_multiple_param_groups=use_multiple_param_groups,
        )
        self._step_model(model1, optim1, num_iters=NUM_ITERS)
        fsdp_osd1 = osd_method(model1, optim1, optim_input1)
        if halve_world_size:
            # Create a new process group with halved world size
            new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0]
            new_group = dist.new_group(ranks=new_group_ranks)
            if self.rank not in new_group_ranks:
                return
        else:
            # Continue using the same group and hence world size
            new_group = dist.distributed_c10d._get_default_group()
        # Second, run a wrapped model with (possibly) halved world size and
        # (possibly) differing `optim_input` across ranks
        model2, optim2, optim_input2 = initializer(
            wrap=True, group=new_group,
            use_multiple_param_groups=use_multiple_param_groups,
            use_diff_optim_inputs=use_diff_optim_inputs,
            **new_model_kwargs,  # specify `wrap_alt` to change wrapping
        )
        self._step_model(model2, optim2, num_iters=NUM_ITERS)
        fsdp_osd2 = osd_method(model2, optim2, optim_input2, group=new_group)
        # Compute two sharded optim state dicts: (1) for the first model
        # according to the second model and (2) for the second model according
        # to the second model
        if osd_comm_method == _OSDCommMethod.BROADCAST_OBJECT_LIST:
            fsdp_osd1 = self._broadcast_full_osd(fsdp_osd1, group=new_group)
            sharded_osd1 = FSDP.shard_full_optim_state_dict(
                fsdp_osd1, model2, optim_input2,
            )
            fsdp_osd2 = self._broadcast_full_osd(fsdp_osd2, group=new_group)
            sharded_osd2 = FSDP.shard_full_optim_state_dict(
                fsdp_osd2, model2, optim_input2,
            )
        elif osd_comm_method == _OSDCommMethod.SCATTER_FULL_OSD:
            sharded_osd1 = FSDP.scatter_full_optim_state_dict(
                fsdp_osd1 if self.rank == 0 else None, model2, optim_input2,
                group=new_group,
            )
            sharded_osd2 = FSDP.scatter_full_optim_state_dict(
                fsdp_osd2 if self.rank == 0 else None, model2, optim_input2,
                group=new_group,
            )
            self._check_state_device(sharded_osd1, on_gpu=True)
            self._check_state_device(sharded_osd2, on_gpu=True)
        elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD:
            sharded_osd1 = FSDP.flatten_sharded_optim_state_dict(
                fsdp_osd1, model2, optim_input2,
            )
            sharded_osd2 = FSDP.flatten_sharded_optim_state_dict(
                fsdp_osd2, model2, optim_input2,
            )

        # As a sanity check, check that sharding the second model's full/sharded
        # optimizer state dict according to itself is equivalent to its local
        # optimizer's state dict
        local_osd2 = optim2.state_dict()
        check_same_param_keys = True  # should all have matching parameter IDs
        self._check_same_param_groups(
            sharded_osd2, local_osd2,
            check_same_param_keys=check_same_param_keys,
        )
        self._check_same_state(
            sharded_osd2, local_osd2,
            check_same_param_keys=check_same_param_keys,
        )
        # Check that sharding the first model's full/sharded optimizer state dict
        # according to the second model is equivalent to the second model's
        # local optimizer state dict
        self._check_same_param_groups(
            sharded_osd1, local_osd2,
            check_same_param_keys=check_same_param_keys,
        )
        self._check_same_state(
            sharded_osd1, local_osd2,
            check_same_param_keys=check_same_param_keys,
        )
        # As a sanity check, check that we can load and run a few iterations
        if osd_comm_method != _OSDCommMethod.FLATTEN_SHARDED_OSD:
            optim2.load_state_dict(sharded_osd1)
            self._step_model(model2, optim2, num_iters=NUM_ITERS)