Beispiel #1
0
def eval(model, iterator, fname, write):
    model.eval()

    words_all, triggers_all, triggers_hat_all = [], [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            tokens_2d, triggers_2d, entities_3d, postags_2d, adj, seqlen_1d, words, triggers = batch

            trigger_logits, trigger_hat_2d = model.predict_triggers(
                tokens_2d=tokens_2d,
                entities_3d=entities_3d,
                postags_2d=postags_2d,
                seqlen_1d=seqlen_1d,
                adjm=adj)
            words_all.extend(words)
            triggers_all.extend(triggers)
            triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist())

    triggers_true, triggers_pred = [], []
    with open('temp', 'w') as fout:
        for i, (words, triggers, triggers_hat) in enumerate(
                zip(words_all, triggers_all, triggers_hat_all)):
            triggers_hat = triggers_hat[:len(words)]
            triggers_hat = [idx2trigger[hat] for hat in triggers_hat]

            # [(ith sentence, t_start, t_end, t_type_str)]
            triggers_true.extend([(i, *item)
                                  for item in find_triggers(triggers)])
            triggers_pred.extend([(i, *item)
                                  for item in find_triggers(triggers_hat)])

            for w, t, t_h in zip(words, triggers, triggers_hat):
                fout.write('{}\t{}\t{}\n'.format(w, t, t_h))
            fout.write("\n")

    print('[trigger classification]')
    trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true,
                                                   triggers_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r,
                                                 trigger_f1))
    print('[trigger identification]')
    triggers_true = [(item[0], item[1], item[2]) for item in triggers_true]
    triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred]
    trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true,
                                                      triggers_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_,
                                                 trigger_f1_))

    metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        trigger_p, trigger_r, trigger_f1)
    metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        trigger_p_, trigger_r_, trigger_f1_)
    final = fname
    if write:
        with open(final, 'w') as fout:
            result = open("temp", "r").read()
            fout.write("{}\n".format(result))
            fout.write(metric)
        os.remove("temp")
    return metric
Beispiel #2
0
    def predict_triggers(self, tokens_x_2d, entities_x_3d, postags_x_2d,
                         head_indexes_2d, triggers_y_2d, arguments_2d):
        tokens_x_2d = torch.LongTensor(tokens_x_2d).to(self.device)
        # postags_x_2d = torch.LongTensor(postags_x_2d).to(self.device)
        triggers_y_2d = torch.LongTensor(triggers_y_2d).to(self.device)
        head_indexes_2d = torch.LongTensor(head_indexes_2d).to(self.device)

        # postags_x_2d = self.postag_embed(postags_x_2d)
        # entity_x_2d = self.entity_embed(entities_x_3d)

        if self.training:
            self.bert.train()
            encoded_layers, _ = self.bert(tokens_x_2d)
            enc = encoded_layers[-1]
        else:
            self.bert.eval()
            with torch.no_grad():
                encoded_layers, _ = self.bert(tokens_x_2d)
                enc = encoded_layers[-1]

        # x = torch.cat([enc, entity_x_2d, postags_x_2d], 2)
        # x = self.fc1(enc)  # x: [batch_size, seq_len, hidden_size]
        x = enc
        # logits = self.fc2(x + enc)

        batch_size = tokens_x_2d.shape[0]

        for i in range(batch_size):
            x[i] = torch.index_select(x[i], 0, head_indexes_2d[i])

        trigger_logits = self.fc_trigger(x)
        trigger_hat_2d = trigger_logits.argmax(-1)

        argument_hidden, argument_keys = [], []
        for i in range(batch_size):
            candidates = arguments_2d[i]['candidates']
            golden_entity_tensors = {}

            for j in range(len(candidates)):
                e_start, e_end, e_type_str = candidates[j]
                golden_entity_tensors[candidates[j]] = x[i,
                                                         e_start:e_end, ].mean(
                                                             dim=0)

            predicted_triggers = find_triggers([
                idx2trigger[trigger] for trigger in trigger_hat_2d[i].tolist()
            ])
            for predicted_trigger in predicted_triggers:
                t_start, t_end, t_type_str = predicted_trigger
                event_tensor = x[i, t_start:t_end, ].mean(dim=0)
                for j in range(len(candidates)):
                    e_start, e_end, e_type_str = candidates[j]
                    entity_tensor = golden_entity_tensors[candidates[j]]

                    argument_hidden.append(
                        torch.cat([event_tensor, entity_tensor]))
                    argument_keys.append((i, t_start, t_end, t_type_str,
                                          e_start, e_end, e_type_str))

        return trigger_logits, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys
Beispiel #3
0
def train(model, iterator, optimizer, hp):
    model.train()

    words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], []
    triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], []
    # 角色
    # argument_keys:(正确)预测触发词 - 正确实体
    # arguments_2d:正确触发词 - 正确角色
    # 正确实体
    # arguments_y_2d:输入CRF的标签数据[dim0, seq_len]
    # argument_hat_1d: CRF计算结果
    # argument_hat_2d:根据argument_keys和argument_hat_1d写成字典格式
    #
    # 触发词
    # trigger_hat_2d:CRF预测触发词
    # triggers_y_2d:正确触发词
    for i, batch in enumerate(iterator):
        tokens_x_2d, entities_x_3d, postags_x_2d, triggers_y_2d, arguments_2d, seqlens_1d, head_indexes_2d, words_2d, triggers_2d, adjm = batch
        optimizer.zero_grad()
        ## crf_loss, 触发词标签, 预测触发词, 实体-事件拼接张量, (7维元组)
        trigger_loss, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys = model.module.predict_triggers(tokens_x_2d=tokens_x_2d, entities_x_3d=entities_x_3d,
                                                                                                                      postags_x_2d=postags_x_2d, head_indexes_2d=head_indexes_2d,
                                                                                                                      triggers_y_2d=triggers_y_2d, arguments_2d=arguments_2d, adjm=adjm)

        if len(argument_keys) > 0:
            argument_loss, arguments_y_2d, argument_hat_1d, argument_hat_2d = model.module.predict_arguments(argument_hidden, argument_keys, arguments_2d, adjm)
            # argument_loss = criterion(argument_logits, arguments_y_1d)
            loss =  trigger_loss +  hp.LOSS_alpha* argument_loss
            # if i == 0:

            #     print("=====sanity check for triggers======")
            #     print('triggers_y_2d[0]:', triggers_y_2d[0])
            #     print("trigger_hat_2d[0]:", trigger_hat_2d[0])

            #     print("=======================")

            #     print("=====sanity check for arguments======")
            #     print('arguments_y_2d[0]:', arguments_y_2d[0])
            #     print('argument_hat_1d[0]:', argument_hat_1d[0])
            #     print("arguments_2d[0]:", arguments_2d)
            #     print("argument_hat_2d[0]:", argument_hat_2d)
            #     print("=======================")

        else:
            loss = trigger_loss

        nn.utils.clip_grad_norm_(model.parameters(), 3.0)
        loss.backward()

        optimizer.step()

        # if i == 0:
        #     print("=====sanity check======")
        #     print("tokens_x_2d[0]:", tokenizer.convert_ids_to_tokens(tokens_x_2d[0])[:seqlens_1d[0]])
        #     print("entities_x_3d[0]:", entities_x_3d[0][:seqlens_1d[0]])
        #     print("postags_x_2d[0]:", postags_x_2d[0][:seqlens_1d[0]])
        #     print("head_indexes_2d[0]:", head_indexes_2d[0][:seqlens_1d[0]])
        #     print("triggers_2d[0]:", triggers_2d[0])
        #     print("triggers_y_2d[0]:", triggers_y_2d.cpu().numpy().tolist()[0][:seqlens_1d[0]])
        #     print('trigger_hat_2d[0]:', trigger_hat_2d.cpu().numpy().tolist()[0][:seqlens_1d[0]])
        #     print("seqlens_1d[0]:", seqlens_1d[0])
        #     print("arguments_2d[0]:", arguments_2d[0])
        #     print("=======================")

        #### 训练精度评估 ####
        words_all.extend(words_2d)
        triggers_all.extend(triggers_2d)
        triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist())
        arguments_all.extend(arguments_2d)

        if len(argument_keys) > 0:
            arguments_hat_all.extend(argument_hat_2d)
        else:
            batch_size = len(arguments_2d)
            argument_hat_2d = [{'events': {}} for _ in range(batch_size)]
            arguments_hat_all.extend(argument_hat_2d)

        for ii, (words, triggers, triggers_hat, arguments, arguments_hat) in enumerate(
                zip(words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all)):
            triggers_hat = triggers_hat[:len(words)]
            triggers_hat = [idx2trigger[hat] for hat in triggers_hat]

            # [(ith sentence, t_start, t_end, t_type_str)]
            triggers_true.extend([(ii, *item) for item in find_triggers(triggers)])
            triggers_pred.extend([(ii, *item) for item in find_triggers(triggers_hat)])

            # [(ith sentence, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)]
            for trigger in arguments['events']:
                t_start, t_end, t_type_str = trigger
                for argument in arguments['events'][trigger]:
                    a_start, a_end, a_type_idx = argument
                    arguments_true.append((ii, t_start, t_end, t_type_str, a_start, a_end, a_type_idx))

            for trigger in arguments_hat['events']:
                t_start, t_end, t_type_str = trigger
                for argument in arguments_hat['events'][trigger]:
                    a_start, a_end, a_type_idx = argument
                    arguments_pred.append((ii, t_start, t_end, t_type_str, a_start, a_end, a_type_idx))

        if i % 100 == 0:  # monitoring
            trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true, triggers_pred)
            argument_p, argument_r, argument_f1 = calc_metric(arguments_true, arguments_pred)
            ## 100step 清零
            words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], []
            triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], []
            #########################
            if  len(argument_keys) > 0:
                print("【识别到事件】step: {}, loss: {:.3f}, trigger_loss:{:.3f}, argument_loss:{:.3f}".format(i, loss.item(), trigger_loss.item(), argument_loss.item()),
                      '【trigger】 P={:.3f}  R={:.3f}  F1={:.3f}'.format(trigger_p, trigger_r, trigger_f1),
                      '【argument】 P={:.3f}  R={:.3f}  F1={:.3f}'.format(argument_p, argument_r, argument_f1)
                      )
            else:
                print("【未识别到事件】step: {}, loss: {:.3f} ".format(i, loss.item()),
                      '【trigger】 P={:.3f}  R={:.3f}  F1={:.3f}'.format(trigger_p, trigger_r, trigger_f1)
                      )
