示例#1
0
    def test_nca_loss(self):
        softmax_scale = 10
        loss_func = NCALoss(softmax_scale=softmax_scale)

        embedding_angles = [0, 20, 40, 60, 80]
        embeddings = torch.tensor(
            [c_f.angle_to_coord(a) for a in embedding_angles],
            requires_grad=True,
            dtype=torch.float)  #2D embeddings
        labels = torch.LongTensor([0, 0, 1, 1, 2])

        loss = loss_func(embeddings, labels)
        loss.backward()

        pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
        neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0),
                     (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1),
                     (4, 2), (4, 3)]

        correct_total = 0
        for a1, p in pos_pairs:
            anchor1, positive = embeddings[a1], embeddings[p]
            ap_dist = torch.sum((anchor1 - positive)**2)
            numerator = torch.exp(-ap_dist * softmax_scale)
            denominator = numerator.clone()
            for a2, n in neg_pairs:
                if a2 == a1:
                    anchor2, negative = embeddings[a2], embeddings[n]
                    an_dist = torch.sum((anchor2 - negative)**2)
                    denominator += torch.exp(-an_dist * softmax_scale)

            correct_total += -torch.log(numerator / denominator)

        correct_total /= len(pos_pairs)
        self.assertTrue(torch.isclose(loss, correct_total))
示例#2
0
 def test_zero_loss(self):
     loss_func = NCALoss(10)
     for dtype in TEST_DTYPES:
         for embedding_angles, labels in [([0,20], [0,0]), ([0], [0])]:
             embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings
             labels = torch.LongTensor(labels)
             loss = loss_func(embeddings, labels)
             loss.backward()
             self.assertEqual(loss, 0)