Example #1
0
class TestBuffer(TestCase):
    def train_batch(self):
        """Returns an iterator used for testing"""
        return iter([i for i in range(100)])

    def setUp(self) -> None:
        self.state = np.random.rand(4, 84, 84)
        self.next_state = np.random.rand(4, 84, 84)
        self.action = np.ones([1])
        self.reward = np.ones([1])
        self.done = np.zeros([1])
        self.experience = Experience(self.state, self.action, self.reward,
                                     self.done, self.next_state)
        self.source = Mock()
        self.source.step = Mock(return_value=(self.experience, torch.tensor(0),
                                              False))
        self.batch_size = 8
        self.buffer = Buffer(8)

        for _ in range(self.batch_size):
            self.buffer.append(self.experience)

    def test_sample_batch(self):
        """check that a sinlge sample is returned"""
        sample = self.buffer.sample()
        self.assertEqual(len(sample), 5)
        self.assertEqual(sample[0].shape, (self.batch_size, 4, 84, 84))
        self.assertEqual(sample[1].shape, (self.batch_size, 1))
        self.assertEqual(sample[2].shape, (self.batch_size, 1))
        self.assertEqual(sample[3].shape, (self.batch_size, 1))
        self.assertEqual(sample[4].shape, (self.batch_size, 4, 84, 84))
    def setUp(self) -> None:
        self.state = np.random.rand(4, 84, 84)
        self.next_state = np.random.rand(4, 84, 84)
        self.action = np.ones([1])
        self.reward = np.ones([1])
        self.done = np.zeros([1])
        self.experience = Experience(self.state, self.action, self.reward, self.done, self.next_state)
        self.source = Mock()
        self.source.step = Mock(return_value=(self.experience, torch.tensor(0), False))
        self.batch_size = 8
        self.buffer = Buffer(8)

        for _ in range(self.batch_size):
            self.buffer.append(self.experience)
class TestBuffer(TestCase):
    def setUp(self) -> None:
        self.state = np.random.rand(4, 84, 84)
        self.next_state = np.random.rand(4, 84, 84)
        self.action = np.ones([1])
        self.reward = np.ones([1])
        self.done = np.zeros([1])
        self.experience = Experience(self.state, self.action, self.reward,
                                     self.done, self.next_state)
        self.source = Mock()
        self.source.step = Mock(return_value=(self.experience, torch.tensor(0),
                                              False))
        self.batch_size = 8
        self.buffer = Buffer(8)

        for _ in range(self.batch_size):
            self.buffer.append(self.experience)

    def test_sample_batch(self):
        """check that a sinlge sample is returned"""
        sample = self.buffer.sample()
        self.assertEqual(len(sample), 5)
        self.assertEqual(sample[0].shape, (self.batch_size, 4, 84, 84))
        self.assertEqual(sample[1].shape, (self.batch_size, 1))
        self.assertEqual(sample[2].shape, (self.batch_size, 1))
        self.assertEqual(sample[3].shape, (self.batch_size, 1))
        self.assertEqual(sample[4].shape, (self.batch_size, 4, 84, 84))

    def test_dataloader(self):
        """tests that the buffer works with dataloader"""
        dataset = RLDataset(self.buffer, sample_size=self.batch_size)
        dl = DataLoader(dataset, batch_size=self.batch_size)

        for i_batch, sample_batched in enumerate(dl):
            self.assertIsInstance(sample_batched, list)
            self.assertEqual(sample_batched[0].shape,
                             torch.Size([self.batch_size, 4, 84, 84]))
            self.assertEqual(sample_batched[1].shape,
                             torch.Size([self.batch_size, 1]))
            self.assertEqual(sample_batched[2].shape,
                             torch.Size([self.batch_size, 1]))
            self.assertEqual(sample_batched[3].shape,
                             torch.Size([self.batch_size, 1]))
            self.assertEqual(sample_batched[4].shape,
                             torch.Size([self.batch_size, 4, 84, 84]))