Exemple #1
0
 def test_with_no_valid_pairs(self):
     lossA = ContrastiveLoss(use_similarity=False)
     lossB = ContrastiveLoss(use_similarity=True)
     embedding_angles = [0]
     embeddings = torch.FloatTensor(
         [c_f.angle_to_coord(a) for a in embedding_angles])  #2D embeddings
     labels = torch.LongTensor([0])
     self.assertEqual(lossA(embeddings, labels), 0)
     self.assertEqual(lossB(embeddings, labels), 0)
Exemple #2
0
 def test_backward(self):
     loss_funcA = ContrastiveLoss(use_similarity=False)
     loss_funcB = ContrastiveLoss(use_similarity=True)
     for dtype in [torch.float16, torch.float32, torch.float64]:
         for loss_func in [loss_funcA, loss_funcB]:
             embedding_angles = [0]
             embeddings = torch.tensor(
                 [c_f.angle_to_coord(a) for a in embedding_angles],
                 requires_grad=True,
                 dtype=dtype).to(self.device)  #2D embeddings
             labels = torch.LongTensor([0])
             loss = loss_func(embeddings, labels)
             loss.backward()
Exemple #3
0
 def test_backward(self):
     loss_funcA = ContrastiveLoss()
     loss_funcB = ContrastiveLoss(distance=CosineSimilarity())
     for dtype in TEST_DTYPES:
         for loss_func in [loss_funcA, loss_funcB]:
             embedding_angles = [0]
             embeddings = torch.tensor(
                 [c_f.angle_to_coord(a) for a in embedding_angles],
                 requires_grad=True,
                 dtype=dtype).to(self.device)  #2D embeddings
             labels = torch.LongTensor([0])
             loss = loss_func(embeddings, labels)
             loss.backward()
Exemple #4
0
 def test_with_no_valid_pairs(self):
     loss_funcA = ContrastiveLoss()
     loss_funcB = ContrastiveLoss(distance=CosineSimilarity())
     for dtype in TEST_DTYPES:
         embedding_angles = [0]
         embeddings = torch.tensor(
             [c_f.angle_to_coord(a) for a in embedding_angles],
             requires_grad=True,
             dtype=dtype).to(self.device)  #2D embeddings
         labels = torch.LongTensor([0])
         lossA = loss_funcA(embeddings, labels)
         lossB = loss_funcB(embeddings, labels)
         self.assertEqual(lossA, 0)
         self.assertEqual(lossB, 0)
Exemple #5
0
 def test_with_no_valid_pairs(self):
     loss_funcA = ContrastiveLoss(use_similarity=False)
     loss_funcB = ContrastiveLoss(use_similarity=True)
     for dtype in [torch.float16, torch.float32, torch.float64]:
         embedding_angles = [0]
         embeddings = torch.tensor(
             [c_f.angle_to_coord(a) for a in embedding_angles],
             requires_grad=True,
             dtype=dtype).to(self.device)  #2D embeddings
         labels = torch.LongTensor([0])
         lossA = loss_funcA(embeddings, labels)
         lossB = loss_funcB(embeddings, labels)
         self.assertEqual(lossA, 0)
         self.assertEqual(lossB, 0)
