def test_triplet_margin_loss(self): margin = 0.2 loss_funcA = TripletMarginLoss(margin=margin) loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer()) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) lossA.backward() lossB.backward() triplets = [(0, 1, 2), (0, 1, 3), (0, 1, 4), (1, 0, 2), (1, 0, 3), (1, 0, 4), (2, 3, 0), (2, 3, 1), (2, 3, 4), (3, 2, 0), (3, 2, 1), (3, 2, 4)] correct_loss = 0 num_non_zero_triplets = 0 for a, p, n in triplets: anchor, positive, negative = embeddings[a], embeddings[ p], embeddings[n] curr_loss = torch.relu( torch.sqrt(torch.sum((anchor - positive)**2)) - torch.sqrt(torch.sum((anchor - negative)**2)) + margin) if curr_loss > 0: num_non_zero_triplets += 1 correct_loss += curr_loss self.assertTrue( torch.isclose(lossA, correct_loss / num_non_zero_triplets)) self.assertTrue(torch.isclose(lossB, correct_loss / len(triplets)))
def test_with_no_valid_triplets(self): loss_funcA = TripletMarginLoss(margin=0.2, avg_non_zero_only=True) loss_funcB = TripletMarginLoss(margin=0.2, avg_non_zero_only=False) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.FloatTensor([c_f.angle_to_coord(a) for a in embedding_angles]) #2D embeddings labels = torch.LongTensor([0, 1, 2, 3, 4]) self.assertEqual(loss_funcA(embeddings, labels), 0) self.assertEqual(loss_funcB(embeddings, labels), 0)
def test_with_no_valid_triplets(self): loss_funcA = TripletMarginLoss(margin=0.2) loss_funcB = TripletMarginLoss(margin=0.2, reducer=MeanReducer()) for dtype in [torch.float16, torch.float32, torch.float64]: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0, 1, 2, 3, 4]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) self.assertEqual(lossA, 0) self.assertEqual(lossB, 0)
def test_backward(self): margin = 0.2 loss_funcA = TripletMarginLoss(margin=margin) loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer()) for dtype in [torch.float16, torch.float32, torch.float64]: for loss_func in [loss_funcA, loss_funcB]: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) loss.backward()
def test_key_mismatch(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) self.assertRaises( AssertionError, lambda: MultipleLosses( losses={ "lossA": lossA, "lossB": lossB }, weights={ "blah": 1, "lossB": 0.23 }, ), ) minerA = MultiSimilarityMiner() self.assertRaises( AssertionError, lambda: MultipleLosses( losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }, miners={"blah": minerA}, ), )
def test_input_indices_tuple(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) miner = MultiSimilarityMiner() loss_func1 = MultipleLosses(losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }) loss_func2 = MultipleLosses(losses=[lossA, lossB], weights=[1, 0.23]) for loss_func in [loss_func1, loss_func2]: for dtype in TEST_DTYPES: embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) indices_tuple = miner(embeddings, labels) loss = loss_func(embeddings, labels, indices_tuple) loss.backward() correct_loss = ( lossA(embeddings, labels, indices_tuple) + lossB(embeddings, labels, indices_tuple) * 0.23) self.assertTrue(torch.isclose(loss, correct_loss))
def test_multiple_losses(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) loss_func = MultipleLosses(losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }) for dtype in TEST_DTYPES: embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) loss = loss_func(embeddings, labels) loss.backward() correct_loss = lossA(embeddings, labels) + lossB(embeddings, labels) * 0.23 self.assertTrue(torch.isclose(loss, correct_loss))
def test_metric_loss_only(self): loss_fn = NTXentLoss() dataset = datasets.FakeData() model = torch.nn.Identity() batch_size = 32 for trainer_class in [ MetricLossOnly, DeepAdversarialMetricLearning, TrainWithClassifier, TwoStreamMetricLoss, ]: model_dict = {"trunk": model} optimizer_dict = {"trunk_optimizer": None} loss_fn_dict = {"metric_loss": loss_fn} lr_scheduler_dict = {"trunk_scheduler_by_iteration": None} if trainer_class is DeepAdversarialMetricLearning: model_dict["generator"] = model loss_fn_dict["synth_loss"] = loss_fn loss_fn_dict["g_adv_loss"] = TripletMarginLoss() kwargs = { "models": model_dict, "optimizers": optimizer_dict, "batch_size": batch_size, "loss_funcs": loss_fn_dict, "mining_funcs": {}, "dataset": dataset, "freeze_these": ["trunk"], "lr_schedulers": lr_scheduler_dict, } trainer = trainer_class(**kwargs) for k in [ "models", "mining_funcs", "loss_funcs", "freeze_these", "lr_schedulers", ]: new_kwargs = copy.deepcopy(kwargs) if k == "models": new_kwargs[k] = {} if k == "mining_funcs": new_kwargs[k] = {"dog": None} if k == "loss_funcs": if trainer_class is DeepAdversarialMetricLearning: new_kwargs[k] = {} else: continue if k == "freeze_these": new_kwargs[k] = ["frog"] if k == "lr_schedulers": new_kwargs[k] = {"trunk_scheduler": None} with self.assertRaises(AssertionError): trainer = trainer_class(**new_kwargs)
def test_backward(self): margin = 0.2 loss_funcA = TripletMarginLoss(margin=margin) loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer()) loss_funcC = TripletMarginLoss(smooth_loss=True) for dtype in TEST_DTYPES: for loss_func in [loss_funcA, loss_funcB, loss_funcC]: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) loss.backward()
def normal_loss(c, scale): criterion = TripletMarginLoss(triplets_per_anchor=1, distance=LpDistance(normalize_embeddings=False, p=1)) l_rv = stats.norm(loc=c, scale=scale) r_rv = stats.norm(scale=scale) r_rv = l_rv l = l_rv.rvs(10).reshape((5, 2)) r = r_rv.rvs(10).reshape((5, 2)) df = pd.DataFrame() df['x'] = np.concatenate((l[:, 0], r[:, 0])) df['y'] = np.concatenate((l[:, 1], r[:, 1])) df['label'] = np.concatenate((np.zeros(5), np.ones(5))) embeddings = torch.as_tensor(np.concatenate((l, r))) labels = torch.as_tensor(df['label']) loss = criterion(embeddings, labels) print(f'center = {c}, loss = {loss}')
def test_length_mistmatch(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) self.assertRaises( AssertionError, lambda: MultipleLosses(losses=[lossA, lossB], weights=[1])) minerA = MultiSimilarityMiner() self.assertRaises( AssertionError, lambda: MultipleLosses( losses=[lossA, lossB], weights=[1, 0.2], miners=[minerA], ), )
def test_check_shapes(self): embeddings = torch.randn(32, 512, 3) labels = torch.randn(32) loss_fn = TripletMarginLoss() # embeddings is 3-dimensional self.assertRaises(ValueError, lambda: loss_fn(embeddings, labels)) # embeddings does not match labels embeddings = torch.randn(33, 512) self.assertRaises(ValueError, lambda: loss_fn(embeddings, labels)) # labels is 2D embeddings = torch.randn(32, 512) labels = labels.unsqueeze(1) self.assertRaises(ValueError, lambda: loss_fn(embeddings, labels)) # correct shapes labels = labels.squeeze(1) self.assertTrue(torch.is_tensor(loss_fn(embeddings, labels)))
def test_triplet_margin_loss(self): margin = 0.2 loss_funcA = TripletMarginLoss(margin=margin) loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer()) loss_funcC = TripletMarginLoss(margin=margin, distance=CosineSimilarity()) loss_funcD = TripletMarginLoss(margin=margin, reducer=MeanReducer(), distance=CosineSimilarity()) for dtype in TEST_DTYPES: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(self.device) # 2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) lossC = loss_funcC(embeddings, labels) lossD = loss_funcD(embeddings, labels) triplets = [ (0, 1, 2), (0, 1, 3), (0, 1, 4), (1, 0, 2), (1, 0, 3), (1, 0, 4), (2, 3, 0), (2, 3, 1), (2, 3, 4), (3, 2, 0), (3, 2, 1), (3, 2, 4), ] correct_loss = 0 correct_loss_cosine = 0 num_non_zero_triplets = 0 num_non_zero_triplets_cosine = 0 for a, p, n in triplets: anchor, positive, negative = embeddings[a], embeddings[ p], embeddings[n] curr_loss = torch.relu( torch.sqrt(torch.sum((anchor - positive)**2)) - torch.sqrt(torch.sum((anchor - negative)**2)) + margin) curr_loss_cosine = torch.relu( torch.sum(anchor * negative) - torch.sum(anchor * positive) + margin) if curr_loss > 0: num_non_zero_triplets += 1 if curr_loss_cosine > 0: num_non_zero_triplets_cosine += 1 correct_loss += curr_loss correct_loss_cosine += curr_loss_cosine rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue( torch.isclose(lossA, correct_loss / num_non_zero_triplets, rtol=rtol)) self.assertTrue( torch.isclose(lossB, correct_loss / len(triplets), rtol=rtol)) self.assertTrue( torch.isclose(lossC, correct_loss_cosine / num_non_zero_triplets_cosine, rtol=rtol)) self.assertTrue( torch.isclose(lossD, correct_loss_cosine / len(triplets), rtol=rtol))
def test_collect_stats_flag(self): self.assertTrue(c_f.COLLECT_STATS == WITH_COLLECT_STATS) loss_fn = TripletMarginLoss() self.assertTrue(loss_fn.collect_stats == WITH_COLLECT_STATS) self.assertTrue(loss_fn.distance.collect_stats == WITH_COLLECT_STATS) self.assertTrue(loss_fn.reducer.collect_stats == WITH_COLLECT_STATS)
def __init__(self, ignore_index: int = 255, alpha: float = 1.0): self.alpha = alpha self.class_loss = CrossEntropyLoss(ignore_index=ignore_index) self.instance_loss = TripletMarginLoss()
def triplet_margin_loss_factory(triplets_per_anchor, normalize_embeddings, p): criterion = TripletMarginLoss( triplets_per_anchor=triplets_per_anchor, distance=LpDistance(normalize_embeddings=normalize_embeddings, p=p)) return criterion