Example #1
0
  def test_reduce_mean(self):
    # Shape = [2, 3, 2].
    tensor = tf.constant([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
                          [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]])
    # Shape = [2, 3, 1].
    weights = tf.constant([[[1.0], [0.0], [1.0]], [[0.0], [1.0], [0.0]]])
    # Shape = [2, 1, 2].
    means = data_utils.reduce_weighted_mean(
        tensor, weights, axis=-2, keepdims=True)

    self.assertAllClose(means, [[[3.0, 4.0]], [[9.0, 10.0]]])
Example #2
0
def centralize_masked_points(points, point_masks):
  """Sets masked out points to the centers of rest of the points.

  Args:
    points: A tensor for points. Shape = [..., num_points, point_dim].
    point_masks: A tensor for the masks. Shape = [..., num_points].

  Returns:
    A tensor for points with masked out points centralized.
  """
  point_masks = tf.expand_dims(point_masks, axis=-1)
  kept_centers = data_utils.reduce_weighted_mean(
      points, weights=point_masks, axis=-2, keepdims=True)
  return tf.where(tf.cast(point_masks, dtype=tf.bool), points, kept_centers)
Example #3
0
def compute_mpjpes(lhs_points, rhs_points, point_masks=None):
  """Computes the Mean Per-Joint Position Errors (MPJPEs).

  If `point_masks` is specified, computes MPJPEs weighted by `point_masks`.

  Args:
    lhs_points: A tensor for the LHS points. Shape = [..., num_points,
      point_dim].
    rhs_points: A tensor for the RHS points. Shape = [..., num_points,
      point_dim].
    point_masks: A tensor for the masks. Shape = [..., num_points]. Ignored if
      None.

  Returns:
    A tensor for MPJPEs. Shape = [...].
  """
  distances = distance_utils.compute_l2_distances(
      lhs_points, rhs_points, keepdims=False)
  return data_utils.reduce_weighted_mean(
      distances, weights=point_masks, axis=-1)