Esempio n. 1
0
 def add_positive_pairwise_loss():
     """Adds positive pairwise loss."""
     (positive_pairwise_anchor_embeddings,
      positive_pairwise_positive_embeddings) = tf.unstack(
          pipeline_utils.stack_embeddings(
              outputs,
              configs['positive_pairwise_embedding_keys'],
              common_module=common_module),
          axis=1)
     if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss:
         positive_pairwise_anchor_embeddings = tf.math.l2_normalize(
             positive_pairwise_anchor_embeddings, axis=-1)
         positive_pairwise_positive_embeddings = tf.math.l2_normalize(
             positive_pairwise_positive_embeddings, axis=-1)
     positive_pairwise_loss, positive_pairwise_loss_summaries = (
         loss_utils.compute_positive_pairwise_loss(
             positive_pairwise_anchor_embeddings,
             positive_pairwise_positive_embeddings,
             loss_weight=FLAGS.positive_pairwise_loss_weight,
             distance_fn=configs[
                 'positive_pairwise_embedding_sample_distance_fn']))
     tf.losses.add_loss(positive_pairwise_loss,
                        loss_collection=tf.GraphKeys.LOSSES)
     summaries.update(positive_pairwise_loss_summaries)
     summaries[
         'train/positive_pairwise_loss'] = positive_pairwise_loss
  def test_compute_positive_pairwise_loss(self):
    anchor_embeddings = tf.constant([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
                                     [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]])
    positive_embeddings = tf.constant([[[12.0, 11.0], [10.0, 9.0], [8.0, 7.0]],
                                       [[6.0, 5.0], [4.0, 3.0], [2.0, 1.0]]])
    weighted_loss, summaries = loss_utils.compute_positive_pairwise_loss(
        anchor_embeddings, positive_embeddings, loss_weight=6.0)

    self.assertAlmostEqual(weighted_loss, 572.0)

    expected_summaries = {
        'pairwise_loss/PositivePair/Loss/Original': 95.333333333,
        'pairwise_loss/PositivePair/Loss/Weighted': 572.0,
        'pairwise_loss/PositivePair/Loss/Weight': 6.0,
    }
    self._assert_dict_equal_or_almost_equal(summaries, expected_summaries)