def test_with_no_valid_pairs(self): lossA = ContrastiveLoss(use_similarity=False) lossB = ContrastiveLoss(use_similarity=True) embedding_angles = [0] embeddings = torch.FloatTensor( [c_f.angle_to_coord(a) for a in embedding_angles]) #2D embeddings labels = torch.LongTensor([0]) self.assertEqual(lossA(embeddings, labels), 0) self.assertEqual(lossB(embeddings, labels), 0)
def test_backward(self): loss_funcA = ContrastiveLoss(use_similarity=False) loss_funcB = ContrastiveLoss(use_similarity=True) for dtype in [torch.float16, torch.float32, torch.float64]: for loss_func in [loss_funcA, loss_funcB]: embedding_angles = [0] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0]) loss = loss_func(embeddings, labels) loss.backward()
def test_backward(self): loss_funcA = ContrastiveLoss() loss_funcB = ContrastiveLoss(distance=CosineSimilarity()) for dtype in TEST_DTYPES: for loss_func in [loss_funcA, loss_funcB]: embedding_angles = [0] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0]) loss = loss_func(embeddings, labels) loss.backward()
def test_with_no_valid_pairs(self): loss_funcA = ContrastiveLoss() loss_funcB = ContrastiveLoss(distance=CosineSimilarity()) for dtype in TEST_DTYPES: embedding_angles = [0] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) self.assertEqual(lossA, 0) self.assertEqual(lossB, 0)
def test_with_no_valid_pairs(self): loss_funcA = ContrastiveLoss(use_similarity=False) loss_funcB = ContrastiveLoss(use_similarity=True) for dtype in [torch.float16, torch.float32, torch.float64]: embedding_angles = [0] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) self.assertEqual(lossA, 0) self.assertEqual(lossB, 0)
def test_input_indices_tuple(self): batch_size = 32 pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1, use_similarity=False) triplet_miner = TripletMarginMiner(margin=1) self.loss = CrossBatchMemory(loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size) for i in range(30): embeddings = torch.randn(batch_size, self.embedding_size) labels = torch.arange(batch_size) self.loss(embeddings, labels) for curr_miner in [pair_miner, triplet_miner]: input_indices_tuple = curr_miner(embeddings, labels) all_labels = torch.cat([labels, self.loss.label_memory], dim=0) a1ii, pii, a2ii, nii = lmu.convert_to_pairs( input_indices_tuple, labels) a1i, pi, a2i, ni = lmu.get_all_pairs_indices( labels, self.loss.label_memory) a1, p, a2, n = self.loss.create_indices_tuple( batch_size, embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, input_indices_tuple) self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue( torch.all((all_labels[a2] - all_labels[n]).bool())) self.assertTrue(len(a1) == len(a1i) + len(a1ii)) self.assertTrue(len(p) == len(pi) + len(pii)) self.assertTrue(len(a2) == len(a2i) + len(a2ii)) self.assertTrue(len(n) == len(ni) + len(nii))
def test_key_mismatch(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) self.assertRaises( AssertionError, lambda: MultipleLosses( losses={ "lossA": lossA, "lossB": lossB }, weights={ "blah": 1, "lossB": 0.23 }, ), ) minerA = MultiSimilarityMiner() self.assertRaises( AssertionError, lambda: MultipleLosses( losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }, miners={"blah": minerA}, ), )
def test_multiple_losses(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) loss_func = MultipleLosses(losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }) for dtype in TEST_DTYPES: embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) loss = loss_func(embeddings, labels) loss.backward() correct_loss = lossA(embeddings, labels) + lossB(embeddings, labels) * 0.23 self.assertTrue(torch.isclose(loss, correct_loss))
def test_shift_indices_tuple(self): for dtype in TEST_DTYPES: batch_size = 32 pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1) triplet_miner = TripletMarginMiner(margin=1) self.loss = CrossBatchMemory( loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size, ) for i in range(30): embeddings = ( torch.randn(batch_size, self.embedding_size) .to(TEST_DEVICE) .type(dtype) ) labels = torch.arange(batch_size).to(TEST_DEVICE) loss = self.loss(embeddings, labels) all_labels = torch.cat([labels, self.loss.label_memory], dim=0) indices_tuple = lmu.get_all_pairs_indices( labels, self.loss.label_memory ) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[2], shifted[2])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size)) self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size)) a1, p, a2, n = shifted self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool())) indices_tuple = pair_miner( embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, ) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[2], shifted[2])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size)) self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size)) a1, p, a2, n = shifted self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool())) indices_tuple = triplet_miner( embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, ) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size)) self.assertTrue(torch.equal(indices_tuple[2], shifted[2] - batch_size)) a, p, n = shifted self.assertTrue(not torch.any((all_labels[a] - all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[p] - all_labels[n]).bool()))
def test_input_indices_tuple(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) miner = MultiSimilarityMiner() loss_func1 = MultipleLosses(losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }) loss_func2 = MultipleLosses(losses=[lossA, lossB], weights=[1, 0.23]) for loss_func in [loss_func1, loss_func2]: for dtype in TEST_DTYPES: embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) indices_tuple = miner(embeddings, labels) loss = loss_func(embeddings, labels, indices_tuple) loss.backward() correct_loss = ( lossA(embeddings, labels, indices_tuple) + lossB(embeddings, labels, indices_tuple) * 0.23) self.assertTrue(torch.isclose(loss, correct_loss))
def test_deepcopy_reducer(self): loss_fn = ContrastiveLoss(pos_margin=0, neg_margin=2, reducer=AvgNonZeroReducer()) embeddings = torch.randn(128, 64) labels = torch.randint(low=0, high=10, size=(128, )) loss = loss_fn(embeddings, labels) self.assertTrue( loss_fn.reducer.reducers["pos_loss"].pos_pairs_past_filter > 0) self.assertTrue( loss_fn.reducer.reducers["neg_loss"].neg_pairs_past_filter > 0)
def test_loss(self): num_labels = 10 num_iter = 10 batch_size = 32 inner_loss = ContrastiveLoss() inner_miner = MultiSimilarityMiner(0.3) outer_miner = MultiSimilarityMiner(0.2) self.loss = CrossBatchMemory(loss=inner_loss, embedding_size=self.embedding_size, memory_size=self.memory_size) self.loss_with_miner = CrossBatchMemory(loss=inner_loss, miner=inner_miner, embedding_size=self.embedding_size, memory_size=self.memory_size) self.loss_with_miner2 = CrossBatchMemory(loss=inner_loss, miner=inner_miner, embedding_size=self.embedding_size, memory_size=self.memory_size) all_embeddings = torch.FloatTensor([]) all_labels = torch.LongTensor([]) for i in range(num_iter): embeddings = torch.randn(batch_size, self.embedding_size) labels = torch.randint(0,num_labels,(batch_size,)) loss = self.loss(embeddings, labels) loss_with_miner = self.loss_with_miner(embeddings, labels) oa1, op, oa2, on = outer_miner(embeddings, labels) loss_with_miner_and_input_indices = self.loss_with_miner2(embeddings, labels, (oa1, op, oa2, on)) all_embeddings = torch.cat([all_embeddings, embeddings]) all_labels = torch.cat([all_labels, labels]) # loss with no inner miner indices_tuple = lmu.get_all_pairs_indices(labels, all_labels) a1,p,a2,n = self.loss.remove_self_comparisons(indices_tuple) p = p+batch_size n = n+batch_size correct_loss = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n)) self.assertTrue(torch.isclose(loss, correct_loss)) # loss with inner miner indices_tuple = inner_miner(embeddings, labels, all_embeddings, all_labels) a1,p,a2,n = self.loss_with_miner.remove_self_comparisons(indices_tuple) p = p+batch_size n = n+batch_size correct_loss_with_miner = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n)) self.assertTrue(torch.isclose(loss_with_miner, correct_loss_with_miner)) # loss with inner and outer miner indices_tuple = inner_miner(embeddings, labels, all_embeddings, all_labels) a1,p,a2,n = self.loss_with_miner2.remove_self_comparisons(indices_tuple) p = p+batch_size n = n+batch_size a1 = torch.cat([oa1, a1]) p = torch.cat([op, p]) a2 = torch.cat([oa2, a2]) n = torch.cat([on, n]) correct_loss_with_miner_and_input_indice = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n)) self.assertTrue(torch.isclose(loss_with_miner_and_input_indices, correct_loss_with_miner_and_input_indice))
def test_length_mistmatch(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) self.assertRaises( AssertionError, lambda: MultipleLosses(losses=[lossA, lossB], weights=[1])) minerA = MultiSimilarityMiner() self.assertRaises( AssertionError, lambda: MultipleLosses( losses=[lossA, lossB], weights=[1, 0.2], miners=[minerA], ), )
def test_queue(self): for dtype in [torch.float16, torch.float32, torch.float64]: batch_size = 32 self.loss = CrossBatchMemory(loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size) for i in range(30): embeddings = torch.randn(batch_size, self.embedding_size).to( self.device).type(dtype) labels = torch.arange(batch_size).to(self.device) q = self.loss.queue_idx self.assertTrue(q == (i * batch_size) % self.memory_size) loss = self.loss(embeddings, labels) start_idx = q if q + batch_size == self.memory_size: end_idx = self.memory_size else: end_idx = (q + batch_size) % self.memory_size if start_idx < end_idx: self.assertTrue( torch.equal( embeddings, self.loss.embedding_memory[start_idx:end_idx])) self.assertTrue( torch.equal(labels, self.loss.label_memory[start_idx:end_idx])) else: correct_embeddings = torch.cat([ self.loss.embedding_memory[start_idx:], self.loss.embedding_memory[:end_idx] ], dim=0) correct_labels = torch.cat([ self.loss.label_memory[start_idx:], self.loss.label_memory[:end_idx] ], dim=0) self.assertTrue(torch.equal(embeddings, correct_embeddings)) self.assertTrue(torch.equal(labels, correct_labels))
def test_contrastive_loss(self): loss_funcA = ContrastiveLoss(pos_margin=0.25, neg_margin=1.5, use_similarity=False, avg_non_zero_only=True, squared_distances=True) loss_funcB = ContrastiveLoss(pos_margin=1.5, neg_margin=0.6, use_similarity=True, avg_non_zero_only=True) loss_funcC = ContrastiveLoss(pos_margin=0.25, neg_margin=1.5, use_similarity=False, avg_non_zero_only=False, squared_distances=True) loss_funcD = ContrastiveLoss(pos_margin=1.5, neg_margin=0.6, use_similarity=True, avg_non_zero_only=False) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.FloatTensor( [c_f.angle_to_coord(a) for a in embedding_angles]) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) lossC = loss_funcC(embeddings, labels) lossD = loss_funcD(embeddings, labels) pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)] neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3)] correct_pos_losses = [0, 0, 0, 0] correct_neg_losses = [0, 0, 0, 0] num_non_zero_pos = [0, 0, 0, 0] num_non_zero_neg = [0, 0, 0, 0] for a, p in pos_pairs: anchor, positive = embeddings[a], embeddings[p] correct_lossA = torch.relu( torch.sum((anchor - positive)**2) - 0.25) correct_lossB = torch.relu(1.5 - torch.matmul(anchor, positive)) correct_pos_losses[0] += correct_lossA correct_pos_losses[1] += correct_lossB correct_pos_losses[2] += correct_lossA correct_pos_losses[3] += correct_lossB if correct_lossA > 0: num_non_zero_pos[0] += 1 num_non_zero_pos[2] += 1 if correct_lossB > 0: num_non_zero_pos[1] += 1 num_non_zero_pos[3] += 1 for a, n in neg_pairs: anchor, negative = embeddings[a], embeddings[n] correct_lossA = torch.relu(1.5 - torch.sum((anchor - negative)**2)) correct_lossB = torch.relu(torch.matmul(anchor, negative) - 0.6) correct_neg_losses[0] += correct_lossA correct_neg_losses[1] += correct_lossB correct_neg_losses[2] += correct_lossA correct_neg_losses[3] += correct_lossB if correct_lossA > 0: num_non_zero_neg[0] += 1 num_non_zero_neg[2] += 1 if correct_lossB > 0: num_non_zero_neg[1] += 1 num_non_zero_neg[3] += 1 for i in range(2): if num_non_zero_pos[i] > 0: correct_pos_losses[i] /= num_non_zero_pos[i] if num_non_zero_neg[i] > 0: correct_neg_losses[i] /= num_non_zero_neg[i] for i in range(2, 4): correct_pos_losses[i] /= len(pos_pairs) correct_neg_losses[i] /= len(neg_pairs) correct_losses = [0, 0, 0, 0] for i in range(4): correct_losses[i] = correct_pos_losses[i] + correct_neg_losses[i] self.assertTrue(torch.isclose(lossA, correct_losses[0])) self.assertTrue(torch.isclose(lossB, correct_losses[1])) self.assertTrue(torch.isclose(lossC, correct_losses[2])) self.assertTrue(torch.isclose(lossD, correct_losses[3]))
def test_queue(self): for test_enqueue_idx in [False, True]: for dtype in TEST_DTYPES: batch_size = 32 enqueue_batch_size = 15 self.loss = CrossBatchMemory( loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size, ) for i in range(30): embeddings = (torch.randn( batch_size, self.embedding_size).to(TEST_DEVICE).type(dtype)) labels = torch.arange(batch_size).to(TEST_DEVICE) q = self.loss.queue_idx B = enqueue_batch_size if test_enqueue_idx else batch_size if test_enqueue_idx: enqueue_idx = torch.arange(enqueue_batch_size) * 2 else: enqueue_idx = None self.assertTrue(q == (i * B) % self.memory_size) loss = self.loss(embeddings, labels, enqueue_idx=enqueue_idx) start_idx = q if q + B == self.memory_size: end_idx = self.memory_size else: end_idx = (q + B) % self.memory_size if start_idx < end_idx: if test_enqueue_idx: self.assertTrue( torch.equal( embeddings[enqueue_idx], self.loss. embedding_memory[start_idx:end_idx], )) self.assertTrue( torch.equal( labels[enqueue_idx], self.loss.label_memory[start_idx:end_idx], )) else: self.assertTrue( torch.equal( embeddings, self.loss. embedding_memory[start_idx:end_idx], )) self.assertTrue( torch.equal( labels, self.loss.label_memory[start_idx:end_idx])) else: correct_embeddings = torch.cat( [ self.loss.embedding_memory[start_idx:], self.loss.embedding_memory[:end_idx], ], dim=0, ) correct_labels = torch.cat( [ self.loss.label_memory[start_idx:], self.loss.label_memory[:end_idx], ], dim=0, ) if test_enqueue_idx: self.assertTrue( torch.equal(embeddings[enqueue_idx], correct_embeddings)) self.assertTrue( torch.equal(labels[enqueue_idx], correct_labels)) else: self.assertTrue( torch.equal(embeddings, correct_embeddings)) self.assertTrue(torch.equal( labels, correct_labels))
def test_contrastive_loss(self): loss_funcA = ContrastiveLoss(pos_margin=0.25, neg_margin=1.5, distance=LpDistance(power=2)) loss_funcB = ContrastiveLoss(pos_margin=1.5, neg_margin=0.6, distance=CosineSimilarity()) loss_funcC = ContrastiveLoss(pos_margin=0.25, neg_margin=1.5, distance=LpDistance(power=2), reducer=MeanReducer()) loss_funcD = ContrastiveLoss(pos_margin=1.5, neg_margin=0.6, distance=CosineSimilarity(), reducer=MeanReducer()) for dtype in TEST_DTYPES: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) lossC = loss_funcC(embeddings, labels) lossD = loss_funcD(embeddings, labels) pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)] neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3)] correct_pos_losses = [0, 0, 0, 0] correct_neg_losses = [0, 0, 0, 0] num_non_zero_pos = [0, 0, 0, 0] num_non_zero_neg = [0, 0, 0, 0] for a, p in pos_pairs: anchor, positive = embeddings[a], embeddings[p] correct_lossA = torch.relu( torch.sum((anchor - positive)**2) - 0.25) correct_lossB = torch.relu(1.5 - torch.matmul(anchor, positive)) correct_pos_losses[0] += correct_lossA correct_pos_losses[1] += correct_lossB correct_pos_losses[2] += correct_lossA correct_pos_losses[3] += correct_lossB if correct_lossA > 0: num_non_zero_pos[0] += 1 num_non_zero_pos[2] += 1 if correct_lossB > 0: num_non_zero_pos[1] += 1 num_non_zero_pos[3] += 1 for a, n in neg_pairs: anchor, negative = embeddings[a], embeddings[n] correct_lossA = torch.relu(1.5 - torch.sum((anchor - negative)**2)) correct_lossB = torch.relu( torch.matmul(anchor, negative) - 0.6) correct_neg_losses[0] += correct_lossA correct_neg_losses[1] += correct_lossB correct_neg_losses[2] += correct_lossA correct_neg_losses[3] += correct_lossB if correct_lossA > 0: num_non_zero_neg[0] += 1 num_non_zero_neg[2] += 1 if correct_lossB > 0: num_non_zero_neg[1] += 1 num_non_zero_neg[3] += 1 for i in range(2): if num_non_zero_pos[i] > 0: correct_pos_losses[i] /= num_non_zero_pos[i] if num_non_zero_neg[i] > 0: correct_neg_losses[i] /= num_non_zero_neg[i] for i in range(2, 4): correct_pos_losses[i] /= len(pos_pairs) correct_neg_losses[i] /= len(neg_pairs) correct_losses = [0, 0, 0, 0] for i in range(4): correct_losses[ i] = correct_pos_losses[i] + correct_neg_losses[i] rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(lossA, correct_losses[0], rtol=rtol)) self.assertTrue(torch.isclose(lossB, correct_losses[1], rtol=rtol)) self.assertTrue(torch.isclose(lossC, correct_losses[2], rtol=rtol)) self.assertTrue(torch.isclose(lossD, correct_losses[3], rtol=rtol))