class TestExperienceReplayBuffer(unittest.TestCase): def setUp(self): np.random.seed(1) random.seed(1) torch.manual_seed(1) self.replay_buffer = ExperienceReplayBuffer(5) def test_run(self): states = torch.arange(0, 20) actions = torch.arange(0, 20) rewards = torch.arange(0, 20) expected_samples = torch.tensor([[0, 0, 0], [1, 1, 0], [0, 1, 1], [3, 0, 0], [1, 4, 4], [1, 2, 4], [2, 4, 3], [4, 7, 4], [7, 4, 6], [6, 5, 6]]) expected_weights = np.ones((10, 3)) actual_samples = [] actual_weights = [] for i in range(10): state = State(states[i].unsqueeze(0), torch.tensor([1])) next_state = State(states[i + 1].unsqueeze(0), torch.tensor([1])) self.replay_buffer.store(state, actions[i], rewards[i], next_state) sample = self.replay_buffer.sample(3) actual_samples.append(sample[0].features) actual_weights.append(sample[-1]) tt.assert_equal( torch.cat(actual_samples).view(expected_samples.shape), expected_samples) np.testing.assert_array_equal(expected_weights, np.vstack(actual_weights))
class TestExperienceReplayBuffer(unittest.TestCase): def test_run(self): np.random.seed(1) random.seed(1) torch.manual_seed(1) self.replay_buffer = ExperienceReplayBuffer(5) states = torch.arange(0, 20) actions = torch.arange(0, 20).view((-1, 1)) rewards = torch.arange(0, 20) expected_samples = torch.tensor([ [0, 0, 0], [1, 1, 0], [0, 1, 1], [3, 0, 0], [1, 4, 4], [1, 2, 4], [2, 4, 3], [4, 7, 4], [7, 4, 6], [6, 5, 6], ]) expected_weights = np.ones((10, 3)) actual_samples = [] actual_weights = [] for i in range(10): state = State(states[i]) next_state = State(states[i + 1], reward=rewards[i]) self.replay_buffer.store(state, actions[i], next_state) sample = self.replay_buffer.sample(3) actual_samples.append(sample[0].observation) actual_weights.append(sample[-1]) tt.assert_equal( torch.cat(actual_samples).view(expected_samples.shape), expected_samples) np.testing.assert_array_equal(expected_weights, np.vstack(actual_weights)) def test_store_device(self): if torch.cuda.is_available(): self.replay_buffer = ExperienceReplayBuffer(5, device='cuda', store_device='cpu') states = torch.arange(0, 20).to('cuda') actions = torch.arange(0, 20).view((-1, 1)).to('cuda') rewards = torch.arange(0, 20).to('cuda') state = State(states[0]) next_state = State(states[1], reward=rewards[1]) self.replay_buffer.store(state, actions[0], next_state) sample = self.replay_buffer.sample(3) self.assertEqual(sample[0].device, torch.device('cuda')) self.assertEqual(self.replay_buffer.buffer[0][0].device, torch.device('cpu'))