示例#1
0
class FlexibleReplayPoolTest(unittest.TestCase):
    def setUp(self):
        self.pool = FlexibleReplayPool(max_size=10,
                                       fields={
                                           'field1': {
                                               'shape': (1, ),
                                               'dtype': 'float32'
                                           },
                                           'field2': {
                                               'shape': (1, ),
                                               'dtype': 'float32'
                                           },
                                       })

    def test_multi_dimensional_field(self):
        # Fill fields with random data
        pool = FlexibleReplayPool(max_size=10,
                                  fields={
                                      'field1': {
                                          'shape': (1, 3),
                                          'dtype': 'float32'
                                      },
                                      'field2': {
                                          'shape': (1, ),
                                          'dtype': 'float32'
                                      },
                                  })
        num_samples = pool._max_size // 2
        pool.add_samples(
            num_samples, **{
                field_name:
                np.random.uniform(0, 1, (num_samples, *field_attrs['shape']))
                for field_name, field_attrs in pool.fields.items()
            })

        self.assertEqual(pool._size, num_samples)

        serialized = pickle.dumps(pool)
        deserialized = pickle.loads(serialized)
        for key in deserialized.__dict__:
            np.testing.assert_array_equal(pool.__dict__[key],
                                          deserialized.__dict__[key])

        self.assertNotEqual(id(pool), id(deserialized))

        self.assertEqual(deserialized._size, num_samples)
        for field_name, field_attrs in pool.fields.items():
            np.testing.assert_array_equal(getattr(pool, field_name),
                                          getattr(deserialized, field_name))

    def test_field_initialization(self):
        # Fill fields with random data
        for field_name, field_attrs in self.pool.fields.items():
            field_values = getattr(self.pool, field_name)
            self.assertEqual(field_values.shape,
                             (self.pool._max_size, *field_attrs['shape']))
            self.assertEqual(field_values.dtype.name, field_attrs['dtype'])

            np.testing.assert_array_equal(field_values, 0.0)

    def test_serialize_deserialize_full(self):
        # Fill fields with random data
        self.pool.add_samples(
            self.pool._max_size, **{
                field_name:
                np.random.uniform(0, 1,
                                  (self.pool._max_size, *field_attrs['shape']))
                for field_name, field_attrs in self.pool.fields.items()
            })

        self.assertEqual(self.pool._size, self.pool._max_size)

        serialized = pickle.dumps(self.pool)
        deserialized = pickle.loads(serialized)
        for key in deserialized.__dict__:
            np.testing.assert_array_equal(self.pool.__dict__[key],
                                          deserialized.__dict__[key])

        self.assertNotEqual(id(self.pool), id(deserialized))

        self.assertEqual(deserialized._size, deserialized._max_size)
        for field_name, field_attrs in self.pool.fields.items():
            np.testing.assert_array_equal(getattr(self.pool, field_name),
                                          getattr(deserialized, field_name))

    def test_serialize_deserialize_not_full(self):
        # Fill fields with random data
        num_samples = self.pool._max_size // 2
        self.pool.add_samples(
            num_samples, **{
                field_name:
                np.random.uniform(0, 1, (num_samples, *field_attrs['shape']))
                for field_name, field_attrs in self.pool.fields.items()
            })

        self.assertEqual(self.pool._size, num_samples)

        serialized = pickle.dumps(self.pool)
        deserialized = pickle.loads(serialized)
        for key in deserialized.__dict__:
            np.testing.assert_array_equal(self.pool.__dict__[key],
                                          deserialized.__dict__[key])

        self.assertNotEqual(id(self.pool), id(deserialized))

        self.assertEqual(deserialized._size, num_samples)
        for field_name, field_attrs in self.pool.fields.items():
            np.testing.assert_array_equal(getattr(self.pool, field_name),
                                          getattr(deserialized, field_name))

    def test_serialize_deserialize_empty(self):
        # Fill fields with random data

        self.assertEqual(self.pool._size, 0)
        for field_name in self.pool.field_names:
            np.testing.assert_array_equal(getattr(self.pool, field_name), 0.0)

        serialized = pickle.dumps(self.pool)
        deserialized = pickle.loads(serialized)
        for key in deserialized.__dict__:
            np.testing.assert_array_equal(self.pool.__dict__[key],
                                          deserialized.__dict__[key])

        self.assertNotEqual(id(self.pool), id(deserialized))

        self.assertEqual(deserialized._size, 0)
        for field_name, field_attrs in self.pool.fields.items():
            np.testing.assert_array_equal(getattr(self.pool, field_name),
                                          getattr(deserialized, field_name))

    def test_add_sample(self):
        for value in range(self.pool._max_size):
            sample = {
                'field1': np.array([[value]]),
                'field2': np.array([[-value * 2]]),
            }
            self.pool.add_sample(**sample)

        np.testing.assert_array_equal(self.pool.field1,
                                      np.arange(self.pool._max_size)[:, None])
        np.testing.assert_array_equal(
            self.pool.field2, -np.arange(self.pool._max_size)[:, None] * 2)

    def test_add_samples(self):
        samples = {
            'field1': np.arange(self.pool._max_size)[:, None],
            'field2': -np.arange(self.pool._max_size)[:, None] * 2,
        }
        self.pool.add_samples(num_samples=self.pool._max_size, **samples)

        np.testing.assert_array_equal(self.pool.field1,
                                      np.arange(self.pool._max_size)[:, None])
        np.testing.assert_array_equal(
            self.pool.field2, -np.arange(self.pool._max_size)[:, None] * 2)

    def test_random_indices(self):
        empty_pool_indices = self.pool.random_indices(4)
        self.assertEqual(empty_pool_indices.shape, (0, ))
        self.assertEqual(empty_pool_indices.dtype, np.int64)

        samples = {
            'field1': np.arange(self.pool._max_size)[:, None],
            'field2': -np.arange(self.pool._max_size)[:, None] * 2,
        }
        self.pool.add_samples(num_samples=self.pool._max_size, **samples)
        full_pool_indices = self.pool.random_indices(4)
        self.assertEqual(full_pool_indices.shape, (4, ))
        self.assertTrue(np.all(full_pool_indices < self.pool.size))
        self.assertTrue(np.all(full_pool_indices >= 0))

    def test_random_batch(self):
        empty_pool_batch = self.pool.random_batch(4)
        for key, values in empty_pool_batch.items():
            self.assertEqual(values.size, 0)

        samples = {
            'field1': np.arange(self.pool._max_size)[:, None],
            'field2': -np.arange(self.pool._max_size)[:, None] * 2,
        }
        self.pool.add_samples(num_samples=self.pool._max_size, **samples)
        full_pool_batch = self.pool.random_batch(4)

        for key, values in full_pool_batch.items():
            self.assertEqual(values.shape, (4, 1))

        self.assertTrue(
            np.all(full_pool_batch['field1'] < self.pool._max_size))
        self.assertTrue(np.all(full_pool_batch['field1'] >= 0))

        self.assertTrue(np.all(full_pool_batch['field2'] % 2 == 0))
        self.assertTrue(np.all(full_pool_batch['field2'] <= 0))

    def test_last_n_batch(self):
        empty_pool_batch = self.pool.last_n_batch(4)
        for key, values in empty_pool_batch.items():
            self.assertEqual(values.size, 0)

        samples = {
            'field1': np.arange(self.pool._max_size)[:, None],
            'field2': -np.arange(self.pool._max_size)[:, None] * 2,
        }
        self.pool.add_samples(num_samples=self.pool._max_size, **samples)
        full_pool_batch = self.pool.last_n_batch(4)

        for key, values in full_pool_batch.items():
            np.testing.assert_array_equal(samples[key][-4:], values)
            self.assertEqual(values.shape, (4, 1))

    def test_batch_by_indices(self):
        with self.assertRaises(ValueError):
            self.pool.batch_by_indices(np.array([-1, 2, 4]))

        samples = {
            'field1': np.arange(self.pool._max_size)[:, None],
            'field2': -np.arange(self.pool._max_size)[:, None] * 2,
        }
        self.pool.add_samples(num_samples=self.pool._max_size, **samples)

        batch = self.pool.batch_by_indices(
            np.flip(np.arange(self.pool._max_size)))
        for key, values in batch.items():
            np.testing.assert_array_equal(np.flip(samples[key]), values)
            self.assertEqual(values.shape, (self.pool._max_size, 1))