Exemplo n.º 1
0
    def test_capacity_greater_or_equal_to_size_and_power_of_2(self):
        sum_tree = replay_lib.SumTree()
        sum_tree.set_all([4., 5., 3., 2.])
        self.assertEqual(4, sum_tree.capacity)

        sum_tree = replay_lib.SumTree()
        sum_tree.set_all([4., 5., 3., 2., 9])
        self.assertEqual(8, sum_tree.capacity)
Exemplo n.º 2
0
    def test_set_all_cannot_add_negative_nan_or_inf_values(self):

        with self.assertRaises(ValueError):
            replay_lib.SumTree().set_all([1, -1])

        with self.assertRaises(ValueError):
            replay_lib.SumTree().set_all([1, np.nan])

        with self.assertRaises(ValueError):
            replay_lib.SumTree().set_all([1, np.inf])
Exemplo n.º 3
0
    def test_getting_and_setting_state(self):
        sum_tree = replay_lib.SumTree()
        values = [4, 5, 3, 9]
        sum_tree.set_all(values)
        state = sum_tree.get_state()

        new_sum_tree = replay_lib.SumTree()
        new_sum_tree.set_state(state)
        new_sum_tree.check_valid()
        np.testing.assert_allclose(new_sum_tree.values, sum_tree.values)
        self.assertEqual(sum_tree.size, new_sum_tree.size)
        self.assertEqual(sum_tree.capacity, new_sum_tree.capacity)
Exemplo n.º 4
0
 def test_exception_raised_when_index_out_of_bounds_in_get(self):
     sum_tree = replay_lib.SumTree()
     size = 3
     sum_tree.resize(size)
     for i in [-1, size]:
         with self.assertRaises(IndexError):
             sum_tree.get([i])
Exemplo n.º 5
0
 def test_set_updates_total_sum(self):
     sum_tree = replay_lib.SumTree()
     values = [4, 5, 3, 9]
     sum_tree.set_all(values)
     sum_tree.set([1], [2])
     self.assertAlmostEqual(sum(values) - 5 + 2, sum_tree.root())
     sum_tree.check_valid()
Exemplo n.º 6
0
 def test_set_all(self):
   sum_tree = replay_lib.SumTree()
   values = [4., 5., 3.]
   sum_tree.set_all(values)
   self.assertLen(values, sum_tree.size)
   for i in range(len(values)):
     np.testing.assert_array_almost_equal([values[i]], sum_tree.get([i]))
   sum_tree.check_valid()
Exemplo n.º 7
0
 def test_resizes_preserves_values_when_shrinking(self):
   sum_tree = replay_lib.SumTree()
   values = [4., 5., 3., 8., 2.]
   sum_tree.set_all(values)
   new_size = len(values) - 2
   sum_tree.resize(new_size)
   for i in range(new_size):
     np.testing.assert_array_almost_equal([values[i]], sum_tree.get([i]))
   sum_tree.check_valid()
Exemplo n.º 8
0
 def test_resizing_to_size_between_current_size_and_capacity(self):
     sum_tree = replay_lib.SumTree()
     values = [4., 5., 3., 8., 2.]
     sum_tree.set_all(values)
     new_size = 7
     assert sum_tree.size < new_size < sum_tree.capacity
     sum_tree.resize(new_size)
     np.testing.assert_allclose(values + [0., 0.], sum_tree.values)
     sum_tree.check_valid()
Exemplo n.º 9
0
 def test_resize_preserves_values_and_zeros_the_rest_when_growing(self):
   sum_tree = replay_lib.SumTree()
   values = [4., 5., 3.]
   sum_tree.set_all(values)
   new_size = len(values) + 5
   sum_tree.resize(new_size)
   for i in range(len(values)):
     np.testing.assert_array_almost_equal([values[i]], sum_tree.get([i]))
   for i in range(len(values), new_size):
     np.testing.assert_array_almost_equal([0.], sum_tree.get([i]))
   sum_tree.check_valid()
Exemplo n.º 10
0
    def test_set_cannot_add_negative_nan_or_inf_values(self):
        sum_tree = replay_lib.SumTree()
        sum_tree.set_all([0, 1, 2])

        with self.assertRaises(ValueError):
            sum_tree.set([1], [-1])

        with self.assertRaises(ValueError):
            sum_tree.set([1], [np.nan])

        with self.assertRaises(ValueError):
            sum_tree.set([1], [np.inf])
