def test_dfs(): """Test dfs for lane graph Lane Graph: 9629626 / \ / \ 9620336 9632589 (10.77) (8.33) | | | | 9628835 9621228 (31.9) (31.96) | | | | 9629406 9626257 (7.9) (7.81) """ lane_id = 9629626 city_name = "MIA" dist = 0.0 threshold = 30.0 extend_along_predecessor = False avm = ArgoverseMap() lane_seq = avm.dfs(lane_id, city_name, dist, threshold, extend_along_predecessor) expected_lane_seq = [[9629626, 9620336, 9628835], [9629626, 9632589, 9621228]] assert np.array_equal(lane_seq, expected_lane_seq), "dfs over lane graph failed!"
def get_candidate_centerlines_for_trajectory( self, xy: np.ndarray, city_name: str, avm: ArgoverseMap, viz: bool = False, max_search_radius: float = 50.0, seq_len: int = 50, max_candidates: int = 10, mode: str = "test", ) -> List[np.ndarray]: """Get centerline candidates upto a threshold. Algorithm: 1. Take the lanes in the bubble of last observed coordinate 2. Extend before and after considering all possible candidates 3. Get centerlines based on point in polygon score. Args: xy: Trajectory coordinates, city_name: City name, avm: Argoverse map_api instance, viz: Visualize candidate centerlines, max_search_radius: Max search radius for finding nearby lanes in meters, seq_len: Sequence length, max_candidates: Maximum number of centerlines to return, mode: train/val/test mode Returns: candidate_centerlines: List of candidate centerlines """ # Get all lane candidates within a bubble curr_lane_candidates = avm.get_lane_ids_in_xy_bbox( xy[-1, 0], xy[-1, 1], city_name, self._MANHATTAN_THRESHOLD) # Keep expanding the bubble until at least 1 lane is found while (len(curr_lane_candidates) < 1 and self._MANHATTAN_THRESHOLD < max_search_radius): self._MANHATTAN_THRESHOLD *= 2 curr_lane_candidates = avm.get_lane_ids_in_xy_bbox( xy[-1, 0], xy[-1, 1], city_name, self._MANHATTAN_THRESHOLD) assert len(curr_lane_candidates) > 0, "No nearby lanes found!!" # Set dfs threshold traj_len = xy.shape[0] # Assuming a speed of 50 mps, set threshold for traversing in the front and back dfs_threshold_front = (self._DFS_THRESHOLD_FRONT_SCALE * (seq_len + 1 - traj_len) / 10) dfs_threshold_back = self._DFS_THRESHOLD_BACK_SCALE * (traj_len + 1) / 10 # DFS to get all successor and predecessor candidates obs_pred_lanes: List[Sequence[int]] = [] for lane in curr_lane_candidates: candidates_future = avm.dfs(lane, city_name, 0, dfs_threshold_front) candidates_past = avm.dfs(lane, city_name, 0, dfs_threshold_back, True) # Merge past and future for past_lane_seq in candidates_past: for future_lane_seq in candidates_future: assert ( past_lane_seq[-1] == future_lane_seq[0] ), "Incorrect DFS for candidate lanes past and future" obs_pred_lanes.append(past_lane_seq + future_lane_seq[1:]) # Removing overlapping lanes obs_pred_lanes = remove_overlapping_lane_seq(obs_pred_lanes) # Sort lanes based on point in polygon score obs_pred_lanes, scores = self.sort_lanes_based_on_point_in_polygon_score( obs_pred_lanes, xy, city_name, avm) # If the best centerline is not along the direction of travel, re-sort if mode == "test": candidate_centerlines = self.get_heuristic_centerlines_for_test_set( obs_pred_lanes, xy, city_name, avm, max_candidates, scores) else: candidate_centerlines = avm.get_cl_from_lane_seq( [obs_pred_lanes[0]], city_name) if viz: plt.figure(0, figsize=(8, 7)) for centerline_coords in candidate_centerlines: visualize_centerline(centerline_coords) plt.plot( xy[:, 0], xy[:, 1], "-", color="#d33e4c", alpha=1, linewidth=3, zorder=15, ) final_x = xy[-1, 0] final_y = xy[-1, 1] plt.plot( final_x, final_y, "o", color="#d33e4c", alpha=1, markersize=10, zorder=15, ) plt.xlabel("Map X") plt.ylabel("Map Y") plt.axis("off") plt.title(f"Number of candidates = {len(candidate_centerlines)}") plt.show() return candidate_centerlines
def get_candidate_centerlines_for_trajectory( self, xy: np.ndarray, city_name: str, avm: ArgoverseMap, yaw_deg: int, viz: bool = False, max_search_radius: float = 100.0, seq_len: int = 50, max_candidates: int = 10, mode: str = "test", ) -> List[np.ndarray]: """Get centerline candidates upto a threshold. Algorithm: 1. Take the lanes in the bubble of last observed coordinate 2. Extend before and after considering all possible candidates 3. Get centerlines based on point in polygon score. Args: xy: Trajectory coordinates, city_name: City name, avm: Argoverse map_api instance, viz: Visualize candidate centerlines, max_search_radius: Max search radius for finding nearby lanes in meters, seq_len: Sequence length, max_candidates: Maximum number of centerlines to return, mode: train/val/test mode Returns: candidate_centerlines: List of candidate centerlines """ # Get all lane candidates within a bubble curr_lane_candidates = avm.get_lane_ids_in_xy_bbox( xy[-1, 0], xy[-1, 1], city_name, self._MANHATTAN_THRESHOLD) # Keep expanding the bubble until at least 1 lane is found while (len(curr_lane_candidates) < 1 and self._MANHATTAN_THRESHOLD < max_search_radius): self._MANHATTAN_THRESHOLD *= 2 curr_lane_candidates = avm.get_lane_ids_in_xy_bbox( xy[-1, 0], xy[-1, 1], city_name, self._MANHATTAN_THRESHOLD) assert len(curr_lane_candidates) > 0, "No nearby lanes found!!" # Set dfs threshold traj_len = xy.shape[0] # Assuming a speed of 50 mps, set threshold for traversing in the front and back dfs_threshold_front = (self._DFS_THRESHOLD_FRONT_SCALE * (seq_len + 1 - traj_len) / 10) dfs_threshold_back = self._DFS_THRESHOLD_BACK_SCALE * (traj_len + 1) / 10 # DFS to get all successor and predecessor candidates obs_pred_lanes: List[Sequence[int]] = [] for lane in curr_lane_candidates: candidates_future = avm.dfs(lane, city_name, 0, dfs_threshold_front) candidates_past = avm.dfs(lane, city_name, 0, dfs_threshold_back, True) # Merge past and future for past_lane_seq in candidates_past: for future_lane_seq in candidates_future: assert ( past_lane_seq[-1] == future_lane_seq[0] ), "Incorrect DFS for candidate lanes past and future" obs_pred_lanes.append(past_lane_seq + future_lane_seq[1:]) # Removing overlapping lanes obs_pred_lanes = remove_overlapping_lane_seq(obs_pred_lanes) # Sort lanes based on point in polygon score obs_pred_lanes, scores = self.sort_lanes_based_on_point_in_polygon_score( obs_pred_lanes, xy, city_name, avm) # If the best centerline is not along the direction of travel, re-sort if mode == "test": candidate_centerlines = self.get_heuristic_centerlines_for_test_set( obs_pred_lanes, xy, city_name, avm, max_candidates, scores) else: candidate_centerlines = avm.get_cl_from_lane_seq( [obs_pred_lanes[0]], city_name) centroid = xy[19] # print("centroid is ",centroid) # print("yaw_deg is ",yaw_deg) raster_size = (224, 224) ego_yaw = -math.pi * yaw_deg / 180 # print("ego_yaw is ",ego_yaw) world_to_image_space = world_to_image_pixels_matrix( raster_size, (0.6, 0.6), ego_translation_m=centroid, ego_yaw_rad=ego_yaw, ego_center_in_image_ratio=np.array([0.125, 0.5]), ) # fig=plt.figure(figsize=(15, 15)) candidate_centerlines_normalized = [] if viz: img = 255 * np.ones(shape=(raster_size[0], raster_size[1], 3), dtype=np.uint8) # plt.figure(0, figsize=(8, 7)) for centerline_coords in candidate_centerlines: cnt_line = transform_points(centerline_coords, world_to_image_space) cv2.polylines(img, [cv2_subpixel(cnt_line)], False, (0, 0, 0), lineType=cv2.LINE_AA, shift=CV2_SHIFT) cropped_vector = crop_tensor(cnt_line, raster_size) # print("*******") # print("image spcace cntr-line") if len(cropped_vector) > 1: candidate_centerlines_normalized.append(cropped_vector) # print("cropped cntr-line",cropped_vector) # break img = img.astype(np.float32) / 255 # raster_img=(img * 255).astype(np.uint8) # fig.add_subplot(1, 4, 1) # plt.imshow(raster_img[::-1]) # plt.show() # fig.add_subplot(1, 4, 2) # if viz: # plt.figure(0, figsize=(8, 7)) # for centerline_coords in candidate_centerlines: # visualize_centerline(centerline_coords) # plt.plot( # xy[:, 0], # xy[:, 1], # "-", # color="#d33e4c", # alpha=1, # linewidth=3, # zorder=15, # ) # final_x = xy[-1, 0] # final_y = xy[-1, 1] # plt.plot( # final_x, # final_y, # "o", # color="#d33e4c", # alpha=1, # markersize=10, # zorder=15, # ) # plt.xlabel("Map X") # plt.ylabel("Map Y") # plt.axis("off") # plt.title(f"Number of candidates = {len(candidate_centerlines)}") # plt.show() return candidate_centerlines, img, candidate_centerlines_normalized, world_to_image_space