def _get_prediction_loss(self,
                             fwd_pass: ForwardPassOutputs) -> torch.Tensor:
        """
        Calculate and return the KL loss on the teacher's prediction layer.

        Also record prediction-loss metrics.
        """
        assert isinstance(self, TorchGeneratorAgent)
        # Code relies on methods
        pred_loss = F.kl_div(
            F.log_softmax(fwd_pass.student_scores, dim=-1, dtype=torch.float),
            F.softmax(fwd_pass.teacher_scores, dim=-1, dtype=torch.float),
            reduction='none',
        ).type_as(fwd_pass.student_scores)
        pred_loss = pred_loss.sum(dim=-1) * fwd_pass.mask
        # Sum over dictionary
        self.record_local_metric(
            'pred_ppl',
            PPLMetric.many(pred_loss.sum(dim=-1), fwd_pass.tokens_per_example),
        )  # Sum over tokens
        self.record_local_metric(
            'pred_loss',
            AverageMetric.many(pred_loss.sum(dim=-1),
                               fwd_pass.tokens_per_example),
        )  # Sum over tokens
        pred_loss = pred_loss.sum() / fwd_pass.num_tokens
        return pred_loss
Example #2
0
    def compute_loss(self, batch, return_output=False):
        """
        Override TGA.compute_loss to ignore start token.
        """
        if batch.label_vec is None:
            raise ValueError('Cannot compute loss without a label.')
        model_output = self.model(*self._model_input(batch),
                                  ys=batch.label_vec)
        scores, preds, *_ = model_output

        if scores.size(1) != batch.label_vec.size(1):
            # ignore start
            scores = scores[:, 1:, :]
            preds = preds[:, 1:]

        score_view = scores.reshape(-1, scores.size(-1))
        loss = self.criterion(score_view, batch.label_vec.view(-1))
        loss = loss.view(scores.shape[:-1]).sum(dim=1)
        # save loss to metrics
        notnull = batch.label_vec.ne(self.NULL_IDX)
        target_tokens = notnull.long().sum(dim=-1)
        correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)

        self.record_local_metric('loss',
                                 AverageMetric.many(loss, target_tokens))
        self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
        self.record_local_metric('token_acc',
                                 AverageMetric.many(correct, target_tokens))
        # actually do backwards loss
        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token
        if return_output:
            return (loss, model_output)
        else:
            return loss
    def compute_loss(self, batch, return_output=False):
        """
        Override from TorchGeneratorAgent
        Compute and return the loss for the given batch.

        Easily overridable for customized loss functions.

        If return_output is True, the full output from the call to self.model()
        is also returned, via a (loss, model_output) pair.
        """
        if batch.label_vec is None:
            raise ValueError('Cannot compute loss without a label.')

        bsz = batch.text_vec.size(0)
        world_cardinality = self.world_cardinality
        embedding_size = self.opt.get('embedding_size')
        encoder_states = self.model.encoder(*self._encoder_input(batch))

        enc_output = encoder_states[0].view(bsz, world_cardinality, -1,
                                            embedding_size).contiguous()
        enc_output_mask = encoder_states[1].view(bsz, world_cardinality,
                                                 -1).contiguous()
        encoder_states = (enc_output, enc_output_mask)

        scores, preds = self.model.selfconscious_decode_forced(
            encoder_states, batch.label_vec)
        model_output = (scores, preds, encoder_states)

        score_view = scores.view(-1, scores.size(-1))
        loss = self.criterion(score_view, batch.label_vec.view(-1))
        loss = loss.view(scores.shape[:-1]).sum(dim=1)
        # save loss to metrics
        notnull = batch.label_vec.ne(self.NULL_IDX)
        target_tokens = notnull.long().sum(dim=-1)
        correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)

        self.record_local_metric('loss',
                                 AverageMetric.many(loss, target_tokens))
        self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
        self.record_local_metric('token_acc',
                                 AverageMetric.many(correct, target_tokens))

        # actually do backwards loss
        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token

        if return_output:
            return (loss, model_output)
        else:
            return loss
Example #4
0
    def compute_loss(
        self,
        batch: Batch,
        return_output: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]:
        """
        Override standard TGA.compute_loss to call relevant RAG Model Interface.
        """
        if batch.label_vec is None:
            raise ValueError('Cannot compute loss without a label.')

        model_output = self.get_model_output(batch)
        scores, preds, enc_state, *_ = model_output

        self._record_retrieval_metrics(batch, enc_state)
        (
            loss,
            metric_loss,
            metric_correct,
            metric_target_tokens,
        ) = self._rag_model_interface.compute_loss(self.criterion, scores,
                                                   preds, enc_state,
                                                   batch.label_vec)

        self.record_local_metric(
            'loss', AverageMetric.many(metric_loss, metric_target_tokens))
        self.record_local_metric(
            'ppl', PPLMetric.many(metric_loss, metric_target_tokens))
        self.record_local_metric(
            'token_acc',
            AverageMetric.many(metric_correct, metric_target_tokens))
        self.record_local_metric(
            'token_em',
            AverageMetric.many([
                x == y for x, y in zip(metric_correct, metric_target_tokens)
            ]),
        )

        if return_output:
            return loss, model_output
        else:
            return loss
    def compute_loss(self, batch, return_output=False):
        if batch.label_vec is None:
            raise ValueError('Cannot compute loss without a label.')
        model_output = self.model(*self._model_input(batch),
                                  ys=batch.label_vec,
                                  res_lens=batch.label_lengths)
        scores, preds, vhred_kl_loss, bow_loss, *_ = model_output
        score_view = scores.view(-1, scores.size(-1))
        loss = self.criterion(score_view / self.opt['temp'],
                              batch.label_vec[:, 1:].contiguous().view(-1))
        loss = loss.view(scores.shape[:-1]).sum(dim=1)
        # save loss to metrics
        notnull = batch.label_vec[:, :-1].ne(self.NULL_IDX)
        target_tokens = notnull.long().sum(dim=-1)
        correct = ((batch.label_vec[:, :-1] == preds) * notnull).sum(dim=-1)

        self.record_local_metric('loss',
                                 AverageMetric.many(loss, target_tokens))
        self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
        self.record_local_metric('token_acc',
                                 AverageMetric.many(correct, target_tokens))
        # actually do backwards loss
        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token

        # for vhred
        if vhred_kl_loss != -1 and bow_loss != -1:
            loss += (vhred_kl_loss *
                     self.model.anneal_weight(self._number_training_updates) +
                     self.opt['bow_w'] * bow_loss)
            self.metrics['kl_loss_cnt'] += 1
            self.metrics['kl_loss'] += vhred_kl_loss.item()
            self.metrics['bow_loss_cnt'] += 1
            self.metrics['bow_loss'] += bow_loss.item()

        if return_output:
            return (loss, model_output)
        else:
            return loss
