def test_trajectory(init_seed, T): make_env = lambda: TimeLimit(SanityEnv()) env = make_vec_env(make_env, 1, init_seed) # single environment env = VecStepInfo(env) D = Trajectory() assert len(D) == 0 assert not D.completed observation, _ = env.reset() D.add_observation(observation) for t in range(T): action = [env.action_space.sample()] next_observation, reward, step_info = env.step(action) # unbatched for [reward, step_info] reward, step_info = map(lambda x: x[0], [reward, step_info]) if step_info.last: D.add_observation([step_info['last_observation']]) else: D.add_observation(next_observation) D.add_action(action) D.add_reward(reward) D.add_step_info(step_info) observation = next_observation if step_info.last: with pytest.raises(AssertionError): D.add_observation(observation) break assert len(D) > 0 assert len(D) <= T assert len(D) + 1 == len(D.observations) assert len(D) + 1 == len(D.numpy_observations) assert len(D) == len(D.actions) assert len(D) == len(D.numpy_actions) assert len(D) == len(D.rewards) assert len(D) == len(D.numpy_rewards) assert len(D) == len(D.numpy_dones) assert len(D) == len(D.numpy_masks) assert np.allclose(np.logical_not(D.numpy_dones), D.numpy_masks) assert len(D) == len(D.step_infos) if len(D) < T: assert step_info.last assert D.completed assert D.reach_terminal assert not D.reach_time_limit assert np.allclose(D.observations[-1], [step_info['last_observation']]) if not step_info.last: assert not D.completed assert not D.reach_terminal assert not D.reach_time_limit
def test_episode_runner(env_id, num_env, init_seed, T): if env_id == 'Sanity': make_env = lambda: TimeLimit(SanityEnv()) else: make_env = lambda: gym.make(env_id) env = make_vec_env(make_env, num_env, init_seed) env = VecStepInfo(env) agent = RandomAgent(None, env, None) runner = EpisodeRunner() if num_env > 1: with pytest.raises(AssertionError): D = runner(agent, env, T) else: with pytest.raises(AssertionError): runner(agent, env.env, T) # must be VecStepInfo D = runner(agent, env, T) for traj in D: assert isinstance(traj, Trajectory) assert len(traj) <= env.spec.max_episode_steps assert traj.numpy_observations.shape == (len(traj) + 1, *env.observation_space.shape) if isinstance(env.action_space, gym.spaces.Discrete): assert traj.numpy_actions.shape == (len(traj),) else: assert traj.numpy_actions.shape == (len(traj), *env.action_space.shape) assert traj.numpy_rewards.shape == (len(traj),) assert traj.numpy_dones.shape == (len(traj), ) assert traj.numpy_masks.shape == (len(traj), ) assert len(traj.step_infos) == len(traj) if traj.completed: assert np.allclose(traj.observations[-1], traj.step_infos[-1]['last_observation'])
def run(config, seed, device, logdir): set_global_seeds(seed) env = make_env(config, seed) env = VecMonitor(env) if config['env.standardize_obs']: env = VecStandardizeObservation(env, clip=5.) if config['env.standardize_reward']: env = VecStandardizeReward(env, clip=10., gamma=config['agent.gamma']) env = VecStepInfo(env) agent = Agent(config, env, device) runner = EpisodeRunner(reset_on_call=False) engine = Engine(config, agent=agent, env=env, runner=runner) train_logs = [] checkpoint_count = 0 for i in count(): if agent.total_timestep >= config['train.timestep']: break train_logger = engine.train(i) train_logs.append(train_logger.logs) if i == 0 or (i+1) % config['log.freq'] == 0: train_logger.dump(keys=None, index=0, indent=0, border='-'*50) if agent.total_timestep >= int(config['train.timestep']*(checkpoint_count/(config['checkpoint.num'] - 1))): agent.checkpoint(logdir, i + 1) checkpoint_count += 1 pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl') return None
def run(config, seed, device, logdir): set_global_seeds(seed) env = make_env(config, seed) env = VecMonitor(env) env = VecStepInfo(env) eval_env = make_env(config, seed) eval_env = VecMonitor(eval_env) agent = Agent(config, env, device) replay = ReplayBuffer(env, config['replay.capacity'], device) engine = Engine(config, agent=agent, env=env, eval_env=eval_env, replay=replay, logdir=logdir) train_logs, eval_logs = engine.train() pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl') pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl') return None
def make_env(config, seed, mode): assert mode in ['train', 'eval'] def _make_env(): env = gym.make(config['env.id']) if config['env.clip_action'] and isinstance(env.action_space, Box): env = ClipAction(env) return env env = make_vec_env(_make_env, 1, seed) # single environment env = VecMonitor(env) if mode == 'train': if config['env.standardize_obs']: env = VecStandardizeObservation(env, clip=5.) if config['env.standardize_reward']: env = VecStandardizeReward(env, clip=10., gamma=config['agent.gamma']) env = VecStepInfo(env) return env
def test_vec_step_info(num_env, init_seed): make_env = lambda: gym.make('Pendulum-v0') env = make_vec_env(make_env, num_env, init_seed) env = VecStepInfo(env) observations, step_infos = env.reset() assert all([isinstance(step_info, StepInfo) for step_info in step_infos]) assert all([step_info.first for step_info in step_infos]) assert all([not step_info.mid for step_info in step_infos]) assert all([not step_info.last for step_info in step_infos]) assert all([not step_info.time_limit for step_info in step_infos]) assert all([not step_info.terminal for step_info in step_infos]) for _ in range(5000): observations, rewards, step_infos = env.step( [env.action_space.sample() for _ in range(num_env)]) for step_info in step_infos: assert isinstance(step_info, StepInfo) if step_info.last: assert step_info.done assert np.allclose(step_info['last_observation'], step_info.info['last_observation']) assert not step_info.first and not step_info.mid # Pendulum cut by TimeLimit assert 'TimeLimit.truncated' in step_info.info assert step_info.time_limit assert not step_info.terminal else: assert not step_info.done assert step_info.mid assert not step_info.first and not step_info.last assert not step_info.time_limit assert not step_info.terminal del make_env, env make_env = lambda: gym.make('CartPole-v1') env = make_vec_env(make_env, num_env, init_seed) env = VecStepInfo(env) observations, step_infos = env.reset() assert all([isinstance(step_info, StepInfo) for step_info in step_infos]) assert all([step_info.first for step_info in step_infos]) assert all([not step_info.mid for step_info in step_infos]) assert all([not step_info.last for step_info in step_infos]) assert all([not step_info.time_limit for step_info in step_infos]) assert all([not step_info.terminal for step_info in step_infos]) for _ in range(5000): observations, rewards, step_infos = env.step( [env.action_space.sample() for _ in range(num_env)]) for step_info in step_infos: assert isinstance(step_info, StepInfo) if step_info.last: assert step_info.done assert np.allclose(step_info['last_observation'], step_info.info['last_observation']) assert not step_info.first and not step_info.mid # CartPole terminates episode with terminal state via random actions assert 'TimeLimit.truncated' not in step_info.info assert not step_info.time_limit assert step_info.terminal else: assert not step_info.done assert step_info.mid assert not step_info.first and not step_info.last assert not step_info.time_limit assert not step_info.terminal