Beispiel #1
0
  def testRouletteWheel_IncrementalSave(self):
    f = tempfile.NamedTemporaryFile()
    r = utils.RouletteWheel(unique_mode=True, save_file=f.name)
    entries = [
        ([1, 2, 3], 0.1, 'a'),
        ([4, 5], 0.2, 'b'),
        ([6], 0.3, 'c'),
        ([7, 8, 9, 10], 0.25, 'd'),
        ([-1, -2], 0.15, 'e'),
        ([-3, -4, -5], 0.5, 'f')]

    self.assertTrue(r.is_empty())
    for i in range(0, len(entries), 2):
      r.add(*entries[i])
      r.add(*entries[i + 1])
      r.incremental_save()

      r2 = utils.RouletteWheel(unique_mode=True, save_file=f.name)
      self.assertEqual(i + 2, len(r2))
      count = 0
      for j, (obj, weight) in enumerate(r2):
        self.assertEqual(entries[j][0], obj)
        self.assertEqual(entries[j][1], weight)
        self.assertEqual(weight, r2.get_weight(entries[j][2]))
        count += 1
      self.assertEqual(i + 2, count)
Beispiel #2
0
  def testRouletteWheel_UniqueMode(self):
    random.seed(12345678987654321)
    r = utils.RouletteWheel(unique_mode=True)
    self.assertEqual(True, r.add([1, 2, 3], 1, 'a'))
    self.assertEqual(True, r.add([4, 5], 0.5, 'b'))
    self.assertEqual(False, r.add([1, 2, 3], 1.5, 'a'))
    self.assertEqual(
        [([1, 2, 3], 1.0), ([4, 5], 0.5)],
        list(r))
    self.assertEqual(1.5, r.total_weight)
    self.assertEqual(
        2,
        r.add_many(
            [[5, 6, 2, 3], [1, 2, 3], [8], [1, 2, 3]],
            [0.1, 0.2, 0.1, 2.0],
            ['c', 'a', 'd', 'a']))
    self.assertEqual(
        [([1, 2, 3], 1.0), ([4, 5], 0.5), ([5, 6, 2, 3], 0.1), ([8], 0.1)],
        list(r))
    self.assertTrue(np.isclose(1.7, r.total_weight))
    self.assertEqual(0, r.add_many([], [], []))  # Adding no items is allowed.
    with self.assertRaises(ValueError):
      # Key not given.
      r.add([7, 8, 9], 2.0)
    with self.assertRaises(ValueError):
      # Keys not given.
      r.add_many([[7, 8, 9], [10]], [2.0, 2.0])
    self.assertEqual(True, r.has_key('a'))
    self.assertEqual(True, r.has_key('b'))
    self.assertEqual(False, r.has_key('z'))
    self.assertEqual(1.0, r.get_weight('a'))
    self.assertEqual(0.5, r.get_weight('b'))

    r = utils.RouletteWheel(unique_mode=False)
    self.assertEqual(True, r.add([1, 2, 3], 1))
    self.assertEqual(True, r.add([4, 5], 0.5))
    self.assertEqual(True, r.add([1, 2, 3], 1.5))
    self.assertEqual(
        [([1, 2, 3], 1.0), ([4, 5], 0.5), ([1, 2, 3], 1.5)],
        list(r))
    self.assertEqual(3, r.total_weight)
    self.assertEqual(
        4,
        r.add_many(
            [[5, 6, 2, 3], [1, 2, 3], [8], [1, 2, 3]],
            [0.1, 0.2, 0.1, 0.2]))
    self.assertEqual(
        [([1, 2, 3], 1.0), ([4, 5], 0.5), ([1, 2, 3], 1.5),
         ([5, 6, 2, 3], 0.1), ([1, 2, 3], 0.2), ([8], 0.1), ([1, 2, 3], 0.2)],
        list(r))
    self.assertTrue(np.isclose(3.6, r.total_weight))
    with self.assertRaises(ValueError):
      # Key is given.
      r.add([7, 8, 9], 2.0, 'a')
    with self.assertRaises(ValueError):
      # Keys are given.
      r.add_many([[7, 8, 9], [10]], [2.0, 2.0], ['a', 'b'])
Beispiel #3
0
 def testRouletteWheel_AddZeroWeights(self):
     r = utils.RouletteWheel()
     self.assertEqual(True, r.add('a', 0))
     self.assertFalse(r.is_empty())
     self.assertEqual(4, r.add_many(['b', 'c', 'd', 'e'], [0, 0.1, 0, 0]))
     self.assertEqual([('a', 0.0), ('b', 0.0), ('c', 0.1), ('d', 0.0),
                       ('e', 0.0)], list(r))
Beispiel #4
0
    def testRouletteWheel(self):
        random.seed(12345678987654321)
        r = utils.RouletteWheel()
        self.assertTrue(r.is_empty())
        with self.assertRaises(RuntimeError):
            r.sample()  # Cannot sample when empty.
        self.assertEqual(0, r.total_weight)
        self.assertEqual(True, r.add('a', 0.1))
        self.assertFalse(r.is_empty())
        self.assertEqual(0.1, r.total_weight)
        self.assertEqual(True, r.add('b', 0.01))
        self.assertEqual(0.11, r.total_weight)
        self.assertEqual(True, r.add('c', 0.5))
        self.assertEqual(True, r.add('d', 0.1))
        self.assertEqual(True, r.add('e', 0.05))
        self.assertEqual(True, r.add('f', 0.03))
        self.assertEqual(True, r.add('g', 0.001))
        self.assertEqual(0.791, r.total_weight)
        self.assertFalse(r.is_empty())

        # Check that sampling is correct.
        obj, weight = r.sample()
        self.assertTrue(isinstance(weight, float), 'Type: %s' % type(weight))
        self.assertTrue((obj, weight) in r)
        for obj, weight in r.sample_many(100):
            self.assertTrue(isinstance(weight, float),
                            'Type: %s' % type(weight))
            self.assertTrue((obj, weight) in r)

        # Check that sampling distribution is correct.
        n = 1000000
        c = Counter(r.sample_many(n))
        for obj, w in r:
            estimated_w = c[(obj, w)] / float(n) * r.total_weight
            self.assertTrue(
                np.isclose(w, estimated_w, atol=1e-3),
                'Expected %s, got %s, for object %s' % (w, estimated_w, obj))