Ejemplo n.º 1
0
def main():
    # Setup MDP, Agents.
    size = 5
    agent = {
        "x": 1,
        "y": 1,
        "dx": 1,
        "dy": 0,
        "dest_x": size,
        "dest_y": size,
        "has_block": 0
    }
    blocks = [{"x": size, "y": 1}]
    lavas = [{
        "x": x,
        "y": y
    } for x, y in map(lambda z: (z + 1, (size + 1) / 2), xrange(size))]

    mdp = TrenchOOMDP(size, size, agent, blocks, lavas)
    ql_agent = QLearnerAgent(actions=mdp.get_actions())
    rand_agent = RandomAgent(actions=mdp.get_actions())

    # Run experiment and make plot.
    # run_agents_on_mdp([ql_agent, rand_agent], mdp, instances=30, episodes=250, steps=250)

    vi = ValueIteration(mdp, delta=0.0001, max_iterations=5000)
    iters, val = vi.run_vi()
    print " done."
    states = vi.get_states()
    num_states = len(states)
    print num_states, states
Ejemplo n.º 2
0
    def __init__(self,
                 mdp,
                 lower_values_init,
                 upper_values_init,
                 tau=10.,
                 name='BRTDP'):
        '''
        Args:
            mdp (MDP): underlying MDP to plan in
            lower_values_init (defaultdict): lower bound initialization on the value function
            upper_values_init (defaultdict): upper bound initialization on the value function
            tau (float): scaling factor to help determine when the bounds on the value function are tight enough
            name (str): Name of the planner
        '''
        Planner.__init__(self, mdp, name)
        self.lower_values = lower_values_init
        self.upper_values = upper_values_init

        # Using the value iteration class for accessing the matrix of transition probabilities
        vi = ValueIteration(mdp, sample_rate=1000)
        self.states = vi.get_states()
        vi._compute_matrix_from_trans_func()
        self.trans_dict = vi.trans_dict

        self.max_diff = (self.upper_values[self.mdp.init_state] -
                         self.lower_values[self.mdp.init_state]) / tau
class MonotoneUpperBound(Planner):
    def __init__(self, mdp, name='MonotoneUpperBound'):
        Planner.__init__(self, mdp, name)
        self.vi = ValueIteration(mdp)
        self.states = self.vi.get_states()
        self.upper_values = self._construct_upper_values()

    def _construct_upper_values(self):
        values = defaultdict()
        for state in self.states:
            values[state] = 1. / (1. - self.gamma)
        return values
Ejemplo n.º 4
0
class MonotoneUpperBound(Planner):
    def __init__(self, mdp, name='MonotoneUpperBound'):
        Planner.__init__(self, mdp, name)
        self.vi = ValueIteration(mdp)
        self.states = self.vi.get_states()
        self.upper_values = self._construct_upper_values()

    def _construct_upper_values(self):
        values = defaultdict()
        for state in self.states:
            values[state] = 1. / (1. - self.gamma)
        return values
Ejemplo n.º 5
0
def make_multitask_sa_info_sa(mdp_distr, beta, is_deterministic_ib=False):
    '''
    Args:
        mdp_distr (simple_rl.MDPDistribution)
        beta (float)
        is_deterministic_ib (float)

    Returns:
        (simple_rl.StateAbstraction)
    '''

    master_sa = None
    all_state_absr = []
    for mdp in mdp_distr.get_all_mdps():

        # Get demo policy.
        vi = ValueIteration(mdp)
        vi.run_vi()
        demo_policy = get_lambda_policy(
            make_det_policy_eps_greedy(vi.policy,
                                       vi.get_states(),
                                       mdp.get_actions(),
                                       epsilon=0.2))

        # Get abstraction.
        pmf_s_phi, phi_pmf, abstr_policy_pmf = run_info_sa(
            mdp,
            demo_policy,
            beta=beta,
            is_deterministic_ib=is_deterministic_ib)
        crisp_sa = convert_prob_sa_to_sa(ProbStateAbstraction(phi_pmf))
        all_state_absr.append(crisp_sa)

    # Make master state abstr by intersection.
    vi = ValueIteration(mdp_distr.get_all_mdps()[0])
    ground_states = vi.get_states()

    master_sa = sa_helpers.merge_state_abstr(all_state_absr, ground_states)

    return master_sa
Ejemplo n.º 6
0
def _make_mini_mdp_option_policy(mini_mdp):
    '''
    Args:
        mini_mdp (MDP)

    Returns:
        Policy
    '''
    # Solve the MDP defined by the terminal abstract state.
    mini_mdp_vi = ValueIteration(mini_mdp, delta=0.005, max_iterations=500, sample_rate=20)
    iters, val = mini_mdp_vi.run_vi()

    o_policy_dict = make_dict_from_lambda(mini_mdp_vi.policy, mini_mdp_vi.get_states())
    o_policy = PolicyFromDict(o_policy_dict)

    return o_policy.get_action, mini_mdp_vi
Ejemplo n.º 7
0
    def __init__(self, mdp, lower_values_init, upper_values_init, tau=10., name='BRTDP'):
        '''
        Args:
            mdp (MDP): underlying MDP to plan in
            lower_values_init (defaultdict): lower bound initialization on the value function
            upper_values_init (defaultdict): upper bound initialization on the value function
            tau (float): scaling factor to help determine when the bounds on the value function are tight enough
            name (str): Name of the planner
        '''
        Planner.__init__(self, mdp, name)
        self.lower_values = lower_values_init
        self.upper_values = upper_values_init

        # Using the value iteration class for accessing the matrix of transition probabilities
        vi = ValueIteration(mdp, sample_rate=1000)
        self.states = vi.get_states()
        vi._compute_matrix_from_trans_func()
        self.trans_dict = vi.trans_dict

        self.max_diff = (self.upper_values[self.mdp.init_state] - self.lower_values[self.mdp.init_state]) / tau
Ejemplo n.º 8
0
def main():

    # Make MDP.
    grid_dim = 11
    mdp = FourRoomMDP(width=grid_dim, height=grid_dim, init_loc=(1, 1), slip_prob=0.05, goal_locs=[(grid_dim, grid_dim)], gamma=0.99)

    # Experiment Type.
    exp_type = "learn_w_abstr"

    # For comparing policies and visualizing.
    beta = 1
    is_deterministic_ib = True
    is_agent_in_control = True

    # For main plotting experiment.
    beta_range = list(chart_utils.drange(0.0, 4.0, 1.0))
    instances = 1

    # Get demo policy.
    vi = ValueIteration(mdp)
    _, val = vi.run_vi()

    # Epsilon greedy policy
    demo_policy = get_lambda_policy(make_det_policy_eps_greedy(vi.policy, vi.get_states(), mdp.get_actions(), epsilon=0.1))

    if exp_type == "plot_info_sa_val_and_num_states":
        # Makes the main two plots.
        make_info_sa_val_and_size_plots(mdp, demo_policy, beta_range, instances=instances, is_agent_in_control=is_agent_in_control)
    elif exp_type == "compare_policies":
        # Makes a plot comparing value of pi-phi combo from info_sa with \pi_d.
        info_sa_compare_policies(mdp, demo_policy, beta=beta, is_deterministic_ib=is_deterministic_ib, is_agent_in_control=is_agent_in_control)
    elif exp_type == "visualize_info_sa_abstr":
        # Visualize the state abstraction found by info_sa.
        info_sa_visualize_abstr(mdp, demo_policy, beta=beta, is_deterministic_ib=is_deterministic_ib, is_agent_in_control=is_agent_in_control)
    elif exp_type == "learn_w_abstr":
        # Run learning experiments for different settings of \beta.
        learn_w_abstr(mdp, demo_policy, is_deterministic_ib=is_deterministic_ib)
    elif exp_type == "planning":
        info_sa_planning_experiment()
Ejemplo n.º 9
0
class MonotoneLowerBound(Planner):
    def __init__(self, mdp, name='MonotoneUpperBound'):
        relaxed_mdp = MonotoneLowerBound._construct_deterministic_relaxation_mdp(mdp)

        Planner.__init__(self, relaxed_mdp, name)
        self.vi = ValueIteration(relaxed_mdp)
        self.states = self.vi.get_states()
        self.vi._compute_matrix_from_trans_func()
        self.vi.run_vi()
        self.lower_values = self._construct_lower_values()

    @staticmethod
    def _construct_deterministic_relaxation_mdp(mdp):
        relaxed_mdp = copy.deepcopy(mdp)
        relaxed_mdp.set_slip_prob(0.0)
        return relaxed_mdp

    def _construct_lower_values(self):
        values = defaultdict()
        for state in self.states:
            values[state] = self.vi.get_value(state)
        return values
class MonotoneLowerBound(Planner):
    def __init__(self, mdp, name='MonotoneUpperBound'):
        relaxed_mdp = MonotoneLowerBound._construct_deterministic_relaxation_mdp(
            mdp)

        Planner.__init__(self, relaxed_mdp, name)
        self.vi = ValueIteration(relaxed_mdp)
        self.states = self.vi.get_states()
        self.vi._compute_matrix_from_trans_func()
        self.vi.run_vi()
        self.lower_values = self._construct_lower_values()

    @staticmethod
    def _construct_deterministic_relaxation_mdp(mdp):
        relaxed_mdp = copy.deepcopy(mdp)
        relaxed_mdp.set_slip_prob(0.0)
        return relaxed_mdp

    def _construct_lower_values(self):
        values = defaultdict()
        for state in self.states:
            values[state] = self.vi.get_value(state)
        return values
Ejemplo n.º 11
0
def run_info_sa(mdp, demo_policy_lambda, iters=500, beta=20.0, convergence_threshold=0.01, is_deterministic_ib=False):
    '''
    Args:
        mdp (simple_rl.MDP)
        demo_policy (lambda : simple_rl.State --> str)
        iters (int)
        beta (float)
        convergence_threshold (float): When all three distributions satisfy
            L1(p_{t+1}, p_t) < @convergence_threshold, we stop iterating.
        is_deterministic_ib (bool): If true, run DIB, else IB.

    Returns:
        (dict): P(s_phi)
        (dict): P(s_phi | s)
        (dict): P(a | s_phi)

    Summary:
        Runs the Blahut-Arimoto like algorithm for the given mdp.
    '''
    print "~"*16
    print "~~ BETA =", beta, "~~"
    print "~"*16

    # Get state and action space.
    vi = ValueIteration(mdp)
    ground_states = vi.get_states()
    num_ground_states = len(ground_states)
    actions = mdp.get_actions()

    # Get pmf demo policy and stationary distribution.
    demo_policy_pmf = get_pmf_policy(demo_policy_lambda, ground_states, actions)
    pmf_s = get_stationary_rho_from_policy(demo_policy_lambda, mdp, ground_states)

    # Init distributions.
    phi_pmf, abstract_states = init_random_phi(ground_states, deterministic=is_deterministic_ib)
    pmf_s_phi = compute_prob_of_s_phi(pmf_s, phi_pmf, ground_states, abstract_states)
    abstr_policy_pmf = compute_abstr_policy(demo_policy_pmf, ground_states, abstract_states, actions, phi_pmf, pmf_s, pmf_s_phi, deterministic=is_deterministic_ib)

    # info_sa.
    for i in range(iters):
        # print 'Iteration {0} of {1}'.format(i+1, iters)

        # (A) Compute \phi.
        next_phi_pmf = compute_phi_pmf(pmf_s, pmf_s_phi, demo_policy_pmf, abstr_policy_pmf, ground_states, abstract_states, beta=beta, deterministic=is_deterministic_ib)

        # (B) Compute \rho(s).
        next_pmf_s_phi = compute_prob_of_s_phi(pmf_s, next_phi_pmf, ground_states, abstract_states)

        # (C) Compute \pi_\phi.
        next_abstr_policy_pmf = compute_abstr_policy(demo_policy_pmf, ground_states, abstract_states, actions, next_phi_pmf, pmf_s, next_pmf_s_phi, deterministic=is_deterministic_ib)

        # Convergence checks.
        coding_update_delta = max([l1_distance(next_phi_pmf[s], phi_pmf[s]) for s in ground_states])
        policy_update_delta = max([l1_distance(next_abstr_policy_pmf[s_phi], abstr_policy_pmf[s_phi]) for s_phi in abstract_states])
        state_distr_update_delta = l1_distance(next_pmf_s_phi, pmf_s_phi)

        # Debugging.
        is_coding_converged = coding_update_delta < convergence_threshold
        is_policy_converged = policy_update_delta < convergence_threshold
        is_pmf_s_phi_converged = state_distr_update_delta < convergence_threshold

        # Update pointers.
        phi_pmf = next_phi_pmf
        pmf_s_phi = next_pmf_s_phi
        abstr_policy_pmf = next_abstr_policy_pmf

        if is_coding_converged and is_policy_converged and is_pmf_s_phi_converged:
            print "\tinfo_sa Converged."
            break

    return pmf_s_phi, phi_pmf, abstr_policy_pmf
Ejemplo n.º 12
0
def _draw_state(screen,
                taxi_oomdp,
                state,
                policy=None,
                action_char_dict={},
                show_value=False,
                agent=None,
                draw_statics=True,
                agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        taxi_oomdp (TaxiOOMDP)
        state (State)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            if agent.name == 'Q-learning':
                # Use agent value estimates.
                for s in agent.q_func.keys():
                    val_text_dict[s.get_agent_x()][
                        s.get_agent_y()] = agent.get_value(s)
            # slightly abusing the distinction between agents and planning modules...
            else:
                for s in taxi_oomdp.get_states():
                    val_text_dict[s.get_agent_x()][
                        s.get_agent_y()] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(taxi_oomdp, sample_rate=10)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.get_agent_x()][s.get_agent_y()] = vi.get_value(
                    s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    if policy:
        for s in taxi_oomdp.get_states():
            policy_dict[s.get_agent_x()][s.get_agent_y()] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.
    cell_width = (scr_width - width_buffer * 2) / taxi_oomdp.width
    cell_height = (scr_height - height_buffer * 2) / taxi_oomdp.height
    objects = state.get_objects()
    agent_x, agent_y = objects["agent"][0]["x"], objects["agent"][0]["y"]
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255, 255, 255), agent_shape)

    # Statics
    if draw_statics:
        # Draw walls.
        for w in taxi_oomdp.walls:
            w_x, w_y = w["x"], w["y"]
            top_left_point = width_buffer + cell_width * (
                w_x - 1) + 5, height_buffer + cell_height * (
                    taxi_oomdp.height - w_y) + 5
            pygame.draw.rect(
                screen, (46, 49, 49),
                top_left_point + (cell_width - 10, cell_height - 10), 0)

    # Draw the destination.
    for i, p in enumerate(objects["passenger"]):
        # Dest.
        dest_x, dest_y = p["dest_x"], p["dest_y"]
        top_left_point = int(width_buffer + cell_width * (dest_x - 1) +
                             38), int(height_buffer + cell_height *
                                      (taxi_oomdp.height - dest_y) + 18)

        passenger_size = cell_width / 11
        # purple
        dest_col = (188, 30, 230)

        n, r = 6, passenger_size
        x, y = top_left_point[0], top_left_point[1]
        color = dest_col
        pygame.draw.polygon(screen, color,
                            [(x + r * math.cos(2 * math.pi * i / n),
                              y + r * math.sin(2 * math.pi * i / n))
                             for i in range(n)])

    # Draw new agent.
    top_left_point = width_buffer + cell_width * (
        agent_x - 1), height_buffer + cell_height * (taxi_oomdp.height -
                                                     agent_y)
    agent_center = int(top_left_point[0] +
                       cell_width / 2.0), int(top_left_point[1] +
                                              cell_height / 2.0)
    agent_shape = _draw_agent(agent_center,
                              screen,
                              base_size=min(cell_width, cell_height) / 2.5 - 4)

    for i, p in enumerate(objects["passenger"]):
        # Dest.
        x, y = p["x"], p["y"]
        passenger_size = cell_width / 11
        if p["in_taxi"]:
            top_left_point = int(width_buffer + cell_width * (x - 1) +
                                 passenger_size +
                                 58), int(height_buffer + cell_height *
                                          (taxi_oomdp.height - y) +
                                          passenger_size + 15)
        else:
            top_left_point = int(width_buffer + cell_width * (x - 1) +
                                 passenger_size +
                                 25), int(height_buffer + cell_height *
                                          (taxi_oomdp.height - y) +
                                          passenger_size + 34)

        # light green
        dest_col = (59, 189, 23)

        n, r = 6, passenger_size
        x, y = top_left_point[0], top_left_point[1]
        color = dest_col
        pygame.draw.polygon(screen, color,
                            [(x + r * math.cos(2 * math.pi * i / n),
                              y + r * math.sin(2 * math.pi * i / n))
                             for i in range(n)])

    if draw_statics:
        # For each row:
        for i in range(taxi_oomdp.width):
            # For each column:
            for j in range(taxi_oomdp.height):
                top_left_point = width_buffer + cell_width * i, height_buffer + cell_height * j
                r = pygame.draw.rect(
                    screen, (46, 49, 49),
                    top_left_point + (cell_width, cell_height), 3)

                # Show value of states.
                if show_value and not taxi_helpers.is_wall(
                        taxi_oomdp, i + 1, taxi_oomdp.height - j):
                    # Draw the value.
                    val = val_text_dict[i + 1][taxi_oomdp.height - j]
                    color = mdpv.val_to_color(val)
                    pygame.draw.rect(
                        screen, color,
                        top_left_point + (cell_width, cell_height), 0)
                    value_text = reg_font.render(str(round(val, 2)), True,
                                                 (46, 49, 49))
                    text_center_point = int(top_left_point[0] +
                                            cell_width / 2.0 -
                                            10), int(top_left_point[1] +
                                                     cell_height / 3.0)
                    screen.blit(value_text, text_center_point)

                # Show optimal action to take in each grid cell.
                if policy and not taxi_helpers.is_wall(taxi_oomdp, i + 1,
                                                       taxi_oomdp.height - j):
                    a = policy_dict[i + 1][taxi_oomdp.height - j]
                    if a not in action_char_dict:
                        text_a = a
                    else:
                        text_a = action_char_dict[a]
                    text_center_point = int(top_left_point[0] +
                                            cell_width / 2.0 -
                                            10), int(top_left_point[1] +
                                                     cell_height / 3.0)
                    text_rendered_a = cc_font.render(text_a, True,
                                                     (46, 49, 49))
                    screen.blit(text_rendered_a, text_center_point)

    pygame.display.flip()

    return agent_shape