Beispiel #4
0
    def predict_triggers(self, tokens_x_2d, entities_x_3d, postags_x_2d,
                         head_indexes_2d, triggers_y_2d, arguments_2d, adjm):

        # def get_Ngram_emb(self,emb,N):
        #
        #     batch_size, SEN_LEN, hidden_size = emb.size()
        #     hidden_size = hidden_size*2
        #     x = torch.zeros([batch_size, SEN_LEN, hidden_size],dtype = emb.dtype)
        #     # for i in range(batch_size):
        #     #     for j in range(SEN_LEN):
        #     #         x[i,j]=emb[i,max(j-N,0):min(j+N,SEQ_LEN-1)].mean(dim=0)
        #
        #     for j in range(SEN_LEN):
        #         cnnfeature=self.NgramCNN.forward(emb[:, max(j - N, 0):min(j + N, SEQ_LEN - 1),:])# [batch_size,hidden_size]
        #         Nmax, _ = emb[:,max(j-N,0):min(j+N,SEQ_LEN-1),:].max(dim=1)# [batch_size,hidden_size]
        #         x[:,j,:] = torch.cat([cnnfeature,Nmax],dim=-1)
        #
        #     return x.to(self.device)

        ## 字符ID [batch_size, seq_length]
        tokens_x_2d = torch.LongTensor(tokens_x_2d).to(self.device)
        ## 触发词标签ID [batch_size, seq_length]
        triggers_y_2d = torch.LongTensor(triggers_y_2d).to(self.device)
        ## [batch_size, seq_length]
        xlen = [max(x) for x in head_indexes_2d]
        head_indexes_2d = torch.LongTensor(head_indexes_2d).to(self.device)

        if self.training:
            self.PreModel.train()
            x_1, _ = self.PreModel(tokens_x_2d)
        else:
            self.PreModel.eval()
            with torch.no_grad():
                x_1, _ = self.PreModel(tokens_x_2d)

        batch_size = tokens_x_2d.shape[0]
        SEQ_LEN = x_1.size()[1]
        # [CLS]字符
        # sen_emb = torch.unsqueeze(x_1[:,0,:],dim=1).repeat(1, SEQ_LEN, 1)  # [batch,1,hidden_size]

        # 复数形式拆解
        x = torch.zeros(x_1.size(), dtype=x_1.dtype).to(self.device)
        for i in range(batch_size):
            ## 切片, 会改变位置 同时去除了[CLS]
            x[i] = torch.index_select(x_1[i], 0, head_indexes_2d[i])

        mask = numpy.zeros(shape=[batch_size, SEQ_LEN], dtype=numpy.uint8)
        for i in range(len(xlen)):
            mask[i, :xlen[i]] = 1
        mask = torch.ByteTensor(mask).to(self.device)

        self.mask = mask
        ## [batch_size, SEQ_LEN, hidden_size*2]
        # n_gram_emb = get_Ngram_emb(self,x,5)

        ## emb = torch.cat([x,sen_emb,n_gram_emb],dim=-1) #hidden_size*3
        #emb = torch.cat([x, n_gram_emb], dim=-1)  # hidden_size*3

        emb = x  # [batch_size, seq_len, hidden_size]
        trigger_logits1 = self.tri_fc1(emb)
        trigger_logits1 = nn.functional.leaky_relu_(
            trigger_logits1)  # x: [batch_size, seq_len, trigger_size + 2 ]

        ## tri_CRF1 ##
        trigger_loss = self.tri_CRF1.neg_log_likelihood_loss(
            feats=trigger_logits1, mask=mask, tags=triggers_y_2d)
        _, trigger_hat_2d = self.tri_CRF1.forward(feats=trigger_logits1,
                                                  mask=mask)

        self.emb = emb
        self.tri_result = trigger_hat_2d

        argument_keys = {}  # 记录预测出的正确的触发词,对应的正确角色
        sen_mask_id = []

        for i in range(batch_size):
            ## 列表 元素格式:[触发词开始位置,触发词结束位置,事件类型(34个)
            predicted_triggers = find_triggers([
                self.idx2trigger[trigger]
                for trigger in trigger_hat_2d[i].tolist()
            ])
            for predicted_trigger in predicted_triggers:
                ## 预测-触发词开始位置,预测-触发词结束位置,预测-事件类型(文本)
                t_start, t_end, t_type_str = predicted_trigger
                ## 当预测的触发词 是正确的
                if (t_start, t_end, t_type_str) in arguments_2d[i]['events']:
                    for (a_start, a_end,
                         a_type_idx) in arguments_2d[i]['events'][(
                             t_start, t_end, t_type_str)]:
                        if (i, t_start, t_end, t_type_str) in argument_keys:
                            argument_keys[(i, t_start, t_end,
                                           t_type_str)].append(
                                               (a_start, a_end, a_type_idx))
                        else:
                            argument_keys[(i, t_start, t_end, t_type_str)] = []
                            argument_keys[(i, t_start, t_end,
                                           t_type_str)].append(
                                               (a_start, a_end, a_type_idx))
                # else: #当预测触发词是错误的时候
                #     argument_keys[(i, t_start, t_end, t_type_str)] = []
        return trigger_loss, triggers_y_2d, trigger_hat_2d, sen_mask_id, argument_keys
