コード例 #1
0
ファイル: ms_pacman.py プロジェクト: kekag/MsPacman-RL
    def ProcessModel(self, request, context):
        global server_fits
        global episodes_processed
        if server_fits == 0:
            print("Client found, fitting model")
        elif server_fits % 100 == 0 and not verbose:  # sanity check every 100 fits
            print(f"{server_fits}th fit")

        state = pickle.loads(request.state)
        next_state = pickle.loads(request.next_state)

        target = request.reward + gamma * np.max(model.predict(next_state))
        target_vec = model.predict(state)[0]
        target_vec[request.action] = target

        model.fit(state,
                  target_vec.reshape(-1, action_count),
                  epochs=1,
                  verbose=0)
        server_fits += 1

        if request.done:
            print("Done fitting model for current episode")
            episodes_processed += 1
        return mpm_pb2.Empty()
コード例 #2
0
ファイル: ms_pacman.py プロジェクト: kekag/MsPacman-RL
 def SaveModel(self, request, context):
     print(f"Saving model to {model_file}")
     model.save(model_file)
     if not request.model_only:
         if episodes_processed >= 3:
             print(f"Saving reward plot to {figure_file}")
             plot_reward(rewards, figure_file)
     return mpm_pb2.Empty()
コード例 #3
0
ファイル: ms_pacman.py プロジェクト: kekag/MsPacman-RL
 def DropClient(self, request, context):
     print("Client terminated training")
     return mpm_pb2.Empty()
コード例 #4
0
ファイル: ms_pacman.py プロジェクト: kekag/MsPacman-RL
def main():
    global env
    global model
    global epsilon
    global episodes_processed
    global rewards

    if verbose:
        print("\nAction Space:     ", env.action_space)
        print("Action Meanings:  \n", env.get_action_meanings())
        # print("Action Keys:      \n", env.get_keys_to_action())

    for i in range(n_episodes):
        print("Episode:", i)
        if verbose:
            print()

        state = env.reset()
        state = np.asarray(state)
        state = state.reshape((1, ) + state.shape + (1, ))

        done = False
        total_reward = 0
        tick = 0
        always_noop = False

        while not done:
            if render:
                env.render()

            action, action_type, always_noop = get_action(
                always_noop, epsilon, env.action_space.sample())

            if run_as == Run.client:
                if action == -1:
                    action_response = stub.PredictAction(
                        mpm_pb2.StateRequest(state=pickle.dumps(state)))
                    action = action_response.action

                next_state, current_reward, done, info = env.step(action)
                next_state = np.asarray(next_state)
                next_state = next_state.reshape((1, ) + next_state.shape +
                                                (1, ))

                stub.ProcessModel(
                    mpm_pb2.ModelRequest(state=pickle.dumps(state),
                                         next_state=pickle.dumps(next_state),
                                         reward=current_reward,
                                         done=done))

            elif run_as == Run.local:
                if action == -1:
                    action = np.argmax(model.predict(state))

                next_state, current_reward, done, info = env.step(action)
                next_state = np.asarray(next_state)
                next_state = next_state.reshape((1, ) + next_state.shape +
                                                (1, ))

                # Q-value for action
                target = current_reward + gamma * np.max(
                    model.predict(next_state))
                # Array of Q-values for all actions
                target_vec = model.predict(state)[0]
                # Change actions value to be target for fitting
                target_vec[action] = target

                model.fit(state,
                          target_vec.reshape(-1, action_count),
                          epochs=1,
                          verbose=0)

            total_reward += current_reward
            # total_reward += 1 # Reward each survived tick

            if verbose:
                print("EP %i. ACTION: %9s%7s | REWARD: %4i | LIVES: %d" %
                      (episodes_processed, env.get_action_meanings()[action],
                       action_type, current_reward, info.get('ale.lives')))

            state = next_state
            if done:
                rewards.append(total_reward)
                if verbose:
                    print()
                print(f"Reward: {total_reward}\n")

            tick += 1
        if run_as == Run.local:
            model.save(model_file)
        else:
            stub.SaveModel(mpm_pb2.SaveRequest(model_only=True))

        if render:
            env.render()

        episodes_processed += 1
        if epsilon > min_epsilon:
            epsilon *= decay
            print("Decayed epsilon to", epsilon)

    env.close()
    if run_as == Run.local:
        model.save(model_file)
        if n_episodes >= 3:
            plot_reward(rewards, figure_file)
    elif run_as == Run.client:
        stub.SaveModel(mpm_pb2.SaveRequest(model_only=False))
        stub.DropClient(mpm_pb2.Empty())
コード例 #5
0
ファイル: ms_pacman.py プロジェクト: kekag/MsPacman-RL
        if n_episodes >= 3:
            plot_reward(rewards, figure_file)
    elif run_as == Run.client:
        stub.SaveModel(mpm_pb2.SaveRequest(model_only=False))
        stub.DropClient(mpm_pb2.Empty())


try:
    main()
except KeyboardInterrupt:
    print("\nKEYBOARD INTERRUPT")
    try:
        if episodes_processed > 0:
            save = input("Save model data? [y/n] ")
            if save == 'y' or save == 'Y':
                if run_as == Run.local:
                    env.close()
                    print(f"Saving model to {model_file}")
                    model.save(model_file)
                    if episodes_processed >= 3:
                        print(f"Saving reward plot to {figure_file}")
                        plot_reward(rewards, figure_file)
                elif run_as == Run.client:
                    env.close()
                    stub.SaveModel(mpm_pb2.SaveRequest(model_only=False))
                    stub.DropClient(mpm_pb2.Empty())
        elif run_as == Run.client:
            stub.DropClient(mpm_pb2.Empty())
        sys.exit(0)
    except SystemExit:
        os._exit(0)