def test_select_keypoints_by_name(self):
     input_keypoints = tf.constant([
         [0.0, 0.0, 0.0],
         [1.0, 1.0, 1.0],
         [2.0, 2.0, 2.0],
         [3.0, 3.0, 3.0],
         [4.0, 4.0, 4.0],
         [5.0, 5.0, 5.0],
         [6.0, 6.0, 6.0],
         [7.0, 7.0, 7.0],
         [8.0, 8.0, 8.0],
         [9.0, 9.0, 9.0],
         [10.0, 10.0, 10.0],
         [11.0, 11.0, 11.0],
         [12.0, 12.0, 12.0],
         [13.0, 13.0, 13.0],
         [14.0, 14.0, 14.0],
         [15.0, 15.0, 15.0],
         [16.0, 16.0, 16.0],
     ])
     keypoint_profile_3d = (
         keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17')
     )
     keypoint_profile_2d = (
         keypoint_profiles.create_keypoint_profile_or_die('LEGACY_2DCOCO13')
     )
     output_keypoints, _ = keypoint_utils.select_keypoints_by_name(
         input_keypoints,
         input_keypoint_names=keypoint_profile_3d.keypoint_names,
         output_keypoint_names=(
             keypoint_profile_2d.
             compatible_keypoint_name_dict['LEGACY_3DH36M17']))
     self.assertAllClose(output_keypoints, [
         [1.0, 1.0, 1.0],
         [4.0, 4.0, 4.0],
         [5.0, 5.0, 5.0],
         [6.0, 6.0, 6.0],
         [7.0, 7.0, 7.0],
         [8.0, 8.0, 8.0],
         [9.0, 9.0, 9.0],
         [11.0, 11.0, 11.0],
         [12.0, 12.0, 12.0],
         [13.0, 13.0, 13.0],
         [14.0, 14.0, 14.0],
         [15.0, 15.0, 15.0],
         [16.0, 16.0, 16.0],
     ])
示例#2
0
            def add_triplet_loss():
                """Adds triplet loss."""
                anchor_keypoints_3d = tf.unstack(
                    inputs[common_module.KEY_KEYPOINTS_3D], num=2, axis=1)[0]
                if (configs['keypoint_profile_3d'].keypoint_names !=
                        configs['target_keypoint_profile_3d'].keypoint_names):
                    # Select target keypoints to use if they are different than input.
                    anchor_keypoints_3d, _ = keypoint_utils.select_keypoints_by_name(
                        anchor_keypoints_3d,
                        input_keypoint_names=(
                            configs['keypoint_profile_3d'].keypoint_names),
                        output_keypoint_names=(
                            configs['target_keypoint_profile_3d'].
                            keypoint_names))

                triplet_anchor_embeddings, triplet_positive_embeddings = tf.unstack(
                    pipeline_utils.stack_embeddings(
                        outputs, configs['triplet_embedding_keys']),
                    axis=1)
                if FLAGS.use_normalized_embeddings_for_triplet_loss:
                    triplet_anchor_embeddings = tf.math.l2_normalize(
                        triplet_anchor_embeddings, axis=-1)
                    triplet_positive_embeddings = tf.math.l2_normalize(
                        triplet_positive_embeddings, axis=-1)

                triplet_anchor_mining_embeddings, triplet_positive_mining_embeddings = (
                    tf.unstack(pipeline_utils.stack_embeddings(
                        outputs, configs['triplet_mining_embedding_keys']),
                               axis=1))
                if FLAGS.use_normalized_embeddings_for_triplet_mining:
                    triplet_anchor_mining_embeddings = tf.math.l2_normalize(
                        triplet_anchor_mining_embeddings, axis=-1)
                    triplet_positive_mining_embeddings = tf.math.l2_normalize(
                        triplet_positive_mining_embeddings, axis=-1)

                triplet_loss, triplet_loss_summaries = (
                    loss_utils.compute_keypoint_triplet_losses(
                        anchor_embeddings=triplet_anchor_embeddings,
                        positive_embeddings=triplet_positive_embeddings,
                        match_embeddings=triplet_anchor_embeddings,
                        anchor_keypoints=anchor_keypoints_3d,
                        match_keypoints=anchor_keypoints_3d,
                        margin=FLAGS.triplet_loss_margin,
                        min_negative_keypoint_distance=(
                            configs['min_negative_keypoint_distance']),
                        use_semi_hard=FLAGS.use_semi_hard_triplet_negatives,
                        exclude_inactive_triplet_loss=(
                            FLAGS.exclude_inactive_triplet_loss),
                        embedding_sample_distance_fn=(
                            configs['triplet_embedding_sample_distance_fn']),
                        keypoint_distance_fn=configs['keypoint_distance_fn'],
                        anchor_mining_embeddings=
                        triplet_anchor_mining_embeddings,
                        positive_mining_embeddings=
                        triplet_positive_mining_embeddings,
                        match_mining_embeddings=
                        triplet_anchor_mining_embeddings,
                        summarize_percentiles=FLAGS.summarize_percentiles))
                tf.losses.add_loss(triplet_loss,
                                   loss_collection=tf.GraphKeys.LOSSES)
                summaries.update(triplet_loss_summaries)
                summaries['train/triplet_loss'] = triplet_loss