def test_loss(self):
        num_labels = 10
        num_iter = 10
        batch_size = 32
        inner_loss = ContrastiveLoss()
        inner_miner = MultiSimilarityMiner(0.3)
        outer_miner = MultiSimilarityMiner(0.2)
        self.loss = CrossBatchMemory(loss=inner_loss, embedding_size=self.embedding_size, memory_size=self.memory_size)
        self.loss_with_miner = CrossBatchMemory(loss=inner_loss, miner=inner_miner, embedding_size=self.embedding_size, memory_size=self.memory_size)
        self.loss_with_miner2 = CrossBatchMemory(loss=inner_loss, miner=inner_miner, embedding_size=self.embedding_size, memory_size=self.memory_size)
        all_embeddings = torch.FloatTensor([])
        all_labels = torch.LongTensor([])
        for i in range(num_iter):
            embeddings = torch.randn(batch_size, self.embedding_size)
            labels = torch.randint(0,num_labels,(batch_size,))
            loss = self.loss(embeddings, labels)
            loss_with_miner = self.loss_with_miner(embeddings, labels)
            oa1, op, oa2, on = outer_miner(embeddings, labels)
            loss_with_miner_and_input_indices = self.loss_with_miner2(embeddings, labels, (oa1, op, oa2, on))
            all_embeddings = torch.cat([all_embeddings, embeddings])
            all_labels = torch.cat([all_labels, labels])

            # loss with no inner miner
            indices_tuple = lmu.get_all_pairs_indices(labels, all_labels)
            a1,p,a2,n = self.loss.remove_self_comparisons(indices_tuple)
            p = p+batch_size
            n = n+batch_size
            correct_loss = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n))
            self.assertTrue(torch.isclose(loss, correct_loss))

            # loss with inner miner
            indices_tuple = inner_miner(embeddings, labels, all_embeddings, all_labels)
            a1,p,a2,n = self.loss_with_miner.remove_self_comparisons(indices_tuple)
            p = p+batch_size
            n = n+batch_size
            correct_loss_with_miner = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n))
            self.assertTrue(torch.isclose(loss_with_miner, correct_loss_with_miner))

            # loss with inner and outer miner
            indices_tuple = inner_miner(embeddings, labels, all_embeddings, all_labels)
            a1,p,a2,n = self.loss_with_miner2.remove_self_comparisons(indices_tuple)
            p = p+batch_size
            n = n+batch_size
            a1 = torch.cat([oa1, a1])
            p = torch.cat([op, p])
            a2 = torch.cat([oa2, a2])
            n = torch.cat([on, n])
            correct_loss_with_miner_and_input_indice = inner_loss(torch.cat([embeddings, all_embeddings], dim=0), torch.cat([labels, all_labels], dim=0), (a1,p,a2,n))
            self.assertTrue(torch.isclose(loss_with_miner_and_input_indices, correct_loss_with_miner_and_input_indice))
