コード例 #1
0
ファイル: predicting.py プロジェクト: kumiko-oreyome/qa_mrc
def eval_all():

    output_model_file = "../model_dir/best_model"
    output_config_file = "../model_dir/bert_config.json"

    config = BertConfig(output_config_file)
    model = BertForQuestionAnswering(config)
    model.load_state_dict(
        torch.load(output_model_file))  #, map_location='cpu'))
    evaluate(model.cpu(), result_file="../metric/predicts.json")
コード例 #2
0
def eval_all():
    output_model_file = "../model_dir/best_model"
    output_config_file = "../model_dir/bert_configbase.json"

    config = BertConfig(output_config_file)
    model = BertForQuestionAnswering(config)
    # 针对多卡训练加载模型的方法:
    state_dict = torch.load(output_model_file, map_location='cuda:0')
    # 初始化一个空 dict
    new_state_dict = OrderedDict()
    # 修改 key,没有module字段则需要不上,如果有,则需要修改为 module.features
    for k, v in state_dict.items():
        if 'module' not in k:
            k = k
        else:
            k = k.replace('module.', '')
        new_state_dict[k] = v
    model.load_state_dict(new_state_dict)
    # model.load_state_dict(torch.load(output_model_file)) #, map_location='cpu'))
    evaluate(model.cpu(), result_file="../metric/predicts_dev.json")