def all_distances_sg(self, input_tensor, nodes_ind): """distance calculations for tree based model training, with gradient stops. Args: input_tensor: Tensor of size batch_size x 3 containing user, positive item, negative item indices. nodes_ind: Tensor of size batch_size x tot_node_batch containing nodes indices, where tot_node_batch equals to sum(self.node_batch_per_level). Returns: user_node_distance: Tensor of size batch_size x tot_node_batch containing the distances between the nodes and the user. item_node_distance: Tensor of size batch_size x tot_node_batch containing the distances between the nodes and the positive item. user_item_distance: Tensor of size batch_size x 2 containing the distances between the user and the positive and negative items. """ c = tf.math.softplus(self.c) users, items, nodes = self.as_hyperbolic_points( input_tensor, nodes_ind) user_node_distance = hyp_utils.hyp_distance_batch_rhs( tf.stop_gradient(users), nodes, c) pos_item_node_distance = hyp_utils.hyp_distance_batch_rhs( tf.stop_gradient(items[:, 0, :]), nodes, c) user_item_distance = hyp_utils.hyp_distance_batch_rhs(users, items, c) return user_node_distance, pos_item_node_distance, user_item_distance
def similarity_score(self, lhs, rhs, eval_mode): c = tf.math.softplus(self.c) if eval_mode and self.rhs_dep_lhs: return -hyp_utils.hyp_distance_batch_rhs(lhs, rhs, c)**2 elif eval_mode and not self.rhs_dep_lhs: return -hyp_utils.hyp_distance_all_pairs(lhs, rhs, c)**2 return -hyp_utils.hyp_distance(lhs, rhs, c)**2