Ejemplo n.º 1
0
def run(config, seed, device):
    set_global_seeds(seed)
    logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
    
    env = make_env(config, seed)
    env = VecMonitor(env)
    if config['env.standardize_obs']:
        env = VecStandardizeObservation(env, clip=5.)
    if config['env.standardize_reward']:
        env = VecStandardizeReward(env, clip=10., gamma=config['agent.gamma'])
    
    agent = Agent(config, env, device)
    runner = EpisodeRunner(reset_on_call=False)
    engine = Engine(config, agent=agent, env=env, runner=runner)
    train_logs = []
    for i in count():
        if agent.total_timestep >= config['train.timestep']:
            break
        train_logger = engine.train(i)
        train_logs.append(train_logger.logs)
        if i == 0 or (i+1) % config['log.freq'] == 0:
            train_logger.dump(keys=None, index=0, indent=0, border='-'*50)
        if i == 0 or (i+1) % config['checkpoint.freq'] == 0:
            agent.checkpoint(logdir, i + 1)
    agent.checkpoint(logdir, i + 1)
    pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
    return None
Ejemplo n.º 2
0
def initializer(config, seed, device):
    torch.set_num_threads(1)  # VERY IMPORTANT TO AVOID GETTING STUCK
    global env
    env = make_env(config, seed)
    env = VecMonitor(env)
    if config['env.standardize_obs']:
        env = VecStandardizeObservation(env, clip=5.)
    global agent
    agent = Agent(config, env, device)
Ejemplo n.º 3
0
def run(config, seed, device, logdir):
    set_global_seeds(seed)
    
    env = make_env(config, seed)
    env = VecMonitor(env)
    env = VecStepInfo(env)
    
    eval_env = make_env(config, seed)
    eval_env = VecMonitor(eval_env)
    
    agent = Agent(config, env, device)
    replay = ReplayBuffer(env, config['replay.capacity'], device)
    engine = Engine(config, agent=agent, env=env, eval_env=eval_env, replay=replay, logdir=logdir)
    
    train_logs, eval_logs = engine.train()
    pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
    pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl')
    return None  
Ejemplo n.º 4
0
def test_vec_monitor(env_id, num_env, init_seed):
    make_env = lambda: gym.make(env_id)
    env = make_vec_env(make_env, num_env, init_seed)
    env = VecMonitor(env)

    env.reset()
    counter = 0
    for _ in range(2000):
        actions = [env.action_space.sample() for _ in range(len(env))]
        _, _, dones, infos = env.step(actions)
        for i, (done, info) in enumerate(zip(dones, infos)):
            if done:
                assert 'last_observation' in info
                assert 'episode' in info
                assert 'return' in info['episode']
                assert 'horizon' in info['episode']
                assert 'time' in info['episode']
                assert env.episode_rewards[i] == 0.0
                assert env.episode_horizons[i] == 0.0
                counter += 1
    assert min(100, counter) == len(env.return_queue)
    assert min(100, counter) == len(env.horizon_queue)
Ejemplo n.º 5
0
def make_env(config, seed, mode):
    assert mode in ['train', 'eval']

    def _make_env():
        env = gym.make(config['env.id'])
        if config['env.clip_action'] and isinstance(env.action_space, Box):
            env = ClipAction(env)
        return env

    env = make_vec_env(_make_env, 1, seed)  # single environment
    env = VecMonitor(env)
    if mode == 'train':
        if config['env.standardize_obs']:
            env = VecStandardizeObservation(env, clip=5.)
        if config['env.standardize_reward']:
            env = VecStandardizeReward(env,
                                       clip=10.,
                                       gamma=config['agent.gamma'])
        env = VecStepInfo(env)
    return env
Ejemplo n.º 6
0
def test_get_wrapper(env_id):
    def make_env():
        return gym.make(env_id)

    env = make_env()
    env = ClipReward(env, 0.1, 0.5)
    env = FlattenObservation(env)
    env = FrameStack(env, 4)

    assert get_wrapper(env, 'ClipReward').__class__.__name__ == 'ClipReward'
    assert get_wrapper(
        env, 'FlattenObservation').__class__.__name__ == 'FlattenObservation'
    assert get_wrapper(env, 'Env') is None

    del env

    # vec_env
    env = make_vec_env(make_env, 3, 0)
    env = VecMonitor(env)
    assert get_wrapper(env, 'VecMonitor').__class__.__name__ == 'VecMonitor'
    assert get_wrapper(env, 'ClipReward') is None