class ReferenceDistanceLoss(torch.nn.Module):
    def __init__(self, pos_embs=None, neg_embs=None, **kwargs):
        super().__init__(**kwargs)
        self.pos_embs = (pos_embs.detach().clone()
                         if pos_embs is not None else None)
        self.neg_embs = (neg_embs.detach().clone()
                         if neg_embs is not None else None)
        self.distance = LpDistance(p=2)

    def forward(self, embeddings, *args, reduction="mean"):
        if len(embeddings) == 0:
            return self.zero_losses()

        pos_loss = torch.tensor(0)
        if self.pos_embs is not None:
            pos_loss = self.distance.pairwise_distance(embeddings,
                                                       self.pos_embs)

        neg_loss = torch.tensor(0)
        if self.neg_embs is not None:
            neg_loss = self.distance.pairwise_distance(embeddings,
                                                       self.neg_embs)

        loss = pos_loss - neg_loss

        if reduction == "mean":
            return loss.mean()

        if reduction == "none":
            return loss

        raise ValueError(f"unknown reduction: {reduction}")
 def __init__(self, pos_embs=None, neg_embs=None, **kwargs):
     super().__init__(**kwargs)
     self.pos_embs = (pos_embs.detach().clone()
                      if pos_embs is not None else None)
     self.neg_embs = (neg_embs.detach().clone()
                      if neg_embs is not None else None)
     self.distance = LpDistance(p=2)
コード例 #3
0
    def test_uniform_histogram_miner(self):
        torch.manual_seed(93612)
        batch_size = 128
        embedding_size = 32
        num_bins, pos_per_bin, neg_per_bin = 100, 25, 123
        for distance in [
                LpDistance(p=1),
                LpDistance(p=2),
                LpDistance(normalize_embeddings=False),
                SNRDistance(),
        ]:
            miner = UniformHistogramMiner(
                num_bins=num_bins,
                pos_per_bin=pos_per_bin,
                neg_per_bin=neg_per_bin,
                distance=distance,
            )
            for dtype in TEST_DTYPES:
                embeddings = torch.randn(batch_size,
                                         embedding_size,
                                         device=TEST_DEVICE,
                                         dtype=dtype)
                labels = torch.randint(0,
                                       2,
                                       size=(batch_size, ),
                                       device=TEST_DEVICE)

                a1, p, a2, n = lmu.get_all_pairs_indices(labels)
                dist_mat = distance(embeddings)
                pos_pairs = dist_mat[a1, p]
                neg_pairs = dist_mat[a2, n]

                a1, p, a2, n = miner(embeddings, labels)

                if dtype == torch.float16:
                    continue  # histc doesn't work for Half tensor

                pos_histogram = torch.histc(
                    dist_mat[a1, p],
                    bins=num_bins,
                    min=torch.min(pos_pairs),
                    max=torch.max(pos_pairs),
                )
                neg_histogram = torch.histc(
                    dist_mat[a2, n],
                    bins=num_bins,
                    min=torch.min(neg_pairs),
                    max=torch.max(neg_pairs),
                )

                self.assertTrue(
                    torch.all((pos_histogram == pos_per_bin)
                              | (pos_histogram == 0)))
                self.assertTrue(
                    torch.all((neg_histogram == neg_per_bin)
                              | (neg_histogram == 0)))
コード例 #4
0
 def setUpClass(self):
     self.device = torch.device('cuda')
     self.dist_miner = BatchHardMiner(distance=LpDistance(normalize_embeddings=False))
     self.normalized_dist_miner = BatchHardMiner(distance=LpDistance(normalize_embeddings=True))
     self.normalized_dist_miner_squared = BatchHardMiner(distance=LpDistance(normalize_embeddings=True, power=2))
     self.sim_miner = BatchHardMiner(distance=CosineSimilarity())
     self.labels = torch.LongTensor([0, 0, 1, 1, 0, 2, 1, 1, 1])
     self.correct_a = torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(self.device)
     self.correct_p = torch.LongTensor([4, 4, 8, 8, 0, 2, 2, 2]).to(self.device)
     self.correct_n = [torch.LongTensor([2, 2, 1, 4, 3, 5, 5, 5]).to(self.device), torch.LongTensor([2, 2, 1, 4, 5, 5, 5, 5]).to(self.device)]
コード例 #5
0
 def test_with_no_valid_pairs(self):
     all_embedding_angles = [[0], [0, 10, 20], [0, 40, 60]]
     all_labels = [
         torch.LongTensor([0]),
         torch.LongTensor([0, 0, 0]),
         torch.LongTensor([1, 2, 3]),
     ]
     temperature = 0.1
     for loss_class in [NTXentLoss, SupConLoss]:
         loss_funcA = loss_class(temperature)
         loss_funcB = loss_class(temperature, distance=LpDistance())
         for loss_func in [loss_funcA, loss_funcB]:
             for dtype in TEST_DTYPES:
                 for embedding_angles, labels in zip(
                     all_embedding_angles, all_labels
                 ):
                     embeddings = torch.tensor(
                         [c_f.angle_to_coord(a) for a in embedding_angles],
                         requires_grad=True,
                         dtype=dtype,
                     ).to(
                         TEST_DEVICE
                     )  # 2D embeddings
                     loss = loss_func(embeddings, labels)
                     loss.backward()
                     self.assertEqual(loss, 0)
