def test_rag_sequence_from_pretrained(self):
        load_weight_prefix = "tf_rag_model_1"

        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained(
            "facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba",
            return_tensors="tf").input_ids
        decoder_input_ids = rag_decoder_tokenizer(
            "Linda Davis", return_tensors="tf").input_ids

        with tempfile.TemporaryDirectory() as tmp_dirname:
            rag_sequence = TFRagSequenceForGeneration.from_pretrained_question_encoder_generator(
                "facebook/dpr-question_encoder-single-nq-base",
                "facebook/bart-large-cnn",
                retriever=rag_retriever,
                config=rag_config,
            )
            # check that the from pretrained methods work
            rag_sequence.save_pretrained(tmp_dirname)
            rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever)

            output = rag_sequence(input_ids, labels=decoder_input_ids)

            loss_pretrained = output.loss
            del rag_sequence

        question_encoder = TFAutoModel.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        generator = TFAutoModelForSeq2SeqLM.from_pretrained(
            "facebook/bart-large-cnn",
            load_weight_prefix=load_weight_prefix,
            name="generator")

        rag_sequence = TFRagSequenceForGeneration(
            config=rag_config,
            question_encoder=question_encoder,
            generator=generator,
            retriever=rag_retriever)

        output = rag_sequence(input_ids, labels=decoder_input_ids)

        loss_init = output.loss

        self.assertAlmostEqual(loss_pretrained, loss_init, places=4)
Пример #2
0
    def test_rag_sequence_generate_batch(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
        )
        rag_sequence = TFRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        output_ids = rag_sequence.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
    def test_rag_sequence_generate_batch_from_context_input_ids(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq",
                                                 index_name="exact",
                                                 use_dummy_dataset=True)
        rag_sequence = TFRagSequenceForGeneration.from_pretrained(
            "facebook/rag-sequence-nq", retriever=retriever, from_pt=True)
        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        question_hidden_states = rag_sequence.question_encoder(input_ids)[0]
        docs_dict = retriever(input_ids.numpy(),
                              question_hidden_states.numpy(),
                              return_tensors="tf")
        doc_scores = tf.squeeze(
            tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]),
                      docs_dict["retrieved_doc_embeds"],
                      transpose_b=True),
            axis=[1],
        )
        output_ids = rag_sequence.generate(
            context_input_ids=docs_dict["context_input_ids"],
            context_attention_mask=docs_dict["context_attention_mask"],
            doc_scores=doc_scores,
            do_deduplication=True,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
            " reticular formation",
            " walls of the abdomen",
            " spodumene",
            " obama",
            " new orleans",
            " japan",
            " old trafford",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
 def sequence_model(self):
     return TFRagSequenceForGeneration.from_pretrained_question_encoder_generator(
         "facebook/dpr-question_encoder-single-nq-base",
         "facebook/bart-large-cnn")