def test_rag_sequence_generate_batch(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained( "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True ) rag_sequence = TFRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask output_ids = rag_sequence.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " june 22, 2018", " amplitude modulation", " tim besley ( chairman )", " june 20, 2018", " 1980", " 7.0", " 8", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)
def test_rag_sequence_generate_batch_from_context_input_ids(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True) rag_sequence = TFRagSequenceForGeneration.from_pretrained( "facebook/rag-sequence-nq", retriever=retriever, from_pt=True) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask question_hidden_states = rag_sequence.question_encoder(input_ids)[0] docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf") doc_scores = tf.squeeze( tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), docs_dict["retrieved_doc_embeds"], transpose_b=True), axis=[1], ) output_ids = rag_sequence.generate( context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, do_deduplication=True, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " june 22, 2018", " amplitude modulation", " tim besley ( chairman )", " june 20, 2018", " 1980", " 7.0", " 8", " reticular formation", " walls of the abdomen", " spodumene", " obama", " new orleans", " japan", " old trafford", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)