コード例 #6
0
 def __init__(self, margin, normalize_embeddings):
     self.margin = margin
     self.distance = LpDistance(normalize_embeddings=normalize_embeddings, collect_stats=True)
     # We use triplet loss with Euclidean distance
     self.miner_fn = HardTripletMinerWithMasks(distance=self.distance)
     reducer_fn = reducers.AvgNonZeroReducer(collect_stats=True)
     self.loss_fn = losses.TripletMarginLoss(margin=self.margin, swap=True, distance=self.distance,
                                             reducer=reducer_fn, collect_stats=True)
コード例 #7
0
 def __init__(self, pos_margin, neg_margin, normalize_embeddings):
     self.pos_margin = pos_margin
     self.neg_margin = neg_margin
     self.distance = LpDistance(normalize_embeddings=normalize_embeddings)
     self.miner_fn = HardTripletMinerWithMasks(distance=self.distance)
     # We use contrastive loss with squared Euclidean distance
     self.loss_fn = losses.ContrastiveLoss(pos_margin=self.pos_margin,
                                           neg_margin=self.neg_margin,
                                           distance=self.distance)
コード例 #8
0
 def __init__(self, margin, normalize_embeddings):
     self.margin = margin
     self.normalize_embeddings = normalize_embeddings
     self.distance = LpDistance(normalize_embeddings=normalize_embeddings)
     # We use triplet loss with Euclidean distance
     self.miner_fn = HardTripletMinerWithMasks(distance=self.distance)
     self.loss_fn = losses.TripletMarginLoss(margin=self.margin,
                                             swap=True,
                                             distance=self.distance)
コード例 #9
0
    def test_pair_margin_miner(self):
        for dtype in TEST_DTYPES:
            for distance in [LpDistance(), CosineSimilarity()]:
                embedding_angles = torch.arange(0, 16)
                embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings
                labels = torch.randint(low=0, high=2, size=(16,))
                mat = distance(embeddings)
                pos_pairs = []
                neg_pairs = []
                for i in range(len(embeddings)):
                    anchor_label = labels[i]
                    for j in range(len(embeddings)):
                        if j == i:
                            continue
                        positive_label = labels[j]
                        if positive_label == anchor_label:
                            ap_dist = mat[i,j]
                            pos_pairs.append((i, j, ap_dist))

                for i in range(len(embeddings)):
                    anchor_label = labels[i]
                    for j in range(len(embeddings)):
                        if j == i:
                            continue
                        negative_label = labels[j]
                        if negative_label != anchor_label:
                            an_dist = mat[i,j]
                            neg_pairs.append((i, j, an_dist))

                for pos_margin_int in range(-1, 4):
                    pos_margin = float(pos_margin_int) * 0.05
                    for neg_margin_int in range(2, 7):
                        neg_margin = float(neg_margin_int) * 0.05
                        miner = PairMarginMiner(pos_margin, neg_margin, distance=distance)
                        correct_pos_pairs = []
                        correct_neg_pairs = []
                        for i,j,k in pos_pairs:
                            condition = (k < pos_margin) if distance.is_inverted else (k > pos_margin)
                            if condition:
                                correct_pos_pairs.append((i,j))
                        for i,j,k in neg_pairs:
                            condition = (k > neg_margin) if distance.is_inverted else (k < neg_margin)
                            if condition:                        
                                correct_neg_pairs.append((i,j))

                        correct_pos = set(correct_pos_pairs)
                        correct_neg = set(correct_neg_pairs)
                        a1, p1, a2, n2 = miner(embeddings, labels)
                        mined_pos = set([(a.item(),p.item()) for a,p in zip(a1,p1)])
                        mined_neg = set([(a.item(),n.item()) for a,n in zip(a2,n2)])

                        self.assertTrue(mined_pos == correct_pos)
                        self.assertTrue(mined_neg == correct_neg)
コード例 #10
0
    def test_ntxent_loss(self):
        temperature = 0.1
        loss_funcA = NTXentLoss(temperature=temperature)
        loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance())

        for dtype in TEST_DTYPES:
            embedding_angles = [0, 20, 40, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype).to(self.device)  #2D embeddings

            labels = torch.LongTensor([0, 0, 1, 1, 2])

            lossA = loss_funcA(embeddings, labels)
            lossB = loss_funcB(embeddings, labels)

            pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
            neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4),
                         (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4),
                         (4, 0), (4, 1), (4, 2), (4, 3)]

            total_lossA, total_lossB = 0, 0
            for a1, p in pos_pairs:
                anchor, positive = embeddings[a1], embeddings[p]
                numeratorA = torch.exp(
                    torch.matmul(anchor, positive) / temperature)
                numeratorB = torch.exp(
                    -torch.sqrt(torch.sum(
                        (anchor - positive)**2)) / temperature)
                denominatorA = numeratorA.clone()
                denominatorB = numeratorB.clone()
                for a2, n in neg_pairs:
                    if a2 == a1:
                        negative = embeddings[n]
                    else:
                        continue
                    denominatorA += torch.exp(
                        torch.matmul(anchor, negative) / temperature)
                    denominatorB += torch.exp(
                        -torch.sqrt(torch.sum(
                            (anchor - negative)**2)) / temperature)
                curr_lossA = -torch.log(numeratorA / denominatorA)
                curr_lossB = -torch.log(numeratorB / denominatorB)
                total_lossA += curr_lossA
                total_lossB += curr_lossB

            total_lossA /= len(pos_pairs)
            total_lossB /= len(pos_pairs)
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(lossA, total_lossA, rtol=rtol))
            self.assertTrue(torch.isclose(lossB, total_lossB, rtol=rtol))
コード例 #11
0
ファイル: loss.py プロジェクト: jac99/MinkLoc3D
 def __init__(self, pos_margin, neg_margin, normalize_embeddings):
     self.pos_margin = pos_margin
     self.neg_margin = neg_margin
     self.distance = LpDistance(normalize_embeddings=normalize_embeddings,
                                collect_stats=True)
     self.miner_fn = HardTripletMinerWithMasks(distance=self.distance)
     # We use contrastive loss with squared Euclidean distance
     reducer_fn = reducers.AvgNonZeroReducer(collect_stats=True)
     self.loss_fn = losses.ContrastiveLoss(pos_margin=self.pos_margin,
                                           neg_margin=self.neg_margin,
                                           distance=self.distance,
                                           reducer=reducer_fn,
                                           collect_stats=True)
コード例 #12
0
 def setUpClass(self):
     self.device = torch.device('cuda')
     self.dist_miner = HDCMiner(
         filter_percentage=0.3,
         distance=LpDistance(normalize_embeddings=False))
     self.normalized_dist_miner = HDCMiner(
         filter_percentage=0.3,
         distance=LpDistance(normalize_embeddings=True))
     self.normalized_dist_miner_squared = HDCMiner(
         filter_percentage=0.3,
         distance=LpDistance(normalize_embeddings=True, power=2))
     self.sim_miner = HDCMiner(filter_percentage=0.3,
                               distance=CosineSimilarity())
     self.labels = torch.LongTensor([0, 0, 1, 1, 1, 0])
     correct_a1 = torch.LongTensor([0, 5, 1, 5])
     correct_p = torch.LongTensor([5, 0, 5, 1])
     self.correct_pos_pairs = torch.stack([correct_a1, correct_p],
                                          dim=1).to(self.device)
     correct_a2 = torch.LongTensor([1, 2, 4, 5, 0, 2])
     correct_n = torch.LongTensor([2, 1, 5, 4, 2, 0])
     self.correct_neg_pairs = torch.stack([correct_a2, correct_n],
                                          dim=1).to(self.device)
コード例 #13
0
    def test_logit_getter(self):
        embedding_size = 512
        num_classes = 10
        batch_size = 32

        for dtype in TEST_DTYPES:
            embeddings = (
                torch.randn(batch_size, embedding_size).to(TEST_DEVICE).type(dtype)
            )
            kwargs = {"num_classes": num_classes, "embedding_size": embedding_size}
            loss1 = ArcFaceLoss(**kwargs).to(TEST_DEVICE).type(dtype)
            loss2 = NormalizedSoftmaxLoss(**kwargs).to(TEST_DEVICE).type(dtype)
            loss3 = ProxyAnchorLoss(**kwargs).to(TEST_DEVICE).type(dtype)

            # test the ability to infer shape
            for loss in [loss1, loss2, loss3]:
                self.helper_tester(loss, embeddings, batch_size, num_classes)

            # test specifying wrong layer name
            self.assertRaises(AttributeError, LogitGetter, loss1, layer_name="blah")

            # test specifying correct layer name
            self.helper_tester(
                loss1, embeddings, batch_size, num_classes, layer_name="W"
            )

            # test specifying a distance metric
            self.helper_tester(
                loss1, embeddings, batch_size, num_classes, distance=LpDistance()
            )

            # test specifying transpose incorrectly
            LG = LogitGetter(loss1, transpose=False)
            self.assertRaises(RuntimeError, LG, embeddings)

            # test specifying transpose correctly
            self.helper_tester(
                loss1, embeddings, batch_size, num_classes, transpose=True
            )

            # test copying weights
            LG = LogitGetter(loss1)
            self.assertTrue(torch.all(LG.weights == loss1.W))
            loss1.W.data *= 0
            self.assertTrue(not torch.all(LG.weights == loss1.W))

            # test not copying weights
            LG = LogitGetter(loss1, copy_weights=False)
            self.assertTrue(torch.all(LG.weights == loss1.W))
            loss1.W.data *= 0
            self.assertTrue(torch.all(LG.weights == loss1.W))
コード例 #14
0
 def test_backward(self):
     temperature = 0.1
     loss_funcA = NTXentLoss(temperature=temperature)
     loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance())
     for dtype in TEST_DTYPES:
         for loss_func in [loss_funcA, loss_funcB]:
             embedding_angles = [0, 20, 40, 60, 80]
             embeddings = torch.tensor(
                 [c_f.angle_to_coord(a) for a in embedding_angles],
                 requires_grad=True,
                 dtype=dtype).to(self.device)  #2D embeddings
             labels = torch.LongTensor([0, 0, 1, 1, 2])
             loss = loss_func(embeddings, labels)
             loss.backward()
コード例 #15
0
ファイル: test_criterion.py プロジェクト: AlexSchuy/hgcal-dev
def normal_loss(c, scale):
    criterion = TripletMarginLoss(triplets_per_anchor=1, distance=LpDistance(normalize_embeddings=False, p=1))
    l_rv = stats.norm(loc=c, scale=scale)
    r_rv = stats.norm(scale=scale)
    r_rv = l_rv
    l = l_rv.rvs(10).reshape((5, 2))
    r = r_rv.rvs(10).reshape((5, 2))
    df = pd.DataFrame()
    df['x'] = np.concatenate((l[:, 0], r[:, 0]))
    df['y'] = np.concatenate((l[:, 1], r[:, 1]))
    df['label'] = np.concatenate((np.zeros(5), np.ones(5)))
    embeddings = torch.as_tensor(np.concatenate((l, r)))
    labels = torch.as_tensor(df['label'])
    loss = criterion(embeddings, labels)
    print(f'center = {c}, loss = {loss}')
コード例 #16
0
 def get_pos_neg_vals(self, use_pairwise):
     output = (0, 1, LpDistance(power=2))
     if not use_pairwise:
         return (1, 0, CosineSimilarity())
     return output
コード例 #17
0
    def test_multi_similarity_miner(self):
        epsilon = 0.1
        for dtype in TEST_DTYPES:
            for distance in [CosineSimilarity(), LpDistance()]:
                miner = MultiSimilarityMiner(epsilon, distance=distance)
                embedding_angles = torch.arange(0, 64)
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=dtype,
                ).to(TEST_DEVICE)  # 2D embeddings
                labels = torch.randint(low=0, high=10, size=(64, ))
                mat = distance(embeddings)
                pos_pairs = []
                neg_pairs = []
                for i in range(len(embeddings)):
                    anchor_label = labels[i]
                    for j in range(len(embeddings)):
                        if j != i:
                            other_label = labels[j]
                            if anchor_label == other_label:
                                pos_pairs.append((i, j, mat[i, j]))
                            if anchor_label != other_label:
                                neg_pairs.append((i, j, mat[i, j]))

                correct_a1, correct_p = [], []
                correct_a2, correct_n = [], []
                for a1, p, ap_sim in pos_pairs:
                    most_difficult = (c_f.neg_inf(dtype)
                                      if distance.is_inverted else
                                      c_f.pos_inf(dtype))
                    for a2, n, an_sim in neg_pairs:
                        if a2 == a1:
                            condition = ((an_sim > most_difficult)
                                         if distance.is_inverted else
                                         (an_sim < most_difficult))
                            if condition:
                                most_difficult = an_sim
                    condition = ((ap_sim < most_difficult + epsilon)
                                 if distance.is_inverted else
                                 (ap_sim > most_difficult - epsilon))
                    if condition:
                        correct_a1.append(a1)
                        correct_p.append(p)

                for a2, n, an_sim in neg_pairs:
                    most_difficult = (c_f.pos_inf(dtype)
                                      if distance.is_inverted else
                                      c_f.neg_inf(dtype))
                    for a1, p, ap_sim in pos_pairs:
                        if a2 == a1:
                            condition = ((ap_sim < most_difficult)
                                         if distance.is_inverted else
                                         (ap_sim > most_difficult))
                            if condition:
                                most_difficult = ap_sim
                    condition = ((an_sim > most_difficult - epsilon)
                                 if distance.is_inverted else
                                 (an_sim < most_difficult + epsilon))
                    if condition:
                        correct_a2.append(a2)
                        correct_n.append(n)

                correct_pos_pairs = set([
                    (a, p) for a, p in zip(correct_a1, correct_p)
                ])
                correct_neg_pairs = set([
                    (a, n) for a, n in zip(correct_a2, correct_n)
                ])

                a1, p1, a2, n2 = miner(embeddings, labels)
                pos_pairs = set([(a.item(), p.item()) for a, p in zip(a1, p1)])
                neg_pairs = set([(a.item(), n.item()) for a, n in zip(a2, n2)])

                self.assertTrue(pos_pairs == correct_pos_pairs)
                self.assertTrue(neg_pairs == correct_neg_pairs)
コード例 #18
0
    def test_triplet_margin_miner(self):
        for dtype in TEST_DTYPES:
            for distance in [LpDistance(), CosineSimilarity()]:
                embedding_angles = torch.arange(0, 16)
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=dtype,
                ).to(self.device)  # 2D embeddings
                labels = torch.randint(low=0, high=2, size=(16, ))
                mat = distance(embeddings)
                triplets = []
                for i in range(len(embeddings)):
                    anchor_label = labels[i]
                    for j in range(len(embeddings)):
                        if j == i:
                            continue
                        positive_label = labels[j]
                        if positive_label == anchor_label:
                            ap_dist = mat[i, j]
                            for k in range(len(embeddings)):
                                if k == j or k == i:
                                    continue
                                negative_label = labels[k]
                                if negative_label != positive_label:
                                    an_dist = mat[i, k]
                                    if distance.is_inverted:
                                        triplets.append(
                                            (i, j, k, ap_dist - an_dist))
                                    else:
                                        triplets.append(
                                            (i, j, k, an_dist - ap_dist))

                for margin_int in range(-1, 11):
                    margin = float(margin_int) * 0.05
                    minerA = TripletMarginMiner(margin,
                                                type_of_triplets="all",
                                                distance=distance)
                    minerB = TripletMarginMiner(margin,
                                                type_of_triplets="hard",
                                                distance=distance)
                    minerC = TripletMarginMiner(margin,
                                                type_of_triplets="semihard",
                                                distance=distance)
                    minerD = TripletMarginMiner(margin,
                                                type_of_triplets="easy",
                                                distance=distance)

                    correctA, correctB, correctC, correctD = [], [], [], []
                    for i, j, k, distance_diff in triplets:
                        if distance_diff > margin:
                            correctD.append((i, j, k))
                        else:
                            correctA.append((i, j, k))
                            if distance_diff > 0:
                                correctC.append((i, j, k))
                            if distance_diff <= 0:
                                correctB.append((i, j, k))

                    for correct, miner in [
                        (correctA, minerA),
                        (correctB, minerB),
                        (correctC, minerC),
                        (correctD, minerD),
                    ]:
                        correct_triplets = set(correct)
                        a1, p1, n1 = miner(embeddings, labels)
                        mined_triplets = set([(a.item(), p.item(), n.item())
                                              for a, p, n in zip(a1, p1, n1)])
                        self.assertTrue(mined_triplets == correct_triplets)
コード例 #19
0
    def test_ntxent_loss(self):
        temperature = 0.1
        loss_funcA = NTXentLoss(temperature=temperature)
        loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance())
        loss_funcC = NTXentLoss(
            temperature=temperature, reducer=PerAnchorReducer(AvgNonZeroReducer())
        )
        loss_funcD = SupConLoss(temperature=temperature)
        loss_funcE = SupConLoss(temperature=temperature, distance=LpDistance())

        for dtype in TEST_DTYPES:
            embedding_angles = [0, 10, 20, 50, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype,
            ).to(
                TEST_DEVICE
            )  # 2D embeddings

            labels = torch.LongTensor([0, 0, 0, 1, 1, 2])

            obtained_losses = [
                x(embeddings, labels)
                for x in [loss_funcA, loss_funcB, loss_funcC, loss_funcD, loss_funcE]
            ]

            pos_pairs = [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1), (3, 4), (4, 3)]
            neg_pairs = [
                (0, 3),
                (0, 4),
                (0, 5),
                (1, 3),
                (1, 4),
                (1, 5),
                (2, 3),
                (2, 4),
                (2, 5),
                (3, 0),
                (3, 1),
                (3, 2),
                (3, 5),
                (4, 0),
                (4, 1),
                (4, 2),
                (4, 5),
                (5, 0),
                (5, 1),
                (5, 2),
                (5, 3),
                (5, 4),
            ]

            total_lossA, total_lossB, total_lossC, total_lossD, total_lossE = (
                0,
                0,
                torch.zeros(5, device=TEST_DEVICE, dtype=dtype),
                torch.zeros(5, device=TEST_DEVICE, dtype=dtype),
                torch.zeros(5, device=TEST_DEVICE, dtype=dtype),
            )
            for a1, p in pos_pairs:
                anchor, positive = embeddings[a1], embeddings[p]
                numeratorA = torch.exp(torch.matmul(anchor, positive) / temperature)
                numeratorB = torch.exp(
                    -torch.sqrt(torch.sum((anchor - positive) ** 2)) / temperature
                )
                denominatorA = numeratorA.clone()
                denominatorB = numeratorB.clone()
                denominatorD = 0
                denominatorE = 0
                for a2, n in pos_pairs + neg_pairs:
                    if a2 == a1:
                        negative = embeddings[n]
                        curr_denomD = torch.exp(
                            torch.matmul(anchor, negative) / temperature
                        )
                        curr_denomE = torch.exp(
                            -torch.sqrt(torch.sum((anchor - negative) ** 2))
                            / temperature
                        )
                        denominatorD += curr_denomD
                        denominatorE += curr_denomE
                        if (a2, n) not in pos_pairs:
                            denominatorA += curr_denomD
                            denominatorB += curr_denomE
                    else:
                        continue

                curr_lossA = -torch.log(numeratorA / denominatorA)
                curr_lossB = -torch.log(numeratorB / denominatorB)
                curr_lossD = -torch.log(numeratorA / denominatorD)
                curr_lossE = -torch.log(numeratorB / denominatorE)
                total_lossA += curr_lossA
                total_lossB += curr_lossB
                total_lossC[a1] += curr_lossA
                total_lossD[a1] += curr_lossD
                total_lossE[a1] += curr_lossE

            total_lossA /= len(pos_pairs)
            total_lossB /= len(pos_pairs)
            pos_pair_per_anchor = torch.tensor(
                [2, 2, 2, 1, 1], device=TEST_DEVICE, dtype=dtype
            )
            total_lossC, total_lossD, total_lossE = [
                torch.mean(x / pos_pair_per_anchor)
                for x in [total_lossC, total_lossD, total_lossE]
            ]

            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(obtained_losses[0], total_lossA, rtol=rtol))
            self.assertTrue(torch.isclose(obtained_losses[1], total_lossB, rtol=rtol))
            self.assertTrue(torch.isclose(obtained_losses[2], total_lossC, rtol=rtol))
            self.assertTrue(torch.isclose(obtained_losses[3], total_lossD, rtol=rtol))
            self.assertTrue(torch.isclose(obtained_losses[4], total_lossE, rtol=rtol))
コード例 #20
0
 def setUpClass(self):
     self.labels = torch.LongTensor([0, 0, 1, 1, 0, 2, 1, 1, 1])
     self.a1_idx, self.p_idx, self.a2_idx, self.n_idx = lmu.get_all_pairs_indices(
         self.labels)
     self.distance = LpDistance(normalize_embeddings=False)
     self.gt = {
         "batch_semihard_hard": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.SEMIHARD,
                 neg_strategy=BatchEasyHardMiner.HARD,
             ),
             "easiest_triplet":
             -1,
             "hardest_triplet":
             -1,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             2,
             "easiest_neg_pair":
             3,
             "hardest_neg_pair":
             2,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([1, 6, 6]).to(TEST_DEVICE),
                     torch.LongTensor([1, 8, 6]).to(TEST_DEVICE),
                 ],
                 "correct_n": [
                     torch.LongTensor([2, 5, 5]).to(TEST_DEVICE),
                     torch.LongTensor([2, 5, 5]).to(TEST_DEVICE),
                 ],
             },
         },
         "batch_hard_semihard": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.HARD,
                 neg_strategy=BatchEasyHardMiner.SEMIHARD,
             ),
             "easiest_triplet":
             -1,
             "hardest_triplet":
             -1,
             "easiest_pos_pair":
             3,
             "hardest_pos_pair":
             6,
             "easiest_neg_pair":
             7,
             "hardest_neg_pair":
             4,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p":
                 [torch.LongTensor([4, 4, 2, 2, 2]).to(TEST_DEVICE)],
                 "correct_n": [
                     torch.LongTensor([5, 5, 1, 1, 1]).to(TEST_DEVICE),
                 ],
             },
         },
         "batch_easy_semihard": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.EASY,
                 neg_strategy=BatchEasyHardMiner.SEMIHARD,
             ),
             "easiest_triplet":
             -2,
             "hardest_triplet":
             -1,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             3,
             "easiest_neg_pair":
             4,
             "hardest_neg_pair":
             2,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 8,
                                       7]).to(TEST_DEVICE),
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 6,
                                       7]).to(TEST_DEVICE),
                 ],
                 "correct_n": [
                     torch.LongTensor([2, 3, 0, 1, 8, 4, 5,
                                       5]).to(TEST_DEVICE),
                     torch.LongTensor([2, 3, 4, 1, 8, 4, 5,
                                       5]).to(TEST_DEVICE),
                     torch.LongTensor([2, 3, 0, 5, 8, 4, 5,
                                       5]).to(TEST_DEVICE),
                     torch.LongTensor([2, 3, 4, 5, 8, 4, 5,
                                       5]).to(TEST_DEVICE),
                 ],
             },
         },
         "batch_hard_hard": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.HARD,
                 neg_strategy=BatchEasyHardMiner.HARD,
             ),
             "easiest_triplet":
             2,
             "hardest_triplet":
             5,
             "easiest_pos_pair":
             3,
             "hardest_pos_pair":
             6,
             "easiest_neg_pair":
             3,
             "hardest_neg_pair":
             1,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([4, 4, 8, 8, 0, 2, 2,
                                       2]).to(TEST_DEVICE)
                 ],
                 "correct_n": [
                     torch.LongTensor([2, 2, 1, 4, 3, 5, 5,
                                       5]).to(TEST_DEVICE),
                     torch.LongTensor([2, 2, 1, 4, 5, 5, 5,
                                       5]).to(TEST_DEVICE),
                 ],
             },
         },
         "batch_easy_hard": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.EASY,
                 neg_strategy=BatchEasyHardMiner.HARD,
             ),
             "easiest_triplet":
             -2,
             "hardest_triplet":
             2,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             3,
             "easiest_neg_pair":
             3,
             "hardest_neg_pair":
             1,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 8,
                                       7]).to(TEST_DEVICE),
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 6,
                                       7]).to(TEST_DEVICE),
                 ],
                 "correct_n": [
                     torch.LongTensor([2, 2, 1, 4, 3, 5, 5,
                                       5]).to(TEST_DEVICE),
                     torch.LongTensor([2, 2, 1, 4, 5, 5, 5,
                                       5]).to(TEST_DEVICE),
                 ],
             },
         },
         "batch_hard_easy": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.HARD,
                 neg_strategy=BatchEasyHardMiner.EASY,
             ),
             "easiest_triplet":
             -4,
             "hardest_triplet":
             3,
             "easiest_pos_pair":
             3,
             "hardest_pos_pair":
             6,
             "easiest_neg_pair":
             8,
             "hardest_neg_pair":
             3,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([4, 4, 8, 8, 0, 2, 2,
                                       2]).to(TEST_DEVICE)
                 ],
                 "correct_n": [
                     torch.LongTensor([8, 8, 5, 0, 8, 0, 0,
                                       0]).to(TEST_DEVICE)
                 ],
             },
         },
         "batch_easy_easy": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.EASY,
                 neg_strategy=BatchEasyHardMiner.EASY,
             ),
             "easiest_triplet":
             -7,
             "hardest_triplet":
             -1,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             3,
             "easiest_neg_pair":
             8,
             "hardest_neg_pair":
             3,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 8,
                                       7]).to(TEST_DEVICE),
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 6,
                                       7]).to(TEST_DEVICE),
                 ],
                 "correct_n": [
                     torch.LongTensor([8, 8, 5, 0, 8, 0, 0,
                                       0]).to(TEST_DEVICE)
                 ],
             },
         },
         "batch_easy_easy_with_min_val": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.EASY,
                 neg_strategy=BatchEasyHardMiner.EASY,
                 allowed_neg_range=[1, 7],
                 allowed_pos_range=[1, 7],
             ),
             "easiest_triplet":
             -6,
             "hardest_triplet":
             -1,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             3,
             "easiest_neg_pair":
             7,
             "hardest_neg_pair":
             3,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 8,
                                       7]).to(TEST_DEVICE),
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 6,
                                       7]).to(TEST_DEVICE),
                 ],
                 "correct_n": [
                     torch.LongTensor([7, 8, 5, 0, 8, 0, 0,
                                       1]).to(TEST_DEVICE)
                 ],
             },
         },
         "batch_easy_all": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.EASY,
                 neg_strategy=BatchEasyHardMiner.ALL,
             ),
             "easiest_triplet":
             0,
             "hardest_triplet":
             0,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             3,
             "easiest_neg_pair":
             8,
             "hardest_neg_pair":
             1,
             "expected": {
                 "correct_a1":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 8,
                                       7]).to(TEST_DEVICE),
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 6,
                                       7]).to(TEST_DEVICE),
                 ],
                 "correct_a2":
                 self.a2_idx,
                 "correct_n": [self.n_idx],
             },
         },
         "batch_all_easy": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.ALL,
                 neg_strategy=BatchEasyHardMiner.EASY,
             ),
             "easiest_triplet":
             0,
             "hardest_triplet":
             0,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             6,
             "easiest_neg_pair":
             8,
             "hardest_neg_pair":
             3,
             "expected": {
                 "correct_a1":
                 self.a1_idx,
                 "correct_p": [self.p_idx],
                 "correct_a2":
                 torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7,
                                   8]).to(TEST_DEVICE),
                 "correct_n": [
                     torch.LongTensor([8, 8, 5, 0, 8, 0, 0, 0,
                                       0]).to(TEST_DEVICE),
                 ],
             },
         },
         "batch_all_all": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.ALL,
                 neg_strategy=BatchEasyHardMiner.ALL,
             ),
             "easiest_triplet":
             0,
             "hardest_triplet":
             0,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             6,
             "easiest_neg_pair":
             8,
             "hardest_neg_pair":
             1,
             "expected": {
                 "correct_a1": self.a1_idx,
                 "correct_p": [self.p_idx],
                 "correct_a2": self.a2_idx,
                 "correct_n": [self.n_idx],
             },
         },
     }
コード例 #21
0
    def test_contrastive_loss(self):
        loss_funcA = ContrastiveLoss(pos_margin=0.25,
                                     neg_margin=1.5,
                                     distance=LpDistance(power=2))
        loss_funcB = ContrastiveLoss(pos_margin=1.5,
                                     neg_margin=0.6,
                                     distance=CosineSimilarity())
        loss_funcC = ContrastiveLoss(pos_margin=0.25,
                                     neg_margin=1.5,
                                     distance=LpDistance(power=2),
                                     reducer=MeanReducer())
        loss_funcD = ContrastiveLoss(pos_margin=1.5,
                                     neg_margin=0.6,
                                     distance=CosineSimilarity(),
                                     reducer=MeanReducer())

        for dtype in TEST_DTYPES:
            embedding_angles = [0, 20, 40, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype).to(self.device)  #2D embeddings
            labels = torch.LongTensor([0, 0, 1, 1, 2])

            lossA = loss_funcA(embeddings, labels)
            lossB = loss_funcB(embeddings, labels)
            lossC = loss_funcC(embeddings, labels)
            lossD = loss_funcD(embeddings, labels)

            pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
            neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4),
                         (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4),
                         (4, 0), (4, 1), (4, 2), (4, 3)]

            correct_pos_losses = [0, 0, 0, 0]
            correct_neg_losses = [0, 0, 0, 0]
            num_non_zero_pos = [0, 0, 0, 0]
            num_non_zero_neg = [0, 0, 0, 0]
            for a, p in pos_pairs:
                anchor, positive = embeddings[a], embeddings[p]
                correct_lossA = torch.relu(
                    torch.sum((anchor - positive)**2) - 0.25)
                correct_lossB = torch.relu(1.5 -
                                           torch.matmul(anchor, positive))
                correct_pos_losses[0] += correct_lossA
                correct_pos_losses[1] += correct_lossB
                correct_pos_losses[2] += correct_lossA
                correct_pos_losses[3] += correct_lossB
                if correct_lossA > 0:
                    num_non_zero_pos[0] += 1
                    num_non_zero_pos[2] += 1
                if correct_lossB > 0:
                    num_non_zero_pos[1] += 1
                    num_non_zero_pos[3] += 1

            for a, n in neg_pairs:
                anchor, negative = embeddings[a], embeddings[n]
                correct_lossA = torch.relu(1.5 -
                                           torch.sum((anchor - negative)**2))
                correct_lossB = torch.relu(
                    torch.matmul(anchor, negative) - 0.6)
                correct_neg_losses[0] += correct_lossA
                correct_neg_losses[1] += correct_lossB
                correct_neg_losses[2] += correct_lossA
                correct_neg_losses[3] += correct_lossB
                if correct_lossA > 0:
                    num_non_zero_neg[0] += 1
                    num_non_zero_neg[2] += 1
                if correct_lossB > 0:
                    num_non_zero_neg[1] += 1
                    num_non_zero_neg[3] += 1

            for i in range(2):
                if num_non_zero_pos[i] > 0:
                    correct_pos_losses[i] /= num_non_zero_pos[i]
                if num_non_zero_neg[i] > 0:
                    correct_neg_losses[i] /= num_non_zero_neg[i]

            for i in range(2, 4):
                correct_pos_losses[i] /= len(pos_pairs)
                correct_neg_losses[i] /= len(neg_pairs)

            correct_losses = [0, 0, 0, 0]
            for i in range(4):
                correct_losses[
                    i] = correct_pos_losses[i] + correct_neg_losses[i]

            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(lossA, correct_losses[0], rtol=rtol))
            self.assertTrue(torch.isclose(lossB, correct_losses[1], rtol=rtol))
            self.assertTrue(torch.isclose(lossC, correct_losses[2], rtol=rtol))
            self.assertTrue(torch.isclose(lossD, correct_losses[3], rtol=rtol))
コード例 #22
0
def triplet_margin_loss_factory(triplets_per_anchor, normalize_embeddings, p):
    criterion = TripletMarginLoss(
        triplets_per_anchor=triplets_per_anchor,
        distance=LpDistance(normalize_embeddings=normalize_embeddings, p=p))
    return criterion