Esempio n. 1
0
def train(dataset_name,
          batch_size=50,
          num_workers=2,
          max_epoch=15,
          lr=3e-5,
          weight_decay=1e-5,
          add_subject_loss=False,
          eval=False,
          continue_train=False,
          large_bert=False,
          subject_1=False,
          use_cls=True,
          softmax=False,
          opt='adam',
          seed=31415926535897932,
          cuda_device=0,
          sort=True,
          metric="micro_f1"):
    print("@@@@@@@@@@@ args @@@@@@@@@@@")
    print(locals())
    print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@")

    os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device)

    if seed is not None:
        torch.manual_seed(seed)

    root_path = cf.rootpath
    dataset_file = ["train.json", "dev.json", "test.json"]
    dataset_pkl_file = ["train.pkl", "dev.pkl", "test.pkl"]

    if large_bert:
        bert_path, bert_large = cf.modelpath, True
    else:
        bert_path, bert_large = cf.modelpath, False

    ckpt = './ckpt/%s_%s_%s_%s_%s_%s_bert.th' % (
        cf.outputname,
        dataset_name,
        "softmax" if softmax else "sigmoid",
        "withCLS" if use_cls else "withoutCLS",
        "-1" if subject_1 else "-2",
        "1024" if bert_large else "768",
    )

    # 这个函数是获取数据集,初步看代码时,可以跳过这种偏细节部分。有一点注意看,就是在这里把分词器tokenize传给了data_loader
    def get_dataset(_model):
        if all(
                map(
                    lambda x: os.path.exists(
                        os.path.join(root_path, dataset_name, x)),
                    dataset_pkl_file)):
            dataset = list(
                map(
                    lambda x: pickle.load(
                        open(os.path.join(root_path, dataset_name, x), "rb")),
                    dataset_pkl_file))

            if softmax:
                list(map(lambda x: x.split(), dataset))
        else:
            dataset = list(
                map(
                    lambda x: SentenceREDataset(
                        path=os.path.join(root_path, dataset_name, x),
                        rel2id=rel2id,
                        tokenizer=_model.sentence_encoder.tokenize,
                        kwargs=None,
                        sort=sort), dataset_file))

            list(
                map(
                    lambda x, y: pickle.dump(
                        x, open(os.path.join(root_path, dataset_name, y), "wb")
                    ), dataset, dataset_pkl_file))

        if dataset_name in ["nyt10", "nyt10_1", "nyt10_2"]:
            list(map(lambda x: x.set_max_words(100), dataset))
            list(map(lambda x: x.remove_na(), dataset))
            # list(map(lambda x: x.remove_repeat(), dataset))
            list(map(lambda x: x.char_idx_to_word_idx(), dataset))
            for d in dataset:
                d.NA_id = -1
        if dataset_name in ["semeval_1"]:
            for d in dataset:
                d.NA_id = -1
        if dataset_name in ["webnlg", "webnlg_1"]:
            for d in dataset:
                d.NA_id = -1

        dataset_loader = list(
            map(
                lambda x: DataLoader(dataset=x,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     pin_memory=True,
                                     num_workers=num_workers,
                                     collate_fn=SentenceREDataset.collate_fn),
                dataset))

        return dataset_loader

    rel2id = json.load(
        open(os.path.join(root_path, dataset_name, 'rel2id.json')))

    # 这部分代码是构建网络结构,安装网络结构图将网络构建出来,先弄em层,然后模型其它部分
    sentence_encoder = encoder.BERTHiddenStateEncoder(pretrain_path=bert_path)
    _model = model.PARA(sentence_encoder,
                        len(rel2id),
                        rel2id,
                        num_token_labels=2,
                        subject_1=subject_1,
                        use_cls=use_cls)

    # 这块就是为了获取数据集loader,由于深度学习按批次训练,不是一次把全部数据放入模型,batch形式的,所以要有loader负责这个活。
    train_loader, val_loader, test_loader = get_dataset(_model)

    # 因为用了opennre的框架,所以这样子,这里其实主要是传入一些超参,这一步才真正的构建了神经网络的结构
    _framework = framework.SentenceRE(
        train_loader=train_loader,
        val_loader=val_loader
        if dataset_name not in ["nyt10", "nyt10_1"] else test_loader,
        test_loader=test_loader,
        model=_model,
        ckpt=ckpt,
        max_epoch=max_epoch,
        lr=lr,
        weight_decay=weight_decay,
        opt=opt,
        add_subject_loss=add_subject_loss,
        loss_func=PARALossSoftmax() if softmax else PARALoss(),
        metric=F1Metric(multi_label=not softmax,
                        na_id=train_loader.dataset.NA_id,
                        ignore_na=dataset_name == "semeval",
                        rel2id=rel2id,
                        print_error_prob=1),
    )

    # 这里有一个继续训练的情况,比如训练了几轮,因为什么原因,停下了,可以从checkpoint中加载进来继续训练。train_model这个函数负责训练
    if not eval:
        if continue_train:
            _framework.parallel_model.load_state_dict(
                torch.load(ckpt).state_dict())
        _framework.train_model(metric=metric)
    _framework.parallel_model.load_state_dict(torch.load(ckpt).state_dict())

    # print("TRAIN---------------------------")
    # result = _framework.eval_model(_framework.train_loader)
    # print('Accuracy on test set: {}'.format(result['acc']))
    # print('Micro Precision: {}'.format(result['micro_p']))
    # print('Micro Recall: {}'.format(result['micro_r']))
    # print('Micro F1: {}'.format(result['micro_f1']))
    #
    # print("DEV---------------------------")
    # result = _framework.eval_model(_framework.val_loader)
    # print('Accuracy on test set: {}'.format(result['acc']))
    # print('Micro Precision: {}'.format(result['micro_p']))
    # print('Micro Recall: {}'.format(result['micro_r']))
    # print('Micro F1: {}'.format(result['micro_f1']))

    print("TEST---------------------------")
    result = _framework.eval_model(_framework.test_loader)
    print('Accuracy on test set: {}'.format(result['acc']))
    print('Micro Precision: {}'.format(result['micro_p']))
    print('Micro Recall: {}'.format(result['micro_r']))
    print('Micro F1: {}'.format(result['micro_f1']))

    if os.path.exists(os.path.join(root_path, dataset_name,
                                   "test_sample.json")):
        test_sample_dataset = SentenceREDataset(
            path=os.path.join(root_path, dataset_name, "test_sample.json"),
            rel2id=rel2id,
            tokenizer=_model.sentence_encoder.tokenize,
            kwargs=None,
            sort=sort)
        test_sample_loader = DataLoader(
            dataset=test_sample_dataset,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=num_workers,
            collate_fn=SentenceREDataset.collate_fn)
        print("TEST-Sample--------------------")
        result = _framework.eval_model(test_sample_loader)
        print('Accuracy on test set: {}'.format(result['acc']))
        print('Micro Precision: {}'.format(result['micro_p']))
        print('Micro Recall: {}'.format(result['micro_r']))
        print('Micro F1: {}'.format(result['micro_f1']))
        _framework.metric.df.to_excel(
            os.path.join(root_path, dataset_name, "res.xlsx"))
