def test_distance_weighted_miner(self):
        embedding_angles = torch.arange(0, 180)
        embeddings = torch.tensor(
            [c_f.angle_to_coord(a) for a in embedding_angles],
            requires_grad=True,
            dtype=torch.float)  #2D embeddings
        labels = torch.randint(low=0, high=2, size=(180, ))
        a, _, n = lmu.get_all_triplets_indices(labels)
        all_an_dist = torch.nn.functional.pairwise_distance(
            embeddings[a], embeddings[n], 2)
        min_an_dist = torch.min(all_an_dist)

        for non_zero_cutoff_int in range(5, 15):
            non_zero_cutoff = (float(non_zero_cutoff_int) / 10.) - 0.01
            miner = DistanceWeightedMiner(0, non_zero_cutoff)
            a, p, n = miner(embeddings, labels)
            anchors, positives, negatives = embeddings[a], embeddings[
                p], embeddings[n]
            an_dist = torch.nn.functional.pairwise_distance(
                anchors, negatives, 2)
            self.assertTrue(torch.max(an_dist) <= non_zero_cutoff)
            an_dist_var = torch.var(an_dist)
            an_dist_mean = torch.mean(an_dist)
            target_var = ((non_zero_cutoff - min_an_dist)**
                          2) / 12  # variance formula for uniform distribution
            target_mean = (non_zero_cutoff - min_an_dist) / 2
            self.assertTrue(
                torch.abs(an_dist_var - target_var) / target_var < 0.1)
            self.assertTrue(
                torch.abs(an_dist_mean - target_mean) / target_mean < 0.1)
示例#2
0
 def test_with_distance_weighted_miner(self):
     for dtype in TEST_DTYPES:
         memory_size = 256
         inner_loss = NTXentLoss(temperature=0.1)
         inner_miner = DistanceWeightedMiner(cutoff=0.5,
                                             nonzero_loss_cutoff=1.4)
         loss_with_miner = CrossBatchMemory(
             loss=inner_loss,
             embedding_size=2,
             memory_size=memory_size,
             miner=inner_miner,
         )
         for i in range(20):
             embedding_angles = torch.arange(0, 32)
             embeddings = torch.tensor(
                 [c_f.angle_to_coord(a) for a in embedding_angles],
                 requires_grad=True,
                 dtype=dtype,
             ).to(TEST_DEVICE)  # 2D embeddings
             labels = torch.randint(low=0, high=10,
                                    size=(32, )).to(TEST_DEVICE)
             loss_val = loss_with_miner(embeddings, labels)
             loss_val.backward()
             self.assertTrue(
                 True)  # just check if we got here without an exception
    def test_distance_weighted_miner(self, with_ref_labels=False):
        for dtype in TEST_DTYPES:
            embedding_angles = torch.arange(0, 256)
            embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings
            ref_embeddings = embeddings.clone() if with_ref_labels else None
            labels = torch.randint(low=0, high=2, size=(256,))
            ref_labels = torch.randint(low=0, high=2, size=(256,)) if with_ref_labels else None

            a,_,n = lmu.get_all_triplets_indices(labels, ref_labels)
            if with_ref_labels:
                all_an_dist = torch.nn.functional.pairwise_distance(embeddings[a], ref_embeddings[n], 2)
            else:
                all_an_dist = torch.nn.functional.pairwise_distance(embeddings[a], embeddings[n], 2)
            min_an_dist = torch.min(all_an_dist)
            
            for non_zero_cutoff_int in range(5, 15):
                non_zero_cutoff = (float(non_zero_cutoff_int) / 10.) - 0.01
                miner = DistanceWeightedMiner(0, non_zero_cutoff)
                a, p, n = miner(embeddings, labels, ref_embeddings, ref_labels)
                if with_ref_labels:
                    anchors, positives, negatives = embeddings[a], ref_embeddings[p], ref_embeddings[n]
                else:
                    anchors, positives, negatives = embeddings[a], embeddings[p], embeddings[n]
                an_dist = torch.nn.functional.pairwise_distance(anchors, negatives, 2)
                self.assertTrue(torch.max(an_dist)<=non_zero_cutoff)
                an_dist_var = torch.var(an_dist)
                an_dist_mean = torch.mean(an_dist)
                target_var = ((non_zero_cutoff - min_an_dist)**2) / 12 # variance formula for uniform distribution
                target_mean = (non_zero_cutoff - min_an_dist) / 2
                self.assertTrue(torch.abs(an_dist_var-target_var)/target_var < 0.1)
                self.assertTrue(torch.abs(an_dist_mean-target_mean)/target_mean < 0.1)
 def test_empty_output(self):
     miner = DistanceWeightedMiner(0.1, 0.5)
     batch_size = 32
     embeddings = torch.randn(batch_size, 64)
     labels = torch.arange(batch_size)
     a, p, n = miner(embeddings, labels)
     self.assertTrue(len(a) == 0)
     self.assertTrue(len(p) == 0)
     self.assertTrue(len(n) == 0)
 def test_empty_output(self):
     miner = DistanceWeightedMiner(0.1, 0.5)
     batch_size = 32
     for dtype in TEST_DTYPES:
         embeddings = torch.randn(batch_size, 64).type(dtype).to(self.device)
         labels = torch.arange(batch_size)
         a, p, n = miner(embeddings, labels)
         self.assertTrue(len(a)==0)
         self.assertTrue(len(p)==0)
         self.assertTrue(len(n)==0)