def _pre_hook( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563 rename_state_dict(prefix + "output_norm.", prefix + "after_norm.", state_dict)
def _pre_hook( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): # https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563 rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict) # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563 rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
def test_v0_3_transformer_input_compatibility(): args = make_arg() model, x, ilens, y, data = prepare("pytorch", args) # these old names are used in v.0.3.x state_dict = model.state_dict() prefix = "encoder." rename_state_dict(prefix + "embed.", prefix + "input_layer.", state_dict) rename_state_dict(prefix + "after_norm.", prefix + "norm.", state_dict) prefix = "decoder." rename_state_dict(prefix + "after_norm.", prefix + "output_norm.", state_dict) model.load_state_dict(state_dict)