Esempio n. 1
0
 def compute_loss(self, model, net_output, sample, reduce=True):
     _2nd_lprobs, target = self.get_lprobs_and_target(
         model, net_output, sample)
     temp_net_output = (net_output[1]["dec_out"], net_output[1])
     # logger.info("Info of net_output:{}".format(net_output[0].size()))
     # logger.info("Info of temp_net_output:{}".format(temp_net_output[0].size()))
     _1st_lprobs, _ = self.get_lprobs_and_target(model, temp_net_output,
                                                 sample)
     # logger.info("shape of _1st_lprobs:{}, shape of _2nd_lprobs:{}".format(_1st_lprobs.size(), _2nd_lprobs.size()))
     _2nd_loss, _2nd_nll_loss = label_smoothed_nll_loss(
         _2nd_lprobs,
         target,
         self.eps,
         ignore_index=self.padding_idx,
         reduce=reduce,
     )
     _1st_loss, _1st_nll_loss = label_smoothed_nll_loss(
         _1st_lprobs,
         target,
         self.eps,
         ignore_index=self.padding_idx,
         reduce=reduce)
     loss = 0.5 * _1st_loss + 0.5 * _2nd_loss
     nll_loss = 0.5 * _1st_nll_loss + 0.5 * _2nd_nll_loss
     return loss, nll_loss
Esempio n. 2
0
    def compute_loss(self, model, net_output, sample, reduction, log_probs):
        # number loss
        _number = net_output["num_output"]
        number = sample["target_lengths"].float()
        diff = torch.sqrt(torch.pow(_number - number, 2) + 1e-6).sum()
        qua_loss = diff
        # alphas_pen
        # alphas_pen = net_output["alphas_pen"]
        # qua_loss = diff + self.args.lambda_alpha * alphas_pen
        target = sample["target"]  # no eos bos
        # N, T -> N * T
        target = target.view(-1)
        lprobs = model.get_normalized_probs(net_output, log_probs=log_probs)
        if not hasattr(lprobs, "batch_first"):
            logging.warning(
                "ERROR: we need to know whether "
                "batch first for the net output; "
                "you need to set batch_first attribute for the return value of "
                "model.get_normalized_probs. Now, we assume this is true, but "
                "in the future, we will raise exception instead. "
            )
        batch_first = getattr(lprobs, "batch_first", True)
        if not batch_first:
            lprobs = lprobs.transpose(0, 1)

        # N, T, D -> N * T, D
        lprobs = lprobs.view(-1, lprobs.size(-1))
        ce_loss, _ = label_smoothed_nll_loss(
            lprobs, target.long(), 0.1, ignore_index=self.padding_idx, reduce=reduction,
        )

        return lprobs, qua_loss, ce_loss
Esempio n. 3
0
    def compute_loss(self, model, ctc_logits, len_ctc_logits, logits, target,
                     target_lengths, reduction, log_probs):
        # N, T -> N * T
        ctc_lprob, lprobs = model.get_normalized_probs(ctc_logits,
                                                       logits,
                                                       log_probs=log_probs)
        ctc_loss = self.cal_ctc_loss(ctc_lprob, len_ctc_logits, target,
                                     target_lengths - 1)

        if not hasattr(lprobs, "batch_first"):
            logging.warning(
                "ERROR: we need to know whether "
                "batch first for the net output; "
                "you need to set batch_first attribute for the return value of "
                "model.get_normalized_probs. Now, we assume this is true, but "
                "in the future, we will raise exception instead. ")
        batch_first = getattr(lprobs, "batch_first", True)
        if not batch_first:
            lprobs = lprobs.transpose(0, 1)

        # N, T, D -> N * T, D
        target = target.view(-1)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        loss, _ = label_smoothed_nll_loss(
            lprobs,
            target.long(),
            0.1,
            ignore_index=self.padding_idx,
            reduce=reduction,
        )

        return lprobs, ctc_loss, loss
 def get_elementwise_loss(self, model, net_output, sample):
     lprobs = model.get_normalized_probs(net_output, log_probs=True)
     lprobs = lprobs.view(-1, lprobs.size(-1))
     target = model.get_targets(sample, net_output).view(-1, 1)
     loss, nll_loss = label_smoothed_nll_loss(
         lprobs, target, self.eps, ignore_index=None, reduce=False,
     )
     return loss, nll_loss
    def compute_loss(self, model, net_output, sample, reduce=True, reverse=False):
        lprobs, target = self.get_lprobs_and_target(model, net_output, sample, reverse=reverse)

        loss, nll_loss = label_smoothed_nll_loss(
            lprobs,
            target,
            self.eps,
            ignore_index=self.padding_idx,
            reduce=reduce,
        )
        return loss, nll_loss
