def test_make(self): """Test make.""" initializer = Mock() normalizer = Mock() constrainer = Mock() regularizer = Mock() for embedding_dim, shape in [ (None, (3, )), (None, (3, 5)), (3, None), ]: spec = EmbeddingSpecification( embedding_dim=embedding_dim, shape=shape, initializer=initializer, normalizer=normalizer, constrainer=constrainer, regularizer=regularizer, ) emb = spec.make(num_embeddings=self.num) # check shape self.assertEqual(emb.embedding_dim, (embedding_dim or int(numpy.prod(shape)))) self.assertEqual(emb.shape, (shape or (embedding_dim, ))) self.assertEqual(emb.num_embeddings, self.num) # check attributes self.assertIs(emb.initializer, initializer) self.assertIs(emb.normalizer, normalizer) self.assertIs(emb.constrainer, constrainer) self.assertIs(emb.regularizer, regularizer)
def test_make_complex(self): """Test making a complex embedding.""" s = EmbeddingSpecification( shape=(5, 5), dtype=torch.cfloat, ) e = s.make(num_embeddings=100) self.assertEqual((5, 10), e.shape)
def __init__(self, triples_factory: TriplesFactory): super().__init__( triples_factory=triples_factory, entity_representations=EmbeddingSpecification(embedding_dim=50), relation_representations=EmbeddingSpecification(embedding_dim=50), ) num_entities = self.num_entities self.scores = torch.arange(num_entities, dtype=torch.float)
def __init__(self, *, triples_factory: CoreTriplesFactory): super().__init__( triples_factory=triples_factory, entity_representations=EmbeddingSpecification(embedding_dim=50), relation_representations=EmbeddingSpecification(embedding_dim=50), ) num_entities = self.num_entities self.scores = torch.arange(num_entities, dtype=torch.float, requires_grad=True) self.num_backward_propagations = 0
def __init__(self, *, triples_factory: TriplesFactory): super().__init__( triples_factory=triples_factory, entity_representations=EmbeddingSpecification(embedding_dim=50), relation_representations=EmbeddingSpecification(embedding_dim=50), ) self.entity_embeddings = nn.Embedding(self.num_entities, self.embedding_dim) self.relation_embeddings = nn.Embedding(self.num_relations, self.embedding_dim)
def test_make_errors(self): """Test errors on making with an invalid key.""" with self.assertRaises(KeyError): EmbeddingSpecification( shape=(1, 1), initializer="garbage", ).make(num_embeddings=1) with self.assertRaises(KeyError): EmbeddingSpecification( shape=(1, 1), constrainer="garbage", ).make(num_embeddings=1) with self.assertRaises(KeyError): EmbeddingSpecification( shape=(1, 1), normalizer="garbage", ).make(num_embeddings=1)
def _pre_instantiation_hook( self, kwargs: MutableMapping[str, Any] ) -> MutableMapping[str, Any]: # noqa: D102 kwargs = super()._pre_instantiation_hook(kwargs=kwargs) kwargs["combined"] = pykeen.nn.emb.CombinedCompGCNRepresentations( triples_factory=generate_triples_factory( num_entities=self.num_entities, num_relations=self.num_relations, num_triples=self.num_triples, create_inverse_triples=True, ), embedding_specification=EmbeddingSpecification( embedding_dim=self.dim), dims=self.dim, ) return kwargs
class RGCNRepresentationTests(cases.RepresentationTestCase): """Test RGCN representations.""" cls = pykeen.nn.emb.RGCNRepresentations num = 8 kwargs = dict( embedding_specification=EmbeddingSpecification(embedding_dim=num), ) num_relations: int = 7 num_triples: int = 31 num_bases: int = 2 def _pre_instantiation_hook( self, kwargs: MutableMapping[str, Any] ) -> MutableMapping[str, Any]: # noqa: D102 kwargs = super()._pre_instantiation_hook(kwargs=kwargs) kwargs["triples_factory"] = generate_triples_factory( num_entities=self.num, num_relations=self.num_relations, num_triples=self.num_triples, ) return kwargs