def test_sanity_check(self): # cross batch memory with batch_size == memory_size should be equivalent to just using the inner loss function for dtype in TEST_DTYPES: for test_enqueue_idx in [False, True]: for memory_size in range(20, 40, 5): inner_loss = NTXentLoss(temperature=0.1) inner_miner = TripletMarginMiner(margin=0.1) loss = CrossBatchMemory( loss=inner_loss, embedding_size=self.embedding_size, memory_size=memory_size, ) loss_with_miner = CrossBatchMemory( loss=inner_loss, embedding_size=self.embedding_size, memory_size=memory_size, miner=inner_miner, ) for i in range(10): if test_enqueue_idx: enqueue_idx = torch.arange(memory_size, memory_size * 2) not_enqueue_idx = torch.arange(memory_size) batch_size = memory_size * 2 else: enqueue_idx = None batch_size = memory_size embeddings = ( torch.randn(batch_size, self.embedding_size) .to(TEST_DEVICE) .type(dtype) ) labels = torch.randint(0, 4, (batch_size,)).to(TEST_DEVICE) if test_enqueue_idx: pairs = lmu.get_all_pairs_indices( labels[not_enqueue_idx], labels[enqueue_idx] ) pairs = c_f.shift_indices_tuple(pairs, memory_size) inner_loss_val = inner_loss(embeddings, labels, pairs) else: inner_loss_val = inner_loss(embeddings, labels) loss_val = loss(embeddings, labels, enqueue_idx=enqueue_idx) self.assertTrue(torch.isclose(inner_loss_val, loss_val)) if test_enqueue_idx: triplets = inner_miner( embeddings[not_enqueue_idx], labels[not_enqueue_idx], embeddings[enqueue_idx], labels[enqueue_idx], ) triplets = c_f.shift_indices_tuple(triplets, memory_size) inner_loss_val = inner_loss(embeddings, labels, triplets) else: triplets = inner_miner(embeddings, labels) inner_loss_val = inner_loss(embeddings, labels, triplets) loss_val = loss_with_miner( embeddings, labels, enqueue_idx=enqueue_idx ) self.assertTrue(torch.isclose(inner_loss_val, loss_val))
def test_input_indices_tuple(self): batch_size = 32 pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1, use_similarity=False) triplet_miner = TripletMarginMiner(margin=1) self.loss = CrossBatchMemory(loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size) for i in range(30): embeddings = torch.randn(batch_size, self.embedding_size) labels = torch.arange(batch_size) self.loss(embeddings, labels) for curr_miner in [pair_miner, triplet_miner]: input_indices_tuple = curr_miner(embeddings, labels) all_labels = torch.cat([labels, self.loss.label_memory], dim=0) a1ii, pii, a2ii, nii = lmu.convert_to_pairs( input_indices_tuple, labels) a1i, pi, a2i, ni = lmu.get_all_pairs_indices( labels, self.loss.label_memory) a1, p, a2, n = self.loss.create_indices_tuple( batch_size, embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, input_indices_tuple) self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue( torch.all((all_labels[a2] - all_labels[n]).bool())) self.assertTrue(len(a1) == len(a1i) + len(a1ii)) self.assertTrue(len(p) == len(pi) + len(pii)) self.assertTrue(len(a2) == len(a2i) + len(a2ii)) self.assertTrue(len(n) == len(ni) + len(nii))
def test_shift_indices_tuple(self): for dtype in TEST_DTYPES: batch_size = 32 pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1) triplet_miner = TripletMarginMiner(margin=1) self.loss = CrossBatchMemory( loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size, ) for i in range(30): embeddings = ( torch.randn(batch_size, self.embedding_size) .to(TEST_DEVICE) .type(dtype) ) labels = torch.arange(batch_size).to(TEST_DEVICE) loss = self.loss(embeddings, labels) all_labels = torch.cat([labels, self.loss.label_memory], dim=0) indices_tuple = lmu.get_all_pairs_indices( labels, self.loss.label_memory ) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[2], shifted[2])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size)) self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size)) a1, p, a2, n = shifted self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool())) indices_tuple = pair_miner( embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, ) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[2], shifted[2])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size)) self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size)) a1, p, a2, n = shifted self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool())) indices_tuple = triplet_miner( embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, ) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size)) self.assertTrue(torch.equal(indices_tuple[2], shifted[2] - batch_size)) a, p, n = shifted self.assertTrue(not torch.any((all_labels[a] - all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[p] - all_labels[n]).bool()))
def test_uniform_histogram_miner(self): torch.manual_seed(93612) batch_size = 128 embedding_size = 32 num_bins, pos_per_bin, neg_per_bin = 100, 25, 123 for distance in [ LpDistance(p=1), LpDistance(p=2), LpDistance(normalize_embeddings=False), SNRDistance(), ]: miner = UniformHistogramMiner( num_bins=num_bins, pos_per_bin=pos_per_bin, neg_per_bin=neg_per_bin, distance=distance, ) for dtype in TEST_DTYPES: embeddings = torch.randn(batch_size, embedding_size, device=TEST_DEVICE, dtype=dtype) labels = torch.randint(0, 2, size=(batch_size, ), device=TEST_DEVICE) a1, p, a2, n = lmu.get_all_pairs_indices(labels) dist_mat = distance(embeddings) pos_pairs = dist_mat[a1, p] neg_pairs = dist_mat[a2, n] a1, p, a2, n = miner(embeddings, labels) if dtype == torch.float16: continue # histc doesn't work for Half tensor pos_histogram = torch.histc( dist_mat[a1, p], bins=num_bins, min=torch.min(pos_pairs), max=torch.max(pos_pairs), ) neg_histogram = torch.histc( dist_mat[a2, n], bins=num_bins, min=torch.min(neg_pairs), max=torch.max(neg_pairs), ) self.assertTrue( torch.all((pos_histogram == pos_per_bin) | (pos_histogram == 0))) self.assertTrue( torch.all((neg_histogram == neg_per_bin) | (neg_histogram == 0)))
def test_get_all_pairs_triplets_indices(self): original_x = torch.arange(10) for i in range(1, 11): x = original_x.repeat(i) correct_num_pos = len(x)*(i-1) correct_num_neg = len(x)*(len(x)-i) a1, p, a2, n = lmu.get_all_pairs_indices(x) self.assertTrue(len(a1) == len(p) == correct_num_pos) self.assertTrue(len(a2) == len(n) == correct_num_neg) correct_num_triplets = len(x)*(i-1)*(len(x)-i) a, p, n = lmu.get_all_triplets_indices(x) self.assertTrue(len(a) == len(p) == len(n) == correct_num_triplets)
def test_loss(self): num_labels = 10 num_iter = 10 batch_size = 32 inner_loss = ContrastiveLoss() inner_miner = MultiSimilarityMiner(0.3) outer_miner = MultiSimilarityMiner(0.2) self.loss = CrossBatchMemory(loss=inner_loss, embedding_size=self.embedding_size, memory_size=self.memory_size) self.loss_with_miner = CrossBatchMemory(loss=inner_loss, miner=inner_miner, embedding_size=self.embedding_size, memory_size=self.memory_size) self.loss_with_miner2 = CrossBatchMemory(loss=inner_loss, miner=inner_miner, embedding_size=self.embedding_size, memory_size=self.memory_size) all_embeddings = torch.FloatTensor([]) all_labels = torch.LongTensor([]) for i in range(num_iter): embeddings = torch.randn(batch_size, self.embedding_size) labels = torch.randint(0,num_labels,(batch_size,)) loss = self.loss(embeddings, labels) loss_with_miner = self.loss_with_miner(embeddings, labels) oa1, op, oa2, on = outer_miner(embeddings, labels) loss_with_miner_and_input_indices = self.loss_with_miner2(embeddings, labels, (oa1, op, oa2, on)) all_embeddings = torch.cat([all_embeddings, embeddings]) all_labels = torch.cat([all_labels, labels]) # loss with no inner miner indices_tuple = lmu.get_all_pairs_indices(labels, all_labels) a1,p,a2,n = self.loss.remove_self_comparisons(indices_tuple) p = p+batch_size n = n+batch_size correct_loss = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n)) self.assertTrue(torch.isclose(loss, correct_loss)) # loss with inner miner indices_tuple = inner_miner(embeddings, labels, all_embeddings, all_labels) a1,p,a2,n = self.loss_with_miner.remove_self_comparisons(indices_tuple) p = p+batch_size n = n+batch_size correct_loss_with_miner = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n)) self.assertTrue(torch.isclose(loss_with_miner, correct_loss_with_miner)) # loss with inner and outer miner indices_tuple = inner_miner(embeddings, labels, all_embeddings, all_labels) a1,p,a2,n = self.loss_with_miner2.remove_self_comparisons(indices_tuple) p = p+batch_size n = n+batch_size a1 = torch.cat([oa1, a1]) p = torch.cat([op, p]) a2 = torch.cat([oa2, a2]) n = torch.cat([on, n]) correct_loss_with_miner_and_input_indice = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n)) self.assertTrue(torch.isclose(loss_with_miner_and_input_indices, correct_loss_with_miner_and_input_indice))
def setUpClass(self): self.labels = torch.LongTensor([0, 0, 1, 1, 0, 2, 1, 1, 1]) self.a1_idx, self.p_idx, self.a2_idx, self.n_idx = lmu.get_all_pairs_indices( self.labels) self.distance = LpDistance(normalize_embeddings=False) self.gt = { "batch_semihard_hard": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.SEMIHARD, neg_strategy=BatchEasyHardMiner.HARD, ), "easiest_triplet": -1, "hardest_triplet": -1, "easiest_pos_pair": 1, "hardest_pos_pair": 2, "easiest_neg_pair": 3, "hardest_neg_pair": 2, "expected": { "correct_a": torch.LongTensor([0, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([1, 6, 6]).to(TEST_DEVICE), torch.LongTensor([1, 8, 6]).to(TEST_DEVICE), ], "correct_n": [ torch.LongTensor([2, 5, 5]).to(TEST_DEVICE), torch.LongTensor([2, 5, 5]).to(TEST_DEVICE), ], }, }, "batch_hard_semihard": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.HARD, neg_strategy=BatchEasyHardMiner.SEMIHARD, ), "easiest_triplet": -1, "hardest_triplet": -1, "easiest_pos_pair": 3, "hardest_pos_pair": 6, "easiest_neg_pair": 7, "hardest_neg_pair": 4, "expected": { "correct_a": torch.LongTensor([0, 1, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [torch.LongTensor([4, 4, 2, 2, 2]).to(TEST_DEVICE)], "correct_n": [ torch.LongTensor([5, 5, 1, 1, 1]).to(TEST_DEVICE), ], }, }, "batch_easy_semihard": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.EASY, neg_strategy=BatchEasyHardMiner.SEMIHARD, ), "easiest_triplet": -2, "hardest_triplet": -1, "easiest_pos_pair": 1, "hardest_pos_pair": 3, "easiest_neg_pair": 4, "hardest_neg_pair": 2, "expected": { "correct_a": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([1, 0, 3, 2, 1, 7, 8, 7]).to(TEST_DEVICE), torch.LongTensor([1, 0, 3, 2, 1, 7, 6, 7]).to(TEST_DEVICE), ], "correct_n": [ torch.LongTensor([2, 3, 0, 1, 8, 4, 5, 5]).to(TEST_DEVICE), torch.LongTensor([2, 3, 4, 1, 8, 4, 5, 5]).to(TEST_DEVICE), torch.LongTensor([2, 3, 0, 5, 8, 4, 5, 5]).to(TEST_DEVICE), torch.LongTensor([2, 3, 4, 5, 8, 4, 5, 5]).to(TEST_DEVICE), ], }, }, "batch_hard_hard": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.HARD, neg_strategy=BatchEasyHardMiner.HARD, ), "easiest_triplet": 2, "hardest_triplet": 5, "easiest_pos_pair": 3, "hardest_pos_pair": 6, "easiest_neg_pair": 3, "hardest_neg_pair": 1, "expected": { "correct_a": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([4, 4, 8, 8, 0, 2, 2, 2]).to(TEST_DEVICE) ], "correct_n": [ torch.LongTensor([2, 2, 1, 4, 3, 5, 5, 5]).to(TEST_DEVICE), torch.LongTensor([2, 2, 1, 4, 5, 5, 5, 5]).to(TEST_DEVICE), ], }, }, "batch_easy_hard": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.EASY, neg_strategy=BatchEasyHardMiner.HARD, ), "easiest_triplet": -2, "hardest_triplet": 2, "easiest_pos_pair": 1, "hardest_pos_pair": 3, "easiest_neg_pair": 3, "hardest_neg_pair": 1, "expected": { "correct_a": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([1, 0, 3, 2, 1, 7, 8, 7]).to(TEST_DEVICE), torch.LongTensor([1, 0, 3, 2, 1, 7, 6, 7]).to(TEST_DEVICE), ], "correct_n": [ torch.LongTensor([2, 2, 1, 4, 3, 5, 5, 5]).to(TEST_DEVICE), torch.LongTensor([2, 2, 1, 4, 5, 5, 5, 5]).to(TEST_DEVICE), ], }, }, "batch_hard_easy": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.HARD, neg_strategy=BatchEasyHardMiner.EASY, ), "easiest_triplet": -4, "hardest_triplet": 3, "easiest_pos_pair": 3, "hardest_pos_pair": 6, "easiest_neg_pair": 8, "hardest_neg_pair": 3, "expected": { "correct_a": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([4, 4, 8, 8, 0, 2, 2, 2]).to(TEST_DEVICE) ], "correct_n": [ torch.LongTensor([8, 8, 5, 0, 8, 0, 0, 0]).to(TEST_DEVICE) ], }, }, "batch_easy_easy": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.EASY, neg_strategy=BatchEasyHardMiner.EASY, ), "easiest_triplet": -7, "hardest_triplet": -1, "easiest_pos_pair": 1, "hardest_pos_pair": 3, "easiest_neg_pair": 8, "hardest_neg_pair": 3, "expected": { "correct_a": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([1, 0, 3, 2, 1, 7, 8, 7]).to(TEST_DEVICE), torch.LongTensor([1, 0, 3, 2, 1, 7, 6, 7]).to(TEST_DEVICE), ], "correct_n": [ torch.LongTensor([8, 8, 5, 0, 8, 0, 0, 0]).to(TEST_DEVICE) ], }, }, "batch_easy_easy_with_min_val": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.EASY, neg_strategy=BatchEasyHardMiner.EASY, allowed_neg_range=[1, 7], allowed_pos_range=[1, 7], ), "easiest_triplet": -6, "hardest_triplet": -1, "easiest_pos_pair": 1, "hardest_pos_pair": 3, "easiest_neg_pair": 7, "hardest_neg_pair": 3, "expected": { "correct_a": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([1, 0, 3, 2, 1, 7, 8, 7]).to(TEST_DEVICE), torch.LongTensor([1, 0, 3, 2, 1, 7, 6, 7]).to(TEST_DEVICE), ], "correct_n": [ torch.LongTensor([7, 8, 5, 0, 8, 0, 0, 1]).to(TEST_DEVICE) ], }, }, "batch_easy_all": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.EASY, neg_strategy=BatchEasyHardMiner.ALL, ), "easiest_triplet": 0, "hardest_triplet": 0, "easiest_pos_pair": 1, "hardest_pos_pair": 3, "easiest_neg_pair": 8, "hardest_neg_pair": 1, "expected": { "correct_a1": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([1, 0, 3, 2, 1, 7, 8, 7]).to(TEST_DEVICE), torch.LongTensor([1, 0, 3, 2, 1, 7, 6, 7]).to(TEST_DEVICE), ], "correct_a2": self.a2_idx, "correct_n": [self.n_idx], }, }, "batch_all_easy": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.ALL, neg_strategy=BatchEasyHardMiner.EASY, ), "easiest_triplet": 0, "hardest_triplet": 0, "easiest_pos_pair": 1, "hardest_pos_pair": 6, "easiest_neg_pair": 8, "hardest_neg_pair": 3, "expected": { "correct_a1": self.a1_idx, "correct_p": [self.p_idx], "correct_a2": torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8]).to(TEST_DEVICE), "correct_n": [ torch.LongTensor([8, 8, 5, 0, 8, 0, 0, 0, 0]).to(TEST_DEVICE), ], }, }, "batch_all_all": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.ALL, neg_strategy=BatchEasyHardMiner.ALL, ), "easiest_triplet": 0, "hardest_triplet": 0, "easiest_pos_pair": 1, "hardest_pos_pair": 6, "easiest_neg_pair": 8, "hardest_neg_pair": 1, "expected": { "correct_a1": self.a1_idx, "correct_p": [self.p_idx], "correct_a2": self.a2_idx, "correct_n": [self.n_idx], }, }, }
def test_per_anchor_reducer(self): for inner_reducer in [MeanReducer(), AvgNonZeroReducer()]: reducer = PerAnchorReducer(inner_reducer) batch_size = 100 embedding_size = 64 for dtype in TEST_DTYPES: embeddings = ( torch.randn(batch_size, embedding_size).type(dtype).to(TEST_DEVICE) ) labels = torch.randint(0, 10, (batch_size,)) pos_pair_indices = lmu.get_all_pairs_indices(labels)[:2] neg_pair_indices = lmu.get_all_pairs_indices(labels)[2:] triplet_indices = lmu.get_all_triplets_indices(labels) for indices, reduction_type in [ (torch.arange(batch_size), "element"), (pos_pair_indices, "pos_pair"), (neg_pair_indices, "neg_pair"), (triplet_indices, "triplet"), ]: loss_size = ( len(indices) if reduction_type == "element" else len(indices[0]) ) losses = torch.randn(loss_size).type(dtype).to(TEST_DEVICE) loss_dict = { "loss": { "losses": losses, "indices": indices, "reduction_type": reduction_type, } } if reduction_type == "triplet": self.assertRaises( NotImplementedError, lambda: reducer(loss_dict, embeddings, labels), ) continue output = reducer(loss_dict, embeddings, labels) if reduction_type == "element": loss_dict = { "loss": { "losses": losses, "indices": c_f.torch_arange_from_size(embeddings), "reduction_type": "element", } } else: anchors = indices[0] correct_output = torch.zeros( batch_size, device=TEST_DEVICE, dtype=dtype ) for i in range(len(embeddings)): matching_pairs_mask = anchors == i num_matching_pairs = torch.sum(matching_pairs_mask) if num_matching_pairs > 0: correct_output[i] = ( torch.sum(losses[matching_pairs_mask]) / num_matching_pairs ) loss_dict = { "loss": { "losses": correct_output, "indices": c_f.torch_arange_from_size(embeddings), "reduction_type": "element", } } correct_output = inner_reducer(loss_dict, embeddings, labels) self.assertTrue(torch.isclose(output, correct_output))
def to_pairs(y): p1_inds, p2_inds, _, _ = get_all_pairs_indices(torch.as_tensor(y)) return set(map(lambda x: tuple(sorted(x)), ((float(p1_inds[i]), float(p2_inds[i])) for i in range(len(p1_inds)))))