def test_init_and_from_pretrained(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained(
            "facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_config = RagConfig.from_pretrained("facebook/rag-sequence-base")
        rag = TFRagTokenForGeneration(rag_config, retriever=rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba",
            return_tensors="tf").input_ids
        decoder_input_ids = rag_decoder_tokenizer(
            "Linda Davis", return_tensors="tf").input_ids

        rag(
            input_ids,
            decoder_input_ids=decoder_input_ids,
        )

        # this should not give any warnings
        with tempfile.TemporaryDirectory() as tmpdirname:
            rag.save_pretrained(tmpdirname)
            rag = TFRagTokenForGeneration.from_pretrained(
                tmpdirname, retriever=rag_retriever)
Exemplo n.º 2
0
    def test_rag_token_inference(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_token = self.token_model
        rag_token.set_retriever(rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="tf"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids

        output = rag_token(
            input_ids,
            labels=decoder_input_ids,
        )

        expected_shape = tf.TensorShape([5, 5, 50264])
        self.assertEqual(output.logits.shape, expected_shape)

        expected_doc_scores = tf.convert_to_tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]])
        expected_loss = tf.convert_to_tensor([36.3557])

        tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3)
        tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3)
Exemplo n.º 3
0
    def test_rag_sequence_from_pretrained(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="pt"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids

        input_ids = input_ids.to(torch_device)
        decoder_input_ids = decoder_input_ids.to(torch_device)

        with tempfile.TemporaryDirectory() as tmp_dirname:
            rag_sequence = RagSequenceForGeneration.from_pretrained_question_encoder_generator(
                "facebook/dpr-question_encoder-single-nq-base",
                "facebook/bart-large-cnn",
                retriever=rag_retriever,
                config=rag_config,
            ).to(torch_device)
            # check that the from pretrained methods work
            rag_sequence.save_pretrained(tmp_dirname)
            rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever)
            rag_sequence.to(torch_device)

            with torch.no_grad():
                output = rag_sequence(
                    input_ids,
                    labels=decoder_input_ids,
                )

            loss_pretrained = output.loss
            del rag_sequence

        question_encoder = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        generator = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
        rag_sequence = RagSequenceForGeneration(
            config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever
        )
        rag_sequence.to(torch_device)

        with torch.no_grad():
            output = rag_sequence(
                input_ids,
                labels=decoder_input_ids,
            )

        loss_init = output.loss

        self.assertAlmostEqual(loss_pretrained.item(), loss_init.item(), places=4)
    def test_rag_sequence_from_pretrained(self):
        load_weight_prefix = "tf_rag_model_1"

        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained(
            "facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba",
            return_tensors="tf").input_ids
        decoder_input_ids = rag_decoder_tokenizer(
            "Linda Davis", return_tensors="tf").input_ids

        with tempfile.TemporaryDirectory() as tmp_dirname:
            rag_sequence = TFRagSequenceForGeneration.from_pretrained_question_encoder_generator(
                "facebook/dpr-question_encoder-single-nq-base",
                "facebook/bart-large-cnn",
                retriever=rag_retriever,
                config=rag_config,
            )
            # check that the from pretrained methods work
            rag_sequence.save_pretrained(tmp_dirname)
            rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever)

            output = rag_sequence(input_ids, labels=decoder_input_ids)

            loss_pretrained = output.loss
            del rag_sequence

        question_encoder = TFAutoModel.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        generator = TFAutoModelForSeq2SeqLM.from_pretrained(
            "facebook/bart-large-cnn",
            load_weight_prefix=load_weight_prefix,
            name="generator")

        rag_sequence = TFRagSequenceForGeneration(
            config=rag_config,
            question_encoder=question_encoder,
            generator=generator,
            retriever=rag_retriever)

        output = rag_sequence(input_ids, labels=decoder_input_ids)

        loss_init = output.loss

        self.assertAlmostEqual(loss_pretrained, loss_init, places=4)
    def test_rag_token_inference_save_pretrained(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained(
            "facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_token = self.token_model
        rag_token.set_retriever(rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba",
            return_tensors="tf").input_ids
        decoder_input_ids = rag_decoder_tokenizer(
            "Linda Davis", return_tensors="tf").input_ids

        # model must run once to be functional before loading/saving works
        rag_token(
            input_ids,
            labels=decoder_input_ids,
        )

        # check that outputs after saving and loading are equal
        with tempfile.TemporaryDirectory() as tmpdirname:
            rag_token.save_pretrained(tmpdirname)
            rag_token = TFRagTokenForGeneration.from_pretrained(
                tmpdirname, retriever=rag_retriever)

        output = rag_token(
            input_ids,
            labels=decoder_input_ids,
        )

        expected_shape = tf.TensorShape([5, 5, 50264])
        self.assertEqual(output.logits.shape, expected_shape)

        expected_doc_scores = tf.convert_to_tensor(
            [[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]])
        expected_loss = tf.convert_to_tensor([36.3557])

        tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3)
        tf.debugging.assert_near(output.doc_scores,
                                 expected_doc_scores,
                                 atol=1e-3)
    def test_rag_token_inference_nq_checkpoint(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained(
            "facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_token = self.token_model_nq_checkpoint(retriever=rag_retriever)

        # check that outputs after saving and loading are equal
        with tempfile.TemporaryDirectory() as tmpdirname:
            rag_token.save_pretrained(tmpdirname)
            rag_token = TFRagTokenForGeneration.from_pretrained(
                tmpdirname, retriever=rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba",
            return_tensors="tf").input_ids
        decoder_input_ids = rag_decoder_tokenizer(
            "Linda Davis", return_tensors="tf").input_ids

        output = rag_token(
            input_ids,
            labels=decoder_input_ids,
        )

        expected_shape = tf.TensorShape([5, 5, 50265])
        self.assertEqual(output.logits.shape, expected_shape)

        expected_doc_scores = tf.convert_to_tensor(
            [[62.9402, 62.7107, 62.2382, 62.1194, 61.8578]])
        expected_loss = tf.convert_to_tensor([32.521812])

        tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3)
        tf.debugging.assert_near(output.doc_scores,
                                 expected_doc_scores,
                                 atol=1e-3)
Exemplo n.º 7
0
    def test_rag_token_inference(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained(
            "facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_token = self.token_model
        rag_token.set_retriever(rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba",
            return_tensors="pt").input_ids
        decoder_input_ids = rag_decoder_tokenizer(
            "Linda Davis", return_tensors="pt").input_ids

        input_ids = input_ids.to(torch_device)
        decoder_input_ids = decoder_input_ids.to(torch_device)

        with torch.no_grad():
            output = rag_token(
                input_ids,
                labels=decoder_input_ids,
            )

        expected_shape = torch.Size([5, 5, 50264])
        self.assertEqual(output.logits.shape, expected_shape)

        expected_doc_scores = torch.tensor(
            [[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
        _assert_tensors_equal(expected_doc_scores,
                              output.doc_scores,
                              atol=TOLERANCE)

        expected_loss = torch.tensor([36.3557]).to(torch_device)
        _assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)
Exemplo n.º 8
0
    def test_rag_sequence_generate_beam(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained(
            "facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_sequence = self.sequence_model
        rag_sequence.set_retriever(rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba",
            return_tensors="pt").input_ids

        input_ids = input_ids.to(torch_device)

        output_ids = rag_sequence.generate(
            input_ids,
            decoder_start_token_id=rag_sequence.generator.config.
            decoder_start_token_id,
            num_beams=2,
            num_return_sequences=2,
        )
        # sequence generate test
        output_text_1 = rag_decoder_tokenizer.decode(output_ids[0],
                                                     skip_special_tokens=True)
        output_text_2 = rag_decoder_tokenizer.decode(output_ids[1],
                                                     skip_special_tokens=True)

        # Expected outputs as given by model at integration time.
        EXPECTED_OUTPUT_TEXT_1 = """\"She's My Kind of Girl\" was released through Epic Records in Japan in March 1972, giving the duo a Top 10 hit. Two more singles were released in Japan, \"En Carousel\" and \"Love Has Its Ways\" Ulvaeus and Andersson persevered with their songwriting and experimented with new sounds and vocal arrangements."""
        EXPECTED_OUTPUT_TEXT_2 = """In September 2018, Björn Ulvaeus revealed that the two new songs, \"I Still Have Faith In You\" and \"Don't Shut Me Down\", would be released no earlier than March 2019. The two new tracks will feature in a TV special set to air later in the year."""

        self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
        self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
Exemplo n.º 9
0
    def test_rag_token_generate_beam(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained(
            "facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_token = self.token_model
        rag_token.set_retriever(rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba",
            return_tensors="pt").input_ids

        input_ids = input_ids.to(torch_device)

        output_ids = rag_token.generate(
            input_ids,
            decoder_start_token_id=rag_token.generator.config.
            decoder_start_token_id,
            num_beams=2,
            num_return_sequences=2,
        )
        # sequence generate test
        output_text_1 = rag_decoder_tokenizer.decode(output_ids[0],
                                                     skip_special_tokens=True)
        output_text_2 = rag_decoder_tokenizer.decode(output_ids[1],
                                                     skip_special_tokens=True)

        # Expected outputs as given by model at integration time.
        EXPECTED_OUTPUT_TEXT_1 = "\"She's My Kind of Girl"
        EXPECTED_OUTPUT_TEXT_2 = "\"She's My Kind of Love"

        self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
        self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
Exemplo n.º 10
0
 def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
     return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))