Ejemplo n.º 1
0
def compute_negative_indicator_matrix(anchor_points,
                                      match_points,
                                      distance_fn,
                                      min_negative_distance,
                                      anchor_point_masks=None,
                                      match_point_masks=None):
    """Computes all-pair negative match indicator matrix.

  Args:
    anchor_points: A tensor for anchor points. Shape = [num_anchors, ...,
      point_dim].
    match_points: A tensor for match points. Shape = [num_matches, ...,
      point_dim].
    distance_fn: A function handle for computing distance matrix.
    min_negative_distance: A float for the minimum negative distance threshold.
    anchor_point_masks: A tensor for anchor point masks. Shape = [num_anchors,
      ...]. Ignored if None.
    match_point_masks: A tensor for match point masks. Shape = [num_matches,
      ...]. Ignored if None.

  Returns:
    A boolean tensor for negative indicator matrix. Shape = [num_anchors,
      num_matches].
  """
    distance_matrix = distance_utils.compute_distance_matrix(
        anchor_points,
        match_points,
        distance_fn=distance_fn,
        start_point_masks=anchor_point_masks,
        end_point_masks=match_point_masks)
    return distance_matrix >= min_negative_distance
Ejemplo n.º 2
0
    def test_compute_distance_matrix_with_end_masks(self):
        # Shape = [2, 3, 1].
        start_points = tf.constant([
            [[1.0], [2.0], [3.0]],
            [[4.0], [5.0], [6.0]],
        ])
        # Shape = [3, 3, 1].
        end_points = tf.constant([
            [[11.0], [12.0], [13.0]],
            [[14.0], [15.0], [16.0]],
            [[17.0], [18.0], [19.0]],
        ])
        # Shape = [3, 3].
        end_point_masks = tf.constant([[1.0, 0.0, 1.0], [1.0, 0.0, 0.0],
                                       [1.0, 1.0, 1.0]])

        def masked_add(lhs, rhs, masks):
            masks = tf.expand_dims(masks, axis=-1)
            return tf.math.reduce_sum((lhs + rhs) * masks, axis=[-2, -1])

        # Shape = [2, 3].
        distance_matrix = distance_utils.compute_distance_matrix(
            start_points,
            end_points,
            distance_fn=masked_add,
            end_point_masks=end_point_masks)

        with self.session() as sess:
            distance_matrix_result = sess.run(distance_matrix)

        self.assertAllClose(distance_matrix_result,
                            [[28.0, 15.0, 60.0], [34.0, 18.0, 69.0]])
Ejemplo n.º 3
0
    def test_compute_distance_matrix_with_start_masks(self):
        # Shape = [2, 3, 1].
        start_points = tf.constant([
            [[1.0], [2.0], [3.0]],
            [[4.0], [5.0], [6.0]],
        ])
        # Shape = [3, 3, 1].
        end_points = tf.constant([
            [[11.0], [12.0], [13.0]],
            [[14.0], [15.0], [16.0]],
            [[17.0], [18.0], [19.0]],
        ])
        # Shape = [2, 3].
        start_point_masks = tf.constant([[1.0, 1.0, 1.0], [1.0, 1.0, 0.0]])

        def masked_add(lhs, rhs, masks):
            masks = tf.expand_dims(masks, axis=-1)
            return tf.math.reduce_sum((lhs + rhs) * masks, axis=[-2, -1])

        # Shape = [2, 3].
        distance_matrix = distance_utils.compute_distance_matrix(
            start_points,
            end_points,
            distance_fn=masked_add,
            start_point_masks=start_point_masks)

        self.assertAllClose(distance_matrix,
                            [[42.0, 51.0, 60.0], [32.0, 38.0, 44.0]])
Ejemplo n.º 4
0
 def test_compute_distance_matrix(self):
     # Shape = [2, 1]
     start_points = tf.constant([[1], [2]])
     # Shape = [3, 1]
     end_points = tf.constant([[3], [4], [5]])
     distance_matrix = distance_utils.compute_distance_matrix(
         start_points, end_points, distance_fn=tf.math.subtract)
     self.assertAllEqual(distance_matrix,
                         [[[-2], [-3], [-4]], [[-1], [-2], [-3]]])
