def add_triplet_loss(): """Adds triplet loss.""" anchor_keypoints_3d, positive_keypoints_3d = tf.unstack( inputs[common_module.KEY_KEYPOINTS_3D], num=2, axis=1) anchor_keypoint_masks_3d, positive_keypoint_masks_3d = None, None if FLAGS.use_inferred_keypoint_masks_for_triplet_label: anchor_keypoint_masks_2d, positive_keypoint_masks_2d = tf.unstack( inputs[ common_module.KEY_PREPROCESSED_KEYPOINT_MASKS_2D], num=2, axis=1) anchor_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks( anchor_keypoint_masks_2d, input_keypoint_profile=configs['keypoint_profile_2d'], output_keypoint_profile=configs['keypoint_profile_3d'], enforce_surjectivity=True) positive_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks( positive_keypoint_masks_2d, input_keypoint_profile=configs['keypoint_profile_2d'], output_keypoint_profile=configs['keypoint_profile_3d'], enforce_surjectivity=True) triplet_anchor_embeddings, triplet_positive_embeddings = tf.unstack( pipeline_utils.stack_embeddings( outputs, configs['triplet_embedding_keys']), axis=1) if FLAGS.use_normalized_embeddings_for_triplet_loss: triplet_anchor_embeddings = tf.math.l2_normalize( triplet_anchor_embeddings, axis=-1) triplet_positive_embeddings = tf.math.l2_normalize( triplet_positive_embeddings, axis=-1) triplet_anchor_mining_embeddings, triplet_positive_mining_embeddings = ( tf.unstack(pipeline_utils.stack_embeddings( outputs, configs['triplet_mining_embedding_keys']), axis=1)) if FLAGS.use_normalized_embeddings_for_triplet_mining: triplet_anchor_mining_embeddings = tf.math.l2_normalize( triplet_anchor_mining_embeddings, axis=-1) triplet_positive_mining_embeddings = tf.math.l2_normalize( triplet_positive_mining_embeddings, axis=-1) triplet_loss, triplet_loss_summaries = ( loss_utils.compute_keypoint_triplet_losses( anchor_embeddings=triplet_anchor_embeddings, positive_embeddings=triplet_positive_embeddings, match_embeddings=triplet_positive_embeddings, anchor_keypoints=anchor_keypoints_3d, match_keypoints=positive_keypoints_3d, margin=FLAGS.triplet_loss_margin, min_negative_keypoint_distance=( configs['min_negative_keypoint_distance']), use_semi_hard=FLAGS.use_semi_hard_triplet_negatives, exclude_inactive_triplet_loss=( FLAGS.exclude_inactive_triplet_loss), anchor_keypoint_masks=anchor_keypoint_masks_3d, match_keypoint_masks=positive_keypoint_masks_3d, embedding_sample_distance_fn=( configs['triplet_embedding_sample_distance_fn']), keypoint_distance_fn=configs['keypoint_distance_fn'], anchor_mining_embeddings= triplet_anchor_mining_embeddings, positive_mining_embeddings= triplet_positive_mining_embeddings, match_mining_embeddings= triplet_positive_mining_embeddings, summarize_percentiles=FLAGS.summarize_percentiles)) tf.losses.add_loss(triplet_loss, loss_collection=tf.GraphKeys.LOSSES) summaries.update(triplet_loss_summaries) summaries['train/triplet_loss'] = triplet_loss
def test_transfer_keypoint_masks_case_2(self): # Shape = [2, 16]. input_keypoint_masks = tf.constant([ [ 1.0, # NOSE 1.0, # NECK 1.0, # LEFT_SHOULDER 1.0, # RIGHT_SHOULDER 0.0, # LEFT_ELBOW 1.0, # RIGHT_ELBOW 1.0, # LEFT_WRIST 0.0, # RIGHT_WRIST 0.0, # SPINE 0.0, # PELVIS 1.0, # LEFT_HIP 0.0, # RIGHT_HIP 1.0, # LEFT_KNEE 1.0, # RIGHT_KNEE 0.0, # LEFT_ANKLE 0.0, # RIGHT_ANKLE ], [ 0.0, # NOSE 0.0, # NECK 0.0, # LEFT_SHOULDER 0.0, # RIGHT_SHOULDER 1.0, # LEFT_ELBOW 0.0, # RIGHT_ELBOW 0.0, # LEFT_WRIST 1.0, # RIGHT_WRIST 1.0, # SPINE 1.0, # PELVIS 0.0, # LEFT_HIP 1.0, # RIGHT_HIP 0.0, # LEFT_KNEE 0.0, # RIGHT_KNEE 1.0, # LEFT_ANKLE 1.0, # RIGHT_ANKLE ] ]) input_keypoint_profile = keypoint_profiles.create_keypoint_profile_or_die( '3DSTD16') output_keypoint_profile = keypoint_profiles.create_keypoint_profile_or_die( '2DSTD13') # Shape = [2, 13]. output_keypoint_masks = keypoint_utils.transfer_keypoint_masks( input_keypoint_masks, input_keypoint_profile, output_keypoint_profile) self.assertAllClose( output_keypoint_masks, [ [ 1.0, # NOSE_TIP 1.0, # LEFT_SHOULDER 1.0, # RIGHT_SHOULDER 0.0, # LEFT_ELBOW 1.0, # RIGHT_ELBOW 1.0, # LEFT_WRIST 0.0, # RIGHT_WRIST 1.0, # LEFT_HIP 0.0, # RIGHT_HIP 1.0, # LEFT_KNEE 1.0, # RIGHT_KNEE 0.0, # LEFT_ANKLE 0.0, # RIGHT_ANKLE ], [ 0.0, # NOSE_TIP 0.0, # LEFT_SHOULDER 0.0, # RIGHT_SHOULDER 1.0, # LEFT_ELBOW 0.0, # RIGHT_ELBOW 0.0, # LEFT_WRIST 1.0, # RIGHT_WRIST 0.0, # LEFT_HIP 1.0, # RIGHT_HIP 0.0, # LEFT_KNEE 0.0, # RIGHT_KNEE 1.0, # LEFT_ANKLE 1.0, # RIGHT_ANKLE ] ])