def test_mean_reducer(self):
        reducer = MeanReducer()
        batch_size = 100
        embedding_size = 64
        for dtype in TEST_DTYPES:
            embeddings = (torch.randn(
                batch_size, embedding_size).type(dtype).to(self.device))
            labels = torch.randint(0, 10, (batch_size, ))
            pair_indices = (
                torch.randint(0, batch_size, (batch_size, )),
                torch.randint(0, batch_size, (batch_size, )),
            )
            triplet_indices = pair_indices + (torch.randint(
                0, batch_size, (batch_size, )), )
            losses = torch.randn(batch_size).type(dtype).to(self.device)

            for indices, reduction_type in [
                (torch.arange(batch_size), "element"),
                (pair_indices, "pos_pair"),
                (pair_indices, "neg_pair"),
                (triplet_indices, "triplet"),
            ]:
                loss_dict = {
                    "loss": {
                        "losses": losses,
                        "indices": indices,
                        "reduction_type": reduction_type,
                    }
                }
                output = reducer(loss_dict, embeddings, labels)
                correct_output = torch.mean(losses)
                self.assertTrue(output == correct_output)
Exemplo n.º 2
0
    def test_triplet_margin_loss(self):
        margin = 0.2
        loss_funcA = TripletMarginLoss(margin=margin)
        loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer())
        embedding_angles = [0, 20, 40, 60, 80]
        embeddings = torch.tensor(
            [c_f.angle_to_coord(a) for a in embedding_angles],
            requires_grad=True,
            dtype=torch.float)  #2D embeddings
        labels = torch.LongTensor([0, 0, 1, 1, 2])

        lossA = loss_funcA(embeddings, labels)
        lossB = loss_funcB(embeddings, labels)
        lossA.backward()
        lossB.backward()

        triplets = [(0, 1, 2), (0, 1, 3), (0, 1, 4), (1, 0, 2), (1, 0, 3),
                    (1, 0, 4), (2, 3, 0), (2, 3, 1), (2, 3, 4), (3, 2, 0),
                    (3, 2, 1), (3, 2, 4)]

        correct_loss = 0
        num_non_zero_triplets = 0
        for a, p, n in triplets:
            anchor, positive, negative = embeddings[a], embeddings[
                p], embeddings[n]
            curr_loss = torch.relu(
                torch.sqrt(torch.sum((anchor - positive)**2)) -
                torch.sqrt(torch.sum((anchor - negative)**2)) + margin)
            if curr_loss > 0:
                num_non_zero_triplets += 1
            correct_loss += curr_loss
        self.assertTrue(
            torch.isclose(lossA, correct_loss / num_non_zero_triplets))
        self.assertTrue(torch.isclose(lossB, correct_loss / len(triplets)))
 def test_setting_reducers(self):
     for loss in [TripletMarginLoss, ContrastiveLoss]:
         for reducer in [
                 ThresholdReducer(low=0),
                 MeanReducer(),
                 AvgNonZeroReducer(),
         ]:
             L = loss(reducer=reducer)
             if isinstance(L, TripletMarginLoss):
                 assert type(L.reducer) == type(reducer)
             else:
                 for v in L.reducer.reducers.values():
                     assert type(v) == type(reducer)
Exemplo n.º 4
0
 def test_with_no_valid_triplets(self):
     loss_funcA = TripletMarginLoss(margin=0.2)
     loss_funcB = TripletMarginLoss(margin=0.2, reducer=MeanReducer())
     for dtype in [torch.float16, torch.float32, torch.float64]:
         embedding_angles = [0, 20, 40, 60, 80]
         embeddings = torch.tensor(
             [c_f.angle_to_coord(a) for a in embedding_angles],
             requires_grad=True,
             dtype=dtype).to(self.device)  #2D embeddings
         labels = torch.LongTensor([0, 1, 2, 3, 4])
         lossA = loss_funcA(embeddings, labels)
         lossB = loss_funcB(embeddings, labels)
         self.assertEqual(lossA, 0)
         self.assertEqual(lossB, 0)
