Beispiel #1
0
    def test_intra_pair_variance_loss(self):
        pos_eps, neg_eps = 0.01, 0.02
        loss_func = IntraPairVarianceLoss(pos_eps, neg_eps)

        for dtype in TEST_DTYPES:
            embedding_angles = [0, 20, 40, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype,
            ).to(self.device)  # 2D embeddings
            labels = torch.LongTensor([0, 0, 1, 1, 2])

            loss = loss_func(embeddings, labels)
            loss.backward()

            pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
            neg_pairs = [
                (0, 2),
                (0, 3),
                (0, 4),
                (1, 2),
                (1, 3),
                (1, 4),
                (2, 0),
                (2, 1),
                (2, 4),
                (3, 0),
                (3, 1),
                (3, 4),
                (4, 0),
                (4, 1),
                (4, 2),
                (4, 3),
            ]

            pos_total, neg_total = 0, 0
            mean_pos = 0
            mean_neg = 0
            for a, p in pos_pairs:
                mean_pos += torch.matmul(embeddings[a], embeddings[p])
            for a, n in neg_pairs:
                mean_neg += torch.matmul(embeddings[a], embeddings[n])
            mean_pos /= len(pos_pairs)
            mean_neg /= len(neg_pairs)

            for a, p in pos_pairs:
                pos_total += (torch.relu(
                    ((1 - pos_eps) * mean_pos -
                     torch.matmul(embeddings[a], embeddings[p])))**2)
            for a, n in neg_pairs:
                neg_total += (torch.relu(
                    (torch.matmul(embeddings[a], embeddings[n]) -
                     (1 + neg_eps) * mean_neg))**2)

            pos_total /= len(pos_pairs)
            neg_total /= len(neg_pairs)
            correct_total = pos_total + neg_total
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(loss, correct_total, rtol=rtol))
Beispiel #2
0
 def test_with_no_valid_pairs(self):
     loss_func = IntraPairVarianceLoss(0.01,0.01)
     for dtype in TEST_DTYPES:
         embedding_angles = [0]
         embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings
         labels = torch.LongTensor([0])
         loss = loss_func(embeddings, labels)
         loss.backward()
         self.assertEqual(loss, 0)