def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id src_ids, src_mask = batch["input_ids"], batch["attention_mask"] tgt_ids = batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(tgt_ids) else: decoder_input_ids = shift_tokens_right( tgt_ids, pad_token_id, self.model.config.decoder_start_token_id) if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero batch["decoder_input_ids"] = decoder_input_ids self.save_readable_batch(batch) outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) lm_logits = outputs["logits"] if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) assert lm_logits.shape[-1] == self.vocab_size loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id) return (loss, )
def __call__(self, batch) -> Dict[str, torch.Tensor]: if hasattr(self.tokenizer, "prepare_seq2seq_batch"): batch = self._encode(batch) input_ids, attention_mask, labels = ( batch["input_ids"], batch["attention_mask"], batch["labels"], ) else: input_ids = torch.stack([x["input_ids"] for x in batch]) attention_mask = torch.stack([x["attention_mask"] for x in batch]) labels = torch.stack([x["labels"] for x in batch]) labels = trim_batch(labels, self.pad_token_id) input_ids, attention_mask = trim_batch( input_ids, self.pad_token_id, attention_mask=attention_mask) if isinstance(self.tokenizer, T5Tokenizer): decoder_input_ids = self._shift_right_t5(labels) else: decoder_input_ids = shift_tokens_right(labels, self.pad_token_id) batch = { "input_ids": input_ids, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, "labels": labels, } return batch
def preprocess_data_mbart(data): input_text, target_text, tokenizer, args = data tokenized_example = tokenizer.prepare_seq2seq_batch( src_texts=[input_text], tgt_texts=[target_text], src_lang=args.src_lang, tgt_lang=args.tgt_lang, max_length=args.max_seq_length, padding="max_length", # pad_to_max_length=True won't work in this case return_tensors="pt", truncation=True, ) decoder_input_ids = tokenized_example["labels"].clone() decoder_input_ids = shift_tokens_right(decoder_input_ids, tokenizer.pad_token_id) labels = tokenized_example["labels"] labels[labels == tokenizer.pad_token_id] = -100 return { "input_ids": tokenized_example["input_ids"].squeeze(), "attention_mask": tokenized_example["attention_mask"].squeeze(), "decoder_input_ids": decoder_input_ids.squeeze(), "labels": labels.squeeze(), }
def test_seq2seq_max_target_length(self): batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10) batch["decoder_input_ids"] = shift_tokens_right( batch.labels, self.tokenizer.pad_token_id) self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.decoder_input_ids.shape[1], 10) # max_target_length will default to max_length if not specified batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3) batch["decoder_input_ids"] = shift_tokens_right( batch.labels, self.tokenizer.pad_token_id) self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.decoder_input_ids.shape[1], 3)
def test_batch_fairseq_parity(self): batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch( self.src_text, tgt_texts=self.tgt_text, return_tensors="pt") batch["decoder_input_ids"] = shift_tokens_right( batch.labels, self.tokenizer.pad_token_id) for k in batch: batch[k] = batch[k].tolist() # batch = {k: v.tolist() for k,v in batch.items()} # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4 # batch.decoder_inputs_ids[0][0] == assert batch.input_ids[1][-2:] == [2, EN_CODE] assert batch.decoder_input_ids[1][0] == RO_CODE assert batch.decoder_input_ids[1][-1] == 2 assert batch.labels[1][-2:] == [2, RO_CODE]
def test_seq2seq_dataset_truncation(self, tok_name): tokenizer = AutoTokenizer.from_pretrained(tok_name) tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) max_src_len = 4 max_tgt_len = 8 assert max_len_target > max_src_len # Will be truncated assert max_len_source > max_src_len # Will be truncated src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error. train_dataset = Seq2SeqDataset( tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=max_src_len, max_target_length=max_tgt_len, # ignored src_lang=src_lang, tgt_lang=tgt_lang, ) dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) for batch in dataloader: assert isinstance(batch, dict) assert batch["attention_mask"].shape == batch["input_ids"].shape # show that articles were trimmed. assert batch["input_ids"].shape[1] == max_src_len # show that targets are the same len assert batch["labels"].shape[1] == max_tgt_len if tok_name != MBART_TINY: continue # check language codes in correct place batch["decoder_input_ids"] = shift_tokens_right( batch["labels"], tokenizer.pad_token_id) assert batch["decoder_input_ids"][ 0, 0].item() == tokenizer.lang_code_to_id[tgt_lang] assert batch["decoder_input_ids"][ 0, -1].item() == tokenizer.eos_token_id assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id assert batch["input_ids"][ 0, -1].item() == tokenizer.lang_code_to_id[src_lang] break # No need to test every batch
def test_enro_tokenizer_prepare_seq2seq_batch(self): batch = self.tokenizer.prepare_seq2seq_batch( self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt") batch["decoder_input_ids"] = shift_tokens_right( batch.labels, self.tokenizer.pad_token_id) self.assertIsInstance(batch, BatchEncoding) self.assertEqual((2, 14), batch.input_ids.shape) self.assertEqual((2, 14), batch.attention_mask.shape) result = batch.input_ids.tolist()[0] self.assertListEqual(self.expected_src_tokens, result) self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS # Test that special tokens are reset self.assertEqual(self.tokenizer.prefix_tokens, []) self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_forward(self): src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."] expected_ids = [38, 121, 14, 697, 38848, 0] model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt, return_tensors="pt").to( torch_device ) self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist()) desired_keys = { "input_ids", "attention_mask", "labels", } self.assertSetEqual(desired_keys, set(model_inputs.keys())) model_inputs["decoder_input_ids"] = shift_tokens_right(model_inputs.labels, self.tokenizer.pad_token_id) model_inputs["return_dict"] = True model_inputs["use_cache"] = False with torch.no_grad(): outputs = self.model(**model_inputs) max_indices = outputs.logits.argmax(-1) self.tokenizer.batch_decode(max_indices)
def _step(self, batch: dict) -> tuple: """Compute the loss for a batch""" pad_token_id = self.tokenizer.pad_token_id input_ids, src_mask, labels = batch["input_ids"], batch[ "attention_mask"], batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(labels) else: decoder_input_ids = shift_tokens_right(labels, pad_token_id) # noinspection PyCallingNonCallable student_outputs = self( input_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, output_hidden_states=self.do_calc_hidden_loss, output_attentions=False, use_cache=False, ) lm_logits = student_outputs["logits"] # Same cross entropy vs. label smoothing logic as finetune.py assert lm_logits.shape[-1] == self.model.config.vocab_size if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id) student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) else: lprobs = nn.functional.log_softmax(lm_logits, dim=-1) student_lm_loss, _ = label_smoothed_nll_loss( lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id) def zero_tensor(): return torch.tensor(0.0).type_as(student_lm_loss) teacher_enc_outputs = student_outputs[ "encoder_last_hidden_state"] # use this unless self.different_base_models hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor() if self.different_encoder: # compute encoder hidden state loss all_teacher_encoder_outputs = self.teacher.get_encoder()( input_ids, attention_mask=src_mask, output_hidden_states=self.do_calc_hidden_loss, ) if self.different_base_models: teacher_enc_outputs = all_teacher_encoder_outputs[ "last_hidden_state"] elif self.do_calc_hidden_loss: hid_loss_enc = self.calc_hidden_loss( src_mask, student_outputs["encoder_hidden_states"], all_teacher_encoder_outputs["hidden_states"], self.e_matches, normalize_hidden=self.hparams.normalize_hidden, ) teacher_outputs = self.teacher( input_ids, attention_mask=src_mask, encoder_outputs=(teacher_enc_outputs, ), decoder_input_ids=decoder_input_ids, output_hidden_states=self.do_calc_hidden_loss, use_cache= False, # since we are not passing labels, never let this default to True ) dec_mask = decoder_input_ids.ne(pad_token_id) loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs["logits"]) if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states hid_loss_dec = self.calc_hidden_loss( dec_mask, student_outputs["decoder_hidden_states"], teacher_outputs["decoder_hidden_states"], self.d_matches, normalize_hidden=self.hparams.normalize_hidden, ) blended_loss = (self.alpha_ce * loss_ce + self.alpha_mlm * student_lm_loss + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)) return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
pad_token_id = tokenizer.pad_token_id mask_token_id = tokenizer.mask_token_id # model = BartModel.from_pretrained('./bart-base', return_dict=True) model = BartForConditionalGeneration.from_pretrained(tmp_model_name_path, return_dict=True).to(device) model.eval() st, ed = 0, 0 all_loss = [] while ed < len(ipt): st, ed = ed, (ed + batch_size) if (ed + batch_size < len(ipt)) else len(ipt) input_ids = tokenizer(ipt[st:ed], return_tensors="pt", padding=True, truncation=True, max_length=1000).input_ids.to(device) with torch.no_grad(): src_ids = input_ids tgt_ids = tokenizer(opt[st:ed], return_tensors="pt", padding=True, truncation=True, max_length=1000).input_ids.to(device) # tgt_ids = torch.cat([torch.zeros([batch_size, 1], dtype=tgt_ids.dtype), tgt_ids], 1) decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) # print(src_ids, tokenizer.decode(src_ids[0], skip_special_tokens=False)) # print(tgt_ids, tokenizer.decode(tgt_ids[0], skip_special_tokens=False)) outputs = model(src_ids, decoder_input_ids=decoder_input_ids, use_cache=False) lm_logits = outputs["logits"] # print(src_ids.size(), lm_logits.size(), decoder_input_ids.size()) tmp_batch_size = lm_logits.size()[0] pad_pos = torch.eq(tgt_ids, pad_token_id).to(torch.float) sen_pos = torch.eq(tgt_ids, mask_token_id).to(torch.float) dis_pos = torch.cat([torch.zeros([tmp_batch_size, 1]).to(sen_pos.device), sen_pos[:, :-1]], 1) loss_mask = 1 - (pad_pos + sen_pos + dis_pos) # Same behavior as modeling_bart.py, besides ignoring pad_token_id ce_loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def forward( self, input_ids=None, attention_mask=None, vis_inputs=None, vis_attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, reduce_loss=False, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id) outputs = self.model( input_ids, attention_mask=attention_mask, vis_inputs=vis_inputs, vis_attention_mask=vis_attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias masked_lm_loss = None if labels is not None: # loss_fct = CrossEntropyLoss() if reduce_loss: loss_fct = CrossEntropyLoss(ignore_index=-100) else: loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none') masked_lm_loss = loss_fct( lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (lm_logits, ) + outputs[1:] return ((masked_lm_loss, ) + output) if masked_lm_loss is not None else output return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, )
def forward( self, input_ids=None, attention_mask=None, vis_inputs=None, vis_attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): # different to other models, Bart automatically creates decoder_input_ids from # input_ids if no decoder_input_ids are provided if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( input_ids, self.config.pad_token_id, self.config.decoder_start_token_id) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, vis_inputs=vis_inputs, vis_attention_mask=vis_attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=False elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id).to( dtype=torch.float, device=input_ids.device) if vis_attention_mask is None: B, L = attention_mask.size() V_L = encoder_outputs[0].size(1) - L vis_attention_mask = attention_mask.new_ones(B, V_L) encoder_attention_mask = torch.cat( [attention_mask, vis_attention_mask], dim=1) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], # encoder_attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, )
def _valid_step(self, batch: dict) -> dict: pad_token_id = self.tokenizer.pad_token_id src_ids, src_mask = batch["input_ids"], batch["attention_mask"] tgt_ids = batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(tgt_ids) else: decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero batch["decoder_input_ids"] = decoder_input_ids self.save_readable_batch(batch) outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) lm_logits = outputs["logits"] if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) assert lm_logits.shape[-1] == self.vocab_size loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) p_tensor = torch.all(torch.max(lm_logits, 2)[1] == tgt_ids, 1) acc = p_tensor.sum() / sum(p_tensor.shape) lm_loss, ti_loss = loss + 0., loss + 0. # ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id, reduction="none") # assert lm_logits.shape[-1] == self.vocab_size # batch_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) # lm_mask = (batch["loss_labels"].eq(0).to(torch.float)[:, None] * torch.ones_like(tgt_ids) * (1 - tgt_ids.eq(pad_token_id).to(torch.float))).view(-1) # ti_mask = (batch["loss_labels"].eq(1).to(torch.float)[:, None] * torch.ones_like(tgt_ids) * (1 - tgt_ids.eq(pad_token_id).to(torch.float))).view(-1) # lm_loss = torch.sum(batch_loss * lm_mask) / (torch.sum(lm_mask) + 1e-20) # ti_loss = torch.sum(batch_loss * ti_mask) / (torch.sum(ti_mask) + 1e-20) # loss = torch.sum(batch_loss * (lm_mask + ti_mask)) / (torch.sum(lm_mask + ti_mask) + 1e-20) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id) lm_loss, ti_loss = loss + 0., ti_loss + 0. # print(src_ids) # print(batch.keys()) # print(batch["ids"].cpu().numpy()) # print(batch["loss_labels"].cpu().numpy()) # for i in range(3): # print(self.tokenizer.convert_ids_to_tokens(batch["input_ids"].cpu().numpy()[i])) # print(self.tokenizer.convert_ids_to_tokens(decoder_input_ids.cpu().numpy()[i])) # print(self.tokenizer.convert_ids_to_tokens(batch["labels"].cpu().numpy()[i])) # # print(self.tokenizer.unk_token, self.tokenizer.pad_token, self.tokenizer.eos_token, self.tokenizer.bos_token, self.tokenizer.cls_token, self.tokenizer.mask_token) # print("="*10) # print("="*30) # print(loss, lm_loss, ti_loss) # exit() return {'loss': loss, 'acc': acc}
def forward( self, input_ids=None, input_intensity_labels=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): # decoder_input设为了一个序列长度为1的全零向量,所以这一步应该也不会执行 if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( input_ids, self.config.pad_token_id, self.config.decoder_start_token_id ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) # input_ids = input_ids elif inputs_embeds is not None: inputs_embeds = self.resnet(inputs_embeds) # inputs_embeds = self.batchnorm(inputs_embeds) input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if input_intensity_labels is not None: inputs_intensity_embeds = self.shared(input_intensity_labels) # inputs_intensity_embeds = self.batchnorm(inputs_intensity_embeds) inputs_embeds = torch.cat((inputs_embeds, inputs_intensity_embeds), dim=-1) inputs_embeds = self.linear2(inputs_embeds) attention_mask = torch.ones(input_shape[0], input_shape[1]).to(self.args.device) if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True # config中将return dict设置为了false,所以这一步应该不会执行 elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # decoder # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = () for i in range(4): if decoder_input_ids is not None: dec_size = decoder_input_ids.size()[:2] decoder_attention_mask = torch.ones(dec_size[0], dec_size[1]).to(self.args.device) elif decoder_inputs_embeds is not None: dec_size = decoder_inputs_embeds.size()[:2] decoder_attention_mask = torch.ones(dec_size[0], dec_size[1]).to(self.args.device) else: decoder_attention_mask = None # make masks if user doesn't supply # if not use_cache: # if decoder_input_ids is not None: # decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( # self.config, # input_ids, # decoder_input_ids=decoder_input_ids, # decoder_padding_mask=decoder_attention_mask, # causal_mask_dtype=self.shared.weight.dtype, # ) # else: # decoder_input_embeds, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( # self.config, # input_ids, # decoder_input_ids=decoder_inputs_embeds, # decoder_padding_mask=decoder_attention_mask, # causal_mask_dtype=self.shared.weight.dtype, # ) # else: # decoder_padding_mask, causal_mask = None, None assert decoder_input_ids is not None or decoder_inputs_embeds is not None # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, encoder_head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) decoder_inputs_embeds = torch.cat((decoder_inputs_embeds, decoder_outputs[0][:, -1:, :]), dim=1) decoder_input_ids = None # return if not return_dict: out = decoder_outputs + encoder_outputs out = (self.linear(out[0]),) + out else: out = Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) return out