Ejemplo n.º 1
0
    def testRouletteWheel_UniqueMode(self):
        random.seed(12345678987654321)
        r = utils.RouletteWheel(unique_mode=True)
        self.assertEqual(True, r.add([1, 2, 3], 1, 'a'))
        self.assertEqual(True, r.add([4, 5], 0.5, 'b'))
        self.assertEqual(False, r.add([1, 2, 3], 1.5, 'a'))
        self.assertEqual([([1, 2, 3], 1.0), ([4, 5], 0.5)], list(r))
        self.assertEqual(1.5, r.total_weight)
        self.assertEqual(
            2,
            r.add_many([[5, 6, 2, 3], [1, 2, 3], [8], [1, 2, 3]],
                       [0.1, 0.2, 0.1, 2.0], ['c', 'a', 'd', 'a']))
        self.assertEqual([([1, 2, 3], 1.0), ([4, 5], 0.5), ([5, 6, 2, 3], 0.1),
                          ([8], 0.1)], list(r))
        self.assertTrue(np.isclose(1.7, r.total_weight))
        self.assertEqual(0, r.add_many([], [],
                                       []))  # Adding no items is allowed.
        with self.assertRaises(ValueError):
            # Key not given.
            r.add([7, 8, 9], 2.0)
        with self.assertRaises(ValueError):
            # Keys not given.
            r.add_many([[7, 8, 9], [10]], [2.0, 2.0])
        self.assertEqual(True, r.has_key('a'))
        self.assertEqual(True, r.has_key('b'))
        self.assertEqual(False, r.has_key('z'))
        self.assertEqual(1.0, r.get_weight('a'))
        self.assertEqual(0.5, r.get_weight('b'))

        r = utils.RouletteWheel(unique_mode=False)
        self.assertEqual(True, r.add([1, 2, 3], 1))
        self.assertEqual(True, r.add([4, 5], 0.5))
        self.assertEqual(True, r.add([1, 2, 3], 1.5))
        self.assertEqual([([1, 2, 3], 1.0), ([4, 5], 0.5), ([1, 2, 3], 1.5)],
                         list(r))
        self.assertEqual(3, r.total_weight)
        self.assertEqual(
            4,
            r.add_many([[5, 6, 2, 3], [1, 2, 3], [8], [1, 2, 3]],
                       [0.1, 0.2, 0.1, 0.2]))
        self.assertEqual([([1, 2, 3], 1.0), ([4, 5], 0.5), ([1, 2, 3], 1.5),
                          ([5, 6, 2, 3], 0.1), ([1, 2, 3], 0.2), ([8], 0.1),
                          ([1, 2, 3], 0.2)], list(r))
        self.assertTrue(np.isclose(3.6, r.total_weight))
        with self.assertRaises(ValueError):
            # Key is given.
            r.add([7, 8, 9], 2.0, 'a')
        with self.assertRaises(ValueError):
            # Keys are given.
            r.add_many([[7, 8, 9], [10]], [2.0, 2.0], ['a', 'b'])
Ejemplo n.º 2
0
 def testRouletteWheel_AddZeroWeights(self):
     r = utils.RouletteWheel()
     self.assertEqual(True, r.add('a', 0))
     self.assertFalse(r.is_empty())
     self.assertEqual(4, r.add_many(['b', 'c', 'd', 'e'], [0, 0.1, 0, 0]))
     self.assertEqual([('a', 0.0), ('b', 0.0), ('c', 0.1), ('d', 0.0),
                       ('e', 0.0)], list(r))
Ejemplo n.º 3
0
    def testRouletteWheel_IncrementalSave(self):
        f = tempfile.NamedTemporaryFile()
        r = utils.RouletteWheel(unique_mode=True, save_file=f.name)
        entries = [([1, 2, 3], 0.1, 'a'), ([4, 5], 0.2, 'b'), ([6], 0.3, 'c'),
                   ([7, 8, 9, 10], 0.25, 'd'), ([-1, -2], 0.15, 'e'),
                   ([-3, -4, -5], 0.5, 'f')]

        self.assertTrue(r.is_empty())
        for i in range(0, len(entries), 2):
            r.add(*entries[i])
            r.add(*entries[i + 1])
            r.incremental_save()

            r2 = utils.RouletteWheel(unique_mode=True, save_file=f.name)
            self.assertEqual(i + 2, len(r2))
            count = 0
            for j, (obj, weight) in enumerate(r2):
                self.assertEqual(entries[j][0], obj)
                self.assertEqual(entries[j][1], weight)
                self.assertEqual(weight, r2.get_weight(entries[j][2]))
                count += 1
            self.assertEqual(i + 2, count)
Ejemplo n.º 4
0
    def testRouletteWheel(self):
        random.seed(12345678987654321)
        r = utils.RouletteWheel()
        self.assertTrue(r.is_empty())
        with self.assertRaises(RuntimeError):
            r.sample()  # Cannot sample when empty.
        self.assertEqual(0, r.total_weight)
        self.assertEqual(True, r.add('a', 0.1))
        self.assertFalse(r.is_empty())
        self.assertEqual(0.1, r.total_weight)
        self.assertEqual(True, r.add('b', 0.01))
        self.assertEqual(0.11, r.total_weight)
        self.assertEqual(True, r.add('c', 0.5))
        self.assertEqual(True, r.add('d', 0.1))
        self.assertEqual(True, r.add('e', 0.05))
        self.assertEqual(True, r.add('f', 0.03))
        self.assertEqual(True, r.add('g', 0.001))
        self.assertEqual(0.791, r.total_weight)
        self.assertFalse(r.is_empty())

        # Check that sampling is correct.
        obj, weight = r.sample()
        self.assertTrue(isinstance(weight, float), 'Type: %s' % type(weight))
        self.assertTrue((obj, weight) in r)
        for obj, weight in r.sample_many(100):
            self.assertTrue(isinstance(weight, float),
                            'Type: %s' % type(weight))
            self.assertTrue((obj, weight) in r)

        # Check that sampling distribution is correct.
        n = 1000000
        c = Counter(r.sample_many(n))
        for obj, w in r:
            estimated_w = c[(obj, w)] / float(n) * r.total_weight
            self.assertTrue(
                np.isclose(w, estimated_w, atol=1e-3),
                'Expected %s, got %s, for object %s' % (w, estimated_w, obj))