def test_dfs():
    """Test dfs for lane graph

    Lane Graph:
                9629626
               /       \
              /         \
          9620336    9632589
          (10.77)     (8.33)
             |          |
             |          |
          9628835    9621228
           (31.9)    (31.96)
             |          |
             |          |
          9629406    9626257
           (7.9)      (7.81)

    """

    lane_id = 9629626
    city_name = "MIA"
    dist = 0.0
    threshold = 30.0
    extend_along_predecessor = False

    avm = ArgoverseMap()
    lane_seq = avm.dfs(lane_id, city_name, dist, threshold,
                       extend_along_predecessor)

    expected_lane_seq = [[9629626, 9620336, 9628835],
                         [9629626, 9632589, 9621228]]
    assert np.array_equal(lane_seq,
                          expected_lane_seq), "dfs over lane graph failed!"
    def get_candidate_centerlines_for_trajectory(
        self,
        xy: np.ndarray,
        city_name: str,
        avm: ArgoverseMap,
        viz: bool = False,
        max_search_radius: float = 50.0,
        seq_len: int = 50,
        max_candidates: int = 10,
        mode: str = "test",
    ) -> List[np.ndarray]:
        """Get centerline candidates upto a threshold.
        Algorithm:
        1. Take the lanes in the bubble of last observed coordinate
        2. Extend before and after considering all possible candidates
        3. Get centerlines based on point in polygon score.
        Args:
            xy: Trajectory coordinates,
            city_name: City name,
            avm: Argoverse map_api instance,
            viz: Visualize candidate centerlines,
            max_search_radius: Max search radius for finding nearby lanes in meters,
            seq_len: Sequence length,
            max_candidates: Maximum number of centerlines to return,
            mode: train/val/test mode
        Returns:
            candidate_centerlines: List of candidate centerlines
        """
        # Get all lane candidates within a bubble
        curr_lane_candidates = avm.get_lane_ids_in_xy_bbox(
            xy[-1, 0], xy[-1, 1], city_name, self._MANHATTAN_THRESHOLD)

        # Keep expanding the bubble until at least 1 lane is found
        while (len(curr_lane_candidates) < 1
               and self._MANHATTAN_THRESHOLD < max_search_radius):
            self._MANHATTAN_THRESHOLD *= 2
            curr_lane_candidates = avm.get_lane_ids_in_xy_bbox(
                xy[-1, 0], xy[-1, 1], city_name, self._MANHATTAN_THRESHOLD)

        assert len(curr_lane_candidates) > 0, "No nearby lanes found!!"

        # Set dfs threshold
        traj_len = xy.shape[0]

        # Assuming a speed of 50 mps, set threshold for traversing in the front and back
        dfs_threshold_front = (self._DFS_THRESHOLD_FRONT_SCALE *
                               (seq_len + 1 - traj_len) / 10)
        dfs_threshold_back = self._DFS_THRESHOLD_BACK_SCALE * (traj_len +
                                                               1) / 10

        # DFS to get all successor and predecessor candidates
        obs_pred_lanes: List[Sequence[int]] = []
        for lane in curr_lane_candidates:
            candidates_future = avm.dfs(lane, city_name, 0,
                                        dfs_threshold_front)
            candidates_past = avm.dfs(lane, city_name, 0, dfs_threshold_back,
                                      True)

            # Merge past and future
            for past_lane_seq in candidates_past:
                for future_lane_seq in candidates_future:
                    assert (
                        past_lane_seq[-1] == future_lane_seq[0]
                    ), "Incorrect DFS for candidate lanes past and future"
                    obs_pred_lanes.append(past_lane_seq + future_lane_seq[1:])

        # Removing overlapping lanes
        obs_pred_lanes = remove_overlapping_lane_seq(obs_pred_lanes)

        # Sort lanes based on point in polygon score
        obs_pred_lanes, scores = self.sort_lanes_based_on_point_in_polygon_score(
            obs_pred_lanes, xy, city_name, avm)

        # If the best centerline is not along the direction of travel, re-sort
        if mode == "test":
            candidate_centerlines = self.get_heuristic_centerlines_for_test_set(
                obs_pred_lanes, xy, city_name, avm, max_candidates, scores)
        else:
            candidate_centerlines = avm.get_cl_from_lane_seq(
                [obs_pred_lanes[0]], city_name)

        if viz:
            plt.figure(0, figsize=(8, 7))
            for centerline_coords in candidate_centerlines:
                visualize_centerline(centerline_coords)
            plt.plot(
                xy[:, 0],
                xy[:, 1],
                "-",
                color="#d33e4c",
                alpha=1,
                linewidth=3,
                zorder=15,
            )

            final_x = xy[-1, 0]
            final_y = xy[-1, 1]

            plt.plot(
                final_x,
                final_y,
                "o",
                color="#d33e4c",
                alpha=1,
                markersize=10,
                zorder=15,
            )
            plt.xlabel("Map X")
            plt.ylabel("Map Y")
            plt.axis("off")
            plt.title(f"Number of candidates = {len(candidate_centerlines)}")
            plt.show()

        return candidate_centerlines