Exemple #6
0
 def test_input_indices_tuple(self):
     batch_size = 32
     pair_miner = PairMarginMiner(pos_margin=0,
                                  neg_margin=1,
                                  use_similarity=False)
     triplet_miner = TripletMarginMiner(margin=1)
     self.loss = CrossBatchMemory(loss=ContrastiveLoss(),
                                  embedding_size=self.embedding_size,
                                  memory_size=self.memory_size)
     for i in range(30):
         embeddings = torch.randn(batch_size, self.embedding_size)
         labels = torch.arange(batch_size)
         self.loss(embeddings, labels)
         for curr_miner in [pair_miner, triplet_miner]:
             input_indices_tuple = curr_miner(embeddings, labels)
             all_labels = torch.cat([labels, self.loss.label_memory], dim=0)
             a1ii, pii, a2ii, nii = lmu.convert_to_pairs(
                 input_indices_tuple, labels)
             a1i, pi, a2i, ni = lmu.get_all_pairs_indices(
                 labels, self.loss.label_memory)
             a1, p, a2, n = self.loss.create_indices_tuple(
                 batch_size, embeddings, labels, self.loss.embedding_memory,
                 self.loss.label_memory, input_indices_tuple)
             self.assertTrue(not torch.any((all_labels[a1] -
                                            all_labels[p]).bool()))
             self.assertTrue(
                 torch.all((all_labels[a2] - all_labels[n]).bool()))
             self.assertTrue(len(a1) == len(a1i) + len(a1ii))
             self.assertTrue(len(p) == len(pi) + len(pii))
             self.assertTrue(len(a2) == len(a2i) + len(a2ii))
             self.assertTrue(len(n) == len(ni) + len(nii))
    def test_key_mismatch(self):
        lossA = ContrastiveLoss()
        lossB = TripletMarginLoss(0.1)
        self.assertRaises(
            AssertionError,
            lambda: MultipleLosses(
                losses={
                    "lossA": lossA,
                    "lossB": lossB
                },
                weights={
                    "blah": 1,
                    "lossB": 0.23
                },
            ),
        )

        minerA = MultiSimilarityMiner()
        self.assertRaises(
            AssertionError,
            lambda: MultipleLosses(
                losses={
                    "lossA": lossA,
                    "lossB": lossB
                },
                weights={
                    "lossA": 1,
                    "lossB": 0.23
                },
                miners={"blah": minerA},
            ),
        )
Exemple #8
0
    def test_multiple_losses(self):
        lossA = ContrastiveLoss()
        lossB = TripletMarginLoss(0.1)
        loss_func = MultipleLosses(losses={
            "lossA": lossA,
            "lossB": lossB
        },
                                   weights={
                                       "lossA": 1,
                                       "lossB": 0.23
                                   })

        for dtype in TEST_DTYPES:
            embedding_angles = torch.arange(0, 180)
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype).to(self.device)  #2D embeddings
            labels = torch.randint(low=0, high=10, size=(180, ))

            loss = loss_func(embeddings, labels)
            loss.backward()

            correct_loss = lossA(embeddings,
                                 labels) + lossB(embeddings, labels) * 0.23
            self.assertTrue(torch.isclose(loss, correct_loss))
