def test_inference_reader(self):
        config = RealmConfig(reader_beam_size=2, max_span_width=3)
        model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader",
                                            config=config)

        concat_input_ids = torch.arange(10).view((2, 5))
        concat_token_type_ids = torch.tensor(
            [[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
        relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32)

        output = model(concat_input_ids,
                       token_type_ids=concat_token_type_ids,
                       relevance_score=relevance_score,
                       return_dict=True)

        block_idx_expected_shape = torch.Size(())
        start_pos_expected_shape = torch.Size((1, ))
        end_pos_expected_shape = torch.Size((1, ))
        self.assertEqual(output.block_idx.shape, block_idx_expected_shape)
        self.assertEqual(output.start_pos.shape, start_pos_expected_shape)
        self.assertEqual(output.end_pos.shape, end_pos_expected_shape)

        expected_block_idx = torch.tensor(1)
        expected_start_pos = torch.tensor(3)
        expected_end_pos = torch.tensor(3)

        self.assertTrue(
            torch.allclose(output.block_idx, expected_block_idx, atol=1e-4))
        self.assertTrue(
            torch.allclose(output.start_pos, expected_start_pos, atol=1e-4))
        self.assertTrue(
            torch.allclose(output.end_pos, expected_end_pos, atol=1e-4))
 def create_and_check_reader(
     self,
     config,
     input_ids,
     token_type_ids,
     input_mask,
     scorer_encoder_inputs,
     reader_inputs,
     sequence_labels,
     token_labels,
     choice_labels,
 ):
     model = RealmReader(config=config)
     model.to(torch_device)
     model.eval()
     relevance_score = floats_tensor([self.reader_beam_size])
     result = model(
         reader_inputs[0],
         attention_mask=reader_inputs[1],
         token_type_ids=reader_inputs[2],
         relevance_score=relevance_score,
     )
     self.parent.assertEqual(result.block_idx.shape, ())
     self.parent.assertEqual(result.candidate.shape, ())
     self.parent.assertEqual(result.start_pos.shape, ())
     self.parent.assertEqual(result.end_pos.shape, ())
 def test_reader_from_pretrained(self):
     model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader")
     self.assertIsNotNone(model)