def _get_prediction_loss(self, fwd_pass: ForwardPassOutputs) -> torch.Tensor: """ Calculate and return the KL loss on the teacher's prediction layer. Also record prediction-loss metrics. """ assert isinstance(self, TorchGeneratorAgent) # Code relies on methods pred_loss = F.kl_div( F.log_softmax(fwd_pass.student_scores, dim=-1, dtype=torch.float), F.softmax(fwd_pass.teacher_scores, dim=-1, dtype=torch.float), reduction='none', ).type_as(fwd_pass.student_scores) pred_loss = pred_loss.sum(dim=-1) * fwd_pass.mask # Sum over dictionary self.record_local_metric( 'pred_ppl', PPLMetric.many(pred_loss.sum(dim=-1), fwd_pass.tokens_per_example), ) # Sum over tokens self.record_local_metric( 'pred_loss', AverageMetric.many(pred_loss.sum(dim=-1), fwd_pass.tokens_per_example), ) # Sum over tokens pred_loss = pred_loss.sum() / fwd_pass.num_tokens return pred_loss
def compute_loss(self, batch, return_output=False): """ Override TGA.compute_loss to ignore start token. """ if batch.label_vec is None: raise ValueError('Cannot compute loss without a label.') model_output = self.model(*self._model_input(batch), ys=batch.label_vec) scores, preds, *_ = model_output if scores.size(1) != batch.label_vec.size(1): # ignore start scores = scores[:, 1:, :] preds = preds[:, 1:] score_view = scores.reshape(-1, scores.size(-1)) loss = self.criterion(score_view, batch.label_vec.view(-1)) loss = loss.view(scores.shape[:-1]).sum(dim=1) # save loss to metrics notnull = batch.label_vec.ne(self.NULL_IDX) target_tokens = notnull.long().sum(dim=-1) correct = ((batch.label_vec == preds) * notnull).sum(dim=-1) self.record_local_metric('loss', AverageMetric.many(loss, target_tokens)) self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens)) self.record_local_metric('token_acc', AverageMetric.many(correct, target_tokens)) # actually do backwards loss loss = loss.sum() loss /= target_tokens.sum() # average loss per token if return_output: return (loss, model_output) else: return loss
def compute_loss(self, batch, return_output=False): """ Override from TorchGeneratorAgent Compute and return the loss for the given batch. Easily overridable for customized loss functions. If return_output is True, the full output from the call to self.model() is also returned, via a (loss, model_output) pair. """ if batch.label_vec is None: raise ValueError('Cannot compute loss without a label.') bsz = batch.text_vec.size(0) world_cardinality = self.world_cardinality embedding_size = self.opt.get('embedding_size') encoder_states = self.model.encoder(*self._encoder_input(batch)) enc_output = encoder_states[0].view(bsz, world_cardinality, -1, embedding_size).contiguous() enc_output_mask = encoder_states[1].view(bsz, world_cardinality, -1).contiguous() encoder_states = (enc_output, enc_output_mask) scores, preds = self.model.selfconscious_decode_forced( encoder_states, batch.label_vec) model_output = (scores, preds, encoder_states) score_view = scores.view(-1, scores.size(-1)) loss = self.criterion(score_view, batch.label_vec.view(-1)) loss = loss.view(scores.shape[:-1]).sum(dim=1) # save loss to metrics notnull = batch.label_vec.ne(self.NULL_IDX) target_tokens = notnull.long().sum(dim=-1) correct = ((batch.label_vec == preds) * notnull).sum(dim=-1) self.record_local_metric('loss', AverageMetric.many(loss, target_tokens)) self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens)) self.record_local_metric('token_acc', AverageMetric.many(correct, target_tokens)) # actually do backwards loss loss = loss.sum() loss /= target_tokens.sum() # average loss per token if return_output: return (loss, model_output) else: return loss
def compute_loss( self, batch: Batch, return_output: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]: """ Override standard TGA.compute_loss to call relevant RAG Model Interface. """ if batch.label_vec is None: raise ValueError('Cannot compute loss without a label.') model_output = self.get_model_output(batch) scores, preds, enc_state, *_ = model_output self._record_retrieval_metrics(batch, enc_state) ( loss, metric_loss, metric_correct, metric_target_tokens, ) = self._rag_model_interface.compute_loss(self.criterion, scores, preds, enc_state, batch.label_vec) self.record_local_metric( 'loss', AverageMetric.many(metric_loss, metric_target_tokens)) self.record_local_metric( 'ppl', PPLMetric.many(metric_loss, metric_target_tokens)) self.record_local_metric( 'token_acc', AverageMetric.many(metric_correct, metric_target_tokens)) self.record_local_metric( 'token_em', AverageMetric.many([ x == y for x, y in zip(metric_correct, metric_target_tokens) ]), ) if return_output: return loss, model_output else: return loss
def compute_loss(self, batch, return_output=False): if batch.label_vec is None: raise ValueError('Cannot compute loss without a label.') model_output = self.model(*self._model_input(batch), ys=batch.label_vec, res_lens=batch.label_lengths) scores, preds, vhred_kl_loss, bow_loss, *_ = model_output score_view = scores.view(-1, scores.size(-1)) loss = self.criterion(score_view / self.opt['temp'], batch.label_vec[:, 1:].contiguous().view(-1)) loss = loss.view(scores.shape[:-1]).sum(dim=1) # save loss to metrics notnull = batch.label_vec[:, :-1].ne(self.NULL_IDX) target_tokens = notnull.long().sum(dim=-1) correct = ((batch.label_vec[:, :-1] == preds) * notnull).sum(dim=-1) self.record_local_metric('loss', AverageMetric.many(loss, target_tokens)) self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens)) self.record_local_metric('token_acc', AverageMetric.many(correct, target_tokens)) # actually do backwards loss loss = loss.sum() loss /= target_tokens.sum() # average loss per token # for vhred if vhred_kl_loss != -1 and bow_loss != -1: loss += (vhred_kl_loss * self.model.anneal_weight(self._number_training_updates) + self.opt['bow_w'] * bow_loss) self.metrics['kl_loss_cnt'] += 1 self.metrics['kl_loss'] += vhred_kl_loss.item() self.metrics['bow_loss_cnt'] += 1 self.metrics['bow_loss'] += bow_loss.item() if return_output: return (loss, model_output) else: return loss
def compute_loss(self, batch, return_output=False): """ Compute and return the loss for the given batch. Easily overridable for customized loss functions. If return_output is True, the full output from the call to self.model() is also returned, via a (loss, model_output) pair. """ # print('Computing loss on batch', batch['u1'].shape) if batch.label_vec is None: raise ValueError('Cannot compute loss without a label.') model_output = self.model(self._model_input(batch)) scores, preds, *_ = model_output # import pdb; pdb.set_trace() preds = torch.argmax(scores, dim=2) score_view = scores.view(-1, scores.size(-1)) loss = self.criterion(score_view, batch.label_vec.view(-1)) loss = loss.view(scores.shape[:-1]).sum(dim=1) # save loss to metrics notnull = batch.label_vec.ne(self.NULL_IDX) target_tokens = notnull.long().sum(dim=-1) correct = ((batch.label_vec == preds) * notnull).sum(dim=-1) self.record_local_metric('loss', AverageMetric.many(loss, target_tokens)) self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens)) self.record_local_metric('token_acc', AverageMetric.many(correct, target_tokens)) # actually do backwards loss loss = loss.sum() loss /= target_tokens.sum() # average loss per token if return_output: return (loss, model_output) else: return loss
def compute_loss(self, batch, return_output=False): """ Compute and return the loss for the given batch. Easily overridable for customized loss functions. If return_output is True, the full output from the call to self.model() is also returned, via a (loss, model_output) pair. """ model_input = self._model_input(batch) with torch.no_grad(): teacher_output = self.teacher_agent.model(*model_input, ys=batch.label_vec) teacher_scores, teacher_preds, *_ = teacher_output if batch.label_vec is None: raise ValueError('Cannot compute loss without a label.') model_output = self.model(*model_input, ys=batch.label_vec) scores, preds, *_ = model_output if scores.size(-1) < teacher_scores.size(-1): vocab_difference = teacher_scores.size(-1) - scores.size(-1) scores = F.pad(scores, (0, vocab_difference), "constant", 0) teacher_scores[:, :, -vocab_difference:] = 0 # also zeros out teacher outputs score_view = scores.view(-1, scores.size(-1)) loss = self.criterion(score_view, batch.label_vec.view(-1)) loss = loss.view(scores.shape[:-1]).sum(dim=1) # teacher loss (for record keeping) teacher_score_view = teacher_scores.view(-1, teacher_scores.size(-1)) teacher_loss = self.criterion(teacher_score_view, batch.label_vec.view(-1)) teacher_loss = teacher_loss.view(teacher_scores.shape[:-1]).sum(dim=1) # KL loss ce_loss_fct = nn.KLDivLoss(reduction="none") loss_kl = (ce_loss_fct( F.log_softmax(scores / self.distill_temperature, dim=-1), F.softmax(teacher_scores / self.distill_temperature, dim=-1)) * (self.distill_temperature)**2).view(scores.shape[0], -1).sum(dim=-1) # print(loss.size()) # print(loss_kl.size()) # save loss to metrics notnull = batch.label_vec.ne(self.NULL_IDX) target_tokens = notnull.long().sum(dim=-1) correct = ((batch.label_vec == preds) * notnull).sum(dim=-1) teacher_correct = ((batch.label_vec == teacher_preds) * notnull).sum(dim=-1) self.record_local_metric('kl_loss', AverageMetric.many(loss_kl, target_tokens)) # print(loss.size()) # print(target_tokens.size()) self.record_local_metric('loss', AverageMetric.many(loss, target_tokens)) self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens)) self.record_local_metric( 'token_acc', AverageMetric.many(correct, target_tokens), ) self.record_local_metric( 'teacher_loss', AverageMetric.many(teacher_loss, target_tokens)) self.record_local_metric('teacher_ppl', PPLMetric.many(teacher_loss, target_tokens)) self.record_local_metric( 'teacher_token_acc', AverageMetric.many(teacher_correct, target_tokens), ) # actually do backwards loss loss = loss.sum() loss /= target_tokens.sum() # average loss per token loss = self.distill_alpha * loss_kl + (1 - self.distill_alpha) * loss loss = loss.mean() if return_output: return (loss, model_output) else: return loss