def _pre_load_state_dict_hook(state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], prefix: str, *args: Any) -> None: replace_by_prefix_(state_dict, prefix, prefix + "_fpw_module.") # flat_param actually needs to move one level up though flat_param_key = prefix + "_fpw_module.flat_param" if flat_param_key in state_dict: replace_by_prefix_(state_dict, flat_param_key, prefix + "flat_param")
def _pre_load_state_dict_hook(state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], prefix: str, *args: Any) -> None: # Push everything down to ._fpw_module level. replace_by_prefix_(state_dict, prefix, prefix + "_fpw_module.") # The flat_param_* keys actually needs to move one level up. flat_param_key = prefix + "_fpw_module.flat_param" for k in list(state_dict.keys()): if k.startswith(flat_param_key): last_part = k.split(".")[-1] assert last_part.startswith("flat_param_"), last_part replace_by_prefix_(state_dict, k, prefix + last_part)
def test_replace_by_prefix(): state_dict = { "layer.a": torch.tensor(1), "abc.layer.def": torch.tensor(2), "layer.b": torch.tensor(3) } replace_by_prefix_(state_dict, "layer.", "module.layer.") assert state_dict == { "module.layer.a": torch.tensor(1), "abc.layer.def": torch.tensor(2), "module.layer.b": torch.tensor(3), }
def _post_state_dict_hook(module: nn.Module, state_dict: "OrderedDict[str, Tensor]", prefix: str, *args: Any) -> "OrderedDict[str, Tensor]": replace_by_prefix_(state_dict, prefix + "_fpw_module.", prefix) return state_dict
def _post_state_dict_hook(module: nn.Module, state_dict: "OrderedDict[str, Tensor]", prefix: str, *args: Any) -> "OrderedDict[str, Tensor]": # Move everything from .fpw_module up one level. replace_by_prefix_(state_dict, prefix + "_fpw_module.", prefix) return state_dict