예제 #1
0
    def test_make(self):
        """Test make."""
        initializer = Mock()
        normalizer = Mock()
        constrainer = Mock()
        regularizer = Mock()
        for embedding_dim, shape in [
            (None, (3, )),
            (None, (3, 5)),
            (3, None),
        ]:
            spec = EmbeddingSpecification(
                embedding_dim=embedding_dim,
                shape=shape,
                initializer=initializer,
                normalizer=normalizer,
                constrainer=constrainer,
                regularizer=regularizer,
            )
            emb = spec.make(num_embeddings=self.num)

            # check shape
            self.assertEqual(emb.embedding_dim,
                             (embedding_dim or int(numpy.prod(shape))))
            self.assertEqual(emb.shape, (shape or (embedding_dim, )))
            self.assertEqual(emb.num_embeddings, self.num)

            # check attributes
            self.assertIs(emb.initializer, initializer)
            self.assertIs(emb.normalizer, normalizer)
            self.assertIs(emb.constrainer, constrainer)
            self.assertIs(emb.regularizer, regularizer)
예제 #2
0
 def test_make_complex(self):
     """Test making a complex embedding."""
     s = EmbeddingSpecification(
         shape=(5, 5),
         dtype=torch.cfloat,
     )
     e = s.make(num_embeddings=100)
     self.assertEqual((5, 10), e.shape)
예제 #3
0
 def __init__(self, triples_factory: TriplesFactory):
     super().__init__(
         triples_factory=triples_factory,
         entity_representations=EmbeddingSpecification(embedding_dim=50),
         relation_representations=EmbeddingSpecification(embedding_dim=50),
     )
     num_entities = self.num_entities
     self.scores = torch.arange(num_entities, dtype=torch.float)
예제 #4
0
 def __init__(self, *, triples_factory: CoreTriplesFactory):
     super().__init__(
         triples_factory=triples_factory,
         entity_representations=EmbeddingSpecification(embedding_dim=50),
         relation_representations=EmbeddingSpecification(embedding_dim=50),
     )
     num_entities = self.num_entities
     self.scores = torch.arange(num_entities, dtype=torch.float, requires_grad=True)
     self.num_backward_propagations = 0
예제 #5
0
 def __init__(self, *, triples_factory: TriplesFactory):
     super().__init__(
         triples_factory=triples_factory,
         entity_representations=EmbeddingSpecification(embedding_dim=50),
         relation_representations=EmbeddingSpecification(embedding_dim=50),
     )
     self.entity_embeddings = nn.Embedding(self.num_entities,
                                           self.embedding_dim)
     self.relation_embeddings = nn.Embedding(self.num_relations,
                                             self.embedding_dim)
예제 #6
0
 def test_make_errors(self):
     """Test errors on making with an invalid key."""
     with self.assertRaises(KeyError):
         EmbeddingSpecification(
             shape=(1, 1),
             initializer="garbage",
         ).make(num_embeddings=1)
     with self.assertRaises(KeyError):
         EmbeddingSpecification(
             shape=(1, 1),
             constrainer="garbage",
         ).make(num_embeddings=1)
     with self.assertRaises(KeyError):
         EmbeddingSpecification(
             shape=(1, 1),
             normalizer="garbage",
         ).make(num_embeddings=1)
예제 #7
0
 def _pre_instantiation_hook(
     self, kwargs: MutableMapping[str, Any]
 ) -> MutableMapping[str, Any]:  # noqa: D102
     kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
     kwargs["combined"] = pykeen.nn.emb.CombinedCompGCNRepresentations(
         triples_factory=generate_triples_factory(
             num_entities=self.num_entities,
             num_relations=self.num_relations,
             num_triples=self.num_triples,
             create_inverse_triples=True,
         ),
         embedding_specification=EmbeddingSpecification(
             embedding_dim=self.dim),
         dims=self.dim,
     )
     return kwargs
예제 #8
0
class RGCNRepresentationTests(cases.RepresentationTestCase):
    """Test RGCN representations."""

    cls = pykeen.nn.emb.RGCNRepresentations
    num = 8
    kwargs = dict(
        embedding_specification=EmbeddingSpecification(embedding_dim=num), )
    num_relations: int = 7
    num_triples: int = 31
    num_bases: int = 2

    def _pre_instantiation_hook(
        self, kwargs: MutableMapping[str, Any]
    ) -> MutableMapping[str, Any]:  # noqa: D102
        kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
        kwargs["triples_factory"] = generate_triples_factory(
            num_entities=self.num,
            num_relations=self.num_relations,
            num_triples=self.num_triples,
        )
        return kwargs