Пример #1
0
class MonotoneLowerBound(Planner):
    def __init__(self, mdp, name='MonotoneUpperBound'):
        relaxed_mdp = MonotoneLowerBound._construct_deterministic_relaxation_mdp(mdp)

        Planner.__init__(self, relaxed_mdp, name)
        self.vi = ValueIteration(relaxed_mdp)
        self.states = self.vi.get_states()
        self.vi._compute_matrix_from_trans_func()
        self.vi.run_vi()
        self.lower_values = self._construct_lower_values()

    @staticmethod
    def _construct_deterministic_relaxation_mdp(mdp):
        relaxed_mdp = copy.deepcopy(mdp)
        relaxed_mdp.set_slip_prob(0.0)
        return relaxed_mdp

    def _construct_lower_values(self):
        values = defaultdict()
        for state in self.states:
            values[state] = self.vi.get_value(state)
        return values
class MonotoneLowerBound(Planner):
    def __init__(self, mdp, name='MonotoneUpperBound'):
        relaxed_mdp = MonotoneLowerBound._construct_deterministic_relaxation_mdp(
            mdp)

        Planner.__init__(self, relaxed_mdp, name)
        self.vi = ValueIteration(relaxed_mdp)
        self.states = self.vi.get_states()
        self.vi._compute_matrix_from_trans_func()
        self.vi.run_vi()
        self.lower_values = self._construct_lower_values()

    @staticmethod
    def _construct_deterministic_relaxation_mdp(mdp):
        relaxed_mdp = copy.deepcopy(mdp)
        relaxed_mdp.set_slip_prob(0.0)
        return relaxed_mdp

    def _construct_lower_values(self):
        values = defaultdict()
        for state in self.states:
            values[state] = self.vi.get_value(state)
        return values
Пример #3
0
def _draw_state(screen,
                grid_mdp,
                state,
                policy=None,
                action_char_dict={},
                show_value=False,
                agent=None,
                draw_statics=False,
                agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        grid_mdp (MDP)
        state (State)
        show_value (bool)
        agent (Agent): Used to show value, by default uses VI.
        draw_statics (bool)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            # Use agent value estimates.
            for s in agent.q_func.keys():
                val_text_dict[s.x][s.y] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(grid_mdp)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.x][s.y] = vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    if policy:
        vi = ValueIteration(grid_mdp)
        vi.run_vi()
        for s in vi.get_states():
            policy_dict[s.x][s.y] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.
    cell_width = (scr_width - width_buffer * 2) / grid_mdp.width
    cell_height = (scr_height - height_buffer * 2) / grid_mdp.height
    goal_locs = grid_mdp.get_goal_locs()
    lava_locs = grid_mdp.get_lava_locs()
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    # Draw the static entities.
    if draw_statics:
        # For each row:
        for i in range(grid_mdp.width):
            # For each column:
            for j in range(grid_mdp.height):

                top_left_point = width_buffer + cell_width * i, height_buffer + cell_height * j
                r = pygame.draw.rect(
                    screen, (46, 49, 49),
                    top_left_point + (cell_width, cell_height), 3)

                if policy and not grid_mdp.is_wall(i + 1, grid_mdp.height - j):
                    a = policy_dict[i + 1][grid_mdp.height - j]
                    if a not in action_char_dict:
                        text_a = a
                    else:
                        text_a = action_char_dict[a]
                    text_center_point = int(top_left_point[0] +
                                            cell_width / 2.0 -
                                            10), int(top_left_point[1] +
                                                     cell_height / 3.0)
                    text_rendered_a = cc_font.render(text_a, True,
                                                     (46, 49, 49))
                    screen.blit(text_rendered_a, text_center_point)

                if show_value and not grid_mdp.is_wall(i + 1,
                                                       grid_mdp.height - j):
                    # Draw the value.
                    val = val_text_dict[i + 1][grid_mdp.height - j]
                    color = mdpv.val_to_color(val)
                    pygame.draw.rect(
                        screen, color,
                        top_left_point + (cell_width, cell_height), 0)
                    # text_center_point = int(top_left_point[0] + cell_width/2.0 - 10), int(top_left_point[1] + cell_height/7.0)
                    # text = str(round(val,2))
                    # text_rendered = reg_font.render(text, True, (46, 49, 49))
                    # screen.blit(text_rendered, text_center_point)

                if grid_mdp.is_wall(i + 1, grid_mdp.height - j):
                    # Draw the walls.
                    top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                    r = pygame.draw.rect(
                        screen, (94, 99, 99),
                        top_left_point + (cell_width - 10, cell_height - 10),
                        0)

                if (i + 1, grid_mdp.height - j) in goal_locs:
                    # Draw goal.
                    circle_center = int(top_left_point[0] + cell_width /
                                        2.0), int(top_left_point[1] +
                                                  cell_height / 2.0)
                    circler_color = (154, 195, 157)
                    pygame.draw.circle(screen, circler_color, circle_center,
                                       int(min(cell_width, cell_height) / 3.0))

                if (i + 1, grid_mdp.height - j) in lava_locs:
                    # Draw goal.
                    circle_center = int(top_left_point[0] + cell_width /
                                        2.0), int(top_left_point[1] +
                                                  cell_height / 2.0)
                    circler_color = (224, 145, 157)
                    pygame.draw.circle(screen, circler_color, circle_center,
                                       int(min(cell_width, cell_height) / 4.0))

                # Current state.
                if not show_value and (i + 1, grid_mdp.height - j) == (
                        state.x, state.y) and agent_shape is None:
                    tri_center = int(top_left_point[0] +
                                     cell_width / 2.0), int(top_left_point[1] +
                                                            cell_height / 2.0)
                    agent_shape = _draw_agent(
                        tri_center,
                        screen,
                        base_size=min(cell_width, cell_height) / 2.5 - 8)

    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255, 255, 255), agent_shape)
        top_left_point = width_buffer + cell_width * (
            state.x - 1), height_buffer + cell_height * (grid_mdp.height -
                                                         state.y)
        tri_center = int(top_left_point[0] +
                         cell_width / 2.0), int(top_left_point[1] +
                                                cell_height / 2.0)

        # Draw new.
        # if not show_value or policy is not None:
        agent_shape = _draw_agent(
            tri_center,
            screen,
            base_size=min(cell_width, cell_height) / 2.5 - 16)

    pygame.display.flip()

    return agent_shape
