コード例 #1
0
    def test_sanity_check(self):
        # cross batch memory with batch_size == memory_size should be equivalent to just using the inner loss function
        for dtype in TEST_DTYPES:
            for test_enqueue_idx in [False, True]:
                for memory_size in range(20, 40, 5):
                    inner_loss = NTXentLoss(temperature=0.1)
                    inner_miner = TripletMarginMiner(margin=0.1)
                    loss = CrossBatchMemory(
                        loss=inner_loss,
                        embedding_size=self.embedding_size,
                        memory_size=memory_size,
                    )
                    loss_with_miner = CrossBatchMemory(
                        loss=inner_loss,
                        embedding_size=self.embedding_size,
                        memory_size=memory_size,
                        miner=inner_miner,
                    )
                    for i in range(10):
                        if test_enqueue_idx:
                            enqueue_idx = torch.arange(memory_size, memory_size * 2)
                            not_enqueue_idx = torch.arange(memory_size)
                            batch_size = memory_size * 2
                        else:
                            enqueue_idx = None
                            batch_size = memory_size
                        embeddings = (
                            torch.randn(batch_size, self.embedding_size)
                            .to(TEST_DEVICE)
                            .type(dtype)
                        )
                        labels = torch.randint(0, 4, (batch_size,)).to(TEST_DEVICE)

                        if test_enqueue_idx:
                            pairs = lmu.get_all_pairs_indices(
                                labels[not_enqueue_idx], labels[enqueue_idx]
                            )
                            pairs = c_f.shift_indices_tuple(pairs, memory_size)
                            inner_loss_val = inner_loss(embeddings, labels, pairs)
                        else:
                            inner_loss_val = inner_loss(embeddings, labels)
                        loss_val = loss(embeddings, labels, enqueue_idx=enqueue_idx)
                        self.assertTrue(torch.isclose(inner_loss_val, loss_val))

                        if test_enqueue_idx:
                            triplets = inner_miner(
                                embeddings[not_enqueue_idx],
                                labels[not_enqueue_idx],
                                embeddings[enqueue_idx],
                                labels[enqueue_idx],
                            )
                            triplets = c_f.shift_indices_tuple(triplets, memory_size)
                            inner_loss_val = inner_loss(embeddings, labels, triplets)
                        else:
                            triplets = inner_miner(embeddings, labels)
                            inner_loss_val = inner_loss(embeddings, labels, triplets)
                        loss_val = loss_with_miner(
                            embeddings, labels, enqueue_idx=enqueue_idx
                        )
                        self.assertTrue(torch.isclose(inner_loss_val, loss_val))
コード例 #2
0
 def test_input_indices_tuple(self):
     batch_size = 32
     pair_miner = PairMarginMiner(pos_margin=0,
                                  neg_margin=1,
                                  use_similarity=False)
     triplet_miner = TripletMarginMiner(margin=1)
     self.loss = CrossBatchMemory(loss=ContrastiveLoss(),
                                  embedding_size=self.embedding_size,
                                  memory_size=self.memory_size)
     for i in range(30):
         embeddings = torch.randn(batch_size, self.embedding_size)
         labels = torch.arange(batch_size)
         self.loss(embeddings, labels)
         for curr_miner in [pair_miner, triplet_miner]:
             input_indices_tuple = curr_miner(embeddings, labels)
             all_labels = torch.cat([labels, self.loss.label_memory], dim=0)
             a1ii, pii, a2ii, nii = lmu.convert_to_pairs(
                 input_indices_tuple, labels)
             a1i, pi, a2i, ni = lmu.get_all_pairs_indices(
                 labels, self.loss.label_memory)
             a1, p, a2, n = self.loss.create_indices_tuple(
                 batch_size, embeddings, labels, self.loss.embedding_memory,
                 self.loss.label_memory, input_indices_tuple)
             self.assertTrue(not torch.any((all_labels[a1] -
                                            all_labels[p]).bool()))
             self.assertTrue(
                 torch.all((all_labels[a2] - all_labels[n]).bool()))
             self.assertTrue(len(a1) == len(a1i) + len(a1ii))
             self.assertTrue(len(p) == len(pi) + len(pii))
             self.assertTrue(len(a2) == len(a2i) + len(a2ii))
             self.assertTrue(len(n) == len(ni) + len(nii))
