def test_capacity_greater_or_equal_to_size_and_power_of_2(self): sum_tree = replay_lib.SumTree() sum_tree.set_all([4., 5., 3., 2.]) self.assertEqual(4, sum_tree.capacity) sum_tree = replay_lib.SumTree() sum_tree.set_all([4., 5., 3., 2., 9]) self.assertEqual(8, sum_tree.capacity)
def test_set_all_cannot_add_negative_nan_or_inf_values(self): with self.assertRaises(ValueError): replay_lib.SumTree().set_all([1, -1]) with self.assertRaises(ValueError): replay_lib.SumTree().set_all([1, np.nan]) with self.assertRaises(ValueError): replay_lib.SumTree().set_all([1, np.inf])
def test_getting_and_setting_state(self): sum_tree = replay_lib.SumTree() values = [4, 5, 3, 9] sum_tree.set_all(values) state = sum_tree.get_state() new_sum_tree = replay_lib.SumTree() new_sum_tree.set_state(state) new_sum_tree.check_valid() np.testing.assert_allclose(new_sum_tree.values, sum_tree.values) self.assertEqual(sum_tree.size, new_sum_tree.size) self.assertEqual(sum_tree.capacity, new_sum_tree.capacity)
def test_exception_raised_when_index_out_of_bounds_in_get(self): sum_tree = replay_lib.SumTree() size = 3 sum_tree.resize(size) for i in [-1, size]: with self.assertRaises(IndexError): sum_tree.get([i])
def test_set_updates_total_sum(self): sum_tree = replay_lib.SumTree() values = [4, 5, 3, 9] sum_tree.set_all(values) sum_tree.set([1], [2]) self.assertAlmostEqual(sum(values) - 5 + 2, sum_tree.root()) sum_tree.check_valid()
def test_set_all(self): sum_tree = replay_lib.SumTree() values = [4., 5., 3.] sum_tree.set_all(values) self.assertLen(values, sum_tree.size) for i in range(len(values)): np.testing.assert_array_almost_equal([values[i]], sum_tree.get([i])) sum_tree.check_valid()
def test_resizes_preserves_values_when_shrinking(self): sum_tree = replay_lib.SumTree() values = [4., 5., 3., 8., 2.] sum_tree.set_all(values) new_size = len(values) - 2 sum_tree.resize(new_size) for i in range(new_size): np.testing.assert_array_almost_equal([values[i]], sum_tree.get([i])) sum_tree.check_valid()
def test_resizing_to_size_between_current_size_and_capacity(self): sum_tree = replay_lib.SumTree() values = [4., 5., 3., 8., 2.] sum_tree.set_all(values) new_size = 7 assert sum_tree.size < new_size < sum_tree.capacity sum_tree.resize(new_size) np.testing.assert_allclose(values + [0., 0.], sum_tree.values) sum_tree.check_valid()
def test_resize_preserves_values_and_zeros_the_rest_when_growing(self): sum_tree = replay_lib.SumTree() values = [4., 5., 3.] sum_tree.set_all(values) new_size = len(values) + 5 sum_tree.resize(new_size) for i in range(len(values)): np.testing.assert_array_almost_equal([values[i]], sum_tree.get([i])) for i in range(len(values), new_size): np.testing.assert_array_almost_equal([0.], sum_tree.get([i])) sum_tree.check_valid()
def test_set_cannot_add_negative_nan_or_inf_values(self): sum_tree = replay_lib.SumTree() sum_tree.set_all([0, 1, 2]) with self.assertRaises(ValueError): sum_tree.set([1], [-1]) with self.assertRaises(ValueError): sum_tree.set([1], [np.nan]) with self.assertRaises(ValueError): sum_tree.set([1], [np.inf])
def test_with_random_data(self, seed): actual_sum_tree = replay_lib.SumTree() naive_sum_tree = NaiveSumTree() # Randomly perform operations, periodically stopping to compare. operation_iterator = zip( random_operations(actual_sum_tree, seed), random_operations(naive_sum_tree, seed)) for actual_value, naive_value in operation_iterator: if actual_value is not None and naive_value is not None: np.testing.assert_allclose(actual_value, naive_value) actual_sum_tree.check_valid() self.assertAlmostEqual(naive_sum_tree.root(), actual_sum_tree.root()) np.testing.assert_allclose(naive_sum_tree.values, actual_sum_tree.values)
def test_query_raises_exception_if_target_out_of_range(self): sum_tree = replay_lib.SumTree() values = [3., 1., 2., 5.] sum_tree.set_all(values) with self.assertRaises(ValueError): sum_tree.query([-1.]) with self.assertRaises(ValueError): sum_tree.query([sum(values)]) with self.assertRaises(ValueError): sum_tree.query([sum(values) + 1.]) with self.assertRaises(ValueError): sum_tree.query([sum_tree.root()])
def test_get_with_multiple_indexes(self): sum_tree = replay_lib.SumTree() values = [4., 5., 3., 9.] sum_tree.set_all(values) selected = sum_tree.get([1, 3]) np.testing.assert_allclose([values[1], values[3]], selected)
def test_root_returns_sum(self): sum_tree = replay_lib.SumTree() values = [3., 1., 2., 5.] sum_tree.set_all(values) self.assertAlmostEqual(sum(values), sum_tree.root())
def test_query_never_returns_an_index_with_zero_index(self, target): sum_tree = replay_lib.SumTree() values = np.array([0, 1, 0, 0, 3, 0, 2, 0, 3, 0], dtype=np.float64) zero_indices = (values == 0).nonzero()[0] sum_tree.set_all(values) self.assertNotIn(sum_tree.query([target])[0], zero_indices)
def test_query_multiple(self): sum_tree = replay_lib.SumTree() values = [3., 1., 2., 5.] sum_tree.set_all(values) np.testing.assert_array_equal([0, 1, 2], sum_tree.query([2.9, 3., 4]))
def test_query_typical(self, expected_index, target): sum_tree = replay_lib.SumTree() values = [3., 1., 2., 5.] sum_tree.set_all(values) self.assertEqual([expected_index], sum_tree.query([target]))
def test_set_multiple(self): sum_tree = replay_lib.SumTree() values = [4, 5, 3, 9] sum_tree.set_all(values) sum_tree.set([2, 0], [99, 88]) np.testing.assert_allclose([88, 5, 99, 9], sum_tree.values)
def test_size_is_correct(self): sum_tree = replay_lib.SumTree() self.assertEqual(0, sum_tree.size) size = 3 sum_tree.resize(size) self.assertEqual(size, sum_tree.size)
def test_resize_returns_zero_values_initially(self): sum_tree = replay_lib.SumTree() size = 3 sum_tree.resize(size) for i in range(size): self.assertEqual(0, sum_tree.get([i]))
def test_values_returns_values(self): sum_tree = replay_lib.SumTree() values = [4., 5., 3.] sum_tree.set_all(values) np.testing.assert_allclose(values, sum_tree.values)
def test_resize_to_0(self): sum_tree = replay_lib.SumTree() sum_tree.resize(0) sum_tree.check_valid() self.assertTrue(np.isnan(sum_tree.root()))
def test_resize_to_1(self): sum_tree = replay_lib.SumTree() sum_tree.resize(1) sum_tree.check_valid() self.assertEqual(0, sum_tree.root())
def test_set_single(self): sum_tree = replay_lib.SumTree() values = [4, 5, 3, 9] sum_tree.set_all(values) sum_tree.set([2], [99]) np.testing.assert_allclose([4, 5, 99, 9], sum_tree.values)
def test_can_create_empty(self): sum_tree = replay_lib.SumTree() sum_tree.check_valid() self.assertEqual(0, sum_tree.size) self.assertTrue(np.isnan(sum_tree.root()))