Ejemplo n.º 5
0
def compute_negative_indicator_matrix(anchors, matches, distance_fn,
                                      min_negative_distance):
    """Computes all-pair negative match indicator matrix.

  Args:
    anchors: A tensor for anchor points. Shape = [num_anchors, ...].
    matches: A tensor for match points. Shape = [num_matches, ...].
    distance_fn: A function handle for computing distance matrix.
    min_negative_distance: A float for the minimum negative distance threshold.

  Returns:
    A boolean tensor for negative indicator matrix. Shape = [num_anchors,
      num_matches].
  """
    distance_matrix = distance_utils.compute_distance_matrix(
        anchors, matches, distance_fn=distance_fn)
    return distance_matrix >= min_negative_distance
Ejemplo n.º 6
0
def compute_positive_indicator_matrix(anchors, matches, distance_fn,
                                      max_positive_distance):
    """Computes all-pair positive indicator matrix.

  Args:
    anchors: A tensor for anchor points. Shape = [num_anchors, ...].
    matches: A tensor for match points. Shape = [num_matches, ...].
    distance_fn: A function handle for computing distance matrix.
    max_positive_distance: A float for the maximum positive distance threshold.

  Returns:
    A float tensor for positive indicator matrix. Shape = [num_anchors,
      num_matches].
  """
    distance_matrix = distance_utils.compute_distance_matrix(
        anchors, matches, distance_fn=distance_fn)
    distance_matrix = (distance_matrix + tf.transpose(distance_matrix)) / 2.0
    positive_indicator_matrix = distance_matrix <= max_positive_distance
    return tf.cast(positive_indicator_matrix, dtype=tf.dtypes.float32)
