def test_avg_non_zero_reducer(self): reducer = AvgNonZeroReducer() batch_size = 100 embedding_size = 64 embeddings = torch.randn(batch_size, embedding_size) labels = torch.randint(0, 10, (batch_size, )) pair_indices = (torch.randint(0, batch_size, (batch_size, )), torch.randint(0, batch_size, (batch_size, ))) triplet_indices = pair_indices + (torch.randint( 0, batch_size, (batch_size, )), ) losses = torch.randn(batch_size) zero_losses = torch.zeros(batch_size) for indices, reduction_type in [(torch.arange(batch_size), "element"), (pair_indices, "pos_pair"), (pair_indices, "neg_pair"), (triplet_indices, "triplet")]: for L in [losses, zero_losses]: loss_dict = { "loss": { "losses": L, "indices": indices, "reduction_type": reduction_type } } output = reducer(loss_dict, embeddings, labels) filtered_L = L[L > 0] if len(filtered_L) > 0: correct_output = torch.mean(filtered_L) else: correct_output = torch.mean(L) * 0 self.assertTrue(output == correct_output)
def test_deepcopy_reducer(self): loss_fn = ContrastiveLoss(pos_margin=0, neg_margin=2, reducer=AvgNonZeroReducer()) embeddings = torch.randn(128, 64) labels = torch.randint(low=0, high=10, size=(128, )) loss = loss_fn(embeddings, labels) self.assertTrue( loss_fn.reducer.reducers["pos_loss"].pos_pairs_past_filter > 0) self.assertTrue( loss_fn.reducer.reducers["neg_loss"].neg_pairs_past_filter > 0)
def test_setting_reducers(self): for loss in [TripletMarginLoss, ContrastiveLoss]: for reducer in [ ThresholdReducer(low=0), MeanReducer(), AvgNonZeroReducer(), ]: L = loss(reducer=reducer) if isinstance(L, TripletMarginLoss): assert type(L.reducer) == type(reducer) else: for v in L.reducer.reducers.values(): assert type(v) == type(reducer)
def test_multiple_reducers(self): reducer = MultipleReducers({ "lossA": AvgNonZeroReducer(), "lossB": DivisorReducer() }) batch_size = 100 embedding_size = 64 for dtype in TEST_DTYPES: embeddings = (torch.randn( batch_size, embedding_size).type(dtype).to(TEST_DEVICE)) labels = torch.randint(0, 10, (batch_size, )) pair_indices = ( torch.randint(0, batch_size, (batch_size, )), torch.randint(0, batch_size, (batch_size, )), ) triplet_indices = pair_indices + (torch.randint( 0, batch_size, (batch_size, )), ) lossesA = torch.randn(batch_size).type(dtype).to(TEST_DEVICE) lossesB = torch.randn(batch_size).type(dtype).to(TEST_DEVICE) for indices, reduction_type in [ (torch.arange(batch_size), "element"), (pair_indices, "pos_pair"), (pair_indices, "neg_pair"), (triplet_indices, "triplet"), ]: loss_dict = { "lossA": { "losses": lossesA, "indices": indices, "reduction_type": reduction_type, }, "lossB": { "losses": lossesB, "indices": indices, "reduction_type": reduction_type, "divisor_summands": { "partA": 32, "partB": 15 }, }, } output = reducer(loss_dict, embeddings, labels) correct_output = (torch.mean( lossesA[lossesA > 0])) + (torch.sum(lossesB) / (32 + 15)) self.assertTrue(output == correct_output)
def test_ntxent_loss(self): temperature = 0.1 loss_funcA = NTXentLoss(temperature=temperature) loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance()) loss_funcC = NTXentLoss( temperature=temperature, reducer=PerAnchorReducer(AvgNonZeroReducer()) ) loss_funcD = SupConLoss(temperature=temperature) loss_funcE = SupConLoss(temperature=temperature, distance=LpDistance()) for dtype in TEST_DTYPES: embedding_angles = [0, 10, 20, 50, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to( TEST_DEVICE ) # 2D embeddings labels = torch.LongTensor([0, 0, 0, 1, 1, 2]) obtained_losses = [ x(embeddings, labels) for x in [loss_funcA, loss_funcB, loss_funcC, loss_funcD, loss_funcE] ] pos_pairs = [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1), (3, 4), (4, 3)] neg_pairs = [ (0, 3), (0, 4), (0, 5), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 0), (3, 1), (3, 2), (3, 5), (4, 0), (4, 1), (4, 2), (4, 5), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), ] total_lossA, total_lossB, total_lossC, total_lossD, total_lossE = ( 0, 0, torch.zeros(5, device=TEST_DEVICE, dtype=dtype), torch.zeros(5, device=TEST_DEVICE, dtype=dtype), torch.zeros(5, device=TEST_DEVICE, dtype=dtype), ) for a1, p in pos_pairs: anchor, positive = embeddings[a1], embeddings[p] numeratorA = torch.exp(torch.matmul(anchor, positive) / temperature) numeratorB = torch.exp( -torch.sqrt(torch.sum((anchor - positive) ** 2)) / temperature ) denominatorA = numeratorA.clone() denominatorB = numeratorB.clone() denominatorD = 0 denominatorE = 0 for a2, n in pos_pairs + neg_pairs: if a2 == a1: negative = embeddings[n] curr_denomD = torch.exp( torch.matmul(anchor, negative) / temperature ) curr_denomE = torch.exp( -torch.sqrt(torch.sum((anchor - negative) ** 2)) / temperature ) denominatorD += curr_denomD denominatorE += curr_denomE if (a2, n) not in pos_pairs: denominatorA += curr_denomD denominatorB += curr_denomE else: continue curr_lossA = -torch.log(numeratorA / denominatorA) curr_lossB = -torch.log(numeratorB / denominatorB) curr_lossD = -torch.log(numeratorA / denominatorD) curr_lossE = -torch.log(numeratorB / denominatorE) total_lossA += curr_lossA total_lossB += curr_lossB total_lossC[a1] += curr_lossA total_lossD[a1] += curr_lossD total_lossE[a1] += curr_lossE total_lossA /= len(pos_pairs) total_lossB /= len(pos_pairs) pos_pair_per_anchor = torch.tensor( [2, 2, 2, 1, 1], device=TEST_DEVICE, dtype=dtype ) total_lossC, total_lossD, total_lossE = [ torch.mean(x / pos_pair_per_anchor) for x in [total_lossC, total_lossD, total_lossE] ] rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(obtained_losses[0], total_lossA, rtol=rtol)) self.assertTrue(torch.isclose(obtained_losses[1], total_lossB, rtol=rtol)) self.assertTrue(torch.isclose(obtained_losses[2], total_lossC, rtol=rtol)) self.assertTrue(torch.isclose(obtained_losses[3], total_lossD, rtol=rtol)) self.assertTrue(torch.isclose(obtained_losses[4], total_lossE, rtol=rtol))
def get_default_reducer(self): return AvgNonZeroReducer()
def test_per_anchor_reducer(self): for inner_reducer in [MeanReducer(), AvgNonZeroReducer()]: reducer = PerAnchorReducer(inner_reducer) batch_size = 100 embedding_size = 64 for dtype in TEST_DTYPES: embeddings = ( torch.randn(batch_size, embedding_size).type(dtype).to(TEST_DEVICE) ) labels = torch.randint(0, 10, (batch_size,)) pos_pair_indices = lmu.get_all_pairs_indices(labels)[:2] neg_pair_indices = lmu.get_all_pairs_indices(labels)[2:] triplet_indices = lmu.get_all_triplets_indices(labels) for indices, reduction_type in [ (torch.arange(batch_size), "element"), (pos_pair_indices, "pos_pair"), (neg_pair_indices, "neg_pair"), (triplet_indices, "triplet"), ]: loss_size = ( len(indices) if reduction_type == "element" else len(indices[0]) ) losses = torch.randn(loss_size).type(dtype).to(TEST_DEVICE) loss_dict = { "loss": { "losses": losses, "indices": indices, "reduction_type": reduction_type, } } if reduction_type == "triplet": self.assertRaises( NotImplementedError, lambda: reducer(loss_dict, embeddings, labels), ) continue output = reducer(loss_dict, embeddings, labels) if reduction_type == "element": loss_dict = { "loss": { "losses": losses, "indices": c_f.torch_arange_from_size(embeddings), "reduction_type": "element", } } else: anchors = indices[0] correct_output = torch.zeros( batch_size, device=TEST_DEVICE, dtype=dtype ) for i in range(len(embeddings)): matching_pairs_mask = anchors == i num_matching_pairs = torch.sum(matching_pairs_mask) if num_matching_pairs > 0: correct_output[i] = ( torch.sum(losses[matching_pairs_mask]) / num_matching_pairs ) loss_dict = { "loss": { "losses": correct_output, "indices": c_f.torch_arange_from_size(embeddings), "reduction_type": "element", } } correct_output = inner_reducer(loss_dict, embeddings, labels) self.assertTrue(torch.isclose(output, correct_output))