def test_rag_sequence_from_pretrained(self): load_weight_prefix = "tf_rag_model_1" rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf").input_ids decoder_input_ids = rag_decoder_tokenizer( "Linda Davis", return_tensors="tf").input_ids with tempfile.TemporaryDirectory() as tmp_dirname: rag_sequence = TFRagSequenceForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn", retriever=rag_retriever, config=rag_config, ) # check that the from pretrained methods work rag_sequence.save_pretrained(tmp_dirname) rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever) output = rag_sequence(input_ids, labels=decoder_input_ids) loss_pretrained = output.loss del rag_sequence question_encoder = TFAutoModel.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") generator = TFAutoModelForSeq2SeqLM.from_pretrained( "facebook/bart-large-cnn", load_weight_prefix=load_weight_prefix, name="generator") rag_sequence = TFRagSequenceForGeneration( config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever) output = rag_sequence(input_ids, labels=decoder_input_ids) loss_init = output.loss self.assertAlmostEqual(loss_pretrained, loss_init, places=4)
def test_rag_sequence_generate_batch(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained( "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True ) rag_sequence = TFRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask output_ids = rag_sequence.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " june 22, 2018", " amplitude modulation", " tim besley ( chairman )", " june 20, 2018", " 1980", " 7.0", " 8", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def test_rag_sequence_generate_batch_from_context_input_ids(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True) rag_sequence = TFRagSequenceForGeneration.from_pretrained( "facebook/rag-sequence-nq", retriever=retriever, from_pt=True) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask question_hidden_states = rag_sequence.question_encoder(input_ids)[0] docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf") doc_scores = tf.squeeze( tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), docs_dict["retrieved_doc_embeds"], transpose_b=True), axis=[1], ) output_ids = rag_sequence.generate( context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, do_deduplication=True, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " june 22, 2018", " amplitude modulation", " tim besley ( chairman )", " june 20, 2018", " 1980", " 7.0", " 8", " reticular formation", " walls of the abdomen", " spodumene", " obama", " new orleans", " japan", " old trafford", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def sequence_model(self): return TFRagSequenceForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn")