示例#1
0
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, lus=None, frames=None, args=None, using_gold_fame=False):
#         print(input_ids.type())
#         print(attention_mask.type())
#         print(lus.type())
#         print(frames.type())
#         print(args.type())
        sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
        sequence_output = self.dropout(sequence_output)
        pooled_output = self.dropout(pooled_output)
        
        frame_logits = self.frame_classifier(pooled_output)
        arg_logits = self.arg_classifier(sequence_output)
        
        lufr_masks = dataio.get_masks(lus, self.lufrmap, num_label=self.num_frames).to(device)
        
        frame_loss = 0 # loss for frame id
        arg_loss = 0 # loss for arg id
        if frames is not None:
            for i in range(len(frame_logits)):
                frame_logit = frame_logits[i]
                arg_logit = arg_logits[i]
                lufr_mask = lufr_masks[i]
                gold_frame = frames[i]
                gold_arg = args[i]
                
                #train frame classifier
                loss_fct_frame = CrossEntropyLoss(weight = lufr_mask)
                loss_per_seq_for_frame = loss_fct_frame(frame_logit.view(-1, self.num_frames), gold_frame.view(-1))
                frame_loss += loss_per_seq_for_frame
                
                #train arg classifier
                pred_frame, frame_score = self.logit2label(frame_logit, lufr_mask)
                frarg_mask = dataio.get_masks([pred_frame], self.frargmap, num_label=self.num_args).to(device)[0]
                
                loss_fct_arg = CrossEntropyLoss(weight = frarg_mask)
                
                # only keep active parts of loss
                if attention_mask is not None:
                    active_loss = attention_mask[i].view(-1) == 1
                    active_logits = arg_logit.view(-1, self.num_args)[active_loss]
                    active_labels = gold_arg.view(-1)[active_loss]
                    loss_per_seq_for_arg = loss_fct_arg(active_logits, active_labels)
                else:
                    loss_per_seq_for_arg = loss_fct_arg(arg_logit.view(-1, self.num_args), gold_arg.view(-1))
                arg_loss += loss_per_seq_for_arg
            
            # 0.5 weighted loss
            total_loss = 0.5*frame_loss + 0.5*arg_loss
            loss = total_loss / len(frame_logits)
            return loss
        else:
            return frame_logits, arg_logits
示例#2
0
 def forward(self, input_ids, token_type_ids=None, attention_mask=None, tgt_idxs=0, lus=None, frames=None,):
     sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)        
     sequence_output = self.dropout(sequence_output)
     tgt_vec = []
     for i in range(len(sequence_output)):
         tgt_vec.append(sequence_output[i][tgt_idxs[i]])
     tgt_vec = torch.stack(tgt_vec)
     lu_vec = self.lu_embeddings(lus)
     
     tgt_embs = torch.cat((tgt_vec, lu_vec), -1)
     logits = self.classifier(tgt_embs)
     masks = dataio.get_masks(lus, self.lufrmap, num_label=self.num_labels).to(device)
     
     total_loss = 0
     if frames is not None:
         for i in range(len(logits)):
             logit = logits[i]
             mask = masks[i]
             frame = frames[i]
             loss_fct = CrossEntropyLoss(weight = mask)
             loss_per_seq = loss_fct(logit.view(-1, self.num_labels), frame.view(-1))
             total_loss += loss_per_seq
         loss = total_loss / len(logits)
         return loss
     else:
         return logits
