示例#1
0
def test_trajectory(init_seed, T):
    make_env = lambda: TimeLimit(SanityEnv())
    env = make_vec_env(make_env, 1, init_seed)  # single environment
    env = VecStepInfo(env)
    D = Trajectory()
    assert len(D) == 0
    assert not D.completed
    
    observation, _ = env.reset()
    D.add_observation(observation)
    for t in range(T):
        action = [env.action_space.sample()]
        next_observation, reward, step_info = env.step(action)
        # unbatched for [reward, step_info]
        reward, step_info = map(lambda x: x[0], [reward, step_info])
        if step_info.last:
            D.add_observation([step_info['last_observation']])
        else:
            D.add_observation(next_observation)
        D.add_action(action)
        D.add_reward(reward)
        D.add_step_info(step_info)
        observation = next_observation
        if step_info.last:
            with pytest.raises(AssertionError):
                D.add_observation(observation)
            break
    assert len(D) > 0
    assert len(D) <= T
    assert len(D) + 1 == len(D.observations)
    assert len(D) + 1 == len(D.numpy_observations)
    assert len(D) == len(D.actions)
    assert len(D) == len(D.numpy_actions)
    assert len(D) == len(D.rewards)
    assert len(D) == len(D.numpy_rewards)
    assert len(D) == len(D.numpy_dones)
    assert len(D) == len(D.numpy_masks)
    assert np.allclose(np.logical_not(D.numpy_dones), D.numpy_masks)
    assert len(D) == len(D.step_infos)
    if len(D) < T:
        assert step_info.last
        assert D.completed
        assert D.reach_terminal
        assert not D.reach_time_limit
        assert np.allclose(D.observations[-1], [step_info['last_observation']])
    if not step_info.last:
        assert not D.completed
        assert not D.reach_terminal
        assert not D.reach_time_limit
示例#2
0
def test_episode_runner(env_id, num_env, init_seed, T):    
    if env_id == 'Sanity':
        make_env = lambda: TimeLimit(SanityEnv())
    else:
        make_env = lambda: gym.make(env_id)
    env = make_vec_env(make_env, num_env, init_seed)
    env = VecStepInfo(env)
    agent = RandomAgent(None, env, None)
    runner = EpisodeRunner()
    
    if num_env > 1:
        with pytest.raises(AssertionError):
            D = runner(agent, env, T)
    else:
        with pytest.raises(AssertionError):
            runner(agent, env.env, T)  # must be VecStepInfo
        D = runner(agent, env, T)
        for traj in D:
            assert isinstance(traj, Trajectory)
            assert len(traj) <= env.spec.max_episode_steps
            assert traj.numpy_observations.shape == (len(traj) + 1, *env.observation_space.shape)
            if isinstance(env.action_space, gym.spaces.Discrete):
                assert traj.numpy_actions.shape == (len(traj),)
            else:
                assert traj.numpy_actions.shape == (len(traj), *env.action_space.shape)
            assert traj.numpy_rewards.shape == (len(traj),)
            assert traj.numpy_dones.shape == (len(traj), )
            assert traj.numpy_masks.shape == (len(traj), )
            assert len(traj.step_infos) == len(traj)
            if traj.completed:
                assert np.allclose(traj.observations[-1], traj.step_infos[-1]['last_observation'])
示例#3
0
def run(config, seed, device, logdir):
    set_global_seeds(seed)
    
    env = make_env(config, seed)
    env = VecMonitor(env)
    if config['env.standardize_obs']:
        env = VecStandardizeObservation(env, clip=5.)
    if config['env.standardize_reward']:
        env = VecStandardizeReward(env, clip=10., gamma=config['agent.gamma'])
    env = VecStepInfo(env)
    
    agent = Agent(config, env, device)
    runner = EpisodeRunner(reset_on_call=False)
    engine = Engine(config, agent=agent, env=env, runner=runner)
    train_logs = []
    checkpoint_count = 0
    for i in count():
        if agent.total_timestep >= config['train.timestep']:
            break
        train_logger = engine.train(i)
        train_logs.append(train_logger.logs)
        if i == 0 or (i+1) % config['log.freq'] == 0:
            train_logger.dump(keys=None, index=0, indent=0, border='-'*50)
        if agent.total_timestep >= int(config['train.timestep']*(checkpoint_count/(config['checkpoint.num'] - 1))):
            agent.checkpoint(logdir, i + 1)
            checkpoint_count += 1
    pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
    return None
