def test_inference_open_qa(self):
        from transformers.models.realm.retrieval_realm import RealmRetriever

        config = RealmConfig()

        tokenizer = RealmTokenizer.from_pretrained(
            "qqaatw/realm-orqa-nq-openqa")
        retriever = RealmRetriever.from_pretrained(
            "qqaatw/realm-orqa-nq-openqa")

        model = RealmForOpenQA.from_pretrained(
            "qqaatw/realm-orqa-nq-openqa",
            retriever=retriever,
            config=config,
        )

        question = "Who is the pioneer in modern computer science?"

        question = tokenizer(
            [question],
            padding=True,
            truncation=True,
            max_length=model.config.searcher_seq_len,
            return_tensors="pt",
        ).to(model.device)

        predicted_answer_ids = model(**question).predicted_answer_ids

        predicted_answer = tokenizer.decode(predicted_answer_ids)
        self.assertEqual(predicted_answer, "alan mathison turing")
Exemplo n.º 2
0
    def test_training(self):
        if not self.model_tester.is_training:
            return

        config, *inputs = self.model_tester.prepare_config_and_inputs()
        input_ids, token_type_ids, input_mask, scorer_encoder_inputs = inputs[0:4]
        config.return_dict = True

        tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")

        # RealmKnowledgeAugEncoder training
        model = RealmKnowledgeAugEncoder(config)
        model.to(torch_device)
        model.train()

        inputs_dict = {
            "input_ids": scorer_encoder_inputs[0].to(torch_device),
            "attention_mask": scorer_encoder_inputs[1].to(torch_device),
            "token_type_ids": scorer_encoder_inputs[2].to(torch_device),
            "relevance_score": floats_tensor([self.model_tester.batch_size, self.model_tester.num_candidates]),
        }
        inputs_dict["labels"] = torch.zeros(
            (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
        )
        inputs = inputs_dict
        loss = model(**inputs).loss
        loss.backward()

        # RealmForOpenQA training
        openqa_config = copy.deepcopy(config)
        openqa_config.vocab_size = 30522  # the retrieved texts will inevitably have more than 99 vocabs.
        openqa_config.num_block_records = 5
        openqa_config.searcher_beam_size = 2

        block_records = np.array(
            [
                b"This is the first record.",
                b"This is the second record.",
                b"This is the third record.",
                b"This is the fourth record.",
                b"This is the fifth record.",
            ],
            dtype=np.object,
        )
        retriever = RealmRetriever(block_records, tokenizer)
        model = RealmForOpenQA(openqa_config, retriever)
        model.to(torch_device)
        model.train()

        inputs_dict = {
            "input_ids": input_ids[:1].to(torch_device),
            "attention_mask": input_mask[:1].to(torch_device),
            "token_type_ids": token_type_ids[:1].to(torch_device),
            "answer_ids": input_ids[:1].tolist(),
        }
        inputs = self._prepare_for_class(inputs_dict, RealmForOpenQA)
        loss = model(**inputs).reader_output.loss
        loss.backward()