コード例 #3
0
    def test_shift_indices_tuple(self):
        for dtype in TEST_DTYPES:
            batch_size = 32
            pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1)
            triplet_miner = TripletMarginMiner(margin=1)
            self.loss = CrossBatchMemory(
                loss=ContrastiveLoss(),
                embedding_size=self.embedding_size,
                memory_size=self.memory_size,
            )
            for i in range(30):
                embeddings = (
                    torch.randn(batch_size, self.embedding_size)
                    .to(TEST_DEVICE)
                    .type(dtype)
                )
                labels = torch.arange(batch_size).to(TEST_DEVICE)
                loss = self.loss(embeddings, labels)
                all_labels = torch.cat([labels, self.loss.label_memory], dim=0)

                indices_tuple = lmu.get_all_pairs_indices(
                    labels, self.loss.label_memory
                )
                shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
                self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
                self.assertTrue(torch.equal(indices_tuple[2], shifted[2]))
                self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size))
                self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size))
                a1, p, a2, n = shifted
                self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool()))
                self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool()))

                indices_tuple = pair_miner(
                    embeddings,
                    labels,
                    self.loss.embedding_memory,
                    self.loss.label_memory,
                )
                shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
                self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
                self.assertTrue(torch.equal(indices_tuple[2], shifted[2]))
                self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size))
                self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size))
                a1, p, a2, n = shifted
                self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool()))
                self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool()))

                indices_tuple = triplet_miner(
                    embeddings,
                    labels,
                    self.loss.embedding_memory,
                    self.loss.label_memory,
                )
                shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
                self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
                self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size))
                self.assertTrue(torch.equal(indices_tuple[2], shifted[2] - batch_size))
                a, p, n = shifted
                self.assertTrue(not torch.any((all_labels[a] - all_labels[p]).bool()))
                self.assertTrue(torch.all((all_labels[p] - all_labels[n]).bool()))
コード例 #4
0
    def test_uniform_histogram_miner(self):
        torch.manual_seed(93612)
        batch_size = 128
        embedding_size = 32
        num_bins, pos_per_bin, neg_per_bin = 100, 25, 123
        for distance in [
                LpDistance(p=1),
                LpDistance(p=2),
                LpDistance(normalize_embeddings=False),
                SNRDistance(),
        ]:
            miner = UniformHistogramMiner(
                num_bins=num_bins,
                pos_per_bin=pos_per_bin,
                neg_per_bin=neg_per_bin,
                distance=distance,
            )
            for dtype in TEST_DTYPES:
                embeddings = torch.randn(batch_size,
                                         embedding_size,
                                         device=TEST_DEVICE,
                                         dtype=dtype)
                labels = torch.randint(0,
                                       2,
                                       size=(batch_size, ),
                                       device=TEST_DEVICE)

                a1, p, a2, n = lmu.get_all_pairs_indices(labels)
                dist_mat = distance(embeddings)
                pos_pairs = dist_mat[a1, p]
                neg_pairs = dist_mat[a2, n]

                a1, p, a2, n = miner(embeddings, labels)

                if dtype == torch.float16:
                    continue  # histc doesn't work for Half tensor

                pos_histogram = torch.histc(
                    dist_mat[a1, p],
                    bins=num_bins,
                    min=torch.min(pos_pairs),
                    max=torch.max(pos_pairs),
                )
                neg_histogram = torch.histc(
                    dist_mat[a2, n],
                    bins=num_bins,
                    min=torch.min(neg_pairs),
                    max=torch.max(neg_pairs),
                )

                self.assertTrue(
                    torch.all((pos_histogram == pos_per_bin)
                              | (pos_histogram == 0)))
                self.assertTrue(
                    torch.all((neg_histogram == neg_per_bin)
                              | (neg_histogram == 0)))
コード例 #5
0
    def test_get_all_pairs_triplets_indices(self):
        original_x = torch.arange(10)

        for i in range(1, 11):
            x = original_x.repeat(i)
            correct_num_pos = len(x)*(i-1)
            correct_num_neg = len(x)*(len(x)-i)
            a1, p, a2, n = lmu.get_all_pairs_indices(x)
            self.assertTrue(len(a1) == len(p) == correct_num_pos)
            self.assertTrue(len(a2) == len(n) == correct_num_neg)

            correct_num_triplets = len(x)*(i-1)*(len(x)-i)
            a, p, n = lmu.get_all_triplets_indices(x)
            self.assertTrue(len(a) == len(p) == len(n) == correct_num_triplets)
