def setUpClass(self): self.dist_miner = PairMarginMiner(pos_margin=4, neg_margin=4, use_similarity=False, normalize_embeddings=False) self.normalized_dist_miner = PairMarginMiner(pos_margin=1.29, neg_margin=1.28, use_similarity=False, normalize_embeddings=True) self.normalized_dist_miner_squared = PairMarginMiner( pos_margin=1.66, neg_margin=1.64, use_similarity=False, normalize_embeddings=True, squared_distances=True) self.sim_miner = PairMarginMiner(pos_margin=0.17, neg_margin=0.18, use_similarity=True, normalize_embeddings=True) self.labels = torch.LongTensor([0, 0, 1, 1, 0, 2, 1, 1, 1]) self.correct_a1 = torch.LongTensor([2, 2, 3, 7, 8, 8]) self.correct_p = torch.LongTensor([7, 8, 8, 2, 2, 3]) self.correct_a2 = torch.LongTensor([ 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 7, 7, 8 ]) self.correct_n = torch.LongTensor([ 2, 3, 2, 3, 0, 1, 4, 5, 0, 1, 4, 5, 2, 3, 5, 6, 7, 2, 3, 4, 6, 7, 8, 4, 5, 4, 5, 5 ])
def test_input_indices_tuple(self): batch_size = 32 pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1, use_similarity=False) triplet_miner = TripletMarginMiner(margin=1) self.loss = CrossBatchMemory(loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size) for i in range(30): embeddings = torch.randn(batch_size, self.embedding_size) labels = torch.arange(batch_size) self.loss(embeddings, labels) for curr_miner in [pair_miner, triplet_miner]: input_indices_tuple = curr_miner(embeddings, labels) all_labels = torch.cat([labels, self.loss.label_memory], dim=0) a1ii, pii, a2ii, nii = lmu.convert_to_pairs( input_indices_tuple, labels) a1i, pi, a2i, ni = lmu.get_all_pairs_indices( labels, self.loss.label_memory) a1, p, a2, n = self.loss.create_indices_tuple( batch_size, embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, input_indices_tuple) self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue( torch.all((all_labels[a2] - all_labels[n]).bool())) self.assertTrue(len(a1) == len(a1i) + len(a1ii)) self.assertTrue(len(p) == len(pi) + len(pii)) self.assertTrue(len(a2) == len(a2i) + len(a2ii)) self.assertTrue(len(n) == len(ni) + len(nii))
def test_shift_indices_tuple(self): for dtype in TEST_DTYPES: batch_size = 32 pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1) triplet_miner = TripletMarginMiner(margin=1) self.loss = CrossBatchMemory( loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size, ) for i in range(30): embeddings = ( torch.randn(batch_size, self.embedding_size) .to(TEST_DEVICE) .type(dtype) ) labels = torch.arange(batch_size).to(TEST_DEVICE) loss = self.loss(embeddings, labels) all_labels = torch.cat([labels, self.loss.label_memory], dim=0) indices_tuple = lmu.get_all_pairs_indices( labels, self.loss.label_memory ) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[2], shifted[2])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size)) self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size)) a1, p, a2, n = shifted self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool())) indices_tuple = pair_miner( embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, ) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[2], shifted[2])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size)) self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size)) a1, p, a2, n = shifted self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool())) indices_tuple = triplet_miner( embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, ) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size)) self.assertTrue(torch.equal(indices_tuple[2], shifted[2] - batch_size)) a, p, n = shifted self.assertTrue(not torch.any((all_labels[a] - all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[p] - all_labels[n]).bool()))
def test_empty_output(self): miner = PairMarginMiner(0, 1) batch_size = 32 for dtype in TEST_DTYPES: embeddings = torch.randn(batch_size, 64).type(dtype).to(self.device) labels = torch.arange(batch_size) a, p, _, _ = miner(embeddings, labels) self.assertTrue(len(a)==0) self.assertTrue(len(p)==0)
def test_pair_margin_miner(self): for dtype in TEST_DTYPES: for distance in [LpDistance(), CosineSimilarity()]: embedding_angles = torch.arange(0, 16) embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.randint(low=0, high=2, size=(16,)) mat = distance(embeddings) pos_pairs = [] neg_pairs = [] for i in range(len(embeddings)): anchor_label = labels[i] for j in range(len(embeddings)): if j == i: continue positive_label = labels[j] if positive_label == anchor_label: ap_dist = mat[i,j] pos_pairs.append((i, j, ap_dist)) for i in range(len(embeddings)): anchor_label = labels[i] for j in range(len(embeddings)): if j == i: continue negative_label = labels[j] if negative_label != anchor_label: an_dist = mat[i,j] neg_pairs.append((i, j, an_dist)) for pos_margin_int in range(-1, 4): pos_margin = float(pos_margin_int) * 0.05 for neg_margin_int in range(2, 7): neg_margin = float(neg_margin_int) * 0.05 miner = PairMarginMiner(pos_margin, neg_margin, distance=distance) correct_pos_pairs = [] correct_neg_pairs = [] for i,j,k in pos_pairs: condition = (k < pos_margin) if distance.is_inverted else (k > pos_margin) if condition: correct_pos_pairs.append((i,j)) for i,j,k in neg_pairs: condition = (k > neg_margin) if distance.is_inverted else (k < neg_margin) if condition: correct_neg_pairs.append((i,j)) correct_pos = set(correct_pos_pairs) correct_neg = set(correct_neg_pairs) a1, p1, a2, n2 = miner(embeddings, labels) mined_pos = set([(a.item(),p.item()) for a,p in zip(a1,p1)]) mined_neg = set([(a.item(),n.item()) for a,n in zip(a2,n2)]) self.assertTrue(mined_pos == correct_pos) self.assertTrue(mined_neg == correct_neg)