def test_inference_reader(self): config = RealmConfig(reader_beam_size=2, max_span_width=3) model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader", config=config) concat_input_ids = torch.arange(10).view((2, 5)) concat_token_type_ids = torch.tensor( [[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64) relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32) output = model(concat_input_ids, token_type_ids=concat_token_type_ids, relevance_score=relevance_score, return_dict=True) block_idx_expected_shape = torch.Size(()) start_pos_expected_shape = torch.Size((1, )) end_pos_expected_shape = torch.Size((1, )) self.assertEqual(output.block_idx.shape, block_idx_expected_shape) self.assertEqual(output.start_pos.shape, start_pos_expected_shape) self.assertEqual(output.end_pos.shape, end_pos_expected_shape) expected_block_idx = torch.tensor(1) expected_start_pos = torch.tensor(3) expected_end_pos = torch.tensor(3) self.assertTrue( torch.allclose(output.block_idx, expected_block_idx, atol=1e-4)) self.assertTrue( torch.allclose(output.start_pos, expected_start_pos, atol=1e-4)) self.assertTrue( torch.allclose(output.end_pos, expected_end_pos, atol=1e-4))
def create_and_check_reader( self, config, input_ids, token_type_ids, input_mask, scorer_encoder_inputs, reader_inputs, sequence_labels, token_labels, choice_labels, ): model = RealmReader(config=config) model.to(torch_device) model.eval() relevance_score = floats_tensor([self.reader_beam_size]) result = model( reader_inputs[0], attention_mask=reader_inputs[1], token_type_ids=reader_inputs[2], relevance_score=relevance_score, ) self.parent.assertEqual(result.block_idx.shape, ()) self.parent.assertEqual(result.candidate.shape, ()) self.parent.assertEqual(result.start_pos.shape, ()) self.parent.assertEqual(result.end_pos.shape, ())
def test_reader_from_pretrained(self): model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader") self.assertIsNotNone(model)