コード例 #1
0
ファイル: run_atari.py プロジェクト: christopher-hsu/ray
def train():

    logger.configure()
    set_global_seeds(args.seed)

    directory = os.path.join(
        args.log_dir,
        '_'.join([args.env,
                  datetime.datetime.now().strftime("%m%d%H%M")]))
    if not os.path.exists(directory):
        os.makedirs(directory)
    else:
        ValueError("The directory already exists...", directory)
    json.dump(vars(args),
              open(os.path.join(directory, 'learning_prop.json'), 'w'))

    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)

    nb_test_steps = args.nb_test_steps if args.nb_test_steps > 0 else None
    if args.record == 1:
        env = Monitor(env, directory=args.log_dir)
    with tf.device(args.device):
        model = deepq.models.cnn_to_mlp(
            convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
            hiddens=[256],
            dueling=bool(args.dueling),
        )

        act, records = deepq.learn(
            env,
            q_func=model,
            lr=args.learning_rate,
            lr_decay_factor=args.learning_rate_decay_factor,
            lr_growth_factor=args.learning_rate_growth_factor,
            max_timesteps=args.nb_train_steps,
            buffer_size=args.buffer_size,
            exploration_fraction=args.eps_fraction,
            exploration_final_eps=args.eps_min,
            train_freq=4,
            print_freq=1000,
            checkpoint_freq=int(args.nb_train_steps / 10),
            learning_starts=args.nb_warmup_steps,
            target_network_update_freq=args.target_update_freq,
            gamma=0.99,
            prioritized_replay=bool(args.prioritized),
            prioritized_replay_alpha=args.prioritized_replay_alpha,
            epoch_steps=args.nb_epoch_steps,
            gpu_memory=args.gpu_memory,
            double_q=args.double_q,
            save_dir=directory,
            nb_test_steps=nb_test_steps,
            scope=args.scope,
            test_eps=args.test_eps,
        )
        print("Saving model to model.pkl")
        act.save(os.path.join(directory, "model.pkl"))
    env.close()
    plot(records, directory)
コード例 #2
0
def test(env_id,
         isAtari,
         act_greedy,
         nb_itrs=3,
         nb_test_steps=10000,
         render=False):
    total_rewards = []
    for _ in range(nb_itrs):
        if isAtari:
            from baselines0.common.atari_wrappers import make_atari
            env_new = make_atari(env_id)
            env_new = deepq.wrap_atari_dqn(env_new)
        else:
            env_new = envs.make(env_id, render, figID=1)
        obs = env_new.reset()

        if nb_test_steps is None:
            done_test = False
            episode_reward = 0
            t = 0
            while not done_test:
                action = act_greedy(np.array(obs)[None])[0]
                obs, rew, done, info = env_new.step(action)
                if render:
                    env_new.render(mode='test')
                episode_reward += rew
                t += 1
                if done:
                    obs = env_new.reset()
                    if (isAtari and (info['ale.lives'] == 0)) or (not isAtari):
                        done_test = done
            if render:
                env_new.close()
            total_rewards.append(episode_reward)
        else:
            t = 0
            episodes = []
            episode_reward = 0
            while (t < nb_test_steps):
                action = act_greedy(np.array(obs)[None])[0]
                obs, rew, done, info = env_new.step(action)
                episode_reward += rew
                t += 1
                if done:
                    obs = env_new.reset()
                    if (isAtari and (info['ale.lives'] == 0)) or (not isAtari):
                        episodes.append(episode_reward)
                        episode_reward = 0
            if not (episodes):
                episodes.append(episode_reward)
            total_rewards.append(np.mean(episodes))

    return np.array(total_rewards, dtype=np.float32)
