Esempio n. 1
0
 def get_queries(self, queries):
     """Compute embedding and biases of queries."""
     lhs = givens_reflection(self.rel_diag(queries[:, 1]),
                             self.entity(queries[:, 0]))
     rel = self.rel(queries[:, 1])
     lhs_biases = self.bh(queries[:, 0])
     return lhs + rel, lhs_biases
Esempio n. 2
0
 def get_queries(self, queries):
     """Compute embedding and biases of queries."""
     c = F.softplus(self.c[queries[:, 1]])
     rel, _ = torch.chunk(self.rel(queries[:, 1]), 2, dim=1)
     rel = expmap0(rel, c)
     lhs = givens_reflection(self.rel_diag(queries[:, 1]),
                             self.entity(queries[:, 0]))
     lhs = expmap0(lhs, c)
     res = project(mobius_add(lhs, rel, c), c)
     return (res, c), self.bh(queries[:, 0])
Esempio n. 3
0
 def get_queries(self, queries):
     """Compute embedding and biases of queries."""
     c = F.softplus(self.c[queries[:, 1]])
     head = self.entity(queries[:, 0])
     rot_mat, ref_mat = torch.chunk(self.rel_diag(queries[:, 1]), 2, dim=1)
     rot_q = givens_rotations(rot_mat, head).view((-1, 1, self.rank))
     ref_q = givens_reflection(ref_mat, head).view((-1, 1, self.rank))
     cands = torch.cat([ref_q, rot_q], dim=1)
     context_vec = self.context_vec(queries[:, 1]).view((-1, 1, self.rank))
     att_weights = torch.sum(context_vec * cands * self.scale,
                             dim=-1,
                             keepdim=True)
     att_weights = self.act(att_weights)
     att_q = torch.sum(att_weights * cands, dim=1)
     lhs = expmap0(att_q, c)
     rel, _ = torch.chunk(self.rel(queries[:, 1]), 2, dim=1)
     rel = expmap0(rel, c)
     res = project(mobius_add(lhs, rel, c), c)
     return (res, c), self.bh(queries[:, 0])
Esempio n. 4
0
 def get_reflection_queries(self, queries):
     lhs_ref_e = givens_reflection(self.ref(queries[:, 1]),
                                   self.entity(queries[:, 0]))
     return lhs_ref_e