def test_mean_reducer(self): reducer = MeanReducer() batch_size = 100 embedding_size = 64 for dtype in TEST_DTYPES: embeddings = (torch.randn( batch_size, embedding_size).type(dtype).to(self.device)) labels = torch.randint(0, 10, (batch_size, )) pair_indices = ( torch.randint(0, batch_size, (batch_size, )), torch.randint(0, batch_size, (batch_size, )), ) triplet_indices = pair_indices + (torch.randint( 0, batch_size, (batch_size, )), ) losses = torch.randn(batch_size).type(dtype).to(self.device) for indices, reduction_type in [ (torch.arange(batch_size), "element"), (pair_indices, "pos_pair"), (pair_indices, "neg_pair"), (triplet_indices, "triplet"), ]: loss_dict = { "loss": { "losses": losses, "indices": indices, "reduction_type": reduction_type, } } output = reducer(loss_dict, embeddings, labels) correct_output = torch.mean(losses) self.assertTrue(output == correct_output)
def test_triplet_margin_loss(self): margin = 0.2 loss_funcA = TripletMarginLoss(margin=margin) loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer()) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) lossA.backward() lossB.backward() triplets = [(0, 1, 2), (0, 1, 3), (0, 1, 4), (1, 0, 2), (1, 0, 3), (1, 0, 4), (2, 3, 0), (2, 3, 1), (2, 3, 4), (3, 2, 0), (3, 2, 1), (3, 2, 4)] correct_loss = 0 num_non_zero_triplets = 0 for a, p, n in triplets: anchor, positive, negative = embeddings[a], embeddings[ p], embeddings[n] curr_loss = torch.relu( torch.sqrt(torch.sum((anchor - positive)**2)) - torch.sqrt(torch.sum((anchor - negative)**2)) + margin) if curr_loss > 0: num_non_zero_triplets += 1 correct_loss += curr_loss self.assertTrue( torch.isclose(lossA, correct_loss / num_non_zero_triplets)) self.assertTrue(torch.isclose(lossB, correct_loss / len(triplets)))
def test_setting_reducers(self): for loss in [TripletMarginLoss, ContrastiveLoss]: for reducer in [ ThresholdReducer(low=0), MeanReducer(), AvgNonZeroReducer(), ]: L = loss(reducer=reducer) if isinstance(L, TripletMarginLoss): assert type(L.reducer) == type(reducer) else: for v in L.reducer.reducers.values(): assert type(v) == type(reducer)
def test_with_no_valid_triplets(self): loss_funcA = TripletMarginLoss(margin=0.2) loss_funcB = TripletMarginLoss(margin=0.2, reducer=MeanReducer()) for dtype in [torch.float16, torch.float32, torch.float64]: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0, 1, 2, 3, 4]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) self.assertEqual(lossA, 0) self.assertEqual(lossB, 0)
def test_backward(self): margin = 0.2 loss_funcA = TripletMarginLoss(margin=margin) loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer()) for dtype in [torch.float16, torch.float32, torch.float64]: for loss_func in [loss_funcA, loss_funcB]: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) loss.backward()
def test_backward(self): margin = 0.2 loss_funcA = TripletMarginLoss(margin=margin) loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer()) loss_funcC = TripletMarginLoss(smooth_loss=True) for dtype in TEST_DTYPES: for loss_func in [loss_funcA, loss_funcB, loss_funcC]: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) loss.backward()
def test_triplet_margin_loss(self): margin = 0.2 loss_funcA = TripletMarginLoss(margin=margin) loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer()) loss_funcC = TripletMarginLoss(margin=margin, distance=CosineSimilarity()) loss_funcD = TripletMarginLoss(margin=margin, reducer=MeanReducer(), distance=CosineSimilarity()) for dtype in TEST_DTYPES: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(self.device) # 2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) lossC = loss_funcC(embeddings, labels) lossD = loss_funcD(embeddings, labels) triplets = [ (0, 1, 2), (0, 1, 3), (0, 1, 4), (1, 0, 2), (1, 0, 3), (1, 0, 4), (2, 3, 0), (2, 3, 1), (2, 3, 4), (3, 2, 0), (3, 2, 1), (3, 2, 4), ] correct_loss = 0 correct_loss_cosine = 0 num_non_zero_triplets = 0 num_non_zero_triplets_cosine = 0 for a, p, n in triplets: anchor, positive, negative = embeddings[a], embeddings[ p], embeddings[n] curr_loss = torch.relu( torch.sqrt(torch.sum((anchor - positive)**2)) - torch.sqrt(torch.sum((anchor - negative)**2)) + margin) curr_loss_cosine = torch.relu( torch.sum(anchor * negative) - torch.sum(anchor * positive) + margin) if curr_loss > 0: num_non_zero_triplets += 1 if curr_loss_cosine > 0: num_non_zero_triplets_cosine += 1 correct_loss += curr_loss correct_loss_cosine += curr_loss_cosine rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue( torch.isclose(lossA, correct_loss / num_non_zero_triplets, rtol=rtol)) self.assertTrue( torch.isclose(lossB, correct_loss / len(triplets), rtol=rtol)) self.assertTrue( torch.isclose(lossC, correct_loss_cosine / num_non_zero_triplets_cosine, rtol=rtol)) self.assertTrue( torch.isclose(lossD, correct_loss_cosine / len(triplets), rtol=rtol))
def get_default_reducer(self): return MeanReducer()
def test_contrastive_loss(self): loss_funcA = ContrastiveLoss(pos_margin=0.25, neg_margin=1.5, distance=LpDistance(power=2)) loss_funcB = ContrastiveLoss(pos_margin=1.5, neg_margin=0.6, distance=CosineSimilarity()) loss_funcC = ContrastiveLoss(pos_margin=0.25, neg_margin=1.5, distance=LpDistance(power=2), reducer=MeanReducer()) loss_funcD = ContrastiveLoss(pos_margin=1.5, neg_margin=0.6, distance=CosineSimilarity(), reducer=MeanReducer()) for dtype in TEST_DTYPES: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) lossC = loss_funcC(embeddings, labels) lossD = loss_funcD(embeddings, labels) pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)] neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3)] correct_pos_losses = [0, 0, 0, 0] correct_neg_losses = [0, 0, 0, 0] num_non_zero_pos = [0, 0, 0, 0] num_non_zero_neg = [0, 0, 0, 0] for a, p in pos_pairs: anchor, positive = embeddings[a], embeddings[p] correct_lossA = torch.relu( torch.sum((anchor - positive)**2) - 0.25) correct_lossB = torch.relu(1.5 - torch.matmul(anchor, positive)) correct_pos_losses[0] += correct_lossA correct_pos_losses[1] += correct_lossB correct_pos_losses[2] += correct_lossA correct_pos_losses[3] += correct_lossB if correct_lossA > 0: num_non_zero_pos[0] += 1 num_non_zero_pos[2] += 1 if correct_lossB > 0: num_non_zero_pos[1] += 1 num_non_zero_pos[3] += 1 for a, n in neg_pairs: anchor, negative = embeddings[a], embeddings[n] correct_lossA = torch.relu(1.5 - torch.sum((anchor - negative)**2)) correct_lossB = torch.relu( torch.matmul(anchor, negative) - 0.6) correct_neg_losses[0] += correct_lossA correct_neg_losses[1] += correct_lossB correct_neg_losses[2] += correct_lossA correct_neg_losses[3] += correct_lossB if correct_lossA > 0: num_non_zero_neg[0] += 1 num_non_zero_neg[2] += 1 if correct_lossB > 0: num_non_zero_neg[1] += 1 num_non_zero_neg[3] += 1 for i in range(2): if num_non_zero_pos[i] > 0: correct_pos_losses[i] /= num_non_zero_pos[i] if num_non_zero_neg[i] > 0: correct_neg_losses[i] /= num_non_zero_neg[i] for i in range(2, 4): correct_pos_losses[i] /= len(pos_pairs) correct_neg_losses[i] /= len(neg_pairs) correct_losses = [0, 0, 0, 0] for i in range(4): correct_losses[ i] = correct_pos_losses[i] + correct_neg_losses[i] rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(lossA, correct_losses[0], rtol=rtol)) self.assertTrue(torch.isclose(lossB, correct_losses[1], rtol=rtol)) self.assertTrue(torch.isclose(lossC, correct_losses[2], rtol=rtol)) self.assertTrue(torch.isclose(lossD, correct_losses[3], rtol=rtol))
def test_contrastive_loss(self): loss_funcA = ContrastiveLoss(pos_margin=0.25, neg_margin=1.5, use_similarity=False, squared_distances=True) loss_funcB = ContrastiveLoss(pos_margin=1.5, neg_margin=0.6, use_similarity=True) loss_funcC = ContrastiveLoss(pos_margin=0.25, neg_margin=1.5, use_similarity=False, squared_distances=True, reducer=MeanReducer()) loss_funcD = ContrastiveLoss(pos_margin=1.5, neg_margin=0.6, use_similarity=True, reducer=MeanReducer()) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) lossC = loss_funcC(embeddings, labels) lossD = loss_funcD(embeddings, labels) lossA.backward() lossB.backward() lossC.backward() lossD.backward() pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)] neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3)] correct_pos_losses = [0, 0, 0, 0] correct_neg_losses = [0, 0, 0, 0] num_non_zero_pos = [0, 0, 0, 0] num_non_zero_neg = [0, 0, 0, 0] for a, p in pos_pairs: anchor, positive = embeddings[a], embeddings[p] correct_lossA = torch.relu( torch.sum((anchor - positive)**2) - 0.25) correct_lossB = torch.relu(1.5 - torch.matmul(anchor, positive)) correct_pos_losses[0] += correct_lossA correct_pos_losses[1] += correct_lossB correct_pos_losses[2] += correct_lossA correct_pos_losses[3] += correct_lossB if correct_lossA > 0: num_non_zero_pos[0] += 1 num_non_zero_pos[2] += 1 if correct_lossB > 0: num_non_zero_pos[1] += 1 num_non_zero_pos[3] += 1 for a, n in neg_pairs: anchor, negative = embeddings[a], embeddings[n] correct_lossA = torch.relu(1.5 - torch.sum((anchor - negative)**2)) correct_lossB = torch.relu(torch.matmul(anchor, negative) - 0.6) correct_neg_losses[0] += correct_lossA correct_neg_losses[1] += correct_lossB correct_neg_losses[2] += correct_lossA correct_neg_losses[3] += correct_lossB if correct_lossA > 0: num_non_zero_neg[0] += 1 num_non_zero_neg[2] += 1 if correct_lossB > 0: num_non_zero_neg[1] += 1 num_non_zero_neg[3] += 1 for i in range(2): if num_non_zero_pos[i] > 0: correct_pos_losses[i] /= num_non_zero_pos[i] if num_non_zero_neg[i] > 0: correct_neg_losses[i] /= num_non_zero_neg[i] for i in range(2, 4): correct_pos_losses[i] /= len(pos_pairs) correct_neg_losses[i] /= len(neg_pairs) correct_losses = [0, 0, 0, 0] for i in range(4): correct_losses[i] = correct_pos_losses[i] + correct_neg_losses[i] self.assertTrue(torch.isclose(lossA, correct_losses[0])) self.assertTrue(torch.isclose(lossB, correct_losses[1])) self.assertTrue(torch.isclose(lossC, correct_losses[2])) self.assertTrue(torch.isclose(lossD, correct_losses[3]))
def test_per_anchor_reducer(self): for inner_reducer in [MeanReducer(), AvgNonZeroReducer()]: reducer = PerAnchorReducer(inner_reducer) batch_size = 100 embedding_size = 64 for dtype in TEST_DTYPES: embeddings = ( torch.randn(batch_size, embedding_size).type(dtype).to(TEST_DEVICE) ) labels = torch.randint(0, 10, (batch_size,)) pos_pair_indices = lmu.get_all_pairs_indices(labels)[:2] neg_pair_indices = lmu.get_all_pairs_indices(labels)[2:] triplet_indices = lmu.get_all_triplets_indices(labels) for indices, reduction_type in [ (torch.arange(batch_size), "element"), (pos_pair_indices, "pos_pair"), (neg_pair_indices, "neg_pair"), (triplet_indices, "triplet"), ]: loss_size = ( len(indices) if reduction_type == "element" else len(indices[0]) ) losses = torch.randn(loss_size).type(dtype).to(TEST_DEVICE) loss_dict = { "loss": { "losses": losses, "indices": indices, "reduction_type": reduction_type, } } if reduction_type == "triplet": self.assertRaises( NotImplementedError, lambda: reducer(loss_dict, embeddings, labels), ) continue output = reducer(loss_dict, embeddings, labels) if reduction_type == "element": loss_dict = { "loss": { "losses": losses, "indices": c_f.torch_arange_from_size(embeddings), "reduction_type": "element", } } else: anchors = indices[0] correct_output = torch.zeros( batch_size, device=TEST_DEVICE, dtype=dtype ) for i in range(len(embeddings)): matching_pairs_mask = anchors == i num_matching_pairs = torch.sum(matching_pairs_mask) if num_matching_pairs > 0: correct_output[i] = ( torch.sum(losses[matching_pairs_mask]) / num_matching_pairs ) loss_dict = { "loss": { "losses": correct_output, "indices": c_f.torch_arange_from_size(embeddings), "reduction_type": "element", } } correct_output = inner_reducer(loss_dict, embeddings, labels) self.assertTrue(torch.isclose(output, correct_output))