コード例 #1
0
ファイル: test.py プロジェクト: zoskia/RL-MsPacman
def test_model(model_path, max_steps):
    dqn = DQN()
    env = gym.make("MsPacman-v0")

    X_state = tf.placeholder(
        tf.float32, shape=[None, input_height, input_width, input_channels])
    online_q_values, online_vars = dqn.create_model(X_state, "qnetwork_online")
    saver = tf.train.Saver()

    with tf.Session() as sess:
        saver.restore(sess, model_path)

        obs = env.reset()

        for step in range(max_steps):
            state = preprocess_observation(obs)

            # evaluates what to do
            q_values = online_q_values.eval(feed_dict={X_state: [state]})
            action = np.argmax(q_values)

            # plays the game
            obs, reward, done, info = env.step(action)
            env.render()
            time.sleep(0.05)
            if done:
                break
    env.close()
コード例 #2
0
ファイル: training.py プロジェクト: zoskia/RL-MsPacman
def train_model():
    iteration = 0
    loss_val = np.infty
    game_length = 0
    total_max_q = 0
    mean_max_q = 0.0
    done = True
    state = []

    dqn = DQN()
    env = gym.make("MsPacman-v0")

    X_state = tf.placeholder(
        tf.float32, shape=[None, input_height, input_width, input_channels])

    online_q_values, online_vars = dqn.create_model(X_state, "qnetwork_online")
    target_q_values, target_vars = dqn.create_model(X_state, "qnetwork_target")

    copy_ops = [
        target_var.assign(online_vars[var_name])
        for var_name, target_var in target_vars.items()
    ]
    copy_online_to_target = tf.group(*copy_ops)

    X_action, global_step, loss, training_op, y = define_train_variables(
        online_q_values)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:

        restore_session(copy_online_to_target, init, saver, sess)

        while True:
            step = global_step.eval()
            if step >= n_steps:
                break

            iteration += 1
            print(
                "\rIteration {}\tTraining step {}/{} ({:.1f})%\tLoss {:5f}\tMean Max-Q {:5f}   "
                .format(iteration, step, n_steps, step * 100 / n_steps,
                        loss_val, mean_max_q),
                end="")

            state = skip_some_steps(done, env, state)

            done, q_values, next_state = evaluate_and_play_online_dqn(
                X_state, env, online_q_values, state, step)
            state = next_state

            mean_max_q = compute_statistics(done, game_length, mean_max_q,
                                            q_values, total_max_q)

            if iteration < training_start or iteration % training_interval != 0:
                continue

            loss_val = train_online_dqn(X_action, X_state, loss, sess,
                                        target_q_values, training_op, y)

            # Copy the online DQN to the target DQN
            if step % copy_steps == 0:
                copy_online_to_target.run()

            # Save model
            if step % save_steps == 0:
                saver.save(sess, checkpoint_path)