Ejemplo n.º 7
0
def compute_keypoint_triplet_losses(
        anchor_embeddings,
        positive_embeddings,
        match_embeddings,
        anchor_keypoints,
        match_keypoints,
        margin,
        min_negative_keypoint_distance,
        use_semi_hard,
        exclude_inactive_triplet_loss,
        anchor_keypoint_masks=None,
        match_keypoint_masks=None,
        embedding_sample_distance_fn=create_sample_distance_fn(),
        keypoint_distance_fn=keypoint_utils.compute_procrustes_aligned_mpjpes,
        anchor_mining_embeddings=None,
        positive_mining_embeddings=None,
        match_mining_embeddings=None,
        summarize_percentiles=True):
    """Computes triplet losses with both hard and semi-hard negatives.

  Args:
    anchor_embeddings: A tensor for anchor embeddings. Shape = [num_anchors,
      embedding_dim] or [num_anchors, num_samples, embedding_dim].
    positive_embeddings: A tensor for positive match embeddings. Shape =
      [num_anchors, embedding_dim] or [num_anchors, num_samples, embedding_dim].
    match_embeddings: A tensor for candidate negative match embeddings. Shape =
      [num_anchors, embedding_dim] or [num_matches, num_samples, embedding_dim].
    anchor_keypoints: A tensor for anchor keypoints for computing pair labels.
      Shape = [num_anchors, ..., num_keypoints, keypoint_dim].
    match_keypoints: A tensor for match keypoints for computing pair labels.
      Shape = [num_anchors, ..., num_keypoints, keypoint_dim].
    margin: A float for triplet loss margin.
    min_negative_keypoint_distance: A float for the minimum negative distance
      threshold. If negative, uses all other samples as negative matches. In
      this case, `num_anchors` and `num_matches` are assumed to be equal. Note
      that this option is for saving negative match computation. To support
      different `num_anchors` and `num_matches`, setting this to 0 (without
      saving computation).
    use_semi_hard: A boolean for whether to use semi-hard negative triplet loss
      as the final loss.
    exclude_inactive_triplet_loss: A boolean for whether to exclude inactive
      triplets in the final loss computation.
    anchor_keypoint_masks: A tensor for anchor keypoint masks for computing pair
      labels. Shape = [num_anchors, ..., num_keypoints]. Ignored if None.
    match_keypoint_masks: A tensor for match keypoint masks for computing pair
      labels. Shape = [num_anchors, ..., num_keypoints]. Ignored if None.
    embedding_sample_distance_fn: A function handle for computing sample
      embedding distances, which takes two embedding tensors of shape [...,
      num_samples, embedding_dim] and returns a distance tensor of shape [...].
    keypoint_distance_fn: A function handle for computing keypoint distance
      matrix, which takes two matrix tensors and returns an element-wise
      distance matrix tensor.
    anchor_mining_embeddings: A tensor for anchor embeddings for triplet mining.
      Shape = [num_anchors, embedding_dim] or [num_anchors, num_samples,
      embedding_dim]. Use None to ignore and use `anchor_embeddings` instead.
    positive_mining_embeddings: A tensor for positive match embeddings for
      triplet mining. Shape = [num_anchors, embedding_dim] or [num_anchors,
      num_samples, embedding_dim]. Use None to ignore and use
      `positive_embeddings` instead.
    match_mining_embeddings: A tensor for candidate negative match embeddings
      for triplet mining. Shape = [num_anchors, embedding_dim] or [num_matches,
      num_samples, embedding_dim]. Use None to ignore and use `match_embeddings`
      instead.
    summarize_percentiles: A boolean for whether to summarize percentiles of
      certain variables, e.g., embedding distances in triplet loss. Consider
      turning this off in case tensorflow_probability percentile computation
      causes failures at random due to empty tensor.

  Returns:
    loss: A tensor for triplet loss. Shape = [].
    summaries: A dictionary for loss and batch statistics summaries.
  """
    def maybe_expand_sample_dim(embeddings):
        if len(embeddings.shape.as_list()) == 2:
            return tf.expand_dims(embeddings, axis=-2)
        return embeddings

    anchor_embeddings = maybe_expand_sample_dim(anchor_embeddings)
    positive_embeddings = maybe_expand_sample_dim(positive_embeddings)
    match_embeddings = maybe_expand_sample_dim(match_embeddings)

    if min_negative_keypoint_distance >= 0.0:
        anchor_match_negative_indicator_matrix = (
            compute_negative_indicator_matrix(
                anchor_points=anchor_keypoints,
                match_points=match_keypoints,
                distance_fn=keypoint_distance_fn,
                min_negative_distance=min_negative_keypoint_distance,
                anchor_point_masks=anchor_keypoint_masks,
                match_point_masks=match_keypoint_masks))
    else:
        num_anchors = tf.shape(anchor_keypoints)[0]
        anchor_match_negative_indicator_matrix = tf.math.logical_not(
            tf.eye(num_anchors, dtype=tf.bool))

    anchor_positive_distances = embedding_sample_distance_fn(
        anchor_embeddings, positive_embeddings)

    if anchor_mining_embeddings is None and positive_mining_embeddings is None:
        anchor_positive_mining_distances = anchor_positive_distances
    else:
        anchor_positive_mining_distances = embedding_sample_distance_fn(
            anchor_embeddings if anchor_mining_embeddings is None else
            maybe_expand_sample_dim(anchor_mining_embeddings),
            positive_embeddings if positive_mining_embeddings is None else
            maybe_expand_sample_dim(positive_mining_embeddings))

    anchor_match_distance_matrix = distance_utils.compute_distance_matrix(
        anchor_embeddings,
        match_embeddings,
        distance_fn=embedding_sample_distance_fn)

    if anchor_mining_embeddings is None and match_mining_embeddings is None:
        anchor_match_mining_distance_matrix = anchor_match_distance_matrix
    else:
        anchor_match_mining_distance_matrix = distance_utils.compute_distance_matrix(
            anchor_embeddings if anchor_mining_embeddings is None else
            maybe_expand_sample_dim(anchor_mining_embeddings),
            match_embeddings if match_mining_embeddings is None else
            maybe_expand_sample_dim(match_mining_embeddings),
            distance_fn=embedding_sample_distance_fn)

    num_total_triplets = tf.cast(tf.shape(anchor_embeddings)[0],
                                 dtype=tf.float32)

    def compute_loss_and_create_summaries(use_semi_hard):
        """Computes loss and creates summaries."""
        (loss, num_active_triplets, negative_distances, mining_loss,
         num_active_mining_triplets,
         negative_mining_distances) = (compute_hard_negative_triplet_loss(
             anchor_positive_distances,
             anchor_match_distance_matrix,
             anchor_match_negative_indicator_matrix,
             margin=margin,
             use_semi_hard=use_semi_hard,
             anchor_positive_mining_distances=anchor_positive_mining_distances,
             anchor_match_mining_distance_matrix=(
                 anchor_match_mining_distance_matrix)))
        negative_distances = tf.boolean_mask(
            negative_distances,
            mask=negative_distances < negative_distances.dtype.max)
        negative_mining_distances = tf.boolean_mask(
            negative_mining_distances,
            mask=negative_distances < negative_distances.dtype.max)

        active_triplet_ratio = (
            tf.cast(num_active_triplets, dtype=tf.float32) /
            num_total_triplets)
        active_mining_triplet_ratio = (
            tf.cast(num_active_mining_triplets, dtype=tf.float32) /
            num_total_triplets)

        active_loss = (
            loss /
            tf.math.maximum(1e-12, tf.stop_gradient(active_triplet_ratio)))
        active_mining_loss = (mining_loss / tf.math.maximum(
            1e-12, tf.stop_gradient(active_mining_triplet_ratio)))

        tag = 'SemiHardNegative' if use_semi_hard else 'HardNegative'
        summaries = {
            # Summaries related to triplet loss computation.
            'triplet_loss/Anchor/%s/Distance/Mean' % tag:
            tf.math.reduce_mean(negative_distances),
            'triplet_loss/%s/Loss/All' % tag:
            loss,
            'triplet_loss/%s/Loss/Active' % tag:
            active_loss,
            'triplet_loss/%s/ActiveTripletNum' % tag:
            num_active_triplets,
            'triplet_loss/%s/ActiveTripletRatio' % tag:
            active_triplet_ratio,

            # Summaries related to triplet mining.
            'triplet_mining/Anchor/%s/Distance/Mean' % tag:
            tf.math.reduce_mean(negative_mining_distances),
            'triplet_mining/%s/Loss/All' % tag:
            mining_loss,
            'triplet_mining/%s/Loss/Active' % tag:
            active_mining_loss,
            'triplet_mining/%s/ActiveTripletNum' % tag:
            num_active_mining_triplets,
            'triplet_mining/%s/ActiveTripletRatio' % tag:
            active_mining_triplet_ratio,
        }
        if summarize_percentiles:
            summaries.update({
                'triplet_loss/Anchor/%s/Distance/Median' % tag:
                tfp.stats.percentile(negative_distances, q=50),
                'triplet_mining/Anchor/%s/Distance/Median' % tag:
                tfp.stats.percentile(negative_mining_distances, q=50),
            })

        return loss, active_loss, summaries

    hard_negative_loss, hard_negative_active_loss, hard_negative_summaries = (
        compute_loss_and_create_summaries(use_semi_hard=False))
    (semi_hard_negative_loss, semi_hard_negative_active_loss,
     semi_hard_negative_summaries) = (compute_loss_and_create_summaries(
         use_semi_hard=True))

    summaries = {
        'triplet_loss/Margin':
        tf.constant(margin),
        'triplet_loss/Anchor/Positive/Distance/Mean':
        tf.math.reduce_mean(anchor_positive_distances),
        'triplet_mining/Anchor/Positive/Distance/Mean':
        tf.math.reduce_mean(anchor_positive_mining_distances),
    }
    if summarize_percentiles:
        summaries.update({
            'triplet_loss/Anchor/Positive/Distance/Median':
            tfp.stats.percentile(anchor_positive_distances, q=50),
            'triplet_mining/Anchor/Positive/Distance/Median':
            tfp.stats.percentile(anchor_positive_mining_distances, q=50),
        })
    summaries.update(hard_negative_summaries)
    summaries.update(semi_hard_negative_summaries)

    if use_semi_hard:
        if exclude_inactive_triplet_loss:
            loss = semi_hard_negative_active_loss
        else:
            loss = semi_hard_negative_loss
    else:
        if exclude_inactive_triplet_loss:
            loss = hard_negative_active_loss
        else:
            loss = hard_negative_loss

    return loss, summaries