def compute_loss(self, model, net_output, sample, reduce=True): _2nd_lprobs, target = self.get_lprobs_and_target( model, net_output, sample) temp_net_output = (net_output[1]["dec_out"], net_output[1]) # logger.info("Info of net_output:{}".format(net_output[0].size())) # logger.info("Info of temp_net_output:{}".format(temp_net_output[0].size())) _1st_lprobs, _ = self.get_lprobs_and_target(model, temp_net_output, sample) # logger.info("shape of _1st_lprobs:{}, shape of _2nd_lprobs:{}".format(_1st_lprobs.size(), _2nd_lprobs.size())) _2nd_loss, _2nd_nll_loss = label_smoothed_nll_loss( _2nd_lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, ) _1st_loss, _1st_nll_loss = label_smoothed_nll_loss( _1st_lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce) loss = 0.5 * _1st_loss + 0.5 * _2nd_loss nll_loss = 0.5 * _1st_nll_loss + 0.5 * _2nd_nll_loss return loss, nll_loss
def compute_loss(self, model, net_output, sample, reduction, log_probs): # number loss _number = net_output["num_output"] number = sample["target_lengths"].float() diff = torch.sqrt(torch.pow(_number - number, 2) + 1e-6).sum() qua_loss = diff # alphas_pen # alphas_pen = net_output["alphas_pen"] # qua_loss = diff + self.args.lambda_alpha * alphas_pen target = sample["target"] # no eos bos # N, T -> N * T target = target.view(-1) lprobs = model.get_normalized_probs(net_output, log_probs=log_probs) if not hasattr(lprobs, "batch_first"): logging.warning( "ERROR: we need to know whether " "batch first for the net output; " "you need to set batch_first attribute for the return value of " "model.get_normalized_probs. Now, we assume this is true, but " "in the future, we will raise exception instead. " ) batch_first = getattr(lprobs, "batch_first", True) if not batch_first: lprobs = lprobs.transpose(0, 1) # N, T, D -> N * T, D lprobs = lprobs.view(-1, lprobs.size(-1)) ce_loss, _ = label_smoothed_nll_loss( lprobs, target.long(), 0.1, ignore_index=self.padding_idx, reduce=reduction, ) return lprobs, qua_loss, ce_loss
def compute_loss(self, model, ctc_logits, len_ctc_logits, logits, target, target_lengths, reduction, log_probs): # N, T -> N * T ctc_lprob, lprobs = model.get_normalized_probs(ctc_logits, logits, log_probs=log_probs) ctc_loss = self.cal_ctc_loss(ctc_lprob, len_ctc_logits, target, target_lengths - 1) if not hasattr(lprobs, "batch_first"): logging.warning( "ERROR: we need to know whether " "batch first for the net output; " "you need to set batch_first attribute for the return value of " "model.get_normalized_probs. Now, we assume this is true, but " "in the future, we will raise exception instead. ") batch_first = getattr(lprobs, "batch_first", True) if not batch_first: lprobs = lprobs.transpose(0, 1) # N, T, D -> N * T, D target = target.view(-1) lprobs = lprobs.view(-1, lprobs.size(-1)) loss, _ = label_smoothed_nll_loss( lprobs, target.long(), 0.1, ignore_index=self.padding_idx, reduce=reduction, ) return lprobs, ctc_loss, loss
def get_elementwise_loss(self, model, net_output, sample): lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = lprobs.view(-1, lprobs.size(-1)) target = model.get_targets(sample, net_output).view(-1, 1) loss, nll_loss = label_smoothed_nll_loss( lprobs, target, self.eps, ignore_index=None, reduce=False, ) return loss, nll_loss
def compute_loss(self, model, net_output, sample, reduce=True, reverse=False): lprobs, target = self.get_lprobs_and_target(model, net_output, sample, reverse=reverse) loss, nll_loss = label_smoothed_nll_loss( lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, ) return loss, nll_loss
def forward(self, model, sample, reduce=True, log_probs=True): net_output = model(**sample['net_input']) sample_size = sample['target'].size( 0) if self.sentence_avg else sample['ntokens'] lprobs = model.get_normalized_probs(net_output[0], log_probs=True) lprobs = lprobs.view(-1, lprobs.size(-1)) target = model.get_targets(sample, net_output[0]).view(-1, 1) primary_loss, primary_nll_loss = label_smoothed_nll_loss( lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, ) auxiliary_token_lens = model.get_auxiliary_token_lens(sample) auxiliary_probs = model.auxiliary_decoder.get_normalized_probs( net_output[1], log_probs=True) auxiliary_probs = auxiliary_probs.view(-1, auxiliary_probs.size(-1)) auxiliary_target = model.get_auxiliary_target(sample, net_output[1]).view( -1, 1) auxiliary_loss, auxiliary_nll_loss = label_smoothed_nll_loss( auxiliary_probs, auxiliary_target, self.eps, ignore_index=self.padding_idx, reduce=reduce, ) loss = self.primary_loss_weight * primary_loss + self.auxiliary_loss_weight * auxiliary_loss logging_output = { 'loss': loss.data, 'primary_loss': primary_loss.data, 'primary_nll_loss': primary_nll_loss.data, 'auxiliary_loss': auxiliary_loss.data, 'auxiliary_nll_loss': auxiliary_nll_loss.data, 'ntokens': sample['ntokens'], 'auxiliary_ntokens': auxiliary_token_lens.sum().data, 'nsentences': sample['target'].size(0), 'sample_size': sample_size, } return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True): lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = lprobs.view(-1, lprobs.size(-1)) target = model.get_targets(sample, net_output).view(-1, 1) loss, nll_loss = label_smoothed_nll_loss( lprobs, target, self.label_smoothing, ignore_index=self.padding_idx, reduce=reduce, ) return loss, nll_loss
def compute_loss_and_acc(self, model, lprobs, target, reduction='sum'): if not lprobs.batch_first: lprobs = lprobs.transpose(0, 1) lprobs = lprobs.view(-1, lprobs.size(-1)) # -> (B x T) x C target = target.view(-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=(reduction == 'sum'), ) mask = target.ne(self.padding_idx) correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))) total = torch.sum(mask) return loss, nll_loss, correct, total
def compute_loss(self, model, net_output, sample, reduction, log_probs): pad_id = self.task.target_dictionary.pad() # number loss _number = net_output["num_output"] number = sample["target_lengths"].float() diff = torch.sqrt(torch.pow(_number - number, 2) + 1e-12).sum() qua_loss = diff target = sample["target"].view(-1) # N, T -> N * T lprobs_ctc, lprobs = model.get_normalized_probs(net_output, retrun_ctc=True, log_probs=log_probs) pred_mask = net_output["pred_mask"] lprobs = lprobs.view(-1, lprobs.size(-1)) # N, T, D -> N * T, D target_masked = torch.where( pred_mask, sample["target"], torch.ones_like(sample["target"]) * pad_id).long().view(-1) ce_loss, _ = label_smoothed_nll_loss( lprobs, target_masked, 0.0, ignore_index=self.task.target_dictionary.pad(), reduce=reduction, ) # CTC loss target = sample["target"] pad_mask = target != self.task.target_dictionary.pad() targets_flat = target.masked_select(pad_mask) target_lengths = sample["target_lengths"] len_lprobs = net_output["len_logits_ctc"] with torch.backends.cudnn.flags(enabled=False): ctc_loss = F.ctc_loss( lprobs_ctc.transpose(0, 1), # T x B x V targets_flat, len_lprobs, target_lengths, blank=self.task.target_dictionary.pad(), reduction="sum", zero_infinity=True, ) pred_mask = net_output["pred_mask"] target_masked = (sample["target"] * pred_mask[:, 1:-1]).view(-1) return lprobs, target_masked, ctc_loss, qua_loss, ce_loss
def cif_loss(self, model, sample, net_output, reduce): target = sample["target"] # N, T -> N * T target = target.view(-1) lprobs = model.get_normalized_probs_cif(net_output, log_probs=True) # N, T, D -> N * T, D lprobs = lprobs.view(-1, lprobs.size(-1)) ce_loss, _ = label_smoothed_nll_loss( lprobs, target.long(), 0.1, ignore_index=self.padding_idx, reduce=reduce, ) return ce_loss, lprobs
def compute_loss(self, model, net_output, sample, reduce=True): lprobs, target = self.get_lprobs_and_target(model, net_output, sample) n_correct = 0 if isinstance(target, dict): t_lprobs = target["target_logprobs"] if not lprobs.batch_first: lprobs = lprobs.transpose(0, 1) t_lprobs = t_lprobs.transpose(0, 1) nsentences, seq_len = lprobs.size()[:2] ntokens = nsentences * seq_len t_probs = t_lprobs.exp() mask_indices = (net_output[1]["mask_indices"][0] if len(net_output[1]["mask_indices"]) > 0 else None) # mask_indices is True for those masking frames if mask_indices is not None: # B X T t_probs = t_probs.masked_fill( mask_indices.eq(False).unsqueeze(-1), 0) ntokens = mask_indices.int().sum() t_probs = t_probs.detach() t_lprobs = t_lprobs.detach() loss = (-(t_probs * (lprobs - t_lprobs)).sum() if reduce else -(t_probs * (lprobs - t_lprobs)).sum(-1, keepdim=True)) nll_loss = loss else: nsentences = target.size(0) mask = target.ne(self.padding_idx) loss, nll_loss = label_smoothed_nll_loss( lprobs.view(-1, lprobs.size(-1)), target.view(-1), self.eps, ignore_index=self.padding_idx, reduce=reduce, ) n_correct = torch.sum( lprobs.argmax(-1).masked_select(mask).eq( target.masked_select(mask))) ntokens = torch.sum(mask) return loss, nll_loss, nsentences, ntokens, n_correct
def forward(self, model, sample, reduce=True, log_probs=True): net_output = model(**sample['net_input']) sample_size = sample['target'].size( 0) if self.sentence_avg else sample['ntokens'] lprobs = model.get_normalized_probs(net_output[0], log_probs=True) lprobs = lprobs.view(-1, lprobs.size(-1)) target = model.get_targets(sample, net_output[0]).view(-1, 1) loss, nll_loss = label_smoothed_nll_loss( lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, ) auxiliary_probs = model.auxiliary_decoder.get_normalized_probs( net_output[1], log_probs=True) if self.auxiliary_loss_class_weights is not None: class_weights = self.auxiliary_loss_class_weights.to( sample["auxiliary_target"].device) else: class_weights = None auxiliary_loss = F.nll_loss( auxiliary_probs, sample["auxiliary_target"].view(-1), weight=class_weights, reduction='sum' if reduce else 'none', ) loss = loss + self.auxiliary_loss_weight * auxiliary_loss logging_output = { 'loss': loss.data, 'nll_loss': nll_loss.data, 'auxiliary_loss': auxiliary_loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample['target'].size(0), 'sample_size': sample_size, } return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True): self.user_mode = net_output[1]["user_mode"] if self.is_mode('max_lm_margin'): if self.training: lm_out_margin, decoder_out_margin = net_output[1]['lm_out_margin'], net_output[1]['decoder_out_margin'] lp = self.get_log_probs(lm_out_margin, model).detach() lq =self.get_log_probs(decoder_out_margin, model) idx1 = net_output[1]['encoder_out']['idx1'] idx2 = net_output[1]['encoder_out']['idx2'] nkl = (lp.exp() * (lq - lp)).sum() / len(idx2) * len(idx1) new_sample = self.sub_sample_by_id(sample, idx1, model) loss, nll_loss = super().compute_loss(model, net_output, new_sample) loss_lm, nll_loss_lm = super().compute_loss(model, net_output[1]['lm_out'], new_sample) loss = loss + float(self.user_mode['rkl']) * nkl + loss_lm else: loss, nll_loss = super().compute_loss(model, net_output, sample) return loss, nll_loss elif self.is_mode('simple_mask'): msk = model.decoder.a_in_b(sample['target'], sample['net_input']['src_tokens']) mask = (msk | ((torch.rand_like(msk.float()) > float(self.user_mode['simple_mask'])) & (sample['target'].ne(model.decoder.pad))).long()).float() lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = lprobs.view(-1, lprobs.size(-1)) target = model.get_targets(sample, net_output).view(-1, 1) loss, nll_loss = label_smoothed_nll_loss( lprobs, target, self.eps, ignore_index=None, reduce=False, ) loss, nll_loss = (loss * mask.view(-1)).sum(), (nll_loss * mask.view(-1)).sum() return loss, nll_loss elif self.is_mode('endorsement', 'lse'): pe = net_output[0].float() ne = net_output[1]['neg_out'][0].float() p_mask = net_output[1]['p_mask'].float() n_mask = net_output[1]['n_mask'].float() loss = ((1 + torch.exp(-pe)).log() * p_mask).sum() + ((1 + torch.exp(ne)).log() * n_mask).sum() #pe[pe == float('-inf')] = float('inf') #loss = ((1 + torch.exp(-pe)).log() * p_mask).sum() + ((1 + torch.exp(ne)).log() * n_mask).sum() return loss, loss elif self.has_mode('endorsement'): if self.has_mode('pretrain'): pe = net_output[0].float() ne = net_output[1]['neg_out'][0].float() p_mask = net_output[1]['p_mask'] n_mask = net_output[1]['n_mask'] n_in_p_mask = net_output[1]['n_in_p_mask'].float() ploss = (-F.logsigmoid(pe).masked_select(p_mask)).sum() nloss = (-F.logsigmoid(-ne*(1 - n_in_p_mask)).masked_select(n_mask)).sum() wloss = net_output[1]['wloss'].float().sum() * 0.05 freq_loss = net_output[1]['freq_loss'].float().sum() * 1 nll_loss = loss = ploss + nloss + wloss + freq_loss if self.has_mode('self_align'): wloss = net_output[1]['self_align_decoder_out'][1]['wloss'].float().sum() * 0.05 freq_loss = net_output[1]['self_align_decoder_out'][1]['freq_loss'].float().sum() * 1 ploss = net_output[1]['self_align_decoder_out'][1]['ploss'] nloss = net_output[1]['self_align_decoder_out'][1]['nloss'] align_loss = ploss + nloss + freq_loss loss = nll_loss = loss + align_loss return loss, nll_loss elif self.has_mode('drop_words'): if model.training: sample['target'] = net_output[1]['drop_words_target'] loss, nll_loss = super().compute_loss(model, net_output, sample) return loss, nll_loss else: if model.training: m = net_output[1]['edm_decoder_out'][1]['m'].detach() s_mask_inf = net_output[1]['edm_decoder_out'][1]['s_mask_inf'].detach() max_match = (m + s_mask_inf.unsqueeze(-1)).max(1)[0] a, b = (float(self.user_mode['a']), float(self.user_mode['b'])) if 'a' in self.user_mode else (1.0, 0.0) weight = (a * (max_match - b)).sigmoid().float() #weight = (max_match).sigmoid().float() #weight[:] = 1 # dbg = model.decoder.get_align(sample['net_input']['src_tokens'], sample['target'], net_output[1]['edm_decoder_out'][1]) if self.has_mode('dbg_log_endorsement'): data = {'src_tokens' : model.decoder.get_src_words(sample['net_input']['src_tokens'], ''), 'target' : model.decoder.get_src_words(sample['target'], ''), 'attn' : net_output[1]['attn'].data.cpu(), 'm' : m.data.cpu()} torch.save(data, open("output/handcraft/dbg_%s.pt"%model.encoder.task.args.user_mode, "wb")) if self.has_mode('add_exact'): word_exactness = net_output[1]['word_exactness'] exactness = word_exactness[sample['target']] tgt_notin_src = 1 - model.decoder.a_in_b(sample['target'], sample['net_input']['src_tokens']) weight_exact = 1 - tgt_notin_src.float() * exactness weight = weight * weight_exact if self.has_mode('hard_weight'): weight = (weight > torch.rand_like(weight)).float() loss, nll_loss = self.get_elementwise_loss(model, net_output, sample) target_mask = sample['target'].ne(self.padding_idx) loss = weight.masked_select(target_mask) * loss.masked_select(target_mask.view(-1)) if self.has_mode('sent_weight'): sweight = (((weight * target_mask.float()).sum(1) / target_mask.float().sum(1)).unsqueeze(1) * target_mask.float()).masked_select(target_mask) loss = loss * sweight loss = loss.sum() nll_loss = (weight.masked_select(target_mask) * nll_loss.masked_select(target_mask.view(-1))).sum() #loss = nll_loss = loss.masked_select(target_mask.view(-1)).sum() #loss, nll_loss = super().compute_loss(model, net_output, sample) return loss, nll_loss else: pass elif self.is_mode('endorsement', 'bak'): pe = net_output[0].float() ne = net_output[1]['neg_out'][0].float() p_mask = net_output[1]['p_mask'] n_mask = net_output[1]['n_mask'] #loss = ((1 + torch.exp(-pe)).log() * p_mask).sum() + ((1 + torch.exp(ne)).log() * n_mask).sum() loss = (-F.logsigmoid(pe).masked_select(p_mask)).sum() + (-F.logsigmoid(-ne).masked_select(n_mask)).sum() return loss, loss elif self.is_mode('attn_endorse'): weight = net_output[1]['attn'].max(2)[0].detach().float()#.pow(0.5) loss, nll_loss = self.get_elementwise_loss(model, net_output, sample) target_mask = sample['target'].ne(self.padding_idx) loss = (weight.masked_select(target_mask) * loss.masked_select(target_mask.view(-1))).sum() nll_loss = (weight.masked_select(target_mask) * nll_loss.masked_select(target_mask.view(-1))).sum() return loss, nll_loss loss, nll_loss = super().compute_loss(model, net_output, sample) if self.is_mode('rl_word'): rl_loss = net_output[1]['encoder_out']['rl_loss'].sum() loss = loss + rl_loss elif self.has_mode('sep_lm', 'sep_lm1', 'sep_lm2', 'sep_lm3'): lm_loss, _ = super().compute_loss(model, net_output[1]['lm_out'], sample) if self.has_mode('only_lm'): loss = lm_loss else: loss = loss + lm_loss elif self.is_mode('ignore_batch'): if model.batch_id % 10 == 1: loss, nll_loss = loss * 0, nll_loss * 0 elif self.is_mode('gated'): #gs = torch.stack([x.g.squeeze(-1) for x in net_output[1]['inner_states'][1:]]) gs = net_output[1]['inner_states'][1].g.squeeze(-1) lg = gs.float().sum() #print(loss.data.item(), lg.data.item()) loss = loss + float(self.user_mode['gated']) * lg elif self.has_mode('add_lm'): lm_loss, lm_nll_loss = super().compute_loss(model, net_output[1]['lm_out'], sample) loss = loss + lm_loss elif self.is_mode('decomposable1'): gs = net_output[1]['inner_states'][-1].g.squeeze(-1) lg = gs.float().sum() loss = loss + float(self.user_mode['decomposable']) * lg elif self.is_mode('rl_edm'): if self.training: scores = net_output[1]['word_scores'].float() lprob = net_output[1]['word_lprob'] tokens = net_output[1]['word_tokens'] mask = tokens.ne(model.decoder.pad).float() b = 2.0 #idx = (tokens == model.decoder.model.decoder.dictionary.indices['town']) #lp = lprob[idx] #rl_loss = -(-1 * lp).sum() * 100 #rl_loss = -(torch.min(scores - b, 0.1*(scores - b)) * mask * lprob ).sum() #2 * sigmoid(0.5 * x - 2) - 1 from -6 to 19 #rl_loss = -(((scores * 0.5 - 2).sigmoid() * 2 - 1) * mask * lprob ).sum() mask = ((scores < b) & tokens.ne(model.decoder.pad)) rl_loss = -((scores[mask] - b) * lprob[mask] ).sum() * 0.01 loss = loss + rl_loss if self.has_mode('dbg_log_endorsement'): data = {'src_tokens' : model.decoder.get_src_words(sample['net_input']['src_tokens'], ''), 'target' : model.decoder.get_src_words(sample['target'], ''), 'attn' : net_output[1]['attn'].data.cpu()} torch.save(data, open("output/handcraft/dbg_%s.pt"%model.encoder.task.args.user_mode, "wb")) return loss, nll_loss
def forward(self, model, sample, reduce=True, log_pred=False): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ net_output = model(**sample['net_input']) logits = model.get_logits(net_output).float() # self pred target = model.get_targets(sample, net_output) # self target logits_ali = model.get_ali_logits(net_output) # ali pred target_ali = model.get_ali_targets(sample, net_output) # ali target weights = None if hasattr(model, 'get_target_weights') and not self.infonce: weights = model.get_target_weights(target, net_output) if torch.is_tensor(weights): weights = weights.float() if self.infonce: loss = F.cross_entropy( logits, target, reduction="sum" if reduce else "none", ) lprobs = utils.log_softmax(logits_ali.float(), dim=-1) loss_ali, _ = label_smoothed_nll_loss( lprobs, target_ali, 0.0, reduce="sum" if reduce else "none", ) else: loss = F.binary_cross_entropy_with_logits( logits, target.float(), weights, reduction="sum" if reduce else "none", ) sample_size = target.numel() if self.infonce else target.long().sum( ).item() losses = [loss.detach().clone(), loss_ali.detach().clone()] loss += 0.001 * loss_ali if self.loss_weights is not None: assert hasattr(model, "get_extra_losses") extra_losses = model.get_extra_losses(net_output) if torch.is_tensor(extra_losses): extra_losses = [extra_losses] if len(self.loss_weights) == 1 and len(extra_losses) != 1: self.loss_weights = [self.loss_weights[0]] * len(extra_losses) assert len(extra_losses) == len( self.loss_weights ), f'{len(extra_losses)}, {len(self.loss_weights)}' for p, coef in zip(extra_losses, self.loss_weights): if coef != 0 and p is not None: p = coef * p.float() * sample_size loss += p losses.append(p) logging_output = { 'loss': loss.item(), 'ntokens': sample_size, 'nsentences': sample['id'].numel(), 'sample_size': sample_size, } for lk in self.log_keys: if lk in net_output: logging_output[lk] = float((net_output[lk])) if len(losses) > 1: for i, l in enumerate(losses): logging_output[f'loss_{i}'] = l.item() if self.infonce: with torch.no_grad(): if logits.numel() == 0: corr = 0 count = 0 corr_ali = 0 count_ali = 0 else: assert logits.dim() > 1, logits.shape max = logits.argmax(-1) == 0 min = logits.argmin(-1) == 0 both = max & min corr = max.long().sum().item() - both.long().sum().item() count = max.numel() corr_ali = (logits_ali.argmax(1) == target_ali).sum() count_ali = target_ali.numel() logging_output["correct"] = corr logging_output["count"] = count logging_output["correct_ali"] = corr_ali.data logging_output["count_ali"] = count_ali if log_pred: logging_output['logits'] = logits.cpu().numpy() logging_output['target'] = target.cpu().numpy() return loss, sample_size, logging_output
def compute_encoder_classification_loss(self, net_input, net_output, reduce=True, classification_step=True, language_classifier_one_vs_rest=0, vocab_size=256000): # net_input["src_tokens"] is B x T # Take first src token (src lang ID). B src_lang_target = net_input[ "src_tokens"][:, 0] - vocab_size + 1 # label 0 is reserved for padding encoder_classification_out = net_output[1]["classification_out"] max_len, bsz, num_total_labels = encoder_classification_out.shape lang_target_padding = 0 if not torch.all(src_lang_target > 0): print("Violating 1") print(net_input["src_tokens"]) print(src_lang_target) exit() # print("pred shape 1", F.log_softmax(src_lang_pred.float(), dim=-1).shape) # print("pred shape 2", model.get_normalized_probs(src_lang_pred, log_probs=True).shape) <-- this eats the 1st dim lprobs = F.log_softmax(encoder_classification_out.float(), dim=-1) # softmax target = src_lang_target.repeat(max_len, 1) # B --> T x B # Get indices src_pad_idx = net_input["src_tokens"].eq(self.padding_idx).transpose( 0, 1) src_one_lang_idx = target == language_classifier_one_vs_rest # print("ONE LANG", src_one_lang_idx) # print("OTHER LANG", ~src_one_lang_idx) if language_classifier_one_vs_rest != 0: # Change target to binary target[torch.logical_and(src_one_lang_idx, ~src_pad_idx)] = 1 target[torch.logical_and(~src_one_lang_idx, ~src_pad_idx)] = 2 target[src_pad_idx] = lang_target_padding if not torch.all(target < num_total_labels): print("Violating 2") # print("src", net_input["src_tokens"]) print("target", target) exit() if self.ignore_prefix_size > 0: if getattr(lprobs, "batch_first", False): lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() target = target[:, self.ignore_prefix_size:].contiguous() else: lprobs = lprobs[self.ignore_prefix_size:, :, :].contiguous() target = target[self.ignore_prefix_size:, :].contiguous() lprobs, target = lprobs.view(-1, lprobs.size(-1)), target.view(-1) if classification_step: loss, nll_loss = label_smoothed_nll_loss( lprobs, target, self.eps, ignore_index=lang_target_padding, reduce=reduce, ) else: loss, nll_loss = self.nll_loss_rev( lprobs, target, self.eps, ignore_index=lang_target_padding, reduce=reduce, ) stats_per_lang = {} unique_targets = torch.unique(target, sorted=True) for id in unique_targets: mask = target.eq(id) n_correct = torch.sum( lprobs.argmax(1).masked_select(mask).eq( target.masked_select(mask))) n_total = torch.sum(mask) stats_per_lang[utils.item(id.data)] = [ utils.item(n_correct.data), utils.item(n_total.data) ] # Calc overall accuracy mask = target.ne(self.padding_idx) n_correct = torch.sum( lprobs.argmax(1).masked_select(mask).eq( target.masked_select(mask))) n_total = torch.sum(mask) return loss, nll_loss, n_correct, n_total, stats_per_lang