Exemplo n.º 5
0
    def test_backward(self):
        margin = 0.2
        loss_funcA = TripletMarginLoss(margin=margin)
        loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer())
        for dtype in [torch.float16, torch.float32, torch.float64]:
            for loss_func in [loss_funcA, loss_funcB]:
                embedding_angles = [0, 20, 40, 60, 80]
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=dtype).to(self.device)  #2D embeddings
                labels = torch.LongTensor([0, 0, 1, 1, 2])

                loss = loss_func(embeddings, labels)
                loss.backward()
    def test_backward(self):
        margin = 0.2
        loss_funcA = TripletMarginLoss(margin=margin)
        loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer())
        loss_funcC = TripletMarginLoss(smooth_loss=True)
        for dtype in TEST_DTYPES:
            for loss_func in [loss_funcA, loss_funcB, loss_funcC]:
                embedding_angles = [0, 20, 40, 60, 80]
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=dtype,
                ).to(TEST_DEVICE)  # 2D embeddings
                labels = torch.LongTensor([0, 0, 1, 1, 2])

                loss = loss_func(embeddings, labels)
                loss.backward()
    def test_triplet_margin_loss(self):
        margin = 0.2
        loss_funcA = TripletMarginLoss(margin=margin)
        loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer())
        loss_funcC = TripletMarginLoss(margin=margin,
                                       distance=CosineSimilarity())
        loss_funcD = TripletMarginLoss(margin=margin,
                                       reducer=MeanReducer(),
                                       distance=CosineSimilarity())
        for dtype in TEST_DTYPES:
            embedding_angles = [0, 20, 40, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype,
            ).to(self.device)  # 2D embeddings
            labels = torch.LongTensor([0, 0, 1, 1, 2])

            lossA = loss_funcA(embeddings, labels)
            lossB = loss_funcB(embeddings, labels)
            lossC = loss_funcC(embeddings, labels)
            lossD = loss_funcD(embeddings, labels)

            triplets = [
                (0, 1, 2),
                (0, 1, 3),
                (0, 1, 4),
                (1, 0, 2),
                (1, 0, 3),
                (1, 0, 4),
                (2, 3, 0),
                (2, 3, 1),
                (2, 3, 4),
                (3, 2, 0),
                (3, 2, 1),
                (3, 2, 4),
            ]

            correct_loss = 0
            correct_loss_cosine = 0
            num_non_zero_triplets = 0
            num_non_zero_triplets_cosine = 0
            for a, p, n in triplets:
                anchor, positive, negative = embeddings[a], embeddings[
                    p], embeddings[n]
                curr_loss = torch.relu(
                    torch.sqrt(torch.sum((anchor - positive)**2)) -
                    torch.sqrt(torch.sum((anchor - negative)**2)) + margin)
                curr_loss_cosine = torch.relu(
                    torch.sum(anchor * negative) -
                    torch.sum(anchor * positive) + margin)
                if curr_loss > 0:
                    num_non_zero_triplets += 1
                if curr_loss_cosine > 0:
                    num_non_zero_triplets_cosine += 1
                correct_loss += curr_loss
                correct_loss_cosine += curr_loss_cosine
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(
                torch.isclose(lossA,
                              correct_loss / num_non_zero_triplets,
                              rtol=rtol))
            self.assertTrue(
                torch.isclose(lossB, correct_loss / len(triplets), rtol=rtol))
            self.assertTrue(
                torch.isclose(lossC,
                              correct_loss_cosine /
                              num_non_zero_triplets_cosine,
                              rtol=rtol))
            self.assertTrue(
                torch.isclose(lossD,
                              correct_loss_cosine / len(triplets),
                              rtol=rtol))
Exemplo n.º 8
0
 def get_default_reducer(self):
     return MeanReducer()