Exemple #9
0
    def test_shift_indices_tuple(self):
        for dtype in TEST_DTYPES:
            batch_size = 32
            pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1)
            triplet_miner = TripletMarginMiner(margin=1)
            self.loss = CrossBatchMemory(
                loss=ContrastiveLoss(),
                embedding_size=self.embedding_size,
                memory_size=self.memory_size,
            )
            for i in range(30):
                embeddings = (
                    torch.randn(batch_size, self.embedding_size)
                    .to(TEST_DEVICE)
                    .type(dtype)
                )
                labels = torch.arange(batch_size).to(TEST_DEVICE)
                loss = self.loss(embeddings, labels)
                all_labels = torch.cat([labels, self.loss.label_memory], dim=0)

                indices_tuple = lmu.get_all_pairs_indices(
                    labels, self.loss.label_memory
                )
                shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
                self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
                self.assertTrue(torch.equal(indices_tuple[2], shifted[2]))
                self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size))
                self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size))
                a1, p, a2, n = shifted
                self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool()))
                self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool()))

                indices_tuple = pair_miner(
                    embeddings,
                    labels,
                    self.loss.embedding_memory,
                    self.loss.label_memory,
                )
                shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
                self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
                self.assertTrue(torch.equal(indices_tuple[2], shifted[2]))
                self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size))
                self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size))
                a1, p, a2, n = shifted
                self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool()))
                self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool()))

                indices_tuple = triplet_miner(
                    embeddings,
                    labels,
                    self.loss.embedding_memory,
                    self.loss.label_memory,
                )
                shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
                self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
                self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size))
                self.assertTrue(torch.equal(indices_tuple[2], shifted[2] - batch_size))
                a, p, n = shifted
                self.assertTrue(not torch.any((all_labels[a] - all_labels[p]).bool()))
                self.assertTrue(torch.all((all_labels[p] - all_labels[n]).bool()))
    def test_input_indices_tuple(self):
        lossA = ContrastiveLoss()
        lossB = TripletMarginLoss(0.1)
        miner = MultiSimilarityMiner()
        loss_func1 = MultipleLosses(losses={
            "lossA": lossA,
            "lossB": lossB
        },
                                    weights={
                                        "lossA": 1,
                                        "lossB": 0.23
                                    })

        loss_func2 = MultipleLosses(losses=[lossA, lossB], weights=[1, 0.23])

        for loss_func in [loss_func1, loss_func2]:
            for dtype in TEST_DTYPES:
                embedding_angles = torch.arange(0, 180)
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=dtype,
                ).to(TEST_DEVICE)  # 2D embeddings
                labels = torch.randint(low=0, high=10, size=(180, ))
                indices_tuple = miner(embeddings, labels)

                loss = loss_func(embeddings, labels, indices_tuple)
                loss.backward()

                correct_loss = (
                    lossA(embeddings, labels, indices_tuple) +
                    lossB(embeddings, labels, indices_tuple) * 0.23)
                self.assertTrue(torch.isclose(loss, correct_loss))
 def test_deepcopy_reducer(self):
     loss_fn = ContrastiveLoss(pos_margin=0,
                               neg_margin=2,
                               reducer=AvgNonZeroReducer())
     embeddings = torch.randn(128, 64)
     labels = torch.randint(low=0, high=10, size=(128, ))
     loss = loss_fn(embeddings, labels)
     self.assertTrue(
         loss_fn.reducer.reducers["pos_loss"].pos_pairs_past_filter > 0)
     self.assertTrue(
         loss_fn.reducer.reducers["neg_loss"].neg_pairs_past_filter > 0)
    def test_loss(self):
        num_labels = 10
        num_iter = 10
        batch_size = 32
        inner_loss = ContrastiveLoss()
        inner_miner = MultiSimilarityMiner(0.3)
        outer_miner = MultiSimilarityMiner(0.2)
        self.loss = CrossBatchMemory(loss=inner_loss, embedding_size=self.embedding_size, memory_size=self.memory_size)
        self.loss_with_miner = CrossBatchMemory(loss=inner_loss, miner=inner_miner, embedding_size=self.embedding_size, memory_size=self.memory_size)
        self.loss_with_miner2 = CrossBatchMemory(loss=inner_loss, miner=inner_miner, embedding_size=self.embedding_size, memory_size=self.memory_size)
        all_embeddings = torch.FloatTensor([])
        all_labels = torch.LongTensor([])
        for i in range(num_iter):
            embeddings = torch.randn(batch_size, self.embedding_size)
            labels = torch.randint(0,num_labels,(batch_size,))
            loss = self.loss(embeddings, labels)
            loss_with_miner = self.loss_with_miner(embeddings, labels)
            oa1, op, oa2, on = outer_miner(embeddings, labels)
            loss_with_miner_and_input_indices = self.loss_with_miner2(embeddings, labels, (oa1, op, oa2, on))
            all_embeddings = torch.cat([all_embeddings, embeddings])
            all_labels = torch.cat([all_labels, labels])

            # loss with no inner miner
            indices_tuple = lmu.get_all_pairs_indices(labels, all_labels)
            a1,p,a2,n = self.loss.remove_self_comparisons(indices_tuple)
            p = p+batch_size
            n = n+batch_size
            correct_loss = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n))
            self.assertTrue(torch.isclose(loss, correct_loss))

            # loss with inner miner
            indices_tuple = inner_miner(embeddings, labels, all_embeddings, all_labels)
            a1,p,a2,n = self.loss_with_miner.remove_self_comparisons(indices_tuple)
            p = p+batch_size
            n = n+batch_size
            correct_loss_with_miner = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n))
            self.assertTrue(torch.isclose(loss_with_miner, correct_loss_with_miner))

            # loss with inner and outer miner
            indices_tuple = inner_miner(embeddings, labels, all_embeddings, all_labels)
            a1,p,a2,n = self.loss_with_miner2.remove_self_comparisons(indices_tuple)
            p = p+batch_size
            n = n+batch_size
            a1 = torch.cat([oa1, a1])
            p = torch.cat([op, p])
            a2 = torch.cat([oa2, a2])
            n = torch.cat([on, n])
            correct_loss_with_miner_and_input_indice = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n))
            self.assertTrue(torch.isclose(loss_with_miner_and_input_indices, correct_loss_with_miner_and_input_indice))
    def test_length_mistmatch(self):
        lossA = ContrastiveLoss()
        lossB = TripletMarginLoss(0.1)
        self.assertRaises(
            AssertionError,
            lambda: MultipleLosses(losses=[lossA, lossB], weights=[1]))

        minerA = MultiSimilarityMiner()
        self.assertRaises(
            AssertionError,
            lambda: MultipleLosses(
                losses=[lossA, lossB],
                weights=[1, 0.2],
                miners=[minerA],
            ),
        )
    def test_queue(self):
        for dtype in [torch.float16, torch.float32, torch.float64]:
            batch_size = 32
            self.loss = CrossBatchMemory(loss=ContrastiveLoss(),
                                         embedding_size=self.embedding_size,
                                         memory_size=self.memory_size)
            for i in range(30):
                embeddings = torch.randn(batch_size, self.embedding_size).to(
                    self.device).type(dtype)
                labels = torch.arange(batch_size).to(self.device)
                q = self.loss.queue_idx
                self.assertTrue(q == (i * batch_size) % self.memory_size)
                loss = self.loss(embeddings, labels)

                start_idx = q
                if q + batch_size == self.memory_size:
                    end_idx = self.memory_size
                else:
                    end_idx = (q + batch_size) % self.memory_size
                if start_idx < end_idx:
                    self.assertTrue(
                        torch.equal(
                            embeddings,
                            self.loss.embedding_memory[start_idx:end_idx]))
                    self.assertTrue(
                        torch.equal(labels,
                                    self.loss.label_memory[start_idx:end_idx]))
                else:
                    correct_embeddings = torch.cat([
                        self.loss.embedding_memory[start_idx:],
                        self.loss.embedding_memory[:end_idx]
                    ],
                                                   dim=0)
                    correct_labels = torch.cat([
                        self.loss.label_memory[start_idx:],
                        self.loss.label_memory[:end_idx]
                    ],
                                               dim=0)
                    self.assertTrue(torch.equal(embeddings,
                                                correct_embeddings))
                    self.assertTrue(torch.equal(labels, correct_labels))
