Ejemplo n.º 1
0
    def generate(self,batch_iter, num_batches):
        self.model.eval()
        itoemo = ['NORM', 'POS', 'NEG']
        with torch.no_grad():
            results = []
            for inputs in batch_iter:
                enc_inputs = inputs
                dec_inputs = inputs.num_tgt_input
                enc_outputs=Pack()
                outputs = self.model.forward(enc_inputs, dec_inputs, hidden=None)
                outputs=outputs.logits
                preds = outputs.max(dim=2)
                # news_id = inputs.id
                tgt_raw = inputs.raw_tgt
                preds = preds[1].tolist()

                temp_a_1=[]
                emo_b_1=[]
                temp = []
                tgt_emo = inputs.tgt_emo[0].tolist()
                for  a, b, c in zip( tgt_raw, preds, tgt_emo):
                    # enc_outputs.add(preds=preds, scores=scores, emos=emos, target_emos=temp)
                    # result_batch = enc_outputs.flatten()
                    # results += result_batch
                    a = a[1:]
                    temp_a = []
                    emo_b = []
                    emo_c=[]

                    for i, entity in enumerate(a):
                        temp_a.append(entity)
                        emo_b.append(itoemo[b[i]])
                        emo_c.append(itoemo[c[i]])

                        # tgt_raw=tgt_raw[1:]
                    assert len(temp_a) == len(emo_b)
                    assert len(emo_c) == len(emo_b)
                    temp_a_1.append([temp_a])   #pred1
                    emo_b_1.append([emo_b])   # emo
                    temp.append(emo_c)

                # temp = []
                # tgt_emo = inputs.tgt_emo[0].tolist()
                # for item in tgt_emo:
                #     temp.append([itoemo[x] for x in item])

                # print(emo_b_1)
                # print(temp)

                if hasattr(inputs, 'id') and inputs.id is not None:
                    enc_outputs.add(id=inputs['id'])
                enc_outputs.add(tgt=tgt_raw, preds=temp_a_1, emos=emo_b_1, target_emos=temp)
                result_batch=enc_outputs.flatten()
                results+=result_batch
            return results
Ejemplo n.º 2
0
def reward_fn1(self, preds, targets, gold_ents, ptr_index, task_label):
    """
    reward_fn1
    General reward
    """
    # parameters
    alpha1 = 1.0
    alpha2 = 0.3

    # acc reward
    '''
    # get the weighted mask
    no_padding_mask = preds.ne(self.padding_idx).float()
    trues = (preds == targets).float()
    if self.padding_idx is not None:
        weights = no_padding_mask
        acc = (weights * trues).sum(dim=1) / weights.sum(dim=1)
    else:
        acc = trues.mean(dim=1)
    '''

    pred_text = self.tgt_field.denumericalize(preds)
    tgt_text = self.tgt_field.denumericalize(targets)
    batch_size = targets.size(0)
    batch_kb_inputs = self.kbs[:batch_size, :, :]
    kb_plain = self.kb_field.denumericalize(batch_kb_inputs)

    result = Pack()
    result.add(pred_text=pred_text, tgt_text=tgt_text, gold_ents=gold_ents, kb_plain=kb_plain)
    result_list = result.flatten()

    # bleu reward
    bleu_score = []
    for res in result_list:
        hyp_toks = res.pred_text.split()
        ref_toks = res.tgt_text.split()
        try:
            bleu_1 = sentence_bleu(references=[ref_toks], hypothesis=hyp_toks,
                                   smoothing_function=SmoothingFunction().method7,
                                   weights=[1, 0, 0, 0])
        except:
            bleu_1 = 0
        try:
            bleu_2 = sentence_bleu(references=[ref_toks], hypothesis=hyp_toks,
                                   smoothing_function=SmoothingFunction().method7,
                                   weights=[0.5, 0.5, 0, 0])
        except:
            bleu_2 = 0
        bleu = (bleu_1 + bleu_2) / 2
        bleu_score.append(bleu)
    bleu_score = torch.tensor(bleu_score, dtype=torch.float)

    # entity f1 reward
    f1_score = []
    report_f1 = []
    for res in result_list:
        if len(res.gold_ents) == 0:
            f1_pred = 1.0
        else:
            # TODO: change the way
            #gold_entity = ' '.join(res.gold_ents).replace('_', ' ').split()
            #pred_sent = res.pred_text.replace('_', ' ')
            gold_entity = res.gold_ents
            pred_sent = res.pred_text
            f1_pred, _ = compute_prf(gold_entity, pred_sent,
                                     global_entity_list=[], kb_plain=res.kb_plain)
            report_f1.append(f1_pred)
        f1_score.append(f1_pred)
    if len(report_f1) == 0:
        report_f1.append(0.0)
    f1_score = torch.tensor(f1_score, dtype=torch.float)
    report_f1 = torch.tensor(report_f1, dtype=torch.float)

    if self.use_gpu:
        bleu_score = bleu_score.cuda()
        f1_score = f1_score.cuda()
        report_f1 = report_f1.cuda()

    # compound reward
    #reward = alpha1 * bleu_score.unsqueeze(-1) + alpha2 * f1_score.unsqueeze(-1)
    reward = alpha1 * bleu_score.unsqueeze(-1)

    return reward, bleu_score, report_f1