class TestExperienceReplayBuffer(unittest.TestCase):
    def setUp(self):
        np.random.seed(1)
        random.seed(1)
        torch.manual_seed(1)
        self.replay_buffer = ExperienceReplayBuffer(5)

    def test_run(self):
        states = torch.arange(0, 20)
        actions = torch.arange(0, 20)
        rewards = torch.arange(0, 20)
        expected_samples = torch.tensor([[0, 0, 0], [1, 1, 0], [0, 1, 1],
                                         [3, 0, 0], [1, 4, 4], [1, 2, 4],
                                         [2, 4, 3], [4, 7, 4], [7, 4, 6],
                                         [6, 5, 6]])
        expected_weights = np.ones((10, 3))
        actual_samples = []
        actual_weights = []
        for i in range(10):
            state = State(states[i].unsqueeze(0), torch.tensor([1]))
            next_state = State(states[i + 1].unsqueeze(0), torch.tensor([1]))
            self.replay_buffer.store(state, actions[i], rewards[i], next_state)
            sample = self.replay_buffer.sample(3)
            actual_samples.append(sample[0].features)
            actual_weights.append(sample[-1])
        tt.assert_equal(
            torch.cat(actual_samples).view(expected_samples.shape),
            expected_samples)
        np.testing.assert_array_equal(expected_weights,
                                      np.vstack(actual_weights))
Exemplo n.º 2
0
class TestExperienceReplayBuffer(unittest.TestCase):
    def test_run(self):
        np.random.seed(1)
        random.seed(1)
        torch.manual_seed(1)
        self.replay_buffer = ExperienceReplayBuffer(5)

        states = torch.arange(0, 20)
        actions = torch.arange(0, 20).view((-1, 1))
        rewards = torch.arange(0, 20)
        expected_samples = torch.tensor([
            [0, 0, 0],
            [1, 1, 0],
            [0, 1, 1],
            [3, 0, 0],
            [1, 4, 4],
            [1, 2, 4],
            [2, 4, 3],
            [4, 7, 4],
            [7, 4, 6],
            [6, 5, 6],
        ])
        expected_weights = np.ones((10, 3))
        actual_samples = []
        actual_weights = []
        for i in range(10):
            state = State(states[i])
            next_state = State(states[i + 1], reward=rewards[i])
            self.replay_buffer.store(state, actions[i], next_state)
            sample = self.replay_buffer.sample(3)
            actual_samples.append(sample[0].observation)
            actual_weights.append(sample[-1])
        tt.assert_equal(
            torch.cat(actual_samples).view(expected_samples.shape),
            expected_samples)
        np.testing.assert_array_equal(expected_weights,
                                      np.vstack(actual_weights))

    def test_store_device(self):
        if torch.cuda.is_available():
            self.replay_buffer = ExperienceReplayBuffer(5,
                                                        device='cuda',
                                                        store_device='cpu')

            states = torch.arange(0, 20).to('cuda')
            actions = torch.arange(0, 20).view((-1, 1)).to('cuda')
            rewards = torch.arange(0, 20).to('cuda')
            state = State(states[0])
            next_state = State(states[1], reward=rewards[1])
            self.replay_buffer.store(state, actions[0], next_state)
            sample = self.replay_buffer.sample(3)
            self.assertEqual(sample[0].device, torch.device('cuda'))
            self.assertEqual(self.replay_buffer.buffer[0][0].device,
                             torch.device('cpu'))