Example #1
0
    def test_trans_d_predict(self):
        """Test TransD's predict function."""
        trans_d = TransR(config=TRANS_D_CONFIG)
        predictions = trans_d.predict(triples=TEST_TRIPLES)

        self.assertEqual(len(predictions), len(TEST_TRIPLES))
        self.assertTrue(type(predictions.shape[0]), float)
Example #2
0
    def test_trans_r_predict(self):
        """Test TransR's predict function."""
        trans_r = TransR(**TRANS_R_CONFIG)
        predictions = trans_r.predict(triples=TEST_TRIPLES)

        self.assertEqual(len(predictions), len(TEST_TRIPLES))
        self.assertTrue(type(predictions.shape[0]), float)
Example #3
0
    def test_trans_d_predict(self):
        """Test TransD's predict function."""
        trans_d = TransR(**TRANS_D_CONFIG)
        trans_d.num_entities = TRANS_D_CONFIG[NUM_ENTITIES]
        trans_d.num_relations = TRANS_D_CONFIG[NUM_RELATIONS]
        predictions = trans_d.predict(triples=TEST_TRIPLES)

        self.assertEqual(len(predictions), len(TEST_TRIPLES))
        self.assertTrue(type(predictions.shape[0]), float)
Example #4
0
    def test_compute_scores_trans_r(self):
        """Test that TransR's socore function computes the scores correct."""
        trans_r = TransR(config=TRANS_R_CONFIG)
        proj_h_embs = torch.tensor([[1., 1.], [1., 1.]], dtype=torch.float)
        proj_r_embs = torch.tensor([[1., 1.], [2., 2.]], dtype=torch.float)
        proj_t_embs = torch.tensor([[2., 2.], [4., 4.]], dtype=torch.float)

        scores = trans_r._compute_scores(proj_h_embs, proj_r_embs,
                                         proj_t_embs).cpu().numpy().tolist()

        self.assertEqual(scores, [0., 4.])
Example #5
0
 def test_instantiate_trans_r(self):
     """Test that TransR can be instantiated."""
     trans_r = TransR(**TRANS_R_CONFIG)
     trans_r.num_entities = TRANS_R_CONFIG[NUM_ENTITIES]
     trans_r.num_relations = TRANS_R_CONFIG[NUM_RELATIONS]
     self.assertIsNotNone(trans_r)
     self.assertEqual(trans_r.num_entities, 5)
     self.assertEqual(trans_r.num_relations, 5)
     self.assertEqual(trans_r.embedding_dim, 5)
     self.assertEqual(trans_r.relation_embedding_dim, 3)
     self.assertEqual(trans_r.scoring_fct_norm, 1)
     self.assertEqual(trans_r.margin_loss, 4)