def require_distributed_retrieval(test_case): """ Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with :class:`~transformers.RagRetriever`. These tests are skipped when respective libraries are not installed. """ if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()): test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case) return test_case
def require_retrieval(test_case): """ Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with :class:`~transformers.RagRetriever`. These tests are skipped when respective libraries are not installed. """ if not (is_tf_available() and is_datasets_available() and is_faiss_available()): test_case = unittest.skip("test requires tensorflow, datasets and faiss")(test_case) return test_case
def require_retrieval(test_case): """ Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with [`RagRetriever`]. These tests are skipped when respective libraries are not installed. """ if not (is_torch_available() and is_datasets_available() and is_faiss_available()): test_case = unittest.skip("test requires PyTorch, datasets and faiss")( test_case) return test_case
def test_finetune_bert2bert(self): if not is_datasets_available(): return import datasets bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained( "prajjwal1/bert-tiny", "prajjwal1/bert-tiny") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size bert2bert.config.eos_token_id = tokenizer.sep_token_id bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id bert2bert.config.max_length = 128 train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]") val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]") train_dataset = train_dataset.select(range(32)) val_dataset = val_dataset.select(range(16)) rouge = datasets.load_metric("rouge") batch_size = 4 def _map_to_encoder_decoder_inputs(batch): # Tokenizer will automatically set [BOS] <text> [EOS] inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512) outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128) batch["input_ids"] = inputs.input_ids batch["attention_mask"] = inputs.attention_mask batch["decoder_input_ids"] = outputs.input_ids batch["labels"] = outputs.input_ids.copy() batch["labels"] = [[ -100 if token == tokenizer.pad_token_id else token for token in labels ] for labels in batch["labels"]] batch["decoder_attention_mask"] = outputs.attention_mask assert all([len(x) == 512 for x in inputs.input_ids]) assert all([len(x) == 128 for x in outputs.input_ids]) return batch def _compute_metrics(pred): labels_ids = pred.label_ids pred_ids = pred.predictions # all unnecessary tokens are removed pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid return { "rouge2_precision": round(rouge_output.precision, 4), "rouge2_recall": round(rouge_output.recall, 4), "rouge2_fmeasure": round(rouge_output.fmeasure, 4), } # map train dataset train_dataset = train_dataset.map( _map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["article", "highlights"], ) train_dataset.set_format( type="torch", columns=[ "input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels" ], ) # same for validation dataset val_dataset = val_dataset.map( _map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["article", "highlights"], ) val_dataset.set_format( type="torch", columns=[ "input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels" ], ) output_dir = self.get_auto_remove_tmp_dir() training_args = Seq2SeqTrainingArguments( output_dir=output_dir, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, predict_with_generate=True, evaluation_strategy="steps", do_train=True, do_eval=True, warmup_steps=0, eval_steps=2, logging_steps=2, ) # instantiate trainer trainer = Seq2SeqTrainer( model=bert2bert, args=training_args, compute_metrics=_compute_metrics, train_dataset=train_dataset, eval_dataset=val_dataset, ) # start training trainer.train()
class RagTestMixin: all_model_classes = ((RagModel, RagTokenForGeneration, RagSequenceForGeneration) if is_torch_available() and is_datasets_available() and is_faiss_available() else ()) retrieval_vector_size = 32 n_docs = 3 max_combined_length = 16 def setUp(self): self.tmpdirname = tempfile.mkdtemp() # DPR tok vocab_tokens = [ "[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest", ] dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer") os.makedirs(dpr_tokenizer_path, exist_ok=True) self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"]) with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) # BART tok vocab = [ "l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "\u0120", "\u0120l", "\u0120n", "\u0120lo", "\u0120low", "er", "\u0120lowest", "\u0120newer", "\u0120wider", "<unk>", ] vocab_tokens = dict(zip(vocab, range(len(vocab)))) merges = [ "#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", "" ] self.special_tokens_map = {"unk_token": "<unk>"} bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer") os.makedirs(bart_tokenizer_path, exist_ok=True) self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"]) self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"]) with open(self.vocab_file, "w", encoding="utf-8") as fp: fp.write(json.dumps(vocab_tokens) + "\n") with open(self.merges_file, "w", encoding="utf-8") as fp: fp.write("\n".join(merges)) t5_tokenizer = T5Tokenizer(T5_SAMPLE_VOCAB) t5_tokenizer_path = os.path.join(self.tmpdirname, "t5_tokenizer") t5_tokenizer.save_pretrained(t5_tokenizer_path) @cached_property def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer: return DPRQuestionEncoderTokenizer.from_pretrained( os.path.join(self.tmpdirname, "dpr_tokenizer")) @cached_property def bart_tokenizer(self) -> BartTokenizer: return BartTokenizer.from_pretrained( os.path.join(self.tmpdirname, "bart_tokenizer")) @cached_property def t5_tokenizer(self) -> BartTokenizer: return T5Tokenizer.from_pretrained( os.path.join(self.tmpdirname, "t5_tokenizer")) def tearDown(self): shutil.rmtree(self.tmpdirname) def get_retriever(self, config): dataset = Dataset.from_dict({ "id": ["0", "1", "3"], "text": ["foo", "bar", "qux"], "title": ["Foo", "Bar", "Qux"], "embeddings": [ np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size), 3 * np.ones(self.retrieval_vector_size), ], }) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) tokenizer = self.bart_tokenizer if config.generator.model_type == "bart" else self.t5_tokenizer with patch("transformers.retrieval_rag.load_dataset" ) as mock_load_dataset: mock_load_dataset.return_value = dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.dpr_tokenizer, generator_tokenizer=tokenizer, ) return retriever def check_model_with_retriever(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) for model_class in self.all_model_classes: model = model_class( config, retriever=self.get_retriever(config)).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) outputs = model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) # logits self.assertEqual( outputs.logits.shape, (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) def check_model_generate(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) for model_class in self.all_model_classes[1:]: model = model_class( config, retriever=self.get_retriever(config)).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) outputs = model.generate( input_ids=input_ids, num_beams=2, num_return_sequences=2, decoder_start_token_id=config.generator.eos_token_id, ) self.assertIsNotNone(outputs) def check_model_without_retriever(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) retriever = self.get_retriever(config) for model_class in self.all_model_classes: model = model_class(config).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) question_hidden_states = model.question_encoder( input_ids, attention_mask=attention_mask)[0] out = retriever( input_ids, question_hidden_states.cpu().detach().to( torch.float32).numpy(), prefix=config.generator.prefix, return_tensors="pt", ) context_input_ids, context_attention_mask, retrieved_doc_embeds = ( out["context_input_ids"], out["context_attention_mask"], out["retrieved_doc_embeds"], ) # cast retrieved_doc_embeds = retrieved_doc_embeds.to( question_hidden_states) context_input_ids = context_input_ids.to(input_ids) context_attention_mask = context_attention_mask.to(input_ids) # compute doc_scores doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose( 1, 2)).squeeze(1) outputs = model( context_input_ids=context_input_ids, context_attention_mask=context_attention_mask, doc_scores=doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) # logits self.assertEqual( outputs.logits.shape, (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) def check_model_custom_n_docs(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, n_docs, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) retriever = self.get_retriever(config) for model_class in self.all_model_classes: model = model_class(config).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) question_hidden_states = model.question_encoder( input_ids, attention_mask=attention_mask)[0] out = retriever( input_ids, question_hidden_states.cpu().detach().to( torch.float32).numpy(), prefix=config.generator.prefix, return_tensors="pt", n_docs=n_docs, ) context_input_ids, context_attention_mask, retrieved_doc_embeds = ( out["context_input_ids"], out["context_attention_mask"], out["retrieved_doc_embeds"], ) # cast retrieved_doc_embeds = retrieved_doc_embeds.to( question_hidden_states) context_input_ids = context_input_ids.to(input_ids) context_attention_mask = context_attention_mask.to(input_ids) # compute doc_scores doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose( 1, 2)).squeeze(1) outputs = model( context_input_ids=context_input_ids, context_attention_mask=context_attention_mask, doc_scores=doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, n_docs=n_docs, ) # logits self.assertEqual( outputs.logits.shape, (n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], n_docs)) def check_model_with_mismatch_n_docs_value(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, retriever_n_docs, generator_n_docs, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) retriever = self.get_retriever(config) for model_class in self.all_model_classes: model = model_class(config).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) question_hidden_states = model.question_encoder( input_ids, attention_mask=attention_mask)[0] out = retriever( input_ids, question_hidden_states.cpu().detach().to( torch.float32).numpy(), prefix=config.generator.prefix, return_tensors="pt", n_docs=retriever_n_docs, ) context_input_ids, context_attention_mask, retrieved_doc_embeds = ( out["context_input_ids"], out["context_attention_mask"], out["retrieved_doc_embeds"], ) # cast retrieved_doc_embeds = retrieved_doc_embeds.to( question_hidden_states) context_input_ids = context_input_ids.to(input_ids) context_attention_mask = context_attention_mask.to(input_ids) # compute doc_scores doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose( 1, 2)).squeeze(1) self.assertRaises( AssertionError, model.__call__, context_input_ids=context_input_ids, context_attention_mask=context_attention_mask, doc_scores=doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, n_docs=generator_n_docs, ) def check_model_with_encoder_outputs(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) for model_class in self.all_model_classes: model = model_class( config, retriever=self.get_retriever(config)).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) outputs = model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) encoder_outputs = BaseModelOutput( outputs.generator_enc_last_hidden_state) # run only generator outputs = model( encoder_outputs=encoder_outputs, doc_scores=outputs.doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) # logits self.assertEqual( outputs.logits.shape, (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) def test_model_with_retriever(self): inputs_dict = self.config_and_inputs self.check_model_with_retriever(**inputs_dict) def test_model_without_retriever(self): inputs_dict = self.config_and_inputs self.check_model_without_retriever(**inputs_dict) def test_model_with_encoder_outputs(self): inputs_dict = self.config_and_inputs self.check_model_with_encoder_outputs(**inputs_dict) def test_model_generate(self): inputs_dict = self.config_and_inputs self.check_model_generate(**inputs_dict) def test_model_with_custom_n_docs(self): inputs_dict = self.config_and_inputs inputs_dict["n_docs"] = 1 self.check_model_custom_n_docs(**inputs_dict) def test_model_with_mismatch_n_docs_value(self): inputs_dict = self.config_and_inputs inputs_dict["retriever_n_docs"] = 3 inputs_dict["generator_n_docs"] = 2 self.check_model_with_mismatch_n_docs_value(**inputs_dict)
from transformers.file_utils import cached_property, is_datasets_available, is_faiss_available, is_torch_available from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES from .test_modeling_bart import ModelTester as BartModelTester from .test_modeling_dpr import DPRModelTester from .test_modeling_t5 import T5ModelTester TOLERANCE = 1e-3 T5_SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") if is_torch_available() and is_datasets_available() and is_faiss_available(): import torch from datasets import Dataset import faiss from transformers import ( AutoConfig, AutoModel, AutoModelForSeq2SeqLM, RagConfig, RagModel, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration, RagTokenizer, )
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments from transformers.file_utils import is_datasets_available from transformers.testing_utils import TestCasePlus, require_datasets, slow if is_datasets_available(): import datasets class Seq2seqTrainerTester(TestCasePlus): @slow @require_datasets def test_finetune_bert2bert(self): bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size bert2bert.config.eos_token_id = tokenizer.sep_token_id bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id bert2bert.config.max_length = 128