class FlexibleReplayPoolTest(unittest.TestCase): def setUp(self): self.pool = FlexibleReplayPool(max_size=10, fields={ 'field1': { 'shape': (1, ), 'dtype': 'float32' }, 'field2': { 'shape': (1, ), 'dtype': 'float32' }, }) def test_multi_dimensional_field(self): # Fill fields with random data pool = FlexibleReplayPool(max_size=10, fields={ 'field1': { 'shape': (1, 3), 'dtype': 'float32' }, 'field2': { 'shape': (1, ), 'dtype': 'float32' }, }) num_samples = pool._max_size // 2 pool.add_samples( num_samples, **{ field_name: np.random.uniform(0, 1, (num_samples, *field_attrs['shape'])) for field_name, field_attrs in pool.fields.items() }) self.assertEqual(pool._size, num_samples) serialized = pickle.dumps(pool) deserialized = pickle.loads(serialized) for key in deserialized.__dict__: np.testing.assert_array_equal(pool.__dict__[key], deserialized.__dict__[key]) self.assertNotEqual(id(pool), id(deserialized)) self.assertEqual(deserialized._size, num_samples) for field_name, field_attrs in pool.fields.items(): np.testing.assert_array_equal(getattr(pool, field_name), getattr(deserialized, field_name)) def test_field_initialization(self): # Fill fields with random data for field_name, field_attrs in self.pool.fields.items(): field_values = getattr(self.pool, field_name) self.assertEqual(field_values.shape, (self.pool._max_size, *field_attrs['shape'])) self.assertEqual(field_values.dtype.name, field_attrs['dtype']) np.testing.assert_array_equal(field_values, 0.0) def test_serialize_deserialize_full(self): # Fill fields with random data self.pool.add_samples( self.pool._max_size, **{ field_name: np.random.uniform(0, 1, (self.pool._max_size, *field_attrs['shape'])) for field_name, field_attrs in self.pool.fields.items() }) self.assertEqual(self.pool._size, self.pool._max_size) serialized = pickle.dumps(self.pool) deserialized = pickle.loads(serialized) for key in deserialized.__dict__: np.testing.assert_array_equal(self.pool.__dict__[key], deserialized.__dict__[key]) self.assertNotEqual(id(self.pool), id(deserialized)) self.assertEqual(deserialized._size, deserialized._max_size) for field_name, field_attrs in self.pool.fields.items(): np.testing.assert_array_equal(getattr(self.pool, field_name), getattr(deserialized, field_name)) def test_serialize_deserialize_not_full(self): # Fill fields with random data num_samples = self.pool._max_size // 2 self.pool.add_samples( num_samples, **{ field_name: np.random.uniform(0, 1, (num_samples, *field_attrs['shape'])) for field_name, field_attrs in self.pool.fields.items() }) self.assertEqual(self.pool._size, num_samples) serialized = pickle.dumps(self.pool) deserialized = pickle.loads(serialized) for key in deserialized.__dict__: np.testing.assert_array_equal(self.pool.__dict__[key], deserialized.__dict__[key]) self.assertNotEqual(id(self.pool), id(deserialized)) self.assertEqual(deserialized._size, num_samples) for field_name, field_attrs in self.pool.fields.items(): np.testing.assert_array_equal(getattr(self.pool, field_name), getattr(deserialized, field_name)) def test_serialize_deserialize_empty(self): # Fill fields with random data self.assertEqual(self.pool._size, 0) for field_name in self.pool.field_names: np.testing.assert_array_equal(getattr(self.pool, field_name), 0.0) serialized = pickle.dumps(self.pool) deserialized = pickle.loads(serialized) for key in deserialized.__dict__: np.testing.assert_array_equal(self.pool.__dict__[key], deserialized.__dict__[key]) self.assertNotEqual(id(self.pool), id(deserialized)) self.assertEqual(deserialized._size, 0) for field_name, field_attrs in self.pool.fields.items(): np.testing.assert_array_equal(getattr(self.pool, field_name), getattr(deserialized, field_name)) def test_add_sample(self): for value in range(self.pool._max_size): sample = { 'field1': np.array([[value]]), 'field2': np.array([[-value * 2]]), } self.pool.add_sample(**sample) np.testing.assert_array_equal(self.pool.field1, np.arange(self.pool._max_size)[:, None]) np.testing.assert_array_equal( self.pool.field2, -np.arange(self.pool._max_size)[:, None] * 2) def test_add_samples(self): samples = { 'field1': np.arange(self.pool._max_size)[:, None], 'field2': -np.arange(self.pool._max_size)[:, None] * 2, } self.pool.add_samples(num_samples=self.pool._max_size, **samples) np.testing.assert_array_equal(self.pool.field1, np.arange(self.pool._max_size)[:, None]) np.testing.assert_array_equal( self.pool.field2, -np.arange(self.pool._max_size)[:, None] * 2) def test_random_indices(self): empty_pool_indices = self.pool.random_indices(4) self.assertEqual(empty_pool_indices.shape, (0, )) self.assertEqual(empty_pool_indices.dtype, np.int64) samples = { 'field1': np.arange(self.pool._max_size)[:, None], 'field2': -np.arange(self.pool._max_size)[:, None] * 2, } self.pool.add_samples(num_samples=self.pool._max_size, **samples) full_pool_indices = self.pool.random_indices(4) self.assertEqual(full_pool_indices.shape, (4, )) self.assertTrue(np.all(full_pool_indices < self.pool.size)) self.assertTrue(np.all(full_pool_indices >= 0)) def test_random_batch(self): empty_pool_batch = self.pool.random_batch(4) for key, values in empty_pool_batch.items(): self.assertEqual(values.size, 0) samples = { 'field1': np.arange(self.pool._max_size)[:, None], 'field2': -np.arange(self.pool._max_size)[:, None] * 2, } self.pool.add_samples(num_samples=self.pool._max_size, **samples) full_pool_batch = self.pool.random_batch(4) for key, values in full_pool_batch.items(): self.assertEqual(values.shape, (4, 1)) self.assertTrue( np.all(full_pool_batch['field1'] < self.pool._max_size)) self.assertTrue(np.all(full_pool_batch['field1'] >= 0)) self.assertTrue(np.all(full_pool_batch['field2'] % 2 == 0)) self.assertTrue(np.all(full_pool_batch['field2'] <= 0)) def test_last_n_batch(self): empty_pool_batch = self.pool.last_n_batch(4) for key, values in empty_pool_batch.items(): self.assertEqual(values.size, 0) samples = { 'field1': np.arange(self.pool._max_size)[:, None], 'field2': -np.arange(self.pool._max_size)[:, None] * 2, } self.pool.add_samples(num_samples=self.pool._max_size, **samples) full_pool_batch = self.pool.last_n_batch(4) for key, values in full_pool_batch.items(): np.testing.assert_array_equal(samples[key][-4:], values) self.assertEqual(values.shape, (4, 1)) def test_batch_by_indices(self): with self.assertRaises(ValueError): self.pool.batch_by_indices(np.array([-1, 2, 4])) samples = { 'field1': np.arange(self.pool._max_size)[:, None], 'field2': -np.arange(self.pool._max_size)[:, None] * 2, } self.pool.add_samples(num_samples=self.pool._max_size, **samples) batch = self.pool.batch_by_indices( np.flip(np.arange(self.pool._max_size))) for key, values in batch.items(): np.testing.assert_array_equal(np.flip(samples[key]), values) self.assertEqual(values.shape, (self.pool._max_size, 1))