def __init__(self, root_dir='argoverse-data//data', avm=None, social=False, train_seq_size=20, cuda=False, test=False, oracle=True): super(Argoverse_LaneCentre_Data, self).__init__(root_dir, train_seq_size, cuda, test) if avm is None: self.avm = ArgoverseMap() else: self.avm = avm self.stationary_threshold = 2.0 self.oracle = oracle print("Done loading map")
def map_features_helper(locations, dfs_threshold_multiplier=2.0, save_str="", avm=None, mfu=None, rotation=None, translation=None, generate_candidate_centerlines=0, compute_all=False): # Initialize map utilities if not provided if avm is None: avm = ArgoverseMap() if mfu is None: mfu = MapFeaturesUtils() # Get best-fitting (oracle) centerline for current vehicle heuristic_oracle_centerline = mfu.get_candidate_centerlines_for_trajectory( locations, city, avm=avm, viz=False, max_candidates=generate_candidate_centerlines, mode='train')[0] # NOQA features = { "HEURISTIC_ORACLE_CENTERLINE" + save_str: heuristic_oracle_centerline, "HEURISTIC_ORACLE_CENTERLINE_NORMALIZED" + save_str: normalize_xy(heuristic_oracle_centerline, translation=translation, rotation=rotation)[0] # NOQA } # Get top-fitting candidate centerlines for current vehicle (can beused at test time) if compute_all: if generate_candidate_centerlines > 0: test_candidate_centerlines = mfu.get_candidate_centerlines_for_trajectory( locations, city, avm=avm, viz=False, max_candidates=generate_candidate_centerlines, mode='test') # NOQA features["TEST_CANDIDATE_CENTERLINES" + save_str] = test_candidate_centerlines # Apply rotation and translation normalization if specified if rotation is not None or translation is not None: if generate_candidate_centerlines > 0: features['TEST_CANDIDATE_CENTERLINE_NORMALIZED' + save_str] = [ normalize_xy(test_candidate_centerline, translation=translation, rotation=rotation)[0] for test_candidate_centerline in test_candidate_centerlines ] # NOQA return features
def __init__(self, question_h5, image_feature_h5_path, lidar_feature_h5_path,vocab,load_lidar=True,npoint=1024,normal_channel=True,uniform=False,cache_size=15000,drivable_area=False, mode='prefix', image_h5=None, lidar_h5=None, max_samples=None, question_families=None, image_idx_start_from=None): #############read whole question_h5 file in memory############################# self.all_questions = question_h5['questions'][:] self.all_answers = get_answer_classes(question_h5['answers'][:], vocab) self.all_image_idxs = question_h5['image_idxs'][:] self.all_video_names = (question_h5['video_names'][:]).astype(str) self.questions_length = question_h5['question_length'][:] self.image_feature_h5 = image_feature_h5_path self.load_lidar=load_lidar ############for lidar########################################################## if self.load_lidar: self.argoverse_loader = ArgoverseTrackingLoader(lidar_feature_h5_path) self.am = ArgoverseMap() self.drivable_area=drivable_area
def visualize_ground_lidar_pts(log_id: str, dataset_dir: str, experiment_prefix: str): """Process a log by drawing the LiDAR returns that are classified as belonging to the ground surface in a red to green colormap in the image. Args: log_id: The ID of a log dataset_dir: Where the dataset is stored experiment_prefix: Output prefix """ sdb = SynchronizationDB(dataset_dir, collect_single_log_id=log_id) city_info_fpath = f"{dataset_dir}/{log_id}/city_info.json" city_info = read_json_file(city_info_fpath) city_name = city_info["city_name"] avm = ArgoverseMap() ply_fpaths = sorted(glob.glob(f"{dataset_dir}/{log_id}/lidar/PC_*.ply")) for i, ply_fpath in enumerate(ply_fpaths): if i % 500 == 0: print(f"\tOn file {i} of {log_id}") lidar_timestamp_ns = ply_fpath.split("/")[-1].split(".")[0].split("_")[-1] pose_fpath = f"{dataset_dir}/{log_id}/poses/city_SE3_egovehicle_{lidar_timestamp_ns}.json" if not Path(pose_fpath).exists(): continue pose_data = read_json_file(pose_fpath) rotation = np.array(pose_data["rotation"]) translation = np.array(pose_data["translation"]) city_to_egovehicle_se3 = SE3(rotation=quat2rotmat(rotation), translation=translation) lidar_pts = load_ply(ply_fpath) lidar_timestamp_ns = int(lidar_timestamp_ns) draw_ground_pts_in_image( sdb, lidar_pts, city_to_egovehicle_se3, avm, log_id, lidar_timestamp_ns, city_name, dataset_dir, experiment_prefix, ) for camera_name in CAMERA_LIST: if "stereo" in camera_name: fps = 5 else: fps = 10 cmd = f"ffmpeg -r {fps} -f image2 -i '{experiment_prefix}_ground_viz/{log_id}/{camera_name}/%*.jpg' {experiment_prefix}_ground_viz/{experiment_prefix}_{log_id}_{camera_name}_{fps}fps.mp4" print(cmd) run_command(cmd)
def test_remove_extended_predecessors() -> None: """Test remove_extended_predecessors() for map_api""" lane_seqs = [ [9621385, 9619110, 9619209, 9631133], [9621385, 9619110, 9619209], [9619209, 9631133], ] xy = np.array([[-130.0, 2315.0], [-129.0, 2315.0], [-128.0, 2315.0]]) # 9619209 comntains xy[0] city_name = "MIA" avm = ArgoverseMap() filtered_lane_seq = avm.remove_extended_predecessors( lane_seqs, xy, city_name) assert np.array_equal(filtered_lane_seq, [[9619209, 9631133], [9619209], [9619209, 9631133] ]), "remove_extended_predecessors() failed!"
def get_all_lanes(city_name: str, avm: Optional[ArgoverseMap] = None) -> list: # Get API for Argo Dataset map avm = ArgoverseMap() if avm is None else avm seq_lane_bbox = avm.city_halluc_bbox_table[city_name] seq_lane_props = avm.city_lane_centerlines_dict[city_name] lane_centerlines = [lane.centerline for lane in seq_lane_props.values()] return lane_centerlines
def __init__(self, split, config, train=True): self.config = config self.train = train if 'preprocess' in config and config['preprocess']: if train: self.split = np.load(self.config['preprocess_train'], allow_pickle=True) else: self.split = np.load(self.config['preprocess_val'], allow_pickle=True) else: self.avl = ArgoverseForecastingLoader(split) self.avl.seq_list = sorted(self.avl.seq_list) self.am = ArgoverseMap() if 'raster' in config and config['raster']: #TODO: DELETE self.map_query = MapQuery(config['map_scale'])
def visualize_forecating_data_on_map(args: Any) -> None: print("Loading map...") avm = ArgoverseMap() fomv = ForecastingOnMapVisualizer( dataset_dir=args.dataset_dir, save_img=args.save_image, overwrite_rendered_file=args.overwrite_rendered_file) for i in range(fomv.num): print(f"Processing the file: {fomv.filenames[i]}") fomv.plot_log_one_at_a_time(avm, log_num=i)
def __init__(self, cfg): super().__init__() self.am = ArgoverseMap() self.axis_range = self.get_map_range(self.am) self.city_halluc_bbox_table, self.city_halluc_tableidx_to_laneid_map = self.am.build_hallucinated_lane_bbox_index( ) self.laneid_map = self.process_laneid_map() self.vector_map, self.extra_map = self.generate_vector_map() # am.draw_lane(city_halluc_tableidx_to_laneid_map['PIT']['494'], 'PIT') # self.save_vector_map(self.vector_map) self.last_observe = cfg['last_observe'] ##set root_dir to the correct path to your dataset folder self.root_dir = cfg['data_locate'] self.afl = ArgoverseForecastingLoader(self.root_dir) self.map_feature = dict(PIT=[], MIA=[]) self.city_name, self.center_xy, self.rotate_matrix = dict(), dict( ), dict()
def verify_manhattan_search_functionality(): """ Minimal example where we """ adm = ArgoverseMap() # query_x = 254. # query_y = 1778. ref_query_x = 422.0 ref_query_y = 1005.0 city_name = "PIT" # 'MIA' for trial_idx in range(10): query_x = ref_query_x + (np.random.rand() - 0.5) * 10 query_y = ref_query_y + (np.random.rand() - 0.5) * 10 # query_x,query_y = (3092.49845414,1798.55426805) query_x, query_y = (3112.80160113, 1817.07585338) lane_segment_ids = avm.get_lane_ids_in_xy_bbox(query_x, query_y, city_name, 5000) fig = plt.figure(figsize=(22.5, 8)) ax = fig.add_subplot(111) # ax.scatter([query_x], [query_y], 500, color='k', marker='.') plot_lane_segment_patch(pittsburgh_bounds, ax, color="m", alpha=0.1) if len(lane_segment_ids) > 0: for i, lane_segment_id in enumerate(lane_segment_ids): patch_color = "y" # patch_colors[i % 4] lane_centerline = avm.get_lane_segment_centerline( lane_segment_id, city_name) test_x, test_y = lane_centerline.mean(axis=0) inside = point_inside_polygon(n_poly_vertices, pittsburgh_bounds[:, 0], pittsburgh_bounds[:, 1], test_x, test_y) if inside: halluc_lane_polygon = avm.get_lane_segment_polygon( lane_segment_id, city_name) xmin, ymin, xmax, ymax = find_lane_segment_bounds_in_table( adm, city_name, lane_segment_id) add_lane_segment_to_ax(ax, lane_centerline, halluc_lane_polygon, patch_color, xmin, xmax, ymin, ymax) ax.axis("equal") plt.show() datetime_str = generate_datetime_string() plt.savefig(f"{trial_idx}_{datetime_str}.jpg") plt.close("all")
def __init__(self, data_dir, obs_len=20, position_downscaling_factor=100): """ Args: inp_dir: Directory with all trajectories obs_len: length of observed trajectory """ self.data_dir = data_dir self.obs_len = obs_len self.position_downscaling_factor = position_downscaling_factor assert os.path.isdir(data_dir), 'Invalid Data Directory' self.afl = ArgoverseForecastingLoader(data_dir) self.avm = ArgoverseMap()
def __init__(self, root_dir='argoverse-data//data', avm=None, train_seq_size=20, mode="train", save=False, load_saved=False): super(Argoverse_MultiLane_Data, self).__init__(root_dir, train_seq_size) if avm is None: self.avm = ArgoverseMap() else: self.avm = avm # if mode=="train": # with open('train.pkl', 'rb') as f: # self.seq_paths=pickle.load(f) # elif mode=="validate": # with open('val.pkl', 'rb') as f: # self.seq_paths=pickle.load(f) self.map_features_utils_instance = MapFeaturesUtils() self.social_features_utils_instance = SocialFeaturesUtils() self.mode = mode self.save = save self.load_saved = load_saved
def plot_nearby_halluc_lanes( ax: plt.Axes, city_name: str, avm: ArgoverseMap, query_x: float, query_y: float, patch_color: str = "r", radius: float = 20.0, ) -> None: """Produce lane segment graphs for visual verification.""" nearby_lane_ids = avm.get_lane_ids_in_xy_bbox(query_x, query_y, city_name, radius) for nearby_lane_id in nearby_lane_ids: halluc_lane_polygon = avm.get_lane_segment_polygon( nearby_lane_id, city_name) plot_lane_segment_patch(halluc_lane_polygon, ax, color=patch_color, alpha=0.3) plt.text( halluc_lane_polygon[:, 0].mean(), halluc_lane_polygon[:, 1].mean(), str(nearby_lane_id), )
def get_pruned_guesses( forecasted_trajectories: Dict[int, List[np.ndarray]], city_names: Dict[int, str], gt_trajectories: Dict[int, np.ndarray], ) -> Dict[int, List[np.ndarray]]: """Prune the number of guesses using map. Args: forecasted_trajectories: Trajectories forecasted by the algorithm. city_names: Dict mapping sequence id to city name. gt_trajectories: Ground Truth trajectories. Returns: Pruned number of forecasted trajectories. """ args = parse_arguments() avm = ArgoverseMap() pruned_guesses = {} for seq_id, trajectories in forecasted_trajectories.items(): city_name = city_names[seq_id] da_points = [] for trajectory in trajectories: raster_layer = avm.get_raster_layer_points_boolean( trajectory, city_name, "driveable_area") da_points.append(np.sum(raster_layer)) sorted_idx = np.argsort(da_points)[::-1] pruned_guesses[seq_id] = [ trajectories[i] for i in sorted_idx[:args.prune_n_guesses] ] return pruned_guesses
def __init__(self,root,train = True,test = False): ''' 根据路径获得数据,并根据训练、验证、测试划分数据 train_data 和 test_data路径分开 ''' self.test = test afl = ArgoverseForecastingLoader(root) self.avm = ArgoverseMap() if self.test: self.afl = afl elif train: self.afl = afl[:int(0.7*len(afl))] else: self.afl = afl[int(0.7*len(afl)):]
def __init__(self, split, config, train=False): self.config = config self.train = train split2 = config['val_split'] if split=='val' else config['test_split'] split = self.config['preprocess_val'] if split=='val' else self.config['preprocess_test'] self.avl = ArgoverseForecastingLoader(split2) if 'preprocess' in config and config['preprocess']: if train: self.split = np.load(split, allow_pickle=True) else: self.split = np.load(split, allow_pickle=True) else: self.avl = ArgoverseForecastingLoader(split) self.am = ArgoverseMap()
def draw_lane_ids(lane_ids: List[int], am: ArgoverseMap, ax: Axes, city_name: str) -> None: """ Args: - lane_ids - am - ax - city_name Returns: - None """ for lane_id in lane_ids: centerline = am.get_lane_segment_centerline(int(lane_id), city_name) ax.text(centerline[2, 0], centerline[2, 1], f"s_{lane_id}") ax.text(centerline[-3, 0], centerline[-3, 1], f"e_{lane_id}")
def build_city_lane_graphs( am: ArgoverseMap) -> Mapping[str, Mapping[int, List[int]]]: """ Args: - am Returns: - city_graph_dict """ city_lane_centerlines_dict = am.build_centerline_index() city_graph_dict = {} for city_name in ["MIA", "PIT"]: city_graph = {} for lane_id, segment in city_lane_centerlines_dict[city_name].items(): # allow left/right lane changes if segment.l_neighbor_id: if lanes_travel_same_direction(lane_id, segment.l_neighbor_id, am, city_name): city_graph.setdefault(str(lane_id), []).append( str(segment.l_neighbor_id)) if segment.r_neighbor_id: if lanes_travel_same_direction(lane_id, segment.r_neighbor_id, am, city_name): city_graph.setdefault(str(lane_id), []).append( str(segment.r_neighbor_id)) if segment.predecessors: for pred_id in segment.predecessors: city_graph.setdefault(str(pred_id), []).append(str(lane_id)) if segment.successors: for succ_id in segment.successors: city_graph.setdefault(str(lane_id), []).append(str(succ_id)) for k, v in city_graph.items(): city_graph[k] = list(set(v)) city_graph[k].sort() city_graph_dict[city_name] = city_graph return city_graph_dict
def __init__(self, tracking_dataset_dir, dataset_name=None, argoverse_map=None, argoverse_loader=None, save_imgs=False): logger = logging.getLogger() logger.setLevel(logging.CRITICAL) self.dataset_dir = tracking_dataset_dir self.am = ArgoverseMap() if argoverse_map is None else argoverse_map self.argoverse_loader = ArgoverseTrackingLoader( tracking_dataset_dir ) if argoverse_loader is None else argoverse_loader self.dataset_prefix_name = dataset_name self.objects_from_to = self._get_objects_from_to() self.valid_target_objects = self._get_valid_target_objects( save_imgs=save_imgs)
def __init__(self, root_dir='argoverse-data/forecasting_sample/data', train_seq_size=20, mode="train", save=False, load_saved=False, avm=None): super(Argoverse_Social_Centerline_Data, self).__init__(root_dir, train_seq_size) # self.agent_rel=agent_rel if avm is None: self.avm = ArgoverseMap() else: self.avm = avm self.map_features_utils_instance = MapFeaturesUtils() self.social_features_utils_instance = SocialFeaturesUtils() self.save = save self.mode = mode self.load_saved = load_saved
def get_point_in_polygon_score(self, lane_seq: List[int], xy_seq: np.ndarray, city_name: str, avm: ArgoverseMap) -> int: """Get the number of coordinates that lie insde the lane seq polygon. Args: lane_seq: Sequence of lane ids xy_seq: Trajectory coordinates city_name: City name (PITT/MIA) avm: Argoverse map_api instance Returns: point_in_polygon_score: Number of coordinates in the trajectory that lie within the lane sequence """ lane_seq_polygon = cascaded_union([ Polygon(avm.get_lane_segment_polygon(lane, city_name)).buffer(0) for lane in lane_seq ]) point_in_polygon_score = 0 for xy in xy_seq: point_in_polygon_score += lane_seq_polygon.contains(Point(xy)) return point_in_polygon_score
def test_filter_objs_to_roi(): """ Use the map to filter out an object that lies outside the ROI in a parking lot """ avm = ArgoverseMap() # should be outside of ROI outside_obj = { "center": {"x": -14.102872067388489, "y": 19.466695178746022, "z": 0.11740010190455852}, "rotation": {"x": 0.0, "y": 0.0, "z": -0.038991328555453404, "w": 0.9992395490058831}, "length": 4.56126567460171, "width": 1.9370055686754908, "height": 1.5820081349372281, "track_label_uuid": "03a321bf955a4d7781682913884abf06", "timestamp": 315970611820366000, "label_class": "VEHICLE", } # should be inside the ROI inside_obj = { "center": {"x": -20.727430239506702, "y": 3.4488006757501353, "z": 0.4036619561689685}, "rotation": {"x": 0.0, "y": 0.0, "z": 0.0013102003738908123, "w": 0.9999991416871218}, "length": 4.507580779458834, "width": 1.9243189627993598, "height": 1.629934978730058, "track_label_uuid": "bb0f40e4f68043e285d64a839f2f092c", "timestamp": 315970611820366000, "label_class": "VEHICLE", } log_city_name = "PIT" lidar_ts = 315970611820366000 dataset_dir = TEST_DATA_LOC / "roi_based_test" log_id = "21e37598-52d4-345c-8ef9-03ae19615d3d" city_SE3_egovehicle = get_city_SE3_egovehicle_at_sensor_t(lidar_ts, dataset_dir, log_id) dts = np.array([json_label_dict_to_obj_record(item) for item in [outside_obj, inside_obj]]) dts_filtered = filter_objs_to_roi(dts, avm, city_SE3_egovehicle, log_city_name) assert dts_filtered.size == 1 assert dts_filtered.dtype == "O" # array of objects assert isinstance(dts_filtered, np.ndarray) assert dts_filtered[0].track_id == "bb0f40e4f68043e285d64a839f2f092c"
def __init__(self, file_path: str, shuffle: bool = True, random_rotation: bool = False, max_car_num: int = 50, freq: int = 10, use_interpolate: bool = False, use_lane: bool = False, use_mask: bool = True): if not os.path.exists(file_path): raise Exception("Path does not exist.") self.afl = ArgoverseForecastingLoader(file_path) self.shuffle = shuffle self.random_rotation = random_rotation self.max_car_num = max_car_num self.freq = freq self.use_interpolate = use_interpolate self.am = ArgoverseMap() self.use_mask = use_mask self.file_path = file_path
def save_all_to_pickle(): datasets = ["train1", "train2", "train3", "train4"] final_dict = {} for dataset in datasets: tracking_dataset_dir = '/media/bartosz/hdd1TB/workspace_hdd/datasets/argodataset/argoverse-tracking/' + dataset ################### am = ArgoverseMap() argoverse_loader = ArgoverseTrackingLoader(tracking_dataset_dir) ################### argoverse = Argoverse(tracking_dataset_dir=tracking_dataset_dir, dataset_name=dataset, argoverse_map=am, argoverse_loader=argoverse_loader) final_dict.update(argoverse.valid_target_objects) print("Processed {}".format(dataset)) f = "/media/bartosz/hdd1TB/workspace_hdd/SS-LSTM/data/argoverse/train1234_48x48.pickle" pickle_out = open(f, "wb") pickle.dump(final_dict, pickle_out, protocol=2) pickle_out.close() print("Saved to pickle {}".format(f))
def __init__(self, root, train=True, test=False): ''' 根据路径获得数据,并根据训练、验证、测试划分数据 train_data 和 test_data路径分开 ''' self.test = test self.train = train self.afl = ArgoverseForecastingLoader(root) self.avm = ArgoverseMap() root_dir = Path(root) r = [(root_dir / x).absolute() for x in os.listdir(root_dir)] n = len(r) if self.test == True: self.start = 0 self.end = n elif self.train: self.start = 0 self.end = int(0.7 * n) else: self.start = int(0.7 * n) + 1 self.end = n
def __init__(self, data_dict: Dict[str, Any], args: Any, mode: str, base_dir="/work/vita/sadegh/argo/argoverse-api/", use_history=True, use_agents=True, use_scene=True): """Initialize the Dataset. Args: data_dict: Dict containing all the data args: Arguments passed to the baseline code mode: train/val/test mode """ self.data_dict = data_dict self.args = args self.mode = mode self.use_history = use_history self.use_agents = use_agents self.use_scene = use_scene # Get input self.input_data = data_dict["{}_input".format(mode)] if mode != "test": self.output_data = data_dict["{}_output".format(mode)] self.data_size = self.input_data.shape[0] # Get helpers self.helpers = self.get_helpers() self.helpers = list(zip(*self.helpers)) middle_dir = mode if mode != "test" else "test_obs" self.root_dir = base_dir + middle_dir + "/data" ##set root_dir to the correct path to your dataset folder self.afl = ArgoverseForecastingLoader(self.root_dir) self.avm = ArgoverseMap() self.mf = MapFeaturesUtils()
def verify_point_in_polygon_for_lanes(): """ """ avm = ArgoverseMap() # ref_query_x = 422. # ref_query_y = 1005. ref_query_x = -662 ref_query_y = 2817 city_name = "MIA" for trial_idx in range(10): query_x = ref_query_x + (np.random.rand() - 0.5) * 10 query_y = ref_query_y + (np.random.rand() - 0.5) * 10 fig = plt.figure(figsize=(22.5, 8)) ax = fig.add_subplot(111) ax.scatter([query_x], [query_y], 100, color="k", marker=".") occupied_lane_ids = avm.get_lane_segments_containing_xy( query_x, query_y, city_name) for occupied_lane_id in occupied_lane_ids: halluc_lane_polygon = avm.get_lane_segment_polygon( occupied_lane_id, city_name) plot_lane_segment_patch(halluc_lane_polygon, ax, color="y", alpha=0.3) nearby_lane_ids = avm.get_lane_ids_in_xy_bbox(query_x, query_y, city_name) nearby_lane_ids = set(nearby_lane_ids) - set(occupied_lane_ids) for nearby_lane_id in nearby_lane_ids: halluc_lane_polygon = avm.get_lane_segment_polygon( nearby_lane_id, city_name) plot_lane_segment_patch(halluc_lane_polygon, ax, color="r", alpha=0.3) ax.axis("equal") plt.show() plt.close("all")
def evaluation(): am = ArgoverseMap() val_dataset = read_pkl_data(val_path, batch_size=args.val_batch_size, shuffle=False, repeat=False) trained_model = torch.load(model_name + '.pth') trained_model.eval() with torch.no_grad(): valid_total_loss, valid_metrics = evaluate( trained_model, val_dataset, train_window=args.train_window, max_iter=len(val_dataset), device=device, start_iter=args.val_batches, use_lane=args.use_lane, batch_size=args.val_batch_size) with open('results/{}_predictions.pickle'.format(model_name), 'wb') as f: pickle.dump(valid_metrics, f)
def filter_objs_to_roi( instances: np.ndarray, avm: ArgoverseMap, city_SE3_egovehicle: SE3, city_name: str ) -> np.ndarray: """Filter objects to the region of interest (5 meter dilation of driveable area). We ignore instances outside of region of interest (ROI) during evaluation. Args: instances: Numpy array of shape (N,) with ObjectLabelRecord entries avm: Argoverse map object city_SE3_egovehicle: pose of egovehicle within city map at time of sweep city_name: name of city where log was captured Returns: instances_roi: objects with any of 4 cuboid corners located within ROI """ # for each cuboid, get its 4 corners in the egovehicle frame corners_egoframe = np.vstack([dt.as_2d_bbox() for dt in instances]) corners_cityframe = city_SE3_egovehicle.transform_point_cloud(corners_egoframe) corner_within_roi = avm.get_raster_layer_points_boolean(corners_cityframe, city_name, "roi") # check for each cuboid if any of its 4 corners lies within the ROI is_within_roi = corner_within_roi.reshape(-1, 4).any(axis=1) instances_roi = instances[is_within_roi] return instances_roi
def viz_predictions( input_: np.ndarray, output: np.ndarray, target: np.ndarray, centerlines: np.ndarray, city_names: np.ndarray, idx=None, show: bool = True, ) -> None: """Visualize predicted trjectories. Args: input_ (numpy array): Input Trajectory with shape (num_tracks x obs_len x 2) output (numpy array of list): Top-k predicted trajectories, each with shape (num_tracks x pred_len x 2) target (numpy array): Ground Truth Trajectory with shape (num_tracks x pred_len x 2) centerlines (numpy array of list of centerlines): Centerlines (Oracle/Top-k) for each trajectory city_names (numpy array): city names for each trajectory show (bool): if True, show """ num_tracks = input_.shape[0] obs_len = input_.shape[1] pred_len = target.shape[1] plt.figure(0, figsize=(8, 7)) avm = ArgoverseMap() for i in range(num_tracks): plt.plot( input_[i, :, 0], input_[i, :, 1], color="#ECA154", label="Observed", alpha=1, linewidth=3, zorder=15, ) plt.plot( input_[i, -1, 0], input_[i, -1, 1], "o", color="#ECA154", label="Observed", alpha=1, linewidth=3, zorder=15, markersize=9, ) plt.plot( target[i, :, 0], target[i, :, 1], color="#d33e4c", label="Target", alpha=1, linewidth=3, zorder=20, ) plt.plot( target[i, -1, 0], target[i, -1, 1], "o", color="#d33e4c", label="Target", alpha=1, linewidth=3, zorder=20, markersize=9, ) for j in range(len(centerlines[i])): plt.plot( centerlines[i][j][:, 0], centerlines[i][j][:, 1], "--", color="grey", alpha=1, linewidth=1, zorder=0, ) for j in range(len(output[i])): plt.plot( output[i][j][:, 0], output[i][j][:, 1], color="#007672", label="Predicted", alpha=1, linewidth=3, zorder=15, ) plt.plot( output[i][j][-1, 0], output[i][j][-1, 1], "o", color="#007672", label="Predicted", alpha=1, linewidth=3, zorder=15, markersize=9, ) for k in range(pred_len): lane_ids = avm.get_lane_ids_in_xy_bbox( output[i][j][k, 0], output[i][j][k, 1], city_names[i], query_search_range_manhattan=2.5, ) for j in range(obs_len): lane_ids = avm.get_lane_ids_in_xy_bbox( input_[i, j, 0], input_[i, j, 1], city_names[i], query_search_range_manhattan=2.5, ) [avm.draw_lane(lane_id, city_names[i]) for lane_id in lane_ids] for j in range(pred_len): lane_ids = avm.get_lane_ids_in_xy_bbox( target[i, j, 0], target[i, j, 1], city_names[i], query_search_range_manhattan=2.5, ) [avm.draw_lane(lane_id, city_names[i]) for lane_id in lane_ids] plt.axis("equal") plt.xticks([]) plt.yticks([]) handles, labels = plt.gca().get_legend_handles_labels() by_label = OrderedDict(zip(labels, handles)) if show: plt.show()