Esempio n. 6
0
 def forward(self, model, sample, reduce=True, log_probs=True):
     net_output = model(**sample['net_input'])
     sample_size = sample['target'].size(
         0) if self.sentence_avg else sample['ntokens']
     lprobs = model.get_normalized_probs(net_output[0], log_probs=True)
     lprobs = lprobs.view(-1, lprobs.size(-1))
     target = model.get_targets(sample, net_output[0]).view(-1, 1)
     primary_loss, primary_nll_loss = label_smoothed_nll_loss(
         lprobs,
         target,
         self.eps,
         ignore_index=self.padding_idx,
         reduce=reduce,
     )
     auxiliary_token_lens = model.get_auxiliary_token_lens(sample)
     auxiliary_probs = model.auxiliary_decoder.get_normalized_probs(
         net_output[1], log_probs=True)
     auxiliary_probs = auxiliary_probs.view(-1, auxiliary_probs.size(-1))
     auxiliary_target = model.get_auxiliary_target(sample,
                                                   net_output[1]).view(
                                                       -1, 1)
     auxiliary_loss, auxiliary_nll_loss = label_smoothed_nll_loss(
         auxiliary_probs,
         auxiliary_target,
         self.eps,
         ignore_index=self.padding_idx,
         reduce=reduce,
     )
     loss = self.primary_loss_weight * primary_loss + self.auxiliary_loss_weight * auxiliary_loss
     logging_output = {
         'loss': loss.data,
         'primary_loss': primary_loss.data,
         'primary_nll_loss': primary_nll_loss.data,
         'auxiliary_loss': auxiliary_loss.data,
         'auxiliary_nll_loss': auxiliary_nll_loss.data,
         'ntokens': sample['ntokens'],
         'auxiliary_ntokens': auxiliary_token_lens.sum().data,
         'nsentences': sample['target'].size(0),
         'sample_size': sample_size,
     }
     return loss, sample_size, logging_output
Esempio n. 7
0
 def compute_loss(self, model, net_output, sample, reduce=True):
     lprobs = model.get_normalized_probs(net_output, log_probs=True)
     lprobs = lprobs.view(-1, lprobs.size(-1))
     target = model.get_targets(sample, net_output).view(-1, 1)
     loss, nll_loss = label_smoothed_nll_loss(
         lprobs,
         target,
         self.label_smoothing,
         ignore_index=self.padding_idx,
         reduce=reduce,
     )
     return loss, nll_loss
    def compute_loss_and_acc(self, model, lprobs, target, reduction='sum'):
        if not lprobs.batch_first:
            lprobs = lprobs.transpose(0, 1)
        lprobs = lprobs.view(-1, lprobs.size(-1))  # -> (B x T) x C
        target = target.view(-1)
        loss, nll_loss = label_smoothed_nll_loss(
            lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=(reduction == 'sum'),
        )

        mask = target.ne(self.padding_idx)
        correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
        total = torch.sum(mask)
        return loss, nll_loss, correct, total
Esempio n. 9
0
    def compute_loss(self, model, net_output, sample, reduction, log_probs):
        pad_id = self.task.target_dictionary.pad()

        # number loss
        _number = net_output["num_output"]
        number = sample["target_lengths"].float()
        diff = torch.sqrt(torch.pow(_number - number, 2) + 1e-12).sum()
        qua_loss = diff

        target = sample["target"].view(-1) # N, T -> N * T
        lprobs_ctc, lprobs = model.get_normalized_probs(net_output, retrun_ctc=True, log_probs=log_probs)

        pred_mask = net_output["pred_mask"]
        lprobs = lprobs.view(-1, lprobs.size(-1)) # N, T, D -> N * T, D
        target_masked = torch.where(
            pred_mask,
            sample["target"],
            torch.ones_like(sample["target"]) * pad_id).long().view(-1)
        ce_loss, _ = label_smoothed_nll_loss(
            lprobs,
            target_masked,
            0.0,
            ignore_index=self.task.target_dictionary.pad(),
            reduce=reduction,
        )

        # CTC loss
        target = sample["target"]
        pad_mask = target != self.task.target_dictionary.pad()
        targets_flat = target.masked_select(pad_mask)
        target_lengths = sample["target_lengths"]

        len_lprobs = net_output["len_logits_ctc"]
        with torch.backends.cudnn.flags(enabled=False):
            ctc_loss = F.ctc_loss(
                lprobs_ctc.transpose(0, 1), # T x B x V
                targets_flat,
                len_lprobs,
                target_lengths,
                blank=self.task.target_dictionary.pad(),
                reduction="sum",
                zero_infinity=True,
            )

        pred_mask = net_output["pred_mask"]
        target_masked = (sample["target"] * pred_mask[:, 1:-1]).view(-1)

        return lprobs, target_masked, ctc_loss, qua_loss, ce_loss
