Beispiel #1
0
def visualize_search_path(search_policy, eval_env, difficulty=0.5):
    set_env_difficulty(eval_env, difficulty)

    if search_policy.open_loop:
        state = eval_env.reset()
        start = state['observation']
        goal = state['goal']

        search_policy.select_action(state)
        waypoints = search_policy.get_waypoints()
    else:
        goal, observations, waypoints, _ = Collector.get_trajectory(search_policy, eval_env)
        start = observations[0]

    plt.figure(figsize=(6, 6))
    plot_walls(eval_env.walls)

    waypoint_vec = np.array(waypoints)

    print(f'waypoints: {waypoint_vec}')
    print(f'waypoints shape: {waypoint_vec.shape}')
    print(f'start: {start}')
    print(f'goal: {goal}')

    plt.scatter([start[0]], [start[1]], marker='+',
                color='red', s=200, label='start')
    plt.scatter([goal[0]], [goal[1]], marker='*',
                color='green', s=200, label='goal')
    plt.plot(waypoint_vec[:, 0], waypoint_vec[:, 1], 'y-s', alpha=0.3, label='waypoint')
    plt.legend(loc='lower left', bbox_to_anchor=(-0.1, -0.15), ncol=4, fontsize=16)
    plt.show()
Beispiel #2
0
def take_cleanup_steps(search_policy, eval_env, num_cleanup_steps):
    set_env_difficulty(eval_env, 0.95)

    search_policy.set_cleanup(True)
    cleanup_start = time.perf_counter()
    # Collector.eval_agent(search_policy, eval_env, num_cleanup_steps, by_episode=False) # random goals in env
    Collector.step_cleanup(
        search_policy, eval_env,
        num_cleanup_steps)  # samples goals from nodes in state graph
    cleanup_end = time.perf_counter()
    search_policy.set_cleanup(False)
    cleanup_time = cleanup_end - cleanup_start
    return cleanup_time
Beispiel #3
0
def visualize_compare_search(agent, search_policy, eval_env, difficulty=0.5, seed=0):
    set_env_difficulty(eval_env, difficulty)

    plt.figure(figsize=(12, 5))
    for col_index in range(2):
        title = 'no search' if col_index == 0 else 'search'
        plt.subplot(1, 2, col_index + 1)
        plot_walls(eval_env.walls)
        use_search = (col_index == 1)

        set_global_seed(seed)
        set_env_seed(eval_env, seed + 1)

        if use_search:
            policy = search_policy
        else:
            policy = agent
        goal, observations, waypoints, _ = Collector.get_trajectory(policy, eval_env)
        start = observations[0]

        obs_vec = np.array(observations)
        waypoint_vec = np.array(waypoints)

        print(f'policy: {title}')
        print(f'start: {start}')
        print(f'goal: {goal}')
        print(f'steps: {obs_vec.shape[0] - 1}')
        print('-' * 10)

        plt.plot(obs_vec[:, 0], obs_vec[:, 1], 'b-o', alpha=0.3)
        plt.scatter([start[0]], [start[1]], marker='+',
                    color='red', s=200, label='start')
        plt.scatter([obs_vec[-1, 0]], [obs_vec[-1, 1]], marker='+',
                    color='green', s=200, label='end')
        plt.scatter([goal[0]], [goal[1]], marker='*',
                    color='green', s=200, label='goal')
        plt.title(title, fontsize=24)

        if use_search:
            plt.plot(waypoint_vec[:, 0], waypoint_vec[:, 1], 'y-s', alpha=0.3, label='waypoint')
            plt.legend(loc='lower left', bbox_to_anchor=(-0.8, -0.15), ncol=4, fontsize=16)
    plt.show()
Beispiel #4
0
def visualize_trajectory(agent, eval_env, difficulty=0.5):
    set_env_difficulty(eval_env, difficulty)

    plt.figure(figsize=(8, 4))
    for col_index in range(2):
        plt.subplot(1, 2, col_index + 1)
        plot_walls(eval_env.walls)
        goal, observations_list, _, _ = Collector.get_trajectory(agent, eval_env)
        obs_vec = np.array(observations_list)

        print(f'traj {col_index}, num steps: {len(obs_vec)}')

        plt.plot(obs_vec[:, 0], obs_vec[:, 1], 'b-o', alpha=0.3)
        plt.scatter([obs_vec[0, 0]], [obs_vec[0, 1]], marker='+',
                    color='red', s=200, label='start')
        plt.scatter([obs_vec[-1, 0]], [obs_vec[-1, 1]], marker='+',
                    color='green', s=200, label='end')
        plt.scatter([goal[0]], [goal[1]], marker='*',
                    color='green', s=200, label='goal')
        if col_index == 0:
            plt.legend(loc='lower left', bbox_to_anchor=(0.3, 1), ncol=3, fontsize=16)
    plt.show()
Beispiel #5
0
def cleanup_and_eval_search_policy(search_policy,
                                   eval_env,
                                   num_evals=10,
                                   difficulty=0.5):
    set_env_difficulty(eval_env, difficulty)
    search_policy.reset_stats()
    success_rate, eval_time = eval_search_policy(search_policy,
                                                 eval_env,
                                                 num_evals=num_evals)

    # Initial sparse graph
    print(
        f'Initial {search_policy} has success rate {success_rate:.2f}, evaluated in {eval_time:.2f} seconds'
    )
    initial_g, initial_rb = search_policy.g.copy(), search_policy.rb_vec.copy()

    # Filter search policy
    search_policy.filter_keep_k_nearest()

    set_env_difficulty(eval_env, difficulty)
    search_policy.reset_stats()
    success_rate, eval_time = eval_search_policy(search_policy,
                                                 eval_env,
                                                 num_evals=num_evals)
    print(
        f'Filtered {search_policy} has success rate {success_rate:.2f}, evaluated in {eval_time:.2f} seconds'
    )
    filtered_g, filtered_rb = search_policy.g.copy(
    ), search_policy.rb_vec.copy()

    # Cleanup steps
    num_cleanup_steps = int(1e4)
    cleanup_time = take_cleanup_steps(search_policy, eval_env,
                                      num_cleanup_steps)
    print(
        f'Took {num_cleanup_steps} cleanup steps in {cleanup_time:.2f} seconds'
    )

    set_env_difficulty(eval_env, difficulty)
    search_policy.reset_stats()
    success_rate, eval_time = eval_search_policy(search_policy,
                                                 eval_env,
                                                 num_evals=num_evals)
    print(
        f'Cleaned {search_policy} has success rate {success_rate:.2f}, evaluated in {eval_time:.2f} seconds'
    )
    cleaned_g, cleaned_rb = search_policy.g.copy(), search_policy.rb_vec.copy()

    return (initial_g, initial_rb), (filtered_g, filtered_rb), (cleaned_g,
                                                                cleaned_rb)