def test_avg_non_zero_reducer(self):
        reducer = AvgNonZeroReducer()
        batch_size = 100
        embedding_size = 64
        embeddings = torch.randn(batch_size, embedding_size)
        labels = torch.randint(0, 10, (batch_size, ))
        pair_indices = (torch.randint(0, batch_size, (batch_size, )),
                        torch.randint(0, batch_size, (batch_size, )))
        triplet_indices = pair_indices + (torch.randint(
            0, batch_size, (batch_size, )), )
        losses = torch.randn(batch_size)
        zero_losses = torch.zeros(batch_size)

        for indices, reduction_type in [(torch.arange(batch_size), "element"),
                                        (pair_indices, "pos_pair"),
                                        (pair_indices, "neg_pair"),
                                        (triplet_indices, "triplet")]:
            for L in [losses, zero_losses]:
                loss_dict = {
                    "loss": {
                        "losses": L,
                        "indices": indices,
                        "reduction_type": reduction_type
                    }
                }
                output = reducer(loss_dict, embeddings, labels)
                filtered_L = L[L > 0]
                if len(filtered_L) > 0:
                    correct_output = torch.mean(filtered_L)
                else:
                    correct_output = torch.mean(L) * 0
                self.assertTrue(output == correct_output)
 def test_deepcopy_reducer(self):
     loss_fn = ContrastiveLoss(pos_margin=0,
                               neg_margin=2,
                               reducer=AvgNonZeroReducer())
     embeddings = torch.randn(128, 64)
     labels = torch.randint(low=0, high=10, size=(128, ))
     loss = loss_fn(embeddings, labels)
     self.assertTrue(
         loss_fn.reducer.reducers["pos_loss"].pos_pairs_past_filter > 0)
     self.assertTrue(
         loss_fn.reducer.reducers["neg_loss"].neg_pairs_past_filter > 0)
 def test_setting_reducers(self):
     for loss in [TripletMarginLoss, ContrastiveLoss]:
         for reducer in [
                 ThresholdReducer(low=0),
                 MeanReducer(),
                 AvgNonZeroReducer(),
         ]:
             L = loss(reducer=reducer)
             if isinstance(L, TripletMarginLoss):
                 assert type(L.reducer) == type(reducer)
             else:
                 for v in L.reducer.reducers.values():
                     assert type(v) == type(reducer)
    def test_multiple_reducers(self):
        reducer = MultipleReducers({
            "lossA": AvgNonZeroReducer(),
            "lossB": DivisorReducer()
        })
        batch_size = 100
        embedding_size = 64
        for dtype in TEST_DTYPES:
            embeddings = (torch.randn(
                batch_size, embedding_size).type(dtype).to(TEST_DEVICE))
            labels = torch.randint(0, 10, (batch_size, ))
            pair_indices = (
                torch.randint(0, batch_size, (batch_size, )),
                torch.randint(0, batch_size, (batch_size, )),
            )
            triplet_indices = pair_indices + (torch.randint(
                0, batch_size, (batch_size, )), )
            lossesA = torch.randn(batch_size).type(dtype).to(TEST_DEVICE)
            lossesB = torch.randn(batch_size).type(dtype).to(TEST_DEVICE)

            for indices, reduction_type in [
                (torch.arange(batch_size), "element"),
                (pair_indices, "pos_pair"),
                (pair_indices, "neg_pair"),
                (triplet_indices, "triplet"),
            ]:
                loss_dict = {
                    "lossA": {
                        "losses": lossesA,
                        "indices": indices,
                        "reduction_type": reduction_type,
                    },
                    "lossB": {
                        "losses": lossesB,
                        "indices": indices,
                        "reduction_type": reduction_type,
                        "divisor_summands": {
                            "partA": 32,
                            "partB": 15
                        },
                    },
                }
                output = reducer(loss_dict, embeddings, labels)
                correct_output = (torch.mean(
                    lossesA[lossesA > 0])) + (torch.sum(lossesB) / (32 + 15))
                self.assertTrue(output == correct_output)
    def test_ntxent_loss(self):
        temperature = 0.1
        loss_funcA = NTXentLoss(temperature=temperature)
        loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance())
        loss_funcC = NTXentLoss(
            temperature=temperature, reducer=PerAnchorReducer(AvgNonZeroReducer())
        )
        loss_funcD = SupConLoss(temperature=temperature)
        loss_funcE = SupConLoss(temperature=temperature, distance=LpDistance())

        for dtype in TEST_DTYPES:
            embedding_angles = [0, 10, 20, 50, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype,
            ).to(
                TEST_DEVICE
            )  # 2D embeddings

            labels = torch.LongTensor([0, 0, 0, 1, 1, 2])

            obtained_losses = [
                x(embeddings, labels)
                for x in [loss_funcA, loss_funcB, loss_funcC, loss_funcD, loss_funcE]
            ]

            pos_pairs = [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1), (3, 4), (4, 3)]
            neg_pairs = [
                (0, 3),
                (0, 4),
                (0, 5),
                (1, 3),
                (1, 4),
                (1, 5),
                (2, 3),
                (2, 4),
                (2, 5),
                (3, 0),
                (3, 1),
                (3, 2),
                (3, 5),
                (4, 0),
                (4, 1),
                (4, 2),
                (4, 5),
                (5, 0),
                (5, 1),
                (5, 2),
                (5, 3),
                (5, 4),
            ]

            total_lossA, total_lossB, total_lossC, total_lossD, total_lossE = (
                0,
                0,
                torch.zeros(5, device=TEST_DEVICE, dtype=dtype),
                torch.zeros(5, device=TEST_DEVICE, dtype=dtype),
                torch.zeros(5, device=TEST_DEVICE, dtype=dtype),
            )
            for a1, p in pos_pairs:
                anchor, positive = embeddings[a1], embeddings[p]
                numeratorA = torch.exp(torch.matmul(anchor, positive) / temperature)
                numeratorB = torch.exp(
                    -torch.sqrt(torch.sum((anchor - positive) ** 2)) / temperature
                )
                denominatorA = numeratorA.clone()
                denominatorB = numeratorB.clone()
                denominatorD = 0
                denominatorE = 0
                for a2, n in pos_pairs + neg_pairs:
                    if a2 == a1:
                        negative = embeddings[n]
                        curr_denomD = torch.exp(
                            torch.matmul(anchor, negative) / temperature
                        )
                        curr_denomE = torch.exp(
                            -torch.sqrt(torch.sum((anchor - negative) ** 2))
                            / temperature
                        )
                        denominatorD += curr_denomD
                        denominatorE += curr_denomE
                        if (a2, n) not in pos_pairs:
                            denominatorA += curr_denomD
                            denominatorB += curr_denomE
                    else:
                        continue

                curr_lossA = -torch.log(numeratorA / denominatorA)
                curr_lossB = -torch.log(numeratorB / denominatorB)
                curr_lossD = -torch.log(numeratorA / denominatorD)
                curr_lossE = -torch.log(numeratorB / denominatorE)
                total_lossA += curr_lossA
                total_lossB += curr_lossB
                total_lossC[a1] += curr_lossA
                total_lossD[a1] += curr_lossD
                total_lossE[a1] += curr_lossE

            total_lossA /= len(pos_pairs)
            total_lossB /= len(pos_pairs)
            pos_pair_per_anchor = torch.tensor(
                [2, 2, 2, 1, 1], device=TEST_DEVICE, dtype=dtype
            )
            total_lossC, total_lossD, total_lossE = [
                torch.mean(x / pos_pair_per_anchor)
                for x in [total_lossC, total_lossD, total_lossE]
            ]

            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(obtained_losses[0], total_lossA, rtol=rtol))
            self.assertTrue(torch.isclose(obtained_losses[1], total_lossB, rtol=rtol))
            self.assertTrue(torch.isclose(obtained_losses[2], total_lossC, rtol=rtol))
            self.assertTrue(torch.isclose(obtained_losses[3], total_lossD, rtol=rtol))
            self.assertTrue(torch.isclose(obtained_losses[4], total_lossE, rtol=rtol))
