def _post_state_dict_hook(module: nn.Module, state_dict: Dict[str, Any], prefix: str, *args: Any) -> Dict[str, Any]: """ _post_state_dict_hook() is called after the state_dict() is executed and before returning the state_dict to the users. This API post-processes the keys of the state_dict to remove the FlattenParamsWrapper internal prefix. """ # Move everything from FPW_MODULE up one level. _replace_by_prefix(state_dict, prefix + f"{FPW_MODULE}.", prefix) return state_dict
def _pre_load_state_dict_hook( module: nn.Module, state_dict: Dict[str, Any], prefix: str, *args: Any, ) -> None: """ ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called. For ``checkpoint_wrapper``, it will add back the module prefix so that non-checkpointed modules can be loaded into checkpoint_wrapper modules properly. """ _replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}.")
def test_replace_by_prefix(self): state_dict = { "layer.a": torch.tensor(1), "abc.layer.def": torch.tensor(2), "layer.b": torch.tensor(3), } original_state_dict = state_dict.copy() _replace_by_prefix(state_dict, "layer.", "module.layer.") assert state_dict == { "module.layer.a": torch.tensor(1), "abc.layer.def": torch.tensor(2), "module.layer.b": torch.tensor(3), } _replace_by_prefix(state_dict, "module.layer.", "layer.") assert state_dict == original_state_dict
def _post_state_dict_hook( module: nn.Module, state_dict: Dict[str, Any], prefix: str, *args: Any, ) -> Dict[str, Any]: """ _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed. For ``checkpoint_wrapper``, it will strip checkpoint-wrapped module prefix so that this module can be loaded into non-checkpointed modules. It would still be able to be loaded into checkpoint-wrapped modules as this class adds the prefix back before loading the state_dict. """ _replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}.", prefix) return state_dict
def _pre_load_state_dict_hook( state_dict: Dict[str, Any], prefix: str, *args: Any, ) -> None: """ _pre_load_state_dict_hook() is called before the _load_from_state_dict() is executed. This API pre-processes the keys of the state_dict to add the FlattenParamsWrapper internal prefix. """ # Push everything down to FPW_MODULE level. _replace_by_prefix(state_dict, prefix, prefix + f"{FPW_MODULE}.") # The flat_param_* keys actually needs to move one level up. flat_param_key = prefix + f"{FPW_MODULE}.{FLAT_PARAM}" for k in list(state_dict.keys()): if k.startswith(flat_param_key): last_part = k.split(".")[-1] assert last_part.startswith( FLAT_PARAM ), f"Expected key to contain flat_param, but key name is {k}" _replace_by_prefix(state_dict, k, prefix + last_part)