Ejemplo n.º 1
0
    _, _, argument_f1_test = eval(model, test_iter, os.path.join(hp.logdir,'0') + '_test')
    best_f1 = max(0,argument_f1_test)
    no_gain_rc = 0#效果不增加代数
    for epoch in range(1, hp.n_epochs + 1):
        train(model, train_iter, optimizer, hp)

        fname = os.path.join(hp.logdir, str(epoch))

        print(f"=========eval dev at epoch={epoch}=========")
        metric_dev,trigger_f1_dev, argument_f1_dev = eval(model, dev_iter, fname + '_dev')

        print(f"=========eval test at epoch={epoch}=========")
        metric_test,trigger_f1_test, argument_f1_test = eval(model, test_iter, fname + '_test')

        if hp.telegram_bot_token:
            report_to_telegram('[epoch {}] dev\n{}'.format(epoch, metric_dev), hp.telegram_bot_token, hp.telegram_chat_id)
            report_to_telegram('[epoch {}] test\n{}'.format(epoch, metric_test), hp.telegram_bot_token, hp.telegram_chat_id)

        if argument_f1_test >best_f1:
            print("角色词 F1 值由 {:.3f} 更新至 {:.3f} ".format(best_f1, argument_f1_test))
            best_f1 = argument_f1_test
            print("=======保存模型=======")
            torch.save(model, hp.model_path)
            no_gain_rc = 0
        else:
            no_gain_rc = no_gain_rc+1

        ## 提前终止
        if no_gain_rc > hp.early_stop:
            print("连续{}个epoch没有提升,在epoch={}提前终止".format(no_gain_rc,epoch))
            break

        if dev_arg_f1 >= dev_arg_f1_max:
          dev_arg_f1_max = dev_arg_f1
          metric_output = os.path.join(hp.module_output, hp.result_output)
          model_save_path = os.path.join(hp.module_output, hp.model_save_name)
          torch.save(model, model_save_path)
          with open(metric_output, 'a') as fout:
            fout.write(f"=========eval dev at epoch={epoch}=========\n")
            fout.write(dev_table.get_string())
            fout.write(f"\n=========eval test at epoch={epoch}=========\n")
            fout.write(test_table.get_string())
            fout.write('\n\n')

        if hp.telegram_bot_token:
            report_to_telegram('[epoch {}] dev\n{}'.format(epoch, metric_dev), TELEGRAM_BOT_TOKEN, TELEGRAM_CHAT_ID)
            report_to_telegram('[epoch {}] test\n{}'.format(epoch, metric_test), TELEGRAM_BOT_TOKEN, TELEGRAM_CHAT_ID)


    with open(metric_output, 'a') as fout:
        fout.write("----------------End of Pre-training on base set------------------\n\n\n")
        fout.write("----------------Finetune on novel set------------------\n")
    # finetune on novel set
    dev_arg_f1_max = 0

    ft_train_dataset = ACE2005DatasetNovel(hp.trainset, all_arguments, argument2idx, novel_event=novel_event, novel_shot=hp.novel_shot)
    ft_dev_dataset = ACE2005DatasetNovel(hp.devset, all_arguments, argument2idx, novel_event=novel_event, novel_shot=1000)
    ft_test_dataset = ACE2005DatasetNovel(hp.testset, all_arguments, argument2idx, novel_event=novel_event, novel_shot=1000)
    samples_weight = ft_train_dataset.get_samples_weight()
    sampler = torch.utils.data.WeightedRandomSampler(samples_weight, len(samples_weight))