def collect_metrics(self, outputs, target, epoch=-1): """ collect_metrics """ num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 # test begin # nll = self.nll(torch.log(outputs.posterior_attn+1e-10), outputs.attn_index) # loss += nll # attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index) # metrics.add(attn_acc=attn_acc) # metrics.add(loss=loss) # return metrics # test end logits = outputs.logits scores = -self.nll_loss(logits, target, reduction=False) nll_loss = self.nll_loss(logits, target) num_words = target.ne(self.padding_idx).sum().item() acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(nll=(nll_loss, num_words), acc=acc) # persona loss if 'attn_index' in outputs: attn_acc = attn_accuracy(outputs.cue_attn, outputs.attn_index) metrics.add(attn_acc=attn_acc) per_logits = torch.log(outputs.cue_attn + self.eps) # cue_attn(batch_size, sent_num) per_labels = outputs.attn_index ##(batch_size) use_per_loss = self.persona_loss( per_logits, per_labels) # per_labels(batch_size) metrics.add(use_per_loss=use_per_loss) loss += 0.7 * use_per_loss loss += 0.3 * nll_loss else: loss += nll_loss metrics.add(loss=loss) return metrics, scores
def collect_metrics(self, outputs, target, epoch=-1): """ collect_metrics """ num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 # test begin # nll = self.nll(torch.log(outputs.posterior_attn+1e-10), outputs.attn_index) # loss += nll # attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index) # metrics.add(attn_acc=attn_acc) # metrics.add(loss=loss) # return metrics # test end logits = outputs.logits scores = -self.nll_loss(logits, target, reduction=False) nll_loss = self.nll_loss(logits, target) num_words = target.ne(self.padding_idx).sum().item() acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(nll=(nll_loss, num_words), acc=acc) if self.use_posterior: kl_loss = self.kl_loss(torch.log(outputs.prior_attn + 1e-10), outputs.posterior_attn.detach()) metrics.add(kl=kl_loss) if self.use_bow: bow_logits = outputs.bow_logits bow_labels = target[:, :-1] bow_logits = bow_logits.repeat(1, bow_labels.size(-1), 1) bow = self.nll_loss(bow_logits, bow_labels) loss += bow metrics.add(bow=bow) if self.use_dssm: mse = self.mse_loss(outputs.dssm, outputs.reply_vec.detach()) loss += mse metrics.add(mse=mse) pos_logits = outputs.pos_logits pos_target = torch.ones_like(pos_logits) neg_logits = outputs.neg_logits neg_target = torch.zeros_like(neg_logits) pos_loss = F.binary_cross_entropy_with_logits(pos_logits, pos_target, reduction='none') neg_loss = F.binary_cross_entropy_with_logits(neg_logits, neg_target, reduction='none') loss += (pos_loss + neg_loss).mean() metrics.add(pos_loss=pos_loss.mean(), neg_loss=neg_loss.mean()) if epoch == -1 or epoch > self.pretrain_epoch or \ (self.use_bow is not True and self.use_dssm is not True): loss += nll_loss loss += kl_loss if self.use_pg: posterior_probs = outputs.posterior_attn.gather( 1, outputs.indexs.view(-1, 1)) reward = -perplexity(logits, target, self.weight, self.padding_idx) * 100 pg_loss = -(reward.detach() - self.baseline) * posterior_probs.view(-1) pg_loss = pg_loss.mean() loss += pg_loss metrics.add(pg_loss=pg_loss, reward=reward.mean()) if 'attn_index' in outputs: attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index) metrics.add(attn_acc=attn_acc) else: loss += nll_loss metrics.add(loss=loss) return metrics, scores