def _draw_state(screen,
                two_goal_oomdp,
                state,
                policy=None,
                action_char_dict={},
                show_value=False,
                agent=None,
                draw_statics=True,
                agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        two_goal_oomdp (TwoGoalOOMDP)
        state (State)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            if agent.name == 'Q-learning':
                # Use agent value estimates.
                for s in agent.q_func.keys():
                    val_text_dict[s.get_agent_x()][s.get_agent_y()] = agent.get_value(s)
            # slightly abusing the distinction between agents and planning modules...
            else:
                for s in two_goal_oomdp.get_states():
                    val_text_dict[s.get_agent_x()][s.get_agent_y()] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(two_goal_oomdp, sample_rate=10)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.get_agent_x()][s.get_agent_y()] = vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda : defaultdict(str))
    if policy:
        for s in two_goal_oomdp.get_states():
            policy_dict[s.get_agent_x()][s.get_agent_y()] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0) # Add 30 for title.
    cell_width = (scr_width - width_buffer * 2) / two_goal_oomdp.width
    cell_height = (scr_height - height_buffer * 2) / two_goal_oomdp.height
    objects = state.get_objects()
    agent_x, agent_y = objects["agent"][0]["x"], objects["agent"][0]["y"]
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255,255,255), agent_shape)

    # Statics
    if draw_statics:
        # Draw walls.
        for w in two_goal_oomdp.walls:
            w_x, w_y = w["x"], w["y"]
            top_left_point = width_buffer + cell_width * (w_x - 1) + 5, height_buffer + cell_height * (
                    two_goal_oomdp.height - w_y) + 5
            pygame.draw.rect(screen, (46, 49, 49), top_left_point + (cell_width - 10, cell_height - 10), 0)

    # Draw the two goals
    col_idxs = [-2, 0]
    for i, g in enumerate(two_goal_oomdp.goals):
        dest_x, dest_y = g["x"], g["y"]
        top_left_point = int(width_buffer + cell_width*(dest_x - 1) + 75), int(height_buffer + cell_height*(two_goal_oomdp.height - dest_y) + 65)
        dest_col = (int(max(color_ls[col_idxs[i]][0]-30, 0)), int(max(color_ls[col_idxs[i]][1]-30, 0)), int(max(color_ls[col_idxs[i]][2]-30, 0)))
        center = top_left_point + (cell_width / 2, cell_height / 2)
        radius = 45
        iterations = 150
        for i in range(iterations):
            ang = i * 3.14159 * 2 / iterations
            dx = int(math.cos(ang) * radius)
            dy = int(math.sin(ang) * radius)
            x = center[0] + dx
            y = center[1] + dy
            pygame.draw.circle(screen, dest_col, (x, y), 5)

    # Draw new agent.
    top_left_point = width_buffer + cell_width * (agent_x - 1), height_buffer + cell_height * (
                two_goal_oomdp.height - agent_y)
    agent_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)
    agent_shape = _draw_agent(agent_center, screen, base_size=min(cell_width, cell_height) / 2.5 - 4)

    if draw_statics:
        # For each row:
        for i in range(two_goal_oomdp.width):
            # For each column:
            for j in range(two_goal_oomdp.height):
                top_left_point = width_buffer + cell_width*i, height_buffer + cell_height*j
                r = pygame.draw.rect(screen, (46, 49, 49), top_left_point + (cell_width, cell_height), 3)

                # Show value of states.
                if show_value and not two_goal_helpers.is_wall(two_goal_oomdp, i + 1, two_goal_oomdp.height - j):
                    # Draw the value.
                    val = val_text_dict[i + 1][two_goal_oomdp.height - j]
                    color = mdpv.val_to_color(val)
                    pygame.draw.rect(screen, color, top_left_point + (cell_width, cell_height), 0)
                    value_text = reg_font.render(str(round(val, 2)), True, (46, 49, 49))
                    text_center_point = int(top_left_point[0] + cell_width / 2.0 - 10), int(
                        top_left_point[1] + cell_height / 3.0)
                    screen.blit(value_text, text_center_point)

                # Show optimal action to take in each grid cell.
                if policy and not two_goal_helpers.is_wall(two_goal_oomdp, i + 1, two_goal_oomdp.height - j):
                    a = policy_dict[i+1][two_goal_oomdp.height - j]
                    if a not in action_char_dict:
                        text_a = a
                    else:
                        text_a = action_char_dict[a]
                    text_center_point = int(top_left_point[0] + cell_width/2.0 - 10), int(top_left_point[1] + cell_height/3.0)
                    text_rendered_a = cc_font.render(text_a, True, (46, 49, 49))
                    screen.blit(text_rendered_a, text_center_point)

    pygame.display.flip()

    return agent_shape
