def test_inference_open_qa(self): from transformers.models.realm.retrieval_realm import RealmRetriever config = RealmConfig() tokenizer = RealmTokenizer.from_pretrained( "qqaatw/realm-orqa-nq-openqa") retriever = RealmRetriever.from_pretrained( "qqaatw/realm-orqa-nq-openqa") model = RealmForOpenQA.from_pretrained( "qqaatw/realm-orqa-nq-openqa", retriever=retriever, config=config, ) question = "Who is the pioneer in modern computer science?" question = tokenizer( [question], padding=True, truncation=True, max_length=model.config.searcher_seq_len, return_tensors="pt", ).to(model.device) predicted_answer_ids = model(**question).predicted_answer_ids predicted_answer = tokenizer.decode(predicted_answer_ids) self.assertEqual(predicted_answer, "alan mathison turing")
def test_training(self): if not self.model_tester.is_training: return config, *inputs = self.model_tester.prepare_config_and_inputs() input_ids, token_type_ids, input_mask, scorer_encoder_inputs = inputs[0:4] config.return_dict = True tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa") # RealmKnowledgeAugEncoder training model = RealmKnowledgeAugEncoder(config) model.to(torch_device) model.train() inputs_dict = { "input_ids": scorer_encoder_inputs[0].to(torch_device), "attention_mask": scorer_encoder_inputs[1].to(torch_device), "token_type_ids": scorer_encoder_inputs[2].to(torch_device), "relevance_score": floats_tensor([self.model_tester.batch_size, self.model_tester.num_candidates]), } inputs_dict["labels"] = torch.zeros( (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device ) inputs = inputs_dict loss = model(**inputs).loss loss.backward() # RealmForOpenQA training openqa_config = copy.deepcopy(config) openqa_config.vocab_size = 30522 # the retrieved texts will inevitably have more than 99 vocabs. openqa_config.num_block_records = 5 openqa_config.searcher_beam_size = 2 block_records = np.array( [ b"This is the first record.", b"This is the second record.", b"This is the third record.", b"This is the fourth record.", b"This is the fifth record.", ], dtype=np.object, ) retriever = RealmRetriever(block_records, tokenizer) model = RealmForOpenQA(openqa_config, retriever) model.to(torch_device) model.train() inputs_dict = { "input_ids": input_ids[:1].to(torch_device), "attention_mask": input_mask[:1].to(torch_device), "token_type_ids": token_type_ids[:1].to(torch_device), "answer_ids": input_ids[:1].tolist(), } inputs = self._prepare_for_class(inputs_dict, RealmForOpenQA) loss = model(**inputs).reader_output.loss loss.backward()