def test_input_indices_tuple(self): batch_size = 32 pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1, use_similarity=False) triplet_miner = TripletMarginMiner(margin=1) self.loss = CrossBatchMemory(loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size) for i in range(30): embeddings = torch.randn(batch_size, self.embedding_size) labels = torch.arange(batch_size) self.loss(embeddings, labels) for curr_miner in [pair_miner, triplet_miner]: input_indices_tuple = curr_miner(embeddings, labels) all_labels = torch.cat([labels, self.loss.label_memory], dim=0) a1ii, pii, a2ii, nii = lmu.convert_to_pairs( input_indices_tuple, labels) a1i, pi, a2i, ni = lmu.get_all_pairs_indices( labels, self.loss.label_memory) a1, p, a2, n = self.loss.create_indices_tuple( batch_size, embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, input_indices_tuple) self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue( torch.all((all_labels[a2] - all_labels[n]).bool())) self.assertTrue(len(a1) == len(a1i) + len(a1ii)) self.assertTrue(len(p) == len(pi) + len(pii)) self.assertTrue(len(a2) == len(a2i) + len(a2ii)) self.assertTrue(len(n) == len(ni) + len(nii))
def compute_loss(self, embeddings, labels, indices_tuple, pos_num): mat = lmu.get_pairwise_mat(embeddings[:pos_num, :], embeddings, self.use_similarity, self.squared_distances) indices_tuple = lmu.convert_to_pairs(indices_tuple, labels) return self.loss_method(mat, labels, indices_tuple)