Пример #5
0
def _draw_state(screen,
                pudd_mdp,
                state,
                policy=None,
                action_char_dict={},
                show_value=False,
                agent=None,
                draw_statics=False,
                agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        pudd_mdp (MDP)
        state (State)
        show_value (bool)
        agent (Agent): Used to show value, by default uses VI.
        draw_statics (bool)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            # Use agent value estimates.
            for s in agent.q_func.keys():
                val_text_dict[s.x][s.y] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(pudd_mdp)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.x][s.y] = vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    if policy:
        vi = ValueIteration(pudd_mdp)
        vi.run_vi()
        for s in vi.get_states():
            policy_dict[s.x][s.y] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.
    cell_width = (scr_width - width_buffer * 2) / 10  #pudd_mdp.width
    cell_height = (scr_height - height_buffer * 2) / 10  # pudd_mdp.height
    goal_locs = pudd_mdp.get_goal_locs()
    # puddle_rects = pudd_mdp.get_puddle_rects()
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)
    delta = 0.2
    #print ("goal locs", goal_locs)
    # Draw the static entities.
    if draw_statics:
        # For each row:
        for i in np.linspace(0, 1, 11):
            # For each column:
            for j in np.linspace(0, 1, 11):
                #print ("i,j",i,j)

                top_left_point = width_buffer + cell_width * i * 10, height_buffer + cell_height * j * 10
                r = pygame.draw.rect(
                    screen, (46, 49, 49),
                    top_left_point + (cell_width, cell_height), 3)

                if pudd_mdp.is_puddle_location(i, j):
                    # Draw the walls.
                    #print ("True for ", i, j)
                    top_left_point = width_buffer + cell_width * (
                        i * 10) + 5, height_buffer + cell_height * j * 10 + 5
                    r = pygame.draw.rect(
                        screen, (0, 127, 255),
                        top_left_point + (cell_width - 10, cell_height - 10),
                        0)

                if [min(i + delta, 1.0), min(j + delta, 1.0)] in goal_locs:
                    # Draw goal.
                    circle_center = int(top_left_point[0] + cell_width /
                                        2.0), int(top_left_point[1] +
                                                  cell_height / 2.0)
                    circler_color = (154, 195, 157)
                    pygame.draw.circle(screen, circler_color, circle_center,
                                       int(min(cell_width, cell_height) / 3.0))

                # Current state.
                if not show_value and (round(i, 1), round(j, 1)) == (round(
                        state.x, 1), round(state.y,
                                           1)) and agent_shape is None:
                    tri_center = int(top_left_point[0] +
                                     cell_width / 2.0), int(top_left_point[1] +
                                                            cell_height / 2.0)
                    agent_shape = _draw_agent(
                        tri_center,
                        screen,
                        base_size=min(cell_width, cell_height) / 2.5 - 8)


#    if agent_shape is not None:
#        # Clear the old shape.
#        pygame.draw.rect(screen, (255,255,255), agent_shape)
#        top_left_point = width_buffer + cell_width*(state.x - 1), height_buffer + cell_height*(pudd_mdp.height - state.y)
#        tri_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0)

# Draw new.
#        agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height)/2.5 - 8)

    pygame.display.flip()

    return agent_shape