Exemple #15
0
    def test_contrastive_loss(self):
        loss_funcA = ContrastiveLoss(pos_margin=0.25,
                                     neg_margin=1.5,
                                     use_similarity=False,
                                     avg_non_zero_only=True,
                                     squared_distances=True)
        loss_funcB = ContrastiveLoss(pos_margin=1.5,
                                     neg_margin=0.6,
                                     use_similarity=True,
                                     avg_non_zero_only=True)
        loss_funcC = ContrastiveLoss(pos_margin=0.25,
                                     neg_margin=1.5,
                                     use_similarity=False,
                                     avg_non_zero_only=False,
                                     squared_distances=True)
        loss_funcD = ContrastiveLoss(pos_margin=1.5,
                                     neg_margin=0.6,
                                     use_similarity=True,
                                     avg_non_zero_only=False)

        embedding_angles = [0, 20, 40, 60, 80]
        embeddings = torch.FloatTensor(
            [c_f.angle_to_coord(a) for a in embedding_angles])  #2D embeddings
        labels = torch.LongTensor([0, 0, 1, 1, 2])

        lossA = loss_funcA(embeddings, labels)
        lossB = loss_funcB(embeddings, labels)
        lossC = loss_funcC(embeddings, labels)
        lossD = loss_funcD(embeddings, labels)

        pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
        neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0),
                     (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1),
                     (4, 2), (4, 3)]

        correct_pos_losses = [0, 0, 0, 0]
        correct_neg_losses = [0, 0, 0, 0]
        num_non_zero_pos = [0, 0, 0, 0]
        num_non_zero_neg = [0, 0, 0, 0]
        for a, p in pos_pairs:
            anchor, positive = embeddings[a], embeddings[p]
            correct_lossA = torch.relu(
                torch.sum((anchor - positive)**2) - 0.25)
            correct_lossB = torch.relu(1.5 - torch.matmul(anchor, positive))
            correct_pos_losses[0] += correct_lossA
            correct_pos_losses[1] += correct_lossB
            correct_pos_losses[2] += correct_lossA
            correct_pos_losses[3] += correct_lossB
            if correct_lossA > 0:
                num_non_zero_pos[0] += 1
                num_non_zero_pos[2] += 1
            if correct_lossB > 0:
                num_non_zero_pos[1] += 1
                num_non_zero_pos[3] += 1

        for a, n in neg_pairs:
            anchor, negative = embeddings[a], embeddings[n]
            correct_lossA = torch.relu(1.5 - torch.sum((anchor - negative)**2))
            correct_lossB = torch.relu(torch.matmul(anchor, negative) - 0.6)
            correct_neg_losses[0] += correct_lossA
            correct_neg_losses[1] += correct_lossB
            correct_neg_losses[2] += correct_lossA
            correct_neg_losses[3] += correct_lossB
            if correct_lossA > 0:
                num_non_zero_neg[0] += 1
                num_non_zero_neg[2] += 1
            if correct_lossB > 0:
                num_non_zero_neg[1] += 1
                num_non_zero_neg[3] += 1

        for i in range(2):
            if num_non_zero_pos[i] > 0:
                correct_pos_losses[i] /= num_non_zero_pos[i]
            if num_non_zero_neg[i] > 0:
                correct_neg_losses[i] /= num_non_zero_neg[i]

        for i in range(2, 4):
            correct_pos_losses[i] /= len(pos_pairs)
            correct_neg_losses[i] /= len(neg_pairs)

        correct_losses = [0, 0, 0, 0]
        for i in range(4):
            correct_losses[i] = correct_pos_losses[i] + correct_neg_losses[i]

        self.assertTrue(torch.isclose(lossA, correct_losses[0]))
        self.assertTrue(torch.isclose(lossB, correct_losses[1]))
        self.assertTrue(torch.isclose(lossC, correct_losses[2]))
        self.assertTrue(torch.isclose(lossD, correct_losses[3]))
