def _get_triple_score(self, head: BoxTensor, tail: BoxTensor, relation: BoxTensor, head_rev: BoxTensor, tail_rev: BoxTensor, relation_rev: BoxTensor) -> torch.Tensor: if self.is_eval(): if len(head.data.shape) > len(tail.data.shape): tail.data = torch.cat(head.data.shape[-3] * [tail.data]) relation.data = torch.cat(head.data.shape[-3] * [relation.data]) relation_rev.data = torch.cat(head.data.shape[-3] * [relation_rev.data]) head_rev.data = torch.cat(head.data.shape[-3] * [head_rev.data]) elif len(head.data.shape) < len(tail.data.shape): head.data = torch.cat(tail.data.shape[-3] * [head.data]) relation.data = torch.cat(tail.data.shape[-3] * [relation.data]) relation_rev.data = torch.cat(tail.data.shape[-3] * [relation_rev.data]) tail_rev.data = torch.cat(tail.data.shape[-3] * [tail_rev.data]) tail_relation_box = relation.gumbel_intersection( tail, gumbel_beta=self.gumbel_beta) tail_head_relation_box = tail_relation_box.gumbel_intersection( head, gumbel_beta=self.gumbel_beta) tail_head_relation_box_vol = tail_head_relation_box._log_soft_volume_adjusted( tail_head_relation_box.z, tail_head_relation_box.Z, temp=self.softbox_temp, gumbel_beta=self.gumbel_beta) tail_vol = tail._log_soft_volume_adjusted(tail.z, tail.Z, temp=self.softbox_temp, gumbel_beta=self.gumbel_beta) score_fwd = tail_head_relation_box_vol - tail_vol tail_relation_box_rev = relation_rev.gumbel_intersection( tail_rev, gumbel_beta=self.gumbel_beta) tail_head_relation_box_rev = tail_relation_box_rev.gumbel_intersection( head_rev, gumbel_beta=self.gumbel_beta) tail_head_relation_box_rev_vol = tail_head_relation_box_rev._log_soft_volume_adjusted( tail_head_relation_box_rev.z, tail_head_relation_box_rev.Z, temp=self.softbox_temp, gumbel_beta=self.gumbel_beta) tail_rev_vol = tail_rev._log_soft_volume_adjusted( tail_rev.z, tail_rev.Z, temp=self.softbox_temp, gumbel_beta=self.gumbel_beta) score_rev = tail_head_relation_box_rev_vol - tail_rev_vol return 0.5 * (score_fwd + score_rev)
def _get_triple_score(self, head: BoxTensor, tail: BoxTensor, relation: BoxTensor) -> torch.Tensor: transformed_box = self.get_relation_transform(head, relation) if self.is_eval(): if len(head.data.shape) > len(tail.data.shape): tail.data = torch.cat(head.data.shape[-3] * [tail.data]) elif len(head.data.shape) < len(tail.data.shape): transformed_box.data = torch.cat(tail.data.shape[-3] * [transformed_box.data]) intersection_box = transformed_box.gumbel_intersection( tail, gumbel_beta=self.gumbel_beta) intersection_vol = intersection_box._log_soft_volume_adjusted( intersection_box.z, intersection_box.Z, temp=self.softbox_temp, gumbel_beta=self.gumbel_beta) tail_vol = tail._log_soft_volume_adjusted(tail.z, tail.Z, temp=self.softbox_temp, gumbel_beta=self.gumbel_beta) score = intersection_vol - tail_vol return score
def _get_triple_score(self, head: BoxTensor, tail: BoxTensor, relation: BoxTensor) -> torch.Tensor: intersection_box = head.gumbel_intersection( tail, gumbel_beta=self.gumbel_beta) intersection_vol = intersection_box._log_soft_volume_adjusted( intersection_box.z, intersection_box.Z, temp=self.softbox_temp, gumbel_beta=self.gumbel_beta) tail_vol = tail._log_soft_volume_adjusted(tail.z, tail.Z, temp=self.softbox_temp, gumbel_beta=self.gumbel_beta) score = intersection_vol - tail_vol return score