def test_rollout(self): buffer = NStepAdvantageBuffer(self.v, self.features, 2, 3, discount_factor=0.5) actions = torch.ones((3)) states = StateArray(torch.arange(0, 12).unsqueeze(1).float(), (12, )) buffer.store(states[0:3], actions, torch.zeros(3)) buffer.store(states[3:6], actions, torch.ones(3)) states, _, advantages = buffer.advantages(states[6:9]) expected_states = StateArray( torch.arange(0, 6).unsqueeze(1).float(), (6, )) expected_next_states = StateArray( torch.cat( (torch.arange(6, 9), torch.arange(6, 9))).unsqueeze(1).float(), (6, )) expected_returns = torch.tensor([0.5, 0.5, 0.5, 1, 1, 1]).float() expected_lengths = torch.tensor([2., 2, 2, 1, 1, 1]) self.assert_states_equal(states, expected_states) tt.assert_allclose( advantages, self._compute_expected_advantages(expected_states, expected_returns, expected_next_states, expected_lengths))
def test_list(self): model = nn.Linear(2, 2) net = nn.RLNetwork(model, (2,)) features = torch.randn((4, 2)) done = torch.tensor([False, False, True, False]) out = net(StateArray(features, (4,), done=done)) tt.assert_almost_equal( out, torch.tensor( [ [0.0479387, -0.2268031], [0.2346841, 0.0743403], [0.0, 0.0], [0.2204496, 0.086818], ] ), ) features = torch.randn(3, 2) done = torch.tensor([False, False, False]) out = net(StateArray(features, (3,), done=done)) tt.assert_almost_equal( out, torch.tensor( [ [0.4234636, 0.1039939], [0.6514298, 0.3354351], [-0.2543002, -0.2041451], ] ), )
def test_multi_env(self): state = StateArray(torch.randn(2, 2), (2, )) self.agent.act(state) tt.assert_allclose(self.test_agent.last_state.observation, torch.tensor([[0.3923, -0.2236, 0.], [-0.3195, -1.2050, 0.]]), atol=1e-04) self.agent.act(state) tt.assert_allclose(self.test_agent.last_state.observation, torch.tensor([[0.3923, -0.2236, 1e-3], [-0.3195, -1.2050, 1e-3]]), atol=1e-04) self.agent.act( StateArray(state.observation, (2, ), done=torch.tensor([False, True]))) tt.assert_allclose(self.test_agent.last_state.observation, torch.tensor([[0.3923, -0.2236, 2e-3], [-0.3195, -1.2050, 2e-3]]), atol=1e-04) self.agent.act(state) tt.assert_allclose(self.test_agent.last_state.observation, torch.tensor([[0.3923, -0.2236, 3e-3], [-0.3195, -1.2050, 0.]]), atol=1e-04) self.agent.act(state) tt.assert_allclose(self.test_agent.last_state.observation, torch.tensor([[0.3923, -0.2236, 4e-3], [-0.3195, -1.2050, 1e-3]]), atol=1e-04)
def test_apply_done(self): observation = torch.randn(3, 4) state = StateArray(observation, (3,), mask=torch.tensor([0., 0., 0.])) model = torch.nn.Linear(4, 2) output = state.apply(model, 'observation') self.assertEqual(output.shape, (3, 2)) self.assertEqual(output.sum().item(), 0)
def test_auto_mask(self): observation = torch.randn(3, 4) state = StateArray({ 'observation': observation, 'done': torch.tensor([True, False, True]), }, (3,)) tt.assert_equal(state.mask, torch.tensor([0., 1., 0.]))
def test_reinforce(self): states = StateArray(torch.randn((3, STATE_DIM)), (3,)) actions = torch.tensor([0, 1, 0]) original_probs = self.q(states, actions) tt.assert_almost_equal( original_probs, torch.tensor( [ [0.2065, 0.1045, 0.1542, 0.2834, 0.2513], [0.3190, 0.2471, 0.0534, 0.1424, 0.2380], [0.1427, 0.2486, 0.0946, 0.4112, 0.1029], ] ), decimal=3, ) target_dists = torch.tensor( [[0, 0, 1, 0, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0]] ).float() def _loss(dist, target_dist): log_dist = torch.log(torch.clamp(dist, min=1e-5)) log_target_dist = torch.log(torch.clamp(target_dist, min=1e-5)) return (target_dist * (log_target_dist - log_dist)).sum(dim=-1).mean() self.q.reinforce(_loss(original_probs, target_dists)) new_probs = self.q(states, actions) tt.assert_almost_equal( torch.sign(new_probs - original_probs), torch.sign(target_dists - 0.5) )
def test_eval_actions(self): states = StateArray(torch.randn(3, STATE_DIM), (3, )) actions = [1, 2, 0] result = self.q.eval(states, actions) self.assertEqual(result.shape, torch.Size([3])) tt.assert_almost_equal( result, torch.tensor([-0.7262873, 0.3484948, -0.0296164]))
def test_done(self): states = StateArray(torch.randn((3, STATE_DIM)), (3,), mask=torch.tensor([1, 0, 1])) probs = self.q(states) self.assertEqual(probs.shape, (3, ACTIONS, ATOMS)) tt.assert_almost_equal( probs.sum(dim=2), torch.tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]), decimal=3, ) tt.assert_almost_equal( probs, torch.tensor( [ [ [0.2065, 0.1045, 0.1542, 0.2834, 0.2513], [0.3903, 0.2471, 0.0360, 0.1733, 0.1533], ], [[0, 0, 1, 0, 0], [0, 0, 1, 0, 0]], [ [0.1427, 0.2486, 0.0946, 0.4112, 0.1029], [0.0819, 0.1320, 0.1203, 0.0373, 0.6285], ], ] ), decimal=3, )
def test_reinforce_list(self): states = StateArray(torch.randn(5, STATE_DIM), (5, ), mask=torch.tensor([1, 1, 0, 1, 0])) result = self.v(states) tt.assert_almost_equal( result, torch.tensor([0.7053187, 0.3975691, 0., 0.2701665, 0.])) self.v.reinforce(loss(result, torch.tensor([1, -1, 1, 1, 1])).float()) result = self.v(states) tt.assert_almost_equal( result, torch.tensor([0.9732854, 0.5453826, 0., 0.4344811, 0.]))
def test_multi_reinforce(self): states = StateArray(torch.randn(6, STATE_DIM), (6, ), mask=torch.tensor([1, 1, 0, 1, 0, 0, 0])) result1 = self.v(states[0:2]) self.v.reinforce(loss(result1, torch.tensor([1, 2])).float()) result2 = self.v(states[2:4]) self.v.reinforce(loss(result2, torch.tensor([1, 1])).float()) result3 = self.v(states[4:6]) self.v.reinforce(loss(result3, torch.tensor([1, 2])).float()) with self.assertRaises(Exception): self.v.reinforce(loss(result3, torch.tensor([1, 2])).float())
def test_eval_list(self): states = StateArray(torch.randn(5, STATE_DIM), (5, ), mask=torch.tensor([1, 1, 0, 1, 0])) result = self.q.eval(states) tt.assert_almost_equal(result, torch.tensor( [[-0.238509, -0.726287, -0.034026], [-0.35688755, -0.6612102, 0.34849477], [0., 0., 0.], [0.1944, -0.5536, -0.2345], [0., 0., 0.]]), decimal=2)
def test_multi_rollout(self): buffer = NStepAdvantageBuffer(self.v, self.features, 2, 2, discount_factor=0.5) raw_states = StateArray( torch.arange(0, 12).unsqueeze(1).float(), (12, )) actions = torch.ones((2)) buffer.store(raw_states[0:2], actions, torch.ones(2)) buffer.store(raw_states[2:4], actions, torch.ones(2)) states, actions, advantages = buffer.advantages(raw_states[4:6]) expected_states = StateArray( torch.arange(0, 4).unsqueeze(1).float(), (4, )) expected_returns = torch.tensor([1.5, 1.5, 1, 1]) expected_next_states = StateArray( torch.tensor([4., 5, 4, 5]).unsqueeze(1), (4, )) expected_lengths = torch.tensor([2., 2, 1, 1]) self.assert_states_equal(states, expected_states) tt.assert_allclose( advantages, self._compute_expected_advantages(expected_states, expected_returns, expected_next_states, expected_lengths)) buffer.store(raw_states[4:6], actions, torch.ones(2)) buffer.store(raw_states[6:8], actions, torch.ones(2)) states, actions, advantages = buffer.advantages(raw_states[8:10]) expected_states = StateArray( torch.arange(4, 8).unsqueeze(1).float(), (4, )) self.assert_states_equal(states, expected_states) tt.assert_allclose( advantages, self._compute_expected_advantages( expected_states, torch.tensor([1.5, 1.5, 1, 1]), StateArray( torch.tensor([8, 9, 8, 9]).unsqueeze(1).float(), (4, )), torch.tensor([2., 2, 1, 1])))
def _to_state(self, obs, rew, done, info): obs = obs.astype(self.observation_space.dtype) rew = rew.astype("float32") done = done.astype("bool") mask = (1 - done).astype("float32") return StateArray( { "observation": torch.tensor(obs, device=self._device), "reward": torch.tensor(rew, device=self._device), "done": torch.tensor(done, device=self._device), "mask": torch.tensor(mask, device=self._device) }, shape=(self._env.num_envs, ))
def test_rollout_with_dones(self): buffer = NStepAdvantageBuffer(self.v, self.features, 3, 3, discount_factor=0.5) done = torch.tensor([False] * 12) done[5] = True done[7] = True done[9] = True states = StateArray(torch.arange(0, 12).unsqueeze(1).float(), (12, ), done=done) actions = torch.ones((3)) buffer.store(states[0:3], actions, torch.zeros(3)) buffer.store(states[3:6], actions, torch.ones(3)) buffer.store(states[6:9], actions, 2 * torch.ones(3)) states, actions, advantages = buffer.advantages(states[9:12]) expected_states = StateArray(torch.arange(0, 9).unsqueeze(1).float(), (9, ), done=done[0:9]) expected_next_done = torch.tensor([True] * 9) expected_next_done[5] = False expected_next_done[7] = False expected_next_done[8] = False expected_next_states = StateArray(torch.tensor( [9, 7, 5, 9, 7, 11, 9, 10, 11]).unsqueeze(1).float(), (9, ), done=expected_next_done) expected_returns = torch.tensor([1, 0.5, 0, 2, 1, 2, 2, 2, 2]).float() expected_lengths = torch.tensor([3, 2, 1, 2, 1, 2, 1, 1, 1]).float() self.assert_states_equal(states, expected_states) tt.assert_allclose( advantages, self._compute_expected_advantages(expected_states, expected_returns, expected_next_states, expected_lengths))
def learn_step(self, idxs, transition_batch, weights): Otm1, old_action, env_rew, done, Ot = transition_batch batch_size = len(Otm1) actions = torch.tensor(old_action, device=self.device) rewards = torch.tensor(env_rew, device=self.device) dones = torch.tensor(done, device=self.device) states = StateArray( { 'observation': torch.tensor(Otm1, device=self.device), 'reward': rewards, 'done': torch.zeros_like(dones), 'mask': torch.ones_like(dones), }, shape=(batch_size, )) next_states = StateArray( { 'observation': torch.tensor(Ot, device=self.device), 'done': dones, 'mask': 1 - dones, }, shape=(batch_size, )) # forward pass values = self.q(states, actions) # compute targets targets = rewards + self.discount_factor * torch.max( self.q.target(next_states), dim=1)[0] # print(values) # compute loss loss = mse_loss(values, targets) # backward pass self.q.reinforce(loss) # self.logger.record_mean("reward_mean", self.reward_normalizer.mean.detach().cpu().numpy()) # self.logger.record_mean("reward_stdev", self.reward_normalizer.stdev.detach().cpu().numpy()) self.logger.record_mean("critic_loss", loss.detach().cpu().numpy()) self.logger.record_sum("learner_steps", batch_size)
def test_single_q_values(self): states = StateArray(torch.randn((3, STATE_DIM)), (3, )) actions = torch.tensor([0, 1, 0]) probs = self.q(states, actions) self.assertEqual(probs.shape, (3, ATOMS)) tt.assert_almost_equal(probs.sum(dim=1), torch.tensor([1.0, 1.0, 1.0]), decimal=3) tt.assert_almost_equal( probs, torch.tensor([ [0.2065, 0.1045, 0.1542, 0.2834, 0.2513], [0.3190, 0.2471, 0.0534, 0.1424, 0.2380], [0.1427, 0.2486, 0.0946, 0.4112, 0.1029], ]), decimal=3, )
def test_run(self): states = StateArray(torch.arange(0, 20), (20,), reward=torch.arange(-1, 19).float()) actions = torch.arange(0, 20) for i in range(3): self.replay_buffer.store(states[i], actions[i], states[i + 1]) self.assertEqual(len(self.replay_buffer), 0) for i in range(3, 6): self.replay_buffer.store(states[i], actions[i], states[i + 1]) self.assertEqual(len(self.replay_buffer), i - 2) sample = self.replay_buffer.buffer.buffer[0] self.assert_states_equal(sample[0], states[0]) tt.assert_equal(sample[1], actions[0]) tt.assert_equal(sample[2].reward, torch.tensor(0 + 1 * 0.5 + 2 * 0.25 + 3 * 0.125)) tt.assert_equal( self.replay_buffer.buffer.buffer[1][2].reward, torch.tensor(1 + 2 * 0.5 + 3 * 0.25 + 4 * 0.125), )
def test_run(self): states = StateArray(torch.arange(0, 20), (20,), reward=torch.arange(-1, 19).float()) actions = torch.arange(0, 20).view((-1, 1)) expected_samples = State( torch.tensor( [ [0, 1, 2], [0, 1, 3], [5, 5, 5], [6, 6, 2], [7, 7, 7], [7, 8, 8], [7, 7, 7], ] ) ) expected_weights = [ [1.0000, 1.0000, 1.0000], [0.5659, 0.7036, 0.5124], [0.0631, 0.0631, 0.0631], [0.0631, 0.0631, 0.1231], [0.0631, 0.0631, 0.0631], [0.0776, 0.0631, 0.0631], [0.0866, 0.0866, 0.0866], ] actual_samples = [] actual_weights = [] for i in range(10): self.replay_buffer.store(states[i], actions[i], states[i + 1]) if i > 2: sample = self.replay_buffer.sample(3) sample_states = sample[0].observation self.replay_buffer.update_priorities(torch.randn(3)) actual_samples.append(sample_states) actual_weights.append(sample[-1]) actual_samples = State(torch.cat(actual_samples).view((-1, 3))) self.assert_states_equal(actual_samples, expected_samples) np.testing.assert_array_almost_equal( expected_weights, np.vstack(actual_weights), decimal=3 )
def test_as_output(self): observation = torch.randn(3, 4) state = StateArray(observation, (3,)) tensor = torch.randn(3, 5) self.assertEqual(state.as_output(tensor).shape, (3, 5))
def learn_step(self, idxs, transition_batch, weights): Otm1, targ_vec, old_action, env_rew, done, Ot = transition_batch batch_size = len(Ot) obsm1 = self.obs_preproc(torch.tensor(Otm1, device=self.device)) targ_vec = torch.tensor(targ_vec, device=self.device) actions = torch.tensor(old_action, device=self.device) rewards = torch.tensor(env_rew, device=self.device) done = torch.tensor(done, device=self.device).float().to(self.device) next_obs = self.obs_preproc(torch.tensor(Ot, device=self.device)) weights = torch.tensor(weights, device=self.device) # assert (not (Otm1 == Ot).all()) # print(self.device) states = StateArray( { 'observation': obsm1, 'reward': rewards, 'done': done, }, shape=(batch_size, )) # print(states['mask']) next_states = StateArray( { 'observation': obsm1, 'reward': torch.zeros(batch_size, device=self.device), 'done': torch.zeros(batch_size, device=self.device), 'mask': torch.ones(batch_size, device=self.device), }, shape=(batch_size, )) # prediction_reward = self.predictor(Ot) * targ_vec with torch.no_grad(): distribution = self.policy_learner(states) _log_probs = distribution.log_prob(actions).detach().squeeze() value_feature1 = self.features(states) value_feature2 = self.features(next_states) _actions = distribution.sample() #torch.argmax(_log_probs, axis=-1) q_targets = rewards + self.discount_factor * self.v.target( value_feature2).detach() # print(value_feature1) v_targets = torch.min( self.qs[0].target(value_feature1, _actions), self.qs[1].target(value_feature1, _actions), ) - self.temperature * _log_probs # update Q and V-functions # print(q_targets.min(),torch.min( # self.qs[0].target(value_feature1, _actions), # self.qs[1].target(value_feature1, _actions), # )) for i in range(2): self.qs[i].reinforce( mse_loss(self.qs[i](value_feature1, actions), q_targets)) # print(self.v(value_feature1).shape) # print(v_targets.shape) self.v.reinforce(mse_loss(self.v(value_feature1), v_targets)) # update policy distribution = self.policy_learner(states) _actions2 = distribution.sample() _log_probs2 = distribution.log_prob(_actions2).squeeze() loss = (-self.qs[0](value_feature1, _actions2).detach() + self.temperature * _log_probs2).mean() self.policy_learner.reinforce(loss) self.features.reinforce() self.qs[0].zero_grad() # adjust temperature temperature_grad = (_log_probs + self.entropy_target).mean() self.temperature += self.lr_temperature * temperature_grad.detach( ).cpu().numpy()