示例#4
0
def run(config, seed, device, logdir):
    set_global_seeds(seed)
    
    env = make_env(config, seed)
    env = VecMonitor(env)
    env = VecStepInfo(env)
    
    eval_env = make_env(config, seed)
    eval_env = VecMonitor(eval_env)
    
    agent = Agent(config, env, device)
    replay = ReplayBuffer(env, config['replay.capacity'], device)
    engine = Engine(config, agent=agent, env=env, eval_env=eval_env, replay=replay, logdir=logdir)
    
    train_logs, eval_logs = engine.train()
    pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
    pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl')
    return None  
示例#5
0
def make_env(config, seed, mode):
    assert mode in ['train', 'eval']

    def _make_env():
        env = gym.make(config['env.id'])
        if config['env.clip_action'] and isinstance(env.action_space, Box):
            env = ClipAction(env)
        return env

    env = make_vec_env(_make_env, 1, seed)  # single environment
    env = VecMonitor(env)
    if mode == 'train':
        if config['env.standardize_obs']:
            env = VecStandardizeObservation(env, clip=5.)
        if config['env.standardize_reward']:
            env = VecStandardizeReward(env,
                                       clip=10.,
                                       gamma=config['agent.gamma'])
        env = VecStepInfo(env)
    return env
示例#6
0
文件: test_envs.py 项目: jlqzzz/lagom
def test_vec_step_info(num_env, init_seed):
    make_env = lambda: gym.make('Pendulum-v0')
    env = make_vec_env(make_env, num_env, init_seed)
    env = VecStepInfo(env)

    observations, step_infos = env.reset()
    assert all([isinstance(step_info, StepInfo) for step_info in step_infos])
    assert all([step_info.first for step_info in step_infos])
    assert all([not step_info.mid for step_info in step_infos])
    assert all([not step_info.last for step_info in step_infos])
    assert all([not step_info.time_limit for step_info in step_infos])
    assert all([not step_info.terminal for step_info in step_infos])

    for _ in range(5000):
        observations, rewards, step_infos = env.step(
            [env.action_space.sample() for _ in range(num_env)])
        for step_info in step_infos:
            assert isinstance(step_info, StepInfo)
            if step_info.last:
                assert step_info.done
                assert np.allclose(step_info['last_observation'],
                                   step_info.info['last_observation'])
                assert not step_info.first and not step_info.mid
                # Pendulum cut by TimeLimit
                assert 'TimeLimit.truncated' in step_info.info
                assert step_info.time_limit
                assert not step_info.terminal
            else:
                assert not step_info.done
                assert step_info.mid
                assert not step_info.first and not step_info.last
                assert not step_info.time_limit
                assert not step_info.terminal
    del make_env, env

    make_env = lambda: gym.make('CartPole-v1')
    env = make_vec_env(make_env, num_env, init_seed)
    env = VecStepInfo(env)

    observations, step_infos = env.reset()
    assert all([isinstance(step_info, StepInfo) for step_info in step_infos])
    assert all([step_info.first for step_info in step_infos])
    assert all([not step_info.mid for step_info in step_infos])
    assert all([not step_info.last for step_info in step_infos])
    assert all([not step_info.time_limit for step_info in step_infos])
    assert all([not step_info.terminal for step_info in step_infos])

    for _ in range(5000):
        observations, rewards, step_infos = env.step(
            [env.action_space.sample() for _ in range(num_env)])
        for step_info in step_infos:
            assert isinstance(step_info, StepInfo)
            if step_info.last:
                assert step_info.done
                assert np.allclose(step_info['last_observation'],
                                   step_info.info['last_observation'])
                assert not step_info.first and not step_info.mid
                # CartPole terminates episode with terminal state via random actions
                assert 'TimeLimit.truncated' not in step_info.info
                assert not step_info.time_limit
                assert step_info.terminal
            else:
                assert not step_info.done
                assert step_info.mid
                assert not step_info.first and not step_info.last
                assert not step_info.time_limit
                assert not step_info.terminal