Beispiel #5
0
def eval(model, iterator, fname):
    model.eval()

    words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], []

    with torch.no_grad():
        for i, batch in enumerate(iterator):
            tokens_x_2d, entities_x_3d, postags_x_2d, triggers_y_2d, arguments_2d, seqlens_1d, head_indexes_2d, words_2d, triggers_2d, adjm = batch

            trigger_loss, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys = model.module.predict_triggers(
                tokens_x_2d=tokens_x_2d,
                entities_x_3d=entities_x_3d,
                postags_x_2d=postags_x_2d,
                head_indexes_2d=head_indexes_2d,
                triggers_y_2d=triggers_y_2d,
                arguments_2d=arguments_2d,
                adjm=adjm)

            words_all.extend(words_2d)
            triggers_all.extend(triggers_2d)
            triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist())
            arguments_all.extend(arguments_2d)

            if len(argument_keys) > 0:
                argument_loss, arguments_y_2d, argument_hat_1d, argument_hat_2d = model.module.predict_arguments(
                    argument_hidden, argument_keys, arguments_2d, adjm)
                arguments_hat_all.extend(argument_hat_2d)
                # if i == 0:

                #     print("=====sanity check for triggers======")
                #     print('triggers_y_2d[0]:', triggers_y_2d[0])
                #     print("trigger_hat_2d[0]:", trigger_hat_2d[0])
                #     print("=======================")

                #     print("=====sanity check for arguments======")
                #     print('arguments_y_2d[0]:', arguments_y_2d[0])
                #     print('argument_hat_1d[0]:', argument_hat_1d[0])
                #     print("arguments_2d[0]:", arguments_2d)
                #     print("argument_hat_2d[0]:", argument_hat_2d)
                #     print("=======================")
            else:
                batch_size = len(arguments_2d)
                argument_hat_2d = [{'events': {}} for _ in range(batch_size)]
                arguments_hat_all.extend(argument_hat_2d)

    triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], []
    with open('temp', 'w', encoding="utf-8") as fout:
        for i, (words, triggers, triggers_hat, arguments,
                arguments_hat) in enumerate(
                    zip(words_all, triggers_all, triggers_hat_all,
                        arguments_all, arguments_hat_all)):
            triggers_hat = triggers_hat[:len(words)]
            triggers_hat = [idx2trigger[hat] for hat in triggers_hat]

            # [(ith sentence, t_start, t_end, t_type_str)]
            triggers_true.extend([(i, *item)
                                  for item in find_triggers(triggers)])
            triggers_pred.extend([(i, *item)
                                  for item in find_triggers(triggers_hat)])

            # [(ith sentence, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)]
            for trigger in arguments['events']:
                t_start, t_end, t_type_str = trigger
                for argument in arguments['events'][trigger]:
                    a_start, a_end, a_type_idx = argument
                    arguments_true.append((i, t_start, t_end, t_type_str,
                                           a_start, a_end, a_type_idx))

            for trigger in arguments_hat['events']:
                t_start, t_end, t_type_str = trigger
                for argument in arguments_hat['events'][trigger]:
                    a_start, a_end, a_type_idx = argument
                    arguments_pred.append((i, t_start, t_end, t_type_str,
                                           a_start, a_end, a_type_idx))

            for w, t, t_h in zip(words[1:-1], triggers, triggers_hat):
                fout.write('{}\t{}\t{}\n'.format(w, t, t_h))

            arg_write = arguments['events']
            for arg_key in arg_write:
                arg = arg_write[
                    arg_key]  # list,eg: [(0, 5, 25), (8, 19, 17), (20, 21, 29)]
                for ii, tup in enumerate(arg):
                    arg[ii] = (tup[0], tup[1], idx2argument[tup[2]]
                               )  # 将id 转为 str
                arg_write[arg_key] = arg

            arghat_write = arguments_hat['events']
            for arg_key in arghat_write:
                arg = arghat_write[
                    arg_key]  # list,eg: [(0, 5, 25), (8, 19, 17), (20, 21, 29)]
                for ii, tup in enumerate(arg):
                    arg[ii] = (tup[0], tup[1], idx2argument[tup[2]]
                               )  # 将id 转为 str
                arghat_write[arg_key] = arg

            fout.write('#真实值#\t{}\n'.format(arg_write))
            fout.write('#预测值#\t{}\n'.format(arghat_write))
            fout.write("\n")

    # print(classification_report([idx2trigger[idx] for idx in y_true], [idx2trigger[idx] for idx in y_pred]))

    print('[trigger classification]')
    trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true,
                                                   triggers_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r,
                                                 trigger_f1))

    print('[argument classification]')
    argument_p, argument_r, argument_f1 = calc_metric(arguments_true,
                                                      arguments_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p, argument_r,
                                                 argument_f1))

    print('[trigger identification]')
    triggers_true = [(item[0], item[1], item[2]) for item in triggers_true]
    triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred]
    trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true,
                                                      triggers_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_,
                                                 trigger_f1_))

    print('[argument identification]')
    arguments_true = [(item[0], item[1], item[2], item[3], item[4], item[5])
                      for item in arguments_true]
    arguments_pred = [(item[0], item[1], item[2], item[3], item[4], item[5])
                      for item in arguments_pred]
    argument_p_, argument_r_, argument_f1_ = calc_metric(
        arguments_true, arguments_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p_, argument_r_,
                                                 argument_f1_))

    metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        trigger_p, trigger_r, trigger_f1)
    metric += '[argument classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        argument_p, argument_r, argument_f1)
    metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        trigger_p_, trigger_r_, trigger_f1_)
    metric += '[argument identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        argument_p_, argument_r_, argument_f1_)
    final = fname + ".trigger-F%.2f argument-F%.2f" % (trigger_f1, argument_f1)
    with open(final, 'w', encoding="utf-8") as fout:
        result = open("temp", "r", encoding="utf-8").read()
        fout.write("{}\n".format(result))
        fout.write(metric)
    os.remove("temp")
    return metric, trigger_f1, argument_f1