Exemplo n.º 9
0
    def test_contrastive_loss(self):
        loss_funcA = ContrastiveLoss(pos_margin=0.25,
                                     neg_margin=1.5,
                                     distance=LpDistance(power=2))
        loss_funcB = ContrastiveLoss(pos_margin=1.5,
                                     neg_margin=0.6,
                                     distance=CosineSimilarity())
        loss_funcC = ContrastiveLoss(pos_margin=0.25,
                                     neg_margin=1.5,
                                     distance=LpDistance(power=2),
                                     reducer=MeanReducer())
        loss_funcD = ContrastiveLoss(pos_margin=1.5,
                                     neg_margin=0.6,
                                     distance=CosineSimilarity(),
                                     reducer=MeanReducer())

        for dtype in TEST_DTYPES:
            embedding_angles = [0, 20, 40, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype).to(self.device)  #2D embeddings
            labels = torch.LongTensor([0, 0, 1, 1, 2])

            lossA = loss_funcA(embeddings, labels)
            lossB = loss_funcB(embeddings, labels)
            lossC = loss_funcC(embeddings, labels)
            lossD = loss_funcD(embeddings, labels)

            pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
            neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4),
                         (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4),
                         (4, 0), (4, 1), (4, 2), (4, 3)]

            correct_pos_losses = [0, 0, 0, 0]
            correct_neg_losses = [0, 0, 0, 0]
            num_non_zero_pos = [0, 0, 0, 0]
            num_non_zero_neg = [0, 0, 0, 0]
            for a, p in pos_pairs:
                anchor, positive = embeddings[a], embeddings[p]
                correct_lossA = torch.relu(
                    torch.sum((anchor - positive)**2) - 0.25)
                correct_lossB = torch.relu(1.5 -
                                           torch.matmul(anchor, positive))
                correct_pos_losses[0] += correct_lossA
                correct_pos_losses[1] += correct_lossB
                correct_pos_losses[2] += correct_lossA
                correct_pos_losses[3] += correct_lossB
                if correct_lossA > 0:
                    num_non_zero_pos[0] += 1
                    num_non_zero_pos[2] += 1
                if correct_lossB > 0:
                    num_non_zero_pos[1] += 1
                    num_non_zero_pos[3] += 1

            for a, n in neg_pairs:
                anchor, negative = embeddings[a], embeddings[n]
                correct_lossA = torch.relu(1.5 -
                                           torch.sum((anchor - negative)**2))
                correct_lossB = torch.relu(
                    torch.matmul(anchor, negative) - 0.6)
                correct_neg_losses[0] += correct_lossA
                correct_neg_losses[1] += correct_lossB
                correct_neg_losses[2] += correct_lossA
                correct_neg_losses[3] += correct_lossB
                if correct_lossA > 0:
                    num_non_zero_neg[0] += 1
                    num_non_zero_neg[2] += 1
                if correct_lossB > 0:
                    num_non_zero_neg[1] += 1
                    num_non_zero_neg[3] += 1

            for i in range(2):
                if num_non_zero_pos[i] > 0:
                    correct_pos_losses[i] /= num_non_zero_pos[i]
                if num_non_zero_neg[i] > 0:
                    correct_neg_losses[i] /= num_non_zero_neg[i]

            for i in range(2, 4):
                correct_pos_losses[i] /= len(pos_pairs)
                correct_neg_losses[i] /= len(neg_pairs)

            correct_losses = [0, 0, 0, 0]
            for i in range(4):
                correct_losses[
                    i] = correct_pos_losses[i] + correct_neg_losses[i]

            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(lossA, correct_losses[0], rtol=rtol))
            self.assertTrue(torch.isclose(lossB, correct_losses[1], rtol=rtol))
            self.assertTrue(torch.isclose(lossC, correct_losses[2], rtol=rtol))
            self.assertTrue(torch.isclose(lossD, correct_losses[3], rtol=rtol))
    def test_contrastive_loss(self):
        loss_funcA = ContrastiveLoss(pos_margin=0.25,
                                     neg_margin=1.5,
                                     use_similarity=False,
                                     squared_distances=True)
        loss_funcB = ContrastiveLoss(pos_margin=1.5,
                                     neg_margin=0.6,
                                     use_similarity=True)
        loss_funcC = ContrastiveLoss(pos_margin=0.25,
                                     neg_margin=1.5,
                                     use_similarity=False,
                                     squared_distances=True,
                                     reducer=MeanReducer())
        loss_funcD = ContrastiveLoss(pos_margin=1.5,
                                     neg_margin=0.6,
                                     use_similarity=True,
                                     reducer=MeanReducer())

        embedding_angles = [0, 20, 40, 60, 80]
        embeddings = torch.tensor(
            [c_f.angle_to_coord(a) for a in embedding_angles],
            requires_grad=True,
            dtype=torch.float)  #2D embeddings
        labels = torch.LongTensor([0, 0, 1, 1, 2])

        lossA = loss_funcA(embeddings, labels)
        lossB = loss_funcB(embeddings, labels)
        lossC = loss_funcC(embeddings, labels)
        lossD = loss_funcD(embeddings, labels)

        lossA.backward()
        lossB.backward()
        lossC.backward()
        lossD.backward()

        pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
        neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0),
                     (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1),
                     (4, 2), (4, 3)]

        correct_pos_losses = [0, 0, 0, 0]
        correct_neg_losses = [0, 0, 0, 0]
        num_non_zero_pos = [0, 0, 0, 0]
        num_non_zero_neg = [0, 0, 0, 0]
        for a, p in pos_pairs:
            anchor, positive = embeddings[a], embeddings[p]
            correct_lossA = torch.relu(
                torch.sum((anchor - positive)**2) - 0.25)
            correct_lossB = torch.relu(1.5 - torch.matmul(anchor, positive))
            correct_pos_losses[0] += correct_lossA
            correct_pos_losses[1] += correct_lossB
            correct_pos_losses[2] += correct_lossA
            correct_pos_losses[3] += correct_lossB
            if correct_lossA > 0:
                num_non_zero_pos[0] += 1
                num_non_zero_pos[2] += 1
            if correct_lossB > 0:
                num_non_zero_pos[1] += 1
                num_non_zero_pos[3] += 1

        for a, n in neg_pairs:
            anchor, negative = embeddings[a], embeddings[n]
            correct_lossA = torch.relu(1.5 - torch.sum((anchor - negative)**2))
            correct_lossB = torch.relu(torch.matmul(anchor, negative) - 0.6)
            correct_neg_losses[0] += correct_lossA
            correct_neg_losses[1] += correct_lossB
            correct_neg_losses[2] += correct_lossA
            correct_neg_losses[3] += correct_lossB
            if correct_lossA > 0:
                num_non_zero_neg[0] += 1
                num_non_zero_neg[2] += 1
            if correct_lossB > 0:
                num_non_zero_neg[1] += 1
                num_non_zero_neg[3] += 1

        for i in range(2):
            if num_non_zero_pos[i] > 0:
                correct_pos_losses[i] /= num_non_zero_pos[i]
            if num_non_zero_neg[i] > 0:
                correct_neg_losses[i] /= num_non_zero_neg[i]

        for i in range(2, 4):
            correct_pos_losses[i] /= len(pos_pairs)
            correct_neg_losses[i] /= len(neg_pairs)

        correct_losses = [0, 0, 0, 0]
        for i in range(4):
            correct_losses[i] = correct_pos_losses[i] + correct_neg_losses[i]

        self.assertTrue(torch.isclose(lossA, correct_losses[0]))
        self.assertTrue(torch.isclose(lossB, correct_losses[1]))
        self.assertTrue(torch.isclose(lossC, correct_losses[2]))
        self.assertTrue(torch.isclose(lossD, correct_losses[3]))
    def test_per_anchor_reducer(self):
        for inner_reducer in [MeanReducer(), AvgNonZeroReducer()]:
            reducer = PerAnchorReducer(inner_reducer)
            batch_size = 100
            embedding_size = 64
            for dtype in TEST_DTYPES:
                embeddings = (
                    torch.randn(batch_size, embedding_size).type(dtype).to(TEST_DEVICE)
                )
                labels = torch.randint(0, 10, (batch_size,))
                pos_pair_indices = lmu.get_all_pairs_indices(labels)[:2]
                neg_pair_indices = lmu.get_all_pairs_indices(labels)[2:]
                triplet_indices = lmu.get_all_triplets_indices(labels)

                for indices, reduction_type in [
                    (torch.arange(batch_size), "element"),
                    (pos_pair_indices, "pos_pair"),
                    (neg_pair_indices, "neg_pair"),
                    (triplet_indices, "triplet"),
                ]:
                    loss_size = (
                        len(indices) if reduction_type == "element" else len(indices[0])
                    )
                    losses = torch.randn(loss_size).type(dtype).to(TEST_DEVICE)
                    loss_dict = {
                        "loss": {
                            "losses": losses,
                            "indices": indices,
                            "reduction_type": reduction_type,
                        }
                    }
                    if reduction_type == "triplet":
                        self.assertRaises(
                            NotImplementedError,
                            lambda: reducer(loss_dict, embeddings, labels),
                        )
                        continue

                    output = reducer(loss_dict, embeddings, labels)
                    if reduction_type == "element":
                        loss_dict = {
                            "loss": {
                                "losses": losses,
                                "indices": c_f.torch_arange_from_size(embeddings),
                                "reduction_type": "element",
                            }
                        }
                    else:
                        anchors = indices[0]
                        correct_output = torch.zeros(
                            batch_size, device=TEST_DEVICE, dtype=dtype
                        )
                        for i in range(len(embeddings)):
                            matching_pairs_mask = anchors == i
                            num_matching_pairs = torch.sum(matching_pairs_mask)
                            if num_matching_pairs > 0:
                                correct_output[i] = (
                                    torch.sum(losses[matching_pairs_mask])
                                    / num_matching_pairs
                                )
                        loss_dict = {
                            "loss": {
                                "losses": correct_output,
                                "indices": c_f.torch_arange_from_size(embeddings),
                                "reduction_type": "element",
                            }
                        }
                    correct_output = inner_reducer(loss_dict, embeddings, labels)
                    self.assertTrue(torch.isclose(output, correct_output))