class TestPrioritizedReplayBuffer(unittest.TestCase): def setUp(self): random.seed(1) np.random.seed(1) torch.manual_seed(1) self.replay_buffer = PrioritizedReplayBuffer(5, 0.6) def test_run(self): states = StateArray(torch.arange(0, 20), (20,), reward=torch.arange(-1, 19).float()) actions = torch.arange(0, 20).view((-1, 1)) expected_samples = State( torch.tensor( [ [0, 1, 2], [0, 1, 3], [5, 5, 5], [6, 6, 2], [7, 7, 7], [7, 8, 8], [7, 7, 7], ] ) ) expected_weights = [ [1.0000, 1.0000, 1.0000], [0.5659, 0.7036, 0.5124], [0.0631, 0.0631, 0.0631], [0.0631, 0.0631, 0.1231], [0.0631, 0.0631, 0.0631], [0.0776, 0.0631, 0.0631], [0.0866, 0.0866, 0.0866], ] actual_samples = [] actual_weights = [] for i in range(10): self.replay_buffer.store(states[i], actions[i], states[i + 1]) if i > 2: sample = self.replay_buffer.sample(3) sample_states = sample[0].observation self.replay_buffer.update_priorities(torch.randn(3)) actual_samples.append(sample_states) actual_weights.append(sample[-1]) actual_samples = State(torch.cat(actual_samples).view((-1, 3))) self.assert_states_equal(actual_samples, expected_samples) np.testing.assert_array_almost_equal( expected_weights, np.vstack(actual_weights), decimal=3 ) def assert_states_equal(self, actual, expected): tt.assert_almost_equal(actual.observation, expected.observation) self.assertEqual(actual.mask, expected.mask)
class TestPrioritizedReplayBuffer(unittest.TestCase): def setUp(self): random.seed(1) np.random.seed(1) torch.manual_seed(1) self.replay_buffer = PrioritizedReplayBuffer(5, 0.6) def test_run(self): states = State(torch.arange(0, 20)) actions = torch.arange(0, 20) rewards = torch.arange(0, 20) expected_samples = State( torch.tensor([ [0, 2, 2], [0, 1, 1], [3, 3, 5], [5, 3, 6], [3, 5, 7], [8, 5, 8], [8, 5, 5], ])) expected_weights = [[1., 1., 1.], [0.56589746, 0.5124394, 0.5124394], [0.5124343, 0.5124343, 0.5124343], [0.5090894, 0.6456939, 0.46323255], [0.51945686, 0.5801515, 0.45691562], [0.45691025, 0.5096957, 0.45691025], [0.5938914, 0.6220026, 0.6220026]] actual_samples = [] actual_weights = [] for i in range(10): self.replay_buffer.store(states[i], actions[i], rewards[i], states[i + 1]) if i > 2: sample = self.replay_buffer.sample(3) sample_states = sample[0].features self.replay_buffer.update_priorities(torch.randn(3)) actual_samples.append(sample_states) actual_weights.append(sample[-1]) actual_samples = State(torch.cat(actual_samples).view((-1, 3))) self.assert_states_equal(actual_samples, expected_samples) np.testing.assert_array_almost_equal(expected_weights, np.vstack(actual_weights)) def assert_states_equal(self, actual, expected): tt.assert_almost_equal(actual.raw, expected.raw) tt.assert_equal(actual.mask, expected.mask)