class ReferenceDistanceLoss(torch.nn.Module): def __init__(self, pos_embs=None, neg_embs=None, **kwargs): super().__init__(**kwargs) self.pos_embs = (pos_embs.detach().clone() if pos_embs is not None else None) self.neg_embs = (neg_embs.detach().clone() if neg_embs is not None else None) self.distance = LpDistance(p=2) def forward(self, embeddings, *args, reduction="mean"): if len(embeddings) == 0: return self.zero_losses() pos_loss = torch.tensor(0) if self.pos_embs is not None: pos_loss = self.distance.pairwise_distance(embeddings, self.pos_embs) neg_loss = torch.tensor(0) if self.neg_embs is not None: neg_loss = self.distance.pairwise_distance(embeddings, self.neg_embs) loss = pos_loss - neg_loss if reduction == "mean": return loss.mean() if reduction == "none": return loss raise ValueError(f"unknown reduction: {reduction}")
def __init__(self, pos_embs=None, neg_embs=None, **kwargs): super().__init__(**kwargs) self.pos_embs = (pos_embs.detach().clone() if pos_embs is not None else None) self.neg_embs = (neg_embs.detach().clone() if neg_embs is not None else None) self.distance = LpDistance(p=2)
def test_uniform_histogram_miner(self): torch.manual_seed(93612) batch_size = 128 embedding_size = 32 num_bins, pos_per_bin, neg_per_bin = 100, 25, 123 for distance in [ LpDistance(p=1), LpDistance(p=2), LpDistance(normalize_embeddings=False), SNRDistance(), ]: miner = UniformHistogramMiner( num_bins=num_bins, pos_per_bin=pos_per_bin, neg_per_bin=neg_per_bin, distance=distance, ) for dtype in TEST_DTYPES: embeddings = torch.randn(batch_size, embedding_size, device=TEST_DEVICE, dtype=dtype) labels = torch.randint(0, 2, size=(batch_size, ), device=TEST_DEVICE) a1, p, a2, n = lmu.get_all_pairs_indices(labels) dist_mat = distance(embeddings) pos_pairs = dist_mat[a1, p] neg_pairs = dist_mat[a2, n] a1, p, a2, n = miner(embeddings, labels) if dtype == torch.float16: continue # histc doesn't work for Half tensor pos_histogram = torch.histc( dist_mat[a1, p], bins=num_bins, min=torch.min(pos_pairs), max=torch.max(pos_pairs), ) neg_histogram = torch.histc( dist_mat[a2, n], bins=num_bins, min=torch.min(neg_pairs), max=torch.max(neg_pairs), ) self.assertTrue( torch.all((pos_histogram == pos_per_bin) | (pos_histogram == 0))) self.assertTrue( torch.all((neg_histogram == neg_per_bin) | (neg_histogram == 0)))
def setUpClass(self): self.device = torch.device('cuda') self.dist_miner = BatchHardMiner(distance=LpDistance(normalize_embeddings=False)) self.normalized_dist_miner = BatchHardMiner(distance=LpDistance(normalize_embeddings=True)) self.normalized_dist_miner_squared = BatchHardMiner(distance=LpDistance(normalize_embeddings=True, power=2)) self.sim_miner = BatchHardMiner(distance=CosineSimilarity()) self.labels = torch.LongTensor([0, 0, 1, 1, 0, 2, 1, 1, 1]) self.correct_a = torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(self.device) self.correct_p = torch.LongTensor([4, 4, 8, 8, 0, 2, 2, 2]).to(self.device) self.correct_n = [torch.LongTensor([2, 2, 1, 4, 3, 5, 5, 5]).to(self.device), torch.LongTensor([2, 2, 1, 4, 5, 5, 5, 5]).to(self.device)]
def test_with_no_valid_pairs(self): all_embedding_angles = [[0], [0, 10, 20], [0, 40, 60]] all_labels = [ torch.LongTensor([0]), torch.LongTensor([0, 0, 0]), torch.LongTensor([1, 2, 3]), ] temperature = 0.1 for loss_class in [NTXentLoss, SupConLoss]: loss_funcA = loss_class(temperature) loss_funcB = loss_class(temperature, distance=LpDistance()) for loss_func in [loss_funcA, loss_funcB]: for dtype in TEST_DTYPES: for embedding_angles, labels in zip( all_embedding_angles, all_labels ): embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to( TEST_DEVICE ) # 2D embeddings loss = loss_func(embeddings, labels) loss.backward() self.assertEqual(loss, 0)
def __init__(self, margin, normalize_embeddings): self.margin = margin self.distance = LpDistance(normalize_embeddings=normalize_embeddings, collect_stats=True) # We use triplet loss with Euclidean distance self.miner_fn = HardTripletMinerWithMasks(distance=self.distance) reducer_fn = reducers.AvgNonZeroReducer(collect_stats=True) self.loss_fn = losses.TripletMarginLoss(margin=self.margin, swap=True, distance=self.distance, reducer=reducer_fn, collect_stats=True)
def __init__(self, pos_margin, neg_margin, normalize_embeddings): self.pos_margin = pos_margin self.neg_margin = neg_margin self.distance = LpDistance(normalize_embeddings=normalize_embeddings) self.miner_fn = HardTripletMinerWithMasks(distance=self.distance) # We use contrastive loss with squared Euclidean distance self.loss_fn = losses.ContrastiveLoss(pos_margin=self.pos_margin, neg_margin=self.neg_margin, distance=self.distance)
def __init__(self, margin, normalize_embeddings): self.margin = margin self.normalize_embeddings = normalize_embeddings self.distance = LpDistance(normalize_embeddings=normalize_embeddings) # We use triplet loss with Euclidean distance self.miner_fn = HardTripletMinerWithMasks(distance=self.distance) self.loss_fn = losses.TripletMarginLoss(margin=self.margin, swap=True, distance=self.distance)
def test_pair_margin_miner(self): for dtype in TEST_DTYPES: for distance in [LpDistance(), CosineSimilarity()]: embedding_angles = torch.arange(0, 16) embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.randint(low=0, high=2, size=(16,)) mat = distance(embeddings) pos_pairs = [] neg_pairs = [] for i in range(len(embeddings)): anchor_label = labels[i] for j in range(len(embeddings)): if j == i: continue positive_label = labels[j] if positive_label == anchor_label: ap_dist = mat[i,j] pos_pairs.append((i, j, ap_dist)) for i in range(len(embeddings)): anchor_label = labels[i] for j in range(len(embeddings)): if j == i: continue negative_label = labels[j] if negative_label != anchor_label: an_dist = mat[i,j] neg_pairs.append((i, j, an_dist)) for pos_margin_int in range(-1, 4): pos_margin = float(pos_margin_int) * 0.05 for neg_margin_int in range(2, 7): neg_margin = float(neg_margin_int) * 0.05 miner = PairMarginMiner(pos_margin, neg_margin, distance=distance) correct_pos_pairs = [] correct_neg_pairs = [] for i,j,k in pos_pairs: condition = (k < pos_margin) if distance.is_inverted else (k > pos_margin) if condition: correct_pos_pairs.append((i,j)) for i,j,k in neg_pairs: condition = (k > neg_margin) if distance.is_inverted else (k < neg_margin) if condition: correct_neg_pairs.append((i,j)) correct_pos = set(correct_pos_pairs) correct_neg = set(correct_neg_pairs) a1, p1, a2, n2 = miner(embeddings, labels) mined_pos = set([(a.item(),p.item()) for a,p in zip(a1,p1)]) mined_neg = set([(a.item(),n.item()) for a,n in zip(a2,n2)]) self.assertTrue(mined_pos == correct_pos) self.assertTrue(mined_neg == correct_neg)
def test_ntxent_loss(self): temperature = 0.1 loss_funcA = NTXentLoss(temperature=temperature) loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance()) for dtype in TEST_DTYPES: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)] neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3)] total_lossA, total_lossB = 0, 0 for a1, p in pos_pairs: anchor, positive = embeddings[a1], embeddings[p] numeratorA = torch.exp( torch.matmul(anchor, positive) / temperature) numeratorB = torch.exp( -torch.sqrt(torch.sum( (anchor - positive)**2)) / temperature) denominatorA = numeratorA.clone() denominatorB = numeratorB.clone() for a2, n in neg_pairs: if a2 == a1: negative = embeddings[n] else: continue denominatorA += torch.exp( torch.matmul(anchor, negative) / temperature) denominatorB += torch.exp( -torch.sqrt(torch.sum( (anchor - negative)**2)) / temperature) curr_lossA = -torch.log(numeratorA / denominatorA) curr_lossB = -torch.log(numeratorB / denominatorB) total_lossA += curr_lossA total_lossB += curr_lossB total_lossA /= len(pos_pairs) total_lossB /= len(pos_pairs) rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(lossA, total_lossA, rtol=rtol)) self.assertTrue(torch.isclose(lossB, total_lossB, rtol=rtol))
def __init__(self, pos_margin, neg_margin, normalize_embeddings): self.pos_margin = pos_margin self.neg_margin = neg_margin self.distance = LpDistance(normalize_embeddings=normalize_embeddings, collect_stats=True) self.miner_fn = HardTripletMinerWithMasks(distance=self.distance) # We use contrastive loss with squared Euclidean distance reducer_fn = reducers.AvgNonZeroReducer(collect_stats=True) self.loss_fn = losses.ContrastiveLoss(pos_margin=self.pos_margin, neg_margin=self.neg_margin, distance=self.distance, reducer=reducer_fn, collect_stats=True)
def setUpClass(self): self.device = torch.device('cuda') self.dist_miner = HDCMiner( filter_percentage=0.3, distance=LpDistance(normalize_embeddings=False)) self.normalized_dist_miner = HDCMiner( filter_percentage=0.3, distance=LpDistance(normalize_embeddings=True)) self.normalized_dist_miner_squared = HDCMiner( filter_percentage=0.3, distance=LpDistance(normalize_embeddings=True, power=2)) self.sim_miner = HDCMiner(filter_percentage=0.3, distance=CosineSimilarity()) self.labels = torch.LongTensor([0, 0, 1, 1, 1, 0]) correct_a1 = torch.LongTensor([0, 5, 1, 5]) correct_p = torch.LongTensor([5, 0, 5, 1]) self.correct_pos_pairs = torch.stack([correct_a1, correct_p], dim=1).to(self.device) correct_a2 = torch.LongTensor([1, 2, 4, 5, 0, 2]) correct_n = torch.LongTensor([2, 1, 5, 4, 2, 0]) self.correct_neg_pairs = torch.stack([correct_a2, correct_n], dim=1).to(self.device)
def test_logit_getter(self): embedding_size = 512 num_classes = 10 batch_size = 32 for dtype in TEST_DTYPES: embeddings = ( torch.randn(batch_size, embedding_size).to(TEST_DEVICE).type(dtype) ) kwargs = {"num_classes": num_classes, "embedding_size": embedding_size} loss1 = ArcFaceLoss(**kwargs).to(TEST_DEVICE).type(dtype) loss2 = NormalizedSoftmaxLoss(**kwargs).to(TEST_DEVICE).type(dtype) loss3 = ProxyAnchorLoss(**kwargs).to(TEST_DEVICE).type(dtype) # test the ability to infer shape for loss in [loss1, loss2, loss3]: self.helper_tester(loss, embeddings, batch_size, num_classes) # test specifying wrong layer name self.assertRaises(AttributeError, LogitGetter, loss1, layer_name="blah") # test specifying correct layer name self.helper_tester( loss1, embeddings, batch_size, num_classes, layer_name="W" ) # test specifying a distance metric self.helper_tester( loss1, embeddings, batch_size, num_classes, distance=LpDistance() ) # test specifying transpose incorrectly LG = LogitGetter(loss1, transpose=False) self.assertRaises(RuntimeError, LG, embeddings) # test specifying transpose correctly self.helper_tester( loss1, embeddings, batch_size, num_classes, transpose=True ) # test copying weights LG = LogitGetter(loss1) self.assertTrue(torch.all(LG.weights == loss1.W)) loss1.W.data *= 0 self.assertTrue(not torch.all(LG.weights == loss1.W)) # test not copying weights LG = LogitGetter(loss1, copy_weights=False) self.assertTrue(torch.all(LG.weights == loss1.W)) loss1.W.data *= 0 self.assertTrue(torch.all(LG.weights == loss1.W))
def test_backward(self): temperature = 0.1 loss_funcA = NTXentLoss(temperature=temperature) loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance()) for dtype in TEST_DTYPES: for loss_func in [loss_funcA, loss_funcB]: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) loss.backward()
def normal_loss(c, scale): criterion = TripletMarginLoss(triplets_per_anchor=1, distance=LpDistance(normalize_embeddings=False, p=1)) l_rv = stats.norm(loc=c, scale=scale) r_rv = stats.norm(scale=scale) r_rv = l_rv l = l_rv.rvs(10).reshape((5, 2)) r = r_rv.rvs(10).reshape((5, 2)) df = pd.DataFrame() df['x'] = np.concatenate((l[:, 0], r[:, 0])) df['y'] = np.concatenate((l[:, 1], r[:, 1])) df['label'] = np.concatenate((np.zeros(5), np.ones(5))) embeddings = torch.as_tensor(np.concatenate((l, r))) labels = torch.as_tensor(df['label']) loss = criterion(embeddings, labels) print(f'center = {c}, loss = {loss}')
def get_pos_neg_vals(self, use_pairwise): output = (0, 1, LpDistance(power=2)) if not use_pairwise: return (1, 0, CosineSimilarity()) return output
def test_multi_similarity_miner(self): epsilon = 0.1 for dtype in TEST_DTYPES: for distance in [CosineSimilarity(), LpDistance()]: miner = MultiSimilarityMiner(epsilon, distance=distance) embedding_angles = torch.arange(0, 64) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=10, size=(64, )) mat = distance(embeddings) pos_pairs = [] neg_pairs = [] for i in range(len(embeddings)): anchor_label = labels[i] for j in range(len(embeddings)): if j != i: other_label = labels[j] if anchor_label == other_label: pos_pairs.append((i, j, mat[i, j])) if anchor_label != other_label: neg_pairs.append((i, j, mat[i, j])) correct_a1, correct_p = [], [] correct_a2, correct_n = [], [] for a1, p, ap_sim in pos_pairs: most_difficult = (c_f.neg_inf(dtype) if distance.is_inverted else c_f.pos_inf(dtype)) for a2, n, an_sim in neg_pairs: if a2 == a1: condition = ((an_sim > most_difficult) if distance.is_inverted else (an_sim < most_difficult)) if condition: most_difficult = an_sim condition = ((ap_sim < most_difficult + epsilon) if distance.is_inverted else (ap_sim > most_difficult - epsilon)) if condition: correct_a1.append(a1) correct_p.append(p) for a2, n, an_sim in neg_pairs: most_difficult = (c_f.pos_inf(dtype) if distance.is_inverted else c_f.neg_inf(dtype)) for a1, p, ap_sim in pos_pairs: if a2 == a1: condition = ((ap_sim < most_difficult) if distance.is_inverted else (ap_sim > most_difficult)) if condition: most_difficult = ap_sim condition = ((an_sim > most_difficult - epsilon) if distance.is_inverted else (an_sim < most_difficult + epsilon)) if condition: correct_a2.append(a2) correct_n.append(n) correct_pos_pairs = set([ (a, p) for a, p in zip(correct_a1, correct_p) ]) correct_neg_pairs = set([ (a, n) for a, n in zip(correct_a2, correct_n) ]) a1, p1, a2, n2 = miner(embeddings, labels) pos_pairs = set([(a.item(), p.item()) for a, p in zip(a1, p1)]) neg_pairs = set([(a.item(), n.item()) for a, n in zip(a2, n2)]) self.assertTrue(pos_pairs == correct_pos_pairs) self.assertTrue(neg_pairs == correct_neg_pairs)
def test_triplet_margin_miner(self): for dtype in TEST_DTYPES: for distance in [LpDistance(), CosineSimilarity()]: embedding_angles = torch.arange(0, 16) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(self.device) # 2D embeddings labels = torch.randint(low=0, high=2, size=(16, )) mat = distance(embeddings) triplets = [] for i in range(len(embeddings)): anchor_label = labels[i] for j in range(len(embeddings)): if j == i: continue positive_label = labels[j] if positive_label == anchor_label: ap_dist = mat[i, j] for k in range(len(embeddings)): if k == j or k == i: continue negative_label = labels[k] if negative_label != positive_label: an_dist = mat[i, k] if distance.is_inverted: triplets.append( (i, j, k, ap_dist - an_dist)) else: triplets.append( (i, j, k, an_dist - ap_dist)) for margin_int in range(-1, 11): margin = float(margin_int) * 0.05 minerA = TripletMarginMiner(margin, type_of_triplets="all", distance=distance) minerB = TripletMarginMiner(margin, type_of_triplets="hard", distance=distance) minerC = TripletMarginMiner(margin, type_of_triplets="semihard", distance=distance) minerD = TripletMarginMiner(margin, type_of_triplets="easy", distance=distance) correctA, correctB, correctC, correctD = [], [], [], [] for i, j, k, distance_diff in triplets: if distance_diff > margin: correctD.append((i, j, k)) else: correctA.append((i, j, k)) if distance_diff > 0: correctC.append((i, j, k)) if distance_diff <= 0: correctB.append((i, j, k)) for correct, miner in [ (correctA, minerA), (correctB, minerB), (correctC, minerC), (correctD, minerD), ]: correct_triplets = set(correct) a1, p1, n1 = miner(embeddings, labels) mined_triplets = set([(a.item(), p.item(), n.item()) for a, p, n in zip(a1, p1, n1)]) self.assertTrue(mined_triplets == correct_triplets)
def test_ntxent_loss(self): temperature = 0.1 loss_funcA = NTXentLoss(temperature=temperature) loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance()) loss_funcC = NTXentLoss( temperature=temperature, reducer=PerAnchorReducer(AvgNonZeroReducer()) ) loss_funcD = SupConLoss(temperature=temperature) loss_funcE = SupConLoss(temperature=temperature, distance=LpDistance()) for dtype in TEST_DTYPES: embedding_angles = [0, 10, 20, 50, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to( TEST_DEVICE ) # 2D embeddings labels = torch.LongTensor([0, 0, 0, 1, 1, 2]) obtained_losses = [ x(embeddings, labels) for x in [loss_funcA, loss_funcB, loss_funcC, loss_funcD, loss_funcE] ] pos_pairs = [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1), (3, 4), (4, 3)] neg_pairs = [ (0, 3), (0, 4), (0, 5), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 0), (3, 1), (3, 2), (3, 5), (4, 0), (4, 1), (4, 2), (4, 5), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), ] total_lossA, total_lossB, total_lossC, total_lossD, total_lossE = ( 0, 0, torch.zeros(5, device=TEST_DEVICE, dtype=dtype), torch.zeros(5, device=TEST_DEVICE, dtype=dtype), torch.zeros(5, device=TEST_DEVICE, dtype=dtype), ) for a1, p in pos_pairs: anchor, positive = embeddings[a1], embeddings[p] numeratorA = torch.exp(torch.matmul(anchor, positive) / temperature) numeratorB = torch.exp( -torch.sqrt(torch.sum((anchor - positive) ** 2)) / temperature ) denominatorA = numeratorA.clone() denominatorB = numeratorB.clone() denominatorD = 0 denominatorE = 0 for a2, n in pos_pairs + neg_pairs: if a2 == a1: negative = embeddings[n] curr_denomD = torch.exp( torch.matmul(anchor, negative) / temperature ) curr_denomE = torch.exp( -torch.sqrt(torch.sum((anchor - negative) ** 2)) / temperature ) denominatorD += curr_denomD denominatorE += curr_denomE if (a2, n) not in pos_pairs: denominatorA += curr_denomD denominatorB += curr_denomE else: continue curr_lossA = -torch.log(numeratorA / denominatorA) curr_lossB = -torch.log(numeratorB / denominatorB) curr_lossD = -torch.log(numeratorA / denominatorD) curr_lossE = -torch.log(numeratorB / denominatorE) total_lossA += curr_lossA total_lossB += curr_lossB total_lossC[a1] += curr_lossA total_lossD[a1] += curr_lossD total_lossE[a1] += curr_lossE total_lossA /= len(pos_pairs) total_lossB /= len(pos_pairs) pos_pair_per_anchor = torch.tensor( [2, 2, 2, 1, 1], device=TEST_DEVICE, dtype=dtype ) total_lossC, total_lossD, total_lossE = [ torch.mean(x / pos_pair_per_anchor) for x in [total_lossC, total_lossD, total_lossE] ] rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(obtained_losses[0], total_lossA, rtol=rtol)) self.assertTrue(torch.isclose(obtained_losses[1], total_lossB, rtol=rtol)) self.assertTrue(torch.isclose(obtained_losses[2], total_lossC, rtol=rtol)) self.assertTrue(torch.isclose(obtained_losses[3], total_lossD, rtol=rtol)) self.assertTrue(torch.isclose(obtained_losses[4], total_lossE, rtol=rtol))
def setUpClass(self): self.labels = torch.LongTensor([0, 0, 1, 1, 0, 2, 1, 1, 1]) self.a1_idx, self.p_idx, self.a2_idx, self.n_idx = lmu.get_all_pairs_indices( self.labels) self.distance = LpDistance(normalize_embeddings=False) self.gt = { "batch_semihard_hard": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.SEMIHARD, neg_strategy=BatchEasyHardMiner.HARD, ), "easiest_triplet": -1, "hardest_triplet": -1, "easiest_pos_pair": 1, "hardest_pos_pair": 2, "easiest_neg_pair": 3, "hardest_neg_pair": 2, "expected": { "correct_a": torch.LongTensor([0, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([1, 6, 6]).to(TEST_DEVICE), torch.LongTensor([1, 8, 6]).to(TEST_DEVICE), ], "correct_n": [ torch.LongTensor([2, 5, 5]).to(TEST_DEVICE), torch.LongTensor([2, 5, 5]).to(TEST_DEVICE), ], }, }, "batch_hard_semihard": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.HARD, neg_strategy=BatchEasyHardMiner.SEMIHARD, ), "easiest_triplet": -1, "hardest_triplet": -1, "easiest_pos_pair": 3, "hardest_pos_pair": 6, "easiest_neg_pair": 7, "hardest_neg_pair": 4, "expected": { "correct_a": torch.LongTensor([0, 1, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [torch.LongTensor([4, 4, 2, 2, 2]).to(TEST_DEVICE)], "correct_n": [ torch.LongTensor([5, 5, 1, 1, 1]).to(TEST_DEVICE), ], }, }, "batch_easy_semihard": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.EASY, neg_strategy=BatchEasyHardMiner.SEMIHARD, ), "easiest_triplet": -2, "hardest_triplet": -1, "easiest_pos_pair": 1, "hardest_pos_pair": 3, "easiest_neg_pair": 4, "hardest_neg_pair": 2, "expected": { "correct_a": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([1, 0, 3, 2, 1, 7, 8, 7]).to(TEST_DEVICE), torch.LongTensor([1, 0, 3, 2, 1, 7, 6, 7]).to(TEST_DEVICE), ], "correct_n": [ torch.LongTensor([2, 3, 0, 1, 8, 4, 5, 5]).to(TEST_DEVICE), torch.LongTensor([2, 3, 4, 1, 8, 4, 5, 5]).to(TEST_DEVICE), torch.LongTensor([2, 3, 0, 5, 8, 4, 5, 5]).to(TEST_DEVICE), torch.LongTensor([2, 3, 4, 5, 8, 4, 5, 5]).to(TEST_DEVICE), ], }, }, "batch_hard_hard": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.HARD, neg_strategy=BatchEasyHardMiner.HARD, ), "easiest_triplet": 2, "hardest_triplet": 5, "easiest_pos_pair": 3, "hardest_pos_pair": 6, "easiest_neg_pair": 3, "hardest_neg_pair": 1, "expected": { "correct_a": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([4, 4, 8, 8, 0, 2, 2, 2]).to(TEST_DEVICE) ], "correct_n": [ torch.LongTensor([2, 2, 1, 4, 3, 5, 5, 5]).to(TEST_DEVICE), torch.LongTensor([2, 2, 1, 4, 5, 5, 5, 5]).to(TEST_DEVICE), ], }, }, "batch_easy_hard": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.EASY, neg_strategy=BatchEasyHardMiner.HARD, ), "easiest_triplet": -2, "hardest_triplet": 2, "easiest_pos_pair": 1, "hardest_pos_pair": 3, "easiest_neg_pair": 3, "hardest_neg_pair": 1, "expected": { "correct_a": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([1, 0, 3, 2, 1, 7, 8, 7]).to(TEST_DEVICE), torch.LongTensor([1, 0, 3, 2, 1, 7, 6, 7]).to(TEST_DEVICE), ], "correct_n": [ torch.LongTensor([2, 2, 1, 4, 3, 5, 5, 5]).to(TEST_DEVICE), torch.LongTensor([2, 2, 1, 4, 5, 5, 5, 5]).to(TEST_DEVICE), ], }, }, "batch_hard_easy": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.HARD, neg_strategy=BatchEasyHardMiner.EASY, ), "easiest_triplet": -4, "hardest_triplet": 3, "easiest_pos_pair": 3, "hardest_pos_pair": 6, "easiest_neg_pair": 8, "hardest_neg_pair": 3, "expected": { "correct_a": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([4, 4, 8, 8, 0, 2, 2, 2]).to(TEST_DEVICE) ], "correct_n": [ torch.LongTensor([8, 8, 5, 0, 8, 0, 0, 0]).to(TEST_DEVICE) ], }, }, "batch_easy_easy": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.EASY, neg_strategy=BatchEasyHardMiner.EASY, ), "easiest_triplet": -7, "hardest_triplet": -1, "easiest_pos_pair": 1, "hardest_pos_pair": 3, "easiest_neg_pair": 8, "hardest_neg_pair": 3, "expected": { "correct_a": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([1, 0, 3, 2, 1, 7, 8, 7]).to(TEST_DEVICE), torch.LongTensor([1, 0, 3, 2, 1, 7, 6, 7]).to(TEST_DEVICE), ], "correct_n": [ torch.LongTensor([8, 8, 5, 0, 8, 0, 0, 0]).to(TEST_DEVICE) ], }, }, "batch_easy_easy_with_min_val": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.EASY, neg_strategy=BatchEasyHardMiner.EASY, allowed_neg_range=[1, 7], allowed_pos_range=[1, 7], ), "easiest_triplet": -6, "hardest_triplet": -1, "easiest_pos_pair": 1, "hardest_pos_pair": 3, "easiest_neg_pair": 7, "hardest_neg_pair": 3, "expected": { "correct_a": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([1, 0, 3, 2, 1, 7, 8, 7]).to(TEST_DEVICE), torch.LongTensor([1, 0, 3, 2, 1, 7, 6, 7]).to(TEST_DEVICE), ], "correct_n": [ torch.LongTensor([7, 8, 5, 0, 8, 0, 0, 1]).to(TEST_DEVICE) ], }, }, "batch_easy_all": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.EASY, neg_strategy=BatchEasyHardMiner.ALL, ), "easiest_triplet": 0, "hardest_triplet": 0, "easiest_pos_pair": 1, "hardest_pos_pair": 3, "easiest_neg_pair": 8, "hardest_neg_pair": 1, "expected": { "correct_a1": torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(TEST_DEVICE), "correct_p": [ torch.LongTensor([1, 0, 3, 2, 1, 7, 8, 7]).to(TEST_DEVICE), torch.LongTensor([1, 0, 3, 2, 1, 7, 6, 7]).to(TEST_DEVICE), ], "correct_a2": self.a2_idx, "correct_n": [self.n_idx], }, }, "batch_all_easy": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.ALL, neg_strategy=BatchEasyHardMiner.EASY, ), "easiest_triplet": 0, "hardest_triplet": 0, "easiest_pos_pair": 1, "hardest_pos_pair": 6, "easiest_neg_pair": 8, "hardest_neg_pair": 3, "expected": { "correct_a1": self.a1_idx, "correct_p": [self.p_idx], "correct_a2": torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8]).to(TEST_DEVICE), "correct_n": [ torch.LongTensor([8, 8, 5, 0, 8, 0, 0, 0, 0]).to(TEST_DEVICE), ], }, }, "batch_all_all": { "miner": BatchEasyHardMiner( distance=self.distance, pos_strategy=BatchEasyHardMiner.ALL, neg_strategy=BatchEasyHardMiner.ALL, ), "easiest_triplet": 0, "hardest_triplet": 0, "easiest_pos_pair": 1, "hardest_pos_pair": 6, "easiest_neg_pair": 8, "hardest_neg_pair": 1, "expected": { "correct_a1": self.a1_idx, "correct_p": [self.p_idx], "correct_a2": self.a2_idx, "correct_n": [self.n_idx], }, }, }
def test_contrastive_loss(self): loss_funcA = ContrastiveLoss(pos_margin=0.25, neg_margin=1.5, distance=LpDistance(power=2)) loss_funcB = ContrastiveLoss(pos_margin=1.5, neg_margin=0.6, distance=CosineSimilarity()) loss_funcC = ContrastiveLoss(pos_margin=0.25, neg_margin=1.5, distance=LpDistance(power=2), reducer=MeanReducer()) loss_funcD = ContrastiveLoss(pos_margin=1.5, neg_margin=0.6, distance=CosineSimilarity(), reducer=MeanReducer()) for dtype in TEST_DTYPES: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) lossC = loss_funcC(embeddings, labels) lossD = loss_funcD(embeddings, labels) pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)] neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3)] correct_pos_losses = [0, 0, 0, 0] correct_neg_losses = [0, 0, 0, 0] num_non_zero_pos = [0, 0, 0, 0] num_non_zero_neg = [0, 0, 0, 0] for a, p in pos_pairs: anchor, positive = embeddings[a], embeddings[p] correct_lossA = torch.relu( torch.sum((anchor - positive)**2) - 0.25) correct_lossB = torch.relu(1.5 - torch.matmul(anchor, positive)) correct_pos_losses[0] += correct_lossA correct_pos_losses[1] += correct_lossB correct_pos_losses[2] += correct_lossA correct_pos_losses[3] += correct_lossB if correct_lossA > 0: num_non_zero_pos[0] += 1 num_non_zero_pos[2] += 1 if correct_lossB > 0: num_non_zero_pos[1] += 1 num_non_zero_pos[3] += 1 for a, n in neg_pairs: anchor, negative = embeddings[a], embeddings[n] correct_lossA = torch.relu(1.5 - torch.sum((anchor - negative)**2)) correct_lossB = torch.relu( torch.matmul(anchor, negative) - 0.6) correct_neg_losses[0] += correct_lossA correct_neg_losses[1] += correct_lossB correct_neg_losses[2] += correct_lossA correct_neg_losses[3] += correct_lossB if correct_lossA > 0: num_non_zero_neg[0] += 1 num_non_zero_neg[2] += 1 if correct_lossB > 0: num_non_zero_neg[1] += 1 num_non_zero_neg[3] += 1 for i in range(2): if num_non_zero_pos[i] > 0: correct_pos_losses[i] /= num_non_zero_pos[i] if num_non_zero_neg[i] > 0: correct_neg_losses[i] /= num_non_zero_neg[i] for i in range(2, 4): correct_pos_losses[i] /= len(pos_pairs) correct_neg_losses[i] /= len(neg_pairs) correct_losses = [0, 0, 0, 0] for i in range(4): correct_losses[ i] = correct_pos_losses[i] + correct_neg_losses[i] rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(lossA, correct_losses[0], rtol=rtol)) self.assertTrue(torch.isclose(lossB, correct_losses[1], rtol=rtol)) self.assertTrue(torch.isclose(lossC, correct_losses[2], rtol=rtol)) self.assertTrue(torch.isclose(lossD, correct_losses[3], rtol=rtol))
def triplet_margin_loss_factory(triplets_per_anchor, normalize_embeddings, p): criterion = TripletMarginLoss( triplets_per_anchor=triplets_per_anchor, distance=LpDistance(normalize_embeddings=normalize_embeddings, p=p)) return criterion