コード例 #1
0
 def test_npair_loss(self):
     embedding1 = tf.reshape(self.get_embedding1(), [20, 3])
     embedding2 = tf.reshape(self.get_embedding2(), [20, 3])
     labels_onehot = tf.one_hot(tf.reshape(self.get_labels1(), [20]),
                                depth=3)
     valid_labels = tf.greater_equal(tf.reshape(self.get_labels1(), [20]),
                                     0)
     embedding1 = tf.boolean_mask(embedding1, valid_labels)
     embedding2 = tf.boolean_mask(embedding2, valid_labels)
     labels_onehot = tf.boolean_mask(labels_onehot, valid_labels)
     loss11 = mll.npair_loss(embedding1, labels_onehot, 'dotproduct',
                             'softmax')
     loss12 = mll.npair_loss(embedding2, labels_onehot, 'dotproduct',
                             'softmax')
     loss21 = mll.npair_loss(embedding1, labels_onehot, 'dotproduct',
                             'sigmoid')
     loss22 = mll.npair_loss(embedding2, labels_onehot, 'dotproduct',
                             'sigmoid')
     loss31 = mll.npair_loss(embedding1, labels_onehot, 'distance',
                             'softmax')
     loss32 = mll.npair_loss(embedding2, labels_onehot, 'distance',
                             'softmax')
     loss41 = mll.npair_loss(embedding1, labels_onehot, 'distance',
                             'sigmoid')
     loss42 = mll.npair_loss(embedding2, labels_onehot, 'distance',
                             'sigmoid')
     self.assertLess(loss11.numpy(), loss12.numpy())
     self.assertLess(loss21.numpy(), loss22.numpy())
     self.assertLess(loss31.numpy(), loss32.numpy())
     self.assertLess(loss41.numpy(), loss42.numpy())
コード例 #2
0
 def npair_loss_i():
   return metric_learning_losses.npair_loss(
       embedding=sampled_embeddings_i,
       target=target_i,
       similarity_strategy=similarity_strategy,
       loss_strategy=loss_strategy)