Esempio n. 10
0
    def cif_loss(self, model, sample, net_output, reduce):
        target = sample["target"]
        # N, T -> N * T
        target = target.view(-1)
        lprobs = model.get_normalized_probs_cif(net_output, log_probs=True)

        # N, T, D -> N * T, D
        lprobs = lprobs.view(-1, lprobs.size(-1))
        ce_loss, _ = label_smoothed_nll_loss(
            lprobs,
            target.long(),
            0.1,
            ignore_index=self.padding_idx,
            reduce=reduce,
        )
        return ce_loss, lprobs
    def compute_loss(self, model, net_output, sample, reduce=True):
        lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
        n_correct = 0
        if isinstance(target, dict):
            t_lprobs = target["target_logprobs"]

            if not lprobs.batch_first:
                lprobs = lprobs.transpose(0, 1)
                t_lprobs = t_lprobs.transpose(0, 1)
            nsentences, seq_len = lprobs.size()[:2]
            ntokens = nsentences * seq_len
            t_probs = t_lprobs.exp()
            mask_indices = (net_output[1]["mask_indices"][0] if
                            len(net_output[1]["mask_indices"]) > 0 else None)

            # mask_indices is True for those masking frames
            if mask_indices is not None:  # B X T
                t_probs = t_probs.masked_fill(
                    mask_indices.eq(False).unsqueeze(-1), 0)
                ntokens = mask_indices.int().sum()
            t_probs = t_probs.detach()
            t_lprobs = t_lprobs.detach()
            loss = (-(t_probs * (lprobs - t_lprobs)).sum() if reduce else
                    -(t_probs * (lprobs - t_lprobs)).sum(-1, keepdim=True))
            nll_loss = loss
        else:
            nsentences = target.size(0)
            mask = target.ne(self.padding_idx)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs.view(-1, lprobs.size(-1)),
                target.view(-1),
                self.eps,
                ignore_index=self.padding_idx,
                reduce=reduce,
            )
            n_correct = torch.sum(
                lprobs.argmax(-1).masked_select(mask).eq(
                    target.masked_select(mask)))
            ntokens = torch.sum(mask)
        return loss, nll_loss, nsentences, ntokens, n_correct
