def _pre_load_state_dict_hook(state_dict: Union[Dict[str, Tensor],
                                                "OrderedDict[str, Tensor]"],
                              prefix: str, *args: Any) -> None:
    replace_by_prefix_(state_dict, prefix, prefix + "_fpw_module.")
    # flat_param actually needs to move one level up though
    flat_param_key = prefix + "_fpw_module.flat_param"
    if flat_param_key in state_dict:
        replace_by_prefix_(state_dict, flat_param_key, prefix + "flat_param")
def _pre_load_state_dict_hook(state_dict: Union[Dict[str, Tensor],
                                                "OrderedDict[str, Tensor]"],
                              prefix: str, *args: Any) -> None:
    # Push everything down to ._fpw_module level.
    replace_by_prefix_(state_dict, prefix, prefix + "_fpw_module.")
    # The flat_param_* keys actually needs to move one level up.
    flat_param_key = prefix + "_fpw_module.flat_param"
    for k in list(state_dict.keys()):
        if k.startswith(flat_param_key):
            last_part = k.split(".")[-1]
            assert last_part.startswith("flat_param_"), last_part
            replace_by_prefix_(state_dict, k, prefix + last_part)
Esempio n. 3
0
def test_replace_by_prefix():
    state_dict = {
        "layer.a": torch.tensor(1),
        "abc.layer.def": torch.tensor(2),
        "layer.b": torch.tensor(3)
    }
    replace_by_prefix_(state_dict, "layer.", "module.layer.")
    assert state_dict == {
        "module.layer.a": torch.tensor(1),
        "abc.layer.def": torch.tensor(2),
        "module.layer.b": torch.tensor(3),
    }
def _post_state_dict_hook(module: nn.Module,
                          state_dict: "OrderedDict[str, Tensor]", prefix: str,
                          *args: Any) -> "OrderedDict[str, Tensor]":
    replace_by_prefix_(state_dict, prefix + "_fpw_module.", prefix)
    return state_dict
def _post_state_dict_hook(module: nn.Module,
                          state_dict: "OrderedDict[str, Tensor]", prefix: str,
                          *args: Any) -> "OrderedDict[str, Tensor]":
    # Move everything from .fpw_module up one level.
    replace_by_prefix_(state_dict, prefix + "_fpw_module.", prefix)
    return state_dict