def test_reduce_mean(self): # Shape = [2, 3, 2]. tensor = tf.constant([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]]) # Shape = [2, 3, 1]. weights = tf.constant([[[1.0], [0.0], [1.0]], [[0.0], [1.0], [0.0]]]) # Shape = [2, 1, 2]. means = data_utils.reduce_weighted_mean( tensor, weights, axis=-2, keepdims=True) self.assertAllClose(means, [[[3.0, 4.0]], [[9.0, 10.0]]])
def centralize_masked_points(points, point_masks): """Sets masked out points to the centers of rest of the points. Args: points: A tensor for points. Shape = [..., num_points, point_dim]. point_masks: A tensor for the masks. Shape = [..., num_points]. Returns: A tensor for points with masked out points centralized. """ point_masks = tf.expand_dims(point_masks, axis=-1) kept_centers = data_utils.reduce_weighted_mean( points, weights=point_masks, axis=-2, keepdims=True) return tf.where(tf.cast(point_masks, dtype=tf.bool), points, kept_centers)
def compute_mpjpes(lhs_points, rhs_points, point_masks=None): """Computes the Mean Per-Joint Position Errors (MPJPEs). If `point_masks` is specified, computes MPJPEs weighted by `point_masks`. Args: lhs_points: A tensor for the LHS points. Shape = [..., num_points, point_dim]. rhs_points: A tensor for the RHS points. Shape = [..., num_points, point_dim]. point_masks: A tensor for the masks. Shape = [..., num_points]. Ignored if None. Returns: A tensor for MPJPEs. Shape = [...]. """ distances = distance_utils.compute_l2_distances( lhs_points, rhs_points, keepdims=False) return data_utils.reduce_weighted_mean( distances, weights=point_masks, axis=-1)