コード例 #6
0
    def test_loss(self):
        num_labels = 10
        num_iter = 10
        batch_size = 32
        inner_loss = ContrastiveLoss()
        inner_miner = MultiSimilarityMiner(0.3)
        outer_miner = MultiSimilarityMiner(0.2)
        self.loss = CrossBatchMemory(loss=inner_loss, embedding_size=self.embedding_size, memory_size=self.memory_size)
        self.loss_with_miner = CrossBatchMemory(loss=inner_loss, miner=inner_miner, embedding_size=self.embedding_size, memory_size=self.memory_size)
        self.loss_with_miner2 = CrossBatchMemory(loss=inner_loss, miner=inner_miner, embedding_size=self.embedding_size, memory_size=self.memory_size)
        all_embeddings = torch.FloatTensor([])
        all_labels = torch.LongTensor([])
        for i in range(num_iter):
            embeddings = torch.randn(batch_size, self.embedding_size)
            labels = torch.randint(0,num_labels,(batch_size,))
            loss = self.loss(embeddings, labels)
            loss_with_miner = self.loss_with_miner(embeddings, labels)
            oa1, op, oa2, on = outer_miner(embeddings, labels)
            loss_with_miner_and_input_indices = self.loss_with_miner2(embeddings, labels, (oa1, op, oa2, on))
            all_embeddings = torch.cat([all_embeddings, embeddings])
            all_labels = torch.cat([all_labels, labels])

            # loss with no inner miner
            indices_tuple = lmu.get_all_pairs_indices(labels, all_labels)
            a1,p,a2,n = self.loss.remove_self_comparisons(indices_tuple)
            p = p+batch_size
            n = n+batch_size
            correct_loss = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n))
            self.assertTrue(torch.isclose(loss, correct_loss))

            # loss with inner miner
            indices_tuple = inner_miner(embeddings, labels, all_embeddings, all_labels)
            a1,p,a2,n = self.loss_with_miner.remove_self_comparisons(indices_tuple)
            p = p+batch_size
            n = n+batch_size
            correct_loss_with_miner = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n))
            self.assertTrue(torch.isclose(loss_with_miner, correct_loss_with_miner))

            # loss with inner and outer miner
            indices_tuple = inner_miner(embeddings, labels, all_embeddings, all_labels)
            a1,p,a2,n = self.loss_with_miner2.remove_self_comparisons(indices_tuple)
            p = p+batch_size
            n = n+batch_size
            a1 = torch.cat([oa1, a1])
            p = torch.cat([op, p])
            a2 = torch.cat([oa2, a2])
            n = torch.cat([on, n])
            correct_loss_with_miner_and_input_indice = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n))
            self.assertTrue(torch.isclose(loss_with_miner_and_input_indices, correct_loss_with_miner_and_input_indice))
 def setUpClass(self):
     self.labels = torch.LongTensor([0, 0, 1, 1, 0, 2, 1, 1, 1])
     self.a1_idx, self.p_idx, self.a2_idx, self.n_idx = lmu.get_all_pairs_indices(
         self.labels)
     self.distance = LpDistance(normalize_embeddings=False)
     self.gt = {
         "batch_semihard_hard": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.SEMIHARD,
                 neg_strategy=BatchEasyHardMiner.HARD,
             ),
             "easiest_triplet":
             -1,
             "hardest_triplet":
             -1,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             2,
             "easiest_neg_pair":
             3,
             "hardest_neg_pair":
             2,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([1, 6, 6]).to(TEST_DEVICE),
                     torch.LongTensor([1, 8, 6]).to(TEST_DEVICE),
                 ],
                 "correct_n": [
                     torch.LongTensor([2, 5, 5]).to(TEST_DEVICE),
                     torch.LongTensor([2, 5, 5]).to(TEST_DEVICE),
                 ],
             },
         },
         "batch_hard_semihard": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.HARD,
                 neg_strategy=BatchEasyHardMiner.SEMIHARD,
             ),
             "easiest_triplet":
             -1,
             "hardest_triplet":
             -1,
             "easiest_pos_pair":
             3,
             "hardest_pos_pair":
             6,
             "easiest_neg_pair":
             7,
             "hardest_neg_pair":
             4,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p":
                 [torch.LongTensor([4, 4, 2, 2, 2]).to(TEST_DEVICE)],
                 "correct_n": [
                     torch.LongTensor([5, 5, 1, 1, 1]).to(TEST_DEVICE),
                 ],
             },
         },
         "batch_easy_semihard": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.EASY,
                 neg_strategy=BatchEasyHardMiner.SEMIHARD,
             ),
             "easiest_triplet":
             -2,
             "hardest_triplet":
             -1,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             3,
             "easiest_neg_pair":
             4,
             "hardest_neg_pair":
             2,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 8,
                                       7]).to(TEST_DEVICE),
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 6,
                                       7]).to(TEST_DEVICE),
                 ],
                 "correct_n": [
                     torch.LongTensor([2, 3, 0, 1, 8, 4, 5,
                                       5]).to(TEST_DEVICE),
                     torch.LongTensor([2, 3, 4, 1, 8, 4, 5,
                                       5]).to(TEST_DEVICE),
                     torch.LongTensor([2, 3, 0, 5, 8, 4, 5,
                                       5]).to(TEST_DEVICE),
                     torch.LongTensor([2, 3, 4, 5, 8, 4, 5,
                                       5]).to(TEST_DEVICE),
                 ],
             },
         },
         "batch_hard_hard": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.HARD,
                 neg_strategy=BatchEasyHardMiner.HARD,
             ),
             "easiest_triplet":
             2,
             "hardest_triplet":
             5,
             "easiest_pos_pair":
             3,
             "hardest_pos_pair":
             6,
             "easiest_neg_pair":
             3,
             "hardest_neg_pair":
             1,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([4, 4, 8, 8, 0, 2, 2,
                                       2]).to(TEST_DEVICE)
                 ],
                 "correct_n": [
                     torch.LongTensor([2, 2, 1, 4, 3, 5, 5,
                                       5]).to(TEST_DEVICE),
                     torch.LongTensor([2, 2, 1, 4, 5, 5, 5,
                                       5]).to(TEST_DEVICE),
                 ],
             },
         },
         "batch_easy_hard": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.EASY,
                 neg_strategy=BatchEasyHardMiner.HARD,
             ),
             "easiest_triplet":
             -2,
             "hardest_triplet":
             2,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             3,
             "easiest_neg_pair":
             3,
             "hardest_neg_pair":
             1,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 8,
                                       7]).to(TEST_DEVICE),
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 6,
                                       7]).to(TEST_DEVICE),
                 ],
                 "correct_n": [
                     torch.LongTensor([2, 2, 1, 4, 3, 5, 5,
                                       5]).to(TEST_DEVICE),
                     torch.LongTensor([2, 2, 1, 4, 5, 5, 5,
                                       5]).to(TEST_DEVICE),
                 ],
             },
         },
         "batch_hard_easy": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.HARD,
                 neg_strategy=BatchEasyHardMiner.EASY,
             ),
             "easiest_triplet":
             -4,
             "hardest_triplet":
             3,
             "easiest_pos_pair":
             3,
             "hardest_pos_pair":
             6,
             "easiest_neg_pair":
             8,
             "hardest_neg_pair":
             3,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([4, 4, 8, 8, 0, 2, 2,
                                       2]).to(TEST_DEVICE)
                 ],
                 "correct_n": [
                     torch.LongTensor([8, 8, 5, 0, 8, 0, 0,
                                       0]).to(TEST_DEVICE)
                 ],
             },
         },
         "batch_easy_easy": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.EASY,
                 neg_strategy=BatchEasyHardMiner.EASY,
             ),
             "easiest_triplet":
             -7,
             "hardest_triplet":
             -1,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             3,
             "easiest_neg_pair":
             8,
             "hardest_neg_pair":
             3,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 8,
                                       7]).to(TEST_DEVICE),
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 6,
                                       7]).to(TEST_DEVICE),
                 ],
                 "correct_n": [
                     torch.LongTensor([8, 8, 5, 0, 8, 0, 0,
                                       0]).to(TEST_DEVICE)
                 ],
             },
         },
         "batch_easy_easy_with_min_val": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.EASY,
                 neg_strategy=BatchEasyHardMiner.EASY,
                 allowed_neg_range=[1, 7],
                 allowed_pos_range=[1, 7],
             ),
             "easiest_triplet":
             -6,
             "hardest_triplet":
             -1,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             3,
             "easiest_neg_pair":
             7,
             "hardest_neg_pair":
             3,
             "expected": {
                 "correct_a":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 8,
                                       7]).to(TEST_DEVICE),
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 6,
                                       7]).to(TEST_DEVICE),
                 ],
                 "correct_n": [
                     torch.LongTensor([7, 8, 5, 0, 8, 0, 0,
                                       1]).to(TEST_DEVICE)
                 ],
             },
         },
         "batch_easy_all": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.EASY,
                 neg_strategy=BatchEasyHardMiner.ALL,
             ),
             "easiest_triplet":
             0,
             "hardest_triplet":
             0,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             3,
             "easiest_neg_pair":
             8,
             "hardest_neg_pair":
             1,
             "expected": {
                 "correct_a1":
                 torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE),
                 "correct_p": [
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 8,
                                       7]).to(TEST_DEVICE),
                     torch.LongTensor([1, 0, 3, 2, 1, 7, 6,
                                       7]).to(TEST_DEVICE),
                 ],
                 "correct_a2":
                 self.a2_idx,
                 "correct_n": [self.n_idx],
             },
         },
         "batch_all_easy": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.ALL,
                 neg_strategy=BatchEasyHardMiner.EASY,
             ),
             "easiest_triplet":
             0,
             "hardest_triplet":
             0,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             6,
             "easiest_neg_pair":
             8,
             "hardest_neg_pair":
             3,
             "expected": {
                 "correct_a1":
                 self.a1_idx,
                 "correct_p": [self.p_idx],
                 "correct_a2":
                 torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7,
                                   8]).to(TEST_DEVICE),
                 "correct_n": [
                     torch.LongTensor([8, 8, 5, 0, 8, 0, 0, 0,
                                       0]).to(TEST_DEVICE),
                 ],
             },
         },
         "batch_all_all": {
             "miner":
             BatchEasyHardMiner(
                 distance=self.distance,
                 pos_strategy=BatchEasyHardMiner.ALL,
                 neg_strategy=BatchEasyHardMiner.ALL,
             ),
             "easiest_triplet":
             0,
             "hardest_triplet":
             0,
             "easiest_pos_pair":
             1,
             "hardest_pos_pair":
             6,
             "easiest_neg_pair":
             8,
             "hardest_neg_pair":
             1,
             "expected": {
                 "correct_a1": self.a1_idx,
                 "correct_p": [self.p_idx],
                 "correct_a2": self.a2_idx,
                 "correct_n": [self.n_idx],
             },
         },
     }
