Ejemplo n.º 1
0
            def add_triplet_loss():
                """Adds triplet loss."""
                anchor_keypoints_3d, positive_keypoints_3d = tf.unstack(
                    inputs[common_module.KEY_KEYPOINTS_3D], num=2, axis=1)

                anchor_keypoint_masks_3d, positive_keypoint_masks_3d = None, None
                if FLAGS.use_inferred_keypoint_masks_for_triplet_label:
                    anchor_keypoint_masks_2d, positive_keypoint_masks_2d = tf.unstack(
                        inputs[
                            common_module.KEY_PREPROCESSED_KEYPOINT_MASKS_2D],
                        num=2,
                        axis=1)
                    anchor_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks(
                        anchor_keypoint_masks_2d,
                        input_keypoint_profile=configs['keypoint_profile_2d'],
                        output_keypoint_profile=configs['keypoint_profile_3d'],
                        enforce_surjectivity=True)
                    positive_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks(
                        positive_keypoint_masks_2d,
                        input_keypoint_profile=configs['keypoint_profile_2d'],
                        output_keypoint_profile=configs['keypoint_profile_3d'],
                        enforce_surjectivity=True)

                triplet_anchor_embeddings, triplet_positive_embeddings = tf.unstack(
                    pipeline_utils.stack_embeddings(
                        outputs, configs['triplet_embedding_keys']),
                    axis=1)
                if FLAGS.use_normalized_embeddings_for_triplet_loss:
                    triplet_anchor_embeddings = tf.math.l2_normalize(
                        triplet_anchor_embeddings, axis=-1)
                    triplet_positive_embeddings = tf.math.l2_normalize(
                        triplet_positive_embeddings, axis=-1)

                triplet_anchor_mining_embeddings, triplet_positive_mining_embeddings = (
                    tf.unstack(pipeline_utils.stack_embeddings(
                        outputs, configs['triplet_mining_embedding_keys']),
                               axis=1))
                if FLAGS.use_normalized_embeddings_for_triplet_mining:
                    triplet_anchor_mining_embeddings = tf.math.l2_normalize(
                        triplet_anchor_mining_embeddings, axis=-1)
                    triplet_positive_mining_embeddings = tf.math.l2_normalize(
                        triplet_positive_mining_embeddings, axis=-1)

                triplet_loss, triplet_loss_summaries = (
                    loss_utils.compute_keypoint_triplet_losses(
                        anchor_embeddings=triplet_anchor_embeddings,
                        positive_embeddings=triplet_positive_embeddings,
                        match_embeddings=triplet_positive_embeddings,
                        anchor_keypoints=anchor_keypoints_3d,
                        match_keypoints=positive_keypoints_3d,
                        margin=FLAGS.triplet_loss_margin,
                        min_negative_keypoint_distance=(
                            configs['min_negative_keypoint_distance']),
                        use_semi_hard=FLAGS.use_semi_hard_triplet_negatives,
                        exclude_inactive_triplet_loss=(
                            FLAGS.exclude_inactive_triplet_loss),
                        anchor_keypoint_masks=anchor_keypoint_masks_3d,
                        match_keypoint_masks=positive_keypoint_masks_3d,
                        embedding_sample_distance_fn=(
                            configs['triplet_embedding_sample_distance_fn']),
                        keypoint_distance_fn=configs['keypoint_distance_fn'],
                        anchor_mining_embeddings=
                        triplet_anchor_mining_embeddings,
                        positive_mining_embeddings=
                        triplet_positive_mining_embeddings,
                        match_mining_embeddings=
                        triplet_positive_mining_embeddings,
                        summarize_percentiles=FLAGS.summarize_percentiles))
                tf.losses.add_loss(triplet_loss,
                                   loss_collection=tf.GraphKeys.LOSSES)
                summaries.update(triplet_loss_summaries)
                summaries['train/triplet_loss'] = triplet_loss
def test_transfer_keypoint_masks_case_2(self):
  # Shape = [2, 16].
  input_keypoint_masks = tf.constant([
      [
          1.0,  # NOSE
          1.0,  # NECK
          1.0,  # LEFT_SHOULDER
          1.0,  # RIGHT_SHOULDER
          0.0,  # LEFT_ELBOW
          1.0,  # RIGHT_ELBOW
          1.0,  # LEFT_WRIST
          0.0,  # RIGHT_WRIST
          0.0,  # SPINE
          0.0,  # PELVIS
          1.0,  # LEFT_HIP
          0.0,  # RIGHT_HIP
          1.0,  # LEFT_KNEE
          1.0,  # RIGHT_KNEE
          0.0,  # LEFT_ANKLE
          0.0,  # RIGHT_ANKLE
      ],
      [
          0.0,  # NOSE
          0.0,  # NECK
          0.0,  # LEFT_SHOULDER
          0.0,  # RIGHT_SHOULDER
          1.0,  # LEFT_ELBOW
          0.0,  # RIGHT_ELBOW
          0.0,  # LEFT_WRIST
          1.0,  # RIGHT_WRIST
          1.0,  # SPINE
          1.0,  # PELVIS
          0.0,  # LEFT_HIP
          1.0,  # RIGHT_HIP
          0.0,  # LEFT_KNEE
          0.0,  # RIGHT_KNEE
          1.0,  # LEFT_ANKLE
          1.0,  # RIGHT_ANKLE
      ]
  ])
  input_keypoint_profile = keypoint_profiles.create_keypoint_profile_or_die(
      '3DSTD16')
  output_keypoint_profile = keypoint_profiles.create_keypoint_profile_or_die(
      '2DSTD13')
  # Shape = [2, 13].
  output_keypoint_masks = keypoint_utils.transfer_keypoint_masks(
      input_keypoint_masks, input_keypoint_profile, output_keypoint_profile)
  self.assertAllClose(
      output_keypoint_masks,
      [
          [
              1.0,  # NOSE_TIP
              1.0,  # LEFT_SHOULDER
              1.0,  # RIGHT_SHOULDER
              0.0,  # LEFT_ELBOW
              1.0,  # RIGHT_ELBOW
              1.0,  # LEFT_WRIST
              0.0,  # RIGHT_WRIST
              1.0,  # LEFT_HIP
              0.0,  # RIGHT_HIP
              1.0,  # LEFT_KNEE
              1.0,  # RIGHT_KNEE
              0.0,  # LEFT_ANKLE
              0.0,  # RIGHT_ANKLE
          ],
          [
              0.0,  # NOSE_TIP
              0.0,  # LEFT_SHOULDER
              0.0,  # RIGHT_SHOULDER
              1.0,  # LEFT_ELBOW
              0.0,  # RIGHT_ELBOW
              0.0,  # LEFT_WRIST
              1.0,  # RIGHT_WRIST
              0.0,  # LEFT_HIP
              1.0,  # RIGHT_HIP
              0.0,  # LEFT_KNEE
              0.0,  # RIGHT_KNEE
              1.0,  # LEFT_ANKLE
              1.0,  # RIGHT_ANKLE
          ]
      ])