Esempio n. 6
0
 def get_default_reducer(self):
     return AvgNonZeroReducer()
    def test_per_anchor_reducer(self):
        for inner_reducer in [MeanReducer(), AvgNonZeroReducer()]:
            reducer = PerAnchorReducer(inner_reducer)
            batch_size = 100
            embedding_size = 64
            for dtype in TEST_DTYPES:
                embeddings = (
                    torch.randn(batch_size, embedding_size).type(dtype).to(TEST_DEVICE)
                )
                labels = torch.randint(0, 10, (batch_size,))
                pos_pair_indices = lmu.get_all_pairs_indices(labels)[:2]
                neg_pair_indices = lmu.get_all_pairs_indices(labels)[2:]
                triplet_indices = lmu.get_all_triplets_indices(labels)

                for indices, reduction_type in [
                    (torch.arange(batch_size), "element"),
                    (pos_pair_indices, "pos_pair"),
                    (neg_pair_indices, "neg_pair"),
                    (triplet_indices, "triplet"),
                ]:
                    loss_size = (
                        len(indices) if reduction_type == "element" else len(indices[0])
                    )
                    losses = torch.randn(loss_size).type(dtype).to(TEST_DEVICE)
                    loss_dict = {
                        "loss": {
                            "losses": losses,
                            "indices": indices,
                            "reduction_type": reduction_type,
                        }
                    }
                    if reduction_type == "triplet":
                        self.assertRaises(
                            NotImplementedError,
                            lambda: reducer(loss_dict, embeddings, labels),
                        )
                        continue

                    output = reducer(loss_dict, embeddings, labels)
                    if reduction_type == "element":
                        loss_dict = {
                            "loss": {
                                "losses": losses,
                                "indices": c_f.torch_arange_from_size(embeddings),
                                "reduction_type": "element",
                            }
                        }
                    else:
                        anchors = indices[0]
                        correct_output = torch.zeros(
                            batch_size, device=TEST_DEVICE, dtype=dtype
                        )
                        for i in range(len(embeddings)):
                            matching_pairs_mask = anchors == i
                            num_matching_pairs = torch.sum(matching_pairs_mask)
                            if num_matching_pairs > 0:
                                correct_output[i] = (
                                    torch.sum(losses[matching_pairs_mask])
                                    / num_matching_pairs
                                )
                        loss_dict = {
                            "loss": {
                                "losses": correct_output,
                                "indices": c_f.torch_arange_from_size(embeddings),
                                "reduction_type": "element",
                            }
                        }
                    correct_output = inner_reducer(loss_dict, embeddings, labels)
                    self.assertTrue(torch.isclose(output, correct_output))