def test_rescal_predict(self): """Test RESCAL's predict function.""" rescal = RESCAL(config=RESCAL_CONFIG) predictions = rescal.predict(triples=TEST_TRIPLES) self.assertEqual(len(predictions), len(TEST_TRIPLES)) self.assertTrue(type(predictions.shape[0]), float)
def test_rescal_predict(self): """Test RESCAL's predict function.""" rescal = RESCAL(**RESCAL_CONFIG) rescal.num_entities = RESCAL_CONFIG[NUM_ENTITIES] rescal.num_relations = RESCAL_CONFIG[NUM_RELATIONS] predictions = rescal.predict(triples=TEST_TRIPLES) self.assertEqual(len(predictions), len(TEST_TRIPLES)) self.assertTrue(type(predictions.shape[0]), float)
def test_instantiate_rescal(self): """Test that RESCAL can be instantiated.""" rescal = RESCAL(**RESCAL_CONFIG) rescal.num_entities = RESCAL_CONFIG[NUM_ENTITIES] rescal.num_relations = RESCAL_CONFIG[NUM_RELATIONS] self.assertIsNotNone(rescal) self.assertEqual(rescal.num_entities, 5) self.assertEqual(rescal.num_relations, 5) self.assertEqual(rescal.embedding_dim, 5) self.assertEqual(rescal.margin_loss, 4)
def __init__(self, num_entities, num_relations, relation_embeddings, entity_embeddings, preferred_device='cpu'): super(RescalUnit, self).__init__() if preferred_device == 'cuda': preferred_device = 'gpu' self.model = RESCAL( preferred_device=preferred_device, random_seed=0, embedding_dim=entity_embeddings.shape[1], margin_loss=1, scoring_function=2, ) self.model.num_entities = num_entities self.model.num_relations = num_relations self.model.relation_embeddings = nn.Embedding.from_pretrained( relation_embeddings) # noqa self.model.entity_embeddings = nn.Embedding.from_pretrained( entity_embeddings) # noqa
def test_instantiate_rescal(self): """Test that RESCAL can be instantiated.""" rescal = RESCAL(config=RESCAL_CONFIG) self.assertIsNotNone(rescal) self.assertEqual(rescal.num_entities, 5) self.assertEqual(rescal.num_relations, 5) self.assertEqual(rescal.embedding_dim, 5) self.assertEqual(rescal.margin_loss, 4)
def get_model(args, num_entities, num_relations): '''Initializes a RESCAL model. Parameters ---------- num_entities - int: The total number of distinct entities in the dataset. num_relations - int: The total number of distinct realtions in the dataset. ''' model = RESCAL( preferred_device=args.preferred_device, random_seed=args.random_seed, embedding_dim=args.embedding_dim, margin_loss=args.margin_loss, scoring_function=args.scoring_function, ) model.num_entities = num_entities model.num_relations = num_relations # model._init_embeddings() return model