def compute_negative_indicator_matrix(anchor_points, match_points, distance_fn, min_negative_distance, anchor_point_masks=None, match_point_masks=None): """Computes all-pair negative match indicator matrix. Args: anchor_points: A tensor for anchor points. Shape = [num_anchors, ..., point_dim]. match_points: A tensor for match points. Shape = [num_matches, ..., point_dim]. distance_fn: A function handle for computing distance matrix. min_negative_distance: A float for the minimum negative distance threshold. anchor_point_masks: A tensor for anchor point masks. Shape = [num_anchors, ...]. Ignored if None. match_point_masks: A tensor for match point masks. Shape = [num_matches, ...]. Ignored if None. Returns: A boolean tensor for negative indicator matrix. Shape = [num_anchors, num_matches]. """ distance_matrix = distance_utils.compute_distance_matrix( anchor_points, match_points, distance_fn=distance_fn, start_point_masks=anchor_point_masks, end_point_masks=match_point_masks) return distance_matrix >= min_negative_distance
def test_compute_distance_matrix_with_end_masks(self): # Shape = [2, 3, 1]. start_points = tf.constant([ [[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]], ]) # Shape = [3, 3, 1]. end_points = tf.constant([ [[11.0], [12.0], [13.0]], [[14.0], [15.0], [16.0]], [[17.0], [18.0], [19.0]], ]) # Shape = [3, 3]. end_point_masks = tf.constant([[1.0, 0.0, 1.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) def masked_add(lhs, rhs, masks): masks = tf.expand_dims(masks, axis=-1) return tf.math.reduce_sum((lhs + rhs) * masks, axis=[-2, -1]) # Shape = [2, 3]. distance_matrix = distance_utils.compute_distance_matrix( start_points, end_points, distance_fn=masked_add, end_point_masks=end_point_masks) with self.session() as sess: distance_matrix_result = sess.run(distance_matrix) self.assertAllClose(distance_matrix_result, [[28.0, 15.0, 60.0], [34.0, 18.0, 69.0]])
def test_compute_distance_matrix_with_start_masks(self): # Shape = [2, 3, 1]. start_points = tf.constant([ [[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]], ]) # Shape = [3, 3, 1]. end_points = tf.constant([ [[11.0], [12.0], [13.0]], [[14.0], [15.0], [16.0]], [[17.0], [18.0], [19.0]], ]) # Shape = [2, 3]. start_point_masks = tf.constant([[1.0, 1.0, 1.0], [1.0, 1.0, 0.0]]) def masked_add(lhs, rhs, masks): masks = tf.expand_dims(masks, axis=-1) return tf.math.reduce_sum((lhs + rhs) * masks, axis=[-2, -1]) # Shape = [2, 3]. distance_matrix = distance_utils.compute_distance_matrix( start_points, end_points, distance_fn=masked_add, start_point_masks=start_point_masks) self.assertAllClose(distance_matrix, [[42.0, 51.0, 60.0], [32.0, 38.0, 44.0]])
def test_compute_distance_matrix(self): # Shape = [2, 1] start_points = tf.constant([[1], [2]]) # Shape = [3, 1] end_points = tf.constant([[3], [4], [5]]) distance_matrix = distance_utils.compute_distance_matrix( start_points, end_points, distance_fn=tf.math.subtract) self.assertAllEqual(distance_matrix, [[[-2], [-3], [-4]], [[-1], [-2], [-3]]])
def compute_negative_indicator_matrix(anchors, matches, distance_fn, min_negative_distance): """Computes all-pair negative match indicator matrix. Args: anchors: A tensor for anchor points. Shape = [num_anchors, ...]. matches: A tensor for match points. Shape = [num_matches, ...]. distance_fn: A function handle for computing distance matrix. min_negative_distance: A float for the minimum negative distance threshold. Returns: A boolean tensor for negative indicator matrix. Shape = [num_anchors, num_matches]. """ distance_matrix = distance_utils.compute_distance_matrix( anchors, matches, distance_fn=distance_fn) return distance_matrix >= min_negative_distance
def compute_positive_indicator_matrix(anchors, matches, distance_fn, max_positive_distance): """Computes all-pair positive indicator matrix. Args: anchors: A tensor for anchor points. Shape = [num_anchors, ...]. matches: A tensor for match points. Shape = [num_matches, ...]. distance_fn: A function handle for computing distance matrix. max_positive_distance: A float for the maximum positive distance threshold. Returns: A float tensor for positive indicator matrix. Shape = [num_anchors, num_matches]. """ distance_matrix = distance_utils.compute_distance_matrix( anchors, matches, distance_fn=distance_fn) distance_matrix = (distance_matrix + tf.transpose(distance_matrix)) / 2.0 positive_indicator_matrix = distance_matrix <= max_positive_distance return tf.cast(positive_indicator_matrix, dtype=tf.dtypes.float32)
def compute_keypoint_triplet_losses( anchor_embeddings, positive_embeddings, match_embeddings, anchor_keypoints, match_keypoints, margin, min_negative_keypoint_distance, use_semi_hard, exclude_inactive_triplet_loss, anchor_keypoint_masks=None, match_keypoint_masks=None, embedding_sample_distance_fn=create_sample_distance_fn(), keypoint_distance_fn=keypoint_utils.compute_procrustes_aligned_mpjpes, anchor_mining_embeddings=None, positive_mining_embeddings=None, match_mining_embeddings=None, summarize_percentiles=True): """Computes triplet losses with both hard and semi-hard negatives. Args: anchor_embeddings: A tensor for anchor embeddings. Shape = [num_anchors, embedding_dim] or [num_anchors, num_samples, embedding_dim]. positive_embeddings: A tensor for positive match embeddings. Shape = [num_anchors, embedding_dim] or [num_anchors, num_samples, embedding_dim]. match_embeddings: A tensor for candidate negative match embeddings. Shape = [num_anchors, embedding_dim] or [num_matches, num_samples, embedding_dim]. anchor_keypoints: A tensor for anchor keypoints for computing pair labels. Shape = [num_anchors, ..., num_keypoints, keypoint_dim]. match_keypoints: A tensor for match keypoints for computing pair labels. Shape = [num_anchors, ..., num_keypoints, keypoint_dim]. margin: A float for triplet loss margin. min_negative_keypoint_distance: A float for the minimum negative distance threshold. If negative, uses all other samples as negative matches. In this case, `num_anchors` and `num_matches` are assumed to be equal. Note that this option is for saving negative match computation. To support different `num_anchors` and `num_matches`, setting this to 0 (without saving computation). use_semi_hard: A boolean for whether to use semi-hard negative triplet loss as the final loss. exclude_inactive_triplet_loss: A boolean for whether to exclude inactive triplets in the final loss computation. anchor_keypoint_masks: A tensor for anchor keypoint masks for computing pair labels. Shape = [num_anchors, ..., num_keypoints]. Ignored if None. match_keypoint_masks: A tensor for match keypoint masks for computing pair labels. Shape = [num_anchors, ..., num_keypoints]. Ignored if None. embedding_sample_distance_fn: A function handle for computing sample embedding distances, which takes two embedding tensors of shape [..., num_samples, embedding_dim] and returns a distance tensor of shape [...]. keypoint_distance_fn: A function handle for computing keypoint distance matrix, which takes two matrix tensors and returns an element-wise distance matrix tensor. anchor_mining_embeddings: A tensor for anchor embeddings for triplet mining. Shape = [num_anchors, embedding_dim] or [num_anchors, num_samples, embedding_dim]. Use None to ignore and use `anchor_embeddings` instead. positive_mining_embeddings: A tensor for positive match embeddings for triplet mining. Shape = [num_anchors, embedding_dim] or [num_anchors, num_samples, embedding_dim]. Use None to ignore and use `positive_embeddings` instead. match_mining_embeddings: A tensor for candidate negative match embeddings for triplet mining. Shape = [num_anchors, embedding_dim] or [num_matches, num_samples, embedding_dim]. Use None to ignore and use `match_embeddings` instead. summarize_percentiles: A boolean for whether to summarize percentiles of certain variables, e.g., embedding distances in triplet loss. Consider turning this off in case tensorflow_probability percentile computation causes failures at random due to empty tensor. Returns: loss: A tensor for triplet loss. Shape = []. summaries: A dictionary for loss and batch statistics summaries. """ def maybe_expand_sample_dim(embeddings): if len(embeddings.shape.as_list()) == 2: return tf.expand_dims(embeddings, axis=-2) return embeddings anchor_embeddings = maybe_expand_sample_dim(anchor_embeddings) positive_embeddings = maybe_expand_sample_dim(positive_embeddings) match_embeddings = maybe_expand_sample_dim(match_embeddings) if min_negative_keypoint_distance >= 0.0: anchor_match_negative_indicator_matrix = ( compute_negative_indicator_matrix( anchor_points=anchor_keypoints, match_points=match_keypoints, distance_fn=keypoint_distance_fn, min_negative_distance=min_negative_keypoint_distance, anchor_point_masks=anchor_keypoint_masks, match_point_masks=match_keypoint_masks)) else: num_anchors = tf.shape(anchor_keypoints)[0] anchor_match_negative_indicator_matrix = tf.math.logical_not( tf.eye(num_anchors, dtype=tf.bool)) anchor_positive_distances = embedding_sample_distance_fn( anchor_embeddings, positive_embeddings) if anchor_mining_embeddings is None and positive_mining_embeddings is None: anchor_positive_mining_distances = anchor_positive_distances else: anchor_positive_mining_distances = embedding_sample_distance_fn( anchor_embeddings if anchor_mining_embeddings is None else maybe_expand_sample_dim(anchor_mining_embeddings), positive_embeddings if positive_mining_embeddings is None else maybe_expand_sample_dim(positive_mining_embeddings)) anchor_match_distance_matrix = distance_utils.compute_distance_matrix( anchor_embeddings, match_embeddings, distance_fn=embedding_sample_distance_fn) if anchor_mining_embeddings is None and match_mining_embeddings is None: anchor_match_mining_distance_matrix = anchor_match_distance_matrix else: anchor_match_mining_distance_matrix = distance_utils.compute_distance_matrix( anchor_embeddings if anchor_mining_embeddings is None else maybe_expand_sample_dim(anchor_mining_embeddings), match_embeddings if match_mining_embeddings is None else maybe_expand_sample_dim(match_mining_embeddings), distance_fn=embedding_sample_distance_fn) num_total_triplets = tf.cast(tf.shape(anchor_embeddings)[0], dtype=tf.float32) def compute_loss_and_create_summaries(use_semi_hard): """Computes loss and creates summaries.""" (loss, num_active_triplets, negative_distances, mining_loss, num_active_mining_triplets, negative_mining_distances) = (compute_hard_negative_triplet_loss( anchor_positive_distances, anchor_match_distance_matrix, anchor_match_negative_indicator_matrix, margin=margin, use_semi_hard=use_semi_hard, anchor_positive_mining_distances=anchor_positive_mining_distances, anchor_match_mining_distance_matrix=( anchor_match_mining_distance_matrix))) negative_distances = tf.boolean_mask( negative_distances, mask=negative_distances < negative_distances.dtype.max) negative_mining_distances = tf.boolean_mask( negative_mining_distances, mask=negative_distances < negative_distances.dtype.max) active_triplet_ratio = ( tf.cast(num_active_triplets, dtype=tf.float32) / num_total_triplets) active_mining_triplet_ratio = ( tf.cast(num_active_mining_triplets, dtype=tf.float32) / num_total_triplets) active_loss = ( loss / tf.math.maximum(1e-12, tf.stop_gradient(active_triplet_ratio))) active_mining_loss = (mining_loss / tf.math.maximum( 1e-12, tf.stop_gradient(active_mining_triplet_ratio))) tag = 'SemiHardNegative' if use_semi_hard else 'HardNegative' summaries = { # Summaries related to triplet loss computation. 'triplet_loss/Anchor/%s/Distance/Mean' % tag: tf.math.reduce_mean(negative_distances), 'triplet_loss/%s/Loss/All' % tag: loss, 'triplet_loss/%s/Loss/Active' % tag: active_loss, 'triplet_loss/%s/ActiveTripletNum' % tag: num_active_triplets, 'triplet_loss/%s/ActiveTripletRatio' % tag: active_triplet_ratio, # Summaries related to triplet mining. 'triplet_mining/Anchor/%s/Distance/Mean' % tag: tf.math.reduce_mean(negative_mining_distances), 'triplet_mining/%s/Loss/All' % tag: mining_loss, 'triplet_mining/%s/Loss/Active' % tag: active_mining_loss, 'triplet_mining/%s/ActiveTripletNum' % tag: num_active_mining_triplets, 'triplet_mining/%s/ActiveTripletRatio' % tag: active_mining_triplet_ratio, } if summarize_percentiles: summaries.update({ 'triplet_loss/Anchor/%s/Distance/Median' % tag: tfp.stats.percentile(negative_distances, q=50), 'triplet_mining/Anchor/%s/Distance/Median' % tag: tfp.stats.percentile(negative_mining_distances, q=50), }) return loss, active_loss, summaries hard_negative_loss, hard_negative_active_loss, hard_negative_summaries = ( compute_loss_and_create_summaries(use_semi_hard=False)) (semi_hard_negative_loss, semi_hard_negative_active_loss, semi_hard_negative_summaries) = (compute_loss_and_create_summaries( use_semi_hard=True)) summaries = { 'triplet_loss/Margin': tf.constant(margin), 'triplet_loss/Anchor/Positive/Distance/Mean': tf.math.reduce_mean(anchor_positive_distances), 'triplet_mining/Anchor/Positive/Distance/Mean': tf.math.reduce_mean(anchor_positive_mining_distances), } if summarize_percentiles: summaries.update({ 'triplet_loss/Anchor/Positive/Distance/Median': tfp.stats.percentile(anchor_positive_distances, q=50), 'triplet_mining/Anchor/Positive/Distance/Median': tfp.stats.percentile(anchor_positive_mining_distances, q=50), }) summaries.update(hard_negative_summaries) summaries.update(semi_hard_negative_summaries) if use_semi_hard: if exclude_inactive_triplet_loss: loss = semi_hard_negative_active_loss else: loss = semi_hard_negative_loss else: if exclude_inactive_triplet_loss: loss = hard_negative_active_loss else: loss = hard_negative_loss return loss, summaries