Пример #1
0
    def test_ermlp_predict(self):
        """Test ERMLP's predict function."""
        ermlp = ERMLP(config=ERMLP_CONFIG)
        predictions = ermlp.predict(triples=TEST_TRIPLES)

        self.assertEqual(len(predictions), len(TEST_TRIPLES))
        self.assertTrue(type(predictions.shape[0]), float)
Пример #2
0
    def test_ermlp_predict(self):
        """Test ERMLP's predict function."""
        ermlp = ERMLP(**ERMLP_CONFIG)
        ermlp.num_entities = ERMLP_CONFIG[NUM_ENTITIES]
        ermlp.num_relations = ERMLP_CONFIG[NUM_RELATIONS]
        predictions = ermlp.predict(triples=TEST_TRIPLES)

        self.assertEqual(len(predictions), len(TEST_TRIPLES))
        self.assertTrue(type(predictions.shape[0]), float)
Пример #3
0
 def test_instantiate_ermlp(self):
     """Test that ERMLP can be instantiated."""
     ermlp = ERMLP(**ERMLP_CONFIG)
     ermlp.num_entities = ERMLP_CONFIG[NUM_ENTITIES]
     ermlp.num_relations = ERMLP_CONFIG[NUM_RELATIONS]
     self.assertIsNotNone(ermlp)
     self.assertEqual(ermlp.num_entities, 5)
     self.assertEqual(ermlp.num_relations, 5)
     self.assertEqual(ermlp.embedding_dim, 2)
     self.assertEqual(ermlp.margin_loss, 4)
Пример #4
0
    def test_compute_scores_ermlp(self):
        """Test that SE's score function computes the scores correct."""
        ermlp = ERMLP(config=ERMLP_CONFIG)

        h_embs = torch.tensor([[1., 1.], [1., 1.]], dtype=torch.float)
        r_embs = torch.tensor([[1., 1.], [2., 2.]], dtype=torch.float)
        t_embs = torch.tensor([[2., 2.], [4., 4.]], dtype=torch.float)

        scores = ermlp._compute_scores(h_embs, r_embs,
                                       t_embs).detach().cpu().numpy().tolist()

        self.assertEqual(len(scores), 2)
Пример #5
0
 def test_instantiate_ermlp(self):
     """Test that ERMLP can be instantiated."""
     ermlp = ERMLP(config=ERMLP_CONFIG)
     self.assertIsNotNone(ermlp)
     self.assertEqual(ermlp.num_entities, 5)
     self.assertEqual(ermlp.num_relations, 5)
     self.assertEqual(ermlp.embedding_dim, 2)
     self.assertEqual(ermlp.margin_loss, 4)