Beispiel #2
0
    def test_tuplestoweights_sampler(self):
        model = models.resnet18(pretrained=True)
        model.fc = c_f.Identity()
        model = torch.nn.DataParallel(model)
        model.to(torch.device("cuda"))

        miner = MultiSimilarityMiner(epsilon=-0.2)

        eval_transform = transforms.Compose([
            transforms.Resize(128),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        temporary_folder = "cifar100_temp_for_pytorch_metric_learning_test"

        dataset = datasets.CIFAR100(temporary_folder,
                                    train=True,
                                    download=True,
                                    transform=eval_transform)
        subset_size = 1000
        sampler = TuplesToWeightsSampler(model,
                                         miner,
                                         dataset,
                                         subset_size=subset_size)
        iterable_as_list = list(iter(sampler))
        self.assertTrue(len(iterable_as_list) == subset_size)
        unique_idx = torch.unique(torch.tensor(iterable_as_list))
        self.assertTrue(torch.all(sampler.weights[unique_idx] != 0))

        shutil.rmtree(temporary_folder)
    def test_key_mismatch(self):
        lossA = ContrastiveLoss()
        lossB = TripletMarginLoss(0.1)
        self.assertRaises(
            AssertionError,
            lambda: MultipleLosses(
                losses={
                    "lossA": lossA,
                    "lossB": lossB
                },
                weights={
                    "blah": 1,
                    "lossB": 0.23
                },
            ),
        )

        minerA = MultiSimilarityMiner()
        self.assertRaises(
            AssertionError,
            lambda: MultipleLosses(
                losses={
                    "lossA": lossA,
                    "lossB": lossB
                },
                weights={
                    "lossA": 1,
                    "lossB": 0.23
                },
                miners={"blah": minerA},
            ),
        )
    def test_input_indices_tuple(self):
        lossA = ContrastiveLoss()
        lossB = TripletMarginLoss(0.1)
        miner = MultiSimilarityMiner()
        loss_func1 = MultipleLosses(losses={
            "lossA": lossA,
            "lossB": lossB
        },
                                    weights={
                                        "lossA": 1,
                                        "lossB": 0.23
                                    })

        loss_func2 = MultipleLosses(losses=[lossA, lossB], weights=[1, 0.23])

        for loss_func in [loss_func1, loss_func2]:
            for dtype in TEST_DTYPES:
                embedding_angles = torch.arange(0, 180)
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=dtype,
                ).to(TEST_DEVICE)  # 2D embeddings
                labels = torch.randint(low=0, high=10, size=(180, ))
                indices_tuple = miner(embeddings, labels)

                loss = loss_func(embeddings, labels, indices_tuple)
                loss.backward()

                correct_loss = (
                    lossA(embeddings, labels, indices_tuple) +
                    lossB(embeddings, labels, indices_tuple) * 0.23)
                self.assertTrue(torch.isclose(loss, correct_loss))
 def test_empty_output(self):
     miner = MultiSimilarityMiner(0.1)
     batch_size = 32
     for dtype in [torch.float16, torch.float32, torch.float64]:
         embeddings = torch.randn(batch_size, 64).type(dtype).to(self.device)
         labels = torch.arange(batch_size)
         a1, p, _, _ = miner(embeddings, labels)
         self.assertTrue(len(a1)==0)
         self.assertTrue(len(p)==0)
Beispiel #6
0
 def test_empty_output(self):
     miner = MultiSimilarityMiner(0.1)
     batch_size = 32
     for dtype in TEST_DTYPES:
         embeddings = torch.randn(batch_size,
                                  64).type(dtype).to(TEST_DEVICE)
         labels = torch.arange(batch_size)
         a1, p, _, _ = miner(embeddings, labels)
         self.assertTrue(len(a1) == 0)
         self.assertTrue(len(p) == 0)
    def test_multi_similarity_miner(self):
        epsilon = 0.1
        miner = MultiSimilarityMiner(epsilon)
        for dtype in [torch.float16, torch.float32, torch.float64]:
            embedding_angles = torch.arange(0, 64)
            embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings
            labels = torch.randint(low=0, high=10, size=(64,))
            pos_pairs = []
            neg_pairs = []
            for i in range(len(embeddings)):
                anchor, anchor_label = embeddings[i], labels[i]
                for j in range(len(embeddings)):
                    if j != i:
                        other, other_label = embeddings[j], labels[j]
                        if anchor_label == other_label:
                            pos_pairs.append((i,j,torch.matmul(anchor, other.t()).item()))
                        if anchor_label != other_label:
                            neg_pairs.append((i,j,torch.matmul(anchor, other.t()).item()))
            
            correct_a1, correct_p = [], []
            correct_a2, correct_n = [], []
            for a1,p,ap_sim in pos_pairs:
                max_neg_sim = c_f.neg_inf(dtype)
                for a2,n,an_sim in neg_pairs:
                    if a2==a1:
                        if an_sim > max_neg_sim:
                            max_neg_sim = an_sim
                if ap_sim < max_neg_sim + epsilon:
                    correct_a1.append(a1)
                    correct_p.append(p)

            for a2,n,an_sim in neg_pairs:
                min_pos_sim = c_f.pos_inf(dtype)
                for a1,p,ap_sim in pos_pairs:
                    if a2==a1:
                        if ap_sim < min_pos_sim:
                            min_pos_sim = ap_sim
                if an_sim > min_pos_sim - epsilon:
                    correct_a2.append(a2)
                    correct_n.append(n)

            correct_pos_pairs = set([(a,p) for a,p in zip(correct_a1, correct_p)])
            correct_neg_pairs = set([(a,n) for a,n in zip(correct_a2, correct_n)])

            a1, p1, a2, n2 = miner(embeddings, labels)
            pos_pairs = set([(a.item(),p.item()) for a,p in zip(a1,p1)])
            neg_pairs = set([(a.item(),n.item()) for a,n in zip(a2,n2)])

            self.assertTrue(pos_pairs == correct_pos_pairs)
            self.assertTrue(neg_pairs == correct_neg_pairs)
    def test_length_mistmatch(self):
        lossA = ContrastiveLoss()
        lossB = TripletMarginLoss(0.1)
        self.assertRaises(
            AssertionError,
            lambda: MultipleLosses(losses=[lossA, lossB], weights=[1]))

        minerA = MultiSimilarityMiner()
        self.assertRaises(
            AssertionError,
            lambda: MultipleLosses(
                losses=[lossA, lossB],
                weights=[1, 0.2],
                miners=[minerA],
            ),
        )
Beispiel #9
0
    def test_multi_similarity_miner(self):
        epsilon = 0.1
        for dtype in TEST_DTYPES:
            for distance in [CosineSimilarity(), LpDistance()]:
                miner = MultiSimilarityMiner(epsilon, distance=distance)
                embedding_angles = torch.arange(0, 64)
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=dtype,
                ).to(TEST_DEVICE)  # 2D embeddings
                labels = torch.randint(low=0, high=10, size=(64, ))
                mat = distance(embeddings)
                pos_pairs = []
                neg_pairs = []
                for i in range(len(embeddings)):
                    anchor_label = labels[i]
                    for j in range(len(embeddings)):
                        if j != i:
                            other_label = labels[j]
                            if anchor_label == other_label:
                                pos_pairs.append((i, j, mat[i, j]))
                            if anchor_label != other_label:
                                neg_pairs.append((i, j, mat[i, j]))

                correct_a1, correct_p = [], []
                correct_a2, correct_n = [], []
                for a1, p, ap_sim in pos_pairs:
                    most_difficult = (c_f.neg_inf(dtype)
                                      if distance.is_inverted else
                                      c_f.pos_inf(dtype))
                    for a2, n, an_sim in neg_pairs:
                        if a2 == a1:
                            condition = ((an_sim > most_difficult)
                                         if distance.is_inverted else
                                         (an_sim < most_difficult))
                            if condition:
                                most_difficult = an_sim
                    condition = ((ap_sim < most_difficult + epsilon)
                                 if distance.is_inverted else
                                 (ap_sim > most_difficult - epsilon))
                    if condition:
                        correct_a1.append(a1)
                        correct_p.append(p)

                for a2, n, an_sim in neg_pairs:
                    most_difficult = (c_f.pos_inf(dtype)
                                      if distance.is_inverted else
                                      c_f.neg_inf(dtype))
                    for a1, p, ap_sim in pos_pairs:
                        if a2 == a1:
                            condition = ((ap_sim < most_difficult)
                                         if distance.is_inverted else
                                         (ap_sim > most_difficult))
                            if condition:
                                most_difficult = ap_sim
                    condition = ((an_sim > most_difficult - epsilon)
                                 if distance.is_inverted else
                                 (an_sim < most_difficult + epsilon))
                    if condition:
                        correct_a2.append(a2)
                        correct_n.append(n)

                correct_pos_pairs = set([
                    (a, p) for a, p in zip(correct_a1, correct_p)
                ])
                correct_neg_pairs = set([
                    (a, n) for a, n in zip(correct_a2, correct_n)
                ])

                a1, p1, a2, n2 = miner(embeddings, labels)
                pos_pairs = set([(a.item(), p.item()) for a, p in zip(a1, p1)])
                neg_pairs = set([(a.item(), n.item()) for a, n in zip(a2, n2)])

                self.assertTrue(pos_pairs == correct_pos_pairs)
                self.assertTrue(neg_pairs == correct_neg_pairs)