def setUpClass(self): embedding_angles = [0, 10, 20, 30, 50, 60, 70, 80] embeddings1 = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles]) labels1 = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1]) embedding_angles = [1, 11, 21, 31, 51, 59, 71, 81] embeddings2 = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles]) labels2 = torch.LongTensor([1, 1, 1, 1, 1, 0, 0, 0]) self.dataset_dict = { "train": c_f.EmbeddingDataset(embeddings1, labels1), "val": c_f.EmbeddingDataset(embeddings2, labels2), }
def test_normalized_softmax_loss(self): temperature = 0.1 for dtype in TEST_DTYPES: loss_func = NormalizedSoftmaxLoss( temperature=temperature, num_classes=10, embedding_size=2 ) embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to( TEST_DEVICE ) # 2D embeddings labels = torch.randint(low=0, high=10, size=(180,)).to(TEST_DEVICE) loss = loss_func(embeddings, labels) loss.backward() weights = torch.nn.functional.normalize(loss_func.W, p=2, dim=0) logits = torch.matmul(embeddings, weights) correct_loss = torch.nn.functional.cross_entropy( logits / temperature, labels ) rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))
def test_sim_mining(self): angles = [0, 20, 40, 60, 80, 100, 120, 140, 160] embeddings = torch.FloatTensor([c_f.angle_to_coord(a) for a in angles]) a, p, n = self.sim_miner(embeddings, self.labels) self.helper(a, p, n) self.assertAlmostEqual(self.sim_miner.hardest_pos_pair_dist, np.cos(np.radians(120)), places=5) self.assertAlmostEqual(self.sim_miner.hardest_neg_pair_dist, np.cos(np.radians(20)), places=5)
def test_distance_weighted_miner(self): embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.randint(low=0, high=2, size=(180, )) a, _, n = lmu.get_all_triplets_indices(labels) all_an_dist = torch.nn.functional.pairwise_distance( embeddings[a], embeddings[n], 2) min_an_dist = torch.min(all_an_dist) for non_zero_cutoff_int in range(5, 15): non_zero_cutoff = (float(non_zero_cutoff_int) / 10.) - 0.01 miner = DistanceWeightedMiner(0, non_zero_cutoff) a, p, n = miner(embeddings, labels) anchors, positives, negatives = embeddings[a], embeddings[ p], embeddings[n] an_dist = torch.nn.functional.pairwise_distance( anchors, negatives, 2) self.assertTrue(torch.max(an_dist) <= non_zero_cutoff) an_dist_var = torch.var(an_dist) an_dist_mean = torch.mean(an_dist) target_var = ((non_zero_cutoff - min_an_dist)** 2) / 12 # variance formula for uniform distribution target_mean = (non_zero_cutoff - min_an_dist) / 2 self.assertTrue( torch.abs(an_dist_var - target_var) / target_var < 0.1) self.assertTrue( torch.abs(an_dist_mean - target_mean) / target_mean < 0.1)
def test_backward(self): margin = 10 scale = 2 for dtype in TEST_DTYPES: loss_funcA = LargeMarginSoftmaxLoss(margin=margin, scale=scale, num_classes=10, embedding_size=2) loss_funcB = SphereFaceLoss(margin=margin, scale=scale, num_classes=10, embedding_size=2) for loss_func in [loss_funcA, loss_funcB]: embedding_angles = torch.arange(0, 180) # multiply by 10 to make the embeddings unnormalized embeddings = torch.tensor( np.array([c_f.angle_to_coord(a) for a in embedding_angles]) * 10, requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.randint(low=0, high=10, size=(180, )).to(self.device) loss = loss_func(embeddings, labels) loss.backward()
def test_triplet_margin_loss(self): margin = 0.2 loss_funcA = TripletMarginLoss(margin=margin) loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer()) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) lossA.backward() lossB.backward() triplets = [(0, 1, 2), (0, 1, 3), (0, 1, 4), (1, 0, 2), (1, 0, 3), (1, 0, 4), (2, 3, 0), (2, 3, 1), (2, 3, 4), (3, 2, 0), (3, 2, 1), (3, 2, 4)] correct_loss = 0 num_non_zero_triplets = 0 for a, p, n in triplets: anchor, positive, negative = embeddings[a], embeddings[ p], embeddings[n] curr_loss = torch.relu( torch.sqrt(torch.sum((anchor - positive)**2)) - torch.sqrt(torch.sum((anchor - negative)**2)) + margin) if curr_loss > 0: num_non_zero_triplets += 1 correct_loss += curr_loss self.assertTrue( torch.isclose(lossA, correct_loss / num_non_zero_triplets)) self.assertTrue(torch.isclose(lossB, correct_loss / len(triplets)))
def test_with_distance_weighted_miner(self): for dtype in TEST_DTYPES: memory_size = 256 inner_loss = NTXentLoss(temperature=0.1) inner_miner = DistanceWeightedMiner(cutoff=0.5, nonzero_loss_cutoff=1.4) loss_with_miner = CrossBatchMemory( loss=inner_loss, embedding_size=2, memory_size=memory_size, miner=inner_miner, ) for i in range(20): embedding_angles = torch.arange(0, 32) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=10, size=(32, )).to(TEST_DEVICE) loss_val = loss_with_miner(embeddings, labels) loss_val.backward() self.assertTrue( True) # just check if we got here without an exception
def test_with_no_valid_pairs(self): all_embedding_angles = [[0], [0, 10, 20], [0, 40, 60]] all_labels = [ torch.LongTensor([0]), torch.LongTensor([0, 0, 0]), torch.LongTensor([1, 2, 3]), ] temperature = 0.1 for loss_class in [NTXentLoss, SupConLoss]: loss_funcA = loss_class(temperature) loss_funcB = loss_class(temperature, distance=LpDistance()) for loss_func in [loss_funcA, loss_funcB]: for dtype in TEST_DTYPES: for embedding_angles, labels in zip( all_embedding_angles, all_labels ): embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to( TEST_DEVICE ) # 2D embeddings loss = loss_func(embeddings, labels) loss.backward() self.assertEqual(loss, 0)
def test_soft_triple_loss(self): embedding_size = 2 num_classes = 11 la = 20 gamma = 0.1 reg_weight = 0.2 margin = 0.01 for centers_per_class in range(1, 12): loss_func = SoftTripleLoss(embedding_size, num_classes, centers_per_class=centers_per_class, la=la, gamma=gamma, reg_weight=reg_weight, margin=margin) original_loss_func = OriginalImplementationSoftTriple( la, gamma, reg_weight, margin, embedding_size, num_classes, centers_per_class) loss_func.fc = original_loss_func.fc embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) loss = loss_func(embeddings, labels) loss.backward() correct_loss = original_loss_func(embeddings, labels) self.assertTrue(torch.isclose(loss, correct_loss))
def test_input_indices_tuple(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) miner = MultiSimilarityMiner() loss_func1 = MultipleLosses(losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }) loss_func2 = MultipleLosses(losses=[lossA, lossB], weights=[1, 0.23]) for loss_func in [loss_func1, loss_func2]: for dtype in TEST_DTYPES: embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) indices_tuple = miner(embeddings, labels) loss = loss_func(embeddings, labels, indices_tuple) loss.backward() correct_loss = ( lossA(embeddings, labels, indices_tuple) + lossB(embeddings, labels, indices_tuple) * 0.23) self.assertTrue(torch.isclose(loss, correct_loss))
def test_normalized_dist_squared_mining(self): for dtype in TEST_DTYPES: angles = [0, 20, 40, 60, 80, 100, 120, 140, 160] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in angles], dtype=dtype ).to(TEST_DEVICE) a, p, n = self.normalized_dist_miner_squared(embeddings, self.labels) self.helper(a, p, n) correct_hardest_pos_pair_dist = torch.sum( (embeddings[2] - embeddings[8]) ** 2 ).item() correct_hardest_neg_pair_dist = torch.sum( (embeddings[1] - embeddings[2]) ** 2 ).item() places = 2 if dtype == torch.float16 else 5 self.assertAlmostEqual( self.normalized_dist_miner_squared.hardest_pos_pair_dist, correct_hardest_pos_pair_dist, places=places, ) self.assertAlmostEqual( self.normalized_dist_miner_squared.hardest_neg_pair_dist, correct_hardest_neg_pair_dist, places=places, )
def test_nca_loss(self): softmax_scale = 10 loss_func = NCALoss(softmax_scale=softmax_scale) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) loss.backward() pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)] neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3)] correct_total = 0 for a1, p in pos_pairs: anchor1, positive = embeddings[a1], embeddings[p] ap_dist = torch.sum((anchor1 - positive)**2) numerator = torch.exp(-ap_dist * softmax_scale) denominator = numerator.clone() for a2, n in neg_pairs: if a2 == a1: anchor2, negative = embeddings[a2], embeddings[n] an_dist = torch.sum((anchor2 - negative)**2) denominator += torch.exp(-an_dist * softmax_scale) correct_total += -torch.log(numerator / denominator) correct_total /= len(pos_pairs) self.assertTrue(torch.isclose(loss, correct_total))
def test_intra_pair_variance_loss(self): pos_eps, neg_eps = 0.01, 0.02 loss_func = IntraPairVarianceLoss(pos_eps, neg_eps) for dtype in TEST_DTYPES: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(self.device) # 2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) loss.backward() pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)] neg_pairs = [ (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3), ] pos_total, neg_total = 0, 0 mean_pos = 0 mean_neg = 0 for a, p in pos_pairs: mean_pos += torch.matmul(embeddings[a], embeddings[p]) for a, n in neg_pairs: mean_neg += torch.matmul(embeddings[a], embeddings[n]) mean_pos /= len(pos_pairs) mean_neg /= len(neg_pairs) for a, p in pos_pairs: pos_total += (torch.relu( ((1 - pos_eps) * mean_pos - torch.matmul(embeddings[a], embeddings[p])))**2) for a, n in neg_pairs: neg_total += (torch.relu( (torch.matmul(embeddings[a], embeddings[n]) - (1 + neg_eps) * mean_neg))**2) pos_total /= len(pos_pairs) neg_total /= len(neg_pairs) correct_total = pos_total + neg_total rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(loss, correct_total, rtol=rtol))
def test_multiple_losses(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) loss_func = MultipleLosses(losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }) for dtype in TEST_DTYPES: embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) loss = loss_func(embeddings, labels) loss.backward() correct_loss = lossA(embeddings, labels) + lossB(embeddings, labels) * 0.23 self.assertTrue(torch.isclose(loss, correct_loss))
def test_distance_weighted_miner(self, with_ref_labels=False): for dtype in TEST_DTYPES: embedding_angles = torch.arange(0, 256) embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings ref_embeddings = embeddings.clone() if with_ref_labels else None labels = torch.randint(low=0, high=2, size=(256,)) ref_labels = torch.randint(low=0, high=2, size=(256,)) if with_ref_labels else None a,_,n = lmu.get_all_triplets_indices(labels, ref_labels) if with_ref_labels: all_an_dist = torch.nn.functional.pairwise_distance(embeddings[a], ref_embeddings[n], 2) else: all_an_dist = torch.nn.functional.pairwise_distance(embeddings[a], embeddings[n], 2) min_an_dist = torch.min(all_an_dist) for non_zero_cutoff_int in range(5, 15): non_zero_cutoff = (float(non_zero_cutoff_int) / 10.) - 0.01 miner = DistanceWeightedMiner(0, non_zero_cutoff) a, p, n = miner(embeddings, labels, ref_embeddings, ref_labels) if with_ref_labels: anchors, positives, negatives = embeddings[a], ref_embeddings[p], ref_embeddings[n] else: anchors, positives, negatives = embeddings[a], embeddings[p], embeddings[n] an_dist = torch.nn.functional.pairwise_distance(anchors, negatives, 2) self.assertTrue(torch.max(an_dist)<=non_zero_cutoff) an_dist_var = torch.var(an_dist) an_dist_mean = torch.mean(an_dist) target_var = ((non_zero_cutoff - min_an_dist)**2) / 12 # variance formula for uniform distribution target_mean = (non_zero_cutoff - min_an_dist) / 2 self.assertTrue(torch.abs(an_dist_var-target_var)/target_var < 0.1) self.assertTrue(torch.abs(an_dist_mean-target_mean)/target_mean < 0.1)
def test_sim_mining(self): angles = [0, 20, 40, 60, 80, 100] embeddings = torch.FloatTensor([c_f.angle_to_coord(a) for a in angles]) a1, p, a2, n = self.sim_miner(embeddings, self.labels) pos_pairs = torch.stack([a1, p], dim=1) neg_pairs = torch.stack([a2, n], dim=1) self.helper(pos_pairs, neg_pairs)
def test_ntxent_loss(self): temperature = 0.1 loss_func = NTXentLoss(temperature=temperature) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.FloatTensor( [c_f.angle_to_coord(a) for a in embedding_angles]) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)] neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3)] total_loss = 0 for a1, p in pos_pairs: anchor, positive = embeddings[a1], embeddings[p] numerator = torch.exp(torch.matmul(anchor, positive) / temperature) denominator = numerator.clone() for a2, n in neg_pairs: if a2 == a1: negative = embeddings[n] else: continue denominator += torch.exp( torch.matmul(anchor, negative) / temperature) curr_loss = -torch.log(numerator / denominator) total_loss += curr_loss total_loss /= len(pos_pairs) self.assertTrue(torch.isclose(loss, total_loss))
def test_with_no_valid_pairs(self): loss = NTXentLoss(temperature=0.1) embedding_angles = [0] embeddings = torch.FloatTensor( [c_f.angle_to_coord(a) for a in embedding_angles]) #2D embeddings labels = torch.LongTensor([0]) self.assertEqual(loss(embeddings, labels), 0)
def test_cosface_loss(self): margin = 0.5 scale = 64 for dtype in TEST_DTYPES: loss_func = CosFaceLoss(margin=margin, scale=scale, num_classes=10, embedding_size=2) embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) loss = loss_func(embeddings, labels) loss.backward() weights = torch.nn.functional.normalize(loss_func.W, p=2, dim=0) logits = torch.matmul(embeddings, weights) for i, c in enumerate(labels): logits[i, c] -= margin correct_loss = torch.nn.functional.cross_entropy( logits * scale, labels.to(TEST_DEVICE)) rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))
def test_proxyanchor_loss(self): num_classes = 10 embedding_size = 2 margin = 0.5 for dtype in TEST_DTYPES: alpha = 1 if dtype == torch.float16 else 32 loss_func = ProxyAnchorLoss(num_classes, embedding_size, margin=margin, alpha=alpha).to(TEST_DEVICE) original_loss_func = OriginalImplementationProxyAnchor( num_classes, embedding_size, mrg=margin, alpha=alpha).to(TEST_DEVICE) original_loss_func.proxies.data = original_loss_func.proxies.data.type( dtype) loss_func.proxies = original_loss_func.proxies embedding_angles = list(range(0, 180)) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=5, size=(180, )).to(TEST_DEVICE) loss = loss_func(embeddings, labels) loss.backward() correct_loss = original_loss_func(embeddings, labels) rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))
def test_npairs_loss(self): loss_funcA = NPairsLoss() loss_funcB = NPairsLoss(embedding_regularizer=LpRegularizer()) for dtype in TEST_DTYPES: embedding_angles = list(range(0,180,20))[:7] embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 1, 2, 3]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) pos_pairs = [(0,1), (2,3)] neg_pairs = [(0,3), (2,1)] total_loss = 0 for a1, p in pos_pairs: anchor, positive = embeddings[a1], embeddings[p] numerator = torch.exp(torch.matmul(anchor, positive)) denominator = numerator.clone() for a2, n in neg_pairs: if a2 == a1: negative = embeddings[n] denominator += torch.exp(torch.matmul(anchor, negative)) curr_loss = -torch.log(numerator/denominator) total_loss += curr_loss total_loss /= len(pos_pairs[0]) self.assertTrue(torch.isclose(lossA, total_loss)) self.assertTrue(torch.isclose(lossB, total_loss+1)) # l2_reg is going to be 1 since the embeddings are normalized
def test_proxy_nca_loss(self): for dtype in TEST_DTYPES: softmax_scale = 1 if dtype == torch.float16 else 10 loss_func = ProxyNCALoss(softmax_scale=softmax_scale, num_classes=10, embedding_size=2) embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) loss = loss_func(embeddings, labels) loss.backward() proxies = torch.nn.functional.normalize(loss_func.proxies, p=2, dim=1) correct_loss = 0 for i in range(len(embeddings)): curr_emb, curr_label = embeddings[i], labels[i] curr_proxy = proxies[curr_label] denominator = torch.sum((curr_emb - proxies)**2, dim=1) denominator = torch.sum(torch.exp(-denominator * softmax_scale)) numerator = torch.sum((curr_emb - curr_proxy)**2) numerator = torch.exp(-numerator * softmax_scale) correct_loss += -torch.log(numerator / denominator) correct_loss /= len(embeddings) rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))
def test_angular_miner(self): for dtype in TEST_DTYPES: embedding_angles = torch.arange(0, 16) embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.randint(low=0, high=2, size=(16,)) triplets = [] for i in range(len(embeddings)): anchor, anchor_label = embeddings[i], labels[i] for j in range(len(embeddings)): if j == i: continue positive, positive_label = embeddings[j], labels[j] center = (anchor + positive) / 2 if positive_label == anchor_label: ap_dist = torch.nn.functional.pairwise_distance(anchor.unsqueeze(0), positive.unsqueeze(0), 2) for k in range(len(embeddings)): if k == j or k == i: continue negative, negative_label = embeddings[k], labels[k] if negative_label != positive_label: nc_dist = torch.nn.functional.pairwise_distance(center.unsqueeze(0), negative.unsqueeze(0), 2) angle = torch.atan(ap_dist / (2*nc_dist)) triplets.append((i,j,k,angle)) for angle_in_degrees in range(0, 70, 10): miner = AngularMiner(angle_in_degrees) angle_in_radians = np.radians(angle_in_degrees) correct = [] for i,j,k,angle in triplets: if angle > angle_in_radians: correct.append((i,j,k)) correct_triplets = set(correct) a1, p1, n1 = miner(embeddings, labels) mined_triplets = set([(a.item(),p.item(),n.item()) for a,p,n in zip(a1,p1,n1)]) self.assertTrue(mined_triplets == correct_triplets)
def test_lifted_structure_loss(self): neg_margin = 0.5 loss_func = LiftedStructureLoss(neg_margin=neg_margin) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) loss.backward() pos_pairs = [(0,1), (1,0), (2,3), (3,2)] neg_pairs = [(0,2), (0,3), (0,4), (1,2), (1,3), (1,4), (2,0), (2,1), (2,4), (3,0), (3,1), (3,4), (4,0), (4,1), (4,2), (4,3)] total_loss = 0 for a1,p in pos_pairs: anchor, positive = embeddings[a1], embeddings[p] pos_pair_component = torch.sqrt(torch.sum((anchor-positive)**2)) neg_pair_component = 0 for a2,n in neg_pairs: negative = embeddings[n] if a2 == a1: neg_pair_component += torch.exp(neg_margin - torch.sqrt(torch.sum((anchor-negative)**2))) elif a2 == p: neg_pair_component += torch.exp(neg_margin - torch.sqrt(torch.sum((positive-negative)**2))) else: continue total_loss += torch.relu(torch.log(neg_pair_component) + pos_pair_component)**2 total_loss /= 2*len(pos_pairs) self.assertTrue(torch.isclose(loss, total_loss))
def test_angular_loss(self): loss_func = AngularLoss(alpha=40) for dtype in [torch.float16, torch.float32, torch.float64]: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) loss.backward() sq_tan_alpha = torch.tan( torch.tensor(np.radians(40), dtype=dtype).to(self.device))**2 triplets = [(0, 1, 2), (0, 1, 3), (0, 1, 4), (1, 0, 2), (1, 0, 3), (1, 0, 4), (2, 3, 0), (2, 3, 1), (2, 3, 4), (3, 2, 0), (3, 2, 1), (3, 2, 4)] correct_losses = [0, 0, 0, 0] for a, p, n in triplets: anchor, positive, negative = embeddings[a], embeddings[ p], embeddings[n] exponent = 4 * sq_tan_alpha * torch.matmul( anchor + positive, negative) - 2 * ( 1 + sq_tan_alpha) * torch.matmul(anchor, positive) correct_losses[a] += torch.exp(exponent) total_loss = 0 for c in correct_losses: total_loss += torch.log(1 + c) total_loss /= len(correct_losses) self.assertTrue(torch.isclose(loss, total_loss))
def test_arcface_loss(self): margin = 30 scale = 64 loss_func = ArcFaceLoss(margin=margin, scale=scale, num_classes=10, embedding_size=2) embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) loss = loss_func(embeddings, labels) loss.backward() weights = torch.nn.functional.normalize(loss_func.W, p=2, dim=0) logits = torch.matmul(embeddings, weights) for i, c in enumerate(labels): logits[i, c] = torch.cos( torch.acos(logits[i, c]) + torch.tensor(np.radians(margin))) correct_loss = torch.nn.functional.cross_entropy( logits * scale, labels) self.assertTrue(torch.isclose(loss, correct_loss))
def test_tuplet_margin_loss(self): margin, scale = 5, 64 loss_func = TupletMarginLoss(margin=margin, scale=scale) for dtype in TEST_DTYPES: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to( self.device ) # 2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) loss.backward() pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)] neg_pairs = [ (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3), ] correct_total = 0 for a1, p in pos_pairs: curr_loss = 0 anchor1, positive = embeddings[a1], embeddings[p] ap_angle = torch.acos( torch.matmul(anchor1, positive) ) # embeddings are normalized, so dot product == cosine ap_cos = torch.cos(ap_angle - np.radians(margin)) for a2, n in neg_pairs: if a2 == a1: anchor2, negative = embeddings[a2], embeddings[n] an_cos = torch.matmul(anchor2, negative) curr_loss += torch.exp(scale * (an_cos - ap_cos)) curr_total = torch.log(1 + curr_loss) correct_total += curr_total correct_total /= len(pos_pairs) rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(loss, correct_total, rtol=rtol))
def test_with_no_valid_triplets(self): loss_func = AngularLoss(alpha=40) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.FloatTensor( [c_f.angle_to_coord(a) for a in embedding_angles]) #2D embeddings labels = torch.LongTensor([0, 1, 2, 3, 4]) loss = loss_func(embeddings, labels) self.assertEqual(loss, 0)
def test_with_no_valid_triplets(self): margin, nu, beta = 0.1, 0, 1 loss_func = MarginLoss(margin=margin, nu=nu, beta=beta) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.FloatTensor( [c_f.angle_to_coord(a) for a in embedding_angles]) #2D embeddings labels = torch.LongTensor([0, 1, 2, 3, 4]) self.assertEqual(loss_func(embeddings, labels), 0)
def test_with_no_valid_triplets(self): loss_funcA = TripletMarginLoss(margin=0.2, avg_non_zero_only=True) loss_funcB = TripletMarginLoss(margin=0.2, avg_non_zero_only=False) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.FloatTensor([c_f.angle_to_coord(a) for a in embedding_angles]) #2D embeddings labels = torch.LongTensor([0, 1, 2, 3, 4]) self.assertEqual(loss_funcA(embeddings, labels), 0) self.assertEqual(loss_funcB(embeddings, labels), 0)