コード例 #8
0
    def test_per_anchor_reducer(self):
        for inner_reducer in [MeanReducer(), AvgNonZeroReducer()]:
            reducer = PerAnchorReducer(inner_reducer)
            batch_size = 100
            embedding_size = 64
            for dtype in TEST_DTYPES:
                embeddings = (
                    torch.randn(batch_size, embedding_size).type(dtype).to(TEST_DEVICE)
                )
                labels = torch.randint(0, 10, (batch_size,))
                pos_pair_indices = lmu.get_all_pairs_indices(labels)[:2]
                neg_pair_indices = lmu.get_all_pairs_indices(labels)[2:]
                triplet_indices = lmu.get_all_triplets_indices(labels)

                for indices, reduction_type in [
                    (torch.arange(batch_size), "element"),
                    (pos_pair_indices, "pos_pair"),
                    (neg_pair_indices, "neg_pair"),
                    (triplet_indices, "triplet"),
                ]:
                    loss_size = (
                        len(indices) if reduction_type == "element" else len(indices[0])
                    )
                    losses = torch.randn(loss_size).type(dtype).to(TEST_DEVICE)
                    loss_dict = {
                        "loss": {
                            "losses": losses,
                            "indices": indices,
                            "reduction_type": reduction_type,
                        }
                    }
                    if reduction_type == "triplet":
                        self.assertRaises(
                            NotImplementedError,
                            lambda: reducer(loss_dict, embeddings, labels),
                        )
                        continue

                    output = reducer(loss_dict, embeddings, labels)
                    if reduction_type == "element":
                        loss_dict = {
                            "loss": {
                                "losses": losses,
                                "indices": c_f.torch_arange_from_size(embeddings),
                                "reduction_type": "element",
                            }
                        }
                    else:
                        anchors = indices[0]
                        correct_output = torch.zeros(
                            batch_size, device=TEST_DEVICE, dtype=dtype
                        )
                        for i in range(len(embeddings)):
                            matching_pairs_mask = anchors == i
                            num_matching_pairs = torch.sum(matching_pairs_mask)
                            if num_matching_pairs > 0:
                                correct_output[i] = (
                                    torch.sum(losses[matching_pairs_mask])
                                    / num_matching_pairs
                                )
                        loss_dict = {
                            "loss": {
                                "losses": correct_output,
                                "indices": c_f.torch_arange_from_size(embeddings),
                                "reduction_type": "element",
                            }
                        }
                    correct_output = inner_reducer(loss_dict, embeddings, labels)
                    self.assertTrue(torch.isclose(output, correct_output))
コード例 #9
0
ファイル: cluster.py プロジェクト: AlexSchuy/hgcal-dev
def to_pairs(y):
    p1_inds, p2_inds, _, _ = get_all_pairs_indices(torch.as_tensor(y))
    return set(map(lambda x: tuple(sorted(x)), ((float(p1_inds[i]), float(p2_inds[i])) for i in range(len(p1_inds)))))