Ejemplo n.º 1
0
    def test_triplet_margin_loss(self):
        margin = 0.2
        loss_funcA = TripletMarginLoss(margin=margin)
        loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer())
        embedding_angles = [0, 20, 40, 60, 80]
        embeddings = torch.tensor(
            [c_f.angle_to_coord(a) for a in embedding_angles],
            requires_grad=True,
            dtype=torch.float)  #2D embeddings
        labels = torch.LongTensor([0, 0, 1, 1, 2])

        lossA = loss_funcA(embeddings, labels)
        lossB = loss_funcB(embeddings, labels)
        lossA.backward()
        lossB.backward()

        triplets = [(0, 1, 2), (0, 1, 3), (0, 1, 4), (1, 0, 2), (1, 0, 3),
                    (1, 0, 4), (2, 3, 0), (2, 3, 1), (2, 3, 4), (3, 2, 0),
                    (3, 2, 1), (3, 2, 4)]

        correct_loss = 0
        num_non_zero_triplets = 0
        for a, p, n in triplets:
            anchor, positive, negative = embeddings[a], embeddings[
                p], embeddings[n]
            curr_loss = torch.relu(
                torch.sqrt(torch.sum((anchor - positive)**2)) -
                torch.sqrt(torch.sum((anchor - negative)**2)) + margin)
            if curr_loss > 0:
                num_non_zero_triplets += 1
            correct_loss += curr_loss
        self.assertTrue(
            torch.isclose(lossA, correct_loss / num_non_zero_triplets))
        self.assertTrue(torch.isclose(lossB, correct_loss / len(triplets)))
 def test_with_no_valid_triplets(self):
     loss_funcA = TripletMarginLoss(margin=0.2, avg_non_zero_only=True)
     loss_funcB = TripletMarginLoss(margin=0.2, avg_non_zero_only=False)
     embedding_angles = [0, 20, 40, 60, 80]
     embeddings = torch.FloatTensor([c_f.angle_to_coord(a) for a in embedding_angles]) #2D embeddings
     labels = torch.LongTensor([0, 1, 2, 3, 4])
     self.assertEqual(loss_funcA(embeddings, labels), 0)
     self.assertEqual(loss_funcB(embeddings, labels), 0)
Ejemplo n.º 3
0
 def test_with_no_valid_triplets(self):
     loss_funcA = TripletMarginLoss(margin=0.2)
     loss_funcB = TripletMarginLoss(margin=0.2, reducer=MeanReducer())
     for dtype in [torch.float16, torch.float32, torch.float64]:
         embedding_angles = [0, 20, 40, 60, 80]
         embeddings = torch.tensor(
             [c_f.angle_to_coord(a) for a in embedding_angles],
             requires_grad=True,
             dtype=dtype).to(self.device)  #2D embeddings
         labels = torch.LongTensor([0, 1, 2, 3, 4])
         lossA = loss_funcA(embeddings, labels)
         lossB = loss_funcB(embeddings, labels)
         self.assertEqual(lossA, 0)
         self.assertEqual(lossB, 0)
Ejemplo n.º 4
0
    def test_backward(self):
        margin = 0.2
        loss_funcA = TripletMarginLoss(margin=margin)
        loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer())
        for dtype in [torch.float16, torch.float32, torch.float64]:
            for loss_func in [loss_funcA, loss_funcB]:
                embedding_angles = [0, 20, 40, 60, 80]
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=dtype).to(self.device)  #2D embeddings
                labels = torch.LongTensor([0, 0, 1, 1, 2])

                loss = loss_func(embeddings, labels)
                loss.backward()
    def test_key_mismatch(self):
        lossA = ContrastiveLoss()
        lossB = TripletMarginLoss(0.1)
        self.assertRaises(
            AssertionError,
            lambda: MultipleLosses(
                losses={
                    "lossA": lossA,
                    "lossB": lossB
                },
                weights={
                    "blah": 1,
                    "lossB": 0.23
                },
            ),
        )

        minerA = MultiSimilarityMiner()
        self.assertRaises(
            AssertionError,
            lambda: MultipleLosses(
                losses={
                    "lossA": lossA,
                    "lossB": lossB
                },
                weights={
                    "lossA": 1,
                    "lossB": 0.23
                },
                miners={"blah": minerA},
            ),
        )
    def test_input_indices_tuple(self):
        lossA = ContrastiveLoss()
        lossB = TripletMarginLoss(0.1)
        miner = MultiSimilarityMiner()
        loss_func1 = MultipleLosses(losses={
            "lossA": lossA,
            "lossB": lossB
        },
                                    weights={
                                        "lossA": 1,
                                        "lossB": 0.23
                                    })

        loss_func2 = MultipleLosses(losses=[lossA, lossB], weights=[1, 0.23])

        for loss_func in [loss_func1, loss_func2]:
            for dtype in TEST_DTYPES:
                embedding_angles = torch.arange(0, 180)
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=dtype,
                ).to(TEST_DEVICE)  # 2D embeddings
                labels = torch.randint(low=0, high=10, size=(180, ))
                indices_tuple = miner(embeddings, labels)

                loss = loss_func(embeddings, labels, indices_tuple)
                loss.backward()

                correct_loss = (
                    lossA(embeddings, labels, indices_tuple) +
                    lossB(embeddings, labels, indices_tuple) * 0.23)
                self.assertTrue(torch.isclose(loss, correct_loss))
