class TestBuffer(TestCase): def train_batch(self): """Returns an iterator used for testing""" return iter([i for i in range(100)]) def setUp(self) -> None: self.state = np.random.rand(4, 84, 84) self.next_state = np.random.rand(4, 84, 84) self.action = np.ones([1]) self.reward = np.ones([1]) self.done = np.zeros([1]) self.experience = Experience(self.state, self.action, self.reward, self.done, self.next_state) self.source = Mock() self.source.step = Mock(return_value=(self.experience, torch.tensor(0), False)) self.batch_size = 8 self.buffer = Buffer(8) for _ in range(self.batch_size): self.buffer.append(self.experience) def test_sample_batch(self): """check that a sinlge sample is returned""" sample = self.buffer.sample() self.assertEqual(len(sample), 5) self.assertEqual(sample[0].shape, (self.batch_size, 4, 84, 84)) self.assertEqual(sample[1].shape, (self.batch_size, 1)) self.assertEqual(sample[2].shape, (self.batch_size, 1)) self.assertEqual(sample[3].shape, (self.batch_size, 1)) self.assertEqual(sample[4].shape, (self.batch_size, 4, 84, 84))
class TestBuffer(TestCase): def setUp(self) -> None: self.state = np.random.rand(4, 84, 84) self.next_state = np.random.rand(4, 84, 84) self.action = np.ones([1]) self.reward = np.ones([1]) self.done = np.zeros([1]) self.experience = Experience(self.state, self.action, self.reward, self.done, self.next_state) self.source = Mock() self.source.step = Mock(return_value=(self.experience, torch.tensor(0), False)) self.batch_size = 8 self.buffer = Buffer(8) for _ in range(self.batch_size): self.buffer.append(self.experience) def test_sample_batch(self): """check that a sinlge sample is returned""" sample = self.buffer.sample() self.assertEqual(len(sample), 5) self.assertEqual(sample[0].shape, (self.batch_size, 4, 84, 84)) self.assertEqual(sample[1].shape, (self.batch_size, 1)) self.assertEqual(sample[2].shape, (self.batch_size, 1)) self.assertEqual(sample[3].shape, (self.batch_size, 1)) self.assertEqual(sample[4].shape, (self.batch_size, 4, 84, 84)) def test_dataloader(self): """tests that the buffer works with dataloader""" dataset = RLDataset(self.buffer, sample_size=self.batch_size) dl = DataLoader(dataset, batch_size=self.batch_size) for i_batch, sample_batched in enumerate(dl): self.assertIsInstance(sample_batched, list) self.assertEqual(sample_batched[0].shape, torch.Size([self.batch_size, 4, 84, 84])) self.assertEqual(sample_batched[1].shape, torch.Size([self.batch_size, 1])) self.assertEqual(sample_batched[2].shape, torch.Size([self.batch_size, 1])) self.assertEqual(sample_batched[3].shape, torch.Size([self.batch_size, 1])) self.assertEqual(sample_batched[4].shape, torch.Size([self.batch_size, 4, 84, 84]))