def test_rag_token_from_pretrained(self):
        load_weight_prefix = "tf_rag_model_1"

        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained(
            "facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba",
            return_tensors="tf").input_ids
        decoder_input_ids = rag_decoder_tokenizer(
            "Linda Davis", return_tensors="tf").input_ids

        with tempfile.TemporaryDirectory() as tmp_dirname:
            rag_token = TFRagTokenForGeneration.from_pretrained_question_encoder_generator(
                "facebook/dpr-question_encoder-single-nq-base",
                "facebook/bart-large-cnn",
                retriever=rag_retriever,
                config=rag_config,
            )
            # check that the from pretrained methods work
            rag_token.save_pretrained(tmp_dirname)
            rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)

            output = rag_token(input_ids, labels=decoder_input_ids)

            loss_pretrained = output.loss
            del rag_token

        question_encoder = TFAutoModel.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        generator = TFAutoModelForSeq2SeqLM.from_pretrained(
            "facebook/bart-large-cnn",
            load_weight_prefix=load_weight_prefix,
            name="generator")
        rag_token = TFRagTokenForGeneration(config=rag_config,
                                            question_encoder=question_encoder,
                                            generator=generator,
                                            retriever=rag_retriever)

        output = rag_token(input_ids, labels=decoder_input_ids)

        loss_init = output.loss

        self.assertAlmostEqual(loss_pretrained, loss_init, places=4)
 def token_model(self):
     return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(
         "facebook/dpr-question_encoder-single-nq-base",
         "facebook/bart-large-cnn")