Ejemplo n.º 13
0
def _draw_state(screen,
                grid_mdp,
                state,
                policy=None,
                action_char_dict={},
                show_value=False,
                agent=None,
                draw_statics=False,
                agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        grid_mdp (MDP)
        state (State)
        show_value (bool)
        agent (Agent): Used to show value, by default uses VI.
        draw_statics (bool)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            # Use agent value estimates.
            for s in agent.q_func.keys():
                val_text_dict[s.x][s.y] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(grid_mdp)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.x][s.y] = vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    if policy:
        vi = ValueIteration(grid_mdp)
        vi.run_vi()
        for s in vi.get_states():
            policy_dict[s.x][s.y] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.
    cell_width = (scr_width - width_buffer * 2) / grid_mdp.width
    cell_height = (scr_height - height_buffer * 2) / grid_mdp.height
    goal_locs = grid_mdp.get_goal_locs()
    lava_locs = grid_mdp.get_lava_locs()
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    # Draw the static entities.
    if draw_statics:
        # For each row:
        for i in range(grid_mdp.width):
            # For each column:
            for j in range(grid_mdp.height):

                top_left_point = width_buffer + cell_width * i, height_buffer + cell_height * j
                r = pygame.draw.rect(
                    screen, (46, 49, 49),
                    top_left_point + (cell_width, cell_height), 3)

                if policy and not grid_mdp.is_wall(i + 1, grid_mdp.height - j):
                    a = policy_dict[i + 1][grid_mdp.height - j]
                    if a not in action_char_dict:
                        text_a = a
                    else:
                        text_a = action_char_dict[a]
                    text_center_point = int(top_left_point[0] +
                                            cell_width / 2.0 -
                                            10), int(top_left_point[1] +
                                                     cell_height / 3.0)
                    text_rendered_a = cc_font.render(text_a, True,
                                                     (46, 49, 49))
                    screen.blit(text_rendered_a, text_center_point)

                if show_value and not grid_mdp.is_wall(i + 1,
                                                       grid_mdp.height - j):
                    # Draw the value.
                    val = val_text_dict[i + 1][grid_mdp.height - j]
                    color = mdpv.val_to_color(val)
                    pygame.draw.rect(
                        screen, color,
                        top_left_point + (cell_width, cell_height), 0)
                    # text_center_point = int(top_left_point[0] + cell_width/2.0 - 10), int(top_left_point[1] + cell_height/7.0)
                    # text = str(round(val,2))
                    # text_rendered = reg_font.render(text, True, (46, 49, 49))
                    # screen.blit(text_rendered, text_center_point)

                if grid_mdp.is_wall(i + 1, grid_mdp.height - j):
                    # Draw the walls.
                    top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                    r = pygame.draw.rect(
                        screen, (94, 99, 99),
                        top_left_point + (cell_width - 10, cell_height - 10),
                        0)

                if (i + 1, grid_mdp.height - j) in goal_locs:
                    # Draw goal.
                    circle_center = int(top_left_point[0] + cell_width /
                                        2.0), int(top_left_point[1] +
                                                  cell_height / 2.0)
                    circler_color = (154, 195, 157)
                    pygame.draw.circle(screen, circler_color, circle_center,
                                       int(min(cell_width, cell_height) / 3.0))

                if (i + 1, grid_mdp.height - j) in lava_locs:
                    # Draw goal.
                    circle_center = int(top_left_point[0] + cell_width /
                                        2.0), int(top_left_point[1] +
                                                  cell_height / 2.0)
                    circler_color = (224, 145, 157)
                    pygame.draw.circle(screen, circler_color, circle_center,
                                       int(min(cell_width, cell_height) / 4.0))

                # Current state.
                if not show_value and (i + 1, grid_mdp.height - j) == (
                        state.x, state.y) and agent_shape is None:
                    tri_center = int(top_left_point[0] +
                                     cell_width / 2.0), int(top_left_point[1] +
                                                            cell_height / 2.0)
                    agent_shape = _draw_agent(
                        tri_center,
                        screen,
                        base_size=min(cell_width, cell_height) / 2.5 - 8)

    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255, 255, 255), agent_shape)
        top_left_point = width_buffer + cell_width * (
            state.x - 1), height_buffer + cell_height * (grid_mdp.height -
                                                         state.y)
        tri_center = int(top_left_point[0] +
                         cell_width / 2.0), int(top_left_point[1] +
                                                cell_height / 2.0)

        # Draw new.
        # if not show_value or policy is not None:
        agent_shape = _draw_agent(
            tri_center,
            screen,
            base_size=min(cell_width, cell_height) / 2.5 - 16)

    pygame.display.flip()

    return agent_shape
def _draw_state(screen,
                two_goal_oomdp,
                state,
                policy=None,
                action_char_dict={},
                show_value=False,
                agent=None,
                draw_statics=True,
                agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        two_goal_oomdp (TwoGoalOOMDP)
        state (State)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            if agent.name == 'Q-learning':
                # Use agent value estimates.
                for s in agent.q_func.keys():
                    val_text_dict[s.get_agent_x()][s.get_agent_y()] = agent.get_value(s)
            # slightly abusing the distinction between agents and planning modules...
            else:
                for s in two_goal_oomdp.get_states():
                    val_text_dict[s.get_agent_x()][s.get_agent_y()] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(two_goal_oomdp, sample_rate=10)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.get_agent_x()][s.get_agent_y()] = vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda : defaultdict(str))
    if policy:
        for s in two_goal_oomdp.get_states():
            policy_dict[s.get_agent_x()][s.get_agent_y()] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0) # Add 30 for title.
    cell_width = (scr_width - width_buffer * 2) / two_goal_oomdp.width
    cell_height = (scr_height - height_buffer * 2) / two_goal_oomdp.height
    objects = state.get_objects()
    agent_x, agent_y = objects["agent"][0]["x"], objects["agent"][0]["y"]
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255,255,255), agent_shape)

    # Statics
    if draw_statics:
        # Draw walls.
        for w in two_goal_oomdp.walls:
            w_x, w_y = w["x"], w["y"]
            top_left_point = width_buffer + cell_width * (w_x - 1) + 5, height_buffer + cell_height * (
                    two_goal_oomdp.height - w_y) + 5
            pygame.draw.rect(screen, (46, 49, 49), top_left_point + (cell_width - 10, cell_height - 10), 0)

    # Draw the two goals
    col_idxs = [-2, 0]
    for i, g in enumerate(two_goal_oomdp.goals):
        dest_x, dest_y = g["x"], g["y"]
        top_left_point = int(width_buffer + cell_width*(dest_x - 1) + 75), int(height_buffer + cell_height*(two_goal_oomdp.height - dest_y) + 65)
        dest_col = (int(max(color_ls[col_idxs[i]][0]-30, 0)), int(max(color_ls[col_idxs[i]][1]-30, 0)), int(max(color_ls[col_idxs[i]][2]-30, 0)))
        center = top_left_point + (cell_width / 2, cell_height / 2)
        radius = 45
        iterations = 150
        for i in range(iterations):
            ang = i * 3.14159 * 2 / iterations
            dx = int(math.cos(ang) * radius)
            dy = int(math.sin(ang) * radius)
            x = center[0] + dx
            y = center[1] + dy
            pygame.draw.circle(screen, dest_col, (x, y), 5)

    # Draw new agent.
    top_left_point = width_buffer + cell_width * (agent_x - 1), height_buffer + cell_height * (
                two_goal_oomdp.height - agent_y)
    agent_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)
    agent_shape = _draw_agent(agent_center, screen, base_size=min(cell_width, cell_height) / 2.5 - 4)

    if draw_statics:
        # For each row:
        for i in range(two_goal_oomdp.width):
            # For each column:
            for j in range(two_goal_oomdp.height):
                top_left_point = width_buffer + cell_width*i, height_buffer + cell_height*j
                r = pygame.draw.rect(screen, (46, 49, 49), top_left_point + (cell_width, cell_height), 3)

                # Show value of states.
                if show_value and not two_goal_helpers.is_wall(two_goal_oomdp, i + 1, two_goal_oomdp.height - j):
                    # Draw the value.
                    val = val_text_dict[i + 1][two_goal_oomdp.height - j]
                    color = mdpv.val_to_color(val)
                    pygame.draw.rect(screen, color, top_left_point + (cell_width, cell_height), 0)
                    value_text = reg_font.render(str(round(val, 2)), True, (46, 49, 49))
                    text_center_point = int(top_left_point[0] + cell_width / 2.0 - 10), int(
                        top_left_point[1] + cell_height / 3.0)
                    screen.blit(value_text, text_center_point)

                # Show optimal action to take in each grid cell.
                if policy and not two_goal_helpers.is_wall(two_goal_oomdp, i + 1, two_goal_oomdp.height - j):
                    a = policy_dict[i+1][two_goal_oomdp.height - j]
                    if a not in action_char_dict:
                        text_a = a
                    else:
                        text_a = action_char_dict[a]
                    text_center_point = int(top_left_point[0] + cell_width/2.0 - 10), int(top_left_point[1] + cell_height/3.0)
                    text_rendered_a = cc_font.render(text_a, True, (46, 49, 49))
                    screen.blit(text_rendered_a, text_center_point)

    pygame.display.flip()

    return agent_shape
Ejemplo n.º 15
0
class NavigationMDP(GridWorldMDP):
    '''
        Class for Navigation MDP from:
            MacGlashan, James, and Michael L. Littman. "Between Imitation and
            Intention Learning." IJCAI. 2015.
    '''

    ACTIONS = ["up", "down", "left", "right"]

    @staticmethod
    def states_to_features(states, phi):
        """
        Returns phi(states)    
        """
        return np.asarray([phi(s) for s in states], dtype=np.float32)

    @staticmethod
    def states_to_coord(states, phi):
        """
        Returns phi(states)    
        """
        return np.asarray([(s.x, s.y) for s in states], dtype=np.float32)

    def __init__(
            self,
            width=30,
            height=30,
            living_cell_types=["empty", "yellow", "red", "green", "purple"],
            living_cell_rewards=[0, 0, -10, -10, -10],
            living_cell_distribution="probability",
            living_cell_type_probs=[0.68, 0.17, 0.05, 0.05, 0.05],
            living_cell_locs=[
                np.inf, np.inf, [(1, 1), (5, 5)], [(2, 2)], [4, 4]
            ],
            goal_cell_locs=[],
            goal_cell_rewards=[],
            goal_cell_types=[],
            gamma=0.99,
            slip_prob=0.00,
            step_cost=0.0,
            is_goal_terminal=True,
            traj_init_cell_types=[0],
            planning_init_loc=(1, 1),
            planning_rand_init=True,
            name="Navigation MDP"):
        """
        Note: Locations are specified in (x,y) format, but (row, col) convention 
            is used while storing in memory. 
        Args:
            height (int): Height of navigation grid in no. of cells.
            width (int): Width of navigation grid in no. of cells.
            living_cell_types (list of cell types: [str, str, ...]): Non-goal cell types.
            living_cell_rewards (list of int): Reward for each @cell_type.
            living_cell_distribution (str):
                "probability" - will assign cells according to @living_cell_type_probs.
                "manual" - uses @living_cell_locs to assign cells to state space.
            living_cell_type_probs (list of floats): Probability corresponding to 
                each @living_cell_types. 
                Note: Goal isn't factored so actual probabilities can off.
                Default values are chosen arbitrarily larger than percolation threshold 
                for square lattice, which is just an approximation to match cell 
                distribution with that of the paper.
            living_cell_locs (list of list of tuples
            [[(x1,y1), (x2,y2)], [(x3,y3), ...], np.inf, ...}):
                Specifies living cell locations. If elements are set to np.inf, 
                they will be sampled uniformly at random.
            goal_cell_locs (list of tuples: [(int, int)...]): Goal locations.
            goal_cell_rewards (list of int): Goal rewards.
            goal_cell_types (list of str/int): Type of goal corresponding to @goal_cell_locs.
            traj_init_cell_types (list of int): Trajectory init state sampling cell type 
            """
        assert height > 0 and isinstance(height, int) and width > 0 \
               and isinstance(width,
                              int), "height and widht must be integers and > 0"
        assert len(living_cell_types) == len(living_cell_rewards)
        assert living_cell_distribution == "manual" or len(
            living_cell_types) == len(living_cell_type_probs)
        assert living_cell_distribution == "probability" or len(
            living_cell_types) == len(living_cell_locs)
        assert len(goal_cell_types) == len(goal_cell_locs) == len(
            goal_cell_rewards)

        GridWorldMDP.__init__(self,
                              width=width,
                              height=height,
                              init_loc=planning_init_loc,
                              rand_init=planning_rand_init,
                              goal_locs=goal_cell_locs,
                              lava_locs=[()],
                              walls=[],
                              is_goal_terminal=is_goal_terminal,
                              gamma=gamma,
                              init_state=None,
                              slip_prob=slip_prob,
                              step_cost=step_cost,
                              name=name)

        # Living (navigation) cell types (str) and ids
        self.living_cell_types = living_cell_types
        self.living_cell_ids = list(range(len(living_cell_types)))
        # State space (2d grid where each element holds a cell id)
        self.state_space = self.__generate_state_space(
            height, width, living_cell_distribution, living_cell_type_probs,
            living_cell_locs)
        # Preserve a copy without goals
        self.state_space_wo_goals = self.state_space.copy()

        # Rewards
        self.living_cell_rewards = living_cell_rewards
        self.state_rewards = np.asarray(
            [[self.living_cell_rewards[item] for item in row]
             for row in self.state_space]).reshape(height, width)
        # Preserve a copy without goals
        self.state_rewards_wo_goals = self.state_rewards.copy()

        # Update cells and cell_rewards with goal and its rewards
        self.reset_goals(goal_cell_locs, goal_cell_rewards, goal_cell_types)

        # Find set of Empty/Navigable cells for sampling trajectory init state
        self.set_traj_init_cell_types(cell_types=traj_init_cell_types)

        # Run value iteration
        self.value_iter = ValueIteration(self, sample_rate=1)

        # Additional book-keeping
        self.feature_cell_dist = None
        self.feature_cell_dist_kind = 0

    def get_states(self):
        return self.value_iter.get_states()

    def get_trans_dict(self):
        self.value_iter._compute_matrix_from_trans_func()
        return self.value_iter.trans_dict

    def __generate_state_space(self, height, width, living_cell_distribution,
                               living_cell_type_probs, living_cell_locs):

        assert living_cell_distribution in ["probability", "manual"]
        # Assign cell type over state space
        if living_cell_distribution == "probability":

            cells = np.random.choice(len(living_cell_type_probs),
                                     p=living_cell_type_probs,
                                     replace=True,
                                     size=(height, width))
        else:

            inf_cells = [
                idx for idx, elem in enumerate(living_cell_locs)
                if elem == np.inf
            ]
            if len(inf_cells) == 0:
                cells = -1 * np.ones((height, width), dtype=np.int)
            else:
                cells = np.random.choice(inf_cells,
                                         p=[1. / len(inf_cells)] *
                                         len(inf_cells),
                                         replace=True,
                                         size=(height, width))

            for cell_type, cell_locs in enumerate(living_cell_locs):
                if cell_type not in inf_cells:
                    for cell_loc in cell_locs:
                        row, col = self._xy_to_rowcol(cell_loc[0], cell_loc[1])
                        cells[row, col] = cell_type

        # Additional check to ensure all states have corresponding cell type
        assert np.any(cells == -1) == False, \
            "Some grid cells have unassigned cell type! When you use manual " \
            "distribution, make sure each state of the MPD is covered by a " \
            "cell type. Check usage of np.inf in @living_cell_locs."

        return cells

    def _xy_to_rowcol(self, x, y):
        """
        Converts (x,y) to (row,col)
        """
        return self.height - y, x - 1

    def _rowcol_to_xy(self, row, col):
        """
        Converts (row,col) to (x,y)
        """
        return col + 1, self.height - row

    def _reward_func(self, state, action):
        '''
        Args:
            state (State)
            action (str)

        Returns
            (float)
        '''
        r, c = self._xy_to_rowcol(state.x, state.y)
        if self._is_goal_state_action(state, action):
            next_state = self._transition_func(state, action)
            return self.goal_cell_rewards[self.goal_xy_to_idx[
                (next_state.x, next_state.y)]] \
                   + self.state_rewards[r, c] - self.step_cost
        elif self.state_rewards[r, c] == 0:
            return 0 - self.step_cost
        else:
            return self.state_rewards[r, c] - self.step_cost

    def reset_goals(self, goal_cell_locs, goal_cell_rewards, goal_types):
        """
        Resets the goals. Updates cell type grid and cell reward grid as per
        new goal configuration.
        """
        self.goal_cell_locs = goal_cell_locs
        self.goal_cell_rewards = goal_cell_rewards
        self.goal_cell_types = goal_types
        self.goal_cell_ids = list(
            range(self.living_cell_ids[-1] + 1,
                  self.living_cell_ids[-1] + 1 + len(self.goal_cell_locs)))
        # Reset goal xy to idx dict
        self.goal_xy_to_idx = {}
        # Reset cell type and cell reward grid with no goals
        self.state_space = self.state_space_wo_goals.copy()
        self.state_rewards = self.state_rewards_wo_goals.copy()

        # Update goals and their rewards
        for idx, goal_loc in enumerate(self.goal_cell_locs):
            goal_r, goal_c = self._xy_to_rowcol(goal_loc[0], goal_loc[1])
            self.state_space[goal_r, goal_c] = self.goal_cell_ids[idx]
            self.state_rewards[goal_r, goal_c] = self.goal_cell_rewards[idx]
            self.goal_xy_to_idx[(goal_loc[0], goal_loc[1])] = idx
        self.cell_ids = self.living_cell_ids + self.goal_cell_ids
        self.cell_types = self.living_cell_types + self.goal_cell_types
        self.cell_type_rewards = self.living_cell_rewards + self.goal_cell_rewards

        if len(self.goal_cell_locs) != 0:
            self._policy_invalidated = True

    def set_traj_init_cell_types(self, cell_types=[0]):
        """
        Sets cell types for sampling first state of trajectory 
        """
        self.traj_init_cell_row_idxs, self.traj_init_cell_col_idxs = [], []
        for cell_type in cell_types:
            rs, cs = np.where(self.state_space == cell_type)
            self.traj_init_cell_row_idxs.extend(rs)
            self.traj_init_cell_col_idxs.extend(cs)
        self.num_traj_init_states = len(self.traj_init_cell_row_idxs)

    def sample_empty_state(self, idx=None):
        """
        Returns a random empty/white state of type GridWorldState()
        """

        if idx is None:
            rand_idx = np.random.randint(len(self.traj_init_cell_row_idxs))
        else:
            assert 0 <= idx < len(self.traj_init_cell_row_idxs)
            rand_idx = idx

        x, y = self._rowcol_to_xy(self.traj_init_cell_row_idxs[rand_idx],
                                  self.traj_init_cell_col_idxs[rand_idx])
        return GridWorldState(x, y)

    def sample_init_states(self, n, repetition=False):
        """
        Returns a list of random empty/white state of type GridWorldState()
        Note: if repetition is False, the max no. of states returned = # of empty cells in the grid
        """
        assert n > 0

        if repetition is False:
            return [
                self.sample_empty_state(rand_idx)
                for rand_idx in np.random.permutation(
                    len(self.traj_init_cell_row_idxs))[:n]
            ]
        else:
            return [self.sample_empty_state() for i in range(n)]

    def plan(self, state, policy=None, horizon=100):
        '''
        Args:
            state (State)
            policy (fn): S->A
            horizon (int)

        Returns:
            (list): List of actions
        '''
        action_seq = []
        state_seq = [state]
        steps = 0

        while (not state.is_terminal()) and steps < horizon:
            next_action = policy(state)
            action_seq.append(next_action)
            state = self.transition_func(state, next_action)
            state_seq.append(state)
            steps += 1

        return action_seq, state_seq

    def run_value_iteration(self):
        """
        Runs value iteration (if needed) and returns ValueIteration object.
        """
        # If value iteration was run previously, don't re-run it
        if self._policy_invalidated == True:
            _ = self.value_iter.run_vi()
            self._policy_invalidated = False
        return self.value_iter

    def sample_data(self,
                    n_trajectory,
                    init_states=None,
                    init_repetition=False,
                    policy=None,
                    horizon=100,
                    pad_extra_trajectories=True,
                    map_actions_to_index=True):
        """
        Args:
            n_trajectory: number of trajectories to sample
            init_states:
                None - to use random init state
                [GridWorldState(x,y),...] - to use specific init states
            init_repetition: When init_state is set to None, this will sample
                every possible init state and try to not repeat init state
                unless @n_trajectory > @self.num_traj_init_states
            policy (fn): S->A
            horizon (int): planning horizon
            pad_extra_trajectories: If True, this will always return
                @n_trajectory many trajectories, overrides @init_repetition
                if # unique states !=  @n_trajectory
            value_iter_sampling_rate (int): Used for value iteration if policy
                is set to None
            map_actions_to_index (bool): Set True to get action indices in
                trajectory
        Returns:
            (Traj_states, Traj_actions) where
                Traj_states: [[s1, s2, ..., sT], [s4, s1, ..., sT], ...],
                Traj_actions: [[a1, a2, ..., aT], [a4, a1, ..., aT], ...]
        """
        a_s = []
        d_mdp_states = []
        visited_at_init = defaultdict(lambda: False)
        action_to_idx = {a: i for i, a in enumerate(self.actions)}

        if init_states is None:
            init_states = self.sample_init_states(n_trajectory,
                                                  init_repetition)
            if len(init_states) < n_trajectory and pad_extra_trajectories:
                init_states += self.sample_init_states(n_trajectory -
                                                       len(init_states),
                                                       repetition=True)
        else:
            if len(init_states) < n_trajectory and pad_extra_trajectories:
                # More init states need to be sampled
                init_states += self.sample_init_states(
                    n_trajectory - len(init_states), init_repetition)
            else:
                # We have sufficient init states pre-specified, ignore the rest
                # as we only need n_trajectory many
                init_states = init_states[:n_trajectory]

        if policy is None:
            if len(self.goal_cell_locs) == 0:
                raise ValueError("Cannot determine policy, no goals assigned!")
            policy = self.run_value_iteration().policy

        for init_state in init_states:
            action_seq, state_seq = self.plan(init_state,
                                              policy=policy,
                                              horizon=horizon)
            d_mdp_states.append(state_seq)
            if map_actions_to_index:
                a_s.append(action_seq)
            else:
                a_s.append([action_to_idx[a] for a in action_seq])
        return d_mdp_states, a_s

    def __transfrom(self, mat, type):
        if type == "normalize_manhattan":
            return mat / (self.width + self.height)
        if type == "normalize_euclidean":
            return mat / np.sqrt(self.width**2 + self.height**2)
        else:
            return mat

    def compute_grid_distance_features(self,
                                       incl_cells,
                                       incl_goals,
                                       normalize=False):
        """
        Computes distances to specified cell types for entire grid. 
        Returns 3D array (row,col,distance)
        """
        # Convert 2 flags to decimal representation. This is useful to check if
        # requested features are different from those previously stored
        feature_kind = int(str(int(incl_cells)) + str(int(incl_goals)), 2)
        if self.feature_cell_dist is not None \
                and self.feature_cell_dist_kind == feature_kind:
            return self.__transfrom(
                self.feature_cell_dist,
                "normalize_manhattan" if normalize else "None")
        dist_cell_types = copy.deepcopy(
            self.living_cell_ids) if incl_cells else []
        dist_cell_types += self.goal_cell_ids if incl_goals else []

        loc_cells = [
            np.vstack(np.where(self.state_space == cell)).transpose()
            for cell in dist_cell_types
        ]
        self.feature_cell_dist = np.zeros(
            self.state_space.shape + (len(dist_cell_types), ), np.float32)
        for row in range(self.height):
            for col in range(self.width):
                # Note: if particular cell type is missing in the grid, this
                # will assign distance -1 to it
                # Ord=1: Manhattan, Ord=2: Euclidean and so on
                self.feature_cell_dist[row, col] = [
                    np.linalg.norm([row, col] - loc_cell, ord=1, axis=1).min()
                    if len(loc_cell) != 0 else -1 for loc_cell in loc_cells
                ]

        self.feature_cell_dist_kind = feature_kind
        return self.__transfrom(self.feature_cell_dist,
                                "normalize_manhattan" if normalize else "None")

    def feature_at_loc(self,
                       x,
                       y,
                       feature_type="indicator",
                       incl_cell_distances=False,
                       incl_goal_indicator=False,
                       incl_goal_distances=False,
                       normalize_distance=False,
                       dtype=np.float32):
        """
        Returns feature vector at a state corresponding to (x,y) location
        Args:
            x, y (int, int): cartesian coordinates in 2d state space
            feature_type (str): "indicator" to use one-hot encoding of cell type
                "cartesian" to use (x,y) and "rowcol" to use (row, col) as feature
            incl_cell_distances (bool): Include distance to each type of cell.
            incl_goal_distances (bool): Include distance to the goals.
            incl_goal_indicator: True - adds goal indicator feature
                If all goals have same color, it'll add single indicator variable 
                for all goals, otherwise it'll use different indicator variables 
                for each goal.
                (only applicable for "indicator" feature type).
            normalize_distance (bool): Whether to normalize cell type distances
                 to 0-1 range (only used when "incl_distance_features" is True).
            dtype (numpy datatype): cast feature vector to dtype
        """
        row, col = self._xy_to_rowcol(x, y)
        assert feature_type in ["indicator", "cartesian", "rowcol"]

        if feature_type == "indicator":
            if incl_goal_indicator:
                ind_feature = np.eye(len(self.cell_ids))[self.state_space[row,
                                                                          col]]
            else:
                if (x, y) in self.goal_cell_locs:
                    ind_feature = np.zeros(len(self.living_cell_ids))
                else:
                    ind_feature = np.eye(len(
                        self.living_cell_ids))[self.state_space[row, col]]
        elif feature_type == "cartesian":
            ind_feature = np.array([x, y])
        elif feature_type == "rowcol":
            ind_feature = np.array([row, col])

        if incl_cell_distances or incl_goal_distances:
            return np.hstack((ind_feature,
                              self.compute_grid_distance_features(
                                  incl_cell_distances, incl_goal_distances,
                                  normalize_distance)[row, col])).astype(dtype)
        else:
            return ind_feature.astype(dtype)

    def feature_at_state(self,
                         mdp_state,
                         feature_type="indicator",
                         incl_cell_distances=False,
                         incl_goal_indicator=False,
                         incl_goal_distances=False,
                         normalize_distance=False,
                         dtype=np.float32):
        """
        Returns feature vector at a state corresponding to MdP State
        Args:
            mdp_state (int, int): GridWorldState object
            feature_type (str): "indicator" to use one-hot encoding of cell type
                "cartesian" to use (x,y) and "rowcol" to use (row, col) as feature
            incl_cell_distances (bool): Include distance to each type of cell.
            incl_goal_distances (bool): Include distance to the goals.
            incl_goal_indicator: True - adds goal indicator feature
                If all goals have same color, it'll add single indicator variable 
                for all goals, otherwise it'll use different indicator variables 
                for each goal.
                (only applicable for "indicator" feature type).
            normalize_distance (bool): Whether to normalize cell type distances
                to 0-1 range (only used when "incl_distance_features" is True).
            dtype (numpy datatype): cast feature vector to dtype
        """
        return self.feature_at_loc(mdp_state.x, mdp_state.y, feature_type,
                                   incl_cell_distances, incl_goal_indicator,
                                   incl_goal_distances, normalize_distance,
                                   dtype)

    def __display_text(self,
                       ax,
                       x_start,
                       x_end,
                       y_start,
                       y_end,
                       values,
                       fontsize=12):
        """
        Ref: https://stackoverflow.com/questions/33828780/matplotlib-display-array-values-with-imshow
        """
        x_size = x_end - x_start
        y_size = y_end - y_start
        x_positions = np.linspace(start=x_start,
                                  stop=x_end,
                                  num=x_size,
                                  endpoint=False)
        y_positions = np.linspace(start=y_start,
                                  stop=y_end,
                                  num=y_size,
                                  endpoint=False)
        for y_index, y in enumerate(y_positions):
            for x_index, x in enumerate(x_positions):
                label = values[y_index, x_index]
                ax.text(x,
                        y,
                        label,
                        color='black',
                        ha='center',
                        va='center',
                        fontsize=fontsize)

    def visualize_grid(self,
                       values=None,
                       cmap=None,
                       trajectories=None,
                       subplot_str=None,
                       new_fig=True,
                       show_colorbar=False,
                       show_rewards_colorbar=False,
                       int_cells_cmap=cm.viridis,
                       init_marker=".k",
                       traj_marker="-k",
                       text_values=None,
                       text_size=10,
                       traj_linewidth=0.7,
                       init_marker_sz=10,
                       goal_marker="*c",
                       goal_marker_sz=10,
                       end_marker="",
                       end_marker_sz=10,
                       axis_tick_font_sz=8,
                       title="Navigation MDP"):
        """
        Args:
            values (2d ndarray): Values to be visualized in the grid, 
                defaults to cell types.
            cmap (Matplotlib Colormap): Colormap corresponding to values,
                defaults to ListedColormap with colors specified in 
                @self.living_cell_types and @self.goal_cell_types
            trajectories: Trajectories to be shown on the grid.
            subplot_str (str): Subplot number string (e.g., "411", "412", etc.)
            new_fig (bool): Whether to use existing figure context.
            show_rewards_colorbar (bool): Shows colorbar with cell reward values.
            title (str): Title of the plot.
        """
        if new_fig == True:
            plt.figure(figsize=(max(self.height // 4, 6),
                                max(self.width // 4, 6)))
        # Subplot if needed
        if subplot_str is not None:
            plt.subplot(subplot_str)
        # Colormap
        if cmap is None:
            norm = colors.Normalize(vmin=0, vmax=len(self.cell_types) - 1)
            # Leave string colors as it is, convert int colors to normalized rgba
            cell_colors = [
                int_cells_cmap(norm(cell)) if isinstance(cell, int) else cell
                for cell in self.cell_types
            ]
            cmap = colors.ListedColormap(cell_colors)

        if values is None:
            values = self.state_space.copy()
        # Plot values
        im = plt.imshow(values, interpolation='None', cmap=cmap)
        plt.title(title)
        ax = plt.gca()
        ax.set_xticklabels('')
        ax.set_yticklabels('')
        ax.set_xticks(np.arange(self.width), minor=True)
        ax.set_yticks(np.arange(self.height), minor=True)
        ax.set_xticklabels(1 + np.arange(self.width),
                           minor=True,
                           fontsize=axis_tick_font_sz)
        ax.set_yticklabels(1 + np.arange(self.height)[::-1],
                           minor=True,
                           fontsize=axis_tick_font_sz)
        # Plot Trajectories
        if trajectories is not None and len(trajectories) > 0:
            for state_seq in trajectories:
                if len(state_seq) == 0:
                    continue
                path_xs = [s.x - 1 for s in state_seq]
                path_ys = [self.height - (s.y) for s in state_seq]
                plt.plot(path_xs,
                         path_ys,
                         traj_marker,
                         linewidth=traj_linewidth)
                plt.plot(path_xs[0],
                         path_ys[0],
                         init_marker,
                         markersize=init_marker_sz)  # Mark init state
                plt.plot(path_xs[-1],
                         path_ys[-1],
                         end_marker,
                         markersize=end_marker_sz)  # Mark end state
        # Mark goals
        if len(self.goal_cell_locs) != 0:
            for goal_x, goal_y in self.goal_cell_locs:
                plt.plot(goal_x - 1,
                         self.height - goal_y,
                         goal_marker,
                         markersize=goal_marker_sz)
        # Text values on cell
        if text_values is not None:
            self.__display_text(ax,
                                0,
                                self.width,
                                0,
                                self.height,
                                text_values,
                                fontsize=text_size)
        # Colorbar
        if show_colorbar:
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="3%", pad=0.05)
            if show_rewards_colorbar:
                cb = plt.colorbar(im,
                                  ticks=range(len(self.cell_type_rewards)),
                                  cax=cax)
                cb.set_ticklabels(self.cell_type_rewards)
            else:
                plt.colorbar(im, cax=cax)
        if subplot_str is None:
            plt.show()
Ejemplo n.º 16
0
def draw_state(screen,
               cleanup_mdp,
               state,
               policy=None,
               action_char_dict={},
               show_value=False,
               agent=None,
               draw_statics=False,
               agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        grid_mdp (MDP)
        state (State)
        show_value (bool)
        agent (Agent): Used to show value, by default uses VI.
        draw_statics (bool)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            # Use agent value estimates.
            for s in agent.q_func.keys():
                val_text_dict[s.x][s.y] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(cleanup_mdp)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.x][s.y] = vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    if policy:
        vi = ValueIteration(cleanup_mdp)
        vi.run_vi()
        for s in vi.get_states():
            policy_dict[s.x][s.y] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.

    width = cleanup_mdp.width
    height = cleanup_mdp.height

    cell_width = (scr_width - width_buffer * 2) / width
    cell_height = (scr_height - height_buffer * 2) / height
    # goal_locs = grid_mdp.get_goal_locs()
    # lava_locs = grid_mdp.get_lavacc_locs()
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    # room_locs = [(x + 1, y + 1) for room in cleanup_mdp.rooms for (x, y) in room.points_in_room]
    door_locs = set([(door.x + 1, door.y + 1) for door in state.doors])

    # Draw the static entities.
    # print(draw_statics)
    # draw_statics = True
    # if draw_statics:
        # For each row:
    for i in range(width):
        # For each column:
        for j in range(height):

            top_left_point = width_buffer + cell_width * i, height_buffer + cell_height * j
            r = pygame.draw.rect(screen, (46, 49, 49), top_left_point + (cell_width, cell_height), 3)

            # if policy and not grid_mdp.is_wall(i+1, height - j):
            if policy and (i + 1, height - j) in cleanup_mdp.legal_states:
                a = policy_dict[i + 1][height - j]
                if a not in action_char_dict:
                    text_a = a
                else:
                    text_a = action_char_dict[a]
                text_center_point = int(top_left_point[0] + cell_width / 2.0 - 10), int(
                    top_left_point[1] + cell_height / 3.0)
                text_rendered_a = cc_font.render(text_a, True, (46, 49, 49))
                screen.blit(text_rendered_a, text_center_point)

            # if show_value and not grid_mdp.is_wall(i+1, grid_mdp.height - j):
            if show_value and (i + 1, height - j) in cleanup_mdp.legal_states:
                # Draw the value.
                val = val_text_dict[i + 1][height - j]
                color = mdpv.val_to_color(val)
                pygame.draw.rect(screen, color, top_left_point + (cell_width, cell_height), 0)
                # text_center_point = int(top_left_point[0] + cell_width/2.0 - 10), int(top_left_point[1] + cell_height/7.0)
                # text = str(round(val,2))
                # text_rendered = reg_font.render(text, True, (46, 49, 49))
                # screen.blit(text_rendered, text_center_point)

            # if grid_mdp.is_wall(i+1, grid_mdp.height - j):
            if (i + 1, height - j) not in cleanup_mdp.legal_states:
                # Draw the walls.
                top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                pygame.draw.rect(screen, (94, 99, 99), top_left_point + (cell_width - 10, cell_height - 10), 0)

            if (i + 1, height - j) in door_locs:
                # Draw door
                # door_color = (66, 83, 244)
                door_color = (0, 0, 0)
                top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                pygame.draw.rect(screen, door_color, top_left_point + (cell_width - 10, cell_height - 10), 0)

            else:
                room = cleanup_mdp.check_in_room(state.rooms, i + 1 - 1, height - j - 1)  # Minus 1 for inconsistent x, y
                if room:
                    top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                    room_rgb = _get_rgb(room.color)
                    pygame.draw.rect(screen, room_rgb, top_left_point + (cell_width - 10, cell_height - 10), 0)

            block = cleanup_mdp.find_block(state.blocks, i + 1 - 1, height - j - 1)
            # print(state)
            # print(block)
            if block:
                circle_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)
                block_rgb = _get_rgb(block.color)
                pygame.draw.circle(screen, block_rgb, circle_center, int(min(cell_width, cell_height) / 4.0))

            # Current state.
            if not show_value and (i + 1, height - j) == (state.x + 1, state.y + 1) and agent_shape is None:
                tri_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)
                agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height) / 2.5 - 8)

    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255, 255, 255), agent_shape)
        top_left_point = width_buffer + cell_width * ((state.x + 1) - 1), height_buffer + cell_height * (
                height - (state.y + 1))
        tri_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)

        # Draw new.
        # if not show_value or policy is not None:
        agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height) / 2.5 - 16)

    pygame.display.flip()

    return agent_shape
