def _flatten_full_optim_state_dict( full_optim_state_dict: Dict[str, Any], model: torch.nn.Module, shard_state: bool, ) -> Dict[str, Any]: """ Flattens the full optimizer state dict, still keying by unflattened parameter names. If ``shard_state=True``, then FSDP-managed ``FlatParameter`` 's optimizer states are sharded, and otherwise, they are kept unsharded. Returns: Dict[str, Any]: The flattened optimizer state dict. """ full_osd = full_optim_state_dict if "state" not in full_osd or "param_groups" not in full_osd: raise ValueError( "`full_optim_state_dict` must have the keys \"state\" and " "\"param_groups\" to be a valid optimizer state dict") flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module(model) param_to_unflat_param_names = FSDP._get_param_to_unflat_param_names(model) # Construct the "state" part flat_osd_state: Dict[_OptimStateKey, Any] = {} full_osd_state = full_osd["state"] for param, unflat_param_names in param_to_unflat_param_names.items(): if isinstance(param, FlatParameter): # flatten FSDP parameters' states assert param in flat_param_to_fsdp_module, \ "Check the `flat_param_to_fsdp_module` construction\n" \ f"param: {param}" fsdp_module = flat_param_to_fsdp_module[param] flat_state = _flatten_optim_state( full_osd_state, unflat_param_names, fsdp_module, param, shard_state, ) key = _OptimStateKey(tuple(unflat_param_names), True) flat_osd_state[key] = flat_state else: # do not flatten non-FSDP parameters' states assert len(unflat_param_names) == 1 unflat_param_name = unflat_param_names[0] if unflat_param_name not in full_osd_state: # The state dict may not have an entry for a parameter if it # was not passed into the optimizer (e.g. if it is not an # FSDP-managed parameter) continue key = _OptimStateKey(tuple(unflat_param_names), False) flat_osd_state[key] = copy.copy(full_osd_state[unflat_param_name]) # Construct the "param_groups" part -- copy as is since it will be # rekeyed later according to the target rank's `optim_input` flat_osd_param_groups = copy.deepcopy(full_osd["param_groups"]) return {"state": flat_osd_state, "param_groups": flat_osd_param_groups}
def _rekey_sharded_optim_state_dict( sharded_osd: Dict[str, Any], model: torch.nn.Module, optim_input: Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter], ]] = None, ) -> Dict[str, Any]: """ Rekeys the optimizer state dict from unflattened parameter names to flattened parameter IDs according to the calling rank's ``optim_input``, which may be different across ranks. In particular, the unflattened parameter names are represented as :class:`_OptimStateKey` s. """ param_to_flat_param_id = _get_param_to_param_id(model, optim_input) param_to_unflat_param_names = FSDP._get_param_to_unflat_param_names(model) # All parameter keys in `param_to_flat_param_id` should be in # `param_to_unflat_param_names` -- strict inequality follows when not all # parameters are passed to the optimizer via `optim_input` assert len(param_to_flat_param_id) <= len(param_to_unflat_param_names) unflat_param_names_to_flat_param_id: Dict[Tuple[str, ...], int] = {} # for "state" unflat_param_name_to_flat_param_id: Dict[str, int] = {} # for "param_groups" for param, unflat_param_names in param_to_unflat_param_names.items(): if param not in param_to_flat_param_id: # This parameter was not passed to the optimizer via `optim_input` continue flat_param_id = param_to_flat_param_id[param] unflat_param_names_to_flat_param_id[tuple( unflat_param_names)] = flat_param_id for unflat_param_name in unflat_param_names: unflat_param_name_to_flat_param_id[ unflat_param_name] = flat_param_id sharded_osd_state = sharded_osd["state"] rekeyed_osd_state = {} for key, param_state in sharded_osd_state.items(): flat_param_id = unflat_param_names_to_flat_param_id[ key.unflat_param_names] rekeyed_osd_state[flat_param_id] = param_state rekeyed_osd_param_groups: List[Dict[str, Any]] = [] for unflat_param_group in sharded_osd["param_groups"]: flat_param_group = copy.deepcopy(unflat_param_group) flat_param_ids = sorted( set(unflat_param_name_to_flat_param_id[unflat_param_name] for unflat_param_name in unflat_param_group["params"])) flat_param_group["params"] = flat_param_ids rekeyed_osd_param_groups.append(flat_param_group) return { "state": rekeyed_osd_state, "param_groups": rekeyed_osd_param_groups }
def _flatten_full_optim_state_dict( full_optim_state_dict: Dict[str, Any], model: torch.nn.Module, shard_state: bool, optim_input: Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter], ]] = None, ) -> Tuple[Dict[str, Any], Set[int]]: """ Args: shard_state (bool): Whether to shard flattened positive-dimension tensor state; if ``False``, then the full flattened tensor is kept in the returned :class:`dict. Returns: Tuple[Dict[str, Any], Set[int]]: The flattened optimizer state dict and a set of the parameter IDs corresponding to FSDP parameters. """ full_osd = full_optim_state_dict # alias if "state" not in full_osd or "param_groups" not in full_osd: raise ValueError( "`full_optim_state_dict` must have the keys \"state\" and " "\"param_groups\" to be a valid optimizer state dict") flat_param_id_to_param = _get_param_id_to_param(model, optim_input) flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module(model) param_to_unflat_param_names = FSDP._get_param_to_unflat_param_names(model) # Handle the "state" part of the optimizer state dict flat_osd_state: Dict[int, Any] = {} full_osd_state = full_osd["state"] unflat_param_names_to_flat_param_id: Dict[str, int] = {} fsdp_flat_param_ids = set() # save which IDs are for FSDP parameters for flat_param_id, param in enumerate( flat_param_id_to_param): # type: ignore[assignment] assert param in param_to_unflat_param_names, \ "Check the `param_to_unflat_params` construction\n" \ f"param: {param}" unflat_param_names = param_to_unflat_param_names[param] # For FSDP parameters, we need to flatten if isinstance(param, FlatParameter): assert param in flat_param_to_fsdp_module, \ "Check the `flat_param_to_fsdp_module` mapping " \ f"construction\nparam={param}" unflat_param_names = param_to_unflat_param_names[param] fsdp_module = flat_param_to_fsdp_module[param] flat_state = _flatten_optim_state( full_osd_state, unflat_param_names, fsdp_module, param, shard_state, ) flat_osd_state[flat_param_id] = flat_state for unflat_param_name in unflat_param_names: unflat_param_names_to_flat_param_id[ unflat_param_name] = flat_param_id fsdp_flat_param_ids.add(flat_param_id) # For parameters from non-FSDP modules, we do not need to flatten else: assert len(unflat_param_names) == 1 unflat_param_name = unflat_param_names[0] if unflat_param_name not in full_osd_state: # A non-FSDP module's parameter may be ignored and hence not # have an entry in the optimizer state continue # Remap from unflattened to flattened parameter ID -- do not # deepcopy to avoid unnecessarily duplicating tensor storage flat_osd_state[flat_param_id] = \ copy.copy(full_osd_state[unflat_param_name]) unflat_param_names_to_flat_param_id[ unflat_param_name] = flat_param_id # Handle the "param_groups" part of the optimizer state dict sharded_osd_param_groups: List[Dict[str, Any]] = [] for unflat_param_group in full_osd["param_groups"]: flat_param_group = copy.deepcopy(unflat_param_group) # Map from unflattened parameter names to flattened parameter IDs flat_param_ids = sorted( set(unflat_param_names_to_flat_param_id[unflat_param_name] for unflat_param_name in unflat_param_group["params"])) flat_param_group["params"] = flat_param_ids sharded_osd_param_groups.append(flat_param_group) optim_state_dict = { "state": flat_osd_state, "param_groups": sharded_osd_param_groups, } return optim_state_dict, fsdp_flat_param_ids