def test_loss(self): num_labels = 10 num_iter = 10 batch_size = 32 inner_loss = ContrastiveLoss() inner_miner = MultiSimilarityMiner(0.3) outer_miner = MultiSimilarityMiner(0.2) self.loss = CrossBatchMemory(loss=inner_loss, embedding_size=self.embedding_size, memory_size=self.memory_size) self.loss_with_miner = CrossBatchMemory(loss=inner_loss, miner=inner_miner, embedding_size=self.embedding_size, memory_size=self.memory_size) self.loss_with_miner2 = CrossBatchMemory(loss=inner_loss, miner=inner_miner, embedding_size=self.embedding_size, memory_size=self.memory_size) all_embeddings = torch.FloatTensor([]) all_labels = torch.LongTensor([]) for i in range(num_iter): embeddings = torch.randn(batch_size, self.embedding_size) labels = torch.randint(0,num_labels,(batch_size,)) loss = self.loss(embeddings, labels) loss_with_miner = self.loss_with_miner(embeddings, labels) oa1, op, oa2, on = outer_miner(embeddings, labels) loss_with_miner_and_input_indices = self.loss_with_miner2(embeddings, labels, (oa1, op, oa2, on)) all_embeddings = torch.cat([all_embeddings, embeddings]) all_labels = torch.cat([all_labels, labels]) # loss with no inner miner indices_tuple = lmu.get_all_pairs_indices(labels, all_labels) a1,p,a2,n = self.loss.remove_self_comparisons(indices_tuple) p = p+batch_size n = n+batch_size correct_loss = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n)) self.assertTrue(torch.isclose(loss, correct_loss)) # loss with inner miner indices_tuple = inner_miner(embeddings, labels, all_embeddings, all_labels) a1,p,a2,n = self.loss_with_miner.remove_self_comparisons(indices_tuple) p = p+batch_size n = n+batch_size correct_loss_with_miner = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n)) self.assertTrue(torch.isclose(loss_with_miner, correct_loss_with_miner)) # loss with inner and outer miner indices_tuple = inner_miner(embeddings, labels, all_embeddings, all_labels) a1,p,a2,n = self.loss_with_miner2.remove_self_comparisons(indices_tuple) p = p+batch_size n = n+batch_size a1 = torch.cat([oa1, a1]) p = torch.cat([op, p]) a2 = torch.cat([oa2, a2]) n = torch.cat([on, n]) correct_loss_with_miner_and_input_indice = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n)) self.assertTrue(torch.isclose(loss_with_miner_and_input_indices, correct_loss_with_miner_and_input_indice))
def test_tuplestoweights_sampler(self): model = models.resnet18(pretrained=True) model.fc = c_f.Identity() model = torch.nn.DataParallel(model) model.to(torch.device("cuda")) miner = MultiSimilarityMiner(epsilon=-0.2) eval_transform = transforms.Compose([ transforms.Resize(128), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) temporary_folder = "cifar100_temp_for_pytorch_metric_learning_test" dataset = datasets.CIFAR100(temporary_folder, train=True, download=True, transform=eval_transform) subset_size = 1000 sampler = TuplesToWeightsSampler(model, miner, dataset, subset_size=subset_size) iterable_as_list = list(iter(sampler)) self.assertTrue(len(iterable_as_list) == subset_size) unique_idx = torch.unique(torch.tensor(iterable_as_list)) self.assertTrue(torch.all(sampler.weights[unique_idx] != 0)) shutil.rmtree(temporary_folder)
def test_key_mismatch(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) self.assertRaises( AssertionError, lambda: MultipleLosses( losses={ "lossA": lossA, "lossB": lossB }, weights={ "blah": 1, "lossB": 0.23 }, ), ) minerA = MultiSimilarityMiner() self.assertRaises( AssertionError, lambda: MultipleLosses( losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }, miners={"blah": minerA}, ), )
def test_input_indices_tuple(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) miner = MultiSimilarityMiner() loss_func1 = MultipleLosses(losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }) loss_func2 = MultipleLosses(losses=[lossA, lossB], weights=[1, 0.23]) for loss_func in [loss_func1, loss_func2]: for dtype in TEST_DTYPES: embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) indices_tuple = miner(embeddings, labels) loss = loss_func(embeddings, labels, indices_tuple) loss.backward() correct_loss = ( lossA(embeddings, labels, indices_tuple) + lossB(embeddings, labels, indices_tuple) * 0.23) self.assertTrue(torch.isclose(loss, correct_loss))
def test_empty_output(self): miner = MultiSimilarityMiner(0.1) batch_size = 32 for dtype in [torch.float16, torch.float32, torch.float64]: embeddings = torch.randn(batch_size, 64).type(dtype).to(self.device) labels = torch.arange(batch_size) a1, p, _, _ = miner(embeddings, labels) self.assertTrue(len(a1)==0) self.assertTrue(len(p)==0)
def test_empty_output(self): miner = MultiSimilarityMiner(0.1) batch_size = 32 for dtype in TEST_DTYPES: embeddings = torch.randn(batch_size, 64).type(dtype).to(TEST_DEVICE) labels = torch.arange(batch_size) a1, p, _, _ = miner(embeddings, labels) self.assertTrue(len(a1) == 0) self.assertTrue(len(p) == 0)
def test_multi_similarity_miner(self): epsilon = 0.1 miner = MultiSimilarityMiner(epsilon) for dtype in [torch.float16, torch.float32, torch.float64]: embedding_angles = torch.arange(0, 64) embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.randint(low=0, high=10, size=(64,)) pos_pairs = [] neg_pairs = [] for i in range(len(embeddings)): anchor, anchor_label = embeddings[i], labels[i] for j in range(len(embeddings)): if j != i: other, other_label = embeddings[j], labels[j] if anchor_label == other_label: pos_pairs.append((i,j,torch.matmul(anchor, other.t()).item())) if anchor_label != other_label: neg_pairs.append((i,j,torch.matmul(anchor, other.t()).item())) correct_a1, correct_p = [], [] correct_a2, correct_n = [], [] for a1,p,ap_sim in pos_pairs: max_neg_sim = c_f.neg_inf(dtype) for a2,n,an_sim in neg_pairs: if a2==a1: if an_sim > max_neg_sim: max_neg_sim = an_sim if ap_sim < max_neg_sim + epsilon: correct_a1.append(a1) correct_p.append(p) for a2,n,an_sim in neg_pairs: min_pos_sim = c_f.pos_inf(dtype) for a1,p,ap_sim in pos_pairs: if a2==a1: if ap_sim < min_pos_sim: min_pos_sim = ap_sim if an_sim > min_pos_sim - epsilon: correct_a2.append(a2) correct_n.append(n) correct_pos_pairs = set([(a,p) for a,p in zip(correct_a1, correct_p)]) correct_neg_pairs = set([(a,n) for a,n in zip(correct_a2, correct_n)]) a1, p1, a2, n2 = miner(embeddings, labels) pos_pairs = set([(a.item(),p.item()) for a,p in zip(a1,p1)]) neg_pairs = set([(a.item(),n.item()) for a,n in zip(a2,n2)]) self.assertTrue(pos_pairs == correct_pos_pairs) self.assertTrue(neg_pairs == correct_neg_pairs)
def test_length_mistmatch(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) self.assertRaises( AssertionError, lambda: MultipleLosses(losses=[lossA, lossB], weights=[1])) minerA = MultiSimilarityMiner() self.assertRaises( AssertionError, lambda: MultipleLosses( losses=[lossA, lossB], weights=[1, 0.2], miners=[minerA], ), )
def test_multi_similarity_miner(self): epsilon = 0.1 for dtype in TEST_DTYPES: for distance in [CosineSimilarity(), LpDistance()]: miner = MultiSimilarityMiner(epsilon, distance=distance) embedding_angles = torch.arange(0, 64) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=10, size=(64, )) mat = distance(embeddings) pos_pairs = [] neg_pairs = [] for i in range(len(embeddings)): anchor_label = labels[i] for j in range(len(embeddings)): if j != i: other_label = labels[j] if anchor_label == other_label: pos_pairs.append((i, j, mat[i, j])) if anchor_label != other_label: neg_pairs.append((i, j, mat[i, j])) correct_a1, correct_p = [], [] correct_a2, correct_n = [], [] for a1, p, ap_sim in pos_pairs: most_difficult = (c_f.neg_inf(dtype) if distance.is_inverted else c_f.pos_inf(dtype)) for a2, n, an_sim in neg_pairs: if a2 == a1: condition = ((an_sim > most_difficult) if distance.is_inverted else (an_sim < most_difficult)) if condition: most_difficult = an_sim condition = ((ap_sim < most_difficult + epsilon) if distance.is_inverted else (ap_sim > most_difficult - epsilon)) if condition: correct_a1.append(a1) correct_p.append(p) for a2, n, an_sim in neg_pairs: most_difficult = (c_f.pos_inf(dtype) if distance.is_inverted else c_f.neg_inf(dtype)) for a1, p, ap_sim in pos_pairs: if a2 == a1: condition = ((ap_sim < most_difficult) if distance.is_inverted else (ap_sim > most_difficult)) if condition: most_difficult = ap_sim condition = ((an_sim > most_difficult - epsilon) if distance.is_inverted else (an_sim < most_difficult + epsilon)) if condition: correct_a2.append(a2) correct_n.append(n) correct_pos_pairs = set([ (a, p) for a, p in zip(correct_a1, correct_p) ]) correct_neg_pairs = set([ (a, n) for a, n in zip(correct_a2, correct_n) ]) a1, p1, a2, n2 = miner(embeddings, labels) pos_pairs = set([(a.item(), p.item()) for a, p in zip(a1, p1)]) neg_pairs = set([(a.item(), n.item()) for a, n in zip(a2, n2)]) self.assertTrue(pos_pairs == correct_pos_pairs) self.assertTrue(neg_pairs == correct_neg_pairs)