Exemple #16
0
    def test_queue(self):
        for test_enqueue_idx in [False, True]:
            for dtype in TEST_DTYPES:
                batch_size = 32
                enqueue_batch_size = 15
                self.loss = CrossBatchMemory(
                    loss=ContrastiveLoss(),
                    embedding_size=self.embedding_size,
                    memory_size=self.memory_size,
                )
                for i in range(30):
                    embeddings = (torch.randn(
                        batch_size,
                        self.embedding_size).to(TEST_DEVICE).type(dtype))
                    labels = torch.arange(batch_size).to(TEST_DEVICE)
                    q = self.loss.queue_idx
                    B = enqueue_batch_size if test_enqueue_idx else batch_size
                    if test_enqueue_idx:
                        enqueue_idx = torch.arange(enqueue_batch_size) * 2
                    else:
                        enqueue_idx = None

                    self.assertTrue(q == (i * B) % self.memory_size)
                    loss = self.loss(embeddings,
                                     labels,
                                     enqueue_idx=enqueue_idx)

                    start_idx = q
                    if q + B == self.memory_size:
                        end_idx = self.memory_size
                    else:
                        end_idx = (q + B) % self.memory_size
                    if start_idx < end_idx:
                        if test_enqueue_idx:
                            self.assertTrue(
                                torch.equal(
                                    embeddings[enqueue_idx],
                                    self.loss.
                                    embedding_memory[start_idx:end_idx],
                                ))
                            self.assertTrue(
                                torch.equal(
                                    labels[enqueue_idx],
                                    self.loss.label_memory[start_idx:end_idx],
                                ))
                        else:
                            self.assertTrue(
                                torch.equal(
                                    embeddings,
                                    self.loss.
                                    embedding_memory[start_idx:end_idx],
                                ))
                            self.assertTrue(
                                torch.equal(
                                    labels,
                                    self.loss.label_memory[start_idx:end_idx]))
                    else:
                        correct_embeddings = torch.cat(
                            [
                                self.loss.embedding_memory[start_idx:],
                                self.loss.embedding_memory[:end_idx],
                            ],
                            dim=0,
                        )
                        correct_labels = torch.cat(
                            [
                                self.loss.label_memory[start_idx:],
                                self.loss.label_memory[:end_idx],
                            ],
                            dim=0,
                        )
                        if test_enqueue_idx:
                            self.assertTrue(
                                torch.equal(embeddings[enqueue_idx],
                                            correct_embeddings))
                            self.assertTrue(
                                torch.equal(labels[enqueue_idx],
                                            correct_labels))
                        else:
                            self.assertTrue(
                                torch.equal(embeddings, correct_embeddings))
                            self.assertTrue(torch.equal(
                                labels, correct_labels))
