def _get_triple_score(self, head: BoxTensor, tail: BoxTensor, relation: BoxTensor) -> torch.Tensor: if self.is_eval(): if len(head.data.shape) > len(tail.data.shape): tail.data = torch.cat(head.data.shape[-3] * [tail.data]) relation.data = torch.cat(head.data.shape[-3] * [relation.data]) elif len(head.data.shape) < len(tail.data.shape): head.data = torch.cat(tail.data.shape[-3] * [head.data]) relation.data = torch.cat(tail.data.shape[-3] * [relation.data]) head_relation_box = relation.gumbel_intersection( head, gumbel_beta=self.gumbel_beta) tail_relation_box = relation.gumbel_intersection( tail, gumbel_beta=self.gumbel_beta) tail_head_relation_box = tail_relation_box.gumbel_intersection( head, gumbel_beta=self.gumbel_beta) tail_head_relation_box_vol = tail_head_relation_box._log_soft_volume_adjusted( tail_head_relation_box.z, tail_head_relation_box.Z, temp=self.softbox_temp, gumbel_beta=self.gumbel_beta) tail_relation_box_vol = tail_relation_box._log_soft_volume_adjusted( tail_relation_box.z, tail_relation_box.Z, temp=self.softbox_temp, gumbel_beta=self.gumbel_beta) score_head = tail_head_relation_box_vol - tail_relation_box_vol return score_head
def _get_triple_score(self, head: BoxTensor, tail: BoxTensor, relation: BoxTensor) -> torch.Tensor: intersection_box = head.gumbel_intersection( tail, gumbel_beta=self.gumbel_beta) intersection_vol = intersection_box._log_soft_volume_adjusted( intersection_box.z, intersection_box.Z, temp=self.softbox_temp, gumbel_beta=self.gumbel_beta) tail_vol = tail._log_soft_volume_adjusted(tail.z, tail.Z, temp=self.softbox_temp, gumbel_beta=self.gumbel_beta) score = intersection_vol - tail_vol return score
def _get_triple_score(self, head: BoxTensor, tail: BoxTensor, relation: BoxTensor) -> torch.Tensor: if self.is_eval(): if len(head.data.shape) > len(tail.data.shape): tail.data = torch.cat(head.data.shape[-3] * [tail.data]) elif len(head.data.shape) < len(tail.data.shape): head.data = torch.cat(tail.data.shape[-3] * [head.data]) intersection_box = head.gumbel_intersection( tail, gumbel_beta=self.gumbel_beta) intersection_vol = intersection_box._log_bessel_volume( intersection_box.z, intersection_box.Z, gumbel_beta=self.gumbel_beta) tail_vol = tail._log_bessel_volume(tail.z, tail.Z, gumbel_beta=self.gumbel_beta) score = intersection_vol - tail_vol return score
def _get_triple_score(self, head: BoxTensor, tail: BoxTensor, relation: BoxTensor) -> torch.Tensor: if self.is_eval(): if len(head.data.shape) > len(tail.data.shape): tail.data = torch.cat(head.data.shape[-3] * [tail.data]) elif len(head.data.shape) < len(tail.data.shape): head.data = torch.cat(tail.data.shape[-3] * [head.data]) head_sample = self.reparam_trick(head, gumbel_beta=self.gumbel_beta, n_samples=self.n_samples) tail_sample = self.reparam_trick(tail, gumbel_beta=self.gumbel_beta, n_samples=self.n_samples) intersection_sample_box = head_sample.gumbel_intersection( tail_sample, gumbel_beta=self.gumbel_beta) intersection_box = head.gumbel_intersection( tail, gumbel_beta=self.gumbel_beta) intersection_volume_fwd = intersection_sample_box._log_gumbel_volume( intersection_sample_box.z, intersection_box.Z) intersection_volume_bwd = intersection_sample_box._log_gumbel_volume( intersection_box.z, intersection_sample_box.Z) tail_volume_fwd = tail_sample._log_gumbel_volume(tail.z, tail_sample.Z) tail_volume_bwd = tail_sample._log_gumbel_volume(tail_sample.z, tail.Z) # score = (intersection_volume_fwd + intersection_volume_bwd)/2 - ( # tail_volume_fwd + tail_volume_bwd)/2 intersection_score = torch.logsumexp( torch.stack((intersection_volume_fwd, intersection_volume_bwd)), 0) tail_score = torch.logsumexp( torch.stack((tail_volume_fwd, tail_volume_bwd)), 0) score = intersection_score - tail_score if len(torch.where(score > 0)[0]): breakpoint() return score