Beispiel #6
0
def eval(model, iterator, fname):
    model.eval()

    words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], []
    with torch.no_grad():
        # for i, batch in enumerate(iterator):
        for i, (test, labels) in enumerate(iterator):
            trigger_logits, trigger_entities_hat_2d, triggers_y_2d, argument_hidden_logits, arguments_y_1d, argument_hidden_hat_1d, argument_hat_2d, argument_keys = model(
                test, labels)

            words_all.extend(test[3])
            triggers_all.extend(test[4])
            triggers_hat_all.extend(
                trigger_entities_hat_2d.cpu().numpy().tolist())
            arguments_2d = test[-1]
            arguments_all.extend(arguments_2d)
            if len(argument_keys) > 0:
                arguments_hat_all.extend(argument_hat_2d)
            else:
                batch_size = len(arguments_2d)
                argument_hat_2d = [{'events': {}} for _ in range(batch_size)]
                arguments_hat_all.extend(argument_hat_2d)

    triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], []
    with open('temp', 'w', encoding='utf-8') as fout:
        for i, (words, triggers, triggers_hat, arguments,
                arguments_hat) in enumerate(
                    zip(words_all, triggers_all, triggers_hat_all,
                        arguments_all, arguments_hat_all)):
            triggers_hat = triggers_hat[:len(words)]
            triggers_hat = [idx2trigger_entities[hat] for hat in triggers_hat]

            # [(ith sentence, t_start, t_end, t_type_str)]
            triggers_true_, entities_true = find_triggers(
                triggers[:len(words)])
            triggers_pred_, entities_pred = find_triggers(triggers_hat)
            triggers_true.extend([(i, *item) for item in triggers_true_])
            triggers_pred.extend([(i, *item) for item in triggers_pred_])

            # [(ith sentence, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)]
            for trigger in arguments['events']:
                t_start, t_end, t_type_str = trigger
                for argument in arguments['events'][trigger]:
                    a_start, a_end, a_type_idx = argument
                    arguments_true.append(
                        (t_type_str, a_start, a_end, a_type_idx))

            for trigger in arguments_hat['events']:
                t_start, t_end, t_type_str = trigger
                if t_start >= len(words) or t_end >= len(words):
                    continue
                for argument in arguments_hat['events'][trigger]:
                    a_start, a_end, a_type_idx = argument
                    if a_start >= len(words) or a_end >= len(words):
                        continue
                    arguments_pred.append(
                        (t_type_str, a_start, a_end, a_type_idx))

            for w, t, t_h in zip(words, triggers, triggers_hat):
                fout.write('{}\t{}\t{}\n'.format(w, t, t_h))
            fout.write('#arguments#{}\n'.format(arguments['events']))
            fout.write('#arguments_hat#{}\n'.format(arguments_hat['events']))
            fout.write("\n")

    # print(classification_report([idx2trigger[idx] for idx in y_true], [idx2trigger[idx] for idx in y_pred]))

    print('[trigger classification]')
    trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true,
                                                   triggers_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r,
                                                 trigger_f1))

    print('[argument classification]')
    argument_p, argument_r, argument_f1 = calc_metric(arguments_true,
                                                      arguments_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p, argument_r,
                                                 argument_f1))
    print('[trigger identification]')
    triggers_true = [(item[0], item[1], item[2]) for item in triggers_true]
    triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred]
    trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true,
                                                      triggers_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_,
                                                 trigger_f1_))

    print('[argument identification]')
    arguments_true = [(item[0], item[1], item[2]) for item in arguments_true]
    arguments_pred = [(item[0], item[1], item[2]) for item in arguments_pred]
    argument_p_, argument_r_, argument_f1_ = calc_metric(
        arguments_true, arguments_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p_, argument_r_,
                                                 argument_f1_))

    metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        trigger_p, trigger_r, trigger_f1)
    metric += '[argument classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        argument_p, argument_r, argument_f1)
    metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        trigger_p_, trigger_r_, trigger_f1_)
    metric += '[argument identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        argument_p_, argument_r_, argument_f1_)
    final = fname + ".P%.2f_R%.2f_F%.2f" % (trigger_p, trigger_r, trigger_f1)
    with open(final, 'w', encoding='utf-8') as fout:
        result = open("temp", "r", encoding='utf-8').read()
        fout.write("{}\n".format(result))
        fout.write(metric)
    os.remove("temp")
    return metric, trigger_f1, argument_f1
