Пример #1
0
 def test_input_indices_tuple(self):
     batch_size = 32
     pair_miner = PairMarginMiner(pos_margin=0,
                                  neg_margin=1,
                                  use_similarity=False)
     triplet_miner = TripletMarginMiner(margin=1)
     self.loss = CrossBatchMemory(loss=ContrastiveLoss(),
                                  embedding_size=self.embedding_size,
                                  memory_size=self.memory_size)
     for i in range(30):
         embeddings = torch.randn(batch_size, self.embedding_size)
         labels = torch.arange(batch_size)
         self.loss(embeddings, labels)
         for curr_miner in [pair_miner, triplet_miner]:
             input_indices_tuple = curr_miner(embeddings, labels)
             all_labels = torch.cat([labels, self.loss.label_memory], dim=0)
             a1ii, pii, a2ii, nii = lmu.convert_to_pairs(
                 input_indices_tuple, labels)
             a1i, pi, a2i, ni = lmu.get_all_pairs_indices(
                 labels, self.loss.label_memory)
             a1, p, a2, n = self.loss.create_indices_tuple(
                 batch_size, embeddings, labels, self.loss.embedding_memory,
                 self.loss.label_memory, input_indices_tuple)
             self.assertTrue(not torch.any((all_labels[a1] -
                                            all_labels[p]).bool()))
             self.assertTrue(
                 torch.all((all_labels[a2] - all_labels[n]).bool()))
             self.assertTrue(len(a1) == len(a1i) + len(a1ii))
             self.assertTrue(len(p) == len(pi) + len(pii))
             self.assertTrue(len(a2) == len(a2i) + len(a2ii))
             self.assertTrue(len(n) == len(ni) + len(nii))
Пример #2
0
 def compute_loss(self, embeddings, labels, indices_tuple, pos_num):
     mat = lmu.get_pairwise_mat(embeddings[:pos_num, :], embeddings,
                                self.use_similarity, self.squared_distances)
     indices_tuple = lmu.convert_to_pairs(indices_tuple, labels)
     return self.loss_method(mat, labels, indices_tuple)