def train(dataset_name,
          batch_size=50,
          num_workers=0,
          max_epoch=15,
          lr=3e-5,
          weight_decay=1e-5,
          add_subject_loss=False,
          eval=False,
          continue_train=False,
          large_bert=False,
          subject_1=False,
          use_cls=True,
          softmax=False,
          opt='adam',
          seed=31415926535897932,
          cuda_device=0,
          sort=True,
          metric="micro_f1"):
    print("@@@@@@@@@@@ args @@@@@@@@@@@")
    print(locals())
    print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@")

    os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device)

    if seed is not None:
        torch.manual_seed(seed)

    root_path = 'src/REDN/Datasets'
    dataset_file = [
        "First/5%_train/train.json", "First/5%_train/dev.json",
        "First/5%_train/test.json"
    ]
    dataset_pkl_file = [
        "First/5%_train/train.pkl", "First/5%_train/dev.pkl",
        "First/5%_train/test.pkl"
    ]

    bert_path, bert_large = 'src/REDN/pretrain/bert-base-uncased', False

    ckpt = 'Checkpoints/%s_%s_%s_%s_%s_%s_%s_bert.th' % (
        "5%_train",
        cf.outputname,
        dataset_name,
        "softmax" if softmax else "sigmoid",
        "withCLS" if use_cls else "withoutCLS",
        "-1" if subject_1 else "-2",
        "1024" if bert_large else "768",
    )

    def get_dataset(_model):
        if all(
                map(
                    lambda x: os.path.exists(
                        os.path.join(root_path, dataset_name, x)),
                    dataset_pkl_file)):
            dataset = list(
                map(
                    lambda x: pickle.load(
                        open(os.path.join(root_path, dataset_name, x), "rb")),
                    dataset_pkl_file))

            if softmax:
                list(map(lambda x: x.split(), dataset))
        else:
            dataset = list(
                map(
                    lambda x: SentenceREDataset(
                        path=os.path.join(root_path, dataset_name, x),
                        rel2id=rel2id,
                        tokenizer=_model.sentence_encoder.tokenize,
                        kwargs=None,
                        sort=sort), dataset_file))

            list(
                map(
                    lambda x, y: pickle.dump(
                        x, open(os.path.join(root_path, dataset_name, y), "wb")
                    ), dataset, dataset_pkl_file))

        if dataset_name in ["nyt10", "nyt10_1", "nyt10_2"]:
            list(map(lambda x: x.set_max_words(100), dataset))
            list(map(lambda x: x.remove_na(), dataset))
            # list(map(lambda x: x.remove_repeat(), dataset))
            list(map(lambda x: x.char_idx_to_word_idx(), dataset))
            for d in dataset:
                d.NA_id = -1
        if dataset_name in ["semeval_1"]:
            for d in dataset:
                d.NA_id = -1
        if dataset_name in ["webnlg", "webnlg_1"]:
            for d in dataset:
                d.NA_id = -1
        if dataset_name in ["fewrel"]:
            # list(map(lambda x: x.set_max_words(64), dataset))
            for d in dataset:
                d.NA_id = -1

        dataset_loader = list(
            map(
                lambda x: DataLoader(dataset=x,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     pin_memory=True,
                                     num_workers=num_workers,
                                     collate_fn=SentenceREDataset.collate_fn),
                dataset))

        return dataset_loader

    rel2id = json.load(
        open(os.path.join(root_path, dataset_name, 'First/rel2id.json')))

    sentence_encoder = encoder.BERTHiddenStateEncoder(pretrain_path=bert_path)
    _model = model.PARA(sentence_encoder,
                        len(rel2id),
                        rel2id,
                        num_token_labels=2,
                        subject_1=subject_1,
                        use_cls=use_cls)

    train_loader, val_loader, test_loader = get_dataset(_model)

    _framework = framework.SentenceRE(
        train_loader=train_loader,
        val_loader=val_loader
        if dataset_name not in ["nyt10", "nyt10_1"] else test_loader,
        test_loader=test_loader,
        model=_model,
        ckpt=ckpt,
        max_epoch=max_epoch,
        lr=lr,
        weight_decay=weight_decay,
        opt=opt,
        add_subject_loss=add_subject_loss,
        loss_func=PARALossSoftmax() if softmax else PARALoss(),
        metric=F1Metric(multi_label=not softmax,
                        na_id=train_loader.dataset.NA_id,
                        ignore_na=dataset_name == "semeval",
                        rel2id=rel2id,
                        print_error_prob=1),
    )

    if not eval:
        if continue_train:
            _framework.parallel_model.load_state_dict(
                torch.load(ckpt).state_dict())
        _framework.train_model(metric=metric)
    _framework.parallel_model.load_state_dict(torch.load(ckpt).state_dict())

    print("TRAIN---------------------------")
    result = _framework.eval_model(_framework.train_loader)
    print('Accuracy on test set: {}'.format(result['acc']))
    print('Micro Precision: {}'.format(result['micro_p']))
    print('Micro Recall: {}'.format(result['micro_r']))
    print('Micro F1: {}'.format(result['micro_f1']))

    print("DEV---------------------------")
    result = _framework.eval_model(_framework.val_loader)
    print('Accuracy on test set: {}'.format(result['acc']))
    print('Micro Precision: {}'.format(result['micro_p']))
    print('Micro Recall: {}'.format(result['micro_r']))
    print('Micro F1: {}'.format(result['micro_f1']))

    print("TEST---------------------------")
    result = _framework.eval_model(_framework.test_loader)
    print('Accuracy on test set: {}'.format(result['acc']))
    print('Micro Precision: {}'.format(result['micro_p']))
    print('Micro Recall: {}'.format(result['micro_r']))
    print('Micro F1: {}'.format(result['micro_f1']))

    ### Uncomment this section and comment other test sections
    ### if use Few Shot Evaluation. In the "main" function,
    ### set eval=True
    # print("Few Shot Evaluation------------")
    # torch.save(_framework.parallel_model, ckpt)
    # _framework.parallel_model.load_state_dict(torch.load(ckpt).state_dict())
    # t = tqdm(range(1000), leave=False)
    # acc = []
    # for i in t:
    #   result = _framework.eval_model_few_shot(_framework.test_loader, K=5, Q=1, iter=1)
    #   acc.append(result['acc'])
    #   _framework.parallel_model.load_state_dict(torch.load(ckpt).state_dict())
    #   t.set_postfix(acc=sum(acc)/len(acc))
    # print('Accuracy on test set: {}'.format(sum(acc)/len(acc)))

    if os.path.exists(os.path.join(root_path, dataset_name, "test.txt")):
        test_sample_dataset = SentenceREDataset(
            path=os.path.join(root_path, dataset_name, "test.txt"),
            rel2id=rel2id,
            tokenizer=_model.sentence_encoder.tokenize,
            kwargs=None,
            sort=sort)
        test_sample_loader = DataLoader(
            dataset=test_sample_dataset,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=num_workers,
            collate_fn=SentenceREDataset.collate_fn)
        print("TEST-Sample--------------------")
        result = _framework.eval_model(test_sample_loader)
        print('Accuracy on test set: {}'.format(result['acc']))
        print('Micro Precision: {}'.format(result['micro_p']))
        print('Micro Recall: {}'.format(result['micro_r']))
        print('Micro F1: {}'.format(result['micro_f1']))
        _framework.metric.df.to_excel(
            os.path.join(root_path, dataset_name, "res.xlsx"))