Пример #1
0
def get_gridworld():
    gridworld = EnvBaseline(name='Sample Grid World',
                            s_hash_rowL=s_hash_rowL,
                            row_tickL=row_tickL,
                            x_axis_label=x_axis_label,
                            col_tickL=col_tickL,
                            y_axis_label=y_axis_label,
                            colorD={
                                'Goal': 'g',
                                'Pit': 'r',
                                'Start': 'b'
                            },
                            basic_color='skyblue')

    gridworld.set_info('Sample Grid World showing basic MDP creation.')

    # add actions from each state
    #   (note: a_prob will be normalized within add_action_dict)
    gridworld.add_action_dict(actionD)

    # for each action, define the next state and transition probability
    # (here we use the layout definition to aid the logic)
    for s_hash, aL in actionD.items():
        for a_desc in aL:
            sn_hash = get_next_state(s_hash, a_desc)
            reward = rewardD.get(sn_hash, 0.0)

            # for deterministic MDP, use t_prob=1.0
            gridworld.add_transition(s_hash,
                                     a_desc,
                                     sn_hash,
                                     t_prob=1.0,
                                     reward_obj=reward)

    # after the "add" commands, send all states and actions to environment
    # (any required normalization is done here as well.)
    gridworld.define_env_states_actions()

    # If there is a start state, define it here.
    gridworld.start_state_hash = 'Start'

    # If a limited number of start states are desired, define them here.
    gridworld.define_limited_start_state_list([(2, 0), (2, 2)])

    # if a default policy is desired, define it as a dict.
    gridworld.default_policyD = {
        (0, 0): 'R',
        (1, 0): 'U',
        (0, 1): 'R',
        (0, 2): 'R',
        (1, 2): 'U',
        'Start': 'U',
        (2, 2): 'U',
        (2, 1): 'R',
        (2, 3): 'L'
    }

    return gridworld
Пример #2
0
def get_gridworld(step_reward=0.0):
    gridworld = EnvBaseline(
        name='Simple Grid World')  # GenericLayout set below
    gridworld.set_info('Simple Grid World Example.')

    actionD = {
        (0, 0): ('D', 'R'),
        (0, 1): ('L', 'R'),
        (0, 2): ('L', 'D', 'R'),
        (1, 0): ('U', 'D'),
        (1, 2): ('U', 'D', 'R'),
        (2, 0): ('U', 'R'),
        (2, 1): ('L', 'R'),
        (2, 2): ('L', 'R', 'U'),
        (2, 3): ('L', 'U')
    }

    rewardD = {(0, 3): 1, (1, 3): -1}

    for state_hash, actionL in actionD.items():

        for action_desc in actionL:
            gridworld.add_action(state_hash, action_desc,
                                 a_prob=1.0)  # a_prob will be normalized

            a = action_desc
            s = state_hash

            if a == 'U':
                state_next_hash = (s[0] - 1, s[1])
            elif a == 'D':
                state_next_hash = (s[0] + 1, s[1])
            elif a == 'R':
                state_next_hash = (s[0], s[1] + 1)
            elif a == 'L':
                state_next_hash = (s[0], s[1] - 1)

            reward_val = rewardD.get(state_next_hash, step_reward)

            gridworld.add_transition(state_hash,
                                     action_desc,
                                     state_next_hash,
                                     t_prob=1.0,
                                     reward_obj=reward_val)

    gridworld.define_env_states_actions(
    )  # send all states and actions to environment

    gridworld.layout = GenericLayout(
        gridworld)  # uses default "get_layout_row_col_of_state"

    # If there is a start state, define it here.
    gridworld.start_state_hash = (2, 0)
    gridworld.define_limited_start_state_list([(2, 0), (2, 2)])

    # define default policy (if any)
    # Policy Dictionary for: GridWorld

    policyD = {}  # index=state_hash, value=action_desc

    #                 Vpi shown for gamma=0.9
    policyD[(0, 0)] = 'R'  # Vpi=0.81
    policyD[(1, 0)] = 'U'  # Vpi=0.729
    policyD[(0, 1)] = 'R'  # Vpi=0.9
    policyD[(0, 2)] = 'R'  # Vpi=1.0
    policyD[(1, 2)] = 'U'  # Vpi=0.9
    policyD[(2, 0)] = 'U'  # Vpi=0.6561
    policyD[(2, 2)] = 'U'  # Vpi=0.81
    policyD[(2, 1)] = 'R'  # Vpi=0.729
    policyD[(2, 3)] = 'L'  # Vpi=0.729

    gridworld.default_policyD = policyD

    return gridworld