def __call__(self, T): D = BatchEpisode(self.env_spec) obs = self.env.reset() D.add_observation(obs) self.agent.reset(self.config) # e.g. RNN initial states done = None # for RNN mask for t in range(T): if self.agent.recurrent and done is not None and any(done): kwargs = { 'mask': torch.from_numpy(np.logical_not(done).astype(np.float32)) } else: kwargs = {} out_agent = self.agent.choose_action(obs, **kwargs) action = out_agent.pop('action') if torch.is_tensor(action): raw_action = list(action.detach().cpu().numpy()) else: raw_action = action D.add_action(raw_action) obs, reward, done, info = self.env.step(raw_action) D.add_observation(obs) D.add_reward(reward) D.add_done(done) D.add_info(info) [D.set_completed(n) for n, d in enumerate(done) if d] # Record other information: e.g. log-probability of action, policy entropy D.add_batch_info(out_agent) if all(D.completes): break return D
def test_batch_episode(vec_env, env_id): env = make_vec_env(vec_env, make_gym_env, env_id, 3, 0) env_spec = EnvSpec(env) D = BatchEpisode(env_spec) if env_id == 'CartPole-v1': sticky_action = 1 action_shape = () action_dtype = np.int32 elif env_id == 'Pendulum-v0': sticky_action = [0.1] action_shape = env_spec.action_space.shape action_dtype = np.float32 obs = env.reset() D.add_observation(obs) for t in range(30): action = [sticky_action] * env.num_env obs, reward, done, info = env.step(action) D.add_observation(obs) D.add_action(action) D.add_reward(reward) D.add_done(done) D.add_info(info) D.add_batch_info({'V': [0.1 * (t + 1), (t + 1), 10 * (t + 1)]}) [D.set_completed(n) for n, d in enumerate(done) if d] assert D.N == 3 assert len(D.Ts) == 3 assert D.maxT == max(D.Ts) assert all([ isinstance(x, np.ndarray) for x in [ D.numpy_observations, D.numpy_actions, D.numpy_rewards, D.numpy_dones, D.numpy_masks ] ]) assert all([ x.dtype == np.float32 for x in [D.numpy_observations, D.numpy_rewards, D.numpy_masks] ]) assert all([ x.shape == (3, D.maxT) for x in [D.numpy_rewards, D.numpy_dones, D.numpy_masks] ]) assert D.numpy_actions.dtype == action_dtype assert D.numpy_dones.dtype == np.bool assert D.numpy_observations.shape == (3, D.maxT + 1) + env_spec.observation_space.shape assert D.numpy_actions.shape == (3, D.maxT) + action_shape assert isinstance(D.batch_infos, list) and len(D.batch_infos) == 30 assert np.allclose([0.1 * (x + 1) for x in range(30)], [info['V'][0] for info in D.batch_infos]) assert np.allclose([1 * (x + 1) for x in range(30)], [info['V'][1] for info in D.batch_infos]) assert np.allclose([10 * (x + 1) for x in range(30)], [info['V'][2] for info in D.batch_infos]) seeder = Seeder(0) seed1, seed2, seed3 = seeder(3) env1 = make_gym_env(env_id, seed1) env2 = make_gym_env(env_id, seed2) env3 = make_gym_env(env_id, seed3) for n, ev in enumerate([env1, env2, env3]): obs = ev.reset() assert np.allclose(obs, D.observations[n][0]) assert np.allclose(obs, D.numpy_observations[n, 0, ...]) for t in range(30): obs, reward, done, info = ev.step(sticky_action) assert np.allclose(reward, D.rewards[n][t]) assert np.allclose(reward, D.numpy_rewards[n, t]) assert np.allclose(done, D.dones[n][t]) assert done == D.numpy_dones[n, t] assert int(not done) == D.masks[n][t] assert int(not done) == D.numpy_masks[n, t] if done: assert np.allclose(obs, D.infos[n][t]['terminal_observation']) assert D.completes[n] assert np.allclose(0.0, D.numpy_observations[n, t + 1 + 1:, ...]) assert np.allclose(0.0, D.numpy_actions[n, t + 1:, ...]) assert np.allclose(0.0, D.numpy_rewards[n, t + 1:]) assert np.allclose(True, D.numpy_dones[n, t + 1:]) assert np.allclose(0.0, D.numpy_masks[n, t + 1:]) break else: assert np.allclose(obs, D.observations[n][t + 1])