Esempio n. 1
0
def _flatten_full_optim_state_dict(
    full_optim_state_dict: Dict[str, Any],
    model: torch.nn.Module,
    shard_state: bool,
) -> Dict[str, Any]:
    """
    Flattens the full optimizer state dict, still keying by unflattened
    parameter names. If ``shard_state=True``, then FSDP-managed
    ``FlatParameter`` 's optimizer states are sharded, and otherwise, they are
    kept unsharded.

    Returns:
        Dict[str, Any]: The flattened optimizer state dict.
    """
    full_osd = full_optim_state_dict
    if "state" not in full_osd or "param_groups" not in full_osd:
        raise ValueError(
            "`full_optim_state_dict` must have the keys \"state\" and "
            "\"param_groups\" to be a valid optimizer state dict")
    flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module(model)
    param_to_unflat_param_names = FSDP._get_param_to_unflat_param_names(model)

    # Construct the "state" part
    flat_osd_state: Dict[_OptimStateKey, Any] = {}
    full_osd_state = full_osd["state"]
    for param, unflat_param_names in param_to_unflat_param_names.items():
        if isinstance(param, FlatParameter):  # flatten FSDP parameters' states
            assert param in flat_param_to_fsdp_module, \
                "Check the `flat_param_to_fsdp_module` construction\n" \
                f"param: {param}"
            fsdp_module = flat_param_to_fsdp_module[param]
            flat_state = _flatten_optim_state(
                full_osd_state,
                unflat_param_names,
                fsdp_module,
                param,
                shard_state,
            )
            key = _OptimStateKey(tuple(unflat_param_names), True)
            flat_osd_state[key] = flat_state
        else:  # do not flatten non-FSDP parameters' states
            assert len(unflat_param_names) == 1
            unflat_param_name = unflat_param_names[0]
            if unflat_param_name not in full_osd_state:
                # The state dict may not have an entry for a parameter if it
                # was not passed into the optimizer (e.g. if it is not an
                # FSDP-managed parameter)
                continue
            key = _OptimStateKey(tuple(unflat_param_names), False)
            flat_osd_state[key] = copy.copy(full_osd_state[unflat_param_name])

    # Construct the "param_groups" part -- copy as is since it will be
    # rekeyed later according to the target rank's `optim_input`
    flat_osd_param_groups = copy.deepcopy(full_osd["param_groups"])
    return {"state": flat_osd_state, "param_groups": flat_osd_param_groups}
Esempio n. 2
0
def _rekey_sharded_optim_state_dict(
    sharded_osd: Dict[str, Any],
    model: torch.nn.Module,
    optim_input: Optional[Union[List[Dict[str, Any]],
                                Iterable[torch.nn.Parameter], ]] = None,
) -> Dict[str, Any]:
    """
    Rekeys the optimizer state dict from unflattened parameter names to
    flattened parameter IDs according to the calling rank's ``optim_input``,
    which may be different across ranks. In particular, the unflattened
    parameter names are represented as :class:`_OptimStateKey` s.
    """
    param_to_flat_param_id = _get_param_to_param_id(model, optim_input)
    param_to_unflat_param_names = FSDP._get_param_to_unflat_param_names(model)
    # All parameter keys in `param_to_flat_param_id` should be in
    # `param_to_unflat_param_names` -- strict inequality follows when not all
    # parameters are passed to the optimizer via `optim_input`
    assert len(param_to_flat_param_id) <= len(param_to_unflat_param_names)

    unflat_param_names_to_flat_param_id: Dict[Tuple[str, ...],
                                              int] = {}  # for "state"
    unflat_param_name_to_flat_param_id: Dict[str,
                                             int] = {}  # for "param_groups"
    for param, unflat_param_names in param_to_unflat_param_names.items():
        if param not in param_to_flat_param_id:
            # This parameter was not passed to the optimizer via `optim_input`
            continue
        flat_param_id = param_to_flat_param_id[param]
        unflat_param_names_to_flat_param_id[tuple(
            unflat_param_names)] = flat_param_id
        for unflat_param_name in unflat_param_names:
            unflat_param_name_to_flat_param_id[
                unflat_param_name] = flat_param_id

    sharded_osd_state = sharded_osd["state"]
    rekeyed_osd_state = {}
    for key, param_state in sharded_osd_state.items():
        flat_param_id = unflat_param_names_to_flat_param_id[
            key.unflat_param_names]
        rekeyed_osd_state[flat_param_id] = param_state

    rekeyed_osd_param_groups: List[Dict[str, Any]] = []
    for unflat_param_group in sharded_osd["param_groups"]:
        flat_param_group = copy.deepcopy(unflat_param_group)
        flat_param_ids = sorted(
            set(unflat_param_name_to_flat_param_id[unflat_param_name]
                for unflat_param_name in unflat_param_group["params"]))
        flat_param_group["params"] = flat_param_ids
        rekeyed_osd_param_groups.append(flat_param_group)

    return {
        "state": rekeyed_osd_state,
        "param_groups": rekeyed_osd_param_groups
    }
