def testManyHotLabels(self): batch_size = 7 num_classes = 400 num_positive = 5 # To help keep the test simple, we put the positive labels on the # first 'num_positive' classes in every example. labels = np.zeros((batch_size, num_classes), np.float32) labels[:, :num_positive] = 1.0 logits = np.random.uniform(size=labels.shape).astype( np.float32) * 10 + 1e7 losses = label_lib.MultiLabelContrastiveLoss( tf.convert_to_tensor(labels, dtype=tf.float32), tf.convert_to_tensor(logits, dtype=tf.float32)) # Verify that the multi-label loss is equivalent to the average softmax # cross entropy of each positive pair vs. all negative pairs. negative_pair_logits = logits[:, num_positive:] one_vs_all_labels = np.zeros( (batch_size, num_classes - num_positive + 1), np.float32) one_vs_all_labels[:, 0] = 1 expected_loss_terms = [] for i in range(num_positive): one_vs_all_logits = np.concatenate( [logits[:, i:(i + 1)], negative_pair_logits], axis=1) expected_loss_terms.append( tf.nn.softmax_cross_entropy_with_logits( labels=one_vs_all_labels, logits=one_vs_all_logits)) expected_loss = tf.add_n(expected_loss_terms) / num_positive self.assertAllClose(expected_loss, losses)
def testNoPositiveLabels(self): """Tests that the loss is zero for slices with no positive label.""" batch_size = 7 num_classes = 400 losses = label_lib.MultiLabelContrastiveLoss(labels=tf.zeros( (batch_size, num_classes)), logits=tf.zeros( (batch_size, num_classes))) self.assertAllClose(losses, tf.zeros(batch_size))
def testOneHotLabels(self): """Tests that the loss equals softmax CE when the labels are one hot.""" num_classes = 400 batch_size = 7 label_indices = np.random.randint(0, num_classes, size=(batch_size, 3)) labels = tf.one_hot(label_indices, depth=num_classes, dtype=tf.float32) logits = np.random.uniform(size=(batch_size, 3, num_classes)) * 10 + 1e7 logits_tensor = tf.convert_to_tensor(logits, dtype=tf.float32) losses = label_lib.MultiLabelContrastiveLoss(labels, logits_tensor) expected = tf.nn.softmax_cross_entropy_with_logits( labels=labels, logits=logits_tensor) self.assertAllClose(expected, losses)
def _ComputePerQueryLoss(query_modality, result_modality): labeler_inputs = label_lib.ExamplePairs.BetweenLocalAndGlobalBatches( input_batch, query_modality=query_modality, result_modality=result_modality) labels = p.label_fn(labeler_inputs) # [num_queries, num_results] flat_similarities = self.score_function( local_flat_encodings[query_modality], global_flat_encodings[result_modality]) flat_labels = tf.reshape(labels, flat_similarities.shape) # [num_queries] return label_lib.MultiLabelContrastiveLoss( labels=flat_labels, logits=flat_similarities)
def testIgnoreLabels(self): """Tests that pairs marked IGNORE_PAIR_LABEL are excluded from the loss.""" x = label_lib.IGNORE_PAIR_LABEL labels = tf.constant([ [0, 1, 0, x], [1, 0, 0, 0], [x, 0, x, 1], ], dtype=tf.float32) logits = tf.constant([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], dtype=tf.float32) losses = label_lib.MultiLabelContrastiveLoss(labels, logits) expected_losses = tf.stack([ tf.nn.sparse_softmax_cross_entropy_with_logits( labels=1, logits=[1.0, 2.0, 3.0]), tf.nn.sparse_softmax_cross_entropy_with_logits( labels=0, logits=[1.0, 2.0, 3.0, 4.0]), tf.nn.sparse_softmax_cross_entropy_with_logits(labels=1, logits=[2.0, 4.0]), ]) self.assertAllClose(expected_losses, losses)