Exemplo n.º 1
0
    def test_inference_embedder(self):
        retriever_projected_size = 128

        model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
        input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
        output = model(input_ids)[0]

        expected_shape = torch.Size((1, retriever_projected_size))
        self.assertEqual(output.shape, expected_shape)

        expected_slice = torch.tensor([[-0.0714, -0.0837, -0.1314]])
        self.assertTrue(torch.allclose(output[:, :3], expected_slice, atol=1e-4))
Exemplo n.º 2
0
 def create_and_check_embedder(
     self,
     config,
     input_ids,
     token_type_ids,
     input_mask,
     scorer_encoder_inputs,
     reader_inputs,
     sequence_labels,
     token_labels,
     choice_labels,
 ):
     model = RealmEmbedder(config=config)
     model.to(torch_device)
     model.eval()
     result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
     self.parent.assertEqual(result.projected_score.shape, (self.batch_size, self.retriever_proj_size))
 def test_embedder_from_pretrained(self):
     model = RealmEmbedder.from_pretrained(
         "qqaatw/realm-cc-news-pretrained-embedder")
     self.assertIsNotNone(model)