def interaction_function( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, m_r: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the interaction function for given embeddings. The embeddings have to be in a broadcastable shape. :param h: shape: (batch_size, num_entities, d_e) Head embeddings. :param r: shape: (batch_size, num_entities, d_r) Relation embeddings. :param t: shape: (batch_size, num_entities, d_e) Tail embeddings. :param m_r: shape: (batch_size, num_entities, d_e, d_r) The relation specific linear transformations. :return: shape: (batch_size, num_entities) The scores. """ # project to relation specific subspace, shape: (b, e, d_r) h_bot = h @ m_r t_bot = t @ m_r # ensure constraints h_bot = clamp_norm(h_bot, p=2, dim=-1, maxnorm=1.0) t_bot = clamp_norm(t_bot, p=2, dim=-1, maxnorm=1.0) # evaluate score function, shape: (b, e) return -linalg.vector_norm(h_bot + r - t_bot, dim=-1)**2
def score_h(self, rt_batch: torch.LongTensor, slice_size: Optional[int] = None) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=None) r = self.relation_embeddings(indices=rt_batch[:, 0]) t = self.entity_embeddings(indices=rt_batch[:, 1]) # TODO: Use torch.cdist (see note above in score_hrt()) return -linalg.vector_norm(h[None, :, :] + (r[:, None, :] - t[:, None, :]), dim=-1, ord=self.scoring_fct_norm)
def score_t(self, hr_batch: torch.LongTensor, **kwargs) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=hr_batch[:, 0]) r = self.relation_embeddings(indices=hr_batch[:, 1]) t = self.entity_embeddings(indices=None) # TODO: Use torch.cdist (see note above in score_hrt()) return -linalg.vector_norm( h[:, None, :] + r[:, None, :] - t[None, :, :], dim=-1, ord=self.scoring_fct_norm)
def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=hrt_batch[:, 0]) r = self.relation_embeddings(indices=hrt_batch[:, 1]) t = self.entity_embeddings(indices=hrt_batch[:, 2]) # TODO: Use torch.cdist # There were some performance/memory issues with cdist, cf. # https://github.com/pytorch/pytorch/issues?q=cdist however, @mberr thinks # they are mostly resolved by now. A Benefit would be that we can harness the # future (performance) improvements made by the core torch developers. However, # this will require some benchmarking. return -linalg.vector_norm(h + r - t, dim=-1, ord=self.scoring_fct_norm, keepdim=True)
def score_hrt(self, hrt_batch: torch.LongTensor, **kwargs) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=hrt_batch[:, 0]) d_r = self.relation_embeddings(indices=hrt_batch[:, 1]) w_r = self.normal_vector_embeddings(indices=hrt_batch[:, 1]) t = self.entity_embeddings(indices=hrt_batch[:, 2]) # Project to hyperplane ph = h - torch.sum(w_r * h, dim=-1, keepdim=True) * w_r pt = t - torch.sum(w_r * t, dim=-1, keepdim=True) * w_r # Regularization term self.regularize_if_necessary() return -linalg.vector_norm(ph + d_r - pt, ord=2, dim=-1, keepdim=True)
def update(self, *tensors: torch.FloatTensor) -> None: # noqa: D102 if len(tensors) != 3: raise KeyError("Expects exactly three tensors") if self.apply_only_once and self.updated: return entity_embeddings, normal_vector_embeddings, relation_embeddings = tensors # Entity soft constraint self.regularization_term += torch.sum(functional.relu(linalg.vector_norm(entity_embeddings, dim=-1) ** 2 - 1.0)) # Orthogonality soft constraint d_r_n = functional.normalize(relation_embeddings, dim=-1) self.regularization_term += torch.sum( functional.relu(torch.sum((normal_vector_embeddings * d_r_n) ** 2, dim=-1) - self.epsilon), ) self.updated = True
def score_h(self, rt_batch: torch.LongTensor, **kwargs) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=None) rel_id = rt_batch[:, 0] d_r = self.relation_embeddings(indices=rel_id) w_r = self.normal_vector_embeddings(indices=rel_id) t = self.entity_embeddings(indices=rt_batch[:, 1]) # Project to hyperplane ph = h[None, :, :] - torch.sum(w_r[:, None, :] * h[None, :, :], dim=-1, keepdim=True) * w_r[:, None, :] pt = t - torch.sum(w_r * t, dim=-1, keepdim=True) * w_r # Regularization term self.regularize_if_necessary() return -linalg.vector_norm( ph + (d_r[:, None, :] - pt[:, None, :]), ord=2, dim=-1)
def score_t(self, hr_batch: torch.LongTensor, slice_size: Optional[int] = None ) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=hr_batch[:, 0]) d_r = self.relation_embeddings(indices=hr_batch[:, 1]) w_r = self.normal_vector_embeddings(indices=hr_batch[:, 1]) t = self.entity_embeddings(indices=None) # Project to hyperplane ph = h - torch.sum(w_r * h, dim=-1, keepdim=True) * w_r pt = t[None, :, :] - torch.sum(w_r[:, None, :] * t[None, :, :], dim=-1, keepdim=True) * w_r[:, None, :] # Regularization term self.regularize_if_necessary() return -linalg.vector_norm( ph[:, None, :] + d_r[:, None, :] - pt, ord=2, dim=-1)
def interaction_function( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the interaction function of ComplEx for given embeddings. The embeddings have to be in a broadcastable shape. WARNING: No forward constraints are applied. :param h: shape: (..., e, 2) Head embeddings. Last dimension corresponds to (real, imag). :param r: shape: (..., e, 2) Relation embeddings. Last dimension corresponds to (real, imag). :param t: shape: (..., e, 2) Tail embeddings. Last dimension corresponds to (real, imag). :return: shape: (...) The scores. """ # Decompose into real and imaginary part h_re = h[..., 0] h_im = h[..., 1] r_re = r[..., 0] r_im = r[..., 1] # Rotate (=Hadamard product in complex space). rot_h = torch.stack( [ h_re * r_re - h_im * r_im, h_re * r_im + h_im * r_re, ], dim=-1, ) # Workaround until https://github.com/pytorch/pytorch/issues/30704 is fixed diff = rot_h - t scores = -linalg.vector_norm(diff, dim=(-2, -1)) return scores