def test_init_and_from_pretrained(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_config = RagConfig.from_pretrained("facebook/rag-sequence-base") rag = TFRagTokenForGeneration(rag_config, retriever=rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf").input_ids decoder_input_ids = rag_decoder_tokenizer( "Linda Davis", return_tensors="tf").input_ids rag( input_ids, decoder_input_ids=decoder_input_ids, ) # this should not give any warnings with tempfile.TemporaryDirectory() as tmpdirname: rag.save_pretrained(tmpdirname) rag = TFRagTokenForGeneration.from_pretrained( tmpdirname, retriever=rag_retriever)
def test_rag_token_inference(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base" ) rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.token_model rag_token.set_retriever(rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf" ).input_ids decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids output = rag_token( input_ids, labels=decoder_input_ids, ) expected_shape = tf.TensorShape([5, 5, 50264]) self.assertEqual(output.logits.shape, expected_shape) expected_doc_scores = tf.convert_to_tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]) expected_loss = tf.convert_to_tensor([36.3557]) tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3)
def test_rag_sequence_from_pretrained(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base" ) rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="pt" ).input_ids decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids input_ids = input_ids.to(torch_device) decoder_input_ids = decoder_input_ids.to(torch_device) with tempfile.TemporaryDirectory() as tmp_dirname: rag_sequence = RagSequenceForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn", retriever=rag_retriever, config=rag_config, ).to(torch_device) # check that the from pretrained methods work rag_sequence.save_pretrained(tmp_dirname) rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever) rag_sequence.to(torch_device) with torch.no_grad(): output = rag_sequence( input_ids, labels=decoder_input_ids, ) loss_pretrained = output.loss del rag_sequence question_encoder = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base") generator = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") rag_sequence = RagSequenceForGeneration( config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever ) rag_sequence.to(torch_device) with torch.no_grad(): output = rag_sequence( input_ids, labels=decoder_input_ids, ) loss_init = output.loss self.assertAlmostEqual(loss_pretrained.item(), loss_init.item(), places=4)
def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index): if not self.initialized: self.retriever = RagRetriever( config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer, index=index, init_retrieval=False, ) self.initialized = True
def test_rag_sequence_from_pretrained(self): load_weight_prefix = "tf_rag_model_1" rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf").input_ids decoder_input_ids = rag_decoder_tokenizer( "Linda Davis", return_tensors="tf").input_ids with tempfile.TemporaryDirectory() as tmp_dirname: rag_sequence = TFRagSequenceForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn", retriever=rag_retriever, config=rag_config, ) # check that the from pretrained methods work rag_sequence.save_pretrained(tmp_dirname) rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever) output = rag_sequence(input_ids, labels=decoder_input_ids) loss_pretrained = output.loss del rag_sequence question_encoder = TFAutoModel.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") generator = TFAutoModelForSeq2SeqLM.from_pretrained( "facebook/bart-large-cnn", load_weight_prefix=load_weight_prefix, name="generator") rag_sequence = TFRagSequenceForGeneration( config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever) output = rag_sequence(input_ids, labels=decoder_input_ids) loss_init = output.loss self.assertAlmostEqual(loss_pretrained, loss_init, places=4)
def test_rag_sequence_generate_batch(self): # IMPORTAN: This test fails on GPU, but is fine on CPU -> beam search is very sensible rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base" ) rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_sequence = self.sequence_model rag_sequence.set_retriever(rag_retriever) questions = [ "who sings does he love me with reba", "how many pages is invisible man by ralph ellison", "what", ] input_dict = rag_question_encoder_tokenizer.batch_encode_plus( questions, return_tensors="pt", padding=True, truncation=True, ) input_ids = input_dict.input_ids.to(torch_device) attention_mask = input_dict.attention_mask.to(torch_device) output_ids = rag_sequence.generate( input_ids, attention_mask=attention_mask, decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id, num_beams=4, num_return_sequences=1, max_length=10, ) # sequence generate test output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True) # Expected outputs as given by model at integration time. EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"' EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the' EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark" self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
def test_rag_token_inference_save_pretrained(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.token_model rag_token.set_retriever(rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf").input_ids decoder_input_ids = rag_decoder_tokenizer( "Linda Davis", return_tensors="tf").input_ids # model must run once to be functional before loading/saving works rag_token( input_ids, labels=decoder_input_ids, ) # check that outputs after saving and loading are equal with tempfile.TemporaryDirectory() as tmpdirname: rag_token.save_pretrained(tmpdirname) rag_token = TFRagTokenForGeneration.from_pretrained( tmpdirname, retriever=rag_retriever) output = rag_token( input_ids, labels=decoder_input_ids, ) expected_shape = tf.TensorShape([5, 5, 50264]) self.assertEqual(output.logits.shape, expected_shape) expected_doc_scores = tf.convert_to_tensor( [[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]) expected_loss = tf.convert_to_tensor([36.3557]) tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3)
def test_rag_token_generate_batch(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.token_model rag_token.set_retriever(rag_retriever) questions = [ "who sings does he love me with reba", "how many pages is invisible man by ralph ellison", ] input_ids = rag_question_encoder_tokenizer.batch_encode_plus( questions, return_tensors="pt", padding=True, truncation=True, ).input_ids input_ids = input_ids.to(torch_device) output_ids = rag_token.generate( input_ids, decoder_start_token_id=rag_token.generator.config. decoder_start_token_id, num_beams=4, num_return_sequences=1, max_length=10, ) # sequence generate test output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) # Expected outputs as given by model at integration time. EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the' EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man' self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
def get_retriever(self, config): dataset = Dataset.from_dict( { "id": ["0", "1"], "text": ["foo", "bar"], "title": ["Foo", "Bar"], "embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)], } ) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) tokenizer = self.bart_tokenizer if config.generator.model_type == "bart" else self.t5_tokenizer with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset: mock_load_dataset.return_value = dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.dpr_tokenizer, generator_tokenizer=tokenizer, ) return retriever
def test_rag_token_inference_nq_checkpoint(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.token_model_nq_checkpoint(retriever=rag_retriever) # check that outputs after saving and loading are equal with tempfile.TemporaryDirectory() as tmpdirname: rag_token.save_pretrained(tmpdirname) rag_token = TFRagTokenForGeneration.from_pretrained( tmpdirname, retriever=rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf").input_ids decoder_input_ids = rag_decoder_tokenizer( "Linda Davis", return_tensors="tf").input_ids output = rag_token( input_ids, labels=decoder_input_ids, ) expected_shape = tf.TensorShape([5, 5, 50265]) self.assertEqual(output.logits.shape, expected_shape) expected_doc_scores = tf.convert_to_tensor( [[62.9402, 62.7107, 62.2382, 62.1194, 61.8578]]) expected_loss = tf.convert_to_tensor([32.521812]) tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3)
def test_rag_token_inference(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.token_model rag_token.set_retriever(rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="pt").input_ids decoder_input_ids = rag_decoder_tokenizer( "Linda Davis", return_tensors="pt").input_ids input_ids = input_ids.to(torch_device) decoder_input_ids = decoder_input_ids.to(torch_device) with torch.no_grad(): output = rag_token( input_ids, labels=decoder_input_ids, ) expected_shape = torch.Size([5, 5, 50264]) self.assertEqual(output.logits.shape, expected_shape) expected_doc_scores = torch.tensor( [[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device) _assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE) expected_loss = torch.tensor([36.3557]).to(torch_device) _assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)
def test_rag_sequence_generate_beam(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.sequence_model rag_token.set_retriever(rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="pt").input_ids input_ids = input_ids.to(torch_device) output_ids = rag_token.generate( input_ids, decoder_start_token_id=rag_token.generator.config. decoder_start_token_id, num_beams=2, num_return_sequences=2, ) # sequence generate test output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) # Expected outputs as given by model at integration time. EXPECTED_OUTPUT_TEXT_1 = """\"She's My Kind of Girl\" was released through Epic Records in Japan in March 1972, giving the duo a Top 10 hit. Two more singles were released in Japan, \"En Carousel\" and \"Love Has Its Ways\" Ulvaeus and Andersson persevered with their songwriting and experimented with new sounds and vocal arrangements.""" EXPECTED_OUTPUT_TEXT_2 = """In September 2018, Björn Ulvaeus revealed that the two new songs, \"I Still Have Faith In You\" and \"Don't Shut Me Down\", would be released no earlier than March 2019. The two new tracks will feature in a TV special set to air later in the year.""" self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
def test_rag_token_generate_beam(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.token_model rag_token.set_retriever(rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="pt").input_ids input_ids = input_ids.to(torch_device) output_ids = rag_token.generate( input_ids, decoder_start_token_id=rag_token.generator.config. decoder_start_token_id, num_beams=2, num_return_sequences=2, ) # sequence generate test output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) # Expected outputs as given by model at integration time. EXPECTED_OUTPUT_TEXT_1 = "\"She's My Kind of Girl" EXPECTED_OUTPUT_TEXT_2 = "\"She's My Kind of Love" self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
def test_rag_sequence_generate_beam(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.sequence_model rag_token.set_retriever(rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="pt").input_ids input_ids = input_ids.to(torch_device) output_ids = rag_token.generate( input_ids, decoder_start_token_id=rag_token.generator.config. decoder_start_token_id, num_beams=2, num_return_sequences=2, ) # sequence generate test output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) # Expected outputs as given by model at integration time. EXPECTED_OUTPUT_TEXT_1 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" that day.""" EXPECTED_OUTPUT_TEXT_2 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" (a top ten hit in Austria)""" self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--model_path", type=str, default="/dccstor/dialog/sfeng/transformers_doc2dial/checkpoints/colbert-converted-60000/question_encoder/", ) parser.add_argument( "--out_path", type=str, default="tmp", ) parser.add_argument( "--index_name", type=str, default="exact", ) args = parser.parse_args() model = RagTokenForGeneration.from_pretrained_question_encoder_generator(args.model_path, "facebook/bart-large") question_encoder_tokenizer = AutoTokenizer.from_pretrained(args.model_path) generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer) model.config.use_dummy_dataset = True model.config.index_name = args.index_name retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer) model.save_pretrained(args.out_path) tokenizer.save_pretrained(args.out_path) retriever.save_pretrained(args.out_path)
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, AutoTokenizer model = RagTokenForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large") question_encoder_tokenizer = AutoTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer) model.config.use_dummy_dataset = True model.config.index_name = "exact" retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer) model.save_pretrained("./") tokenizer.save_pretrained("./") retriever.save_pretrained("./")