else:
        raise NotImplementedError("Currently, the %s model is not supported" %
                                  (train_config["model"]))

    model.load_state_dict(
        torch.load(os.path.join(finetune_config["load_model_dir"],
                                "epoch%d.pt" % (best_epochs["dev"])),
                   map_location=torch.device("cpu")))
    model.increase_net(finetune_config)
    if not all([
            train_config[key] == finetune_config[key] for key in [
                "max_npv", "max_npe", "max_npvl", "max_npel", "max_ngv",
                "max_nge", "max_ngvl", "max_ngel", "share_emb", "share_arch"
            ]
    ]):
        model.increase_input_size(finetune_config)
    if not all([
            train_config[key] == finetune_config[key] for key in [
                "predict_net", "predict_net_hidden_dim",
                "predict_net_num_heads", "predict_net_mem_len",
                "predict_net_mem_init", "predict_net_recurrent_steps"
            ]
    ]):
        new_predict_net = model.create_predict_net(
            finetune_config["predict_net"],
            pattern_dim=model.predict_net.pattern_dim,
            graph_dim=model.predict_net.graph_dim,
            hidden_dim=finetune_config["predict_net_hidden_dim"],
            num_heads=finetune_config["predict_net_num_heads"],
            recurrent_steps=finetune_config["predict_net_recurrent_steps"],
            mem_len=finetune_config["predict_net_mem_len"],