def test_seq2seq_max_target_length(self): batch = self.tokenizer.prepare_seq2seq_batch( self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10 ) batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.decoder_input_ids.shape[1], 10) # max_target_length will default to max_length if not specified batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3) batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.decoder_input_ids.shape[1], 3)
def forward(self, input_ids, attention_mask=None, encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None, use_cache=False, is_training=False): if is_training: _decoder_input_ids = shift_tokens_right(decoder_input_ids, self.config.pad_token_id) else: _decoder_input_ids = decoder_input_ids outputs = self.model( input_ids, attention_mask=attention_mask, encoder_outputs=encoder_outputs, decoder_input_ids=_decoder_input_ids, decoder_attention_mask=decoder_attention_mask, decoder_cached_states=decoder_cached_states, use_cache=use_cache, ) lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias) if is_training: # loss_fct = nn.CrossEntropyLoss(reduction="mean", ignore_index=self.config.pad_token_id) # loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), # decoder_input_ids.view(-1)) lprobs = F.log_softmax(lm_logits, dim=-1) loss, _ = label_smoothed_nll_loss(lprobs, decoder_input_ids, epsilon=0.1, ignore_index=self.config.pad_token_id) return loss return (lm_logits, ) + outputs[1:]
def training_step(self, batch, batch_idx): src, tgt = batch[0], batch[1] src_input = self.tokenizer.encode_batch(src, max_length=self.max_len) tgt_input = self.tokenizer.encode_batch(tgt, max_length=self.max_len) input_ids = src_input["input_ids"].to(self.device) attention_mask = src_input["attention_mask"].to(self.device) labels = tgt_input["input_ids"].to(self.device) decoder_input_ids = shift_tokens_right( labels, self.tokenizer.token2idx["[PAD]"]) outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, ) lm_logits = outputs[0] loss_fn = torch.nn.CrossEntropyLoss( ignore_index=self.tokenizer.token2idx["[PAD]"]) lm_loss = loss_fn(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) self.save_model() return {"loss": lm_loss}
def convert_to_features(example_batch: Dict[str, Any]) -> Dict[str, Any]: input_encodings = tokenizer.batch_encode_plus( example_batch["utterance"], max_length=32, truncation=True, padding="longest", return_tensors="pt", ) target_encodings = tokenizer.batch_encode_plus( example_batch["code"], max_length=32, truncation=True, padding="longest", return_tensors="pt", ) labels = target_encodings["input_ids"] decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id) labels[labels[:, :] == model.config.pad_token_id] = -100 encodings = { "input_ids": input_encodings["input_ids"].numpy().copy(), "attention_mask": input_encodings["attention_mask"].numpy().copy(), "decoder_input_ids": decoder_input_ids.numpy().copy(), "labels": labels.numpy().copy(), } return encodings
def _step(self, batch): if batch['task'][0] == 'response': pad_token_id = self.tokenizer.pad_token_id target_ids = batch['target_ids'] decoder_input_ids = shift_tokens_right(target_ids, pad_token_id) outputs = self(input_ids=batch['source_ids'], attention_mask=batch['source_mask'], decoder_input_ids=decoder_input_ids, use_cache=False, task=batch['task'][0]) lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, target_ids, self.hparams.label_smoothing, ignore_index=pad_token_id) elif batch['task'][0] in ['cfemotion', 'emotion', 'sentiment']: outputs = self(input_ids=batch['source_ids'], attention_mask=batch['source_mask'], lm_labels=batch['label'], task=batch['task'][0]) loss = outputs[0] else: raise ValueError('The dataset contains an invalid task.') if self.hparams.task_weights: loss = (self.task_weights[self.tasks.index(batch['task'][0])] * loss) return loss
def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id src_ids, src_mask = batch["input_ids"], batch["attention_mask"] tgt_ids = batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(tgt_ids) else: decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) lm_logits = outputs[0] if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) assert lm_logits.shape[-1] == self.vocab_size loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id) return (loss, )
def preprocess_data_mbart(data): input_text, target_text, tokenizer, args = data tokenized_example = tokenizer.prepare_seq2seq_batch( src_texts=[input_text], tgt_texts=[target_text], src_lang=args.src_lang, tgt_lang=args.tgt_lang, max_length=args.max_seq_length, padding="max_length", # pad_to_max_length=True won't work in this case return_tensors="pt", truncation=True, ) decoder_input_ids = tokenized_example["labels"].clone() decoder_input_ids = shift_tokens_right(decoder_input_ids, tokenizer.pad_token_id) labels = tokenized_example["labels"] labels[labels == tokenizer.pad_token_id] = -100 return { "input_ids": tokenized_example["input_ids"].squeeze(), "attention_mask": tokenized_example["attention_mask"].squeeze(), "decoder_input_ids": decoder_input_ids.squeeze(), "labels": labels.squeeze(), }
def __call__(self, batch) -> Dict[str, torch.Tensor]: if hasattr(self.tokenizer, "prepare_seq2seq_batch"): batch = self._encode(batch) input_ids, attention_mask, labels = ( batch["input_ids"], batch["attention_mask"], batch["labels"], ) else: input_ids = torch.stack([x["input_ids"] for x in batch]) attention_mask = torch.stack([x["attention_mask"] for x in batch]) labels = torch.stack([x["labels"] for x in batch]) labels = trim_batch(labels, self.pad_token_id) input_ids, attention_mask = trim_batch( input_ids, self.pad_token_id, attention_mask=attention_mask) if isinstance(self.tokenizer, T5Tokenizer): decoder_input_ids = self._shift_right_t5(labels) else: decoder_input_ids = shift_tokens_right(labels, self.pad_token_id) batch = { "input_ids": input_ids, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, "labels": labels, } return batch
def test_shift_tokens_right(self): input_ids = torch.Tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]]).long() shifted = shift_tokens_right(input_ids, 1) n_pad_before = input_ids.eq(1).float().sum() n_pad_after = shifted.eq(1).float().sum() self.assertEqual(shifted.shape, input_ids.shape) self.assertEqual(n_pad_after, n_pad_before - 1) self.assertTrue(torch.eq(shifted[:, 0], 2).all())
def test_batch_fairseq_parity(self): batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch( self.src_text, tgt_texts=self.tgt_text, return_tensors="pt") batch["decoder_input_ids"] = shift_tokens_right( batch.labels, self.tokenizer.pad_token_id) for k in batch: batch[k] = batch[k].tolist() # batch = {k: v.tolist() for k,v in batch.items()} # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4 # batch.decoder_inputs_ids[0][0] == assert batch.input_ids[1][-2:] == [2, EN_CODE] assert batch.decoder_input_ids[1][0] == RO_CODE assert batch.decoder_input_ids[1][-1] == 2 assert batch.labels[1][-2:] == [2, RO_CODE]
def _step(self, batch): labels = torch.full( (len(batch['task']), 1), int(batch['task'][0] != 'response'), # fake dtype=torch.long).cuda() if batch['task'][0] == 'response': pad_token_id = self.tokenizer.pad_token_id target_ids = batch['target_ids'] decoder_input_ids = shift_tokens_right(target_ids, pad_token_id) outputs = self(input_ids=batch['source_ids'], attention_mask=batch['source_mask'], decoder_input_ids=decoder_input_ids, use_cache=False, task=batch['task'][0]) lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, target_ids, self.hparams.label_smoothing, ignore_index=pad_token_id) x = outputs[1] # last hidden state elif batch['task'][0] in ['cfemotion', 'emotion', 'sentiment']: outputs = self(input_ids=batch['source_ids'], attention_mask=batch['source_mask'], lm_labels=batch['label'], task=batch['task'][0]) loss = outputs[0] x = outputs[2] # last hidden state else: raise ValueError('The dataset contains an invalid task.') eos_mask = batch['source_ids'].eq(self.model.config.eos_token_id) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError( 'All examples must have the same number of <eos> tokens.') sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] logits = self.discriminator(sentence_representation) adv_loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1)) return loss + adv_loss
def test_enro_tokenizer_prepare_seq2seq_batch(self): batch = self.tokenizer.prepare_seq2seq_batch( self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), ) batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) self.assertIsInstance(batch, BatchEncoding) self.assertEqual((2, 14), batch.input_ids.shape) self.assertEqual((2, 14), batch.attention_mask.shape) result = batch.input_ids.tolist()[0] self.assertListEqual(self.expected_src_tokens, result) self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS # Test that special tokens are reset self.assertEqual(self.tokenizer.prefix_tokens, []) self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_seq2seq_dataset_truncation(tok_name): tokenizer = AutoTokenizer.from_pretrained(tok_name) tmp_dir = make_test_data_dir() max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) max_src_len = 4 max_tgt_len = 8 assert max_len_target > max_src_len # Will be truncated assert max_len_source > max_src_len # Will be truncated src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error. train_dataset = Seq2SeqDataset( tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=max_src_len, max_target_length=max_tgt_len, # ignored src_lang=src_lang, tgt_lang=tgt_lang, ) dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) for batch in dataloader: assert isinstance(batch, dict) assert batch["attention_mask"].shape == batch["input_ids"].shape # show that articles were trimmed. assert batch["input_ids"].shape[1] == max_src_len # show that targets are the same len assert batch["labels"].shape[1] == max_tgt_len if tok_name != MBART_TINY: continue # check language codes in correct place batch["decoder_input_ids"] = shift_tokens_right( batch["labels"], tokenizer.pad_token_id) assert batch["decoder_input_ids"][ 0, 0].item() == tokenizer.lang_code_to_id[tgt_lang] assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id assert batch["input_ids"][ 0, -1].item() == tokenizer.lang_code_to_id[src_lang] break # No need to test every batch
def test_forward(self): src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."] expected_ids = [38, 121, 14, 697, 38848, 0] model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device) self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist()) desired_keys = { "input_ids", "attention_mask", "labels", } self.assertSetEqual(desired_keys, set(model_inputs.keys())) model_inputs["decoder_input_ids"] = shift_tokens_right(model_inputs.labels, self.tokenizer.pad_token_id) model_inputs["return_dict"] = True model_inputs["use_cache"] = False with torch.no_grad(): outputs = self.model(**model_inputs) max_indices = outputs.logits.argmax(-1) self.tokenizer.batch_decode(max_indices)
def _step_D(self, batch): labels = torch.full( (len(batch['task']), 1), int(batch['task'][0] == 'response'), # real dtype=torch.long).cuda() if batch['task'][0] == 'response': pad_token_id = self.tokenizer.pad_token_id target_ids = batch['target_ids'] decoder_input_ids = shift_tokens_right(target_ids, pad_token_id) outputs = self(input_ids=batch['source_ids'], attention_mask=batch['source_mask'], decoder_input_ids=decoder_input_ids, use_cache=False, task=batch['task'][0]) elif batch['task'][0] in ['cfemotion', 'emotion', 'sentiment']: outputs = self( input_ids=batch['source_ids'], attention_mask=batch['source_mask'], # lm_labels=batch['label'], task=batch['task'][0]) else: raise ValueError('The dataset contains an invalid task.') x = outputs[1] # last hidden state eos_mask = batch['source_ids'].eq(self.model.config.eos_token_id) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError( 'All examples must have the same number of <eos> tokens.') sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] logits = self.discriminator(sentence_representation) loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1)) return loss
def main(): parser = argparse.ArgumentParser( description= "Convert BART to LongBART. Replaces BART encoder's SelfAttnetion with LongformerSelfAttention" ) parser.add_argument( '--base_model', type=str, default='facebook/bart-large', help='The name or path of the base model you want to convert') parser.add_argument('--tokenizer_name_or_path', type=str, default='facebook/bart-large', help='The name or path of the tokenizer') parser.add_argument('--save_model_to', type=str, required=True, help='The path to save the converted model') parser.add_argument( '--attention_window', type=int, default=512, help='attention window size for longformer self attention (one sided)') parser.add_argument('--max_pos', type=int, default=4096 * 4, help='maximum encoder positions') args = parser.parse_args() if not os.path.exists(args.save_model_to): os.mkdir(args.save_model_to) create_long_model(save_model_to=args.save_model_to, base_model=args.base_model, tokenizer_name_or_path=args.tokenizer_name_or_path, attention_window=args.attention_window, max_pos=args.max_pos) tokenizer = BartTokenizer.from_pretrained(args.save_model_to) TXT = "My friends are <mask> but they eat too many carbs." model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained( args.save_model_to) model.model.encoder.config.gradient_checkpointing = True model.model.decoder.config.gradient_checkpointing = True data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048) input_ids = data['input_ids'] attention_mask = data['attention_mask'] decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id) logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0] masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() probs = logits[0, masked_index].softmax(dim=0) values, predictions = probs.topk(5) print(tokenizer.convert_ids_to_tokens(predictions))
def _step(self, batch): # assert is_frozen(self.teacher) pad_token_id = self.tokenizer.pad_token_id input_ids, src_mask, tgt_ids = batch["input_ids"], batch[ "attention_mask"], batch["labels"] decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) # noinspection PyCallingNonCallable lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self( input_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, output_hidden_states=True, output_attentions=False, use_cache=False, ) # TODO(@sshleifer): return_dict=True cleanup # Same cross entropy vs. label smoothing logic as finetune.py assert lm_logits.shape[-1] == self.model.config.vocab_size if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) student_lm_loss, _ = label_smoothed_nll_loss( lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id) def zero_tensor(): return torch.tensor(0.0).type_as(student_lm_loss) loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor( ), zero_tensor() if self.different_encoder: with torch.no_grad(): teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder( input_ids, attention_mask=src_mask, output_hidden_states=True) if self.hparams.alpha_encoder_loss > 0: loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask) hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy) teacher_enc_outputs = (enc_outputs, ) assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs) with torch.no_grad(): tloss, tlogits, tdec_hidden, _ = self.teacher( input_ids, attention_mask=src_mask, encoder_outputs=teacher_enc_outputs, decoder_input_ids=decoder_input_ids, lm_labels=tgt_ids, output_hidden_states=True, ) dec_mask = decoder_input_ids.ne(pad_token_id) loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss( dec_mask, lm_logits, tlogits) if self.alpha_hid > 0: hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches) blended_loss = (self.alpha_ce * loss_ce + self.alpha_mlm * student_lm_loss + self.hparams.alpha_encoder_loss * loss_encoder + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)) return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id src_ids, src_mask = batch["input_ids"], batch["attention_mask"] tgt_ids = batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(tgt_ids) else: decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero batch["decoder_input_ids"] = decoder_input_ids self.save_readable_batch(batch) outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) lm_logits = outputs[0] if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) assert lm_logits.shape[-1] == self.vocab_size loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) loss, nll_loss = label_smoothed_nll_loss( lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id) batch_size = src_ids.shape[0] loss_log = {'ce': loss.item()} if self.unlikelihood_training_tokens: ul_token_loss = self.unlikelihood_loss_token( decoder_input_ids, tgt_ids, lm_logits) ul_token_loss_weighted = ul_token_loss * self.unlikelihood_tokens_alpha print("*******************") print(f"loss: {loss}") print(f"UL token loss: {ul_token_loss_weighted}") print("*******************") loss_log['ul_token'] = ul_token_loss_weighted.item() / batch_size loss += ul_token_loss_weighted if self.unlikelihood_training_copy: ul_copy_loss = self.unlikelihood_loss_copying( src_ids, decoder_input_ids, lm_logits, self.unlikelihood_copy_n) ul_copy_loss_weighted = self.unlikelihood_copy_alpha * ul_copy_loss print("*******************") print(f"loss: {loss}") print(f"UL copy loss: {ul_copy_loss_weighted}") print("*******************") loss_log['ul_copy'] = ul_copy_loss_weighted.item() / batch_size loss += ul_copy_loss_weighted if self.unlikelihood_training: ul_loss = self.unlikelihood_loss( decoder_input_ids, lm_logits, self.weight_vector, self.unlikelihood_selective_penalty) ul_loss_weighted = ul_loss * self.unlikelihood_alpha print("*******************") print(f"loss: {loss}") print(f"UL loss: {ul_loss_weighted}") print("*******************") loss_log['ul_logr'] = ul_loss_weighted.item() / batch_size loss += ul_loss_weighted self.losses.append(loss_log) return (loss, )
def generate_decode_copy(model, args, input_ids): assert input_ids.shape[ 0] == 1, "Ngram repetition penalty only supported for batch size of 1 (for now)" alpha = args.decode_state_change_copy_alpha n = args.decode_state_change_copy_n temp = model model = SummarizationModule(args) model.model = temp.to('cuda') model.model.train() input_ids = input_ids.to('cuda') # initialize optimizer #optimizer = SGD(model.model.parameters(), lr=1.0) #optimizer.zero_grad() # first generate output normally tgt_ids, logs = model.model.generate( input_ids, do_sample=False, num_beams=1, max_length=1024, early_stopping=False, num_return_sequences=1, decoder_start_token_id=model.model.config.pad_token_id) # treat generated text as decoder target pad_token_id = model.model.config.pad_token_id decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) outputs = model(input_ids, decoder_input_ids=decoder_input_ids, use_cache=False) lm_logits = outputs[0] input_ids = input_ids.tolist() decoder_input_ids = decoder_input_ids.tolist() tgt_ids = tgt_ids.tolist() open('input_ids.txt', 'w').write(json.dumps(input_ids)) open('decoder_input_ids.txt', 'w').write(json.dumps(decoder_input_ids)) open('tgt_ids.txt', 'w').write(json.dumps(tgt_ids)) return None # apply repetition penalty loss = model.unlikelihood_loss_copying(input_ids, decoder_input_ids, lm_logits, n) * alpha #loss.backward() #optimizer.step() # now re-generate output if args.decode_method == 'greedy': gen_ids, logs = model.model.generate( input_ids, do_sample=False, ngram_copy_penalty=args.decode_ngram_copy_penalty, ngram_copy_weight=args.decode_ngram_copy_weight, max_length=args.max_target_length, early_stopping=False, num_return_sequences=1, decoder_start_token_id=model.model.config.pad_token_id) elif args.decode_method == 'beam': gen_ids, logs = model.model.generate( input_ids, ngram_copy_penalty=args.decode_ngram_copy_penalty, ngram_copy_weight=args.decode_ngram_copy_weight, num_beams=args.decode_num_beams, max_length=args.max_target_length, early_stopping=False, num_return_sequences=1, decoder_start_token_id=model.model.config.pad_token_id) else: gen_ids, logs = model.model.generate( input_ids, do_sample=True, ngram_copy_penalty=args.decode_ngram_copy_penalty, ngram_copy_weight=args.decode_ngram_copy_weight, top_p=args.decode_p, max_length=args.max_target_length, early_stopping=False, num_return_sequences=1, decoder_start_token_id=model.model.config.pad_token_id) return gen_ids, logs
def _step(self, batch): # assert is_frozen(self.teacher) copied_decoder_layers pad_token_id = self.tokenizer.pad_token_id input_ids, src_mask, labels = batch["input_ids"], batch[ "attention_mask"], batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(labels) else: decoder_input_ids = shift_tokens_right(labels, pad_token_id) # noinspection PyCallingNonCallable lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self( input_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, output_hidden_states=True, output_attentions=False, use_cache=False, ) # Same cross entropy vs. label smoothing logic as finetune.py assert lm_logits.shape[-1] == self.model.config.vocab_size if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) student_lm_loss, _ = label_smoothed_nll_loss( lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id) def zero_tensor(): return torch.tensor(0.0).type_as(student_lm_loss) hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor() if self.different_encoder: # compute encoder hidden state loss with torch.no_grad(): teacher_enc_hid = self.teacher.get_encoder()( input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True).hidden_states hid_loss_enc = self.calc_hidden_loss( src_mask, enc_hidden_state, teacher_enc_hid, self.e_matches, normalize_hidden=self.hparams.normalize_hidden, ) with torch.no_grad(): outputs = self.teacher( input_ids, attention_mask=src_mask, encoder_outputs=(enc_outputs, ), decoder_input_ids=decoder_input_ids, lm_labels=labels, output_hidden_states=True, return_dict=True, ) tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states dec_mask = decoder_input_ids.ne(pad_token_id) loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits) if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states hid_loss_dec = self.calc_hidden_loss( dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden) blended_loss = (self.alpha_ce * loss_ce + self.alpha_mlm * student_lm_loss + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)) return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
def convert_examples_to_features_question_generation( examples, tokenizer, max_length=512, max_length_label=32, bart=False, ): """ This function converts a list of examples into features that can be used as inputs for the question generation model. INPUTS: - examples: list of object <class: InputExample>, examples to convert. - tokenizer: torch tokenizer object, tokenize the examples. - max_length: int, size of the maximum input sentence (list of tokens). - max_length_label: int, size of the maximum label sentence. - bart: boolean, saying whether the model is BART or not. OUTPUTS: - features: list of object <class: InputFeatures>, list of features for model. """ processor = DataProcessor() features = [] pad_token = tokenizer.pad_token_id for (ex_index, example) in enumerate(examples): ######## ENCODING INPUT ######## # This will encode both answer and context with a separator. inputs = tokenizer.encode_plus(example.answer, example.context, add_special_tokens=True, max_length=max_length, truncation='only_second') input_ids = inputs["input_ids"] token_type_ids = [0] * (len(tokenizer.encode(example.answer)) + 1) token_type_ids += [1] * (len(input_ids) - len(token_type_ids)) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. attention_mask = [1] * len(input_ids) ######## PADDING ######## padding_length = max_length - len(input_ids) input_ids = input_ids + [pad_token] * padding_length attention_mask = attention_mask + [0] * padding_length token_type_ids = token_type_ids + [0] * padding_length assert len( input_ids ) == max_length, "Error with input ids length {} vs {}".format( len(input_ids), max_length) assert len(attention_mask ) == max_length, "Error with mask length {} vs {}".format( len(attention_mask), max_length) assert len(token_type_ids ) == max_length, "Error with token length {} vs {}".format( len(token_type_ids), max_length) ######## ENCODING LABEL ######## if example.question is not None: if bart: label = tokenizer.encode_plus( example.question, max_length=max_length_label, truncation=True, ) label_ids = label["input_ids"] padding_length = max_length_label - len(label_ids) label_ids = label_ids + [-100] * padding_length decoder_input_ids = shift_tokens_right( torch.tensor(label_ids).unsqueeze(0), -100).squeeze(0).tolist() decoder_input_ids = [ x if x != -100 else pad_token for x in decoder_input_ids ] else: label_ids = tokenizer.encode( example.question, add_special_tokens=True, max_length=max_length_label, truncation=True, ) decoder_input_ids = label_ids padding_length = max_length_label - len(label_ids) label_ids = label_ids + [-100] * padding_length decoder_input_ids = decoder_input_ids + [pad_token ] * padding_length decoder_attention_mask = [1] * max_length_label assert len( label_ids ) == max_length_label, "Error with input length {} vs {}".format( len(input_ids), max_length) assert len( decoder_input_ids ) == max_length_label, "Error with input length {} vs {}".format( len(decoder_input_ids), max_length_label) assert len( decoder_attention_mask ) == max_length_label, "Error with input length {} vs {}".format( len(decoder_attention_mask), max_length_label) features.append( InputFeatures( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label_ids, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, )) else: features.append( InputFeatures( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=None, decoder_input_ids=None, decoder_attention_mask=None, )) return features
def _step(self, batch: dict) -> tuple: """Compute the loss for a batch""" pad_token_id = self.tokenizer.pad_token_id input_ids, src_mask, labels = batch["input_ids"], batch[ "attention_mask"], batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): decoder_input_ids = self.model._shift_right(labels) else: decoder_input_ids = shift_tokens_right(labels, pad_token_id) # noinspection PyCallingNonCallable student_outputs = self( input_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, output_hidden_states=self.do_calc_hidden_loss, output_attentions=False, use_cache=False, ) lm_logits = student_outputs.logits # Same cross entropy vs. label smoothing logic as finetune.py assert lm_logits.shape[-1] == self.model.config.vocab_size if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) else: lprobs = F.log_softmax(lm_logits, dim=-1) student_lm_loss, _ = label_smoothed_nll_loss( lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id) def zero_tensor(): return torch.tensor(0.0).type_as(student_lm_loss) teacher_enc_outputs = student_outputs.encoder_last_hidden_state # use this unless self.different_base_models hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor() if self.different_encoder: # compute encoder hidden state loss all_teacher_encoder_outputs = self.teacher.get_encoder()( input_ids, attention_mask=src_mask, output_hidden_states=self.do_calc_hidden_loss, ) if self.different_base_models: teacher_enc_outputs = all_teacher_encoder_outputs.last_hidden_state elif self.do_calc_hidden_loss: hid_loss_enc = self.calc_hidden_loss( src_mask, student_outputs.encoder_hidden_states, all_teacher_encoder_outputs.hidden_states, self.e_matches, normalize_hidden=self.hparams.normalize_hidden, ) teacher_outputs = self.teacher( input_ids, attention_mask=src_mask, encoder_outputs=(teacher_enc_outputs, ), decoder_input_ids=decoder_input_ids, output_hidden_states=self.do_calc_hidden_loss, use_cache= False, # since we are not passing labels, never let this default to True ) dec_mask = decoder_input_ids.ne(pad_token_id) loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs.logits) if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states hid_loss_dec = self.calc_hidden_loss( dec_mask, student_outputs.decoder_hidden_states, teacher_outputs.decoder_hidden_states, self.d_matches, normalize_hidden=self.hparams.normalize_hidden, ) blended_loss = (self.alpha_ce * loss_ce + self.alpha_mlm * student_lm_loss + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)) return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
def train_epoch(self, batch_size, label_smooth_epsilon, weight, text_type): assert 'train' in self._dataset random.shuffle(self._dataset['train']) print_train_loss_ssl = 0.0 print_train_loss = 0.0 print_train_loss_net = 0.0 num = 0 for i in trange(0, len(self._dataset['train']), batch_size, desc='BART Training'): self._model.split_to_gpus(n_gpus=min(2, torch.cuda.device_count())) self._model.train() self._ssl_model.train() batch = self._dataset['train'][i:i + batch_size] self._optimizer.zero_grad() for j in range(0, len(batch), LIL_BATCH_SIZE): lil_batch = batch[j:j + LIL_BATCH_SIZE] src_lengths = torch.tensor( [len(t.src_tokens) for t in lil_batch]) src_tokens = collate_tokens( [t.src_tokens for t in lil_batch], pad_idx=self._model.dictionary.pad()) tgt_tokens = collate_tokens( [t.tgt_tokens for t in lil_batch], pad_idx=self._model.dictionary.pad()) loss_net = self._get_label_smoothed_nll_loss( src_lengths=src_lengths, src_tokens=src_tokens, tgt_tokens=tgt_tokens, epsilon=label_smooth_epsilon) # SSL training if text_type == 0: texts = [t.texts for t in lil_batch] ssl_batch = tokenizer.batch_encode_plus( texts, padding=True, max_length=SRC_MAX_LEN, truncation=True, return_tensors='pt') ssl_input_ids, ssl_label_ids = mask_tokens( ssl_batch.input_ids, tokenizer, MLM_PROBABILITY) ssl_decoder_input_ids = shift_tokens_right( ssl_batch.input_ids, self._ssl_model.config.pad_token_id) elif text_type == 1: ssl_input_ids, ssl_label_ids = mask_tokens( src_tokens, tokenizer, MLM_PROBABILITY) ssl_decoder_input_ids = shift_tokens_right( src_tokens, self._ssl_model.config.pad_token_id) else: ssl_input_ids, ssl_label_ids = mask_tokens( tgt_tokens, tokenizer, MLM_PROBABILITY) ssl_decoder_input_ids = shift_tokens_right( tgt_tokens, self._ssl_model.config.pad_token_id) ssl_input_ids, ssl_label_ids, ssl_decoder_input_ids = ssl_input_ids.cuda( ), ssl_label_ids.cuda(), ssl_decoder_input_ids.cuda() ssl_outputs = self._ssl_model( input_ids=ssl_input_ids, decoder_input_ids=ssl_decoder_input_ids, labels=ssl_label_ids, ) loss_ssl = ssl_outputs[0] print_train_loss_ssl += loss_ssl print_train_loss_net += loss_net num += 1 # Total training loss loss = (loss_net * NET_WEIGHT + loss_ssl * weight) * len(lil_batch) / batch_size print_train_loss += loss * batch_size if torch.isnan(loss): print('warning: nan loss') print(f'tgt_text: {lil_batch[0].tgt_text}') else: loss.backward() self._optimizer.step() self._lr_scheduler.step() self._global_step += 1 if self._global_step % self._eval_steps == 0: self.gen_log() print("net training loss:", print_train_loss_net.item() / num) print("ssl training loss:", print_train_loss_ssl.item() / num) print("total training loss:", print_train_loss.item() / num)