def test_project_entity(self): """Test _project_entity.""" batch_size = 2 embedding_dim = 3 relation_dim = 5 num_entities = 7 # random entity embeddings & projections e = torch.rand(1, num_entities, embedding_dim) e = clamp_norm(e, maxnorm=1, p=2, dim=-1) e_p = torch.rand(1, num_entities, embedding_dim) # random relation embeddings & projections r_p = torch.rand(batch_size, 1, relation_dim) # project e_bot = project_entity(e=e, e_p=e_p, r_p=r_p) # check shape: assert e_bot.shape == (batch_size, num_entities, relation_dim) # check normalization assert (torch.norm(e_bot, dim=-1, p=2) <= 1.0 + 1.0e-06).all() # check equivalence of re-formulation # e_{\bot} = M_{re} e = (r_p e_p^T + I^{d_r \times d_e}) e # = r_p (e_p^T e) + e' m_re = r_p.unsqueeze(dim=-1) @ e_p.unsqueeze(dim=-2) m_re = m_re + torch.eye(relation_dim, embedding_dim).view( 1, 1, relation_dim, embedding_dim) assert m_re.shape == (batch_size, num_entities, relation_dim, embedding_dim) e_vanilla = (m_re @ e.unsqueeze(dim=-1)).squeeze(dim=-1) e_vanilla = clamp_norm(e_vanilla, p=2, dim=-1, maxnorm=1) assert torch.allclose(e_vanilla, e_bot)
def _exp_score(self, h, r, t, h_p, r_p, t_p, p, power_norm) -> torch.FloatTensor: # noqa: D102 assert power_norm h_bot = project_entity(e=h, e_p=h_p, r_p=r_p) t_bot = project_entity(e=t, e_p=t_p, r_p=r_p) return -((h_bot + r - t_bot)**p).sum()