Example #6
0
    def compute_loss(self, batch, return_output=False):
        """
        Compute and return the loss for the given batch.

        Easily overridable for customized loss functions.

        If return_output is True, the full output from the call to self.model()
        is also returned, via a (loss, model_output) pair.
        """
        # print('Computing loss on batch', batch['u1'].shape)
        if batch.label_vec is None:
            raise ValueError('Cannot compute loss without a label.')
        model_output = self.model(self._model_input(batch))
        scores, preds, *_ = model_output
        # import pdb; pdb.set_trace()
        preds = torch.argmax(scores, dim=2)
        score_view = scores.view(-1, scores.size(-1))
        loss = self.criterion(score_view, batch.label_vec.view(-1))
        loss = loss.view(scores.shape[:-1]).sum(dim=1)
        # save loss to metrics
        notnull = batch.label_vec.ne(self.NULL_IDX)
        target_tokens = notnull.long().sum(dim=-1)
        correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)

        self.record_local_metric('loss',
                                 AverageMetric.many(loss, target_tokens))
        self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
        self.record_local_metric('token_acc',
                                 AverageMetric.many(correct, target_tokens))
        # actually do backwards loss
        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token
        if return_output:
            return (loss, model_output)
        else:
            return loss
    def compute_loss(self, batch, return_output=False):
        """
        Compute and return the loss for the given batch.

        Easily overridable for customized loss functions.

        If return_output is True, the full output from the call to self.model()
        is also returned, via a (loss, model_output) pair.
        """
        model_input = self._model_input(batch)
        with torch.no_grad():
            teacher_output = self.teacher_agent.model(*model_input,
                                                      ys=batch.label_vec)
            teacher_scores, teacher_preds, *_ = teacher_output

        if batch.label_vec is None:
            raise ValueError('Cannot compute loss without a label.')
        model_output = self.model(*model_input, ys=batch.label_vec)
        scores, preds, *_ = model_output

        if scores.size(-1) < teacher_scores.size(-1):
            vocab_difference = teacher_scores.size(-1) - scores.size(-1)
            scores = F.pad(scores, (0, vocab_difference), "constant", 0)
            teacher_scores[:, :,
                           -vocab_difference:] = 0  # also zeros out teacher outputs

        score_view = scores.view(-1, scores.size(-1))

        loss = self.criterion(score_view, batch.label_vec.view(-1))

        loss = loss.view(scores.shape[:-1]).sum(dim=1)

        # teacher loss (for record keeping)
        teacher_score_view = teacher_scores.view(-1, teacher_scores.size(-1))
        teacher_loss = self.criterion(teacher_score_view,
                                      batch.label_vec.view(-1))
        teacher_loss = teacher_loss.view(teacher_scores.shape[:-1]).sum(dim=1)

        # KL loss
        ce_loss_fct = nn.KLDivLoss(reduction="none")
        loss_kl = (ce_loss_fct(
            F.log_softmax(scores / self.distill_temperature, dim=-1),
            F.softmax(teacher_scores / self.distill_temperature, dim=-1)) *
                   (self.distill_temperature)**2).view(scores.shape[0],
                                                       -1).sum(dim=-1)
        # print(loss.size())
        # print(loss_kl.size())

        # save loss to metrics
        notnull = batch.label_vec.ne(self.NULL_IDX)
        target_tokens = notnull.long().sum(dim=-1)
        correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)

        teacher_correct = ((batch.label_vec == teacher_preds) *
                           notnull).sum(dim=-1)

        self.record_local_metric('kl_loss',
                                 AverageMetric.many(loss_kl, target_tokens))
        # print(loss.size())
        # print(target_tokens.size())
        self.record_local_metric('loss',
                                 AverageMetric.many(loss, target_tokens))
        self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
        self.record_local_metric(
            'token_acc',
            AverageMetric.many(correct, target_tokens),
        )
        self.record_local_metric(
            'teacher_loss', AverageMetric.many(teacher_loss, target_tokens))
        self.record_local_metric('teacher_ppl',
                                 PPLMetric.many(teacher_loss, target_tokens))
        self.record_local_metric(
            'teacher_token_acc',
            AverageMetric.many(teacher_correct, target_tokens),
        )
        # actually do backwards loss
        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token

        loss = self.distill_alpha * loss_kl + (1 - self.distill_alpha) * loss
        loss = loss.mean()

        if return_output:
            return (loss, model_output)
        else:
            return loss