Esempio n. 1
0
class TestPrioritizedReplayBuffer(unittest.TestCase):
    def setUp(self):
        random.seed(1)
        np.random.seed(1)
        torch.manual_seed(1)
        self.replay_buffer = PrioritizedReplayBuffer(5, 0.6)

    def test_run(self):
        states = StateArray(torch.arange(0, 20), (20,), reward=torch.arange(-1, 19).float())
        actions = torch.arange(0, 20).view((-1, 1))
        expected_samples = State(
            torch.tensor(
                [
                    [0, 1, 2],
                    [0, 1, 3],
                    [5, 5, 5],
                    [6, 6, 2],
                    [7, 7, 7],
                    [7, 8, 8],
                    [7, 7, 7],
                ]
            )
        )
        expected_weights = [
            [1.0000, 1.0000, 1.0000],
            [0.5659, 0.7036, 0.5124],
            [0.0631, 0.0631, 0.0631],
            [0.0631, 0.0631, 0.1231],
            [0.0631, 0.0631, 0.0631],
            [0.0776, 0.0631, 0.0631],
            [0.0866, 0.0866, 0.0866],
        ]
        actual_samples = []
        actual_weights = []
        for i in range(10):
            self.replay_buffer.store(states[i], actions[i], states[i + 1])
            if i > 2:
                sample = self.replay_buffer.sample(3)
                sample_states = sample[0].observation
                self.replay_buffer.update_priorities(torch.randn(3))
                actual_samples.append(sample_states)
                actual_weights.append(sample[-1])

        actual_samples = State(torch.cat(actual_samples).view((-1, 3)))
        self.assert_states_equal(actual_samples, expected_samples)
        np.testing.assert_array_almost_equal(
            expected_weights, np.vstack(actual_weights), decimal=3
        )

    def assert_states_equal(self, actual, expected):
        tt.assert_almost_equal(actual.observation, expected.observation)
        self.assertEqual(actual.mask, expected.mask)
class TestPrioritizedReplayBuffer(unittest.TestCase):
    def setUp(self):
        random.seed(1)
        np.random.seed(1)
        torch.manual_seed(1)
        self.replay_buffer = PrioritizedReplayBuffer(5, 0.6)

    def test_run(self):
        states = State(torch.arange(0, 20))
        actions = torch.arange(0, 20)
        rewards = torch.arange(0, 20)
        expected_samples = State(
            torch.tensor([
                [0, 2, 2],
                [0, 1, 1],
                [3, 3, 5],
                [5, 3, 6],
                [3, 5, 7],
                [8, 5, 8],
                [8, 5, 5],
            ]))
        expected_weights = [[1., 1., 1.], [0.56589746, 0.5124394, 0.5124394],
                            [0.5124343, 0.5124343, 0.5124343],
                            [0.5090894, 0.6456939, 0.46323255],
                            [0.51945686, 0.5801515, 0.45691562],
                            [0.45691025, 0.5096957, 0.45691025],
                            [0.5938914, 0.6220026, 0.6220026]]
        actual_samples = []
        actual_weights = []
        for i in range(10):
            self.replay_buffer.store(states[i], actions[i], rewards[i],
                                     states[i + 1])
            if i > 2:
                sample = self.replay_buffer.sample(3)
                sample_states = sample[0].features
                self.replay_buffer.update_priorities(torch.randn(3))
                actual_samples.append(sample_states)
                actual_weights.append(sample[-1])

        actual_samples = State(torch.cat(actual_samples).view((-1, 3)))
        self.assert_states_equal(actual_samples, expected_samples)
        np.testing.assert_array_almost_equal(expected_weights,
                                             np.vstack(actual_weights))

    def assert_states_equal(self, actual, expected):
        tt.assert_almost_equal(actual.raw, expected.raw)
        tt.assert_equal(actual.mask, expected.mask)