def test_trans_d_predict(self): """Test TransD's predict function.""" trans_d = TransR(config=TRANS_D_CONFIG) predictions = trans_d.predict(triples=TEST_TRIPLES) self.assertEqual(len(predictions), len(TEST_TRIPLES)) self.assertTrue(type(predictions.shape[0]), float)
def test_trans_r_predict(self): """Test TransR's predict function.""" trans_r = TransR(**TRANS_R_CONFIG) predictions = trans_r.predict(triples=TEST_TRIPLES) self.assertEqual(len(predictions), len(TEST_TRIPLES)) self.assertTrue(type(predictions.shape[0]), float)
def test_trans_d_predict(self): """Test TransD's predict function.""" trans_d = TransR(**TRANS_D_CONFIG) trans_d.num_entities = TRANS_D_CONFIG[NUM_ENTITIES] trans_d.num_relations = TRANS_D_CONFIG[NUM_RELATIONS] predictions = trans_d.predict(triples=TEST_TRIPLES) self.assertEqual(len(predictions), len(TEST_TRIPLES)) self.assertTrue(type(predictions.shape[0]), float)
def test_compute_scores_trans_r(self): """Test that TransR's socore function computes the scores correct.""" trans_r = TransR(config=TRANS_R_CONFIG) proj_h_embs = torch.tensor([[1., 1.], [1., 1.]], dtype=torch.float) proj_r_embs = torch.tensor([[1., 1.], [2., 2.]], dtype=torch.float) proj_t_embs = torch.tensor([[2., 2.], [4., 4.]], dtype=torch.float) scores = trans_r._compute_scores(proj_h_embs, proj_r_embs, proj_t_embs).cpu().numpy().tolist() self.assertEqual(scores, [0., 4.])
def test_instantiate_trans_r(self): """Test that TransR can be instantiated.""" trans_r = TransR(**TRANS_R_CONFIG) trans_r.num_entities = TRANS_R_CONFIG[NUM_ENTITIES] trans_r.num_relations = TRANS_R_CONFIG[NUM_RELATIONS] self.assertIsNotNone(trans_r) self.assertEqual(trans_r.num_entities, 5) self.assertEqual(trans_r.num_relations, 5) self.assertEqual(trans_r.embedding_dim, 5) self.assertEqual(trans_r.relation_embedding_dim, 3) self.assertEqual(trans_r.scoring_fct_norm, 1) self.assertEqual(trans_r.margin_loss, 4)