def test_assert(self): # all zeros probs = np.zeros((10, )).astype('float32') sampler = WeightedRandomSampler(probs, 10, True) try: for idx in iter(sampler): pass self.assertTrue(False) except AssertionError: self.assertTrue(True) # not enough pos probs = self.init_probs(10, 5) sampler = WeightedRandomSampler(probs, 10, False) try: for idx in iter(sampler): pass self.assertTrue(False) except AssertionError: self.assertTrue(True) # neg probs probs = -1.0 * np.ones((10, )).astype('float32') sampler = WeightedRandomSampler(probs, 10, True) try: for idx in iter(sampler): pass self.assertTrue(False) except AssertionError: self.assertTrue(True)
def test_no_replacement(self): probs = self.init_probs(20, 10) sampler = WeightedRandomSampler(probs, 10, False) assert len(sampler) == 10 idxs = [] for idx in iter(sampler): assert probs[idx] > 0. idxs.append(idx) assert len(set(idxs)) == len(idxs)
def test_raise(self): # float num_samples probs = self.init_probs(10, 5) try: sampler = WeightedRandomSampler(probs, 2.3, True) self.assertTrue(False) except ValueError: self.assertTrue(True) # neg num_samples probs = self.init_probs(10, 5) try: sampler = WeightedRandomSampler(probs, -1, True) self.assertTrue(False) except ValueError: self.assertTrue(True) # no-bool replacement probs = self.init_probs(10, 5) try: sampler = WeightedRandomSampler(probs, 5, 5) self.assertTrue(False) except ValueError: self.assertTrue(True)
def test_replacement(self): probs = self.init_probs(20, 10) sampler = WeightedRandomSampler(probs, 30, True) assert len(sampler) == 30 for idx in iter(sampler): assert probs[idx] > 0.
def _make_balanced_sampler(labels): class_counts = np.bincount(labels) class_weights = 1. / class_counts weights = class_weights[labels] return WeightedRandomSampler(weights, len(weights))