コード例 #3
0
def test():
    from baselines0.deepq.utils import BatchInput
    import json
    learning_prop = json.load(
        open(os.path.join(args.log_dir, 'learning_prop.json'), 'r'))

    env = make_atari(args.env)
    env = models.wrap_atari_dqn(env)
    observation_space_shape = env.observation_space.shape

    def make_obs_ph(name):
        return BatchInput(observation_space_shape, name=name)

    model = models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[learning_prop['num_units']] * learning_prop['num_layers'],
        dueling=bool(args.dueling),
        init_mean=args.init_mean,
        init_sd=args.init_sd,
    )

    act_params = {
        'make_obs_ph': make_obs_ph,
        'q_func': model,
        'scope': learning_prop['scope'],
        'eps': args.test_eps
    }
    act = simple.load(os.path.join(args.log_dir, args.log_fname), act_params)
    if args.record:
        env = Monitor(env, directory=args.log_dir)
    episode_rew = 0
    t = 0
    while True:
        obs, done = env.reset(), False
        while (not done):
            if args.render:
                env.render()
                time.sleep(0.05)
            obs, rew, done, info = env.step(act(obs[None])[0])
            # Reset only the enviornment but not the recorder
            if args.record and done:
                obs, done = env.env.reset(), False
            episode_rew += rew
            t += 1
        if info['ale.lives'] == 0:
            print("Episode reward %.2f after %d steps" % (episode_rew, t))
            episode_rew = 0
            t = 0
コード例 #4
0
ファイル: __init__.py プロジェクト: christopher-hsu/ADFQ
def make(env_name, type, render=False, record=False, directory='', **kwargs):
    """
    env_name : str
        name of an environment. (e.g. 'Cartpole-v0')
    type : str
        type of an environment. One of ['atari', 'classic_control',
        'classic_mdp','target_tracking']
    """
    if type == 'atari':
        from baselines0.common.atari_wrappers import make_atari
        from baselines0.common.atari_wrappers import wrap_deepmind
        from baselines0 import bench, logger

        env = make_atari(env_name)
        env = bench.Monitor(env, logger.get_dir())
        env = wrap_deepmind(env, frame_stack=True, scale=True)
        if record:
            env = Monitor(env, directory=directory)

    elif type == 'classic_control':
        env = gym.make(env_name)
        if record:
            env = Monitor(env, directory=directory)

    elif type == 'classic_mdp':
        from envs import classic_mdp
        env = classic_mdp.model_assign(env_name)

    elif type == 'target_tracking':
        import ttenv
        env = ttenv.make(env_name,
                         render=render,
                         record=record,
                         directory=directory,
                         **kwargs)
    elif type == 'ma_target_tracking':
        import maTTenv
        env = maTTenv.make(env_name,
                           render=render,
                           record=record,
                           directory=directory,
                           **kwargs)
    else:
        raise ValueError('Designate the right type of the environment.')

    return env
コード例 #5
0
ファイル: run_atari.py プロジェクト: christopher-hsu/ray
def test():
    from baselines0.deepq.utils import BatchInput

    env = make_atari(args.env)
    env = deepq.wrap_atari_dqn(env)
    observation_space_shape = env.observation_space.shape

    def make_obs_ph(name):
        return BatchInput(observation_space_shape, name=name)

    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[args.num_units] * args.num_layers,
        dueling=bool(args.dueling),
    )
    act_params = {
        'make_obs_ph': make_obs_ph,
        'q_func': model,
        'scope': args.scope
    }
    act = deepq.load(os.path.join(args.log_dir, args.log_fname), act_params)
    if args.record:
        env = Monitor(env, directory=args.log_dir)
    episode_rew = 0
    t = 0
    while True:
        obs, done = env.reset(), False

        while not done:
            if not (args.record):
                env.render()
                #time.sleep(0.01)
            obs, rew, done, info = env.step(act(obs[None])[0])
            episode_rew += rew
            t += 1
        if info['ale.lives'] == 0:
            print("Episode reward %.2f after %d steps" % (episode_rew, t))
            episode_rew = 0
            t = 0
コード例 #6
0
ファイル: cmd_util.py プロジェクト: christopher-hsu/ray
 def _thunk():
     env = make_atari(env_id)
     env.seed(seed + rank)
     env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
     return wrap_deepmind(env, **wrapper_kwargs)