def test_loss_fn(self): model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY, return_dict=True) input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs[ "attention_mask"] target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]], dtype=torch.long, device=model.device) decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line? lm_labels = target_ids[:, 1:].clone() # why clone? model_computed_loss = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, labels=lm_labels, use_cache=False).loss logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits lprobs = torch.nn.functional.log_softmax(logits, dim=-1) smoothed_loss, nll_loss = label_smoothed_nll_loss( lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id) with self.assertRaises(AssertionError): # TODO: understand why this breaks self.assertEqual(nll_loss, model_computed_loss)
def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id src_ids, src_mask = batch["input_ids"], batch["attention_mask"] if isinstance(self.model, T5ForConditionalGeneration): tgt_ids = batch["labels"] decoder_input_ids = self.model._shift_right(tgt_ids) else: #decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) y = batch["labels"] decoder_input_ids = y[:, :-1].contiguous() tgt_ids = y[:, 1:].clone() if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero batch["decoder_input_ids"] = decoder_input_ids self.save_readable_batch(batch) outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) lm_logits = outputs[0] if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) assert lm_logits.shape[-1] == self.vocab_size loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id) return (loss, )
def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id src_ids, src_mask = batch["input_ids"], batch["attention_mask"] tgt_ids = batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(tgt_ids) else: decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) lm_logits = outputs[0] if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) assert lm_logits.shape[-1] == self.vocab_size loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id) return (loss, )
def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id source_ids, source_mask, target_ids = batch["input_ids"], batch[ "attention_mask"], batch["decoder_input_ids"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(target_ids) lm_labels = target_ids else: decoder_input_ids = target_ids[:, :-1].contiguous( ) # Why this line? lm_labels = target_ids[:, 1:].clone() # why clone? outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False) if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) lm_logits = outputs[0] assert lm_logits.shape[-1] == self.model.config.vocab_size loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1)) else: lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id) return (loss, )
def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id # source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] source_ids, source_mask, target_ids, topic_p = batch[ "input_ids"], batch["attention_mask"], batch[ "decoder_input_ids"], batch['topic_p'] decoder_input_ids = target_ids[:, :-1].contiguous() lm_labels = target_ids[:, 1:].clone() outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, topic_p=topic_p, use_cache=False) # calculate loss if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) lm_logits = outputs[0] assert lm_logits.shape[-1] == self.model.config.vocab_size loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1)) else: lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id) return (loss, )
def forward(self, input_ids, attention_mask=None, encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None, use_cache=False, is_training=False): if is_training: _decoder_input_ids = shift_tokens_right(decoder_input_ids, self.config.pad_token_id) else: _decoder_input_ids = decoder_input_ids outputs = self.model( input_ids, attention_mask=attention_mask, encoder_outputs=encoder_outputs, decoder_input_ids=_decoder_input_ids, decoder_attention_mask=decoder_attention_mask, decoder_cached_states=decoder_cached_states, use_cache=use_cache, ) lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias) if is_training: # loss_fct = nn.CrossEntropyLoss(reduction="mean", ignore_index=self.config.pad_token_id) # loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), # decoder_input_ids.view(-1)) lprobs = F.log_softmax(lm_logits, dim=-1) loss, _ = label_smoothed_nll_loss(lprobs, decoder_input_ids, epsilon=0.1, ignore_index=self.config.pad_token_id) return loss return (lm_logits, ) + outputs[1:]
def _compute_loss(self, model, inputs): inputs = copy.deepcopy(inputs) if self.args.label_smoothing == 0: if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: # force training to ignore pad token labels = inputs.pop("labels") logits = model(**inputs, use_cache=False)[0] loss_fct = torch.nn.CrossEntropyLoss( ignore_index=self.config.pad_token_id) loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) else: # compute usual loss via models loss, logits = model(**inputs, use_cache=False)[:2] else: # compute label smoothed loss labels = inputs.pop("labels") logits = model(**inputs, use_cache=False)[0] lprobs = torch.nn.functional.log_softmax(logits, dim=-1) loss, _ = label_smoothed_nll_loss( lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id) return loss, logits
def forward(self, outs, graph_state, target_rel=None, work=False): def get_scores(dep, head): head = torch.tanh(self.transfer_head(head)) dep = torch.tanh(self.transfer_dep(dep)) head = F.dropout(head, p=self.dropout, training=self.training) dep = F.dropout(dep, p=self.dropout, training=self.training) dep_num, bsz, _ = dep.size() head_num = head.size(0) bias_dep = dep.new_ones((dep_num, bsz, 1)) bias_head = head.new_ones((head_num, bsz, 1)) # seq_len x bsz x dim dep = torch.cat([dep, bias_dep], 2) head = torch.cat([head, bias_head], 2) #bsz x dep_num x vocab_size x dim dep = self.proj(dep).view(dep_num, bsz, self.vocabs['rel'].size, -1).transpose(0, 1).contiguous() #bsz x dim x head_num head = head.permute(1, 2, 0) #bsz x dep_num x vocab_size x head_num scores = torch.bmm( dep.view(bsz, dep_num * self.vocabs['rel'].size, -1), head).view(bsz, dep_num, self.vocabs['rel'].size, head_num) return scores scores = get_scores(outs, graph_state).permute(1, 0, 3, 2).contiguous() dep_num, bsz, _ = outs.size() head_num = graph_state.size(0) log_probs = F.log_softmax(scores, dim=-1) _, rel = torch.max(log_probs, -1) if work: #dep_num x bsz x head x vocab return log_probs rel_mask = torch.eq(target_rel, self.vocabs['rel'].token2idx(NIL)) + torch.eq( target_rel, self.vocabs['rel'].token2idx(PAD)) rel_acc = (torch.eq(rel, target_rel).float().masked_fill_(rel_mask, 0.)).sum().item() rel_tot = rel_mask.numel() - rel_mask.float().sum().item() if not self.training: print('rel acc %.3f' % (rel_acc / rel_tot)) rel_loss = label_smoothed_nll_loss( log_probs.view(-1, self.vocabs['rel'].size), target_rel.view(-1), 0.).view(dep_num, bsz, head_num) rel_loss = rel_loss.masked_fill_(rel_mask, 0.).sum((0, 2)) return rel_loss
def _compute_loss(self, logits, labels, ignore_index): if self.args.label_smoothing == 0: # Same behavior as modeling_bart.py loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) assert logits.shape[-1] == self.model.config.vocab_size loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) else: lprobs = torch.nn.functional.log_softmax(logits, dim=-1) loss, nll_loss = label_smoothed_nll_loss(lprobs, labels, self.args.label_smoothing, ignore_index=ignore_index) return loss
def _training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer) -> float: model.train() for k, v in inputs.items(): if isinstance(v, torch.Tensor): inputs[k] = v.to(self.args.device) # Our model outputs do not work with DataParallel, so forcing return tuple. if isinstance(model, nn.DataParallel): inputs["return_tuple"] = True if self.label_smoothing == 0: outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) else: labels = inputs.pop("labels") labels[labels == -100] = model.config.pad_token_id outputs = model(**inputs) lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, labels, self.label_smoothing, ignore_index=model.config.pad_token_id) if self.args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if self.args.gradient_accumulation_steps > 1: loss = loss / self.args.gradient_accumulation_steps if self.args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() return loss.item()
def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id src_ids, src_mask = batch["input_ids"], batch["attention_mask"] tgt_ids = batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(tgt_ids) else: decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero batch["decoder_input_ids"] = decoder_input_ids self.save_readable_batch(batch) outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) lm_logits = outputs[0] if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) assert lm_logits.shape[-1] == self.vocab_size loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id ) batch_size = src_ids.shape[0] loss_log = {'ce': loss.item()} if self.unlikelihood_training: ul_loss = self.unlikelihood_loss(decoder_input_ids, lm_logits, self.weight_vector, self.unlikelihood_selective_penalty) ul_loss_weighted = ul_loss * self.unlikelihood_alpha loss_log['ul_logr'] = ul_loss_weighted.item()/batch_size loss += ul_loss_weighted self.losses.append(loss_log) return (loss,)
def step(self, batch): source_ids, source_mask, target_ids = ( batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"], ) decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line? lm_labels = target_ids[:, 1:].clone() # why clone? outputs = self( source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False, ) lprobs = F.log_softmax(outputs[0], dim=-1) loss = label_smoothed_nll_loss( lprobs, lm_labels, epsilon=0.1, ignore_index=self.tokenizer.pad_token_id) return loss[0]
def _step(self, batch): # assert is_frozen(self.teacher) pad_token_id = self.tokenizer.pad_token_id input_ids, src_mask, tgt_ids = batch["input_ids"], batch[ "attention_mask"], batch["labels"] decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) # noinspection PyCallingNonCallable lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self( input_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, output_hidden_states=True, output_attentions=False, use_cache=False, ) # TODO(@sshleifer): return_dict=True cleanup # Same cross entropy vs. label smoothing logic as finetune.py assert lm_logits.shape[-1] == self.model.config.vocab_size if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) student_lm_loss, _ = label_smoothed_nll_loss( lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id) def zero_tensor(): return torch.tensor(0.0).type_as(student_lm_loss) loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor( ), zero_tensor() if self.different_encoder: with torch.no_grad(): teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder( input_ids, attention_mask=src_mask, output_hidden_states=True) if self.hparams.alpha_encoder_loss > 0: loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask) hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy) teacher_enc_outputs = (enc_outputs, ) assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs) with torch.no_grad(): tloss, tlogits, tdec_hidden, _ = self.teacher( input_ids, attention_mask=src_mask, encoder_outputs=teacher_enc_outputs, decoder_input_ids=decoder_input_ids, lm_labels=tgt_ids, output_hidden_states=True, ) dec_mask = decoder_input_ids.ne(pad_token_id) loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss( dec_mask, lm_logits, tlogits) if self.alpha_hid > 0: hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches) blended_loss = (self.alpha_ce * loss_ce + self.alpha_mlm * student_lm_loss + self.hparams.alpha_encoder_loss * loss_encoder + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)) return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
def _step(self, batch: dict) -> tuple: """Compute the loss for a batch""" pad_token_id = self.tokenizer.pad_token_id input_ids, src_mask, labels = batch["input_ids"], batch[ "attention_mask"], batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(labels) else: decoder_input_ids = shift_tokens_right(labels, pad_token_id) # noinspection PyCallingNonCallable student_outputs = self( input_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, output_hidden_states=self.do_calc_hidden_loss, output_attentions=False, use_cache=False, ) lm_logits = student_outputs["logits"] # Same cross entropy vs. label smoothing logic as finetune.py assert lm_logits.shape[-1] == self.model.config.vocab_size if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id) student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) else: lprobs = nn.functional.log_softmax(lm_logits, dim=-1) student_lm_loss, _ = label_smoothed_nll_loss( lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id) def zero_tensor(): return torch.tensor(0.0).type_as(student_lm_loss) teacher_enc_outputs = student_outputs[ "encoder_last_hidden_state"] # use this unless self.different_base_models hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor() if self.different_encoder: # compute encoder hidden state loss all_teacher_encoder_outputs = self.teacher.get_encoder()( input_ids, attention_mask=src_mask, output_hidden_states=self.do_calc_hidden_loss, ) if self.different_base_models: teacher_enc_outputs = all_teacher_encoder_outputs[ "last_hidden_state"] elif self.do_calc_hidden_loss: hid_loss_enc = self.calc_hidden_loss( src_mask, student_outputs["encoder_hidden_states"], all_teacher_encoder_outputs["hidden_states"], self.e_matches, normalize_hidden=self.hparams.normalize_hidden, ) teacher_outputs = self.teacher( input_ids, attention_mask=src_mask, encoder_outputs=(teacher_enc_outputs, ), decoder_input_ids=decoder_input_ids, output_hidden_states=self.do_calc_hidden_loss, use_cache= False, # since we are not passing labels, never let this default to True ) dec_mask = decoder_input_ids.ne(pad_token_id) loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs["logits"]) if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states hid_loss_dec = self.calc_hidden_loss( dec_mask, student_outputs["decoder_hidden_states"], teacher_outputs["decoder_hidden_states"], self.d_matches, normalize_hidden=self.hparams.normalize_hidden, ) blended_loss = (self.alpha_ce * loss_ce + self.alpha_mlm * student_lm_loss + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)) return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
def _step(self, batch): # assert is_frozen(self.teacher) copied_decoder_layers pad_token_id = self.tokenizer.pad_token_id input_ids, src_mask, labels = batch["input_ids"], batch[ "attention_mask"], batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(labels) else: decoder_input_ids = shift_tokens_right(labels, pad_token_id) # noinspection PyCallingNonCallable lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self( input_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, output_hidden_states=True, output_attentions=False, use_cache=False, ) # Same cross entropy vs. label smoothing logic as finetune.py assert lm_logits.shape[-1] == self.model.config.vocab_size if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) student_lm_loss, _ = label_smoothed_nll_loss( lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id) def zero_tensor(): return torch.tensor(0.0).type_as(student_lm_loss) hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor() if self.different_encoder: # compute encoder hidden state loss with torch.no_grad(): teacher_enc_hid = self.teacher.get_encoder()( input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True).hidden_states hid_loss_enc = self.calc_hidden_loss( src_mask, enc_hidden_state, teacher_enc_hid, self.e_matches, normalize_hidden=self.hparams.normalize_hidden, ) with torch.no_grad(): outputs = self.teacher( input_ids, attention_mask=src_mask, encoder_outputs=(enc_outputs, ), decoder_input_ids=decoder_input_ids, lm_labels=labels, output_hidden_states=True, return_dict=True, ) tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states dec_mask = decoder_input_ids.ne(pad_token_id) loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits) if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states hid_loss_dec = self.calc_hidden_loss( dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden) blended_loss = (self.alpha_ce * loss_ce + self.alpha_mlm * student_lm_loss + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)) return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
def _valid_step(self, batch: dict) -> dict: pad_token_id = self.tokenizer.pad_token_id src_ids, src_mask = batch["input_ids"], batch["attention_mask"] tgt_ids = batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(tgt_ids) else: decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero batch["decoder_input_ids"] = decoder_input_ids self.save_readable_batch(batch) outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) lm_logits = outputs["logits"] if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) assert lm_logits.shape[-1] == self.vocab_size loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) p_tensor = torch.all(torch.max(lm_logits, 2)[1] == tgt_ids, 1) acc = p_tensor.sum() / sum(p_tensor.shape) lm_loss, ti_loss = loss + 0., loss + 0. # ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id, reduction="none") # assert lm_logits.shape[-1] == self.vocab_size # batch_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) # lm_mask = (batch["loss_labels"].eq(0).to(torch.float)[:, None] * torch.ones_like(tgt_ids) * (1 - tgt_ids.eq(pad_token_id).to(torch.float))).view(-1) # ti_mask = (batch["loss_labels"].eq(1).to(torch.float)[:, None] * torch.ones_like(tgt_ids) * (1 - tgt_ids.eq(pad_token_id).to(torch.float))).view(-1) # lm_loss = torch.sum(batch_loss * lm_mask) / (torch.sum(lm_mask) + 1e-20) # ti_loss = torch.sum(batch_loss * ti_mask) / (torch.sum(ti_mask) + 1e-20) # loss = torch.sum(batch_loss * (lm_mask + ti_mask)) / (torch.sum(lm_mask + ti_mask) + 1e-20) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id) lm_loss, ti_loss = loss + 0., ti_loss + 0. # print(src_ids) # print(batch.keys()) # print(batch["ids"].cpu().numpy()) # print(batch["loss_labels"].cpu().numpy()) # for i in range(3): # print(self.tokenizer.convert_ids_to_tokens(batch["input_ids"].cpu().numpy()[i])) # print(self.tokenizer.convert_ids_to_tokens(decoder_input_ids.cpu().numpy()[i])) # print(self.tokenizer.convert_ids_to_tokens(batch["labels"].cpu().numpy()[i])) # # print(self.tokenizer.unk_token, self.tokenizer.pad_token, self.tokenizer.eos_token, self.tokenizer.bos_token, self.tokenizer.cls_token, self.tokenizer.mask_token) # print("="*10) # print("="*30) # print(loss, lm_loss, ti_loss) # exit() return {'loss': loss, 'acc': acc}