コード例 #1
0
ファイル: Main.py プロジェクト: yusuf000/thesisRL
def train():
    env_type, env_id = get_env_type("CartPole-v0")
    env = build_env(env_id, 'deepq')
    model = deepq.learn(env,
                        network='mlp',
                        lr=1e-3,
                        total_timesteps=100000,
                        buffer_size=50000,
                        checkpoint_freq=100,
                        checkpoint_path="cartpole_model",
                        exploration_fraction=0.1,
                        exploration_final_eps=0.02,
                        print_freq=10,
                        save_path="model/cartpole_model",
                        callback=callback)
コード例 #2
0
def train_dqn(env_id, num_timesteps):
    """Train a dqn model.

      Parameters
      -------
      env_id: environment to train on
      num_timesteps: int
          number of env steps to optimizer for

      """

    # 1. Create gym environment
    env = gym.make(FLAGS.env)

    # 2. Apply action space wrapper
    env = MarioActionSpaceWrapper(env)

    # 3. Apply observation space wrapper to reduce input size
    env = ProcessFrame84(env)

    # 4. Create a CNN model for Q-Function
    model = cnn_to_mlp(
      convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
      hiddens=[256],
      dueling=FLAGS.dueling
    )

    # 5. Train the model
    act = deepq.learn(
        env,
        q_func=model,
        lr=FLAGS.lr,
        max_timesteps=FLAGS.timesteps,
        buffer_size=10000,
        exploration_fraction=FLAGS.exploration_fraction,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=FLAGS.prioritized,
        callback=deepq_callback
    )
    act.save("mario_model.pkl")
    env.close()
コード例 #3
0
ファイル: Main.py プロジェクト: yusuf000/thesisRL
def play():
    env_type, env_id = get_env_type("CartPole-v0")
    env = build_env(env_id, 'deepq')
    model = deepq.learn(env,
                        network='mlp',
                        lr=1e-3,
                        total_timesteps=100000,
                        buffer_size=50000,
                        checkpoint_freq=100,
                        exploration_fraction=0.1,
                        exploration_final_eps=0.02,
                        print_freq=10,
                        load_path="model/cartpole_model",
                        callback=callback)
    obs, done = env.reset(), False
    if not isinstance(env, VecEnv):
        obs = np.expand_dims(np.array(obs), axis=0)
    state = model.initial_state if hasattr(model, 'initial_state') else None

    episode_rew = np.zeros(env.num_envs) if isinstance(env,
                                                       VecEnv) else np.zeros(1)
    while True:
        if state is not None:
            actions, _, state, _ = model.step(obs)
        else:
            actions, _, _, _ = model.step(obs)

        obs, rew, done, _ = env.step(actions.numpy())
        if not isinstance(env, VecEnv):
            obs = np.expand_dims(np.array(obs), axis=0)
        episode_rew += rew
        env.render()
        done_any = done.any() if isinstance(done, np.ndarray) else done
        if done_any:
            for i in np.nonzero(done)[0]:
                print('episode_rew={}'.format(episode_rew[i]))
                episode_rew[i] = 0

    env.close()
コード例 #4
0
def main(seed, fraction, discount, path, gpu):
    with tf.device('/device:GPU:%s' % gpu):

        logger.configure(dir=dirs + '%s/%s/' % (path, seed),
                         format_strs=['csv'])

        kwargs = dict(
            seed=seed,
            network=models.mlp(num_layers=2,
                               num_hidden=128,
                               activation=tf.nn.relu),
            lr=1e-4,
            total_timesteps=1500000,
            buffer_size=150000,
            #exploration_fraction=1.0, #random act
            #exploration_final_eps=1.0, #random act
            exploration_fraction=0.2,
            exploration_final_eps=0.02,
            learning_starts=2000,
            target_network_update_freq=500,
            myopic_fraction=fraction,
            final_gamma=discount,
            #gamma=discount,
            prioritized_replay=True,
            prioritized_replay_alpha=0.6,
            print_freq=5)

        f = open(dirs + '%s/%s/params.txt' % (path, seed), 'w')
        f.write(str(kwargs))
        f.close()

        env = gym.make("dense-v0")
        act = deepq.learn(env=env, **kwargs)

        print("Saving model to maze.pkl")

        act.save(dirs + '%s/%s/maze.pkl' % (path, seed))
        save_plot(path, seed)