Beispiel #1
0
    def test_assert(self):
        # all zeros
        probs = np.zeros((10, )).astype('float32')
        sampler = WeightedRandomSampler(probs, 10, True)
        try:
            for idx in iter(sampler):
                pass
            self.assertTrue(False)
        except AssertionError:
            self.assertTrue(True)

        # not enough pos
        probs = self.init_probs(10, 5)
        sampler = WeightedRandomSampler(probs, 10, False)
        try:
            for idx in iter(sampler):
                pass
            self.assertTrue(False)
        except AssertionError:
            self.assertTrue(True)

        # neg probs
        probs = -1.0 * np.ones((10, )).astype('float32')
        sampler = WeightedRandomSampler(probs, 10, True)
        try:
            for idx in iter(sampler):
                pass
            self.assertTrue(False)
        except AssertionError:
            self.assertTrue(True)
Beispiel #2
0
 def test_no_replacement(self):
     probs = self.init_probs(20, 10)
     sampler = WeightedRandomSampler(probs, 10, False)
     assert len(sampler) == 10
     idxs = []
     for idx in iter(sampler):
         assert probs[idx] > 0.
         idxs.append(idx)
     assert len(set(idxs)) == len(idxs)
Beispiel #3
0
    def test_raise(self):
        # float num_samples
        probs = self.init_probs(10, 5)
        try:
            sampler = WeightedRandomSampler(probs, 2.3, True)
            self.assertTrue(False)
        except ValueError:
            self.assertTrue(True)

        # neg num_samples
        probs = self.init_probs(10, 5)
        try:
            sampler = WeightedRandomSampler(probs, -1, True)
            self.assertTrue(False)
        except ValueError:
            self.assertTrue(True)

        # no-bool replacement
        probs = self.init_probs(10, 5)
        try:
            sampler = WeightedRandomSampler(probs, 5, 5)
            self.assertTrue(False)
        except ValueError:
            self.assertTrue(True)
Beispiel #4
0
 def test_replacement(self):
     probs = self.init_probs(20, 10)
     sampler = WeightedRandomSampler(probs, 30, True)
     assert len(sampler) == 30
     for idx in iter(sampler):
         assert probs[idx] > 0.
Beispiel #5
0
def _make_balanced_sampler(labels):
    class_counts = np.bincount(labels)
    class_weights = 1. / class_counts
    weights = class_weights[labels]
    return WeightedRandomSampler(weights, len(weights))