示例#3
0
 def forward(self, input_ids, token_type_ids=None, attention_mask=None, tgt_idxs=0, lus=None, frames=None, arg_idxs=None, args=None):
     sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)        
     sequence_output = self.dropout(sequence_output)
     
     # target and arg vector
     tgt_vec, arg_vec = [],[]
     for i in range(len(sequence_output)):
         tgt_vec.append(sequence_output[i][tgt_idxs[i]])
         arg_vec.append(sequence_output[i][arg_idxs[i]])
     tgt_vec = torch.stack(tgt_vec)
     arg_vec = torch.stack(arg_vec)
     # LU vector
     lu_vec = self.lu_embeddings(lus)
     #frame vector
     frame_vec = self.frame_embeddings(frames)
     # arg_embs
     arg_embs = torch.cat((arg_vec, tgt_vec, lu_vec, frame_vec), -1)
     
     logits = self.classifier(arg_embs)
     masks = dataio.get_masks(frames, self.frargmap, num_label=self.num_labels).to(device)
     
     total_loss = 0
     if args is not None:
         for i in range(len(logits)):
             logit = logits[i]
             mask = masks[i]
             arg = args[i]
             loss_fct = CrossEntropyLoss(weight = mask)
             loss_per_seq = loss_fct(logit.view(-1, self.num_labels), arg.view(-1))
             total_loss += loss_per_seq
         loss = total_loss / len(logits)
         return loss
     else:
         return logits
示例#4
0
 def forward(self, input_ids, token_type_ids=None, attention_mask=None, lus=None, frames=None,):
     _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)        
     pooled_output = self.dropout(pooled_output)
     logits = self.classifier(pooled_output)
     masks = dataio.get_masks(lus, self.lufrmap, num_label=self.num_labels).to(device)
     
     total_loss = 0
     if frames is not None:
         for i in range(len(logits)):
             logit = logits[i]
             mask = masks[i]
             frame = frames[i]
             loss_fct = CrossEntropyLoss(weight = mask)
             loss_per_seq = loss_fct(logit.view(-1, self.num_labels), frame.view(-1))
             total_loss += loss_per_seq
         loss = total_loss / len(logits)
         return loss
     else:
         return logits
def test():
    model_path = model_dir + framenet + '/'
    models = glob.glob(model_path + '*.pt')
    results = []
    for m in models:
        print('model:', m)
        model = torch.load(m)
        model.eval()

        tst_data = bert_io.convert_to_bert_input_frameid(tst)
        sampler = RandomSampler(tst)
        tst_dataloader = DataLoader(tst_data,
                                    sampler=sampler,
                                    batch_size=batch_size)

        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        predictions, true_labels, scores, candis, all_lus = [], [], [], [], []

        for batch in tst_dataloader:
            batch = tuple(t.to(device) for t in batch)
            b_input_ids, b_tgt_idxs, b_lus, b_frames, b_masks = batch

            with torch.no_grad():
                tmp_eval_loss = model(b_input_ids,
                                      token_type_ids=None,
                                      lus=b_lus,
                                      frames=b_frames,
                                      attention_mask=b_masks)
                logits = model(b_input_ids,
                               token_type_ids=None,
                               lus=b_lus,
                               attention_mask=b_masks)
            logits = logits.detach().cpu().numpy()
            label_ids = b_frames.to('cpu').numpy()
            masks = dataio.get_masks(b_lus,
                                     bert_io.lufrmap,
                                     num_label=len(
                                         bert_io.frame2idx)).to(device)
            for lu in b_lus:
                candi_idx = bert_io.lufrmap[str(int(lu))]
                candi = [bert_io.idx2frame[c] for c in candi_idx]
                candi_txt = ','.join(candi)
                candi_txt = str(len(candi)) + '\t' + candi_txt
                candis.append(candi_txt)
                all_lus.append(bert_io.idx2lu[int(lu)])

            for b_idx in range(len(logits)):
                logit = logits[b_idx]
                mask = masks[b_idx]
                b_pred_idxs, b_pred_logits = [], []
                for fr_idx in range(len(mask)):
                    if mask[fr_idx] > 0:
                        b_pred_idxs.append(fr_idx)
                        b_pred_logits.append(logit[fr_idx].item())
                b_pred_idxs = torch.tensor(b_pred_idxs)
                b_pred_logits = torch.tensor(b_pred_logits)
                sm = nn.Softmax()
                b_pred_logits = sm(b_pred_logits).view(1, -1)
                score, indice = b_pred_logits.max(1)
                prediction = b_pred_idxs[indice]
                predictions.append([int(prediction)])
                score = float(score)
                scores.append(score)
            true_labels.append(label_ids)
