def test_inference_open_qa(self): from transformers.models.realm.retrieval_realm import RealmRetriever config = RealmConfig() tokenizer = RealmTokenizer.from_pretrained( "qqaatw/realm-orqa-nq-openqa") retriever = RealmRetriever.from_pretrained( "qqaatw/realm-orqa-nq-openqa") model = RealmForOpenQA.from_pretrained( "qqaatw/realm-orqa-nq-openqa", retriever=retriever, config=config, ) question = "Who is the pioneer in modern computer science?" question = tokenizer( [question], padding=True, truncation=True, max_length=model.config.searcher_seq_len, return_tensors="pt", ).to(model.device) predicted_answer_ids = model(**question).predicted_answer_ids predicted_answer = tokenizer.decode(predicted_answer_ids) self.assertEqual(predicted_answer, "alan mathison turing")
def test_open_qa_from_pretrained(self): model = RealmForOpenQA.from_pretrained("qqaatw/realm-orqa-nq-openqa") self.assertIsNotNone(model)