def test_tuplet_margin_loss(self):
        margin, scale = 5, 64
        loss_func = TupletMarginLoss(margin=margin, scale=scale)

        for dtype in TEST_DTYPES:
            embedding_angles = [0, 20, 40, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype,
            ).to(
                self.device
            )  # 2D embeddings
            labels = torch.LongTensor([0, 0, 1, 1, 2])

            loss = loss_func(embeddings, labels)
            loss.backward()

            pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
            neg_pairs = [
                (0, 2),
                (0, 3),
                (0, 4),
                (1, 2),
                (1, 3),
                (1, 4),
                (2, 0),
                (2, 1),
                (2, 4),
                (3, 0),
                (3, 1),
                (3, 4),
                (4, 0),
                (4, 1),
                (4, 2),
                (4, 3),
            ]

            correct_total = 0

            for a1, p in pos_pairs:
                curr_loss = 0
                anchor1, positive = embeddings[a1], embeddings[p]
                ap_angle = torch.acos(
                    torch.matmul(anchor1, positive)
                )  # embeddings are normalized, so dot product == cosine
                ap_cos = torch.cos(ap_angle - np.radians(margin))
                for a2, n in neg_pairs:
                    if a2 == a1:
                        anchor2, negative = embeddings[a2], embeddings[n]
                        an_cos = torch.matmul(anchor2, negative)
                        curr_loss += torch.exp(scale * (an_cos - ap_cos))

                curr_total = torch.log(1 + curr_loss)
                correct_total += curr_total

            correct_total /= len(pos_pairs)
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(loss, correct_total, rtol=rtol))
Esempio n. 2
0
 def test_with_no_valid_pairs(self):
     loss_func = TupletMarginLoss(0.1, 64)
     embedding_angles = [0]
     embeddings = torch.tensor(
         [c_f.angle_to_coord(a) for a in embedding_angles],
         requires_grad=True,
         dtype=torch.float)  #2D embeddings
     labels = torch.LongTensor([0])
     loss = loss_func(embeddings, labels)
     loss.backward()
     self.assertEqual(loss, 0)