def test_compute_l2_distances_keepdims(self): # Shape = [2, 1, 2, 2] lhs = [[[[0.0, 1.0], [2.0, 3.0]]], [[[10.0, 11.0], [12.0, 13.0]]]] rhs = [[[[0.0, 1.1], [2.3, 3.4]]], [[[10.4, 11.0], [12.4, 13.3]]]] # Shape = [2, 1, 2] distances = distance_utils.compute_l2_distances(lhs, rhs, keepdims=True) self.assertAllClose(distances, [[[[0.1], [0.5]]], [[[0.4], [0.5]]]])
def test_compute_l2_distances(self): # Shape = [2, 1, 2, 2] lhs_points = [[[[0.0, 1.0], [2.0, 3.0]]], [[[10.0, 11.0], [12.0, 13.0]]]] rhs_points = [[[[0.0, 1.1], [2.3, 3.4]]], [[[10.4, 11.0], [12.4, 13.3]]]] # Shape = [2, 1, 2] distances = distance_utils.compute_l2_distances(lhs_points, rhs_points) self.assertAllClose(distances, [[[0.1, 0.5]], [[0.4, 0.5]]])
def compute_mpjpes(lhs_points, rhs_points, point_masks=None): """Computes the Mean Per-Joint Position Errors (MPJPEs). If `point_masks` is specified, computes MPJPEs weighted by `point_masks`. Args: lhs_points: A tensor for the LHS points. Shape = [..., num_points, point_dim]. rhs_points: A tensor for the RHS points. Shape = [..., num_points, point_dim]. point_masks: A tensor for the masks. Shape = [..., num_points]. Ignored if None. Returns: A tensor for MPJPEs. Shape = [...]. """ distances = distance_utils.compute_l2_distances(lhs_points, rhs_points, keepdims=False) if point_masks is None: return tf.math.reduce_mean(distances, axis=-1) return (tf.math.reduce_sum(distances * point_masks, axis=-1) / tf.math.maximum(1e-12, tf.math.reduce_sum(point_masks, axis=-1)))
def compute_scale_distances(): sub_scale_distances_list = [] for lhs_indices, rhs_indices in scale_distance_point_index_pairs: lhs_points = get_points(points, lhs_indices) rhs_points = get_points(points, rhs_indices) sub_scale_distances_list.append( distance_utils.compute_l2_distances( lhs_points, rhs_points, keepdims=True)) sub_scale_distances = tf.concat(sub_scale_distances_list, axis=-1) return scale_distance_reduction_fn( sub_scale_distances, axis=-1, keepdims=True)
def compute_mpjpes(lhs_points, rhs_points): """Computes the Mean Per-Joint Position Errors (MPJPEs). Args: lhs_points: A tensor for the LHS points. Shape = [..., num_points, point_dim]. rhs_points: A tensor for the RHS points. Shape = [..., num_points, point_dim]. Returns: A tensor for MPJPEs. Shape = [...]. """ distances = distance_utils.compute_l2_distances(lhs_points, rhs_points, keepdims=False) return tf.math.reduce_mean(distances, axis=-1)
def compute_mpjpes(lhs_points, rhs_points, point_masks=None): """Computes the Mean Per-Joint Position Errors (MPJPEs). If `point_masks` is specified, computes MPJPEs weighted by `point_masks`. Args: lhs_points: A tensor for the LHS points. Shape = [..., num_points, point_dim]. rhs_points: A tensor for the RHS points. Shape = [..., num_points, point_dim]. point_masks: A tensor for the masks. Shape = [..., num_points]. Ignored if None. Returns: A tensor for MPJPEs. Shape = [...]. """ distances = distance_utils.compute_l2_distances( lhs_points, rhs_points, keepdims=False) return data_utils.reduce_weighted_mean( distances, weights=point_masks, axis=-1)