Ejemplo n.º 17
0
class NavigationWorldMDP(MDP):
    """Class for Navigation MDP from:

            MacGlashan, James, and Michael L. Littman. "Between Imitation and
            Intention Learning." IJCAI. 2015.
    """
    # Static constants.
    ACTIONS = ["up", "down", "left", "right"]
    CELL_KIND_NAV = "nav"
    CELL_KIND_WALL = "wall"
    CELL_KIND_GOAL = "goal"

    def __init__(self,
                 width=30,
                 height=30,
                 nav_cell_types=["white", "yellow", "red", "green", "purple"],
                 nav_cell_rewards=[0, 0, -10, -10, -10],
                 nav_cell_p_or_locs=[0.68, 0.17, 0.05, 0.05, 0.05],
                 wall_cell_types=[],
                 wall_cell_rewards=[],
                 wall_cell_locs=[],
                 goal_cell_types=["blue"],
                 goal_cell_rewards=[1.],
                 goal_cell_locs=[[(21, 21)]],
                 init_loc=(1, 1),
                 rand_init=False,
                 gamma=0.99, slip_prob=0.00, step_cost=0.5,
                 is_goal_terminal=True,
                 name="Navigation MDP"):
        """Navigation World MDP constructor.

        Args:
            height (int): Navigation grid height (in no. of cells).
            width (int): Navigation grid width (in no. of cells).
            nav_cell_types (list of <str / int>): Navigation cell types.
            nav_cell_rewards (list of float): Rewards associated with
                @nav_cell_types.
            nav_cell_p_or_locs (list of <float /
            list of tuples (int x, int y)>):
                Probability corresponding to @nav_cell_types distribution, or
                it could be list of fixed/forced locations [(x,y),...].
                (Default values are chosen arbitrarily larger than percolation
                threshold for square lattice--just an approximation to match
                cell distribution in the paper.)
            goal_cell_types (list of <str / int>): Goal cell types.
            goal_cell_locs (list of list of tuples (int, int)): Goal cell
                locations [(x,y),...] associated with @goal_cell_types.
            nav_cell_rewards (list of float): Rewards associated with
                @goal_cell_types.
            init_loc (int x, int y): Init cell to compute state space
                reachability and value iteration.
            gamma (float): MDP discount factor
            slip_prob (float): With this probability agent could fall either
                on left or right from the intended cell.
            step_cost (float): Living penalty.
            is_goal_terminal (bool): True to set goal state terminal.
        Note:
            Locations are specified in (x,y) format, but (row, col)
            convention is used while storing in memory.
        """
        assert len(nav_cell_types) == len(nav_cell_rewards) \
            == len(nav_cell_p_or_locs)
        assert len(wall_cell_types) == len(wall_cell_rewards) \
            == len(wall_cell_locs)
        assert len(goal_cell_types) == len(goal_cell_rewards) \
            == len(goal_cell_locs)

        self.width = width
        self.height = height
        self.init_loc = init_loc
        self.rand_init = rand_init
        self.gamma = gamma
        self.slip_prob = slip_prob
        self.step_cost = step_cost
        self.is_goal_terminal = is_goal_terminal
        self.name = name
        self.goal_cell_locs = goal_cell_locs

        # Setup init location.
        self.rand_init = rand_init
        if rand_init:
            init_loc = random.randint(1, width), random.randint(1, height)
            while self.is_wall(*init_loc) or self.is_goal(*init_loc):
                init_loc = random.randint(1, width), random.randint(1, height)
        self.init_loc = init_loc
        # Construct base class
        MDP.__init__(self, NavigationWorldMDP.ACTIONS, self._transition_func,
                     self._reward_func,
                     init_state=NavigationWorldState(*init_loc),
                     gamma=gamma)

        # Navigation MDP
        self.__reset_nav_mdp()
        self.__register_cell_types(nav_cell_types, wall_cell_types,
                                   goal_cell_types)
        self.__add_cells_by_locs(self.goal_cell_ids, goal_cell_locs,
                                 NavigationWorldMDP.CELL_KIND_GOAL)
        self.__add_cells_by_locs(self.wall_cell_ids, wall_cell_locs,
                                 NavigationWorldMDP.CELL_KIND_WALL)
        self.__add_nav_cells(self.nav_cell_ids, nav_cell_p_or_locs,
                             NavigationWorldMDP.CELL_KIND_NAV)
        self.__check_state_map()

        self.__register_cell_rewards(nav_cell_rewards,
                                     wall_cell_rewards, goal_cell_rewards)

        # Initialize value iteration object (computes reachable states)
        self.value_iter = ValueIteration(self, sample_rate=1)
        self._policy_invalidated = True

    # ---------------------
    # -- Navigation MDP --
    # ---------------------
    def _xy_to_rowcol(self, x, y):
        """Converts (x, y) to (row, col).

        """
        return self.height - y, x - 1

    def _rowcol_to_xy(self, row, col):
        """Converts (row, col) to (x, y).

        """
        return col + 1, self.height - row

    def __reset_cell_type_params(self):
        self.__max_cell_id = -1
        self.combined_cell_types = []
        self.combined_cell_ids = []
        self.cell_type_to_id = {}
        self.cell_id_to_type = {}

    def __reset_nav_mdp(self):

        self.__reset_cell_type_params()
        self.num_traj_init_states = None
        self.feature_cell_dist = None
        self.feature_cell_dist_kind = 0
        self.nav_p_cell_ids = []
        self.nav_p_cell_probs = []
        self.__reset_state_map()

    def __reset_state_map(self):
        self.map_state_cell_id = -1 * np.ones((self.height, self.width),
                                              dtype=np.int)
        self.xy_to_cell_kind = defaultdict(lambda: "<Undefined>")

    def __check_state_map(self):

        if np.any(self.map_state_cell_id == -1):
            raise ValueError("Some states have unassigned cell type!"
                             "Make sure each state of the MDP is covered by"
                             "a cell type. Check usage of probability values"
                             "in @nav_cell_p_or_locs.")

    def __assign_cell_ids(self, cell_types):

        n = len(cell_types)
        cell_ids = list(
            range(self.__max_cell_id + 1, self.__max_cell_id + 1 + n))
        for cell_type, cell_id in zip(cell_types, cell_ids):
            self.cell_type_to_id[cell_type] = cell_id
            self.cell_id_to_type[cell_id] = cell_type
        self.combined_cell_types += cell_types
        self.combined_cell_ids += cell_ids
        self.__max_cell_id += n
        return cell_ids

    def __register_cell_types(self, nav_cell_types, wall_cell_types,
                              goal_cell_types):

        self.__reset_cell_type_params()
        self.nav_cell_types = nav_cell_types
        self.wall_cell_types = wall_cell_types
        self.goal_cell_types = goal_cell_types

        self.nav_cell_ids = self.__assign_cell_ids(nav_cell_types)
        self.wall_cell_ids = self.__assign_cell_ids(wall_cell_types)
        self.goal_cell_ids = self.__assign_cell_ids(goal_cell_types)
        self.n_unique_cells = self.__max_cell_id + 1

    def __add_cells_by_locs(self, cell_ids, cell_locs_list,
                            kind="<Undefined>"):

        for idx, cell_id in enumerate(cell_ids):
            cell_locs = cell_locs_list[idx]
            assert isinstance(cell_locs, list)
            for x, y in cell_locs:
                r, c = self._xy_to_rowcol(x, y)
                self.map_state_cell_id[r, c] = cell_id
                self.xy_to_cell_kind[(x, y)] = kind

    def __add_nav_cells(self, cell_ids, cell_p_or_locs_list,
                        kind="<Undefined>"):

        self.nav_p_cell_ids = []
        self.nav_p_cell_probs = []

        for idx, cell_id in enumerate(cell_ids):
            if isinstance(cell_p_or_locs_list[idx], list):  # locations
                cell_locs = cell_p_or_locs_list[idx]
                for x, y in cell_locs:
                    r, c = self._xy_to_rowcol(x, y)
                    if self.map_state_cell_id[r, c] == -1:
                        self.map_state_cell_id[r, c] = cell_id
                        self.xy_to_cell_kind[(x, y)] = kind
            else:
                assert isinstance(cell_p_or_locs_list[idx],
                                  float)  # probability values
                prob = cell_p_or_locs_list[idx]
                self.nav_p_cell_ids.append(cell_id)
                self.nav_p_cell_probs.append(prob)

        assert round(sum(self.nav_p_cell_probs),
                     9) == 1, "Probability values must sum to 1."
        for r in range(self.height):
            for c in range(self.width):
                if self.map_state_cell_id[r, c] == -1:
                    self.map_state_cell_id[r, c] = np.random.choice(
                        self.nav_p_cell_ids,
                        size=1,
                        p=self.nav_p_cell_probs)
                    x, y = self._rowcol_to_xy(r, c)
                    self.xy_to_cell_kind[(x, y)] = kind

    def __register_cell_rewards(self, nav_cell_rewards,
                                wall_cell_rewards, goal_cell_rewards):
        self.nav_cell_rewards = nav_cell_rewards
        self.wall_cell_rewards = wall_cell_rewards
        self.goal_cell_rewards = goal_cell_rewards
        self.cell_type_rewards = nav_cell_rewards + \
            wall_cell_rewards + \
            goal_cell_rewards

    def get_cell_id(self, x, y):
        """Get cell id of (x, y) location.

        Returns:
            A unique id assigned to each cell type.
        """
        return self.map_state_cell_id[tuple(self._xy_to_rowcol(x, y))]

    def is_wall(self, x, y):
        """Checks if (x,y) cell is wall or not.

        Returns:
            (bool): True iff (x, y) is a wall location.
        """
        return self.get_state_kind(x, y) == NavigationWorldMDP.CELL_KIND_WALL

    def is_goal(self, x, y):
        """Checks if (x,y) cell is goal or not.

        Returns:
            (bool): True iff (x, y) is a goal location.
        """
        return self.get_state_kind(x, y) == NavigationWorldMDP.CELL_KIND_GOAL

    def get_state_kind(self, x, y):
        """Returns kind of state at (x, y) location.

        Returns:
            (str): "wall", "nav", or "goal".
        """
        return self.xy_to_cell_kind[(x, y)]

    def _reset_goals(self, goal_cell_locs, goal_cell_rewards, goal_cell_types):
        """Resets the goals.

        Resamples previous goal locations with navigation cells.
        """
        # Re-sample old goal state cells
        for r in range(self.height):
            for c in range(self.width):
                x, y = self._rowcol_to_xy(r, c)
                if self.is_goal(x, y):
                    self.map_state_cell_id[r, c] = np.random.choice(
                        self.nav_p_cell_ids,
                        size=1,
                        p=self.nav_p_cell_probs)
                    self.xy_to_cell_kind[
                        (x, y)] = NavigationWorldMDP.CELL_KIND_NAV

        self.__register_cell_types(self.nav_cell_types, self.wall_cell_types,
                                   goal_cell_types)
        self.__add_cells_by_locs(self.goal_cell_ids, goal_cell_locs,
                                 NavigationWorldMDP.CELL_KIND_GOAL)
        self.__check_state_map()
        self.__register_cell_rewards(self.nav_cell_rewards,
                                     self.wall_cell_rewards, goal_cell_rewards)
        self._policy_invalidated = True

    def _reset_rewards(self, nav_cell_rewards, wall_cell_rewards,
                       goal_cell_rewards):
        """Resets rewards corresponding to navigation, wall, and goal cells.

        """
        self.__register_cell_rewards(nav_cell_rewards, wall_cell_rewards,
                                     goal_cell_rewards)
        self._policy_invalidated = True

    # ---------
    # -- MDP --
    # ---------
    def _is_goal_state_action(self, state, action):
        """
        Args:
            state (State)
            action (str)

        Returns:
            (bool): True iff the state-action pair send the agent to
            the goal state.

        """
        if self.is_goal(state.x, state.y) == "goal" and self.is_goal_terminal:
            # Already at terminal.
            return False

        if action == "left" and self.is_goal(state.x - 1, state.y):
            return True
        elif action == "right" and self.is_goal(state.x + 1, state.y):
            return True
        elif action == "down" and self.is_goal(state.x, state.y - 1):
            return True
        elif action == "up" and self.is_goal(state.x, state.y + 1):
            return True
        else:
            return False

    def _transition_func(self, state, action):
        """
        Args:
            state (State)
            action (str)

        Returns
            (State)
        """
        if state.is_terminal():
            return state

        r = random.random()
        if self.slip_prob > r:
            # Flip dir.
            if action == "up":
                action = random.choice(["left", "right"])
            elif action == "down":
                action = random.choice(["left", "right"])
            elif action == "left":
                action = random.choice(["up", "down"])
            elif action == "right":
                action = random.choice(["up", "down"])

        if action == "up" and state.y < self.height and not self.is_wall(
                state.x, state.y + 1):
            next_state = NavigationWorldState(state.x, state.y + 1)
        elif action == "down" and state.y > 1 and not self.is_wall(
                state.x, state.y - 1):
            next_state = NavigationWorldState(state.x, state.y - 1)
        elif action == "right" and state.x < self.width and not self.is_wall(
                state.x + 1, state.y):
            next_state = NavigationWorldState(state.x + 1, state.y)
        elif action == "left" and state.x > 1 and not self.is_wall(state.x - 1,
                                                                   state.y):
            next_state = NavigationWorldState(state.x - 1, state.y)
        else:
            next_state = NavigationWorldState(state.x, state.y)

        if self.is_goal(next_state.x, next_state.y) and self.is_goal_terminal:
            next_state.set_terminal(True)

        return next_state

    def _reward_func(self, state, action):
        """
        Args:
            state (State)
            action (str)

        Returns
            (float)
        """
        next_state = self._transition_func(state, action)
        return self.cell_type_rewards[
            self.get_cell_id(next_state.x, next_state.y)] - self.step_cost

    # -------------------------
    # -- Trajectory Sampling --
    # -------------------------
    def plan(self, state, policy=None, horizon=100):
        '''
        Args:
            state (State)
            policy (fn): S->A
            horizon (int)

        Returns:
            (list): List of actions
        '''
        action_seq = []
        state_seq = [state]
        steps = 0

        while (not state.is_terminal()) and steps < horizon:
            next_action = policy(state)
            action_seq.append(next_action)
            state = self.transition_func(state, next_action)
            state_seq.append(state)
            steps += 1

        return action_seq, state_seq

    def set_traj_init_cell_types(self, cell_types):
        """
        Sets cell types for sampling first state of trajectory
        """
        self.traj_init_cell_row_idxs, self.traj_init_cell_col_idxs = [], []
        for cell_type in cell_types:
            rs, cs = np.where(self.map_state_cell_id ==
                              self.cell_type_to_id[cell_type])
            self.traj_init_cell_row_idxs.extend(rs)
            self.traj_init_cell_col_idxs.extend(cs)
        self.num_traj_init_states = len(self.traj_init_cell_row_idxs)

    def sample_traj_init_state(self, idx=None):
        """Samples trajectory init state (GridWorldState).

        Type of init cells to be sampled can be specified using
        set_traj_init_cell_types().
        """
        if idx is None:
            rand_idx = np.random.randint(len(self.traj_init_cell_row_idxs))
        else:
            assert 0 <= idx < len(self.traj_init_cell_row_idxs)
            rand_idx = idx

        x, y = self._rowcol_to_xy(self.traj_init_cell_row_idxs[rand_idx],
                                  self.traj_init_cell_col_idxs[rand_idx])
        return NavigationWorldState(x, y)

    def sample_init_states(self, n, init_unique=False, skip_states=None):
        """Returns a list of init states (GridWorldState).

        If init_unique is True, the max no. of states returned =
        # of empty cells in the grid.
        """
        assert n > 0

        if init_unique:
            c = 0
            init_states_list = []
            for idx in np.random.permutation(
                    len(self.traj_init_cell_row_idxs)):
                state = self.sample_traj_init_state(idx)
                if skip_states is None or state not in skip_states:
                    init_states_list.append(state)
                    c += 1
                    if c == n:
                        return init_states_list
            return init_states_list
        else:
            return [self.sample_traj_init_state() for i in range(n)]

    def sample_trajectories(self, n_traj, horizon, init_states=None,
                            init_cell_types=None, init_unique=False,
                            policy=None, rand_init_to_match_n_traj=True):
        """Samples trajectories.

        Args:
            n_traj: Number of trajectories to sample.
            horizon (int): Planning horizon (max trajectory length).
            init_states:
                None - to use random init state
                [GridWorldState(x,y),...] - to use specific init states
            init_unique: When init_unique is set to False, this will sample
                every possible init state and try to not repeat init state
                unless @n_traj > @self.num_traj_init_states
            policy (fn): S->A
            rand_init_to_match_n_traj: If True, this will always return
                @n_traj many trajectories. If # of unique states are less
                than @n_traj, this will override the effect of @init_unique and
                sample repeated init states.
        Returns:
            (traj_states_list, traj_actions_list) where
                traj_states: [s1, s2, ..., sT]
                traj_actions: [a1, a2, ..., aT]
        """
        assert len(init_cell_types) >= 1

        self.set_traj_init_cell_types(init_cell_types)
        traj_states_list = []
        traj_action_list = []
        traj_init_states = []

        if init_states is None:
            # If no init_states are provided, sample n_traj many init states.
            traj_init_states = self.sample_init_states(n_traj,
                                                       init_unique=init_unique)
        else:
            traj_init_states = copy.deepcopy(init_states)

        if len(traj_init_states) >= n_traj:
            traj_init_states = traj_init_states[:n_traj]
        else:
            # If # init_states < n_traj, sample remaining ones
            if len(traj_init_states) < n_traj:
                traj_init_states += self.sample_init_states(
                    n_traj - len(traj_init_states), init_unique=init_unique,
                    skip_states=traj_init_states)

            # If rand_init_to_match_n_traj is set to True, sample more
            # init_states if needed (may not be unique)
            if rand_init_to_match_n_traj and len(traj_init_states) < n_traj:
                traj_init_states += self.sample_init_states(n_traj,
                                                            init_unique=False)

        if policy is None:
            if len(self.goal_cell_locs) == 0:
                print("Running value iteration with no goals assigned..")
            policy = self.run_value_iteration().policy

        for init_state in traj_init_states:
            action_seq, state_seq = self.plan(init_state, policy=policy,
                                              horizon=horizon)
            traj_states_list.append(state_seq)
            traj_action_list.append(action_seq)

        return traj_states_list, traj_action_list

    # ---------------------
    # -- Value Iteration --
    # ---------------------
    def run_value_iteration(self):
        """Runs value iteration (if needed).

        Returns:
            ValueIteration object.
        """
        # If value iteration was run previously, don't re-run it
        if self._policy_invalidated == True:
            self.value_iter = ValueIteration(self, sample_rate=1)
            _ = self.value_iter.run_vi()
            self._policy_invalidated = False
        return self.value_iter

    def get_value_grid(self):
        """Returns value over states space grid.

        """
        value_iter = self.run_value_iteration()
        V = np.zeros((self.height, self.width), dtype=np.float32)
        for row in range(self.height):
            for col in range(self.width):
                x, y = self._rowcol_to_xy(row, col)
                V[row, col] = value_iter.value_func[NavigationWorldState(x, y)]
        return V

    def get_all_states(self):
        """Returns all states.

        """
        return [NavigationWorldState(x, y) for x in range(1, self.width + 1)
                for y in range(1, self.height + 1)]

    def get_reachable_states(self):
        """Returns all reachable states from @self.init_loc.

        """
        return self.value_iter.get_states()

    def get_trans_dict(self):
        """Returns transition dynamics matrix.

        """
        self.value_iter._compute_matrix_from_trans_func()
        return self.value_iter.trans_dict

    # --------------
    # -- Features --
    # --------------
    def __transfrom(self, mat, type):
        if type == "normalize_manhattan":
            return mat / (self.width + self.height)
        if type == "normalize_euclidean":
            return mat / np.sqrt(self.width**2 + self.height**2)
        else:
            return mat

    def compute_grid_distance_features(self, incl_cells, incl_goals,
                                       normalize=False):
        """Computes distances to specified cell types for entire grid.

        Returns:
            3D array (row, col, distance)
        """
        # Convert 2 flags to decimal representation. This is useful to check if
        # requested features are different from those previously stored
        feature_kind = int(str(int(incl_cells)) + str(int(incl_goals)), 2)
        if self.feature_cell_dist is not None \
                and self.feature_cell_dist_kind == feature_kind:
            return self.__transfrom(self.feature_cell_dist,
                                    "normalize_manhattan" if normalize
                                    else "None")

        dist_cell_ids = copy.deepcopy(
            self.nav_cell_ids) if incl_cells else []
        dist_cell_ids += self.goal_cell_ids if incl_goals else []
        loc_cells = [
            np.vstack(np.where(self.map_state_cell_id == cell_id)).transpose()
            for cell_id in dist_cell_ids]

        self.feature_cell_dist = np.zeros(
            self.map_state_cell_id.shape + (len(dist_cell_ids),), np.float32)

        for row in range(self.height):
            for col in range(self.width):
                # Note: if particular cell type is missing in the grid, this
                # will assign distance -1 to it
                # Ord=1: Manhattan, Ord=2: Euclidean and so on
                self.feature_cell_dist[row, col] = [
                    np.linalg.norm([row, col] - loc_cell, ord=1, axis=1).min()
                    if len(loc_cell) != 0 else -1 for loc_cell in loc_cells]

        self.feature_cell_dist_kind = feature_kind
        return self.__transfrom(self.feature_cell_dist,
                                "normalize_manhattan" if normalize else "None")

    def cell_id_ind_feature(self, cell_id, include_goal=True):
        """Indicator feature for cell_id.

        """
        if include_goal:
            return np.eye(len(self.combined_cell_ids))[cell_id]
        else:
            # use 0 vector for goals
            return np.vstack(
                (np.eye(len(self.nav_cell_ids)),
                 np.zeros((len(self.wall_cell_ids), len(self.nav_cell_ids))),
                 np.zeros((len(self.goal_cell_ids), len(self.nav_cell_ids))))
            )[cell_id]

    def feature_at_loc(self, x, y, feature_type="indicator",
                       incl_cell_distances=False, incl_goal_indicator=True,
                       incl_goal_distances=False, normalize_distance=False,
                       dtype=np.float32):
        """Returns feature vector at a state corresponding to (x, y) location.

        Args:
            x, y (int, int): cartesian coordinates in 2d state space
            feature_type (str): "indicator" to use one-hot encoding of
                cell type "cartesian" to use (x, y) and "rowcol" to use
                (row, col) as feature.
            incl_cell_distances (bool): Include distance to each type of cell.
            incl_goal_distances (bool): Include distance to the goals.
            incl_goal_indicator: True - adds goal indicator feature
                If all goals have same color, it'll add single indicator
                variable for all goals, otherwise it'll use different
                indicator variables for each goal.
                (only applicable for "indicator" feature type).
            normalize_distance (bool): Whether to normalize cell type
                distances to 0-1 range (only used when
                "incl_distance_features" is True).
            dtype (numpy datatype): cast feature vector to dtype
        """
        row, col = self._xy_to_rowcol(x, y)
        assert feature_type in ["indicator", "cartesian", "rowcol"]

        if feature_type == "indicator":
            ind_feature = self.cell_id_ind_feature(
                self.map_state_cell_id[row, col], incl_goal_indicator)
        elif feature_type == "cartesian":
            ind_feature = np.array([x, y])
        elif feature_type == "rowcol":
            ind_feature = np.array([row, col])

        if incl_cell_distances or incl_goal_distances:
            return np.hstack((ind_feature,
                              self.compute_grid_distance_features(
                                  incl_cell_distances, incl_goal_distances,
                                  normalize_distance)[row, col])).astype(dtype)
        else:
            return ind_feature.astype(dtype)

    def feature_at_state(self, mdp_state, feature_type="indicator",
                         incl_cell_distances=False, incl_goal_indicator=True,
                         incl_goal_distances=False, normalize_distance=False,
                         dtype=np.float32):
        """Returns feature vector at a state corresponding to MdP State.

        Args:
            mdp_state (int, int): GridWorldState object
            feature_type (str): "indicator" to use one-hot encoding of
                cell type "cartesian" to use (x, y) and "rowcol" to use
                (row, col) as feature.
            incl_cell_distances (bool): Include distance to each type of cell.
            incl_goal_distances (bool): Include distance to the goals.
            incl_goal_indicator: True - adds goal indicator feature
                If all goals have same color, it'll add single indicator
                variable for all goals, otherwise it'll use different
                indicator variables for each goal.
                (only applicable for "indicator" feature type).
            normalize_distance (bool): Whether to normalize cell type
                distances to 0-1 range (only used when
                "incl_distance_features" is True).
            dtype (numpy datatype): cast feature vector to dtype
        """
        return self.feature_at_loc(mdp_state.x, mdp_state.y, feature_type,
                                   incl_cell_distances, incl_goal_indicator,
                                   incl_goal_distances, normalize_distance,
                                   dtype)

    # -------------------
    # -- Visualization --
    # -------------------
    @staticmethod
    def __display_text(ax, x_start, x_end, y_start, y_end, values,
                       fontsize=12):
        x_size = x_end - x_start
        y_size = y_end - y_start
        x_positions = np.linspace(start=x_start, stop=x_end, num=x_size,
                                  endpoint=False)
        y_positions = np.linspace(start=y_start, stop=y_end, num=y_size,
                                  endpoint=False)
        for y_index, y in enumerate(y_positions):
            for x_index, x in enumerate(x_positions):
                label = values[y_index, x_index]
                ax.text(x, y, label, color='black', ha='center',
                        va='center', fontsize=fontsize)

    def visualize_grid(self, values=None, cmap=cm.viridis, trajectories=None,
                       subplot_str=None, new_fig=True, show_colorbar=False,
                       show_rewards_colorbar=False, state_space_cmap=True,
                       init_marker=".k", traj_marker="-k",
                       text_values=None, text_size=10,
                       traj_linewidth=0.7, init_marker_sz=10,
                       goal_marker="*c", goal_marker_sz=10,
                       end_marker="", end_marker_sz=10,
                       axis_tick_font_sz=8, title=None):
        """Visualize Navigation World.

        Args:
            values (2d ndarray): Values to be visualized in the grid,
                defaults to cell types.
            cmap (Matplotlib Colormap): Colormap corresponding to values,
                defaults to ListedColormap with colors specified in
                @self.nav_cell_types and @self.goal_cell_types
            trajectories: Trajectories to be shown on the grid.
            subplot_str (str): Subplot number string (e.g., "411", "412", etc.)
            new_fig (bool): Whether to use existing figure context.
            show_rewards_colorbar (bool): Shows colorbar with cell
                reward values.
            title (str): Title of the plot.
        """
        if new_fig:
            plt.figure(
                figsize=(max(self.height // 4, 6), max(self.width // 4, 6)))
        # Subplot if needed
        if subplot_str is not None:
            plt.subplot(subplot_str)

        # Use state space (cell types) if values is None
        if values is None:
            values = self.map_state_cell_id.copy()

        # Colormap
        if cmap is not None and state_space_cmap:
            norm = colors.Normalize(vmin=0,
                                    vmax=len(self.combined_cell_types)-1)
            # Leave string colors as it is, convert int colors to
            # normalized rgba
            cell_colors = [
                cmap(norm(cell)) if isinstance(cell, int) else cell
                for cell in self.combined_cell_types]
            cmap = colors.ListedColormap(cell_colors, N=self.n_unique_cells)

        # Plot values
        im = plt.imshow(values, interpolation='None', cmap=cmap)
        plt.title(title if title else self.name)
        ax = plt.gca()
        ax.set_xticklabels('')
        ax.set_yticklabels('')
        ax.set_xticks(np.arange(self.width), minor=True)
        ax.set_yticks(np.arange(self.height), minor=True)
        ax.set_xticklabels(1 + np.arange(self.width), minor=True,
                           fontsize=axis_tick_font_sz)
        ax.set_yticklabels(1 + np.arange(self.height)[::-1], minor=True,
                           fontsize=axis_tick_font_sz)
        # Plot Trajectories
        if trajectories is not None and len(trajectories) > 0:
            for state_seq in trajectories:
                if len(state_seq) == 0:
                    continue
                path_xs = [s.x - 1 for s in state_seq]
                path_ys = [self.height - (s.y) for s in state_seq]
                plt.plot(path_xs, path_ys, traj_marker,
                         linewidth=traj_linewidth)
                plt.plot(path_xs[0], path_ys[0], init_marker,
                         markersize=init_marker_sz)  # Mark init state
                plt.plot(path_xs[-1], path_ys[-1], end_marker,
                         markersize=end_marker_sz)  # Mark end state
        # Mark goals
        if len(self.goal_cell_locs) != 0:
            for goal_cells in self.goal_cell_locs:
                for goal_x, goal_y in goal_cells:
                    plt.plot(goal_x - 1, self.height - goal_y, goal_marker,
                             markersize=goal_marker_sz)
        # Text values on cell
        if text_values is not None:
            self.__display_text(ax, 0, self.width, 0, self.height, text_values,
                                fontsize=text_size)
        # Colorbar
        if show_colorbar:
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="3%", pad=0.05)
            if show_rewards_colorbar:
                cb = plt.colorbar(im, ticks=range(len(self.cell_type_rewards)),
                                  cax=cax)
                cb.set_ticklabels(self.cell_type_rewards)
            else:
                plt.colorbar(im, cax=cax)
        if subplot_str is None:
            plt.show()
def visualize_options_grid(grid_mdp,
                           action_abstr,
                           scr_width=720,
                           scr_height=720):
    '''
    Args:
        grid_mdp (GridWorldMDP)
        action_abstr (ActionAbstraction)
    '''
    pygame.init()
    title_font = pygame.font.SysFont("CMU Serif", 32)
    small_font = pygame.font.SysFont("CMU Serif", 22)

    if len(action_abstr.get_actions()) == 0:
        print("Options Error: 0 options found. Can't visualize.")
        sys.exit(0)

    if isinstance(grid_mdp, MDPDistribution):
        goal_locs = set([])
        for m in grid_mdp.get_all_mdps():
            for g in m.get_goal_locs():
                goal_locs.add(g)
        grid_mdp = grid_mdp.sample()
    else:
        goal_locs = grid_mdp.get_goal_locs()

    # Pygame init.
    screen = pygame.display.set_mode((scr_width, scr_height))
    pygame.init()
    screen.fill((255, 255, 255))
    pygame.display.update()
    mdp_visualizer._draw_title_text(grid_mdp, screen)
    option_text_point = scr_width / 2.0 - (14 * 7), 18 * scr_height / 20.0

    # Setup states to compute option init/term funcs.
    state_dict = defaultdict(lambda: defaultdict(None))
    vi = ValueIteration(grid_mdp)
    state_space = vi.get_states()
    for s in state_space:
        state_dict[s.x][s.y] = s

    # Draw inital option.
    option_index = 0
    opt_str = "Option " + str(option_index + 1) + " of " + str(
        len(action_abstr.get_actions()))  # + ":" + str(next_option)
    option_text = title_font.render(opt_str, True, (46, 49, 49))
    screen.blit(option_text, option_text_point)
    next_option = action_abstr.get_actions()[option_index]
    visualize_option(screen, grid_mdp, state_dict, option=next_option)

    # Initiation rect and text.
    option_text = small_font.render("Init: ", True, (46, 49, 49))
    screen.blit(option_text, (40, option_text_point[1]))
    pygame.draw.rect(screen, colors[0], (90, option_text_point[1]) + (24, 24))

    # Terminal rect and text.
    option_text = small_font.render("Term: ", True, (46, 49, 49))
    screen.blit(option_text, (scr_width - 150, option_text_point[1]))
    pygame.draw.rect(screen, colors[1],
                     (scr_width - 80, option_text_point[1]) + (24, 24))
    pygame.display.flip()

    # Keep updating options every space press.
    done = False
    while not done:
        # Check for key presses.
        for event in pygame.event.get():
            if event.type == QUIT or (event.type == KEYDOWN
                                      and event.key == K_ESCAPE):
                # Quit.
                pygame.quit()
                sys.exit()
            if event.type == KEYDOWN and event.key == K_RIGHT:
                # Toggle to the next option.
                option_index = (option_index + 1) % len(
                    action_abstr.get_actions())
            elif event.type == KEYDOWN and event.key == K_LEFT:
                # Go to the previous option.
                option_index = (option_index - 1) % len(
                    action_abstr.get_actions())
                if option_index < 0:
                    option_index = len(action_abstr.get_actions()) - 1

            next_option = action_abstr.get_actions()[option_index]
            visualize_option(screen,
                             grid_mdp,
                             state_dict,
                             option=next_option,
                             goal_locs=goal_locs)
            pygame.draw.rect(screen, (255, 255, 255),
                             (130, option_text_point[1]) +
                             (scr_width - 290, 50))
            opt_str = "Option " + str(option_index + 1) + " of " + str(
                len(action_abstr.get_actions()))  # + ":" + str(next_option)
            option_text = title_font.render(opt_str, True, (46, 49, 49))
            screen.blit(option_text, option_text_point)
Ejemplo n.º 19
0
def _draw_state(screen,
                grid_mdp,
                state,
                policy=None,
                action_char_dict={},
                show_value=False,
                agent=None,
                draw_statics=False,
                agent_shape=None,
                options=[]):
    '''
    Args:
        screen (pygame.Surface)
        grid_mdp (MDP)
        state (State)
        show_value (bool)
        agent (Agent): Used to show value, by default uses VI.
        draw_statics (bool)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # Make value dict.
    print('optinos=', options)
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            # Use agent value estimates.
            for s in agent.q_func.keys():
                val_text_dict[s.x][s.y] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(grid_mdp)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.x][s.y] = vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    if policy:
        vi = ValueIteration(grid_mdp)
        vi.run_vi()
        for s in vi.get_states():
            policy_dict[s.x][s.y] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.
    cell_width = (scr_width - width_buffer * 2) / grid_mdp.width
    cell_height = (scr_height - height_buffer * 2) / grid_mdp.height
    goal_locs = grid_mdp.get_goal_locs()
    lava_locs = grid_mdp.get_lava_locs()
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    # Draw the static entities.
    if draw_statics:
        # For each row:
        for i in range(grid_mdp.width):
            # For each column:
            for j in range(grid_mdp.height):

                top_left_point = width_buffer + cell_width * i, height_buffer + cell_height * j
                r = pygame.draw.rect(
                    screen, (46, 49, 49),
                    top_left_point + (cell_width, cell_height), 3)

                if policy and not grid_mdp.is_wall(i + 1, grid_mdp.height - j):
                    a = policy_dict[i + 1][grid_mdp.height - j]
                    if a not in action_char_dict:
                        text_a = a
                    else:
                        text_a = action_char_dict[a]
                    text_center_point = int(top_left_point[0] +
                                            cell_width / 2.0 -
                                            10), int(top_left_point[1] +
                                                     cell_height / 3.0)
                    text_rendered_a = cc_font.render(text_a, True,
                                                     (46, 49, 49))
                    screen.blit(text_rendered_a, text_center_point)

                if show_value and not grid_mdp.is_wall(i + 1,
                                                       grid_mdp.height - j):
                    # Draw the value.
                    val = val_text_dict[i + 1][grid_mdp.height - j]
                    color = mdpv.val_to_color(val)
                    pygame.draw.rect(
                        screen, color,
                        top_left_point + (cell_width, cell_height), 0)

                if grid_mdp.is_wall(i + 1, grid_mdp.height - j):
                    # Draw the walls.
                    top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                    r = pygame.draw.rect(
                        screen, (94, 99, 99),
                        top_left_point + (cell_width - 10, cell_height - 10),
                        0)

                if (i + 1, grid_mdp.height - j) in goal_locs:
                    # Draw goal.
                    # TODO: Better visualization?
                    # circle_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0)
                    # circler_color = (154, 195, 157)
                    # pygame.draw.circle(screen, circler_color, circle_center, int(min(cell_width, cell_height) / 3.0))
                    pass

                if (i + 1, grid_mdp.height - j) in lava_locs:
                    # Draw goal.
                    circle_center = int(top_left_point[0] + cell_width /
                                        2.0), int(top_left_point[1] +
                                                  cell_height / 2.0)
                    circler_color = (224, 145, 157)
                    pygame.draw.circle(screen, circler_color, circle_center,
                                       int(min(cell_width, cell_height) / 4.0))

                # print('options')
                # print(i+1)
                # print(grid_mdp.height - j)
                # print(options)
                if (i + 1, j + 1) in options:
                    # Draw options.
                    # print('Needs to draw options at', i+1, '_', grid_mdp.height - j)
                    #circle_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0)
                    #circler_color = (200, 200, 0)
                    #
                    #pygame.draw.circle(screen, circler_color, circle_center, int(min(cell_width, cell_height) / 4.0))

                    # Add a number for the option
                    indices = [
                        k for k, x in enumerate(options) if x == (i + 1, j + 1)
                    ]
                    for index in indices:
                        ind = int(index / 2) + 1
                        circle_center = int(
                            top_left_point[0] + cell_width / 2.0) + int(
                                cell_width / 6.0 *
                                (ind + 1 - len(options) / 2)), int(
                                    top_left_point[1] + cell_height / 2.0)
                        circler_color = (200, 200, 0)

                        pygame.draw.circle(
                            screen, circler_color, circle_center,
                            int(min(cell_width, cell_height) / 4.0))

                    for index in indices:
                        ind = int(index / 2) + 1
                        print('INDEX=', ind)
                        font = pygame.font.SysFont(None, 24)
                        text = font.render(str(ind), True, (0, 0, 0),
                                           (200, 200, 0))
                        textrect = text.get_rect()
                        textrect.centerx = int(
                            top_left_point[0] + cell_width / 2.0) + int(
                                cell_width / 6.0 *
                                (ind + 1 - len(options) / 2))
                        textrect.centery = int(top_left_point[1] +
                                               cell_height / 2.0)
                        screen.blit(text, textrect)

                # Current state.
                # if not show_value and (i+1,grid_mdp.height - j) == (state.x, state.y) and agent_shape is None:
                #     tri_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0)
                #     agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height)/2.5 - 8)

    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255, 255, 255), agent_shape)
        top_left_point = width_buffer + cell_width * (
            state.x - 1), height_buffer + cell_height * (grid_mdp.height -
                                                         state.y)
        tri_center = int(top_left_point[0] +
                         cell_width / 2.0), int(top_left_point[1] +
                                                cell_height / 2.0)

        # Draw new.
        agent_shape = _draw_agent(
            tri_center,
            screen,
            base_size=min(cell_width, cell_height) / 2.5 - 8)

    pygame.display.flip()

    return agent_shape
Ejemplo n.º 20
0
def _draw_augmented_state(screen,
                          taxi_oomdp,
                          state,
                          policy=None,
                          action_char_dict={},
                          show_value=False,
                          agent=None,
                          draw_statics=True,
                          agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        taxi_oomdp (TaxiOOMDP)
        state (State)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # There are multiple potential states for each grid cell (e.g. state also depends on whether the taxi currently has
    # the passenger or not), but the value and the policy for each cell is simply given by the most recent state
    # returned by get_states(). Began trying to display at least two values and optimal actions for each cell (depending
    # on the onboarding status of the passenger), but quickly realized that it gets very complicated as the MDP gets
    # more complicated (e.g. state also depends on the location of the passenger).
    # Displaying multiple values and optimal actions and will require either handcrafting the pipeline or
    # investing a lot of time into making the pipeline customizable and robust. Leaving incomplete attempt below as
    # commented out code.

    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    # val_text_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
    if show_value:
        if agent is not None:
            if agent.name == 'Q-learning':
                # Use agent value estimates.
                for s in agent.q_func.keys():
                    val_text_dict[s.get_agent_x()][
                        s.get_agent_y()] = agent.get_value(s)
                    # val_text_dict[s.get_agent_x()][s.get_agent_y()][
                    #     s.get_first_obj_of_class("passenger")["in_taxi"]] += agent.get_value(s)
            # slightly abusing the distinction between agents and planning modules...
            else:
                for s in taxi_oomdp.get_states():
                    val_text_dict[s.get_agent_x()][
                        s.get_agent_y()] = agent.get_value(s)
                    # val_text_dict[s.get_agent_x()][s.get_agent_y()][
                    #     s.get_first_obj_of_class("passenger")["in_taxi"]] += agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(taxi_oomdp, sample_rate=10)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.get_agent_x()][s.get_agent_y()] = vi.get_value(
                    s)
                # val_text_dict[s.get_agent_x()][s.get_agent_y()][
                #     s.get_first_obj_of_class("passenger")["in_taxi"]] += vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    # policy_dict = defaultdict(lambda: defaultdict(lambda : defaultdict(str)))
    if policy:
        for s in taxi_oomdp.get_states():
            policy_dict[s.get_agent_x()][s.get_agent_y()] = policy(s)
            # if policy_dict[s.get_agent_x()][s.get_agent_y()][s.get_first_obj_of_class("passenger")["in_taxi"]] != '':
            #     policy_dict[s.get_agent_x()][s.get_agent_y()][s.get_first_obj_of_class("passenger")["in_taxi"]] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.
    cell_width = (scr_width - width_buffer * 2) / taxi_oomdp.width
    cell_height = (scr_height - height_buffer * 2) / taxi_oomdp.height
    objects = state.get_objects()
    agent_x, agent_y = objects["agent"][0]["x"], objects["agent"][0]["y"]
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255, 255, 255), agent_shape)

    # Statics
    if draw_statics:
        # Draw walls.
        for w in taxi_oomdp.walls:
            w_x, w_y = w["x"], w["y"]
            top_left_point = width_buffer + cell_width * (
                w_x - 1) + 5, height_buffer + cell_height * (
                    taxi_oomdp.height - w_y) + 5
            pygame.draw.rect(
                screen, (46, 49, 49),
                top_left_point + (cell_width - 10, cell_height - 10), 0)

        # Draw tolls.
        for t in taxi_oomdp.tolls:
            t_x, t_y = t["x"], t["y"]
            top_left_point = width_buffer + cell_width * (
                t_x - 1) + 5, height_buffer + cell_height * (
                    taxi_oomdp.height - t_y) + 5
            # Clear the space and redraw with correct transparency (instead of simply adding a new layer which would
            # affect the transparency
            pygame.draw.rect(
                screen, (255, 255, 255),
                top_left_point + (cell_width - 10, cell_height - 10), 0)
            pygame.gfxdraw.box(
                screen, top_left_point + (cell_width - 10, cell_height - 10),
                (224, 230, 67))

        # Draw traffic cells.
        for t in taxi_oomdp.traffic_cells:
            t_x, t_y = t["x"], t["y"]
            top_left_point = width_buffer + cell_width * (
                t_x - 1) + 5, height_buffer + cell_height * (
                    taxi_oomdp.height - t_y) + 5
            # Clear the space and redraw with correct transparency (instead of simply adding a new layer which would
            # affect the transparency
            pygame.draw.rect(
                screen, (255, 255, 255),
                top_left_point + (cell_width - 10, cell_height - 10), 0)
            pygame.gfxdraw.box(
                screen, top_left_point + (cell_width - 10, cell_height - 10),
                (58, 28, 232))

        # Draw fuel stations.
        for f in taxi_oomdp.fuel_stations:
            f_x, f_y = f["x"], f["y"]
            top_left_point = width_buffer + cell_width * (
                f_x - 1) + 5, height_buffer + cell_height * (
                    taxi_oomdp.height - f_y) + 5
            pygame.draw.rect(
                screen, (144, 0, 255),
                top_left_point + (cell_width - 10, cell_height - 10), 0)

    # Draw the destination.
    for i, p in enumerate(objects["passenger"]):
        # Dest.
        dest_x, dest_y = p["dest_x"], p["dest_y"]
        top_left_point = int(width_buffer + cell_width * (dest_x - 1) +
                             27), int(height_buffer + cell_height *
                                      (taxi_oomdp.height - dest_y) + 14)
        dest_col = (int(max(color_ls[-i - 1][0] - 30,
                            0)), int(max(color_ls[-i - 1][1] - 30, 0)),
                    int(max(color_ls[-i - 1][2] - 30, 0)))
        pygame.draw.rect(screen, dest_col,
                         top_left_point + (cell_width / 6, cell_height / 6), 0)

    # Draw new agent.
    top_left_point = width_buffer + cell_width * (
        agent_x - 1), height_buffer + cell_height * (taxi_oomdp.height -
                                                     agent_y)
    agent_center = int(top_left_point[0] +
                       cell_width / 2.0), int(top_left_point[1] +
                                              cell_height / 2.0)
    agent_shape = _draw_agent(agent_center,
                              screen,
                              base_size=min(cell_width, cell_height) / 2.5 - 4)

    # Draw the passengers.
    for i, p in enumerate(objects["passenger"]):
        # Passenger
        pass_x, pass_y = p["x"], p["y"]
        taxi_size = int(min(cell_width, cell_height) / 9.0)
        if p["in_taxi"]:
            top_left_point = int(width_buffer + cell_width * (pass_x - 1) +
                                 taxi_size +
                                 58), int(height_buffer + cell_height *
                                          (taxi_oomdp.height - pass_y) +
                                          taxi_size + 16)
        else:
            top_left_point = int(width_buffer + cell_width * (pass_x - 1) +
                                 taxi_size +
                                 26), int(height_buffer + cell_height *
                                          (taxi_oomdp.height - pass_y) +
                                          taxi_size + 38)
        dest_col = (max(color_ls[-i - 1][0] - 30,
                        0), max(color_ls[-i - 1][1] - 30,
                                0), max(color_ls[-i - 1][2] - 30, 0))
        pygame.draw.circle(screen, dest_col, top_left_point, taxi_size)

    if draw_statics:
        # For each row:
        for i in range(taxi_oomdp.width):
            # For each column:
            for j in range(taxi_oomdp.height):
                top_left_point = width_buffer + cell_width * i, height_buffer + cell_height * j
                r = pygame.draw.rect(
                    screen, (46, 49, 49),
                    top_left_point + (cell_width, cell_height), 3)

                # Show value of states.
                if show_value and not taxi_helpers.is_wall(
                        taxi_oomdp, i + 1, taxi_oomdp.height - j):
                    # Draw the value.
                    val = val_text_dict[i + 1][taxi_oomdp.height - j]
                    color = mdpv.val_to_color(val)
                    pygame.draw.rect(
                        screen, color,
                        top_left_point + (cell_width, cell_height), 0)
                    value_text = reg_font.render(str(round(val, 2)), True,
                                                 (46, 49, 49))
                    text_center_point = int(top_left_point[0] +
                                            cell_width / 2.0 -
                                            10), int(top_left_point[1] +
                                                     cell_height / 3.0)
                    screen.blit(value_text, text_center_point)

                    # Draw the value depending on the status of the passenger (incomplete)
                    # val_1 = val_text_dict[s.get_agent_x()][s.get_agent_y()][0]  # passenger not in taxi
                    # val_2 = val_text_dict[s.get_agent_x()][s.get_agent_y()][1]  # passenger not in taxi
                    # color = mdpv.val_to_color((val_1 + val_2) / 2.)
                    # pygame.draw.rect(screen, color, top_left_point + (cell_width, cell_height), 0)
                    #
                    # value_text = reg_font.render(str(round(val_1, 2)), True, (46, 49, 49))
                    # text_center_point = int(top_left_point[0] + cell_width / 2.0 - 10), int(
                    #     top_left_point[1] + cell_height / 1.5)
                    # screen.blit(value_text, text_center_point)
                    #
                    # value_text = reg_font.render(str(round(val_2, 2)), True, (46, 49, 49))
                    # text_center_point = int(top_left_point[0] + cell_width / 2.0 - 10), int(
                    #     top_left_point[1] + cell_height / 4.5)
                    # screen.blit(value_text, text_center_point)

                # Show optimal action to take in each grid cell.
                if policy and not taxi_helpers.is_wall(taxi_oomdp, i + 1,
                                                       taxi_oomdp.height - j):
                    a = policy_dict[i + 1][taxi_oomdp.height - j]
                    if a not in action_char_dict:
                        text_a = a
                    else:
                        text_a = action_char_dict[a]
                    text_center_point = int(top_left_point[0] +
                                            cell_width / 2.0 -
                                            10), int(top_left_point[1] +
                                                     cell_height / 3.0)
                    text_rendered_a = cc_font.render(text_a, True,
                                                     (46, 49, 49))
                    screen.blit(text_rendered_a, text_center_point)

                    # Draw the policy depending on the status of the passenger (incomplete)
                    # a_1 = policy_dict[i + 1][taxi_oomdp.height - j][0]
                    # a_2 = policy_dict[i + 1][taxi_oomdp.height - j][0]
                    # if a_1 not in action_char_dict: text_a_1 = a_1
                    # else: text_a_1 = action_char_dict[a_1]
                    # if a_2 not in action_char_dict: text_a_2 = a_2
                    # else: text_a_2 = action_char_dict[a_2]
                    # text_center_point = int(top_left_point[0] + cell_width/2.0 - 10), int(top_left_point[1] + cell_height/1.5)
                    # text_rendered_a = cc_font.render(text_a_1, True, (46, 49, 49))
                    # screen.blit(text_rendered_a, text_center_point)
                    # text_center_point = int(top_left_point[0] + cell_width / 2.0 - 10), int(
                    #     top_left_point[1] + cell_height / 4.5)
                    # text_rendered_a = cc_font.render(text_a_2, True, (46, 49, 49))
                    # screen.blit(text_rendered_a, text_center_point)

    pygame.display.flip()

    return agent_shape
Ejemplo n.º 21
0
def visualize_options_grid(grid_mdp, action_abstr, scr_width=720, scr_height=720):
    '''
    Args:
        grid_mdp (GridWorldMDP)
        action_abstr (ActionAbstraction)
    '''
    pygame.init()
    title_font = pygame.font.SysFont("CMU Serif", 32)
    small_font = pygame.font.SysFont("CMU Serif", 22)

    if len(action_abstr.get_actions()) == 0:
        print("Options Error: 0 options found. Can't visualize.")
        sys.exit(0)

    if isinstance(grid_mdp, MDPDistribution):
        goal_locs = set([])
        for m in grid_mdp.get_all_mdps():
            for g in m.get_goal_locs():
                goal_locs.add(g)
        grid_mdp = grid_mdp.sample()
    else:
        goal_locs = grid_mdp.get_goal_locs()

    # Pygame init.  
    screen = pygame.display.set_mode((scr_width, scr_height))
    pygame.init()
    screen.fill((255, 255, 255))
    pygame.display.update()
    mdp_visualizer._draw_title_text(grid_mdp, screen)
    option_text_point = scr_width / 2.0 - (14*7), 18*scr_height / 20.0

    # Setup states to compute option init/term funcs.
    state_dict = defaultdict(lambda : defaultdict(None))
    vi = ValueIteration(grid_mdp)
    state_space = vi.get_states()
    for s in state_space:
        state_dict[s.x][s.y] = s

    # Draw inital option.
    option_index = 0
    opt_str = "Option " + str(option_index + 1) + " of " + str(len(action_abstr.get_actions())) # + ":" + str(next_option)
    option_text = title_font.render(opt_str, True, (46, 49, 49))
    screen.blit(option_text, option_text_point)
    next_option = action_abstr.get_actions()[option_index]
    visualize_option(screen, grid_mdp, state_dict, option=next_option)

    # Initiation rect and text.
    option_text = small_font.render("Init: ", True, (46, 49, 49))
    screen.blit(option_text, (40, option_text_point[1]))
    pygame.draw.rect(screen, colors[0], (90, option_text_point[1]) + (24, 24))

    # Terminal rect and text.
    option_text = small_font.render("Term: ", True, (46, 49, 49))
    screen.blit(option_text, (scr_width - 150, option_text_point[1]))
    pygame.draw.rect(screen, colors[1], (scr_width - 80, option_text_point[1]) + (24, 24))
    pygame.display.flip()

    # Keep updating options every space press.
    done = False
    while not done:
        # Check for key presses.
        for event in pygame.event.get():
            if event.type == QUIT or (event.type == KEYDOWN and event.key == K_ESCAPE):
                # Quit.
                pygame.quit()
                sys.exit()
            if event.type == KEYDOWN and event.key == K_RIGHT:
                # Toggle to the next option.
                option_index = (option_index + 1) % len(action_abstr.get_actions())
            elif event.type == KEYDOWN and event.key == K_LEFT:
                # Go to the previous option.
                option_index = (option_index - 1) % len(action_abstr.get_actions())
                if option_index < 0:
                    option_index = len(action_abstr.get_actions()) - 1

            next_option = action_abstr.get_actions()[option_index]
            visualize_option(screen, grid_mdp, state_dict, option=next_option, goal_locs=goal_locs)
            pygame.draw.rect(screen, (255, 255, 255), (130, option_text_point[1]) + (scr_width-290 , 50))
            opt_str = "Option " + str(option_index + 1) + " of " + str(len(action_abstr.get_actions())) # + ":" + str(next_option)
            option_text = title_font.render(opt_str, True, (46, 49, 49))
            screen.blit(option_text, option_text_point)
Ejemplo n.º 22
0
def old_draw_state(screen,
                   cleanup_mdp,
                   state,
                   policy=None,
                   action_char_dict={},
                   show_value=False,
                   agent=None,
                   draw_statics=False,
                   agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        grid_mdp (MDP)
        state (State)
        show_value (bool)
        agent (Agent): Used to show value, by default uses VI.
        draw_statics (bool)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''

    print('Inside draw state\n\n\n\n')
    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            # Use agent value estimates.
            for s in agent.q_func.keys():
                val_text_dict[s.x][s.y] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(cleanup_mdp)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.x][s.y] = vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    if policy:
        vi = ValueIteration(cleanup_mdp)
        vi.run_vi()
        for s in vi.get_states():
            policy_dict[s.x][s.y] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.

    width = cleanup_mdp.width
    height = cleanup_mdp.height

    cell_width = (scr_width - width_buffer * 2) / width
    cell_height = (scr_height - height_buffer * 2) / height
    # goal_locs = grid_mdp.get_goal_locs()
    # lava_locs = grid_mdp.get_lavacc_locs()
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)

    # room_locs = [(x + 1, y + 1) for room in cleanup_mdp.rooms for (x, y) in room.points_in_room]
    door_locs = set([(door.x + 1, door.y + 1) for door in state.doors])

    # Draw the static entities.
    # print(draw_statics)
    # draw_statics = True
    # if draw_statics:
    # For each row:

    for i in range(width):
        # For each column:
        for j in range(height):

            top_left_point = width_buffer + cell_width * i, height_buffer + cell_height * j
            r = pygame.draw.rect(screen, (46, 49, 49),
                                 top_left_point + (cell_width, cell_height), 3)
            '''
            # if policy and not grid_mdp.is_wall(i+1, height - j):
            if policy and (i + 1, height - j) in cleanup_mdp.legal_states:
                a = policy_dict[i + 1][height - j]
                if a not in action_char_dict:
                    text_a = a
                else:
                    text_a = action_char_dict[a]
                text_center_point = int(top_left_point[0] + cell_width / 2.0 - 10), int(
                    top_left_point[1] + cell_height / 3.0)
                text_rendered_a = cc_font.render(text_a, True, (46, 49, 49))
                screen.blit(text_rendered_a, text_center_point)

            # if show_value and not grid_mdp.is_wall(i+1, grid_mdp.height - j):
            if show_value and (i + 1, height - j) in cleanup_mdp.legal_states:
                # Draw the value.
                val = val_text_dict[i + 1][height - j]
                color = mdpv.val_to_color(val)
                pygame.draw.rect(screen, color, top_left_point + (cell_width, cell_height), 0)
                # text_center_point = int(top_left_point[0] + cell_width/2.0 - 10), int(top_left_point[1] + cell_height/7.0)
                # text = str(round(val,2))
                # text_rendered = reg_font.render(text, True, (46, 49, 49))
                # screen.blit(text_rendered, text_center_point)
            '''

            # if grid_mdp.is_wall(i+1, grid_mdp.height - j):
            if (i + 1, height - j) not in cleanup_mdp.legal_states:
                # Draw the walls.
                top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                pygame.draw.rect(
                    screen, (94, 99, 99),
                    top_left_point + (cell_width - 10, cell_height - 10), 0)

            if (i + 1, height - j) in door_locs:
                # Draw door
                # door_color = (66, 83, 244)
                door_color = (0, 0, 0)
                top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                pygame.draw.rect(
                    screen, door_color,
                    top_left_point + (cell_width - 10, cell_height - 10), 0)

            else:
                room = cleanup_mdp.check_in_room(
                    state.rooms, i + 1 - 1,
                    height - j - 1)  # Minus 1 for inconsistent x, y
                if room:
                    top_left_point = width_buffer + cell_width * i + 5, height_buffer + cell_height * j + 5
                    room_rgb = _get_rgb(room.color)
                    pygame.draw.rect(
                        screen, room_rgb,
                        top_left_point + (cell_width - 10, cell_height - 10),
                        0)

            block = cleanup_mdp.find_block(state.blocks, i + 1 - 1,
                                           height - j - 1)
            # print(state)
            # print(block)
            '''
            # ROMA: to draw objects if needed
            if block:
                circle_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)
                block_rgb = _get_rgb(block.color)
                pygame.draw.circle(screen, block_rgb, circle_center, int(min(cell_width, cell_height) / 4.0))
            
            # Current state.
            # ROMA: to draw the agent if needed
            if not show_value and (i + 1, height - j) == (state.x + 1, state.y + 1) and agent_shape is None:
                tri_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)
                agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height) / 2.5 - 8)
            '''
    '''
    if agent_shape is not None:
        # Clear the old shape.
        pygame.draw.rect(screen, (255, 255, 255), agent_shape)
        top_left_point = width_buffer + cell_width * ((state.x + 1) - 1), height_buffer + cell_height * (
                height - (state.y + 1))
        tri_center = int(top_left_point[0] + cell_width / 2.0), int(top_left_point[1] + cell_height / 2.0)

        # Draw new.
        # if not show_value or policy is not None:
        agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height) / 2.5 - 16)
    '''
    pygame.display.flip()

    return agent_shape
Ejemplo n.º 23
0
def info_sa_planning_experiment(min_grid_size=5, max_grid_size=11, beta=10.0):
    '''
    Args:
        min_grid_size (int)
        max_grid_size (int)
        beta (float): Hyperparameter for InfoSA.

    Summary:
        Writes num iterations and time (seconds) for planning with and without abstractions.
    '''
    vanilla_file = "vi.csv"
    sa_file = "vi-$\\phi$.csv"
    file_prefix = os.path.join("results", "planning-four_room")
    
    clear_files(dir_name=file_prefix)

    for grid_dim in xrange(min_grid_size, max_grid_size + 1):
        # ======================
        # == Make Environment ==
        # ======================
        mdp = FourRoomMDP(width=grid_dim, height=grid_dim, init_loc=(1, 1), goal_locs=[(grid_dim, grid_dim)], gamma=0.9)
        
        # Get demo policy.
        vi = ValueIteration(mdp)
        vi.run_vi()
        demo_policy = get_lambda_policy(make_det_policy_eps_greedy(vi.policy, vi.get_states(), mdp.get_actions(), epsilon=0.2))

        # =======================
        # == Make Abstractions ==
        # =======================
        pmf_s_phi, phi_pmf, abstr_policy = run_info_sa(mdp, demo_policy, iters=500, beta=beta, convergence_threshold=0.00001)
        lambda_abstr_policy = get_lambda_policy(abstr_policy)
        prob_s_phi = ProbStateAbstraction(phi_pmf)
        crisp_s_phi = convert_prob_sa_to_sa(prob_s_phi)

        # ============
        # == Run VI ==
        # ============
        vanilla_vi = ValueIteration(mdp, delta=0.0001, sample_rate=25)
        sa_vi = AbstractValueIteration(ground_mdp=mdp, state_abstr=crisp_s_phi, delta=0.0001, vi_sample_rate=25, amdp_sample_rate=25)

        # ==========
        # == Plan ==
        # ==========
        print "Running VIs."
        start_time = time.clock()
        vanilla_iters, vanilla_val = vanilla_vi.run_vi()
        vanilla_time = round(time.clock() - start_time, 2)

        mdp.reset()
        start_time = time.clock()
        sa_iters, sa_abs_val = sa_vi.run_vi()
        sa_time = round(time.clock() - start_time, 2)
        sa_val = evaluate_agent(FixedPolicyAgent(sa_vi.policy), mdp, instances=25)

        print "\n" + "*"*20
        print "Vanilla", "\n\t Iters:", vanilla_iters, "\n\t Value:", round(vanilla_val, 4), "\n\t Time:", vanilla_time
        print 
        print "Phi:", "\n\t Iters:", sa_iters, "\n\t Value:", round(sa_val, 4), "\n\t Time:", sa_time
        print "*"*20 + "\n\n"

        write_datum(os.path.join(file_prefix, "iters", vanilla_file), vanilla_iters)
        write_datum(os.path.join(file_prefix, "iters", sa_file), sa_iters)

        write_datum(os.path.join(file_prefix, "times", vanilla_file), vanilla_time)
        write_datum(os.path.join(file_prefix, "times", sa_file), sa_time)
Ejemplo n.º 24
0
def _draw_state(screen,
                pudd_mdp,
                state,
                policy=None,
                action_char_dict={},
                show_value=False,
                agent=None,
                draw_statics=False,
                agent_shape=None):
    '''
    Args:
        screen (pygame.Surface)
        pudd_mdp (MDP)
        state (State)
        show_value (bool)
        agent (Agent): Used to show value, by default uses VI.
        draw_statics (bool)
        agent_shape (pygame.rect)

    Returns:
        (pygame.Shape)
    '''
    # Make value dict.
    val_text_dict = defaultdict(lambda: defaultdict(float))
    if show_value:
        if agent is not None:
            # Use agent value estimates.
            for s in agent.q_func.keys():
                val_text_dict[s.x][s.y] = agent.get_value(s)
        else:
            # Use Value Iteration to compute value.
            vi = ValueIteration(pudd_mdp)
            vi.run_vi()
            for s in vi.get_states():
                val_text_dict[s.x][s.y] = vi.get_value(s)

    # Make policy dict.
    policy_dict = defaultdict(lambda: defaultdict(str))
    if policy:
        vi = ValueIteration(pudd_mdp)
        vi.run_vi()
        for s in vi.get_states():
            policy_dict[s.x][s.y] = policy(s)

    # Prep some dimensions to make drawing easier.
    scr_width, scr_height = screen.get_width(), screen.get_height()
    width_buffer = scr_width / 10.0
    height_buffer = 30 + (scr_height / 10.0)  # Add 30 for title.
    cell_width = (scr_width - width_buffer * 2) / 10  #pudd_mdp.width
    cell_height = (scr_height - height_buffer * 2) / 10  # pudd_mdp.height
    goal_locs = pudd_mdp.get_goal_locs()
    # puddle_rects = pudd_mdp.get_puddle_rects()
    font_size = int(min(cell_width, cell_height) / 4.0)
    reg_font = pygame.font.SysFont("CMU Serif", font_size)
    cc_font = pygame.font.SysFont("Courier", font_size * 2 + 2)
    delta = 0.2
    #print ("goal locs", goal_locs)
    # Draw the static entities.
    if draw_statics:
        # For each row:
        for i in np.linspace(0, 1, 11):
            # For each column:
            for j in np.linspace(0, 1, 11):
                #print ("i,j",i,j)

                top_left_point = width_buffer + cell_width * i * 10, height_buffer + cell_height * j * 10
                r = pygame.draw.rect(
                    screen, (46, 49, 49),
                    top_left_point + (cell_width, cell_height), 3)

                if pudd_mdp.is_puddle_location(i, j):
                    # Draw the walls.
                    #print ("True for ", i, j)
                    top_left_point = width_buffer + cell_width * (
                        i * 10) + 5, height_buffer + cell_height * j * 10 + 5
                    r = pygame.draw.rect(
                        screen, (0, 127, 255),
                        top_left_point + (cell_width - 10, cell_height - 10),
                        0)

                if [min(i + delta, 1.0), min(j + delta, 1.0)] in goal_locs:
                    # Draw goal.
                    circle_center = int(top_left_point[0] + cell_width /
                                        2.0), int(top_left_point[1] +
                                                  cell_height / 2.0)
                    circler_color = (154, 195, 157)
                    pygame.draw.circle(screen, circler_color, circle_center,
                                       int(min(cell_width, cell_height) / 3.0))

                # Current state.
                if not show_value and (round(i, 1), round(j, 1)) == (round(
                        state.x, 1), round(state.y,
                                           1)) and agent_shape is None:
                    tri_center = int(top_left_point[0] +
                                     cell_width / 2.0), int(top_left_point[1] +
                                                            cell_height / 2.0)
                    agent_shape = _draw_agent(
                        tri_center,
                        screen,
                        base_size=min(cell_width, cell_height) / 2.5 - 8)


#    if agent_shape is not None:
#        # Clear the old shape.
#        pygame.draw.rect(screen, (255,255,255), agent_shape)
#        top_left_point = width_buffer + cell_width*(state.x - 1), height_buffer + cell_height*(pudd_mdp.height - state.y)
#        tri_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0)

# Draw new.
#        agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height)/2.5 - 8)

    pygame.display.flip()

    return agent_shape
Ejemplo n.º 25
0
def make_abstr_mdp(mdp, state_abstr, action_abstr, sample_rate=25):
    '''
	Args:
		mdp (MDP)
		state_abstr (StateAbstraction)
		action_abstr (ActionAbstraction)
		sample_rate (int): Sample rate for computing the abstract R and T.

	Returns:
		(MDP)
	'''

    # Grab ground state space.
    vi = ValueIteration(mdp)
    state_space = vi.get_states()

    # Make abstract reward and transition functions.
    def abstr_reward_lambda(abstr_state, abstr_action):
        # Get relevant MDP components from the lower MDP.
        lower_states = state_abstr.get_lower_states_in_abs_state(abstr_state)
        lower_reward_func = mdp.get_reward_func()
        lower_trans_func = mdp.get_transition_func()

        # Compute reward.
        total_reward = 0
        for ground_s in lower_states:
            for sample in xrange(sample_rate):
                s_prime, reward = abstr_action.rollout(ground_s,
                                                       lower_reward_func,
                                                       lower_trans_func)
                total_reward += float(reward) / (
                    len(lower_states) * sample_rate)  # Add weighted reward.

        # print "~"*20
        # print "R_A:", abstr_state, abstr_action, total_reward
        # print "~"*20

        return total_reward

    def abstr_transition_lambda(abstr_state, abstr_action):
        # print "Abstr Transition Func:"
        # print "\t abstr_state:", abstr_state
        # Get relevant MDP components from the lower MDP.
        lower_states = state_abstr.get_lower_states_in_abs_state(abstr_state)
        lower_reward_func = mdp.get_reward_func()
        lower_trans_func = mdp.get_transition_func()

        # Compute next state distribution.
        s_prime_prob_dict = defaultdict(int)
        total_reward = 0
        for ground_s in lower_states:
            for sample in xrange(sample_rate):
                s_prime, reward = abstr_action.rollout(ground_s,
                                                       lower_reward_func,
                                                       lower_trans_func)
                s_prime_prob_dict[s_prime] += (
                    1.0 / (len(lower_states) * sample_rate)
                )  # Weighted average.

        # Form distribution and sample s_prime.
        end_ground_state = s_prime_prob_dict.keys()[list(
            np.random.multinomial(
                1, s_prime_prob_dict.values()).tolist()).index(1)]
        end_abstr_state = state_abstr.phi(end_ground_state,
                                          level=abstr_state.get_level())

        return end_abstr_state

    # Make the components of the MDP.
    abstr_init_state = state_abstr.phi(mdp.get_init_state())
    abstr_action_space = action_abstr.get_actions()
    abstr_state_space = state_abstr.get_abs_states()
    abstr_reward_func = RewardFunc(abstr_reward_lambda, abstr_state_space,
                                   abstr_action_space)
    abstr_transition_func = TransitionFunc(abstr_transition_lambda,
                                           abstr_state_space,
                                           abstr_action_space,
                                           sample_rate=sample_rate)

    # Make the MDP.
    abstr_mdp = MDP(actions=abstr_action_space,
                    init_state=abstr_init_state,
                    reward_func=abstr_reward_func.reward_func,
                    transition_func=abstr_transition_func.transition_func,
                    gamma=0.5)

    return abstr_mdp
Ejemplo n.º 26
0
def make_policy_blocks_options(mdp_distr, num_options, task_samples):
    '''
    Args:
        mdp_distr (MDPDistribution)
        num_options (int)
        task_samples (int)

    Returns:
        (list): Contains policy blocks options.
    '''
    option_set = []
    # Fill solution set for task_samples draws from MDP distribution
    L = []
    for new_task in xrange(task_samples):
        print "  Sample " + str(new_task + 1) + " of " + str(task_samples) + "."

        # Sample the MDP.
        mdp = mdp_distr.sample()

        # Run VI to get a policy for the MDP as well as the list of states
        print "\tRunning VI...",
        sys.stdout.flush()
        # Run VI
        vi = ValueIteration(mdp, delta=0.0001, max_iterations=5000, sample_rate=5)
        iters, val = vi.run_vi()
        print " done."

        policy = make_dict_from_lambda(vi.policy, vi.get_states())
        L.append(policy)

    power_L = get_power_set(L)
    num_iters = 1
    print 'Beginning policy blocks for {2} options with {0} solution policies and power set of size {1}'\
        .format(len(L), len(power_L), num_options)

    while len(power_L) > 0 and len(option_set) < num_options:
        print 'Running iteration {0} of policy blocks...'.format(num_iters)
        # Initialize empty set of candidate option policies
        C = []
        # Iterate over the power set of solution policies
        for policy_set in power_L:
            # Compute candidate policy as merge over policy set
            candidate = policy_blocks_merge(policy_set)
            if candidate not in C:
                # Compute score of each candidate policy
                C.append((candidate, policy_blocks_score_policy(candidate, L)))
        # Identify the candidate policy with highest score and add to option set
        C = sorted(C, key=lambda x: x[1])
        pi_star = C[-1][0]
        if pi_star not in option_set:
            option_set.append(pi_star)

        # Subtract chosen candidate from L by iterating through power set
        power_L = map(lambda policy_set: [policy_blocks_subtract_pair(p, pi_star) for p in policy_set], power_L)

        # Remove empty elements of power set
        power_L = filter(lambda policy_set: sum(map(lambda x: len(x), policy_set)) > 0, power_L)

        num_iters += 1

    # Generate true option set
    ret = []
    for o in option_set:
        init_predicate = CovPredicate(y=True, policy=o)
        term_predicate = CovPredicate(y=False, policy=o)
        print map(str, o.keys())
        print o.values()
        print '**'
        opt = Option(init_predicate=init_predicate, term_predicate=term_predicate, policy=o)
        ret.append(opt)

    print 'Policy blocks returning with {0} options'.format(len(ret))

    return ret