예제 #1
0
파일: main.py 프로젝트: segaljared/mdp
def run_q_learning_forest(S, r1, r2):
    forest = WrappedForest(S, r1, r2)
    n_episodes = 10000
    how_often = n_episodes / 100

    stats = IterationStats('stats/ql_forest.csv', dims=2)

    def on_episode(episode, time, q_learner, q):
        forest.print_policy(print, q_learner.get_policy())
        stats.save_iteration(episode, time,
                             numpy.nanmean(numpy.nanmax(q, axis=0)), q)

    def is_done(state, action, next_state):
        if next_state.state_num == 0:
            return True
        return False

    gamma = 0.99
    start = time.time()
    numpy.random.seed(5263228)
    q_l = QLearning(forest,
                    0.5,
                    0.2,
                    gamma,
                    on_episode=on_episode,
                    start_at_0=True,
                    alpha=0.1,
                    is_done=is_done,
                    every_n_episode=how_often)
    stats.start_writing()
    q_l.run(n_episodes)
    stats.done_writing()
    forest.print_policy(print, q_l.get_policy())
    print('took {} s'.format(time.time() - start))

    stats = IterationStats('stats/ql_forest.csv', dims=2)
    analysis.create_iteration_value_graph(
        stats, 'average Q',
        'Average Q for each iteration on Forest Q Learning', 'forest_results')
예제 #2
0
파일: main.py 프로젝트: segaljared/mdp
def run_q_learning_grid_world():
    world = GridWorld('simple_grid.txt', -0.01, include_treasure=False)
    n_episodes = 500000
    how_often = n_episodes / 500

    stats = IterationStats('stats/ql_simple_grid.csv', dims=5)

    def on_update(state, action, next_state, q_learner):
        #print('[{},{}] - {} -> [{},{}]'.format(state.x, state.y, action[0], next_state.x, next_state.y))
        pass

    def on_episode(episode, time, q_learner, q):
        world.print_policy(print, q_learner.get_policy())
        stats.save_iteration(episode, time,
                             numpy.nanmean(numpy.nanmax(q, axis=0)), q)
        #time.sleep(1)

    for state in world.get_states():
        if state.tile_type == GridWorldTile.GOAL:
            goal_state = state
            break

    def initialize_toward_goal(state: GridWorldTile):
        actions = state.get_actions()
        if len(actions) == 0:
            return []
        diff_x = goal_state.x - state.x
        diff_y = goal_state.y - state.y
        best_value = 0.1
        if len(actions) == 5 and actions[4][0].startswith('get treasure'):
            best_action = actions[4][0]
        elif abs(diff_x) >= abs(diff_y):
            if diff_x > 0:
                best_action = 'move east'
            else:
                best_action = 'move west'
        else:
            if diff_y < 0:
                best_action = 'move north'
            else:
                best_action = 'move south'
        values = [-0.1] * len(actions)
        for i, action in enumerate(actions):
            if action[0] == best_action:
                values[i] = best_value
        return values

    gamma = 0.99
    q_l = QLearning(world,
                    0.5,
                    0.05,
                    gamma,
                    on_update=on_update,
                    on_episode=on_episode,
                    initializer=initialize_toward_goal,
                    start_at_0=True,
                    alpha=0.1,
                    every_n_episode=how_often)
    stats.start_writing()
    q_l.run(n_episodes)
    stats.done_writing()
    world.print_policy(print, q_l.get_policy())