Example #1
0
def evaluate_ppl(model, data_iter, verbos=True):
    toatl_ppl = 0
    toatl_nums = 0
    padding_idx = model.padding_idx

    for batch_id, inputs in enumerate(data_iter, 1):

        enc_inputs = inputs
        dec_inputs = inputs.tgt[0][:, :-1], inputs.tgt[1] - 1
        probs = model(enc_inputs, dec_inputs).logits
        tgt = inputs.tgt[0][:, 1:]
        batch_ppl = perplexity(logits=probs, targets=tgt, )

        toatl_ppl += batch_ppl  
        toatl_nums += 1

    ppl = toatl_ppl / toatl_nums
    message = f"PPL: {ppl:.4f}"
    if verbos:
        print(message)
    else:
        return message
    def collect_metrics(self, outputs, target, epoch=-1):
        """
        collect_metrics
        """
        num_samples = target.size(0)
        metrics = Pack(num_samples=num_samples)
        loss = 0

        # test begin
        # nll = self.nll(torch.log(outputs.posterior_attn+1e-10), outputs.attn_index)
        # loss += nll
        # attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index)
        # metrics.add(attn_acc=attn_acc)
        # metrics.add(loss=loss)
        # return metrics
        # test end

        logits = outputs.logits
        scores = -self.nll_loss(logits, target, reduction=False)
        nll_loss = self.nll_loss(logits, target)
        num_words = target.ne(self.padding_idx).sum().item()
        acc = accuracy(logits, target, padding_idx=self.padding_idx)
        metrics.add(nll=(nll_loss, num_words), acc=acc)

        if self.use_posterior:
            kl_loss = self.kl_loss(torch.log(outputs.prior_attn + 1e-10),
                                   outputs.posterior_attn.detach())
            metrics.add(kl=kl_loss)
            if self.use_bow:
                bow_logits = outputs.bow_logits
                bow_labels = target[:, :-1]
                bow_logits = bow_logits.repeat(1, bow_labels.size(-1), 1)
                bow = self.nll_loss(bow_logits, bow_labels)
                loss += bow
                metrics.add(bow=bow)
            if self.use_dssm:
                mse = self.mse_loss(outputs.dssm, outputs.reply_vec.detach())
                loss += mse
                metrics.add(mse=mse)
                pos_logits = outputs.pos_logits
                pos_target = torch.ones_like(pos_logits)
                neg_logits = outputs.neg_logits
                neg_target = torch.zeros_like(neg_logits)
                pos_loss = F.binary_cross_entropy_with_logits(pos_logits,
                                                              pos_target,
                                                              reduction='none')
                neg_loss = F.binary_cross_entropy_with_logits(neg_logits,
                                                              neg_target,
                                                              reduction='none')
                loss += (pos_loss + neg_loss).mean()
                metrics.add(pos_loss=pos_loss.mean(), neg_loss=neg_loss.mean())

            if epoch == -1 or epoch > self.pretrain_epoch or \
               (self.use_bow is not True and self.use_dssm is not True):
                loss += nll_loss
                loss += kl_loss
                if self.use_pg:
                    posterior_probs = outputs.posterior_attn.gather(
                        1, outputs.indexs.view(-1, 1))
                    reward = -perplexity(logits, target, self.weight,
                                         self.padding_idx) * 100
                    pg_loss = -(reward.detach() -
                                self.baseline) * posterior_probs.view(-1)
                    pg_loss = pg_loss.mean()
                    loss += pg_loss
                    metrics.add(pg_loss=pg_loss, reward=reward.mean())
            if 'attn_index' in outputs:
                attn_acc = attn_accuracy(outputs.posterior_attn,
                                         outputs.attn_index)
                metrics.add(attn_acc=attn_acc)
        else:
            loss += nll_loss

        metrics.add(loss=loss)
        return metrics, scores