Ejemplo n.º 7
0
    def test_multiple_losses(self):
        lossA = ContrastiveLoss()
        lossB = TripletMarginLoss(0.1)
        loss_func = MultipleLosses(losses={
            "lossA": lossA,
            "lossB": lossB
        },
                                   weights={
                                       "lossA": 1,
                                       "lossB": 0.23
                                   })

        for dtype in TEST_DTYPES:
            embedding_angles = torch.arange(0, 180)
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype).to(self.device)  #2D embeddings
            labels = torch.randint(low=0, high=10, size=(180, ))

            loss = loss_func(embeddings, labels)
            loss.backward()

            correct_loss = lossA(embeddings,
                                 labels) + lossB(embeddings, labels) * 0.23
            self.assertTrue(torch.isclose(loss, correct_loss))
Ejemplo n.º 8
0
    def test_metric_loss_only(self):
        loss_fn = NTXentLoss()
        dataset = datasets.FakeData()
        model = torch.nn.Identity()
        batch_size = 32

        for trainer_class in [
            MetricLossOnly,
            DeepAdversarialMetricLearning,
            TrainWithClassifier,
            TwoStreamMetricLoss,
        ]:

            model_dict = {"trunk": model}
            optimizer_dict = {"trunk_optimizer": None}
            loss_fn_dict = {"metric_loss": loss_fn}
            lr_scheduler_dict = {"trunk_scheduler_by_iteration": None}

            if trainer_class is DeepAdversarialMetricLearning:
                model_dict["generator"] = model
                loss_fn_dict["synth_loss"] = loss_fn
                loss_fn_dict["g_adv_loss"] = TripletMarginLoss()

            kwargs = {
                "models": model_dict,
                "optimizers": optimizer_dict,
                "batch_size": batch_size,
                "loss_funcs": loss_fn_dict,
                "mining_funcs": {},
                "dataset": dataset,
                "freeze_these": ["trunk"],
                "lr_schedulers": lr_scheduler_dict,
            }

            trainer = trainer_class(**kwargs)

            for k in [
                "models",
                "mining_funcs",
                "loss_funcs",
                "freeze_these",
                "lr_schedulers",
            ]:
                new_kwargs = copy.deepcopy(kwargs)
                if k == "models":
                    new_kwargs[k] = {}
                if k == "mining_funcs":
                    new_kwargs[k] = {"dog": None}
                if k == "loss_funcs":
                    if trainer_class is DeepAdversarialMetricLearning:
                        new_kwargs[k] = {}
                    else:
                        continue
                if k == "freeze_these":
                    new_kwargs[k] = ["frog"]
                if k == "lr_schedulers":
                    new_kwargs[k] = {"trunk_scheduler": None}
                with self.assertRaises(AssertionError):
                    trainer = trainer_class(**new_kwargs)
    def test_backward(self):
        margin = 0.2
        loss_funcA = TripletMarginLoss(margin=margin)
        loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer())
        loss_funcC = TripletMarginLoss(smooth_loss=True)
        for dtype in TEST_DTYPES:
            for loss_func in [loss_funcA, loss_funcB, loss_funcC]:
                embedding_angles = [0, 20, 40, 60, 80]
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=dtype,
                ).to(TEST_DEVICE)  # 2D embeddings
                labels = torch.LongTensor([0, 0, 1, 1, 2])

                loss = loss_func(embeddings, labels)
                loss.backward()
Ejemplo n.º 10
0
def normal_loss(c, scale):
    criterion = TripletMarginLoss(triplets_per_anchor=1, distance=LpDistance(normalize_embeddings=False, p=1))
    l_rv = stats.norm(loc=c, scale=scale)
    r_rv = stats.norm(scale=scale)
    r_rv = l_rv
    l = l_rv.rvs(10).reshape((5, 2))
    r = r_rv.rvs(10).reshape((5, 2))
    df = pd.DataFrame()
    df['x'] = np.concatenate((l[:, 0], r[:, 0]))
    df['y'] = np.concatenate((l[:, 1], r[:, 1]))
    df['label'] = np.concatenate((np.zeros(5), np.ones(5)))
    embeddings = torch.as_tensor(np.concatenate((l, r)))
    labels = torch.as_tensor(df['label'])
    loss = criterion(embeddings, labels)
    print(f'center = {c}, loss = {loss}')
    def test_length_mistmatch(self):
        lossA = ContrastiveLoss()
        lossB = TripletMarginLoss(0.1)
        self.assertRaises(
            AssertionError,
            lambda: MultipleLosses(losses=[lossA, lossB], weights=[1]))

        minerA = MultiSimilarityMiner()
        self.assertRaises(
            AssertionError,
            lambda: MultipleLosses(
                losses=[lossA, lossB],
                weights=[1, 0.2],
                miners=[minerA],
            ),
        )
