Пример #1
0
 def add_positive_pairwise_loss():
     """Adds positive pairwise loss."""
     (positive_pairwise_anchor_embeddings,
      positive_pairwise_positive_embeddings) = tf.unstack(
          pipeline_utils.stack_embeddings(
              outputs,
              configs['positive_pairwise_embedding_keys'],
              common_module=common_module),
          axis=1)
     if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss:
         positive_pairwise_anchor_embeddings = tf.math.l2_normalize(
             positive_pairwise_anchor_embeddings, axis=-1)
         positive_pairwise_positive_embeddings = tf.math.l2_normalize(
             positive_pairwise_positive_embeddings, axis=-1)
     positive_pairwise_loss, positive_pairwise_loss_summaries = (
         loss_utils.compute_positive_pairwise_loss(
             positive_pairwise_anchor_embeddings,
             positive_pairwise_positive_embeddings,
             loss_weight=FLAGS.positive_pairwise_loss_weight,
             distance_fn=configs[
                 'positive_pairwise_embedding_sample_distance_fn']))
     tf.losses.add_loss(positive_pairwise_loss,
                        loss_collection=tf.GraphKeys.LOSSES)
     summaries.update(positive_pairwise_loss_summaries)
     summaries[
         'train/positive_pairwise_loss'] = positive_pairwise_loss
Пример #2
0
            def add_triplet_loss():
                """Adds triplet loss."""
                anchor_keypoints_3d, positive_keypoints_3d = tf.unstack(
                    inputs[common_module.KEY_KEYPOINTS_3D], num=2, axis=1)

                anchor_keypoint_masks_3d, positive_keypoint_masks_3d = None, None
                if FLAGS.use_inferred_keypoint_masks_for_triplet_label:
                    anchor_keypoint_masks_2d, positive_keypoint_masks_2d = tf.unstack(
                        inputs[
                            common_module.KEY_PREPROCESSED_KEYPOINT_MASKS_2D],
                        num=2,
                        axis=1)
                    anchor_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks(
                        anchor_keypoint_masks_2d,
                        input_keypoint_profile=configs['keypoint_profile_2d'],
                        output_keypoint_profile=configs['keypoint_profile_3d'],
                        enforce_surjectivity=True)
                    positive_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks(
                        positive_keypoint_masks_2d,
                        input_keypoint_profile=configs['keypoint_profile_2d'],
                        output_keypoint_profile=configs['keypoint_profile_3d'],
                        enforce_surjectivity=True)

                triplet_anchor_embeddings, triplet_positive_embeddings = tf.unstack(
                    pipeline_utils.stack_embeddings(
                        outputs, configs['triplet_embedding_keys']),
                    axis=1)
                if FLAGS.use_normalized_embeddings_for_triplet_loss:
                    triplet_anchor_embeddings = tf.math.l2_normalize(
                        triplet_anchor_embeddings, axis=-1)
                    triplet_positive_embeddings = tf.math.l2_normalize(
                        triplet_positive_embeddings, axis=-1)

                triplet_anchor_mining_embeddings, triplet_positive_mining_embeddings = (
                    tf.unstack(pipeline_utils.stack_embeddings(
                        outputs, configs['triplet_mining_embedding_keys']),
                               axis=1))
                if FLAGS.use_normalized_embeddings_for_triplet_mining:
                    triplet_anchor_mining_embeddings = tf.math.l2_normalize(
                        triplet_anchor_mining_embeddings, axis=-1)
                    triplet_positive_mining_embeddings = tf.math.l2_normalize(
                        triplet_positive_mining_embeddings, axis=-1)

                triplet_loss, triplet_loss_summaries = (
                    loss_utils.compute_keypoint_triplet_losses(
                        anchor_embeddings=triplet_anchor_embeddings,
                        positive_embeddings=triplet_positive_embeddings,
                        match_embeddings=triplet_positive_embeddings,
                        anchor_keypoints=anchor_keypoints_3d,
                        match_keypoints=positive_keypoints_3d,
                        margin=FLAGS.triplet_loss_margin,
                        min_negative_keypoint_distance=(
                            configs['min_negative_keypoint_distance']),
                        use_semi_hard=FLAGS.use_semi_hard_triplet_negatives,
                        exclude_inactive_triplet_loss=(
                            FLAGS.exclude_inactive_triplet_loss),
                        anchor_keypoint_masks=anchor_keypoint_masks_3d,
                        match_keypoint_masks=positive_keypoint_masks_3d,
                        embedding_sample_distance_fn=(
                            configs['triplet_embedding_sample_distance_fn']),
                        keypoint_distance_fn=configs['keypoint_distance_fn'],
                        anchor_mining_embeddings=
                        triplet_anchor_mining_embeddings,
                        positive_mining_embeddings=
                        triplet_positive_mining_embeddings,
                        match_mining_embeddings=
                        triplet_positive_mining_embeddings,
                        summarize_percentiles=FLAGS.summarize_percentiles))
                tf.losses.add_loss(triplet_loss,
                                   loss_collection=tf.GraphKeys.LOSSES)
                summaries.update(triplet_loss_summaries)
                summaries['train/triplet_loss'] = triplet_loss