Beispiel #7
0
def eval(model, iterator, fname):
    model.eval()

    words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            #
            tokens_x_2d, entities_x_3d, postags_x_2d, triggers_y_2d, arguments_2d, seqlens_1d, head_indexes_2d, words_2d, triggers_2d, \
            pre_sent_tokens_x, next_sent_tokens_x, pre_sent_len, next_sent_len, maxlen = batch

            # maxlen = max(seqlens_1d)
            # pre_sent_len_max = max(pre_sent_len)
            # next_sent_len_max = max(next_sent_len)

            pre_sent_flags = []
            next_sent_flags = []

            pre_sent_len_mat = []
            next_sent_len_mat = []

            for i in pre_sent_len:
                tmp = [[1] * 768] * i + [[0] * 768] * (maxlen - i)
                pre_sent_flags.append(tmp)
                pre_sent_len_mat.append([i] * 768)

            for i in next_sent_len:
                tmp = [[1] * 768] * i + [[0] * 768] * (maxlen - i)
                next_sent_flags.append(tmp)
                next_sent_len_mat.append([i] * 768)

            # trigger_logits, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys = model.module.predict_triggers(tokens_x_2d=tokens_x_2d, entities_x_3d=entities_x_3d,
            trigger_logits, triggers_y_2d, trigger_hat_2d = model.predict_triggers(
                tokens_x_2d=tokens_x_2d,
                entities_x_3d=entities_x_3d,
                postags_x_2d=postags_x_2d,
                head_indexes_2d=head_indexes_2d,
                triggers_y_2d=triggers_y_2d,
                arguments_2d=arguments_2d,
                pre_sent_tokens_x=pre_sent_tokens_x,
                next_sent_tokens_x=next_sent_tokens_x,
                pre_sent_flags=pre_sent_flags,
                next_sent_flags=next_sent_flags,
                pre_sent_len_mat=pre_sent_len_mat,
                next_sent_len_mat=next_sent_len_mat)

            words_all.extend(words_2d)
            triggers_all.extend(triggers_2d)
            triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist())
            arguments_all.extend(arguments_2d)

    triggers_true, triggers_pred = [], []
    with open('temp', 'w', encoding='utf-8') as fout:
        for i, (words, triggers, triggers_hat) in enumerate(
                zip(words_all, triggers_all, triggers_hat_all)):
            triggers_hat = triggers_hat[:len(words)]
            triggers_hat = [idx2trigger[hat] for hat in triggers_hat]

            # [(ith sentence, t_start, t_end, t_type_str)]
            triggers_true.extend([(i, *item)
                                  for item in find_triggers(triggers)])
            triggers_pred.extend([(i, *item)
                                  for item in find_triggers(triggers_hat)])

            for w, t, t_h in zip(words[1:-1], triggers, triggers_hat):
                fout.write('{}\t{}\t{}\n'.format(w, t, t_h))
            fout.write("\n")

    # print(classification_report([idx2trigger[idx] for idx in y_true], [idx2trigger[idx] for idx in y_pred]))

    print('[trigger classification]')
    trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true,
                                                   triggers_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r,
                                                 trigger_f1))

    print('[trigger identification]')
    triggers_true = [(item[0], item[1], item[2]) for item in triggers_true]
    triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred]
    trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true,
                                                      triggers_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_,
                                                 trigger_f1_))

    metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        trigger_p, trigger_r, trigger_f1)
    metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        trigger_p_, trigger_r_, trigger_f1_)

    final = fname + ".P%.2f_R%.2f_F%.2f" % (trigger_p, trigger_r, trigger_f1)

    metric_2 = {
        "trigger classification": [trigger_p, trigger_r, trigger_f1],
        "trigger identification": [trigger_p_, trigger_r_, trigger_f1_]
    }

    with open(final, 'w') as fout:
        result = open("temp", "r").read()
        fout.write("{}\n".format(result))
        fout.write(metric)
    os.remove("temp")
    return metric_2
