Esempio n. 1
0
    def test_project_entity(self):
        """Test _project_entity."""
        batch_size = 2
        embedding_dim = 3
        relation_dim = 5
        num_entities = 7

        # random entity embeddings & projections
        e = torch.rand(1, num_entities, embedding_dim)
        e = clamp_norm(e, maxnorm=1, p=2, dim=-1)
        e_p = torch.rand(1, num_entities, embedding_dim)

        # random relation embeddings & projections
        r_p = torch.rand(batch_size, 1, relation_dim)

        # project
        e_bot = project_entity(e=e, e_p=e_p, r_p=r_p)

        # check shape:
        assert e_bot.shape == (batch_size, num_entities, relation_dim)

        # check normalization
        assert (torch.norm(e_bot, dim=-1, p=2) <= 1.0 + 1.0e-06).all()

        # check equivalence of re-formulation
        # e_{\bot} = M_{re} e = (r_p e_p^T + I^{d_r \times d_e}) e
        #                     = r_p (e_p^T e) + e'
        m_re = r_p.unsqueeze(dim=-1) @ e_p.unsqueeze(dim=-2)
        m_re = m_re + torch.eye(relation_dim, embedding_dim).view(
            1, 1, relation_dim, embedding_dim)
        assert m_re.shape == (batch_size, num_entities, relation_dim,
                              embedding_dim)
        e_vanilla = (m_re @ e.unsqueeze(dim=-1)).squeeze(dim=-1)
        e_vanilla = clamp_norm(e_vanilla, p=2, dim=-1, maxnorm=1)
        assert torch.allclose(e_vanilla, e_bot)
Esempio n. 2
0
    def test_project_entity(self):
        """Test _project_entity."""
        # random entity embeddings & projections
        e = torch.rand(1,
                       self.instance.num_entities,
                       self.embedding_dim,
                       generator=self.generator)
        e = clamp_norm(e, maxnorm=1, p=2, dim=-1)
        e_p = torch.rand(1,
                         self.instance.num_entities,
                         self.embedding_dim,
                         generator=self.generator)

        # random relation embeddings & projections
        r = torch.rand(self.batch_size,
                       1,
                       self.instance.relation_dim,
                       generator=self.generator)
        r = clamp_norm(r, maxnorm=1, p=2, dim=-1)
        r_p = torch.rand(self.batch_size,
                         1,
                         self.instance.relation_dim,
                         generator=self.generator)

        # project
        e_bot = _project_entity(e=e, e_p=e_p, r=r, r_p=r_p)

        # check shape:
        assert e_bot.shape == (self.batch_size, self.instance.num_entities,
                               self.instance.relation_dim)

        # check normalization
        assert (torch.norm(e_bot, dim=-1, p=2) <= 1.0 + 1.0e-06).all()
Esempio n. 3
0
 def _exp_score(self, h, r, m_r, t, p, power_norm) -> torch.FloatTensor:
     assert power_norm
     h_bot, t_bot = [
         clamp_norm(x.unsqueeze(dim=0) @ m_r, p=2, dim=-1, maxnorm=1.0)
         for x in (h, t)
     ]
     return -((h_bot + r - t_bot)**p).sum()
Esempio n. 4
0
def test_clamp_norm():
    """Test  clamp_norm() ."""
    max_norm = 1.0
    gen = torch.manual_seed(42)
    eps = 1.0e-06
    for p in [1, 2, float('inf')]:
        for _ in range(10):
            x = torch.rand(10, 20, 30, generator=gen)
            for dim in range(x.ndimension()):
                x_c = clamp_norm(x, maxnorm=max_norm, p=p, dim=dim)

                # check maximum norm constraint
                assert (x_c.norm(p=p, dim=dim) <= max_norm + eps).all()

                # unchanged values for small norms
                norm = x.norm(p=p, dim=dim)
                mask = torch.stack([(norm < max_norm)] * x.shape[dim], dim=dim)
                assert (x_c[mask] == x[mask]).all()