#             break

        pred_tags = [bert_io.idx2frame[p_i] for p in predictions for p_i in p]
        valid_tags = [
            bert_io.idx2frame[l_ii] for l in true_labels for l_i in l
            for l_ii in l_i
        ]

        acc = accuracy_score(pred_tags, valid_tags)
        print("Accuracy: {}".format(accuracy_score(pred_tags, valid_tags)))

        result = m + '\t' + str(acc) + '\n'
        results.append(result)

        epoch = m.split('-')[1]
        fname = model_path + str(epoch) + '-result.txt'
        with open(fname, 'w') as f:
            line = 'accuracy: ' + str(acc) + '\n\n'
            f.write(line)
            line = 'gold' + '\t' + 'prediction' + '\t' + 'score' + '\t' + 'lu' + '\t' + 'candis' + '\n'
            f.write(line)
            for r in range(len(pred_tags)):
                line = valid_tags[r] + '\t' + pred_tags[r] + '\t' + str(
                    scores[r]) + '\t' + all_lus[r] + '\t' + candis[r] + '\n'
                f.write(line)
    fname = model_path + 'accuracy.txt'
    with open(fname, 'w') as f:
        for r in results:
            f.write(r)

    print('result is written to', fname)
    def test(self, tst=False, model_dir='.', MAX_LEN=256, batch_size=8):
        model_path = model_dir + '/' + self.framenet + '-frameid-' + str(
            self.version) + '.pt'
        print('your model is', model_path)
        model = torch.load(model_path)

        model.eval()
        tst_data, tst_sampler, tst_dataloader = frameid.gen_bert_input_representation(
            self, tst, MAX_LEN=256, batch_size=8)

        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        predictions, true_labels, scores, candis, all_lus = [], [], [], [], []
        for batch in tst_dataloader:
            batch = tuple(t.to(device) for t in batch)
            b_input_ids, b_tgt_idxs, b_lus, b_frames, b_masks = batch

            with torch.no_grad():
                tmp_eval_loss = model(b_input_ids,
                                      token_type_ids=None,
                                      tgt_idxs=b_tgt_idxs,
                                      lus=b_lus,
                                      frames=b_frames,
                                      attention_mask=b_masks)
                logits = model(b_input_ids,
                               token_type_ids=None,
                               tgt_idxs=b_tgt_idxs,
                               lus=b_lus,
                               attention_mask=b_masks)
            logits = logits.detach().cpu().numpy()
            label_ids = b_frames.to('cpu').numpy()
            masks = dataio.get_masks(b_lus,
                                     self.lufrmap,
                                     num_label=len(self.frame2idx)).to(device)
            for lu in b_lus:
                candi_idx = self.lufrmap[str(int(lu))]
                candi = [self.idx2frame[c] for c in candi_idx]
                candi_txt = ','.join(candi)
                candi_txt = str(len(candi)) + '\t' + candi_txt
                candis.append(candi_txt)
                all_lus.append(self.idx2lu[int(lu)])

            for b_idx in range(len(logits)):
                logit = logits[b_idx]
                mask = masks[b_idx]
                b_pred_idxs, b_pred_logits = [], []
                for fr_idx in range(len(mask)):
                    if mask[fr_idx] > 0:
                        b_pred_idxs.append(fr_idx)
                        b_pred_logits.append(logit[0][fr_idx].item())
                b_pred_idxs = torch.tensor(b_pred_idxs)
                b_pred_logits = torch.tensor(b_pred_logits)
                sm = nn.Softmax()
                b_pred_logits = sm(b_pred_logits).view(1, -1)
                score, indice = b_pred_logits.max(1)
                prediction = b_pred_idxs[indice]
                predictions.append([int(prediction)])
                score = float(score)
                scores.append(score)
            true_labels.append(label_ids)
            tmp_eval_accuracy = frameid.flat_accuracy(self, logits, label_ids)
            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy
            nb_eval_examples += b_input_ids.size(0)
            nb_eval_steps += 1
        pred_tags = [self.idx2frame[p_i] for p in predictions for p_i in p]
        valid_tags = [
            self.idx2frame[l_ii] for l in true_labels for l_i in l
            for l_ii in l_i
        ]
        acc = accuracy_score(pred_tags, valid_tags)
        print("Accuracy: {}".format(accuracy_score(pred_tags, valid_tags)))

        return acc
    def joint_parser(self, text):
        conll_data = dataio.preprocessor(text)

        # target ID
        tid_data = targetid.baseline(conll_data)

        # add <tgt> and </tgt> to target word
        tgt_data = data2tgt_data(tid_data, mode='parse')

        result = []
        if tgt_data:

            # convert conll to bert inputs
            bert_inputs = self.bert_io.convert_to_bert_input_JointFrameParsing(
                tgt_data)
            dataloader = DataLoader(bert_inputs, sampler=None, batch_size=1)

            pred_frames, pred_args = [], []
            for batch in dataloader:
                batch = tuple(t.to(device) for t in batch)
                b_input_ids, b_orig_tok_to_maps, b_lus, b_masks = batch

                with torch.no_grad():
                    frame_logits, arg_logits = self.model(
                        b_input_ids,
                        token_type_ids=None,
                        lus=b_lus,
                        attention_mask=b_masks)

                frame_logits = frame_logits.detach().cpu().numpy()
                arg_logits = arg_logits.detach().cpu().numpy()
                input_ids = b_input_ids.to('cpu').numpy()
                lufr_masks = dataio.get_masks(
                    b_lus,
                    self.bert_io.lufrmap,
                    num_label=len(self.bert_io.frame2idx)).to(device)

                for b_idx in range(len(frame_logits)):
                    input_id = input_ids[b_idx]
                    frame_logit = frame_logits[b_idx]
                    arg_logit = arg_logits[b_idx]
                    lufr_mask = lufr_masks[b_idx]
                    orig_tok_to_map = b_orig_tok_to_maps[b_idx]

                    pred_frame, frame_score = logit2label(
                        frame_logit, lufr_mask)
                    frarg_mask = dataio.get_masks(
                        [pred_frame],
                        self.bert_io.bio_frargmap,
                        num_label=len(self.bert_io.bio_arg2idx)).to(device)[0]

                    pred_arg_bert = []
                    for logit in arg_logit:
                        label, score = logit2label(logit, frarg_mask)
                        pred_arg_bert.append(int(label))

                    #infer
                    pred_arg = []
                    for idx in orig_tok_to_map:
                        if idx != -1:
                            tok_id = int(input_id[idx])
                            if tok_id == 1:
                                pass
                            elif tok_id == 2:
                                pass
                            else:
                                pred_arg.append(pred_arg_bert[idx])
                    pred_frames.append([int(pred_frame)])
                    pred_args.append(pred_arg)

            pred_frame_tags = [
                self.bert_io.idx2frame[p_i] for p in pred_frames for p_i in p
            ]
            pred_arg_tags = [[self.bert_io.idx2bio_arg[p_i] for p_i in p]
                             for p in pred_args]

            for i in range(len(pred_arg_tags)):
                conll = tid_data[i]
                frame_seq = ['_' for i in range(len(conll[0]))]
                for idx in range(len(conll[1])):
                    if conll[1][idx] != '_':
                        frame_seq[idx] = pred_frame_tags[i]
                conll.append(frame_seq)
                conll.append(pred_arg_tags[i])
                result.append(conll)

        return result