Example #1
0
def _post_state_dict_hook(module: nn.Module, state_dict: Dict[str, Any],
                          prefix: str, *args: Any) -> Dict[str, Any]:
    """
    _post_state_dict_hook() is called after the state_dict() is executed
    and before returning the state_dict to the users.
    This API post-processes the keys of the state_dict to remove the
    FlattenParamsWrapper internal prefix.
    """
    # Move everything from FPW_MODULE up one level.
    _replace_by_prefix(state_dict, prefix + f"{FPW_MODULE}.", prefix)
    return state_dict
Example #2
0
 def _pre_load_state_dict_hook(
     module: nn.Module,
     state_dict: Dict[str, Any],
     prefix: str,
     *args: Any,
 ) -> None:
     """
     ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()``
     is called. For ``checkpoint_wrapper``, it will add back the module
     prefix so that non-checkpointed modules can be loaded into
     checkpoint_wrapper modules properly.
     """
     _replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}.")
Example #3
0
 def test_replace_by_prefix(self):
     state_dict = {
         "layer.a": torch.tensor(1),
         "abc.layer.def": torch.tensor(2),
         "layer.b": torch.tensor(3),
     }
     original_state_dict = state_dict.copy()
     _replace_by_prefix(state_dict, "layer.", "module.layer.")
     assert state_dict == {
         "module.layer.a": torch.tensor(1),
         "abc.layer.def": torch.tensor(2),
         "module.layer.b": torch.tensor(3),
     }
     _replace_by_prefix(state_dict, "module.layer.", "layer.")
     assert state_dict == original_state_dict
Example #4
0
 def _post_state_dict_hook(
     module: nn.Module,
     state_dict: Dict[str, Any],
     prefix: str,
     *args: Any,
 ) -> Dict[str, Any]:
     """
     _post_state_dict_hook() is called after the state_dict() of this
     FSDP module is executed. For ``checkpoint_wrapper``, it will strip
     checkpoint-wrapped module prefix so that this module can be loaded into
     non-checkpointed modules. It would still be able to be loaded into
     checkpoint-wrapped modules as this class adds the prefix back before
     loading the state_dict.
     """
     _replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}.", prefix)
     return state_dict
Example #5
0
def _pre_load_state_dict_hook(
    state_dict: Dict[str, Any],
    prefix: str,
    *args: Any,
) -> None:
    """
    _pre_load_state_dict_hook() is called before the _load_from_state_dict() is
    executed. This API pre-processes the keys of the state_dict to add the
    FlattenParamsWrapper internal prefix.
    """
    # Push everything down to FPW_MODULE level.
    _replace_by_prefix(state_dict, prefix, prefix + f"{FPW_MODULE}.")
    # The flat_param_* keys actually needs to move one level up.
    flat_param_key = prefix + f"{FPW_MODULE}.{FLAT_PARAM}"
    for k in list(state_dict.keys()):
        if k.startswith(flat_param_key):
            last_part = k.split(".")[-1]
            assert last_part.startswith(
                FLAT_PARAM
            ), f"Expected key to contain flat_param, but key name is {k}"
            _replace_by_prefix(state_dict, k, prefix + last_part)