Ejemplo n.º 12
0
    def test_check_shapes(self):
        embeddings = torch.randn(32, 512, 3)
        labels = torch.randn(32)
        loss_fn = TripletMarginLoss()

        # embeddings is 3-dimensional
        self.assertRaises(ValueError, lambda: loss_fn(embeddings, labels))

        # embeddings does not match labels
        embeddings = torch.randn(33, 512)
        self.assertRaises(ValueError, lambda: loss_fn(embeddings, labels))

        # labels is 2D
        embeddings = torch.randn(32, 512)
        labels = labels.unsqueeze(1)
        self.assertRaises(ValueError, lambda: loss_fn(embeddings, labels))

        # correct shapes
        labels = labels.squeeze(1)
        self.assertTrue(torch.is_tensor(loss_fn(embeddings, labels)))
    def test_triplet_margin_loss(self):
        margin = 0.2
        loss_funcA = TripletMarginLoss(margin=margin)
        loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer())
        loss_funcC = TripletMarginLoss(margin=margin,
                                       distance=CosineSimilarity())
        loss_funcD = TripletMarginLoss(margin=margin,
                                       reducer=MeanReducer(),
                                       distance=CosineSimilarity())
        for dtype in TEST_DTYPES:
            embedding_angles = [0, 20, 40, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype,
            ).to(self.device)  # 2D embeddings
            labels = torch.LongTensor([0, 0, 1, 1, 2])

            lossA = loss_funcA(embeddings, labels)
            lossB = loss_funcB(embeddings, labels)
            lossC = loss_funcC(embeddings, labels)
            lossD = loss_funcD(embeddings, labels)

            triplets = [
                (0, 1, 2),
                (0, 1, 3),
                (0, 1, 4),
                (1, 0, 2),
                (1, 0, 3),
                (1, 0, 4),
                (2, 3, 0),
                (2, 3, 1),
                (2, 3, 4),
                (3, 2, 0),
                (3, 2, 1),
                (3, 2, 4),
            ]

            correct_loss = 0
            correct_loss_cosine = 0
            num_non_zero_triplets = 0
            num_non_zero_triplets_cosine = 0
            for a, p, n in triplets:
                anchor, positive, negative = embeddings[a], embeddings[
                    p], embeddings[n]
                curr_loss = torch.relu(
                    torch.sqrt(torch.sum((anchor - positive)**2)) -
                    torch.sqrt(torch.sum((anchor - negative)**2)) + margin)
                curr_loss_cosine = torch.relu(
                    torch.sum(anchor * negative) -
                    torch.sum(anchor * positive) + margin)
                if curr_loss > 0:
                    num_non_zero_triplets += 1
                if curr_loss_cosine > 0:
                    num_non_zero_triplets_cosine += 1
                correct_loss += curr_loss
                correct_loss_cosine += curr_loss_cosine
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(
                torch.isclose(lossA,
                              correct_loss / num_non_zero_triplets,
                              rtol=rtol))
            self.assertTrue(
                torch.isclose(lossB, correct_loss / len(triplets), rtol=rtol))
            self.assertTrue(
                torch.isclose(lossC,
                              correct_loss_cosine /
                              num_non_zero_triplets_cosine,
                              rtol=rtol))
            self.assertTrue(
                torch.isclose(lossD,
                              correct_loss_cosine / len(triplets),
                              rtol=rtol))
Ejemplo n.º 14
0
 def test_collect_stats_flag(self):
     self.assertTrue(c_f.COLLECT_STATS == WITH_COLLECT_STATS)
     loss_fn = TripletMarginLoss()
     self.assertTrue(loss_fn.collect_stats == WITH_COLLECT_STATS)
     self.assertTrue(loss_fn.distance.collect_stats == WITH_COLLECT_STATS)
     self.assertTrue(loss_fn.reducer.collect_stats == WITH_COLLECT_STATS)
Ejemplo n.º 15
0
 def __init__(self, ignore_index: int = 255, alpha: float = 1.0):
     self.alpha = alpha
     self.class_loss = CrossEntropyLoss(ignore_index=ignore_index)
     self.instance_loss = TripletMarginLoss()
Ejemplo n.º 16
0
def triplet_margin_loss_factory(triplets_per_anchor, normalize_embeddings, p):
    criterion = TripletMarginLoss(
        triplets_per_anchor=triplets_per_anchor,
        distance=LpDistance(normalize_embeddings=normalize_embeddings, p=p))
    return criterion