Exemple #3
0
    def get_candidate_centerlines_for_trajectory(
        self,
        xy: np.ndarray,
        city_name: str,
        avm: ArgoverseMap,
        yaw_deg: int,
        viz: bool = False,
        max_search_radius: float = 100.0,
        seq_len: int = 50,
        max_candidates: int = 10,
        mode: str = "test",
    ) -> List[np.ndarray]:
        """Get centerline candidates upto a threshold.

        Algorithm:
        1. Take the lanes in the bubble of last observed coordinate
        2. Extend before and after considering all possible candidates
        3. Get centerlines based on point in polygon score.

        Args:
            xy: Trajectory coordinates, 
            city_name: City name, 
            avm: Argoverse map_api instance, 
            viz: Visualize candidate centerlines, 
            max_search_radius: Max search radius for finding nearby lanes in meters,
            seq_len: Sequence length, 
            max_candidates: Maximum number of centerlines to return, 
            mode: train/val/test mode

        Returns:
            candidate_centerlines: List of candidate centerlines

        """
        # Get all lane candidates within a bubble
        curr_lane_candidates = avm.get_lane_ids_in_xy_bbox(
            xy[-1, 0], xy[-1, 1], city_name, self._MANHATTAN_THRESHOLD)

        # Keep expanding the bubble until at least 1 lane is found
        while (len(curr_lane_candidates) < 1
               and self._MANHATTAN_THRESHOLD < max_search_radius):
            self._MANHATTAN_THRESHOLD *= 2
            curr_lane_candidates = avm.get_lane_ids_in_xy_bbox(
                xy[-1, 0], xy[-1, 1], city_name, self._MANHATTAN_THRESHOLD)

        assert len(curr_lane_candidates) > 0, "No nearby lanes found!!"

        # Set dfs threshold
        traj_len = xy.shape[0]

        # Assuming a speed of 50 mps, set threshold for traversing in the front and back
        dfs_threshold_front = (self._DFS_THRESHOLD_FRONT_SCALE *
                               (seq_len + 1 - traj_len) / 10)
        dfs_threshold_back = self._DFS_THRESHOLD_BACK_SCALE * (traj_len +
                                                               1) / 10

        # DFS to get all successor and predecessor candidates
        obs_pred_lanes: List[Sequence[int]] = []
        for lane in curr_lane_candidates:
            candidates_future = avm.dfs(lane, city_name, 0,
                                        dfs_threshold_front)
            candidates_past = avm.dfs(lane, city_name, 0, dfs_threshold_back,
                                      True)

            # Merge past and future
            for past_lane_seq in candidates_past:
                for future_lane_seq in candidates_future:
                    assert (
                        past_lane_seq[-1] == future_lane_seq[0]
                    ), "Incorrect DFS for candidate lanes past and future"
                    obs_pred_lanes.append(past_lane_seq + future_lane_seq[1:])

        # Removing overlapping lanes
        obs_pred_lanes = remove_overlapping_lane_seq(obs_pred_lanes)

        # Sort lanes based on point in polygon score
        obs_pred_lanes, scores = self.sort_lanes_based_on_point_in_polygon_score(
            obs_pred_lanes, xy, city_name, avm)

        # If the best centerline is not along the direction of travel, re-sort

        if mode == "test":
            candidate_centerlines = self.get_heuristic_centerlines_for_test_set(
                obs_pred_lanes, xy, city_name, avm, max_candidates, scores)
        else:
            candidate_centerlines = avm.get_cl_from_lane_seq(
                [obs_pred_lanes[0]], city_name)

        centroid = xy[19]
        #         print("centroid is  ",centroid)
        #         print("yaw_deg is ",yaw_deg)

        raster_size = (224, 224)
        ego_yaw = -math.pi * yaw_deg / 180
        #         print("ego_yaw is ",ego_yaw)

        world_to_image_space = world_to_image_pixels_matrix(
            raster_size,
            (0.6, 0.6),
            ego_translation_m=centroid,
            ego_yaw_rad=ego_yaw,
            ego_center_in_image_ratio=np.array([0.125, 0.5]),
        )

        #         fig=plt.figure(figsize=(15, 15))
        candidate_centerlines_normalized = []

        if viz:
            img = 255 * np.ones(shape=(raster_size[0], raster_size[1], 3),
                                dtype=np.uint8)

            #             plt.figure(0, figsize=(8, 7))
            for centerline_coords in candidate_centerlines:
                cnt_line = transform_points(centerline_coords,
                                            world_to_image_space)
                cv2.polylines(img, [cv2_subpixel(cnt_line)],
                              False, (0, 0, 0),
                              lineType=cv2.LINE_AA,
                              shift=CV2_SHIFT)
                cropped_vector = crop_tensor(cnt_line, raster_size)
                #                 print("*******")
                #                 print("image spcace cntr-line")
                if len(cropped_vector) > 1:
                    candidate_centerlines_normalized.append(cropped_vector)
#                     print("cropped cntr-line",cropped_vector)
#                     break

            img = img.astype(np.float32) / 255

#             raster_img=(img * 255).astype(np.uint8)
#             fig.add_subplot(1, 4, 1)
#             plt.imshow(raster_img[::-1])
#             plt.show()

#         fig.add_subplot(1, 4, 2)

#         if viz:
#             plt.figure(0, figsize=(8, 7))
#             for centerline_coords in candidate_centerlines:
#                 visualize_centerline(centerline_coords)
#             plt.plot(
#                 xy[:, 0],
#                 xy[:, 1],
#                 "-",
#                 color="#d33e4c",
#                 alpha=1,
#                 linewidth=3,
#                 zorder=15,
#             )

#             final_x = xy[-1, 0]
#             final_y = xy[-1, 1]

#             plt.plot(
#                 final_x,
#                 final_y,
#                 "o",
#                 color="#d33e4c",
#                 alpha=1,
#                 markersize=10,
#                 zorder=15,
#             )
#             plt.xlabel("Map X")
#             plt.ylabel("Map Y")
#             plt.axis("off")
#             plt.title(f"Number of candidates = {len(candidate_centerlines)}")
#             plt.show()

        return candidate_centerlines, img, candidate_centerlines_normalized, world_to_image_space