Esempio n. 12
0
 def forward(self, model, sample, reduce=True, log_probs=True):
     net_output = model(**sample['net_input'])
     sample_size = sample['target'].size(
         0) if self.sentence_avg else sample['ntokens']
     lprobs = model.get_normalized_probs(net_output[0], log_probs=True)
     lprobs = lprobs.view(-1, lprobs.size(-1))
     target = model.get_targets(sample, net_output[0]).view(-1, 1)
     loss, nll_loss = label_smoothed_nll_loss(
         lprobs,
         target,
         self.eps,
         ignore_index=self.padding_idx,
         reduce=reduce,
     )
     auxiliary_probs = model.auxiliary_decoder.get_normalized_probs(
         net_output[1], log_probs=True)
     if self.auxiliary_loss_class_weights is not None:
         class_weights = self.auxiliary_loss_class_weights.to(
             sample["auxiliary_target"].device)
     else:
         class_weights = None
     auxiliary_loss = F.nll_loss(
         auxiliary_probs,
         sample["auxiliary_target"].view(-1),
         weight=class_weights,
         reduction='sum' if reduce else 'none',
     )
     loss = loss + self.auxiliary_loss_weight * auxiliary_loss
     logging_output = {
         'loss': loss.data,
         'nll_loss': nll_loss.data,
         'auxiliary_loss': auxiliary_loss.data,
         'ntokens': sample['ntokens'],
         'nsentences': sample['target'].size(0),
         'sample_size': sample_size,
     }
     return loss, sample_size, logging_output
    def compute_loss(self, model, net_output, sample, reduce=True):
        self.user_mode = net_output[1]["user_mode"]
        if self.is_mode('max_lm_margin'):
            if self.training:
                lm_out_margin, decoder_out_margin = net_output[1]['lm_out_margin'], net_output[1]['decoder_out_margin']
                lp = self.get_log_probs(lm_out_margin, model).detach()
                lq =self.get_log_probs(decoder_out_margin, model)
                idx1 = net_output[1]['encoder_out']['idx1']
                idx2 = net_output[1]['encoder_out']['idx2']
                nkl = (lp.exp() * (lq - lp)).sum() / len(idx2) * len(idx1)
                new_sample = self.sub_sample_by_id(sample, idx1, model)
                loss, nll_loss = super().compute_loss(model, net_output, new_sample)
                loss_lm, nll_loss_lm = super().compute_loss(model, net_output[1]['lm_out'], new_sample)
                loss = loss + float(self.user_mode['rkl']) * nkl + loss_lm
            else:
                loss, nll_loss = super().compute_loss(model, net_output, sample)
            return loss, nll_loss
        elif self.is_mode('simple_mask'):
            msk = model.decoder.a_in_b(sample['target'], sample['net_input']['src_tokens'])
            mask = (msk | ((torch.rand_like(msk.float()) > float(self.user_mode['simple_mask'])) & (sample['target'].ne(model.decoder.pad))).long()).float()
            lprobs = model.get_normalized_probs(net_output, log_probs=True)
            lprobs = lprobs.view(-1, lprobs.size(-1))
            target = model.get_targets(sample, net_output).view(-1, 1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs, target, self.eps, ignore_index=None, reduce=False,
            )
            loss, nll_loss = (loss * mask.view(-1)).sum(), (nll_loss * mask.view(-1)).sum()
            return loss, nll_loss
        elif self.is_mode('endorsement', 'lse'):
            pe = net_output[0].float()
            ne = net_output[1]['neg_out'][0].float()
            p_mask = net_output[1]['p_mask'].float()
            n_mask = net_output[1]['n_mask'].float()
            loss = ((1 + torch.exp(-pe)).log() * p_mask).sum() + ((1 + torch.exp(ne)).log() * n_mask).sum()
            #pe[pe == float('-inf')] = float('inf')
            #loss = ((1 + torch.exp(-pe)).log() * p_mask).sum() + ((1 + torch.exp(ne)).log() * n_mask).sum()
            return loss, loss
        elif self.has_mode('endorsement'):
            if self.has_mode('pretrain'):
                pe = net_output[0].float()
                ne = net_output[1]['neg_out'][0].float()
                p_mask = net_output[1]['p_mask']
                n_mask = net_output[1]['n_mask']
                n_in_p_mask = net_output[1]['n_in_p_mask'].float()
                ploss = (-F.logsigmoid(pe).masked_select(p_mask)).sum()
                nloss = (-F.logsigmoid(-ne*(1 - n_in_p_mask)).masked_select(n_mask)).sum()

                wloss = net_output[1]['wloss'].float().sum() * 0.05
                freq_loss = net_output[1]['freq_loss'].float().sum() * 1

                nll_loss = loss = ploss + nloss + wloss + freq_loss
                if self.has_mode('self_align'):
                    wloss = net_output[1]['self_align_decoder_out'][1]['wloss'].float().sum() * 0.05
                    freq_loss = net_output[1]['self_align_decoder_out'][1]['freq_loss'].float().sum() * 1
                    ploss = net_output[1]['self_align_decoder_out'][1]['ploss']
                    nloss = net_output[1]['self_align_decoder_out'][1]['nloss']
                    align_loss = ploss + nloss + freq_loss
                    loss = nll_loss = loss + align_loss
                return loss, nll_loss
            elif self.has_mode('drop_words'):
                if model.training:
                    sample['target'] = net_output[1]['drop_words_target']
                loss, nll_loss = super().compute_loss(model, net_output, sample)
                return loss, nll_loss
            else:
                if model.training:
                    m = net_output[1]['edm_decoder_out'][1]['m'].detach()
                    s_mask_inf = net_output[1]['edm_decoder_out'][1]['s_mask_inf'].detach()
                    max_match = (m + s_mask_inf.unsqueeze(-1)).max(1)[0]
                    a, b = (float(self.user_mode['a']), float(self.user_mode['b'])) if 'a' in self.user_mode else (1.0, 0.0)
                    weight = (a * (max_match - b)).sigmoid().float()
                    #weight = (max_match).sigmoid().float()
                    #weight[:] = 1
                    # dbg = model.decoder.get_align(sample['net_input']['src_tokens'], sample['target'], net_output[1]['edm_decoder_out'][1])
                    if self.has_mode('dbg_log_endorsement'):
                        data = {'src_tokens' : model.decoder.get_src_words(sample['net_input']['src_tokens'], ''), 'target' : model.decoder.get_src_words(sample['target'], ''), 'attn' : net_output[1]['attn'].data.cpu(), 'm' : m.data.cpu()}
                        torch.save(data, open("output/handcraft/dbg_%s.pt"%model.encoder.task.args.user_mode, "wb"))
                    if self.has_mode('add_exact'):
                        word_exactness = net_output[1]['word_exactness']
                        exactness = word_exactness[sample['target']]
                        tgt_notin_src = 1 - model.decoder.a_in_b(sample['target'], sample['net_input']['src_tokens'])
                        weight_exact = 1 - tgt_notin_src.float() * exactness
                        weight = weight * weight_exact
                    if self.has_mode('hard_weight'):
                        weight = (weight > torch.rand_like(weight)).float()

                    loss, nll_loss = self.get_elementwise_loss(model, net_output, sample)
                    target_mask = sample['target'].ne(self.padding_idx)
                    loss = weight.masked_select(target_mask) * loss.masked_select(target_mask.view(-1))
                    if self.has_mode('sent_weight'):
                        sweight = (((weight * target_mask.float()).sum(1) / target_mask.float().sum(1)).unsqueeze(1) * target_mask.float()).masked_select(target_mask)
                        loss = loss * sweight
                    loss = loss.sum()
                    nll_loss = (weight.masked_select(target_mask) * nll_loss.masked_select(target_mask.view(-1))).sum()
                    #loss = nll_loss = loss.masked_select(target_mask.view(-1)).sum()
                    #loss, nll_loss = super().compute_loss(model, net_output, sample)
                    return loss, nll_loss
                else:
                    pass            
        elif self.is_mode('endorsement', 'bak'):
            pe = net_output[0].float()
            ne = net_output[1]['neg_out'][0].float()
            p_mask = net_output[1]['p_mask']
            n_mask = net_output[1]['n_mask']
            #loss = ((1 + torch.exp(-pe)).log() * p_mask).sum() + ((1 + torch.exp(ne)).log() * n_mask).sum()
            loss = (-F.logsigmoid(pe).masked_select(p_mask)).sum() + (-F.logsigmoid(-ne).masked_select(n_mask)).sum()
            return loss, loss
        elif self.is_mode('attn_endorse'):
            weight = net_output[1]['attn'].max(2)[0].detach().float()#.pow(0.5)
            loss, nll_loss = self.get_elementwise_loss(model, net_output, sample)
            target_mask = sample['target'].ne(self.padding_idx)
            loss = (weight.masked_select(target_mask) * loss.masked_select(target_mask.view(-1))).sum()
            nll_loss = (weight.masked_select(target_mask) * nll_loss.masked_select(target_mask.view(-1))).sum()
            return loss, nll_loss

        loss, nll_loss = super().compute_loss(model, net_output, sample)
        if self.is_mode('rl_word'):
            rl_loss = net_output[1]['encoder_out']['rl_loss'].sum()
            loss = loss + rl_loss
        elif self.has_mode('sep_lm', 'sep_lm1', 'sep_lm2', 'sep_lm3'):
            lm_loss, _ = super().compute_loss(model, net_output[1]['lm_out'], sample)
            if self.has_mode('only_lm'):
                loss = lm_loss
            else:
                loss = loss + lm_loss
        elif self.is_mode('ignore_batch'):
            if model.batch_id % 10 == 1:
                loss, nll_loss = loss * 0, nll_loss * 0
        elif self.is_mode('gated'):
            #gs = torch.stack([x.g.squeeze(-1) for x in net_output[1]['inner_states'][1:]])
            gs = net_output[1]['inner_states'][1].g.squeeze(-1)
            lg = gs.float().sum()
            #print(loss.data.item(), lg.data.item())
            loss = loss + float(self.user_mode['gated']) * lg
        elif self.has_mode('add_lm'):
            lm_loss, lm_nll_loss = super().compute_loss(model, net_output[1]['lm_out'], sample)
            loss = loss + lm_loss
        elif self.is_mode('decomposable1'):
            gs = net_output[1]['inner_states'][-1].g.squeeze(-1)
            lg = gs.float().sum()
            loss = loss + float(self.user_mode['decomposable']) * lg
        elif self.is_mode('rl_edm'):
            if self.training:
                scores = net_output[1]['word_scores'].float()
                lprob = net_output[1]['word_lprob']
                tokens = net_output[1]['word_tokens']
                mask = tokens.ne(model.decoder.pad).float()
                b = 2.0
                #idx = (tokens == model.decoder.model.decoder.dictionary.indices['town'])
                #lp = lprob[idx]
                #rl_loss = -(-1 * lp).sum() * 100
                #rl_loss = -(torch.min(scores - b, 0.1*(scores - b)) * mask * lprob ).sum()
                #2 * sigmoid(0.5 * x - 2) - 1 from -6 to 19
                #rl_loss = -(((scores * 0.5 - 2).sigmoid() * 2 - 1) * mask * lprob ).sum()
                mask = ((scores < b) & tokens.ne(model.decoder.pad))
                rl_loss = -((scores[mask] - b) * lprob[mask] ).sum() * 0.01
                loss = loss + rl_loss
        if self.has_mode('dbg_log_endorsement'):
            data = {'src_tokens' : model.decoder.get_src_words(sample['net_input']['src_tokens'], ''), 'target' : model.decoder.get_src_words(sample['target'], ''), 'attn' : net_output[1]['attn'].data.cpu()}
            torch.save(data, open("output/handcraft/dbg_%s.pt"%model.encoder.task.args.user_mode, "wb"))
        return loss, nll_loss
