def testRouletteWheel_UniqueMode(self): random.seed(12345678987654321) r = utils.RouletteWheel(unique_mode=True) self.assertEqual(True, r.add([1, 2, 3], 1, 'a')) self.assertEqual(True, r.add([4, 5], 0.5, 'b')) self.assertEqual(False, r.add([1, 2, 3], 1.5, 'a')) self.assertEqual([([1, 2, 3], 1.0), ([4, 5], 0.5)], list(r)) self.assertEqual(1.5, r.total_weight) self.assertEqual( 2, r.add_many([[5, 6, 2, 3], [1, 2, 3], [8], [1, 2, 3]], [0.1, 0.2, 0.1, 2.0], ['c', 'a', 'd', 'a'])) self.assertEqual([([1, 2, 3], 1.0), ([4, 5], 0.5), ([5, 6, 2, 3], 0.1), ([8], 0.1)], list(r)) self.assertTrue(np.isclose(1.7, r.total_weight)) self.assertEqual(0, r.add_many([], [], [])) # Adding no items is allowed. with self.assertRaises(ValueError): # Key not given. r.add([7, 8, 9], 2.0) with self.assertRaises(ValueError): # Keys not given. r.add_many([[7, 8, 9], [10]], [2.0, 2.0]) self.assertEqual(True, r.has_key('a')) self.assertEqual(True, r.has_key('b')) self.assertEqual(False, r.has_key('z')) self.assertEqual(1.0, r.get_weight('a')) self.assertEqual(0.5, r.get_weight('b')) r = utils.RouletteWheel(unique_mode=False) self.assertEqual(True, r.add([1, 2, 3], 1)) self.assertEqual(True, r.add([4, 5], 0.5)) self.assertEqual(True, r.add([1, 2, 3], 1.5)) self.assertEqual([([1, 2, 3], 1.0), ([4, 5], 0.5), ([1, 2, 3], 1.5)], list(r)) self.assertEqual(3, r.total_weight) self.assertEqual( 4, r.add_many([[5, 6, 2, 3], [1, 2, 3], [8], [1, 2, 3]], [0.1, 0.2, 0.1, 0.2])) self.assertEqual([([1, 2, 3], 1.0), ([4, 5], 0.5), ([1, 2, 3], 1.5), ([5, 6, 2, 3], 0.1), ([1, 2, 3], 0.2), ([8], 0.1), ([1, 2, 3], 0.2)], list(r)) self.assertTrue(np.isclose(3.6, r.total_weight)) with self.assertRaises(ValueError): # Key is given. r.add([7, 8, 9], 2.0, 'a') with self.assertRaises(ValueError): # Keys are given. r.add_many([[7, 8, 9], [10]], [2.0, 2.0], ['a', 'b'])
def testRouletteWheel_AddZeroWeights(self): r = utils.RouletteWheel() self.assertEqual(True, r.add('a', 0)) self.assertFalse(r.is_empty()) self.assertEqual(4, r.add_many(['b', 'c', 'd', 'e'], [0, 0.1, 0, 0])) self.assertEqual([('a', 0.0), ('b', 0.0), ('c', 0.1), ('d', 0.0), ('e', 0.0)], list(r))
def testRouletteWheel_IncrementalSave(self): f = tempfile.NamedTemporaryFile() r = utils.RouletteWheel(unique_mode=True, save_file=f.name) entries = [([1, 2, 3], 0.1, 'a'), ([4, 5], 0.2, 'b'), ([6], 0.3, 'c'), ([7, 8, 9, 10], 0.25, 'd'), ([-1, -2], 0.15, 'e'), ([-3, -4, -5], 0.5, 'f')] self.assertTrue(r.is_empty()) for i in range(0, len(entries), 2): r.add(*entries[i]) r.add(*entries[i + 1]) r.incremental_save() r2 = utils.RouletteWheel(unique_mode=True, save_file=f.name) self.assertEqual(i + 2, len(r2)) count = 0 for j, (obj, weight) in enumerate(r2): self.assertEqual(entries[j][0], obj) self.assertEqual(entries[j][1], weight) self.assertEqual(weight, r2.get_weight(entries[j][2])) count += 1 self.assertEqual(i + 2, count)
def testRouletteWheel(self): random.seed(12345678987654321) r = utils.RouletteWheel() self.assertTrue(r.is_empty()) with self.assertRaises(RuntimeError): r.sample() # Cannot sample when empty. self.assertEqual(0, r.total_weight) self.assertEqual(True, r.add('a', 0.1)) self.assertFalse(r.is_empty()) self.assertEqual(0.1, r.total_weight) self.assertEqual(True, r.add('b', 0.01)) self.assertEqual(0.11, r.total_weight) self.assertEqual(True, r.add('c', 0.5)) self.assertEqual(True, r.add('d', 0.1)) self.assertEqual(True, r.add('e', 0.05)) self.assertEqual(True, r.add('f', 0.03)) self.assertEqual(True, r.add('g', 0.001)) self.assertEqual(0.791, r.total_weight) self.assertFalse(r.is_empty()) # Check that sampling is correct. obj, weight = r.sample() self.assertTrue(isinstance(weight, float), 'Type: %s' % type(weight)) self.assertTrue((obj, weight) in r) for obj, weight in r.sample_many(100): self.assertTrue(isinstance(weight, float), 'Type: %s' % type(weight)) self.assertTrue((obj, weight) in r) # Check that sampling distribution is correct. n = 1000000 c = Counter(r.sample_many(n)) for obj, w in r: estimated_w = c[(obj, w)] / float(n) * r.total_weight self.assertTrue( np.isclose(w, estimated_w, atol=1e-3), 'Expected %s, got %s, for object %s' % (w, estimated_w, obj))