def test_compute_keypoint_triplet_losses_with_sample_mining_embeddings( self): # Shape = [3, 3, 1, 2]. anchor_embeddings = tf.constant([ [[[1.0, 2.0]], [[1.0, 2.0]], [[1.0, 2.0]]], [[[3.0, 4.0]], [[3.0, 4.0]], [[3.0, 4.0]]], [[[5.0, 6.0]], [[5.0, 6.0]], [[5.0, 6.0]]], ]) # Shape = [3, 3, 1, 2]. positive_embeddings = tf.constant([ [[[2.0, 1.0]], [[2.0, 1.0]], [[2.0, 1.0]]], [[[6.0, 5.0]], [[6.0, 5.0]], [[6.0, 5.0]]], [[[7.0, 6.0]], [[7.0, 6.0]], [[7.0, 6.0]]], ]) # Shape = [4, 3, 2, 2]. match_embeddings = tf.constant([ [[[3.0, 2.0], [3.0, 2.0]], [[3.0, 2.0], [3.0, 2.0]], [[3.0, 2.0], [3.0, 2.0]]], [[[4.0, 3.0], [4.0, 3.0]], [[4.0, 3.0], [4.0, 3.0]], [[4.0, 3.0], [4.0, 3.0]]], [[[6.0, 5.0], [6.0, 5.0]], [[6.0, 5.0], [6.0, 5.0]], [[6.0, 5.0], [6.0, 5.0]]], [[[8.0, 7.0], [8.0, 7.0]], [[8.0, 7.0], [8.0, 7.0]], [[8.0, 7.0], [8.0, 7.0]]], ]) # Shape = [3, 1]. anchor_keypoints = tf.constant([[1], [2], [3]]) # Shape = [4, 1]. match_keypoints = tf.constant([[1], [2], [3], [4]]) def mock_keypoint_distance_fn(unused_lhs, unused_rhs): return tf.constant([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]]) # Shape = [3, 3, 1, 2]. anchor_mining_embeddings = tf.constant([ [[[1.0, 2.0]], [[1.0, 2.0]], [[1.0, 2.0]]], [[[3.0, 4.0]], [[3.0, 4.0]], [[3.0, 4.0]]], [[[5.0, 6.0]], [[5.0, 6.0]], [[5.0, 6.0]]], ]) # Shape = [3, 3, 1, 2]. positive_mining_embeddings = tf.constant([ [[[1.0, 2.0]], [[1.0, 2.0]], [[1.0, 2.0]]], [[[5.0, 6.0]], [[5.0, 6.0]], [[5.0, 6.0]]], [[[6.0, 7.0]], [[6.0, 7.0]], [[6.0, 7.0]]], ]) # Shape = [4, 3, 1, 2]. match_mining_embeddings = tf.constant([ [[[2.0, 3.0]], [[2.0, 3.0]], [[2.0, 3.0]]], [[[3.0, 4.0]], [[3.0, 4.0]], [[3.0, 4.0]]], [[[5.0, 6.0]], [[5.0, 6.0]], [[5.0, 6.0]]], [[[7.0, 8.0]], [[7.0, 8.0]], [[7.0, 8.0]]], ]) loss, summaries = loss_utils.compute_keypoint_triplet_losses( anchor_embeddings, positive_embeddings, match_embeddings, anchor_keypoints, match_keypoints, margin=120.0, min_negative_keypoint_distance=0.5, use_semi_hard=True, exclude_inactive_triplet_loss=True, embedding_sample_distance_fn=loss_utils.create_sample_distance_fn( pairwise_reduction=functools.partial(tf.math.reduce_sum, axis=[-2, -1]), componentwise_reduction=functools.partial(tf.math.reduce_sum, axis=[-1])), keypoint_distance_fn=mock_keypoint_distance_fn, anchor_mining_embeddings=anchor_mining_embeddings, positive_mining_embeddings=positive_mining_embeddings, match_mining_embeddings=match_mining_embeddings) with self.session() as sess: loss_result, summaries_result = sess.run([loss, summaries]) self.assertAlmostEqual(loss_result, 57.0) expected_summaries_result = { 'triplet_loss/Margin': 120.0, 'triplet_loss/Anchor/Positive/Distance/Mean': 48.0 / 3, 'triplet_loss/Anchor/Positive/Distance/Median': 12.0, 'triplet_loss/Anchor/HardNegative/Distance/Mean': 48.0 / 3, 'triplet_loss/Anchor/HardNegative/Distance/Median': 12.0, 'triplet_loss/Anchor/SemiHardNegative/Distance/Mean': 348.0 / 3, 'triplet_loss/Anchor/SemiHardNegative/Distance/Median': 120.0, 'triplet_loss/HardNegative/Loss/All': 360.0 / 3, 'triplet_loss/HardNegative/Loss/Active': 360.0 / 3, 'triplet_loss/HardNegative/ActiveTripletNum': 3, 'triplet_loss/HardNegative/ActiveTripletRatio': 1.0, 'triplet_loss/SemiHardNegative/Loss/All': 114.0 / 3, 'triplet_loss/SemiHardNegative/Loss/Active': 114.0 / 2, 'triplet_loss/SemiHardNegative/ActiveTripletNum': 2, 'triplet_loss/SemiHardNegative/ActiveTripletRatio': 2.0 / 3, 'triplet_mining/Anchor/Positive/Distance/Mean': 30.0 / 3, 'triplet_mining/Anchor/Positive/Distance/Median': 6.0, 'triplet_mining/Anchor/HardNegative/Distance/Mean': 6.0 / 3, 'triplet_mining/Anchor/HardNegative/Distance/Median': 0.0, 'triplet_mining/Anchor/SemiHardNegative/Distance/Mean': 156.0 / 3, 'triplet_mining/Anchor/SemiHardNegative/Distance/Median': 54.0, 'triplet_mining/HardNegative/Loss/All': 384.0 / 3, 'triplet_mining/HardNegative/Loss/Active': 384.0 / 3, 'triplet_mining/HardNegative/ActiveTripletNum': 3, 'triplet_mining/HardNegative/ActiveTripletRatio': 1.0, 'triplet_mining/SemiHardNegative/Loss/All': 234.0 / 3, 'triplet_mining/SemiHardNegative/Loss/Active': 234.0 / 3, 'triplet_mining/SemiHardNegative/ActiveTripletNum': 3, 'triplet_mining/SemiHardNegative/ActiveTripletRatio': 1.0, } self._assert_dict_equal_or_almost_equal(summaries_result, expected_summaries_result)
def add_triplet_loss(): """Adds triplet loss.""" anchor_keypoints_3d, positive_keypoints_3d = tf.unstack( inputs[common_module.KEY_KEYPOINTS_3D], num=2, axis=1) anchor_keypoint_masks_3d, positive_keypoint_masks_3d = None, None if FLAGS.use_inferred_keypoint_masks_for_triplet_label: anchor_keypoint_masks_2d, positive_keypoint_masks_2d = tf.unstack( inputs[ common_module.KEY_PREPROCESSED_KEYPOINT_MASKS_2D], num=2, axis=1) anchor_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks( anchor_keypoint_masks_2d, input_keypoint_profile=configs['keypoint_profile_2d'], output_keypoint_profile=configs['keypoint_profile_3d'], enforce_surjectivity=True) positive_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks( positive_keypoint_masks_2d, input_keypoint_profile=configs['keypoint_profile_2d'], output_keypoint_profile=configs['keypoint_profile_3d'], enforce_surjectivity=True) triplet_anchor_embeddings, triplet_positive_embeddings = tf.unstack( pipeline_utils.stack_embeddings( outputs, configs['triplet_embedding_keys']), axis=1) if FLAGS.use_normalized_embeddings_for_triplet_loss: triplet_anchor_embeddings = tf.math.l2_normalize( triplet_anchor_embeddings, axis=-1) triplet_positive_embeddings = tf.math.l2_normalize( triplet_positive_embeddings, axis=-1) triplet_anchor_mining_embeddings, triplet_positive_mining_embeddings = ( tf.unstack(pipeline_utils.stack_embeddings( outputs, configs['triplet_mining_embedding_keys']), axis=1)) if FLAGS.use_normalized_embeddings_for_triplet_mining: triplet_anchor_mining_embeddings = tf.math.l2_normalize( triplet_anchor_mining_embeddings, axis=-1) triplet_positive_mining_embeddings = tf.math.l2_normalize( triplet_positive_mining_embeddings, axis=-1) triplet_loss, triplet_loss_summaries = ( loss_utils.compute_keypoint_triplet_losses( anchor_embeddings=triplet_anchor_embeddings, positive_embeddings=triplet_positive_embeddings, match_embeddings=triplet_positive_embeddings, anchor_keypoints=anchor_keypoints_3d, match_keypoints=positive_keypoints_3d, margin=FLAGS.triplet_loss_margin, min_negative_keypoint_distance=( configs['min_negative_keypoint_distance']), use_semi_hard=FLAGS.use_semi_hard_triplet_negatives, exclude_inactive_triplet_loss=( FLAGS.exclude_inactive_triplet_loss), anchor_keypoint_masks=anchor_keypoint_masks_3d, match_keypoint_masks=positive_keypoint_masks_3d, embedding_sample_distance_fn=( configs['triplet_embedding_sample_distance_fn']), keypoint_distance_fn=configs['keypoint_distance_fn'], anchor_mining_embeddings= triplet_anchor_mining_embeddings, positive_mining_embeddings= triplet_positive_mining_embeddings, match_mining_embeddings= triplet_positive_mining_embeddings, summarize_percentiles=FLAGS.summarize_percentiles)) tf.losses.add_loss(triplet_loss, loss_collection=tf.GraphKeys.LOSSES) summaries.update(triplet_loss_summaries) summaries['train/triplet_loss'] = triplet_loss
def test_compute_keypoint_triplet_losses(self): # Shape = [3, 1, 1, 2]. anchor_embeddings = tf.constant([ [[[1.0, 2.0]]], [[[3.0, 4.0]]], [[[5.0, 6.0]]], ]) # Shape = [3, 1, 1, 2]. positive_embeddings = tf.constant([ [[[1.0, 2.0]]], [[[5.0, 6.0]]], [[[6.0, 7.0]]], ]) # Shape = [4, 1, 1, 2]. match_embeddings = tf.constant([ [[[2.0, 3.0]]], [[[3.0, 4.0]]], [[[5.0, 6.0]]], [[[7.0, 8.0]]], ]) # Shape = [3, 1]. anchor_keypoints = tf.constant([[1], [2], [3]]) # Shape = [4, 1]. match_keypoints = tf.constant([[1], [2], [3], [4]]) def mock_keypoint_distance_fn(unused_lhs, unused_rhs): # Shape = [3, 4]. return tf.constant([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]]) loss, summaries = loss_utils.compute_keypoint_triplet_losses( anchor_embeddings, positive_embeddings, match_embeddings, anchor_keypoints, match_keypoints, margin=20.0, min_negative_keypoint_distance=0.5, use_semi_hard=True, exclude_inactive_triplet_loss=True, keypoint_distance_fn=mock_keypoint_distance_fn) with self.session() as sess: loss_result, summaries_result = sess.run([loss, summaries]) self.assertAlmostEqual(loss_result, 11.0) expected_summaries_result = { 'triplet_loss/Margin': 20.0, 'triplet_loss/Anchor/Positive/Distance/Mean': 10.0 / 3, 'triplet_loss/Anchor/Positive/Distance/Median': 2.0, 'triplet_loss/Anchor/HardNegative/Distance/Mean': 2.0 / 3, 'triplet_loss/Anchor/HardNegative/Distance/Median': 0.0, 'triplet_loss/Anchor/SemiHardNegative/Distance/Mean': 52.0 / 3, 'triplet_loss/Anchor/SemiHardNegative/Distance/Median': 18.0, 'triplet_loss/HardNegative/Loss/All': 68.0 / 3, 'triplet_loss/HardNegative/Loss/Active': 68.0 / 3, 'triplet_loss/HardNegative/ActiveTripletNum': 3, 'triplet_loss/HardNegative/ActiveTripletRatio': 1.0, 'triplet_loss/SemiHardNegative/Loss/All': 22.0 / 3, 'triplet_loss/SemiHardNegative/Loss/Active': 22.0 / 2, 'triplet_loss/SemiHardNegative/ActiveTripletNum': 2, 'triplet_loss/SemiHardNegative/ActiveTripletRatio': 2.0 / 3, 'triplet_mining/Anchor/Positive/Distance/Mean': 10.0 / 3, 'triplet_mining/Anchor/Positive/Distance/Median': 2.0, 'triplet_mining/Anchor/HardNegative/Distance/Mean': 2.0 / 3, 'triplet_mining/Anchor/HardNegative/Distance/Median': 0.0, 'triplet_mining/Anchor/SemiHardNegative/Distance/Mean': 52.0 / 3, 'triplet_mining/Anchor/SemiHardNegative/Distance/Median': 18.0, 'triplet_mining/HardNegative/Loss/All': 68.0 / 3, 'triplet_mining/HardNegative/Loss/Active': 68.0 / 3, 'triplet_mining/HardNegative/ActiveTripletNum': 3, 'triplet_mining/HardNegative/ActiveTripletRatio': 1.0, 'triplet_mining/SemiHardNegative/Loss/All': 22.0 / 3, 'triplet_mining/SemiHardNegative/Loss/Active': 22.0 / 2, 'triplet_mining/SemiHardNegative/ActiveTripletNum': 2, 'triplet_mining/SemiHardNegative/ActiveTripletRatio': 2.0 / 3, } self._assert_dict_equal_or_almost_equal(summaries_result, expected_summaries_result)
def add_triplet_loss(): """Adds triplet loss.""" anchor_keypoints_3d = tf.unstack( inputs[common_module.KEY_KEYPOINTS_3D], num=2, axis=1)[0] if (configs['keypoint_profile_3d'].keypoint_names != configs['target_keypoint_profile_3d'].keypoint_names): # Select target keypoints to use if they are different than input. anchor_keypoints_3d, _ = keypoint_utils.select_keypoints_by_name( anchor_keypoints_3d, input_keypoint_names=( configs['keypoint_profile_3d'].keypoint_names), output_keypoint_names=( configs['target_keypoint_profile_3d']. keypoint_names)) triplet_anchor_embeddings, triplet_positive_embeddings = tf.unstack( pipeline_utils.stack_embeddings( outputs, configs['triplet_embedding_keys']), axis=1) if FLAGS.use_normalized_embeddings_for_triplet_loss: triplet_anchor_embeddings = tf.math.l2_normalize( triplet_anchor_embeddings, axis=-1) triplet_positive_embeddings = tf.math.l2_normalize( triplet_positive_embeddings, axis=-1) triplet_anchor_mining_embeddings, triplet_positive_mining_embeddings = ( tf.unstack(pipeline_utils.stack_embeddings( outputs, configs['triplet_mining_embedding_keys']), axis=1)) if FLAGS.use_normalized_embeddings_for_triplet_mining: triplet_anchor_mining_embeddings = tf.math.l2_normalize( triplet_anchor_mining_embeddings, axis=-1) triplet_positive_mining_embeddings = tf.math.l2_normalize( triplet_positive_mining_embeddings, axis=-1) triplet_loss, triplet_loss_summaries = ( loss_utils.compute_keypoint_triplet_losses( anchor_embeddings=triplet_anchor_embeddings, positive_embeddings=triplet_positive_embeddings, match_embeddings=triplet_anchor_embeddings, anchor_keypoints=anchor_keypoints_3d, match_keypoints=anchor_keypoints_3d, margin=FLAGS.triplet_loss_margin, min_negative_keypoint_distance=( configs['min_negative_keypoint_distance']), use_semi_hard=FLAGS.use_semi_hard_triplet_negatives, exclude_inactive_triplet_loss=( FLAGS.exclude_inactive_triplet_loss), embedding_sample_distance_fn=( configs['triplet_embedding_sample_distance_fn']), keypoint_distance_fn=configs['keypoint_distance_fn'], anchor_mining_embeddings= triplet_anchor_mining_embeddings, positive_mining_embeddings= triplet_positive_mining_embeddings, match_mining_embeddings= triplet_anchor_mining_embeddings, summarize_percentiles=FLAGS.summarize_percentiles)) tf.losses.add_loss(triplet_loss, loss_collection=tf.GraphKeys.LOSSES) summaries.update(triplet_loss_summaries) summaries['train/triplet_loss'] = triplet_loss