def collect_rl_metrics(self, sample_outputs, greedy_outputs, target, gold_entity, entity_dir): """ collect rl training metrics """ num_samples = target.size(0) rl_metrics = Pack(num_samples=num_samples) loss = 0 # log prob for sampling and greedily generation logits = sample_outputs.logits sample = sample_outputs.pred_word greedy_logits = greedy_outputs.logits greedy_sample = greedy_outputs.pred_word # cal reward sample_reward, _, _ = self.reward_fn(sample, target, gold_entity, entity_dir) greedy_reward, bleu_score, f1_score = self.reward_fn(greedy_sample, target, gold_entity, entity_dir) reward = sample_reward - greedy_reward # cal RL loss sample_log_prob = self.nll_loss(logits, sample, mask=sample.ne(self.padding_idx), reduction=False, matrix=False) # [batch_size, max_len] nll = sample_log_prob * reward.to(sample_log_prob.device) nll = getattr(torch, self.nll_loss.reduction)(nll.sum(dim=-1)) loss += nll # gen report rl_acc = accuracy(greedy_logits, target, padding_idx=self.padding_idx) if reward.dim() == 2: reward = reward.sum(dim=-1) rl_metrics.add(loss=loss, reward=reward.mean(), rl_acc=rl_acc, bleu_score=bleu_score.mean(), f1_score=f1_score.mean()) return rl_metrics
def collect_metrics(self, outputs, target): num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 logits = outputs.logits nll = self.nll_loss(logits, target) num_words = target.ne(self.padding_idx).sum().item() acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(nll=(nll, num_words), acc=acc) loss += nll metrics.add(loss=loss) return metrics
def collect_metrics(self, outputs, target, ptr_index, kb_index): """ collect_metrics """ num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 # loss for generation logits = outputs.logits nll = self.nll_loss(logits, target) loss += nll ''' # loss for gate pad_zeros = torch.zeros([num_samples, 1], dtype=torch.long) if self.use_gpu: pad_zeros = pad_zeros.cuda() ptr_index = torch.cat([ptr_index, pad_zeros], dim=-1).float() gate_logits = outputs.gate_logits loss_gate = self.bce_loss(gate_logits, ptr_index) loss += loss_gate ''' # loss for selector # selector_target = kb_index.float() # selector_logits = outputs.selector_logits # selector_mask = outputs.selector_mask # # if selector_target.size(-1) < selector_logits.size(-1): # pad_zeros = torch.zeros(size=(num_samples, selector_logits.size(-1)-selector_target.size(-1)), # dtype=torch.float) # if self.use_gpu: # pad_zeros = pad_zeros.cuda() # selector_target = torch.cat([selector_target, pad_zeros], dim=-1) # loss_ptr = self.bce_loss(selector_logits, selector_target, mask=selector_mask) loss_ptr = torch.tensor(0.0) if self.use_gpu: loss_ptr = loss_ptr.cuda() loss += loss_ptr acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(loss=loss, ptr=loss_ptr, acc=acc, logits=logits, prob=outputs.prob) return metrics
def collect_metrics(self, outputs, target, emo_target): """ collect_metrics """ num_samples = target[0].size(0) num_words = target[1].sum().item() metrics = Pack(num_samples=num_samples) target_len = target[1] mask = sequence_mask(target_len) mask = mask.float() # logits = outputs.logits # nll = self.nll_loss(logits, target) out_copy = outputs.out_copy # out_copy batch x max_len x src target_loss = out_copy.gather(2, target[0].unsqueeze(-1)).squeeze(-1) target_loss = target_loss * mask target_loss += 1e-15 target_loss = target_loss.log() loss = -((target_loss.sum()) / num_words) out_emo = outputs.logits # batch x max_len x dim batch_size, max_len, class_num = out_emo.size() # out_emo=out_emo.view(batch_size*max_len, class_num) # emo_target=emo_target.view(-1) target_emo_loss = out_emo.gather( 2, emo_target[0].unsqueeze(-1)).squeeze(-1) target_len -= 1 mask_ = sequence_mask(target_len) mask_ = mask_.float() new_mask = mask.data.new(batch_size, max_len).zero_() # print(mask.size()) # print(new_mask.size()) new_mask[:, :max_len - 1] = mask_ target_emo_loss = target_emo_loss * new_mask target_emo_loss += 1e-15 target_emo_loss = target_emo_loss.log() emo_loss = -((target_emo_loss.sum()) / num_words) metrics.add(loss=loss) metrics.add(emo_loss=emo_loss) # 这里,我们将只计算 acc = accuracy(out_copy, target[0], mask=mask) metrics.add(acc=acc) return metrics
def collect_metrics(self, outputs, target, epoch=-1): """ collect_metrics """ num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 # test begin # nll = self.nll(torch.log(outputs.posterior_attn+1e-10), outputs.attn_index) # loss += nll # attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index) # metrics.add(attn_acc=attn_acc) # metrics.add(loss=loss) # return metrics # test end logits = outputs.logits scores = -self.nll_loss(logits, target, reduction=False) nll_loss = self.nll_loss(logits, target) num_words = target.ne(self.padding_idx).sum().item() acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(nll=(nll_loss, num_words), acc=acc) # persona loss if 'attn_index' in outputs: attn_acc = attn_accuracy(outputs.cue_attn, outputs.attn_index) metrics.add(attn_acc=attn_acc) per_logits = torch.log(outputs.cue_attn + self.eps) # cue_attn(batch_size, sent_num) per_labels = outputs.attn_index ##(batch_size) use_per_loss = self.persona_loss( per_logits, per_labels) # per_labels(batch_size) metrics.add(use_per_loss=use_per_loss) loss += 0.7 * use_per_loss loss += 0.3 * nll_loss else: loss += nll_loss metrics.add(loss=loss) return metrics, scores
def collect_metrics(self, outputs, target, bridge=None, epoch=-1): """ collect_metrics """ num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 # response generation logits = outputs.logits nll = self.nll_loss(logits, target) num_words = target.ne(self.padding_idx).sum().item() acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(nll=(nll, num_words), acc=acc) loss += nll # neural topic model ntm_loss = outputs.ntm_loss.sum().item() loss += ntm_loss / self.topic_vocab_size * 0.3 metrics.add(loss=loss) return metrics
def collect_metrics(self, outputs, target, epoch=-1): """ collect_metrics """ num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 # test begin # nll = self.nll(torch.log(outputs.posterior_attn+1e-10), outputs.attn_index) # loss += nll # attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index) # metrics.add(attn_acc=attn_acc) # metrics.add(loss=loss) # return metrics # test end logits = outputs.logits scores = -self.nll_loss(logits, target, reduction=False) nll_loss = self.nll_loss(logits, target) num_words = target.ne(self.padding_idx).sum().item() acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(nll=(nll_loss, num_words), acc=acc) if self.use_posterior: kl_loss = self.kl_loss(torch.log(outputs.prior_attn + 1e-10), outputs.posterior_attn.detach()) metrics.add(kl=kl_loss) if self.use_bow: bow_logits = outputs.bow_logits bow_labels = target[:, :-1] bow_logits = bow_logits.repeat(1, bow_labels.size(-1), 1) bow = self.nll_loss(bow_logits, bow_labels) loss += bow metrics.add(bow=bow) if self.use_dssm: mse = self.mse_loss(outputs.dssm, outputs.reply_vec.detach()) loss += mse metrics.add(mse=mse) pos_logits = outputs.pos_logits pos_target = torch.ones_like(pos_logits) neg_logits = outputs.neg_logits neg_target = torch.zeros_like(neg_logits) pos_loss = F.binary_cross_entropy_with_logits(pos_logits, pos_target, reduction='none') neg_loss = F.binary_cross_entropy_with_logits(neg_logits, neg_target, reduction='none') loss += (pos_loss + neg_loss).mean() metrics.add(pos_loss=pos_loss.mean(), neg_loss=neg_loss.mean()) if epoch == -1 or epoch > self.pretrain_epoch or \ (self.use_bow is not True and self.use_dssm is not True): loss += nll_loss loss += kl_loss if self.use_pg: posterior_probs = outputs.posterior_attn.gather( 1, outputs.indexs.view(-1, 1)) reward = -perplexity(logits, target, self.weight, self.padding_idx) * 100 pg_loss = -(reward.detach() - self.baseline) * posterior_probs.view(-1) pg_loss = pg_loss.mean() loss += pg_loss metrics.add(pg_loss=pg_loss, reward=reward.mean()) if 'attn_index' in outputs: attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index) metrics.add(attn_acc=attn_acc) else: loss += nll_loss metrics.add(loss=loss) return metrics, scores