Esempio n. 14
0
    def forward(self, model, sample, reduce=True, log_pred=False):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        logits = model.get_logits(net_output).float()  # self pred
        target = model.get_targets(sample, net_output)  # self target

        logits_ali = model.get_ali_logits(net_output)  # ali pred
        target_ali = model.get_ali_targets(sample, net_output)  # ali target

        weights = None
        if hasattr(model, 'get_target_weights') and not self.infonce:
            weights = model.get_target_weights(target, net_output)
            if torch.is_tensor(weights):
                weights = weights.float()

        if self.infonce:
            loss = F.cross_entropy(
                logits,
                target,
                reduction="sum" if reduce else "none",
            )
            lprobs = utils.log_softmax(logits_ali.float(), dim=-1)
            loss_ali, _ = label_smoothed_nll_loss(
                lprobs,
                target_ali,
                0.0,
                reduce="sum" if reduce else "none",
            )
        else:
            loss = F.binary_cross_entropy_with_logits(
                logits,
                target.float(),
                weights,
                reduction="sum" if reduce else "none",
            )

        sample_size = target.numel() if self.infonce else target.long().sum(
        ).item()
        losses = [loss.detach().clone(), loss_ali.detach().clone()]
        loss += 0.001 * loss_ali

        if self.loss_weights is not None:
            assert hasattr(model, "get_extra_losses")
            extra_losses = model.get_extra_losses(net_output)
            if torch.is_tensor(extra_losses):
                extra_losses = [extra_losses]
            if len(self.loss_weights) == 1 and len(extra_losses) != 1:
                self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
            assert len(extra_losses) == len(
                self.loss_weights
            ), f'{len(extra_losses)}, {len(self.loss_weights)}'
            for p, coef in zip(extra_losses, self.loss_weights):
                if coef != 0 and p is not None:
                    p = coef * p.float() * sample_size
                    loss += p
                    losses.append(p)

        logging_output = {
            'loss': loss.item(),
            'ntokens': sample_size,
            'nsentences': sample['id'].numel(),
            'sample_size': sample_size,
        }

        for lk in self.log_keys:
            if lk in net_output:
                logging_output[lk] = float((net_output[lk]))

        if len(losses) > 1:
            for i, l in enumerate(losses):
                logging_output[f'loss_{i}'] = l.item()

        if self.infonce:
            with torch.no_grad():
                if logits.numel() == 0:
                    corr = 0
                    count = 0

                    corr_ali = 0
                    count_ali = 0
                else:
                    assert logits.dim() > 1, logits.shape
                    max = logits.argmax(-1) == 0
                    min = logits.argmin(-1) == 0
                    both = max & min
                    corr = max.long().sum().item() - both.long().sum().item()
                    count = max.numel()

                    corr_ali = (logits_ali.argmax(1) == target_ali).sum()
                    count_ali = target_ali.numel()

                logging_output["correct"] = corr
                logging_output["count"] = count

                logging_output["correct_ali"] = corr_ali.data
                logging_output["count_ali"] = count_ali

        if log_pred:
            logging_output['logits'] = logits.cpu().numpy()
            logging_output['target'] = target.cpu().numpy()
        return loss, sample_size, logging_output
    def compute_encoder_classification_loss(self,
                                            net_input,
                                            net_output,
                                            reduce=True,
                                            classification_step=True,
                                            language_classifier_one_vs_rest=0,
                                            vocab_size=256000):
        # net_input["src_tokens"] is B x T
        # Take first src token (src lang ID). B
        src_lang_target = net_input[
            "src_tokens"][:,
                          0] - vocab_size + 1  # label 0 is reserved for padding

        encoder_classification_out = net_output[1]["classification_out"]
        max_len, bsz, num_total_labels = encoder_classification_out.shape
        lang_target_padding = 0

        if not torch.all(src_lang_target > 0):
            print("Violating 1")
            print(net_input["src_tokens"])
            print(src_lang_target)
            exit()

        # print("pred shape 1", F.log_softmax(src_lang_pred.float(), dim=-1).shape)
        # print("pred shape 2", model.get_normalized_probs(src_lang_pred, log_probs=True).shape) <-- this eats the 1st dim
        lprobs = F.log_softmax(encoder_classification_out.float(),
                               dim=-1)  # softmax
        target = src_lang_target.repeat(max_len, 1)  # B --> T x B
        # Get indices
        src_pad_idx = net_input["src_tokens"].eq(self.padding_idx).transpose(
            0, 1)
        src_one_lang_idx = target == language_classifier_one_vs_rest
        # print("ONE LANG", src_one_lang_idx)
        # print("OTHER LANG", ~src_one_lang_idx)

        if language_classifier_one_vs_rest != 0:  # Change target to binary
            target[torch.logical_and(src_one_lang_idx, ~src_pad_idx)] = 1
            target[torch.logical_and(~src_one_lang_idx, ~src_pad_idx)] = 2

        target[src_pad_idx] = lang_target_padding

        if not torch.all(target < num_total_labels):
            print("Violating 2")
            # print("src", net_input["src_tokens"])
            print("target", target)
            exit()

        if self.ignore_prefix_size > 0:
            if getattr(lprobs, "batch_first", False):
                lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
                target = target[:, self.ignore_prefix_size:].contiguous()
            else:
                lprobs = lprobs[self.ignore_prefix_size:, :, :].contiguous()
                target = target[self.ignore_prefix_size:, :].contiguous()

        lprobs, target = lprobs.view(-1, lprobs.size(-1)), target.view(-1)

        if classification_step:
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                target,
                self.eps,
                ignore_index=lang_target_padding,
                reduce=reduce,
            )
        else:
            loss, nll_loss = self.nll_loss_rev(
                lprobs,
                target,
                self.eps,
                ignore_index=lang_target_padding,
                reduce=reduce,
            )

        stats_per_lang = {}
        unique_targets = torch.unique(target, sorted=True)

        for id in unique_targets:
            mask = target.eq(id)
            n_correct = torch.sum(
                lprobs.argmax(1).masked_select(mask).eq(
                    target.masked_select(mask)))
            n_total = torch.sum(mask)
            stats_per_lang[utils.item(id.data)] = [
                utils.item(n_correct.data),
                utils.item(n_total.data)
            ]

        # Calc overall accuracy
        mask = target.ne(self.padding_idx)
        n_correct = torch.sum(
            lprobs.argmax(1).masked_select(mask).eq(
                target.masked_select(mask)))
        n_total = torch.sum(mask)

        return loss, nll_loss, n_correct, n_total, stats_per_lang