def eval(model, iterator, fname):
    model.eval()

    words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            tokens_x_2d, entities_x_3d, postags_x_2d, triggers_y_2d, arguments_2d, seqlens_1d, head_indexes_2d, words_2d, triggers_2d = batch

            trigger_logits, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys = model.module.predict_triggers(
                tokens_x_2d=tokens_x_2d,
                entities_x_3d=entities_x_3d,
                postags_x_2d=postags_x_2d,
                head_indexes_2d=head_indexes_2d,
                triggers_y_2d=triggers_y_2d,
                arguments_2d=arguments_2d)

            words_all.extend(words_2d)
            triggers_all.extend(triggers_2d)
            triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist())
            arguments_all.extend(arguments_2d)

            if len(argument_keys) > 0:
                argument_logits, arguments_y_1d, argument_hat_1d, argument_hat_2d = model.module.predict_arguments(
                    argument_hidden, argument_keys, arguments_2d)
                arguments_hat_all.extend(argument_hat_2d)
            else:
                batch_size = len(arguments_2d)
                argument_hat_2d = [{'events': {}} for _ in range(batch_size)]
                arguments_hat_all.extend(argument_hat_2d)

    triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], []
    with open('temp', 'w') as fout:
        for i, (words, triggers, triggers_hat, arguments,
                arguments_hat) in enumerate(
                    zip(words_all, triggers_all, triggers_hat_all,
                        arguments_all, arguments_hat_all)):
            triggers_hat = triggers_hat[:len(words)]
            triggers_hat = [idx2trigger[hat] for hat in triggers_hat]

            # [(ith sentence, t_start, t_end, t_type_str)]
            triggers_true.extend([(i, *item)
                                  for item in find_triggers(triggers)])
            triggers_pred.extend([(i, *item)
                                  for item in find_triggers(triggers_hat)])

            # [(ith sentence, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)]
            for trigger in arguments['events']:
                t_start, t_end, t_type_str = trigger
                for argument in arguments['events'][trigger]:
                    a_start, a_end, a_type_idx = argument
                    arguments_true.append((i, t_start, t_end, t_type_str,
                                           a_start, a_end, a_type_idx))

            for trigger in arguments_hat['events']:
                t_start, t_end, t_type_str = trigger
                for argument in arguments_hat['events'][trigger]:
                    a_start, a_end, a_type_idx = argument
                    arguments_pred.append((i, t_start, t_end, t_type_str,
                                           a_start, a_end, a_type_idx))

            for w, t, t_h in zip(words[1:-1], triggers, triggers_hat):
                fout.write('{}\t{}\t{}\n'.format(w, t, t_h))
            fout.write('#arguments#{}\n'.format(arguments['events']))
            fout.write('#arguments_hat#{}\n'.format(arguments_hat['events']))
            fout.write("\n")

    # print(classification_report([idx2trigger[idx] for idx in y_true], [idx2trigger[idx] for idx in y_pred]))

    print('[trigger classification]')
    trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true,
                                                   triggers_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r,
                                                 trigger_f1))

    print('[argument classification]')
    argument_p, argument_r, argument_f1 = calc_metric(arguments_true,
                                                      arguments_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p, argument_r,
                                                 argument_f1))
    print('[trigger identification]')
    triggers_true = [(item[0], item[1], item[2]) for item in triggers_true]
    triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred]
    trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true,
                                                      triggers_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_,
                                                 trigger_f1_))

    print('[argument identification]')
    arguments_true = [(item[0], item[1], item[2], item[3], item[4], item[5])
                      for item in arguments_true]
    arguments_pred = [(item[0], item[1], item[2], item[3], item[4], item[5])
                      for item in arguments_pred]
    argument_p_, argument_r_, argument_f1_ = calc_metric(
        arguments_true, arguments_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p_, argument_r_,
                                                 argument_f1_))

    metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        trigger_p, trigger_r, trigger_f1)
    metric += '[argument classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        argument_p, argument_r, argument_f1)
    metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        trigger_p_, trigger_r_, trigger_f1_)
    metric += '[argument identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        argument_p_, argument_r_, argument_f1_)
    final = fname + ".P%.2f_R%.2f_F%.2f" % (trigger_p, trigger_r, trigger_f1)
    with open(final, 'w') as fout:
        result = open("temp", "r").read()
        fout.write("{}\n".format(result))
        fout.write(metric)
    os.remove("temp")
    return metric