Пример #3
0
            def add_triplet_loss():
                """Adds triplet loss."""
                anchor_keypoints_3d = tf.unstack(
                    inputs[common_module.KEY_KEYPOINTS_3D], num=2, axis=1)[0]
                if (configs['keypoint_profile_3d'].keypoint_names !=
                        configs['target_keypoint_profile_3d'].keypoint_names):
                    # Select target keypoints to use if they are different than input.
                    anchor_keypoints_3d, _ = keypoint_utils.select_keypoints_by_name(
                        anchor_keypoints_3d,
                        input_keypoint_names=(
                            configs['keypoint_profile_3d'].keypoint_names),
                        output_keypoint_names=(
                            configs['target_keypoint_profile_3d'].
                            keypoint_names))

                triplet_anchor_embeddings, triplet_positive_embeddings = tf.unstack(
                    pipeline_utils.stack_embeddings(
                        outputs, configs['triplet_embedding_keys']),
                    axis=1)
                if FLAGS.use_normalized_embeddings_for_triplet_loss:
                    triplet_anchor_embeddings = tf.math.l2_normalize(
                        triplet_anchor_embeddings, axis=-1)
                    triplet_positive_embeddings = tf.math.l2_normalize(
                        triplet_positive_embeddings, axis=-1)

                triplet_anchor_mining_embeddings, triplet_positive_mining_embeddings = (
                    tf.unstack(pipeline_utils.stack_embeddings(
                        outputs, configs['triplet_mining_embedding_keys']),
                               axis=1))
                if FLAGS.use_normalized_embeddings_for_triplet_mining:
                    triplet_anchor_mining_embeddings = tf.math.l2_normalize(
                        triplet_anchor_mining_embeddings, axis=-1)
                    triplet_positive_mining_embeddings = tf.math.l2_normalize(
                        triplet_positive_mining_embeddings, axis=-1)

                triplet_loss, triplet_loss_summaries = (
                    loss_utils.compute_keypoint_triplet_losses(
                        anchor_embeddings=triplet_anchor_embeddings,
                        positive_embeddings=triplet_positive_embeddings,
                        match_embeddings=triplet_anchor_embeddings,
                        anchor_keypoints=anchor_keypoints_3d,
                        match_keypoints=anchor_keypoints_3d,
                        margin=FLAGS.triplet_loss_margin,
                        min_negative_keypoint_distance=(
                            configs['min_negative_keypoint_distance']),
                        use_semi_hard=FLAGS.use_semi_hard_triplet_negatives,
                        exclude_inactive_triplet_loss=(
                            FLAGS.exclude_inactive_triplet_loss),
                        embedding_sample_distance_fn=(
                            configs['triplet_embedding_sample_distance_fn']),
                        keypoint_distance_fn=configs['keypoint_distance_fn'],
                        anchor_mining_embeddings=
                        triplet_anchor_mining_embeddings,
                        positive_mining_embeddings=
                        triplet_positive_mining_embeddings,
                        match_mining_embeddings=
                        triplet_anchor_mining_embeddings,
                        summarize_percentiles=FLAGS.summarize_percentiles))
                tf.losses.add_loss(triplet_loss,
                                   loss_collection=tf.GraphKeys.LOSSES)
                summaries.update(triplet_loss_summaries)
                summaries['train/triplet_loss'] = triplet_loss