예제 #1
0
    def test_rescal_predict(self):
        """Test RESCAL's predict function."""
        rescal = RESCAL(config=RESCAL_CONFIG)
        predictions = rescal.predict(triples=TEST_TRIPLES)

        self.assertEqual(len(predictions), len(TEST_TRIPLES))
        self.assertTrue(type(predictions.shape[0]), float)
예제 #2
0
파일: test_models.py 프로젝트: mberr/PyKEEN
    def test_rescal_predict(self):
        """Test RESCAL's predict function."""
        rescal = RESCAL(**RESCAL_CONFIG)
        rescal.num_entities = RESCAL_CONFIG[NUM_ENTITIES]
        rescal.num_relations = RESCAL_CONFIG[NUM_RELATIONS]
        predictions = rescal.predict(triples=TEST_TRIPLES)

        self.assertEqual(len(predictions), len(TEST_TRIPLES))
        self.assertTrue(type(predictions.shape[0]), float)
예제 #3
0
파일: test_models.py 프로젝트: mberr/PyKEEN
 def test_instantiate_rescal(self):
     """Test that RESCAL can be instantiated."""
     rescal = RESCAL(**RESCAL_CONFIG)
     rescal.num_entities = RESCAL_CONFIG[NUM_ENTITIES]
     rescal.num_relations = RESCAL_CONFIG[NUM_RELATIONS]
     self.assertIsNotNone(rescal)
     self.assertEqual(rescal.num_entities, 5)
     self.assertEqual(rescal.num_relations, 5)
     self.assertEqual(rescal.embedding_dim, 5)
     self.assertEqual(rescal.margin_loss, 4)
예제 #4
0
    def __init__(self,
                 num_entities,
                 num_relations,
                 relation_embeddings,
                 entity_embeddings,
                 preferred_device='cpu'):
        super(RescalUnit, self).__init__()

        if preferred_device == 'cuda':
            preferred_device = 'gpu'

        self.model = RESCAL(
            preferred_device=preferred_device,
            random_seed=0,
            embedding_dim=entity_embeddings.shape[1],
            margin_loss=1,
            scoring_function=2,
        )

        self.model.num_entities = num_entities
        self.model.num_relations = num_relations
        self.model.relation_embeddings = nn.Embedding.from_pretrained(
            relation_embeddings)  # noqa
        self.model.entity_embeddings = nn.Embedding.from_pretrained(
            entity_embeddings)  # noqa
예제 #5
0
 def test_instantiate_rescal(self):
     """Test that RESCAL can be instantiated."""
     rescal = RESCAL(config=RESCAL_CONFIG)
     self.assertIsNotNone(rescal)
     self.assertEqual(rescal.num_entities, 5)
     self.assertEqual(rescal.num_relations, 5)
     self.assertEqual(rescal.embedding_dim, 5)
     self.assertEqual(rescal.margin_loss, 4)
예제 #6
0
def get_model(args, num_entities, num_relations):
    '''Initializes a RESCAL model.
    Parameters
    ----------
        num_entities - int: The total number of distinct entities in the
                            dataset.
        num_relations - int: The total number of distinct realtions in the
                             dataset.
    '''
    model = RESCAL(
        preferred_device=args.preferred_device,
        random_seed=args.random_seed,
        embedding_dim=args.embedding_dim,
        margin_loss=args.margin_loss,
        scoring_function=args.scoring_function,
    )

    model.num_entities = num_entities
    model.num_relations = num_relations

    # model._init_embeddings()

    return model