Ejemplo n.º 1
0
    def testManyHotLabels(self):
        batch_size = 7
        num_classes = 400
        num_positive = 5

        # To help keep the test simple, we put the positive labels on the
        # first 'num_positive' classes in every example.
        labels = np.zeros((batch_size, num_classes), np.float32)
        labels[:, :num_positive] = 1.0

        logits = np.random.uniform(size=labels.shape).astype(
            np.float32) * 10 + 1e7
        losses = label_lib.MultiLabelContrastiveLoss(
            tf.convert_to_tensor(labels, dtype=tf.float32),
            tf.convert_to_tensor(logits, dtype=tf.float32))

        # Verify that the multi-label loss is equivalent to the average softmax
        # cross entropy of each positive pair vs. all negative pairs.
        negative_pair_logits = logits[:, num_positive:]

        one_vs_all_labels = np.zeros(
            (batch_size, num_classes - num_positive + 1), np.float32)
        one_vs_all_labels[:, 0] = 1

        expected_loss_terms = []
        for i in range(num_positive):
            one_vs_all_logits = np.concatenate(
                [logits[:, i:(i + 1)], negative_pair_logits], axis=1)
            expected_loss_terms.append(
                tf.nn.softmax_cross_entropy_with_logits(
                    labels=one_vs_all_labels, logits=one_vs_all_logits))
        expected_loss = tf.add_n(expected_loss_terms) / num_positive
        self.assertAllClose(expected_loss, losses)
Ejemplo n.º 2
0
 def testNoPositiveLabels(self):
     """Tests that the loss is zero for slices with no positive label."""
     batch_size = 7
     num_classes = 400
     losses = label_lib.MultiLabelContrastiveLoss(labels=tf.zeros(
         (batch_size, num_classes)),
                                                  logits=tf.zeros(
                                                      (batch_size,
                                                       num_classes)))
     self.assertAllClose(losses, tf.zeros(batch_size))
Ejemplo n.º 3
0
  def testOneHotLabels(self):
    """Tests that the loss equals softmax CE when the labels are one hot."""
    num_classes = 400
    batch_size = 7
    label_indices = np.random.randint(0, num_classes, size=(batch_size, 3))
    labels = tf.one_hot(label_indices, depth=num_classes, dtype=tf.float32)
    logits = np.random.uniform(size=(batch_size, 3, num_classes)) * 10 + 1e7
    logits_tensor = tf.convert_to_tensor(logits, dtype=tf.float32)

    losses = label_lib.MultiLabelContrastiveLoss(labels, logits_tensor)
    expected = tf.nn.softmax_cross_entropy_with_logits(
        labels=labels, logits=logits_tensor)
    self.assertAllClose(expected, losses)
Ejemplo n.º 4
0
        def _ComputePerQueryLoss(query_modality, result_modality):
            labeler_inputs = label_lib.ExamplePairs.BetweenLocalAndGlobalBatches(
                input_batch,
                query_modality=query_modality,
                result_modality=result_modality)
            labels = p.label_fn(labeler_inputs)

            # [num_queries, num_results]
            flat_similarities = self.score_function(
                local_flat_encodings[query_modality],
                global_flat_encodings[result_modality])

            flat_labels = tf.reshape(labels, flat_similarities.shape)
            # [num_queries]
            return label_lib.MultiLabelContrastiveLoss(
                labels=flat_labels, logits=flat_similarities)
Ejemplo n.º 5
0
    def testIgnoreLabels(self):
        """Tests that pairs marked IGNORE_PAIR_LABEL are excluded from the loss."""
        x = label_lib.IGNORE_PAIR_LABEL
        labels = tf.constant([
            [0, 1, 0, x],
            [1, 0, 0, 0],
            [x, 0, x, 1],
        ],
                             dtype=tf.float32)
        logits = tf.constant([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]],
                             dtype=tf.float32)

        losses = label_lib.MultiLabelContrastiveLoss(labels, logits)
        expected_losses = tf.stack([
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=1, logits=[1.0, 2.0, 3.0]),
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=0, logits=[1.0, 2.0, 3.0, 4.0]),
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=1,
                                                           logits=[2.0, 4.0]),
        ])
        self.assertAllClose(expected_losses, losses)