Пример #6
0
def _draw_state(screen,
                grid_mdp,
                state,
                policy=None,
                action_char_dict={},
                show_value=False,
                agent=None,
                draw_statics=False,
                agent_shape=None,
                options=[]):
    '''
    Args:
        screen (pygame.Surface)
        grid_mdp (MDP)
        state (State)
        show_value (bool)
        agent (Agent): Used to show value, by default uses VI.
        draw_statics (bool)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # Make value dict.
    print('optinos=', options)
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            # Use agent value estimates.
            for s in agent.q_func.keys():
                val_text_dict[s.x][s.y] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(grid_mdp)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.x][s.y] = vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    if policy:
        vi = ValueIteration(grid_mdp)
        vi.run_vi()
        for s in vi.get_states():
            policy_dict[s.x][s.y] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.
    cell_width = (scr_width - width_buffer * 2) / grid_mdp.width
    cell_height = (scr_height - height_buffer * 2) / grid_mdp.height
    goal_locs = grid_mdp.get_goal_locs()
    lava_locs = grid_mdp.get_lava_locs()
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    # Draw the static entities.
    if draw_statics:
        # For each row:
        for i in range(grid_mdp.width):
            # For each column:
            for j in range(grid_mdp.height):

                top_left_point = width_buffer + cell_width * i, height_buffer + cell_height * j
                r = pygame.draw.rect(
                    screen, (46, 49, 49),
                    top_left_point + (cell_width, cell_height), 3)

                if policy and not grid_mdp.is_wall(i + 1, grid_mdp.height - j):
                    a = policy_dict[i + 1][grid_mdp.height - j]
                    if a not in action_char_dict:
                        text_a = a
                    else:
                        text_a = action_char_dict[a]
                    text_center_point = int(top_left_point[0] +
                                            cell_width / 2.0 -
                                            10), int(top_left_point[1] +
                                                     cell_height / 3.0)
                    text_rendered_a = cc_font.render(text_a, True,
                                                     (46, 49, 49))
                    screen.blit(text_rendered_a, text_center_point)

                if show_value and not grid_mdp.is_wall(i + 1,
                                                       grid_mdp.height - j):
                    # Draw the value.
                    val = val_text_dict[i + 1][grid_mdp.height - j]
                    color = mdpv.val_to_color(val)
                    pygame.draw.rect(
                        screen, color,
                        top_left_point + (cell_width, cell_height), 0)

                if grid_mdp.is_wall(i + 1, grid_mdp.height - j):
                    # Draw the walls.
                    top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                    r = pygame.draw.rect(
                        screen, (94, 99, 99),
                        top_left_point + (cell_width - 10, cell_height - 10),
                        0)

                if (i + 1, grid_mdp.height - j) in goal_locs:
                    # Draw goal.
                    # TODO: Better visualization?
                    # circle_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0)
                    # circler_color = (154, 195, 157)
                    # pygame.draw.circle(screen, circler_color, circle_center, int(min(cell_width, cell_height) / 3.0))
                    pass

                if (i + 1, grid_mdp.height - j) in lava_locs:
                    # Draw goal.
                    circle_center = int(top_left_point[0] + cell_width /
                                        2.0), int(top_left_point[1] +
                                                  cell_height / 2.0)
                    circler_color = (224, 145, 157)
                    pygame.draw.circle(screen, circler_color, circle_center,
                                       int(min(cell_width, cell_height) / 4.0))

                # print('options')
                # print(i+1)
                # print(grid_mdp.height - j)
                # print(options)
                if (i + 1, j + 1) in options:
                    # Draw options.
                    # print('Needs to draw options at', i+1, '_', grid_mdp.height - j)
                    #circle_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0)
                    #circler_color = (200, 200, 0)
                    #
                    #pygame.draw.circle(screen, circler_color, circle_center, int(min(cell_width, cell_height) / 4.0))

                    # Add a number for the option
                    indices = [
                        k for k, x in enumerate(options) if x == (i + 1, j + 1)
                    ]
                    for index in indices:
                        ind = int(index / 2) + 1
                        circle_center = int(
                            top_left_point[0] + cell_width / 2.0) + int(
                                cell_width / 6.0 *
                                (ind + 1 - len(options) / 2)), int(
                                    top_left_point[1] + cell_height / 2.0)
                        circler_color = (200, 200, 0)

                        pygame.draw.circle(
                            screen, circler_color, circle_center,
                            int(min(cell_width, cell_height) / 4.0))

                    for index in indices:
                        ind = int(index / 2) + 1
                        print('INDEX=', ind)
                        font = pygame.font.SysFont(None, 24)
                        text = font.render(str(ind), True, (0, 0, 0),
                                           (200, 200, 0))
                        textrect = text.get_rect()
                        textrect.centerx = int(
                            top_left_point[0] + cell_width / 2.0) + int(
                                cell_width / 6.0 *
                                (ind + 1 - len(options) / 2))
                        textrect.centery = int(top_left_point[1] +
                                               cell_height / 2.0)
                        screen.blit(text, textrect)

                # Current state.
                # if not show_value and (i+1,grid_mdp.height - j) == (state.x, state.y) and agent_shape is None:
                #     tri_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0)
                #     agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height)/2.5 - 8)

    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255, 255, 255), agent_shape)
        top_left_point = width_buffer + cell_width * (
            state.x - 1), height_buffer + cell_height * (grid_mdp.height -
                                                         state.y)
        tri_center = int(top_left_point[0] +
                         cell_width / 2.0), int(top_left_point[1] +
                                                cell_height / 2.0)

        # Draw new.
        agent_shape = _draw_agent(
            tri_center,
            screen,
            base_size=min(cell_width, cell_height) / 2.5 - 8)

    pygame.display.flip()

    return agent_shape
Пример #7
0
def draw_state(screen,
               cleanup_mdp,
               state,
               policy=None,
               action_char_dict={},
               show_value=False,
               agent=None,
               draw_statics=False,
               agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        grid_mdp (MDP)
        state (State)
        show_value (bool)
        agent (Agent): Used to show value, by default uses VI.
        draw_statics (bool)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            # Use agent value estimates.
            for s in agent.q_func.keys():
                val_text_dict[s.x][s.y] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(cleanup_mdp)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.x][s.y] = vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    if policy:
        vi = ValueIteration(cleanup_mdp)
        vi.run_vi()
        for s in vi.get_states():
            policy_dict[s.x][s.y] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.

    width = cleanup_mdp.width
    height = cleanup_mdp.height

    cell_width = (scr_width - width_buffer * 2) / width
    cell_height = (scr_height - height_buffer * 2) / height
    # goal_locs = grid_mdp.get_goal_locs()
    # lava_locs = grid_mdp.get_lavacc_locs()
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    # room_locs = [(x + 1, y + 1) for room in cleanup_mdp.rooms for (x, y) in room.points_in_room]
    door_locs = set([(door.x + 1, door.y + 1) for door in state.doors])

    # Draw the static entities.
    # print(draw_statics)
    # draw_statics = True
    # if draw_statics:
        # For each row:
    for i in range(width):
        # For each column:
        for j in range(height):

            top_left_point = width_buffer + cell_width * i, height_buffer + cell_height * j
            r = pygame.draw.rect(screen, (46, 49, 49), top_left_point + (cell_width, cell_height), 3)

            # if policy and not grid_mdp.is_wall(i+1, height - j):
            if policy and (i + 1, height - j) in cleanup_mdp.legal_states:
                a = policy_dict[i + 1][height - j]
                if a not in action_char_dict:
                    text_a = a
                else:
                    text_a = action_char_dict[a]
                text_center_point = int(top_left_point[0] + cell_width / 2.0 - 10), int(
                    top_left_point[1] + cell_height / 3.0)
                text_rendered_a = cc_font.render(text_a, True, (46, 49, 49))
                screen.blit(text_rendered_a, text_center_point)

            # if show_value and not grid_mdp.is_wall(i+1, grid_mdp.height - j):
            if show_value and (i + 1, height - j) in cleanup_mdp.legal_states:
                # Draw the value.
                val = val_text_dict[i + 1][height - j]
                color = mdpv.val_to_color(val)
                pygame.draw.rect(screen, color, top_left_point + (cell_width, cell_height), 0)
                # text_center_point = int(top_left_point[0] + cell_width/2.0 - 10), int(top_left_point[1] + cell_height/7.0)
                # text = str(round(val,2))
                # text_rendered = reg_font.render(text, True, (46, 49, 49))
                # screen.blit(text_rendered, text_center_point)

            # if grid_mdp.is_wall(i+1, grid_mdp.height - j):
            if (i + 1, height - j) not in cleanup_mdp.legal_states:
                # Draw the walls.
                top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                pygame.draw.rect(screen, (94, 99, 99), top_left_point + (cell_width - 10, cell_height - 10), 0)

            if (i + 1, height - j) in door_locs:
                # Draw door
                # door_color = (66, 83, 244)
                door_color = (0, 0, 0)
                top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                pygame.draw.rect(screen, door_color, top_left_point + (cell_width - 10, cell_height - 10), 0)

            else:
                room = cleanup_mdp.check_in_room(state.rooms, i + 1 - 1, height - j - 1)  # Minus 1 for inconsistent x, y
                if room:
                    top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                    room_rgb = _get_rgb(room.color)
                    pygame.draw.rect(screen, room_rgb, top_left_point + (cell_width - 10, cell_height - 10), 0)

            block = cleanup_mdp.find_block(state.blocks, i + 1 - 1, height - j - 1)
            # print(state)
            # print(block)
            if block:
                circle_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)
                block_rgb = _get_rgb(block.color)
                pygame.draw.circle(screen, block_rgb, circle_center, int(min(cell_width, cell_height) / 4.0))

            # Current state.
            if not show_value and (i + 1, height - j) == (state.x + 1, state.y + 1) and agent_shape is None:
                tri_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)
                agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height) / 2.5 - 8)

    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255, 255, 255), agent_shape)
        top_left_point = width_buffer + cell_width * ((state.x + 1) - 1), height_buffer + cell_height * (
                height - (state.y + 1))
        tri_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)

        # Draw new.
        # if not show_value or policy is not None:
        agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height) / 2.5 - 16)

    pygame.display.flip()

    return agent_shape
Пример #8
0
def old_draw_state(screen,
                   cleanup_mdp,
                   state,
                   policy=None,
                   action_char_dict={},
                   show_value=False,
                   agent=None,
                   draw_statics=False,
                   agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        grid_mdp (MDP)
        state (State)
        show_value (bool)
        agent (Agent): Used to show value, by default uses VI.
        draw_statics (bool)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''

    print('Inside draw state\n\n\n\n')
    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            # Use agent value estimates.
            for s in agent.q_func.keys():
                val_text_dict[s.x][s.y] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(cleanup_mdp)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.x][s.y] = vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    if policy:
        vi = ValueIteration(cleanup_mdp)
        vi.run_vi()
        for s in vi.get_states():
            policy_dict[s.x][s.y] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.

    width = cleanup_mdp.width
    height = cleanup_mdp.height

    cell_width = (scr_width - width_buffer * 2) / width
    cell_height = (scr_height - height_buffer * 2) / height
    # goal_locs = grid_mdp.get_goal_locs()
    # lava_locs = grid_mdp.get_lavacc_locs()
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    # room_locs = [(x + 1, y + 1) for room in cleanup_mdp.rooms for (x, y) in room.points_in_room]
    door_locs = set([(door.x + 1, door.y + 1) for door in state.doors])

    # Draw the static entities.
    # print(draw_statics)
    # draw_statics = True
    # if draw_statics:
    # For each row:

    for i in range(width):
        # For each column:
        for j in range(height):

            top_left_point = width_buffer + cell_width * i, height_buffer + cell_height * j
            r = pygame.draw.rect(screen, (46, 49, 49),
                                 top_left_point + (cell_width, cell_height), 3)
            '''
            # if policy and not grid_mdp.is_wall(i+1, height - j):
            if policy and (i + 1, height - j) in cleanup_mdp.legal_states:
                a = policy_dict[i + 1][height - j]
                if a not in action_char_dict:
                    text_a = a
                else:
                    text_a = action_char_dict[a]
                text_center_point = int(top_left_point[0] + cell_width / 2.0 - 10), int(
                    top_left_point[1] + cell_height / 3.0)
                text_rendered_a = cc_font.render(text_a, True, (46, 49, 49))
                screen.blit(text_rendered_a, text_center_point)

            # if show_value and not grid_mdp.is_wall(i+1, grid_mdp.height - j):
            if show_value and (i + 1, height - j) in cleanup_mdp.legal_states:
                # Draw the value.
                val = val_text_dict[i + 1][height - j]
                color = mdpv.val_to_color(val)
                pygame.draw.rect(screen, color, top_left_point + (cell_width, cell_height), 0)
                # text_center_point = int(top_left_point[0] + cell_width/2.0 - 10), int(top_left_point[1] + cell_height/7.0)
                # text = str(round(val,2))
                # text_rendered = reg_font.render(text, True, (46, 49, 49))
                # screen.blit(text_rendered, text_center_point)
            '''

            # if grid_mdp.is_wall(i+1, grid_mdp.height - j):
            if (i + 1, height - j) not in cleanup_mdp.legal_states:
                # Draw the walls.
                top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                pygame.draw.rect(
                    screen, (94, 99, 99),
                    top_left_point + (cell_width - 10, cell_height - 10), 0)

            if (i + 1, height - j) in door_locs:
                # Draw door
                # door_color = (66, 83, 244)
                door_color = (0, 0, 0)
                top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                pygame.draw.rect(
                    screen, door_color,
                    top_left_point + (cell_width - 10, cell_height - 10), 0)

            else:
                room = cleanup_mdp.check_in_room(
                    state.rooms, i + 1 - 1,
                    height - j - 1)  # Minus 1 for inconsistent x, y
                if room:
                    top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                    room_rgb = _get_rgb(room.color)
                    pygame.draw.rect(
                        screen, room_rgb,
                        top_left_point + (cell_width - 10, cell_height - 10),
                        0)

            block = cleanup_mdp.find_block(state.blocks, i + 1 - 1,
                                           height - j - 1)
            # print(state)
            # print(block)
            '''
            # ROMA: to draw objects if needed
            if block:
                circle_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)
                block_rgb = _get_rgb(block.color)
                pygame.draw.circle(screen, block_rgb, circle_center, int(min(cell_width, cell_height) / 4.0))
            
            # Current state.
            # ROMA: to draw the agent if needed
            if not show_value and (i + 1, height - j) == (state.x + 1, state.y + 1) and agent_shape is None:
                tri_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)
                agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height) / 2.5 - 8)
            '''
    '''
    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255, 255, 255), agent_shape)
        top_left_point = width_buffer + cell_width * ((state.x + 1) - 1), height_buffer + cell_height * (
                height - (state.y + 1))
        tri_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)

        # Draw new.
        # if not show_value or policy is not None:
        agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height) / 2.5 - 16)
    '''
    pygame.display.flip()

    return agent_shape
Пример #9
0
def _draw_state(screen,
                taxi_oomdp,
                state,
                policy=None,
                action_char_dict={},
                show_value=False,
                agent=None,
                draw_statics=True,
                agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        taxi_oomdp (TaxiOOMDP)
        state (State)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            if agent.name == 'Q-learning':
                # Use agent value estimates.
                for s in agent.q_func.keys():
                    val_text_dict[s.get_agent_x()][
                        s.get_agent_y()] = agent.get_value(s)
            # slightly abusing the distinction between agents and planning modules...
            else:
                for s in taxi_oomdp.get_states():
                    val_text_dict[s.get_agent_x()][
                        s.get_agent_y()] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(taxi_oomdp, sample_rate=10)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.get_agent_x()][s.get_agent_y()] = vi.get_value(
                    s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    if policy:
        for s in taxi_oomdp.get_states():
            policy_dict[s.get_agent_x()][s.get_agent_y()] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.
    cell_width = (scr_width - width_buffer * 2) / taxi_oomdp.width
    cell_height = (scr_height - height_buffer * 2) / taxi_oomdp.height
    objects = state.get_objects()
    agent_x, agent_y = objects["agent"][0]["x"], objects["agent"][0]["y"]
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255, 255, 255), agent_shape)

    # Statics
    if draw_statics:
        # Draw walls.
        for w in taxi_oomdp.walls:
            w_x, w_y = w["x"], w["y"]
            top_left_point = width_buffer + cell_width * (
                w_x - 1) + 5, height_buffer + cell_height * (
                    taxi_oomdp.height - w_y) + 5
            pygame.draw.rect(
                screen, (46, 49, 49),
                top_left_point + (cell_width - 10, cell_height - 10), 0)

    # Draw the destination.
    for i, p in enumerate(objects["passenger"]):
        # Dest.
        dest_x, dest_y = p["dest_x"], p["dest_y"]
        top_left_point = int(width_buffer + cell_width * (dest_x - 1) +
                             38), int(height_buffer + cell_height *
                                      (taxi_oomdp.height - dest_y) + 18)

        passenger_size = cell_width / 11
        # purple
        dest_col = (188, 30, 230)

        n, r = 6, passenger_size
        x, y = top_left_point[0], top_left_point[1]
        color = dest_col
        pygame.draw.polygon(screen, color,
                            [(x + r * math.cos(2 * math.pi * i / n),
                              y + r * math.sin(2 * math.pi * i / n))
                             for i in range(n)])

    # Draw new agent.
    top_left_point = width_buffer + cell_width * (
        agent_x - 1), height_buffer + cell_height * (taxi_oomdp.height -
                                                     agent_y)
    agent_center = int(top_left_point[0] +
                       cell_width / 2.0), int(top_left_point[1] +
                                              cell_height / 2.0)
    agent_shape = _draw_agent(agent_center,
                              screen,
                              base_size=min(cell_width, cell_height) / 2.5 - 4)

    for i, p in enumerate(objects["passenger"]):
        # Dest.
        x, y = p["x"], p["y"]
        passenger_size = cell_width / 11
        if p["in_taxi"]:
            top_left_point = int(width_buffer + cell_width * (x - 1) +
                                 passenger_size +
                                 58), int(height_buffer + cell_height *
                                          (taxi_oomdp.height - y) +
                                          passenger_size + 15)
        else:
            top_left_point = int(width_buffer + cell_width * (x - 1) +
                                 passenger_size +
                                 25), int(height_buffer + cell_height *
                                          (taxi_oomdp.height - y) +
                                          passenger_size + 34)

        # light green
        dest_col = (59, 189, 23)

        n, r = 6, passenger_size
        x, y = top_left_point[0], top_left_point[1]
        color = dest_col
        pygame.draw.polygon(screen, color,
                            [(x + r * math.cos(2 * math.pi * i / n),
                              y + r * math.sin(2 * math.pi * i / n))
                             for i in range(n)])

    if draw_statics:
        # For each row:
        for i in range(taxi_oomdp.width):
            # For each column:
            for j in range(taxi_oomdp.height):
                top_left_point = width_buffer + cell_width * i, height_buffer + cell_height * j
                r = pygame.draw.rect(
                    screen, (46, 49, 49),
                    top_left_point + (cell_width, cell_height), 3)

                # Show value of states.
                if show_value and not taxi_helpers.is_wall(
                        taxi_oomdp, i + 1, taxi_oomdp.height - j):
                    # Draw the value.
                    val = val_text_dict[i + 1][taxi_oomdp.height - j]
                    color = mdpv.val_to_color(val)
                    pygame.draw.rect(
                        screen, color,
                        top_left_point + (cell_width, cell_height), 0)
                    value_text = reg_font.render(str(round(val, 2)), True,
                                                 (46, 49, 49))
                    text_center_point = int(top_left_point[0] +
                                            cell_width / 2.0 -
                                            10), int(top_left_point[1] +
                                                     cell_height / 3.0)
                    screen.blit(value_text, text_center_point)

                # Show optimal action to take in each grid cell.
                if policy and not taxi_helpers.is_wall(taxi_oomdp, i + 1,
                                                       taxi_oomdp.height - j):
                    a = policy_dict[i + 1][taxi_oomdp.height - j]
                    if a not in action_char_dict:
                        text_a = a
                    else:
                        text_a = action_char_dict[a]
                    text_center_point = int(top_left_point[0] +
                                            cell_width / 2.0 -
                                            10), int(top_left_point[1] +
                                                     cell_height / 3.0)
                    text_rendered_a = cc_font.render(text_a, True,
                                                     (46, 49, 49))
                    screen.blit(text_rendered_a, text_center_point)

    pygame.display.flip()

    return agent_shape
Пример #10
0
def _draw_augmented_state(screen,
                          taxi_oomdp,
                          state,
                          policy=None,
                          action_char_dict={},
                          show_value=False,
                          agent=None,
                          draw_statics=True,
                          agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        taxi_oomdp (TaxiOOMDP)
        state (State)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # There are multiple potential states for each grid cell (e.g. state also depends on whether the taxi currently has
    # the passenger or not), but the value and the policy for each cell is simply given by the most recent state
    # returned by get_states(). Began trying to display at least two values and optimal actions for each cell (depending
    # on the onboarding status of the passenger), but quickly realized that it gets very complicated as the MDP gets
    # more complicated (e.g. state also depends on the location of the passenger).
    # Displaying multiple values and optimal actions and will require either handcrafting the pipeline or
    # investing a lot of time into making the pipeline customizable and robust. Leaving incomplete attempt below as
    # commented out code.

    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    # val_text_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
    if show_value:
        if agent is not None:
            if agent.name == 'Q-learning':
                # Use agent value estimates.
                for s in agent.q_func.keys():
                    val_text_dict[s.get_agent_x()][
                        s.get_agent_y()] = agent.get_value(s)
                    # val_text_dict[s.get_agent_x()][s.get_agent_y()][
                    #     s.get_first_obj_of_class("passenger")["in_taxi"]] += agent.get_value(s)
            # slightly abusing the distinction between agents and planning modules...
            else:
                for s in taxi_oomdp.get_states():
                    val_text_dict[s.get_agent_x()][
                        s.get_agent_y()] = agent.get_value(s)
                    # val_text_dict[s.get_agent_x()][s.get_agent_y()][
                    #     s.get_first_obj_of_class("passenger")["in_taxi"]] += agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(taxi_oomdp, sample_rate=10)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.get_agent_x()][s.get_agent_y()] = vi.get_value(
                    s)
                # val_text_dict[s.get_agent_x()][s.get_agent_y()][
                #     s.get_first_obj_of_class("passenger")["in_taxi"]] += vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    # policy_dict = defaultdict(lambda: defaultdict(lambda : defaultdict(str)))
    if policy:
        for s in taxi_oomdp.get_states():
            policy_dict[s.get_agent_x()][s.get_agent_y()] = policy(s)
            # if policy_dict[s.get_agent_x()][s.get_agent_y()][s.get_first_obj_of_class("passenger")["in_taxi"]] != '':
            #     policy_dict[s.get_agent_x()][s.get_agent_y()][s.get_first_obj_of_class("passenger")["in_taxi"]] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.
    cell_width = (scr_width - width_buffer * 2) / taxi_oomdp.width
    cell_height = (scr_height - height_buffer * 2) / taxi_oomdp.height
    objects = state.get_objects()
    agent_x, agent_y = objects["agent"][0]["x"], objects["agent"][0]["y"]
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255, 255, 255), agent_shape)

    # Statics
    if draw_statics:
        # Draw walls.
        for w in taxi_oomdp.walls:
            w_x, w_y = w["x"], w["y"]
            top_left_point = width_buffer + cell_width * (
                w_x - 1) + 5, height_buffer + cell_height * (
                    taxi_oomdp.height - w_y) + 5
            pygame.draw.rect(
                screen, (46, 49, 49),
                top_left_point + (cell_width - 10, cell_height - 10), 0)

        # Draw tolls.
        for t in taxi_oomdp.tolls:
            t_x, t_y = t["x"], t["y"]
            top_left_point = width_buffer + cell_width * (
                t_x - 1) + 5, height_buffer + cell_height * (
                    taxi_oomdp.height - t_y) + 5
            # Clear the space and redraw with correct transparency (instead of simply adding a new layer which would
            # affect the transparency
            pygame.draw.rect(
                screen, (255, 255, 255),
                top_left_point + (cell_width - 10, cell_height - 10), 0)
            pygame.gfxdraw.box(
                screen, top_left_point + (cell_width - 10, cell_height - 10),
                (224, 230, 67))

        # Draw traffic cells.
        for t in taxi_oomdp.traffic_cells:
            t_x, t_y = t["x"], t["y"]
            top_left_point = width_buffer + cell_width * (
                t_x - 1) + 5, height_buffer + cell_height * (
                    taxi_oomdp.height - t_y) + 5
            # Clear the space and redraw with correct transparency (instead of simply adding a new layer which would
            # affect the transparency
            pygame.draw.rect(
                screen, (255, 255, 255),
                top_left_point + (cell_width - 10, cell_height - 10), 0)
            pygame.gfxdraw.box(
                screen, top_left_point + (cell_width - 10, cell_height - 10),
                (58, 28, 232))

        # Draw fuel stations.
        for f in taxi_oomdp.fuel_stations:
            f_x, f_y = f["x"], f["y"]
            top_left_point = width_buffer + cell_width * (
                f_x - 1) + 5, height_buffer + cell_height * (
                    taxi_oomdp.height - f_y) + 5
            pygame.draw.rect(
                screen, (144, 0, 255),
                top_left_point + (cell_width - 10, cell_height - 10), 0)

    # Draw the destination.
    for i, p in enumerate(objects["passenger"]):
        # Dest.
        dest_x, dest_y = p["dest_x"], p["dest_y"]
        top_left_point = int(width_buffer + cell_width * (dest_x - 1) +
                             27), int(height_buffer + cell_height *
                                      (taxi_oomdp.height - dest_y) + 14)
        dest_col = (int(max(color_ls[-i - 1][0] - 30,
                            0)), int(max(color_ls[-i - 1][1] - 30, 0)),
                    int(max(color_ls[-i - 1][2] - 30, 0)))
        pygame.draw.rect(screen, dest_col,
                         top_left_point + (cell_width / 6, cell_height / 6), 0)

    # Draw new agent.
    top_left_point = width_buffer + cell_width * (
        agent_x - 1), height_buffer + cell_height * (taxi_oomdp.height -
                                                     agent_y)
    agent_center = int(top_left_point[0] +
                       cell_width / 2.0), int(top_left_point[1] +
                                              cell_height / 2.0)
    agent_shape = _draw_agent(agent_center,
                              screen,
                              base_size=min(cell_width, cell_height) / 2.5 - 4)

    # Draw the passengers.
    for i, p in enumerate(objects["passenger"]):
        # Passenger
        pass_x, pass_y = p["x"], p["y"]
        taxi_size = int(min(cell_width, cell_height) / 9.0)
        if p["in_taxi"]:
            top_left_point = int(width_buffer + cell_width * (pass_x - 1) +
                                 taxi_size +
                                 58), int(height_buffer + cell_height *
                                          (taxi_oomdp.height - pass_y) +
                                          taxi_size + 16)
        else:
            top_left_point = int(width_buffer + cell_width * (pass_x - 1) +
                                 taxi_size +
                                 26), int(height_buffer + cell_height *
                                          (taxi_oomdp.height - pass_y) +
                                          taxi_size + 38)
        dest_col = (max(color_ls[-i - 1][0] - 30,
                        0), max(color_ls[-i - 1][1] - 30,
                                0), max(color_ls[-i - 1][2] - 30, 0))
        pygame.draw.circle(screen, dest_col, top_left_point, taxi_size)

    if draw_statics:
        # For each row:
        for i in range(taxi_oomdp.width):
            # For each column:
            for j in range(taxi_oomdp.height):
                top_left_point = width_buffer + cell_width * i, height_buffer + cell_height * j
                r = pygame.draw.rect(
                    screen, (46, 49, 49),
                    top_left_point + (cell_width, cell_height), 3)

                # Show value of states.
                if show_value and not taxi_helpers.is_wall(
                        taxi_oomdp, i + 1, taxi_oomdp.height - j):
                    # Draw the value.
                    val = val_text_dict[i + 1][taxi_oomdp.height - j]
                    color = mdpv.val_to_color(val)
                    pygame.draw.rect(
                        screen, color,
                        top_left_point + (cell_width, cell_height), 0)
                    value_text = reg_font.render(str(round(val, 2)), True,
                                                 (46, 49, 49))
                    text_center_point = int(top_left_point[0] +
                                            cell_width / 2.0 -
                                            10), int(top_left_point[1] +
                                                     cell_height / 3.0)
                    screen.blit(value_text, text_center_point)

                    # Draw the value depending on the status of the passenger (incomplete)
                    # val_1 = val_text_dict[s.get_agent_x()][s.get_agent_y()][0]  # passenger not in taxi
                    # val_2 = val_text_dict[s.get_agent_x()][s.get_agent_y()][1]  # passenger not in taxi
                    # color = mdpv.val_to_color((val_1 + val_2) / 2.)
                    # pygame.draw.rect(screen, color, top_left_point + (cell_width, cell_height), 0)
                    #
                    # value_text = reg_font.render(str(round(val_1, 2)), True, (46, 49, 49))
                    # text_center_point = int(top_left_point[0] + cell_width / 2.0 - 10), int(
                    #     top_left_point[1] + cell_height / 1.5)
                    # screen.blit(value_text, text_center_point)
                    #
                    # value_text = reg_font.render(str(round(val_2, 2)), True, (46, 49, 49))
                    # text_center_point = int(top_left_point[0] + cell_width / 2.0 - 10), int(
                    #     top_left_point[1] + cell_height / 4.5)
                    # screen.blit(value_text, text_center_point)

                # Show optimal action to take in each grid cell.
                if policy and not taxi_helpers.is_wall(taxi_oomdp, i + 1,
                                                       taxi_oomdp.height - j):
                    a = policy_dict[i + 1][taxi_oomdp.height - j]
                    if a not in action_char_dict:
                        text_a = a
                    else:
                        text_a = action_char_dict[a]
                    text_center_point = int(top_left_point[0] +
                                            cell_width / 2.0 -
                                            10), int(top_left_point[1] +
                                                     cell_height / 3.0)
                    text_rendered_a = cc_font.render(text_a, True,
                                                     (46, 49, 49))
                    screen.blit(text_rendered_a, text_center_point)

                    # Draw the policy depending on the status of the passenger (incomplete)
                    # a_1 = policy_dict[i + 1][taxi_oomdp.height - j][0]
                    # a_2 = policy_dict[i + 1][taxi_oomdp.height - j][0]
                    # if a_1 not in action_char_dict: text_a_1 = a_1
                    # else: text_a_1 = action_char_dict[a_1]
                    # if a_2 not in action_char_dict: text_a_2 = a_2
                    # else: text_a_2 = action_char_dict[a_2]
                    # text_center_point = int(top_left_point[0] + cell_width/2.0 - 10), int(top_left_point[1] + cell_height/1.5)
                    # text_rendered_a = cc_font.render(text_a_1, True, (46, 49, 49))
                    # screen.blit(text_rendered_a, text_center_point)
                    # text_center_point = int(top_left_point[0] + cell_width / 2.0 - 10), int(
                    #     top_left_point[1] + cell_height / 4.5)
                    # text_rendered_a = cc_font.render(text_a_2, True, (46, 49, 49))
                    # screen.blit(text_rendered_a, text_center_point)

    pygame.display.flip()

    return agent_shape