Esempio n. 3
0
def _flatten_full_optim_state_dict(
    full_optim_state_dict: Dict[str, Any],
    model: torch.nn.Module,
    shard_state: bool,
    optim_input: Optional[Union[List[Dict[str, Any]],
                                Iterable[torch.nn.Parameter], ]] = None,
) -> Tuple[Dict[str, Any], Set[int]]:
    """
    Args:
        shard_state (bool): Whether to shard flattened positive-dimension
            tensor state; if ``False``, then the full flattened tensor is
            kept in the returned :class:`dict.

    Returns:
        Tuple[Dict[str, Any], Set[int]]: The flattened optimizer state dict
            and a set of the parameter IDs corresponding to FSDP parameters.
    """
    full_osd = full_optim_state_dict  # alias
    if "state" not in full_osd or "param_groups" not in full_osd:
        raise ValueError(
            "`full_optim_state_dict` must have the keys \"state\" and "
            "\"param_groups\" to be a valid optimizer state dict")

    flat_param_id_to_param = _get_param_id_to_param(model, optim_input)
    flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module(model)
    param_to_unflat_param_names = FSDP._get_param_to_unflat_param_names(model)

    # Handle the "state" part of the optimizer state dict
    flat_osd_state: Dict[int, Any] = {}
    full_osd_state = full_osd["state"]
    unflat_param_names_to_flat_param_id: Dict[str, int] = {}
    fsdp_flat_param_ids = set()  # save which IDs are for FSDP parameters
    for flat_param_id, param in enumerate(
            flat_param_id_to_param):  # type: ignore[assignment]
        assert param in param_to_unflat_param_names, \
            "Check the `param_to_unflat_params` construction\n" \
            f"param: {param}"
        unflat_param_names = param_to_unflat_param_names[param]
        # For FSDP parameters, we need to flatten
        if isinstance(param, FlatParameter):
            assert param in flat_param_to_fsdp_module, \
                "Check the `flat_param_to_fsdp_module` mapping " \
                f"construction\nparam={param}"
            unflat_param_names = param_to_unflat_param_names[param]
            fsdp_module = flat_param_to_fsdp_module[param]
            flat_state = _flatten_optim_state(
                full_osd_state,
                unflat_param_names,
                fsdp_module,
                param,
                shard_state,
            )
            flat_osd_state[flat_param_id] = flat_state
            for unflat_param_name in unflat_param_names:
                unflat_param_names_to_flat_param_id[
                    unflat_param_name] = flat_param_id
            fsdp_flat_param_ids.add(flat_param_id)
        # For parameters from non-FSDP modules, we do not need to flatten
        else:
            assert len(unflat_param_names) == 1
            unflat_param_name = unflat_param_names[0]
            if unflat_param_name not in full_osd_state:
                # A non-FSDP module's parameter may be ignored and hence not
                # have an entry in the optimizer state
                continue
            # Remap from unflattened to flattened parameter ID -- do not
            # deepcopy to avoid unnecessarily duplicating tensor storage
            flat_osd_state[flat_param_id] = \
                copy.copy(full_osd_state[unflat_param_name])
            unflat_param_names_to_flat_param_id[
                unflat_param_name] = flat_param_id

    # Handle the "param_groups" part of the optimizer state dict
    sharded_osd_param_groups: List[Dict[str, Any]] = []
    for unflat_param_group in full_osd["param_groups"]:
        flat_param_group = copy.deepcopy(unflat_param_group)
        # Map from unflattened parameter names to flattened parameter IDs
        flat_param_ids = sorted(
            set(unflat_param_names_to_flat_param_id[unflat_param_name]
                for unflat_param_name in unflat_param_group["params"]))
        flat_param_group["params"] = flat_param_ids
        sharded_osd_param_groups.append(flat_param_group)

    optim_state_dict = {
        "state": flat_osd_state,
        "param_groups": sharded_osd_param_groups,
    }
    return optim_state_dict, fsdp_flat_param_ids