Exemple #17
0
    def test_contrastive_loss(self):
        loss_funcA = ContrastiveLoss(pos_margin=0.25,
                                     neg_margin=1.5,
                                     distance=LpDistance(power=2))
        loss_funcB = ContrastiveLoss(pos_margin=1.5,
                                     neg_margin=0.6,
                                     distance=CosineSimilarity())
        loss_funcC = ContrastiveLoss(pos_margin=0.25,
                                     neg_margin=1.5,
                                     distance=LpDistance(power=2),
                                     reducer=MeanReducer())
        loss_funcD = ContrastiveLoss(pos_margin=1.5,
                                     neg_margin=0.6,
                                     distance=CosineSimilarity(),
                                     reducer=MeanReducer())

        for dtype in TEST_DTYPES:
            embedding_angles = [0, 20, 40, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype).to(self.device)  #2D embeddings
            labels = torch.LongTensor([0, 0, 1, 1, 2])

            lossA = loss_funcA(embeddings, labels)
            lossB = loss_funcB(embeddings, labels)
            lossC = loss_funcC(embeddings, labels)
            lossD = loss_funcD(embeddings, labels)

            pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
            neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4),
                         (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4),
                         (4, 0), (4, 1), (4, 2), (4, 3)]

            correct_pos_losses = [0, 0, 0, 0]
            correct_neg_losses = [0, 0, 0, 0]
            num_non_zero_pos = [0, 0, 0, 0]
            num_non_zero_neg = [0, 0, 0, 0]
            for a, p in pos_pairs:
                anchor, positive = embeddings[a], embeddings[p]
                correct_lossA = torch.relu(
                    torch.sum((anchor - positive)**2) - 0.25)
                correct_lossB = torch.relu(1.5 -
                                           torch.matmul(anchor, positive))
                correct_pos_losses[0] += correct_lossA
                correct_pos_losses[1] += correct_lossB
                correct_pos_losses[2] += correct_lossA
                correct_pos_losses[3] += correct_lossB
                if correct_lossA > 0:
                    num_non_zero_pos[0] += 1
                    num_non_zero_pos[2] += 1
                if correct_lossB > 0:
                    num_non_zero_pos[1] += 1
                    num_non_zero_pos[3] += 1

            for a, n in neg_pairs:
                anchor, negative = embeddings[a], embeddings[n]
                correct_lossA = torch.relu(1.5 -
                                           torch.sum((anchor - negative)**2))
                correct_lossB = torch.relu(
                    torch.matmul(anchor, negative) - 0.6)
                correct_neg_losses[0] += correct_lossA
                correct_neg_losses[1] += correct_lossB
                correct_neg_losses[2] += correct_lossA
                correct_neg_losses[3] += correct_lossB
                if correct_lossA > 0:
                    num_non_zero_neg[0] += 1
                    num_non_zero_neg[2] += 1
                if correct_lossB > 0:
                    num_non_zero_neg[1] += 1
                    num_non_zero_neg[3] += 1

            for i in range(2):
                if num_non_zero_pos[i] > 0:
                    correct_pos_losses[i] /= num_non_zero_pos[i]
                if num_non_zero_neg[i] > 0:
                    correct_neg_losses[i] /= num_non_zero_neg[i]

            for i in range(2, 4):
                correct_pos_losses[i] /= len(pos_pairs)
                correct_neg_losses[i] /= len(neg_pairs)

            correct_losses = [0, 0, 0, 0]
            for i in range(4):
                correct_losses[
                    i] = correct_pos_losses[i] + correct_neg_losses[i]

            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(lossA, correct_losses[0], rtol=rtol))
            self.assertTrue(torch.isclose(lossB, correct_losses[1], rtol=rtol))
            self.assertTrue(torch.isclose(lossC, correct_losses[2], rtol=rtol))
            self.assertTrue(torch.isclose(lossD, correct_losses[3], rtol=rtol))