Exemple #1
0
    def get_candidate_centerlines_for_traj(
            self,
            xy: np.ndarray,
            city_name: str,
            viz: bool = False,
            max_search_radius: float = 50.0) -> List[np.ndarray]:
        """ Get centerline candidates upto a threshold. .

        Algorithm:
        1. Take the lanes in the bubble of last obs coordinate
        2. Extend before and after considering all possible candidates
        3. Get centerlines with max distance along centerline

        Args:
            xy: trajectory of shape (N, 2).
            city_name
            viz: Visualize

        Returns:
            candidate_centerlines: List of candidate centerlines
        """

        # Get all lane candidates within a bubble
        manhattan_threshold = 2.5
        curr_lane_candidates = self.get_lane_ids_in_xy_bbox(
            xy[-1, 0], xy[-1, 1], city_name, manhattan_threshold)

        # Keep expanding the bubble until at least 1 lane is found
        while len(curr_lane_candidates
                  ) < 1 and manhattan_threshold < max_search_radius:
            manhattan_threshold *= 2
            curr_lane_candidates = self.get_lane_ids_in_xy_bbox(
                xy[-1, 0], xy[-1, 1], city_name, manhattan_threshold)

        assert len(curr_lane_candidates) > 0, "No nearby lanes found!!"

        # Set dfs threshold
        displacement = np.sqrt((xy[0, 0] - xy[-1, 0])**2 +
                               (xy[0, 1] - xy[-1, 1])**2)
        dfs_threshold = displacement * 2.0

        # DFS to get all successor and predecessor candidates
        obs_pred_lanes: List[Sequence[int]] = []
        for lane in curr_lane_candidates:
            candidates_future = self.dfs(lane, city_name, 0, dfs_threshold)
            candidates_past = self.dfs(lane, city_name, 0, dfs_threshold, True)

            # Merge past and future
            for past_lane_seq in candidates_past:
                for future_lane_seq in candidates_future:
                    assert past_lane_seq[-1] == future_lane_seq[
                        0], "Incorrect DFS for candidate lanes past and future"
                    obs_pred_lanes.append(past_lane_seq + future_lane_seq[1:])

        # Removing overlapping lanes
        obs_pred_lanes = remove_overlapping_lane_seq(obs_pred_lanes)

        # Remove unnecessary extended predecessors
        obs_pred_lanes = self.remove_extended_predecessors(
            obs_pred_lanes, xy, city_name)

        # Getting candidate centerlines
        candidate_cl = self.get_cl_from_lane_seq(obs_pred_lanes, city_name)

        # Reduce the number of candidates based on distance travelled along the centerline
        candidate_centerlines = filter_candidate_centerlines(xy, candidate_cl)

        # If no candidate found using above criteria, take the onces along with travel is the maximum
        if len(candidate_centerlines) < 1:
            candidate_centerlines = get_centerlines_most_aligned_with_trajectory(
                xy, candidate_cl)

        if viz:
            plt.figure(0, figsize=(8, 7))
            for centerline_coords in candidate_centerlines:
                visualize_centerline(centerline_coords)
            plt.plot(xy[:, 0],
                     xy[:, 1],
                     "-",
                     color="#d33e4c",
                     alpha=1,
                     linewidth=1,
                     zorder=15)

            final_x = xy[-1, 0]
            final_y = xy[-1, 1]

            plt.plot(final_x,
                     final_y,
                     "o",
                     color="#d33e4c",
                     alpha=1,
                     markersize=7,
                     zorder=15)
            plt.xlabel("Map X")
            plt.ylabel("Map Y")
            plt.axis("off")
            plt.title("Number of candidates = {}".format(
                len(candidate_centerlines)))
            plt.show()

        return candidate_centerlines
def test_filter_candidate_centerlines():
    """Test filter candidate centerlines"""

    # Test Case

    # 0              20 24   30       40               60
    #                        *         *                      50
    #                        |         |
    #                        |         |
    #                        |         |
    #                        |         |
    #                        |         |                      40
    #                  (3)   |         |  (2)
    #                        ^         ^
    #                        ^         ^
    #                        |         |
    #                        |         |
    #                        |         |         (1)
    # *--------<<<-----------|---------|--------<<<--------*  30
    #                        | \       |
    #                xxxxxxxx^x \      ^
    #                        ^ x \(5)  ^
    #         (4)            |  x \    |
    # *-------->>>-----------|---x-\---|-------->>>--------*  20
    #                        |    x \  |
    #                        |     x \ |
    #                        |     x  \|
    #                        |     x   |                      10
    #                        |     x   |
    #                        ^     x   ^
    #                        ^     x   ^
    #                        |         |
    #                        *         *                       0
    #

    xy = np.array([
        [35.0, 0.0],
        [35.0, 4.0],
        [35.0, 8.0],
        [35.0, 12.0],
        [35.0, 16.0],
        [33.0, 20.0],
        [31.0, 24.0],
        [29.0, 28.0],
        [25.0, 28.0],
        [21.0, 28.0],
    ])

    cl1 = np.array([(60.0, 30.0), (0.0, 30.0)])
    cl2 = np.array([(40.0, 0.0), (40.0, 50.0)])
    cl3 = np.array([(30.0, 0.0), (30.0, 50.0)])
    cl4 = np.array([(0.0, 20.0), (60.0, 20.0)])
    cl5 = np.array([(40.0, 0.0), (40.0, 10.0), (30.0, 30.0), (0.0, 30.0)])

    candidate_cl = [cl1, cl2, cl3, cl4, cl5]

    filtered_cl = sorted(filter_candidate_centerlines(xy, candidate_cl))
    expected_cl = [cl5]

    for i in range(len(filtered_cl)):
        assert np.allclose(expected_cl[i],
                           filtered_cl[i]), "Filtered centerlines wrong!"