def test_init_and_from_pretrained(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_config = RagConfig.from_pretrained("facebook/rag-sequence-base") rag = TFRagTokenForGeneration(rag_config, retriever=rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf").input_ids decoder_input_ids = rag_decoder_tokenizer( "Linda Davis", return_tensors="tf").input_ids rag( input_ids, decoder_input_ids=decoder_input_ids, ) # this should not give any warnings with tempfile.TemporaryDirectory() as tmpdirname: rag.save_pretrained(tmpdirname) rag = TFRagTokenForGeneration.from_pretrained( tmpdirname, retriever=rag_retriever)
def test_rag_token_from_pretrained(self): load_weight_prefix = "tf_rag_model_1" rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf").input_ids decoder_input_ids = rag_decoder_tokenizer( "Linda Davis", return_tensors="tf").input_ids with tempfile.TemporaryDirectory() as tmp_dirname: rag_token = TFRagTokenForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn", retriever=rag_retriever, config=rag_config, ) # check that the from pretrained methods work rag_token.save_pretrained(tmp_dirname) rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever) output = rag_token(input_ids, labels=decoder_input_ids) loss_pretrained = output.loss del rag_token question_encoder = TFAutoModel.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") generator = TFAutoModelForSeq2SeqLM.from_pretrained( "facebook/bart-large-cnn", load_weight_prefix=load_weight_prefix, name="generator") rag_token = TFRagTokenForGeneration(config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever) output = rag_token(input_ids, labels=decoder_input_ids) loss_init = output.loss self.assertAlmostEqual(loss_pretrained, loss_init, places=4)
def test_rag_token_greedy_search(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) rag_token = TFRagTokenForGeneration.from_pretrained( "facebook/rag-token-nq", retriever=retriever, from_pt=True) # check first two questions input_dict = tokenizer( self.test_data_questions[:2], return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask # make sure only 1 beam is used rag_token.config.num_beams = 1 output_ids = rag_token.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " september 22, 2017", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def test_rag_token_generate_batch(self): # NOTE: gold labels comes from num_beam=4, so this is effectively beam-search test tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) rag_token = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask output_ids = rag_token.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " september 22, 2017", " amplitude modulation", " stefan persson", " april 20, 2018", " the 1970s", " 7.1. 2", " 13", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def test_rag_token_inference_save_pretrained(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.token_model rag_token.set_retriever(rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf").input_ids decoder_input_ids = rag_decoder_tokenizer( "Linda Davis", return_tensors="tf").input_ids # model must run once to be functional before loading/saving works rag_token( input_ids, labels=decoder_input_ids, ) # check that outputs after saving and loading are equal with tempfile.TemporaryDirectory() as tmpdirname: rag_token.save_pretrained(tmpdirname) rag_token = TFRagTokenForGeneration.from_pretrained( tmpdirname, retriever=rag_retriever) output = rag_token( input_ids, labels=decoder_input_ids, ) expected_shape = tf.TensorShape([5, 5, 50264]) self.assertEqual(output.logits.shape, expected_shape) expected_doc_scores = tf.convert_to_tensor( [[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]) expected_loss = tf.convert_to_tensor([36.3557]) tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3)
def test_rag_token_generate_batch(self): # NOTE: gold labels comes from num_beam=4 -- if change gold labels to greedy-generated, test will pass tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) rag_token = TFRagTokenForGeneration.from_pretrained( "facebook/rag-token-nq", retriever=retriever, from_pt=True) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask # rag_token.config.num_beams = 1 -> different in 2 answers (obama, united stadium) to num_beams=4 labels output_ids = rag_token.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " september 22, 2017", " amplitude modulation", " stefan persson", " april 20, 2018", " the 1970s", " 7.1. 2", " 13", " step by step", " stomach", " spodumene", " obama", " northern new jersey", " india", " united stadium", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def test_rag_token_inference_nq_checkpoint(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.token_model_nq_checkpoint(retriever=rag_retriever) # check that outputs after saving and loading are equal with tempfile.TemporaryDirectory() as tmpdirname: rag_token.save_pretrained(tmpdirname) rag_token = TFRagTokenForGeneration.from_pretrained( tmpdirname, retriever=rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf").input_ids decoder_input_ids = rag_decoder_tokenizer( "Linda Davis", return_tensors="tf").input_ids output = rag_token( input_ids, labels=decoder_input_ids, ) expected_shape = tf.TensorShape([5, 5, 50265]) self.assertEqual(output.logits.shape, expected_shape) expected_doc_scores = tf.convert_to_tensor( [[62.9402, 62.7107, 62.2382, 62.1194, 61.8578]]) expected_loss = tf.convert_to_tensor([32.521812]) tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3)
def token_model_nq_checkpoint(self, retriever): return TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", from_pt=True, retriever=retriever)
def token_model(self): return TFRagTokenForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn")