def test_rag_token_from_pretrained(self): load_weight_prefix = "tf_rag_model_1" rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf").input_ids decoder_input_ids = rag_decoder_tokenizer( "Linda Davis", return_tensors="tf").input_ids with tempfile.TemporaryDirectory() as tmp_dirname: rag_token = TFRagTokenForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn", retriever=rag_retriever, config=rag_config, ) # check that the from pretrained methods work rag_token.save_pretrained(tmp_dirname) rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever) output = rag_token(input_ids, labels=decoder_input_ids) loss_pretrained = output.loss del rag_token question_encoder = TFAutoModel.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") generator = TFAutoModelForSeq2SeqLM.from_pretrained( "facebook/bart-large-cnn", load_weight_prefix=load_weight_prefix, name="generator") rag_token = TFRagTokenForGeneration(config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever) output = rag_token(input_ids, labels=decoder_input_ids) loss_init = output.loss self.assertAlmostEqual(loss_pretrained, loss_init, places=4)
def token_model(self): return TFRagTokenForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn")