コード例 #1
0
 def setUpClass(self):
     self.dist_miner = PairMarginMiner(pos_margin=4,
                                       neg_margin=4,
                                       use_similarity=False,
                                       normalize_embeddings=False)
     self.normalized_dist_miner = PairMarginMiner(pos_margin=1.29,
                                                  neg_margin=1.28,
                                                  use_similarity=False,
                                                  normalize_embeddings=True)
     self.normalized_dist_miner_squared = PairMarginMiner(
         pos_margin=1.66,
         neg_margin=1.64,
         use_similarity=False,
         normalize_embeddings=True,
         squared_distances=True)
     self.sim_miner = PairMarginMiner(pos_margin=0.17,
                                      neg_margin=0.18,
                                      use_similarity=True,
                                      normalize_embeddings=True)
     self.labels = torch.LongTensor([0, 0, 1, 1, 0, 2, 1, 1, 1])
     self.correct_a1 = torch.LongTensor([2, 2, 3, 7, 8, 8])
     self.correct_p = torch.LongTensor([7, 8, 8, 2, 2, 3])
     self.correct_a2 = torch.LongTensor([
         0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5,
         5, 6, 6, 7, 7, 8
     ])
     self.correct_n = torch.LongTensor([
         2, 3, 2, 3, 0, 1, 4, 5, 0, 1, 4, 5, 2, 3, 5, 6, 7, 2, 3, 4, 6, 7,
         8, 4, 5, 4, 5, 5
     ])
コード例 #2
0
 def test_input_indices_tuple(self):
     batch_size = 32
     pair_miner = PairMarginMiner(pos_margin=0,
                                  neg_margin=1,
                                  use_similarity=False)
     triplet_miner = TripletMarginMiner(margin=1)
     self.loss = CrossBatchMemory(loss=ContrastiveLoss(),
                                  embedding_size=self.embedding_size,
                                  memory_size=self.memory_size)
     for i in range(30):
         embeddings = torch.randn(batch_size, self.embedding_size)
         labels = torch.arange(batch_size)
         self.loss(embeddings, labels)
         for curr_miner in [pair_miner, triplet_miner]:
             input_indices_tuple = curr_miner(embeddings, labels)
             all_labels = torch.cat([labels, self.loss.label_memory], dim=0)
             a1ii, pii, a2ii, nii = lmu.convert_to_pairs(
                 input_indices_tuple, labels)
             a1i, pi, a2i, ni = lmu.get_all_pairs_indices(
                 labels, self.loss.label_memory)
             a1, p, a2, n = self.loss.create_indices_tuple(
                 batch_size, embeddings, labels, self.loss.embedding_memory,
                 self.loss.label_memory, input_indices_tuple)
             self.assertTrue(not torch.any((all_labels[a1] -
                                            all_labels[p]).bool()))
             self.assertTrue(
                 torch.all((all_labels[a2] - all_labels[n]).bool()))
             self.assertTrue(len(a1) == len(a1i) + len(a1ii))
             self.assertTrue(len(p) == len(pi) + len(pii))
             self.assertTrue(len(a2) == len(a2i) + len(a2ii))
             self.assertTrue(len(n) == len(ni) + len(nii))
コード例 #3
0
    def test_shift_indices_tuple(self):
        for dtype in TEST_DTYPES:
            batch_size = 32
            pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1)
            triplet_miner = TripletMarginMiner(margin=1)
            self.loss = CrossBatchMemory(
                loss=ContrastiveLoss(),
                embedding_size=self.embedding_size,
                memory_size=self.memory_size,
            )
            for i in range(30):
                embeddings = (
                    torch.randn(batch_size, self.embedding_size)
                    .to(TEST_DEVICE)
                    .type(dtype)
                )
                labels = torch.arange(batch_size).to(TEST_DEVICE)
                loss = self.loss(embeddings, labels)
                all_labels = torch.cat([labels, self.loss.label_memory], dim=0)

                indices_tuple = lmu.get_all_pairs_indices(
                    labels, self.loss.label_memory
                )
                shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
                self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
                self.assertTrue(torch.equal(indices_tuple[2], shifted[2]))
                self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size))
                self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size))
                a1, p, a2, n = shifted
                self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool()))
                self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool()))

                indices_tuple = pair_miner(
                    embeddings,
                    labels,
                    self.loss.embedding_memory,
                    self.loss.label_memory,
                )
                shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
                self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
                self.assertTrue(torch.equal(indices_tuple[2], shifted[2]))
                self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size))
                self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size))
                a1, p, a2, n = shifted
                self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool()))
                self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool()))

                indices_tuple = triplet_miner(
                    embeddings,
                    labels,
                    self.loss.embedding_memory,
                    self.loss.label_memory,
                )
                shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
                self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
                self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size))
                self.assertTrue(torch.equal(indices_tuple[2], shifted[2] - batch_size))
                a, p, n = shifted
                self.assertTrue(not torch.any((all_labels[a] - all_labels[p]).bool()))
                self.assertTrue(torch.all((all_labels[p] - all_labels[n]).bool()))
コード例 #4
0
 def test_empty_output(self):
     miner = PairMarginMiner(0, 1)
     batch_size = 32
     for dtype in TEST_DTYPES:
         embeddings = torch.randn(batch_size, 64).type(dtype).to(self.device)
         labels = torch.arange(batch_size)
         a, p, _, _ = miner(embeddings, labels)
         self.assertTrue(len(a)==0)
         self.assertTrue(len(p)==0)
コード例 #5
0
    def test_pair_margin_miner(self):
        for dtype in TEST_DTYPES:
            for distance in [LpDistance(), CosineSimilarity()]:
                embedding_angles = torch.arange(0, 16)
                embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings
                labels = torch.randint(low=0, high=2, size=(16,))
                mat = distance(embeddings)
                pos_pairs = []
                neg_pairs = []
                for i in range(len(embeddings)):
                    anchor_label = labels[i]
                    for j in range(len(embeddings)):
                        if j == i:
                            continue
                        positive_label = labels[j]
                        if positive_label == anchor_label:
                            ap_dist = mat[i,j]
                            pos_pairs.append((i, j, ap_dist))

                for i in range(len(embeddings)):
                    anchor_label = labels[i]
                    for j in range(len(embeddings)):
                        if j == i:
                            continue
                        negative_label = labels[j]
                        if negative_label != anchor_label:
                            an_dist = mat[i,j]
                            neg_pairs.append((i, j, an_dist))

                for pos_margin_int in range(-1, 4):
                    pos_margin = float(pos_margin_int) * 0.05
                    for neg_margin_int in range(2, 7):
                        neg_margin = float(neg_margin_int) * 0.05
                        miner = PairMarginMiner(pos_margin, neg_margin, distance=distance)
                        correct_pos_pairs = []
                        correct_neg_pairs = []
                        for i,j,k in pos_pairs:
                            condition = (k < pos_margin) if distance.is_inverted else (k > pos_margin)
                            if condition:
                                correct_pos_pairs.append((i,j))
                        for i,j,k in neg_pairs:
                            condition = (k > neg_margin) if distance.is_inverted else (k < neg_margin)
                            if condition:                        
                                correct_neg_pairs.append((i,j))

                        correct_pos = set(correct_pos_pairs)
                        correct_neg = set(correct_neg_pairs)
                        a1, p1, a2, n2 = miner(embeddings, labels)
                        mined_pos = set([(a.item(),p.item()) for a,p in zip(a1,p1)])
                        mined_neg = set([(a.item(),n.item()) for a,n in zip(a2,n2)])

                        self.assertTrue(mined_pos == correct_pos)
                        self.assertTrue(mined_neg == correct_neg)