Exemplo n.º 11
0
  def test_with_random_data(self, seed):
    actual_sum_tree = replay_lib.SumTree()
    naive_sum_tree = NaiveSumTree()

    # Randomly perform operations, periodically stopping to compare.
    operation_iterator = zip(
        random_operations(actual_sum_tree, seed),
        random_operations(naive_sum_tree, seed))
    for actual_value, naive_value in operation_iterator:
      if actual_value is not None and naive_value is not None:
        np.testing.assert_allclose(actual_value, naive_value)
      actual_sum_tree.check_valid()
      self.assertAlmostEqual(naive_sum_tree.root(), actual_sum_tree.root())
      np.testing.assert_allclose(naive_sum_tree.values, actual_sum_tree.values)
Exemplo n.º 12
0
    def test_query_raises_exception_if_target_out_of_range(self):
        sum_tree = replay_lib.SumTree()
        values = [3., 1., 2., 5.]
        sum_tree.set_all(values)

        with self.assertRaises(ValueError):
            sum_tree.query([-1.])

        with self.assertRaises(ValueError):
            sum_tree.query([sum(values)])

        with self.assertRaises(ValueError):
            sum_tree.query([sum(values) + 1.])

        with self.assertRaises(ValueError):
            sum_tree.query([sum_tree.root()])
Exemplo n.º 13
0
 def test_get_with_multiple_indexes(self):
     sum_tree = replay_lib.SumTree()
     values = [4., 5., 3., 9.]
     sum_tree.set_all(values)
     selected = sum_tree.get([1, 3])
     np.testing.assert_allclose([values[1], values[3]], selected)
Exemplo n.º 14
0
 def test_root_returns_sum(self):
     sum_tree = replay_lib.SumTree()
     values = [3., 1., 2., 5.]
     sum_tree.set_all(values)
     self.assertAlmostEqual(sum(values), sum_tree.root())
Exemplo n.º 15
0
 def test_query_never_returns_an_index_with_zero_index(self, target):
     sum_tree = replay_lib.SumTree()
     values = np.array([0, 1, 0, 0, 3, 0, 2, 0, 3, 0], dtype=np.float64)
     zero_indices = (values == 0).nonzero()[0]
     sum_tree.set_all(values)
     self.assertNotIn(sum_tree.query([target])[0], zero_indices)
Exemplo n.º 16
0
 def test_query_multiple(self):
     sum_tree = replay_lib.SumTree()
     values = [3., 1., 2., 5.]
     sum_tree.set_all(values)
     np.testing.assert_array_equal([0, 1, 2], sum_tree.query([2.9, 3., 4]))
Exemplo n.º 17
0
 def test_query_typical(self, expected_index, target):
     sum_tree = replay_lib.SumTree()
     values = [3., 1., 2., 5.]
     sum_tree.set_all(values)
     self.assertEqual([expected_index], sum_tree.query([target]))
Exemplo n.º 18
0
 def test_set_multiple(self):
     sum_tree = replay_lib.SumTree()
     values = [4, 5, 3, 9]
     sum_tree.set_all(values)
     sum_tree.set([2, 0], [99, 88])
     np.testing.assert_allclose([88, 5, 99, 9], sum_tree.values)
Exemplo n.º 19
0
 def test_size_is_correct(self):
     sum_tree = replay_lib.SumTree()
     self.assertEqual(0, sum_tree.size)
     size = 3
     sum_tree.resize(size)
     self.assertEqual(size, sum_tree.size)
Exemplo n.º 20
0
 def test_resize_returns_zero_values_initially(self):
     sum_tree = replay_lib.SumTree()
     size = 3
     sum_tree.resize(size)
     for i in range(size):
         self.assertEqual(0, sum_tree.get([i]))
Exemplo n.º 21
0
 def test_values_returns_values(self):
     sum_tree = replay_lib.SumTree()
     values = [4., 5., 3.]
     sum_tree.set_all(values)
     np.testing.assert_allclose(values, sum_tree.values)
Exemplo n.º 22
0
 def test_resize_to_0(self):
     sum_tree = replay_lib.SumTree()
     sum_tree.resize(0)
     sum_tree.check_valid()
     self.assertTrue(np.isnan(sum_tree.root()))
Exemplo n.º 23
0
 def test_resize_to_1(self):
     sum_tree = replay_lib.SumTree()
     sum_tree.resize(1)
     sum_tree.check_valid()
     self.assertEqual(0, sum_tree.root())
Exemplo n.º 24
0
 def test_set_single(self):
     sum_tree = replay_lib.SumTree()
     values = [4, 5, 3, 9]
     sum_tree.set_all(values)
     sum_tree.set([2], [99])
     np.testing.assert_allclose([4, 5, 99, 9], sum_tree.values)
Exemplo n.º 25
0
 def test_can_create_empty(self):
     sum_tree = replay_lib.SumTree()
     sum_tree.check_valid()
     self.assertEqual(0, sum_tree.size)
     self.assertTrue(np.isnan(sum_tree.root()))