Beispiel #9
0
def eval_module(model, iterator, fname, module, idx2argument):
    model.eval()

    words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            tokens_x_2d, entities_x_3d, postags_x_2d, triggers_y_2d, arguments_2d, seqlens_1d, head_indexes_2d, words_2d, triggers_2d = batch

            trigger_logits, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys, trigger_info, auxiliary_feature = model.module.predict_triggers(
                tokens_x_2d=tokens_x_2d,
                entities_x_3d=entities_x_3d,
                postags_x_2d=postags_x_2d,
                head_indexes_2d=head_indexes_2d,
                triggers_y_2d=triggers_y_2d,
                arguments_2d=arguments_2d)

            words_all.extend(words_2d)
            triggers_all.extend(triggers_2d)
            triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist())
            arguments_all.extend(arguments_2d)

            if len(argument_keys) > 0:
                argument_logits, arguments_y_1d, argument_hat_1d, argument_hat_2d = model.module.module_predict_arguments(
                    argument_hidden, argument_keys, arguments_2d, module)
                module_decisions_logit, module_decisions_y, argument_hat_2d = model.module.meta_classifier(
                    argument_keys, arguments_2d, trigger_info, argument_logits,
                    argument_hat_1d, auxiliary_feature, module)
                arguments_hat_all.extend(argument_hat_2d)
            else:
                batch_size = len(arguments_2d)
                argument_hat_2d = [{'events': {}} for _ in range(batch_size)]
                arguments_hat_all.extend(argument_hat_2d)

    triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], []
    with open('temp', 'w') as fout:
        for i, (words, triggers, triggers_hat, arguments,
                arguments_hat) in enumerate(
                    zip(words_all, triggers_all, triggers_hat_all,
                        arguments_all, arguments_hat_all)):
            triggers_hat = triggers_hat[:len(words)]
            triggers_hat = [idx2trigger[hat] for hat in triggers_hat]

            # [(ith sentence, t_start, t_end, t_type_str)]
            triggers_true.extend([(i, *item)
                                  for item in find_triggers(triggers)])
            triggers_pred.extend([(i, *item)
                                  for item in find_triggers(triggers_hat)])

            # [(ith sentence, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)]
            for trigger in arguments['events']:
                t_start, t_end, t_type_str = trigger
                for argument in arguments['events'][trigger]:
                    a_start, a_end, a_type_idx = argument
                    # strict metric
                    #arguments_true.append((i, t_start, t_end, t_type_str, a_start, a_end, a_type_idx))
                    # relaxed metric
                    if idx2argument[a_type_idx] == module:
                        arguments_true.append(
                            (i, t_type_str, a_start, a_end, 2))
                    #else:
                    #  arguments_true.append((i, t_type_str, a_start, a_end, 1))

            #print(arguments_hat)
            for trigger in arguments_hat['events']:
                t_start, t_end, t_type_str = trigger
                for argument in arguments_hat['events'][trigger]:
                    a_start, a_end, a_type_idx = argument
                    # stric metric
                    # arguments_pred.append((i, t_start, t_end, t_type_str, a_start, a_end, a_type_idx))
                    # relaxed metric
                    #if idx2argument[a_type_idx] == module:
                    arguments_pred.append(
                        (i, t_type_str, a_start, a_end, a_type_idx
                         ))  # 2 is the specific argument idx in module network
                    # else:
                    #   print(idx2argument[a_type_idx])
                    #   arguments_pred.append((i, t_type_str, a_start, a_end, 1))

            # if len(arguments_pred) == 0:
            #   print('---batch {} -----'.format(i))
            #   print(arguments_hat)

            for w, t, t_h in zip(words[1:-1], triggers, triggers_hat):
                fout.write('{}\t{}\t{}\n'.format(w, t, t_h))
            fout.write('#arguments#{}\n'.format(arguments['events']))
            fout.write('#arguments_hat#{}\n'.format(arguments_hat['events']))
            fout.write("\n")

    # print(classification_report([idx2trigger[idx] for idx in y_true], [idx2trigger[idx] for idx in y_pred]))

    print('[trigger classification]')
    trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true,
                                                   triggers_pred)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r,
                                                 trigger_f1))

    print('[argument classification]')
    argument_p, argument_r, argument_f1, num_proposed, num_correct, num_gold = calc_metric(
        arguments_true, arguments_pred, num_flag=True)
    print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p, argument_r,
                                                 argument_f1))

    #print('[trigger identification]')
    # triggers_true = [(item[0], item[1], item[2]) for item in triggers_true]
    # triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred]
    # trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true, triggers_pred)
    #print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_, trigger_f1_))

    #print('[argument identification]')
    # strcit metric
    #arguments_true = [(item[0], item[1], item[2], item[3], item[4], item[5]) for item in arguments_true]
    #arguments_pred = [(item[0], item[1], item[2], item[3], item[4], item[5]) for item in arguments_pred]
    # relax metric
    # arguments_true = [(item[0], item[1], item[2], item[3]) for item in arguments_true]
    # arguments_pred = [(item[0], item[1], item[2], item[3]) for item in arguments_pred]
    # argument_p_, argument_r_, argument_f1_ = calc_metric(arguments_true, arguments_pred)
    #print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p_, argument_r_, argument_f1_))

    metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(
        trigger_p, trigger_r, trigger_f1)
    # metric += '[argument classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(argument_p, argument_r, argument_f1)
    # metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(trigger_p_, trigger_r_, trigger_f1_)
    # metric += '[argument identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(argument_p_, argument_r_, argument_f1_)
    # final = fname + ".P%.2f_R%.2f_F%.2f" % (trigger_p, trigger_r, trigger_f1)
    # with open(final, 'w') as fout:
    #     result = open("temp", "r").read()
    #     fout.write("{}\n".format(result))
    #     fout.write(metric)
    # os.remove("temp")
    return metric, argument_f1, num_proposed, num_correct, num_gold  #,arguments_true, arguments_pred