def test_base_buffer_exception(self):
     with self.assertRaises(ValueError):
         # size <= 0
         ub_data.BaseBuffer(0, 1)
     with self.assertRaises(AssertionError):
         # batch <= 0
         ub_data.BaseBuffer(1, 0)
     buf = ub_data.BaseBuffer(1, 1)
     self.assertTrue(buf.isnull)
     with self.assertRaises(AssertionError):
         # AssertionError: Buffer space not created
         buf[0]
     # auto create space
     buf._set_data(1, indices=0)
     buf[0]
 def test_base_buffer_relative_index(self):
     capacity = 10
     batch = 1
     n_samples = 15  # test circular
     buf = ub_data.BaseBuffer(capacity, batch=batch)
     for i in range(n_samples):
         buf.add({'a': ([i], [i + 1])})
     head = n_samples % capacity
     self.assertEqual(head, buf.head)
     self.assertEqual(head, buf.tail)
     # test int, slice key
     data = buf.rel[1]
     self.assertArrayEqual([head + 1], data['a'][0])
     self.assertArrayEqual([head + 2], data['a'][1])
     data = buf.rel[-1]
     self.assertArrayEqual([n_samples - 1], data['a'][0])
     self.assertArrayEqual([n_samples], data['a'][1])
     data = buf.rel[1:3]
     exp = np.arange(2).reshape(-1, 1)
     self.assertArrayEqual(exp + head + 1, data['a'][0])
     self.assertArrayEqual(exp + head + 2, data['a'][1])
     data = buf.rel[-3:-1]
     exp = np.arange(2, 0, -1).reshape(-1, 1)
     self.assertArrayEqual(n_samples - 1 - exp, data['a'][0])
     self.assertArrayEqual(n_samples - exp, data['a'][1])
     data = buf.rel[-1:1]
     self.assertEqual((0, 1), data['a'][0].shape)
     self.assertEqual((0, 1), data['a'][1].shape)
     # test tuple key
     data = buf.rel[1, 0]
     self.assertArrayEqual(head + 1, data['a'][0])
     self.assertArrayEqual(head + 2, data['a'][1])
     data = buf.rel[-1, 0]
     self.assertArrayEqual(n_samples - 1, data['a'][0])
     self.assertArrayEqual(n_samples, data['a'][1])
     data = buf.rel[1:3, 0]
     exp = np.arange(2)
     self.assertArrayEqual(exp + head + 1, data['a'][0])
     self.assertArrayEqual(exp + head + 2, data['a'][1])
     data = buf.rel[-3:-1, 0]
     exp = np.arange(2, 0, -1)
     self.assertArrayEqual(n_samples - 1 - exp, data['a'][0])
     self.assertArrayEqual(n_samples - exp, data['a'][1])
     data = buf.rel[-1:1, 0]
     self.assertEqual((0, ), data['a'][0].shape)
     self.assertEqual((0, ), data['a'][1].shape)
     # test list key
     data = buf.rel[[1, 3]]
     self.assertArrayEqual([[head + 1], [head + 3]], data['a'][0])
     self.assertArrayEqual([[head + 2], [head + 4]], data['a'][1])
     # test np key
     data = buf.rel[np.asarray([1, 3])]
     self.assertArrayEqual([[head + 1], [head + 3]], data['a'][0])
     self.assertArrayEqual([[head + 2], [head + 4]], data['a'][1])
     # test index out of range
     with self.assertRaises(IndexError):
         buf.rel[capacity + 1]
 def test_sampler_exception(self):
     with self.assertRaises(ValueError):
         # `buffer` must be an instance of BaseBuffer
         ub_data.UniformSampler(None)
     with self.assertRaises(ValueError):
         # `buffer` must be an instance of BaseBuffer
         ub_data.PermuteSampler(None)
     with self.assertRaises(ValueError):
         # `buffer` must be ReplayBuffer
         ub_data.PrioritizedSampler(ub_data.BaseBuffer(10), 1.0)
     with self.assertRaises(ValueError):
         # `buffer` must be ReplayBuffer
         ub_data.PrioritizedSampler(ub_data.DynamicBuffer(), 1.0)
 def test_base_buffer_auto_calc_space(self):
     capacity = 10
     batch = 1
     buf = ub_data.BaseBuffer(capacity, batch=batch)
     self.assertEqual(0, len(buf))
     self.assertEqual(0, buf.len_slots())
     self.assertEqual(capacity, buf.capacity)
     self.assertEqual(capacity, buf.slots)
     self.assertEqual(batch, buf.batch)
     self.assertEqual(0, buf.head)
     self.assertEqual(0, buf.tail)
     self.assertTrue(buf.isnull)
     self.assertFalse(buf.isfull)
     self.assertTrue(buf.ready_for_sample)
     capacity = 10
     n_samples = 15  # test circular
     buf = ub_data.BaseBuffer(capacity, batch=None)
     self.assertEqual(0, len(buf))
     self.assertEqual(0, buf.len_slots())
     self.assertEqual(None, buf.capacity)
     self.assertEqual(None, buf.slots)
     self.assertEqual(None, buf.batch)
     self.assertEqual(0, buf.head)
     self.assertEqual(0, buf.tail)
     self.assertTrue(buf.isnull)
     self.assertFalse(buf.isfull)
     self.assertTrue(buf.ready_for_sample)
     buf.add({'a': [0, 1]})
     self.assertEqual(2, len(buf))
     self.assertEqual(1, buf.len_slots())
     self.assertEqual(capacity, buf.capacity)
     self.assertEqual(math.ceil(capacity / 2), buf.slots)
     self.assertEqual(2, buf.batch)
     self.assertEqual(0, buf.head)
     self.assertEqual(1, buf.tail)
     self.assertFalse(buf.isnull)
     self.assertFalse(buf.isfull)
     self.assertTrue(buf.ready_for_sample)
 def test_base_buffer_shape(self):
     capacity = 10
     batch = 3
     n_samples = 15
     buf = ub_data.BaseBuffer(capacity, batch=batch)
     for i in range(n_samples):
         buf.add({'a': np.arange(batch)})
     self.assertArrayEqual((3, ), buf[1]['a'].shape)
     self.assertArrayEqual((2, 3), buf[1:3]['a'].shape)
     self.assertArrayEqual((2, ), buf[1:3, 0]['a'].shape)
     self.assertArrayEqual((2, 2), buf[1:3, :2]['a'].shape)
     self.assertArrayEqual((3, ), buf.rel[1]['a'].shape)
     self.assertArrayEqual((2, 3), buf.rel[1:3]['a'].shape)
     self.assertArrayEqual((2, ), buf.rel[1:3, 0]['a'].shape)
     self.assertArrayEqual((2, 2), buf.rel[1:3, :2]['a'].shape)
 def test_uniform_sampler_with_base_buffer(self):
     capacity = 10
     batch = 1
     n_samples = 15
     buf = ub_data.BaseBuffer(capacity, batch=batch)
     samp = ub_data.UniformSampler(buf)
     for i in range(n_samples):
         buf.add({'a': ([i], [i + 1])})
         if i < capacity - 1:
             self.assertFalse(buf.isfull)
             self.assertEqual(i + 1, len(buf))
         else:
             self.assertTrue(buf.isfull)
             self.assertEqual(capacity, len(buf))
     exp = np.arange(n_samples - capacity, n_samples)
     exp_a0 = np.roll(exp, n_samples % capacity)
     exp_a1 = exp_a0 + 1
     exp_a0 = np.expand_dims(exp_a0, axis=-1)
     exp_a1 = np.expand_dims(exp_a1, axis=-1)
     self.assertArrayEqual(buf.data['a'][0], exp_a0)
     self.assertArrayEqual(buf.data['a'][1], exp_a1)
     # test sample (batch=None)
     batch = samp()
     self.assertArrayEqual((capacity, ), batch['a'][0].shape)
     self.assertArrayEqual((capacity, ), batch['a'][1].shape)
     self.assertArrayEqual(batch['a'][0], buf[samp.indices]['a'][0])
     self.assertArrayEqual(batch['a'][1], buf[samp.indices]['a'][1])
     # test sample (batch=3)
     batch_size = 3
     batch = samp(batch_size=batch_size)
     self.assertArrayEqual((batch_size, ), batch['a'][0].shape)
     self.assertArrayEqual((batch_size, ), batch['a'][1].shape)
     self.assertArrayEqual(batch['a'][0], buf[samp.indices]['a'][0])
     self.assertArrayEqual(batch['a'][1], buf[samp.indices]['a'][1])
     # test sample (batch=3, seq=2)
     batch_size = 3
     seq_len = 2
     batch = samp(batch_size=batch_size, seq_len=seq_len)
     self.assertArrayEqual((batch_size, seq_len), batch['a'][0].shape)
     self.assertArrayEqual((batch_size, seq_len), batch['a'][1].shape)
     self.assertArrayEqual(batch['a'][0], buf[samp.indices]['a'][0])
     self.assertArrayEqual(batch['a'][1], buf[samp.indices]['a'][1])
     # test update
     batch['a'] = (np.zeros_like(batch['a'][0]),
                   np.zeros_like(batch['a'][1]))
     samp.update(batch)
     self.assertArrayEqual(buf[samp.indices]['a'][0], batch['a'][0])
     self.assertArrayEqual(buf[samp.indices]['a'][1], batch['a'][1])
 def test_uniform_sampler_with_base_buffer_rel(self):
     capacity = 10
     batch = 1
     n_samples = 15
     buf = ub_data.BaseBuffer(capacity, batch=batch)
     samp = ub_data.UniformSampler(buf)
     for i in range(n_samples):
         buf.add({'a': ([i], [i + 1])})
     # test relative indexing
     inds1 = np.arange(3)
     inds2 = np.zeros(3, dtype=np.int64)
     samp._cached_inds = (inds1, inds2)
     self.assertArrayEqual(buf[samp.indices]['a'][0], samp.rel[0]['a'][0])
     self.assertArrayEqual(buf[samp.indices]['a'][1], samp.rel[0]['a'][1])
     self.assertArrayEqual(buf[(inds1 - 3, inds2)]['a'][0],
                           samp.rel[-3]['a'][0])
     self.assertArrayEqual(buf[(inds1 - 3, inds2)]['a'][1],
                           samp.rel[-3]['a'][1])
     self.assertArrayEqual(buf[(inds1 + 3, inds2)]['a'][0],
                           samp.rel[3]['a'][0])
     self.assertArrayEqual(buf[(inds1 + 3, inds2)]['a'][1],
                           samp.rel[3]['a'][1])
     add = np.array([[1, 2]], dtype=np.int64).T
     self.assertArrayEqual(buf[(inds1 + add, inds2)]['a'][0],
                           samp.rel[1:3]['a'][0])
     self.assertArrayEqual(buf[(inds1 + add, inds2)]['a'][1],
                           samp.rel[1:3]['a'][1])
     add = np.array([[-3, -2]], dtype=np.int64).T
     self.assertArrayEqual(buf[(inds1 + add, inds2)]['a'][0],
                           samp.rel[-3:-1]['a'][0])
     self.assertArrayEqual(buf[(inds1 + add, inds2)]['a'][1],
                           samp.rel[-3:-1]['a'][1])
     # test shape
     self.assertArrayEqual((0, 3), samp.rel[-1:1]['a'][0].shape)
     # test setitem
     add = np.array([[1, 2]], dtype=np.int64).T
     batch = samp.rel[1:3]
     batch['a'] = (np.zeros_like(batch['a'][0]),
                   np.zeros_like(batch['a'][1]))
     samp.rel[1:3] = batch
     self.assertTrue(np.all(samp.rel[1:3]['a'][0] == 0))
     self.assertTrue(np.all(samp.rel[1:3]['a'][1] == 0))
    def test_base_buffer_multidim(self):
        capacity = 20
        batch = 2
        dim = 2
        n_samples = 15  # test circular
        buf = ub_data.BaseBuffer(capacity, batch=batch)
        data = np.arange(n_samples * batch * dim).reshape(
            (n_samples, batch, dim))
        for i in range(n_samples):
            buf.add({'a': data[i]})
            if (i + 1) * batch < capacity:
                self.assertFalse(buf.isfull)
                self.assertEqual((i + 1) * batch, len(buf))
                self.assertEqual(i + 1, buf.len_slots())
                self.assertEqual(0, buf.head)
            else:
                self.assertTrue(buf.isfull)
                self.assertEqual(capacity, len(buf))
                self.assertEqual(capacity // batch, buf.len_slots())
            self.assertEqual((i + 1) % (capacity // batch), buf.tail)
        exp = np.arange(n_samples * batch * dim - capacity * dim,
                        n_samples * batch * dim)
        exp = exp.reshape(-1, 2, 2)
        exp = np.roll(exp, n_samples % (capacity // batch), axis=0)
        self.assertArrayEqual(exp, buf.data['a'])

        # test ravel/unravel index
        def test_ravel(indices):
            self.assertArrayEqual(
                np.ravel_multi_index(indices, (buf.slots, buf.batch)),
                buf.ravel_index(indices))

        test_ravel(([1, 2, 3], 0))
        test_ravel(([[1], [2], [3]], [0, 1]))

        def test_unravel(indices):
            self.assertArrayEqual(
                np.unravel_index(indices, (buf.slots, buf.batch)),
                buf.unravel_index(indices))

        test_unravel([4, 5, 6])
        test_unravel(7)
 def test_permute_sampler_with_base_buffer(self):
     capacity = 10
     batch = 1
     n_samples = 15
     buf = ub_data.BaseBuffer(capacity, batch=batch)
     samp = ub_data.PermuteSampler(buf)
     for i in range(n_samples):
         buf.add({'a': ([i], [i + 1])})
     # test sample (batch=None)
     batch_size = len(buf)
     batches = []
     indices = []
     for batch in samp():
         self.assertArrayEqual((batch_size, ), batch['a'][0].shape)
         self.assertArrayEqual((batch_size, ), batch['a'][1].shape)
         self.assertArrayEqual(buf[samp.indices]['a'][0], batch['a'][0])
         self.assertArrayEqual(buf[samp.indices]['a'][1], batch['a'][1])
         batches.append(batch)
         indices.append(
             np.ravel_multi_index(samp.indices,
                                  (buf.len_slots(), buf.batch)))
     self.assertEqual(1, len(batches))
     unique, counts = np.unique(indices, return_counts=True)
     # check if contains all elements
     self.assertTrue(len(buf), len(unique))
     # check if all elements are sampled at least once
     self.assertTrue(np.all(counts == 1))
     # test sample (batch=3)
     batch_size = 3
     batches = []
     indices = []
     for batch in samp(batch_size=batch_size):
         self.assertArrayEqual((batch_size, ), batch['a'][0].shape)
         self.assertArrayEqual((batch_size, ), batch['a'][1].shape)
         self.assertArrayEqual(buf[samp.indices]['a'][0], batch['a'][0])
         self.assertArrayEqual(buf[samp.indices]['a'][1], batch['a'][1])
         batches.append(batch)
         indices.append(
             np.ravel_multi_index(samp.indices,
                                  (buf.len_slots(), buf.batch)))
     self.assertEqual(4, len(batches))  # total samples == capacity
     unique, counts = np.unique(indices, return_counts=True)
     # check if contains all elements
     self.assertTrue(len(buf), len(unique))
     # check if all elements are sampled at least once but less than 2
     self.assertTrue(np.all(counts >= 1))
     self.assertTrue(np.all(counts <= 2))
     # test sample (batch=3, seq_len=2)
     batch_size = 3
     seq_len = 2
     batches = []
     indices = []
     for batch in samp(batch_size=batch_size, seq_len=seq_len):
         self.assertArrayEqual((batch_size, seq_len), batch['a'][0].shape)
         self.assertArrayEqual((batch_size, seq_len), batch['a'][1].shape)
         self.assertArrayEqual(buf[samp.indices]['a'][0], batch['a'][0])
         self.assertArrayEqual(buf[samp.indices]['a'][1], batch['a'][1])
         batches.append(batch)
         indices.append(
             np.ravel_multi_index(samp.indices,
                                  (buf.len_slots(), buf.batch)))
     self.assertEqual(3, len(batches))  # total samples == capacity
     unique, counts = np.unique(indices, return_counts=True)
     # check if contains all elements
     self.assertTrue(len(buf), len(unique))
     # check if all elements are sampled at least once but less than 3
     self.assertTrue(np.all(counts >= 1))
     self.assertTrue(np.all(counts <= 4))
    def test_base_buffer(self):
        capacity = 10
        batch = 1
        n_samples = 15  # test circular
        buf = ub_data.BaseBuffer(capacity, batch=batch)
        self.assertEqual(capacity, buf.capacity)
        self.assertEqual(capacity, buf.slots)
        self.assertEqual(batch, buf.batch)
        self.assertEqual(0, buf.head)
        self.assertEqual(0, buf.tail)
        self.assertTrue(buf.isnull)
        self.assertFalse(buf.isfull)
        self.assertTrue(buf.ready_for_sample)
        for i in range(n_samples):
            buf.add({'a': ([i], [i + 1])})
            if i < capacity - 1:
                self.assertFalse(buf.isfull)
                self.assertEqual(i + 1, len(buf))
                self.assertEqual(i + 1, buf.len_slots())
                self.assertEqual(0, buf.head)
            else:
                self.assertTrue(buf.isfull)
                self.assertEqual(capacity, len(buf))
                self.assertEqual(capacity, buf.len_slots())
            self.assertEqual((i + 1) % capacity, buf.tail)
        exp = np.arange(n_samples - capacity, n_samples)
        exp_a0 = np.roll(exp, n_samples % capacity)
        exp_a1 = exp_a0 + 1
        exp_a0 = np.expand_dims(exp_a0, axis=-1)
        exp_a1 = np.expand_dims(exp_a1, axis=-1)
        self.assertArrayEqual(exp_a0, buf.data['a'][0])
        self.assertArrayEqual(exp_a1, buf.data['a'][1])
        # test getitem
        data = buf[np.arange(n_samples % capacity)]
        exp_a0 = np.arange(n_samples - n_samples % capacity, n_samples)
        exp_a1 = exp_a0 + 1
        exp_a0 = np.expand_dims(exp_a0, axis=-1)
        exp_a1 = np.expand_dims(exp_a1, axis=-1)
        self.assertArrayEqual(exp_a0, data['a'][0])
        self.assertArrayEqual(exp_a1, data['a'][1])
        # test setitem
        n = n_samples - capacity
        new_data = np.arange(n - n_samples % capacity, n)
        new_data = np.expand_dims(new_data, axis=-1)
        new_data = {'a': (new_data, new_data + 1)}
        buf[np.arange(n_samples % capacity)] = new_data
        n = n_samples - capacity - n_samples % capacity
        exp_a0 = np.arange(n, n + capacity)
        exp_a1 = exp_a0 + 1
        exp_a0 = np.expand_dims(exp_a0, axis=-1)
        exp_a1 = np.expand_dims(exp_a1, axis=-1)
        self.assertArrayEqual(exp_a0, buf.data['a'][0])
        self.assertArrayEqual(exp_a1, buf.data['a'][1])
        # test update (should have the same results as setitem)
        buf.update(new_data, indices=np.arange(n_samples % capacity))
        self.assertArrayEqual(exp_a0, buf.data['a'][0])
        self.assertArrayEqual(exp_a1, buf.data['a'][1])

        # test ravel/unravel index
        def test_ravel(indices):
            self.assertArrayEqual(
                np.ravel_multi_index(indices, (buf.slots, buf.batch)),
                buf.ravel_index(indices))

        test_ravel(([1, 2, 3], 0))
        test_ravel(([1, 2, 3], [0]))

        def test_unravel(indices):
            self.assertArrayEqual(
                np.unravel_index(indices, (buf.slots, buf.batch)),
                buf.unravel_index(indices))

        test_unravel([4, 5, 6])
        test_unravel(7)