def BuildVoxelNet(): config_path = Path('second.pytorch/second/configs/xyres_16.proto') ckpt_path = Path('second.pytorch/second/voxelnet-331653.tckpt') inference_ctx = TorchInferenceContext() inference_ctx.build(config_path) inference_ctx.restore(ckpt_path) return inference_ctx
def BuildVoxelNet(): config_path = Path( '/home/lucerna/MEGAsync/project/AVP/second/configs/xyres_16.proto') ckpt_path = Path( '/home/lucerna/MEGAsync/project/AVP/second/voxelnet-331653.tckpt') inference_ctx = TorchInferenceContext() inference_ctx.build(config_path) inference_ctx.restore(ckpt_path) return inference_ctx
def build_network(): global BACKEND instance = request.json cfg_path = Path(instance["config_path"]) ckpt_path = Path(instance["checkpoint_path"]) response = {"status": "normal"} if BACKEND.root_path is None: return error_response("root path is not set") # if BACKEND.kitti_infos is None: # return error_response("kitti info is not loaded") if not cfg_path.exists(): return error_response("config file not exist.") if not ckpt_path.exists(): return error_response("ckpt file not exist.") BACKEND.inference_ctx = TorchInferenceContext() BACKEND.inference_ctx.build(str(cfg_path)) BACKEND.inference_ctx.restore(str(ckpt_path)) response = jsonify(results=[response]) response.headers['Access-Control-Allow-Headers'] = '*' print("build_network successful!") return response
def _build(self): print("Start build...") self.inference_ctx = TorchInferenceContext() self.inference_ctx.build(self.config_path) print("Build succeeded.")
class SecondModel: def __init__(self, data_path, config_path, ckpt_path, calib_idx=0): self.data_path = data_path self.config_path = config_path self.ckpt_path = ckpt_path self.calib_idx = calib_idx self.calib_info = None self.inference_ctx = None def initialize(self): image_infos = get_kitti_image_info( self.data_path, training=True, label_info=False, calib=True, image_ids=[self.calib_idx] ) self.calib_info = image_infos[0] self._build() self._restore() def predcit(self, pointclouds): t = time.time() result_annos = self._inference(pointclouds) print("Inference time: {} ms".format(int((time.time() - t) * 1000))) kitti_anno = self.remove_low_score(result_annos[0]) lidar_boxes = self.kitti_cam_to_lidar(kitti_anno) return lidar_boxes def _build(self): print("Start build...") self.inference_ctx = TorchInferenceContext() self.inference_ctx.build(self.config_path) print("Build succeeded.") def _restore(self): print("Start restore...") self.inference_ctx.restore(self.ckpt_path) print("Restore succeeded.") def _inference(self, pointclouds): inputs = self.inference_ctx.get_inference_input_dict(self.calib_info, pointclouds) det_annos = self.inference_ctx.inference(inputs) return det_annos def kitti_cam_to_lidar(self, kitti_anno): rect = self.calib_info['calib/R0_rect'] Tr_velo_to_cam = self.calib_info['calib/Tr_velo_to_cam'] dims = kitti_anno['dimensions'] loc = kitti_anno['location'] rots = kitti_anno['rotation_y'] boxes_camera = np.concatenate([loc, dims, rots[..., np.newaxis]], axis=1) boxes_lidar = box_np_ops.box_camera_to_lidar(boxes_camera, rect, Tr_velo_to_cam) return boxes_lidar def remove_low_score(self, annos, threshold=0.5): img_filtered_annotations = {} relevant_annotation_indices = [i for i, s in enumerate(annos['score']) if s >= threshold] for key in annos.keys(): img_filtered_annotations[key] = (annos[key][relevant_annotation_indices]) return img_filtered_annotations
def build_vxnet(self): print("Start build_vxnet...") self.inference_ctx = TorchInferenceContext() self.inference_ctx.build(self.config_path) self.json_setting.set("latest_vxnet_cfg_path", self.config_path) print("Build VoxelNet ckpt succeeded.")
class Processor_ROS: def __init__(self, calib_path, config_path, ckpt_path): self.points = None self.json_setting = Settings(str('/home/hradt/' + ".kittiviewerrc")) # self.config_path = self.json_setting.get("latest_vxnet_cfg_path", "") self.calib_path = calib_path self.config_path = config_path self.ckpt_path = ckpt_path self.calib_info = None self.inputs = None self.inference_ctx = None def initialize(self): self.read_calib() self.build_vxnet() self.load_vxnet() def run(self, points): num_features = 4 rect = self.calib_info['calib/R0_rect'] P2 = self.calib_info['calib/P2'] Trv2c = self.calib_info['calib/Tr_velo_to_cam'] image_shape = self.calib_info['img_shape'] self.points = points.reshape([-1, num_features]) # self.points = box_np_ops.remove_outside_points( # self.points, rect, Trv2c, P2, image_shape) # print(self.points) [results] = self.inference_vxnet() results = remove_low_score(results, 0.5) dt_boxes_corners, scores, dt_box_lidar = kitti_anno_to_corners( self.calib_info, results) print("dt_box_lidar: ", dt_box_lidar) return dt_boxes_corners, scores, dt_box_lidar def _extend_matrix(self, mat): mat = np.concatenate([mat, np.array([[0., 0., 0., 1.]])], axis=0) return mat def read_calib(self, extend_matrix=True): # print(self.calib_path) print("Start read_calib...") calib_info = {'calib_path': self.calib_path} with open(self.calib_path, 'r') as f: lines = f.readlines() P0 = np.array([float(info) for info in lines[0].split(' ')[1:13]]).reshape([3, 4]) P1 = np.array([float(info) for info in lines[1].split(' ')[1:13]]).reshape([3, 4]) P2 = np.array([float(info) for info in lines[2].split(' ')[1:13]]).reshape([3, 4]) P3 = np.array([float(info) for info in lines[3].split(' ')[1:13]]).reshape([3, 4]) if extend_matrix: P0 = self._extend_matrix(P0) P1 = self._extend_matrix(P1) P2 = self._extend_matrix(P2) P3 = self._extend_matrix(P3) # calib_info['calib/P0'] = P0 # calib_info['calib/P1'] = P1 calib_info['calib/P2'] = P2 # calib_info['calib/P3'] = P3 R0_rect = np.array([float(info) for info in lines[4].split(' ')[1:10] ]).reshape([3, 3]) if extend_matrix: rect_4x4 = np.zeros([4, 4], dtype=R0_rect.dtype) rect_4x4[3, 3] = 1. rect_4x4[:3, :3] = R0_rect else: rect_4x4 = R0_rect calib_info['calib/R0_rect'] = rect_4x4 Tr_velo_to_cam = np.array([ float(info) for info in lines[5].split(' ')[1:13] ]).reshape([3, 4]) Tr_imu_to_velo = np.array([ float(info) for info in lines[6].split(' ')[1:13] ]).reshape([3, 4]) if extend_matrix: Tr_velo_to_cam = self._extend_matrix(Tr_velo_to_cam) Tr_imu_to_velo = self._extend_matrix(Tr_imu_to_velo) calib_info['calib/Tr_velo_to_cam'] = Tr_velo_to_cam # calib_info['calib/Tr_imu_to_velo'] = Tr_imu_to_velo # add image shape info for lidar point cloud preprocessing calib_info["img_shape"] = np.array( [375, 1242]) # kitti image size: height, width self.calib_info = calib_info print("Read calib file succeeded.") def build_vxnet(self): print("Start build_vxnet...") self.inference_ctx = TorchInferenceContext() self.inference_ctx.build(self.config_path) self.json_setting.set("latest_vxnet_cfg_path", self.config_path) print("Build VoxelNet ckpt succeeded.") def load_vxnet(self): print("Start load_vxnet...") self.json_setting.set("latest_vxnet_ckpt_path", self.ckpt_path) self.inference_ctx.restore(ckpt_path) print("Load VoxelNet ckpt succeeded.") def inference_vxnet(self): print("Start inference_vxnet...") t = time.time() self.inputs = self.inference_ctx.get_inference_input_dict_ros( self.calib_info, self.points) print("input preparation time:", time.time() - t) # print("self.inputs['points'] shape: ", self.inputs["points"].shape) # print("self.inputs['points']: ", self.inputs["points"]) t = time.time() with self.inference_ctx.ctx(): det_annos = self.inference_ctx.inference(self.inputs) # print(det_annos) print("detection time:", time.time() - t) return det_annos
) ckpt_path = Path( '/home/dongwonshin/Desktop/second.pytorch_multiclass/pretrained_models/original_model/voxelnet-74240.tckpt' ) os.environ["CUDA_VISIBLE_DEVICES"] = "1" publish_topic_name = '/inference_results' lidar_frame_step = 1 point_sampling_step = 1 if __name__ == '__main__': # Network model load with open(info_path, 'rb') as f: kitti_infos = pickle.load(f) inference_ctx = TorchInferenceContext() inference_ctx.build(config_path) inference_ctx.restore(ckpt_path) # publisher init rospy.init_node('SECOND network pub_example') pcl_pub = rospy.Publisher(publish_topic_name, PointCloud2) rospy.sleep(1.) rate = rospy.Rate(3) # for lidar_file_path in lidar_file_paths[1::lidar_frame_step]: idx = 0 while not rospy.is_shutdown(): all_points = [] for flip in [True, False]:
class KittiViewer(QMainWindow): def __init__(self): super().__init__() self.title = 'F110 AVP Viewer' self.bbox_window = [10, 10, 1600, 900] self.sstream = sysio.StringIO() self.json_setting = Settings(str(Path.home() / ".kittiviewerrc")) self.kitti_infos = None self.detection_annos = None self.image_idxes = None self.root_path = None self.current_idx = 0 self.dt_image_idxes = None self.current_image = None self.init_ui() self.kitti_info = None self.points = None self.gt_boxes = None self.gt_names = None self.difficulty = None self.group_ids = None self.inference_ctx = None def init_ui(self): # ouster lidar # self.os1 = OS1('10.5.5.66', '10.5.5.1', mode='2048x10') # self.os1.start() # self.unprocessed_packets = Queue() # self.beam_intrinsics = json.loads(self.os1.get_beam_intrinsics()) # self.beam_alt_angles = self.beam_intrinsics['beam_altitude_angles'] # self.beam_az_angles = self.beam_intrinsics['beam_azimuth_angles'] # self.pc_handler = frame_handler(self.unprocessed_packets) # self.ind_list = np.array(range(1, 1024, 4)) self.setWindowTitle(self.title) self.setGeometry(*self.bbox_window) # self.statusBar().showMessage('Message in statusbar.') control_panel_layout = QVBoxLayout() root_path = self.json_setting.get("kitti_root_path", "") self.w_root_path = QLineEdit(root_path) image_idx = self.json_setting.get("image_idx", "0") self.w_imgidx = QLineEdit(image_idx) det_path = self.json_setting.get("latest_det_path", "") self.w_det_path = QLineEdit(det_path) up_scale = self.json_setting.get("up_scale", "") self.w_up_scale = QLineEdit(up_scale) w_x_shift = self.json_setting.get("w_x_shift", "0") self.w_x_shift = QLineEdit(w_x_shift) w_y_shift = self.json_setting.get("w_y_shift", "0") self.w_y_shift = QLineEdit(w_y_shift) w_z_shift = self.json_setting.get("w_z_shift", "0") self.w_z_shift = QLineEdit(w_z_shift) # self.w_cmd = QLineEdit() # self.w_cmd.returnPressed.connect(self.on_CmdReturnPressed) self.w_load = QPushButton('load info') self.w_load.clicked.connect(self.on_loadButtonPressed) self.w_load_det = QPushButton('load detection') self.w_load_det.clicked.connect(self.on_loadDetPressed) self.w_config = KittiDrawControl('ctrl') config = self.json_setting.get("config", "") if config != "": self.w_config.loads(config) self.w_config.configChanged.connect(self.on_configchanged) self.w_plot = QPushButton('plot') self.w_plot.clicked.connect(self.on_plotButtonPressed) self.w_show_panel = QPushButton('control panel') self.w_show_panel.clicked.connect(self.on_panel_clicked) center_widget = QWidget(self) self.w_output = QTextEdit() self.w_config_gbox = QGroupBox("Read Config") layout = QFormLayout() layout.addRow(QLabel("root path:"), self.w_root_path) layout.addRow(QLabel("PC path:"), self.w_det_path) layout.addRow(QLabel("Up Scale:"), self.w_up_scale) layout.addRow(QLabel("x shift:"), self.w_x_shift) layout.addRow(QLabel("y shift:"), self.w_y_shift) layout.addRow(QLabel("z shift:"), self.w_z_shift) self.w_config_gbox.setLayout(layout) control_panel_layout.addWidget(self.w_config_gbox) h_layout = QHBoxLayout() h_layout.addWidget(self.w_load) control_panel_layout.addLayout(h_layout) h_layout = QHBoxLayout() control_panel_layout.addLayout(h_layout) control_panel_layout.addWidget(self.w_show_panel) vcfg_path = self.json_setting.get("latest_vxnet_cfg_path", "") self.w_vconfig_path = QLineEdit(vcfg_path) vckpt_path = self.json_setting.get("latest_vxnet_ckpt_path", "") self.w_vckpt_path = QLineEdit(vckpt_path) layout = QFormLayout() layout.addRow(QLabel("config path:"), self.w_vconfig_path) layout.addRow(QLabel("ckpt path:"), self.w_vckpt_path) control_panel_layout.addLayout(layout) self.w_build_net = QPushButton('Build Network') self.w_build_net.clicked.connect(self.on_BuildVxNetPressed) self.w_load_ckpt = QPushButton('load Network checkpoint') self.w_load_ckpt.clicked.connect(self.on_loadVxNetCkptPressed) h_layout = QHBoxLayout() h_layout.addWidget(self.w_build_net) h_layout.addWidget(self.w_load_ckpt) control_panel_layout.addLayout(h_layout) self.w_inference = QPushButton('Inference Network') self.w_inference.clicked.connect(self.on_InferenceVxNetPressed) control_panel_layout.addWidget(self.w_inference) self.w_load_infer = QPushButton('Load and Inference Network') self.w_load_infer.clicked.connect(self.on_LoadInferenceVxNetPressed) control_panel_layout.addWidget(self.w_load_infer) self.gt_combobox = QComboBox() self.gt_combobox.addItem("All") save_image_path = self.json_setting.get("save_image_path", "") self.w_image_save_path = QLineEdit(save_image_path) self.w_save_image = QPushButton('save image') self.w_save_image.clicked.connect(self.on_saveimg_clicked) control_panel_layout.addWidget(self.w_image_save_path) control_panel_layout.addWidget(self.w_save_image) control_panel_layout.addWidget(self.w_output) self.center_layout = QHBoxLayout() self.w_pc_viewer = KittiPointCloudView( self.w_config, coors_range=self.w_config.get("CoorsRange")) self.center_layout.addWidget(self.w_pc_viewer) self.center_layout.addLayout(control_panel_layout) self.center_layout.setStretch(0, 2) self.center_layout.setStretch(1, 1) center_widget.setLayout(self.center_layout) self.setCentralWidget(center_widget) self.show() # def worker(self, queue, beam_altitude_angles, beam_azimuth_angles): # build_trig_table(beam_altitude_angles, beam_azimuth_angles) # while True: # buffer = queue.get()["buffer"] # buffer_len = len(buffer) # points = np.zeros((3, 256*buffer_len)) # for ind in range(buffer_len): # packet = buffer[ind] # coords = np.array(xyz_points(packet, os16=False)) # coords = coords[:, self.ind_list] # points[:, ind*256:(ind+1)*256] = coords # self.points = np.transpose(points) # # print(self.points.shape) # # self.plot_pointcloud() # # print(points) # def ouster_worker(self): # self.os1.run_forever(self.pc_handler) # def spawn_workers(self, n, worker, *args, **kwargs): # processes = [] # for i in range(n): # process = Process( # target=worker, # args=args, # kwargs=kwargs # ) # process.start() # processes.append(process) # return processes def on_panel_clicked(self): if self.w_config.isHidden(): self.w_config.show() else: self.w_config.hide() def on_saveimg_clicked(self): self.save_image(self.current_image) def on_gt_checkbox_statechanged(self): self.w_cb_gt_curcls.setChecked(True) self.w_cb_dt_curcls.setChecked(False) def on_dt_checkbox_statechanged(self): self.w_cb_gt_curcls.setChecked(False) self.w_cb_dt_curcls.setChecked(True) def on_gt_combobox_changed(self): self._current_gt_cls_idx = 0 self.on_loadButtonPressed() def on_dt_combobox_changed(self): self._current_dt_cls_idx = 0 annos = kitti.filter_empty_annos(self.detection_annos) if self.dt_image_idxes is not None and annos is not None: current_class = self.dt_combobox.currentText() if current_class == "All": self._current_dt_cls_ids = self.dt_image_idxes else: self._current_dt_cls_ids = [ anno["image_idx"][0] for anno in annos if current_class in anno["name"] ] def message(self, value, *arg, color="Black"): colorHtml = f"<font color=\"{color}\">" endHtml = "</font><br>" msg = self.print_str(value, *arg) self.w_output.insertHtml(colorHtml + msg + endHtml) self.w_output.verticalScrollBar().setValue( self.w_output.verticalScrollBar().maximum()) def error(self, value, *arg): time_str = datetime.datetime.now().strftime("[%H:%M:%S]") return self.message(time_str, value, *arg, color="Red") def info(self, value, *arg): time_str = datetime.datetime.now().strftime("[%H:%M:%S]") return self.message(time_str, value, *arg, color="Black") def warning(self, value, *arg): time_str = datetime.datetime.now().strftime("[%H:%M:%S]") return self.message(time_str, value, *arg, color="Yellow") def save_image(self, image): img_path = self.w_image_save_path.text() self.json_setting.set("save_image_path", img_path) if self.current_image is not None: io.imsave(img_path, image) # p = self.w_pc_viewer.grab() p = self.w_pc_viewer.grabFrameBuffer() # p = QtGui.QPixmap.grabWindow(self.w_pc_viewer) pc_img_path = str( Path(img_path).parent / (str(Path(img_path).stem) + "_pc.jpg")) # p.save(pc_img_path, 'jpg') p.save(pc_img_path, 'jpg') self.info("image saved to", img_path) def print_str(self, value, *arg): #self.strprint.flush() self.sstream.truncate(0) self.sstream.seek(0) print(value, *arg, file=self.sstream) return self.sstream.getvalue() def on_nextOrPrevPressed(self, prev): if prev is True: self.current_idx = max(self.current_idx - 1, 0) else: info_len = len(self.image_idxes) self.current_idx = min(self.current_idx + 1, info_len - 1) image_idx = self.image_idxes[self.current_idx] self.w_imgidx.setText(str(image_idx)) self.plot_all(image_idx) def on_nextOrPrevCurClsPressed(self, prev): if self.w_cb_dt_curcls.isChecked(): if prev is True: self._current_dt_cls_idx = max(self._current_dt_cls_idx - 1, 0) else: info_len = len(self._current_dt_cls_ids) self._current_dt_cls_idx = min(self._current_dt_cls_idx + 1, info_len - 1) image_idx = self._current_dt_cls_ids[self._current_dt_cls_idx] self.info("current dt image idx:", image_idx) elif self.w_cb_gt_curcls.isChecked(): if prev is True: self._current_gt_cls_idx = max(self._current_gt_cls_idx - 1, 0) else: info_len = len(self._current_gt_cls_ids) self._current_gt_cls_idx = min(self._current_gt_cls_idx + 1, info_len - 1) image_idx = self._current_gt_cls_ids[self._current_gt_cls_idx] self.info("current gt image idx:", image_idx) self.plot_all(image_idx) def on_CmdReturnPressed(self): cmd = self.print_str(self.cmd.text()) self.output.insertPlainText(cmd) def on_loadButtonPressed(self): # workers = self.spawn_workers(1, self.worker, self.unprocessed_packets, self.beam_alt_angles, self.beam_az_angles) # workers = self.spawn_workers(1, self.ouster_worker) self.load_info() self.plot_pointcloud() # self.os1.handle_request(self.handler) def on_loadDetPressed(self): det_path = self.w_det_path.text() if Path(det_path).is_file(): with open(det_path, "rb") as f: dt_annos = pickle.load(f) else: dt_annos = kitti.get_label_annos(det_path) if len(dt_annos) == 0: self.warning("detection path contain nothing.") return self.detection_annos = dt_annos self.info(f"load {len(dt_annos)} detections.") self.json_setting.set("latest_det_path", det_path) annos = kitti.filter_empty_annos(self.detection_annos) self.dt_image_idxes = [anno["image_idx"][0] for anno in annos] # get class in dt available_cls = [] for anno in self.detection_annos: for name in anno["name"]: if name not in available_cls: available_cls.append(name) self.dt_combobox.clear() self.dt_combobox.addItem("All") for cls_name in available_cls: self.dt_combobox.addItem(cls_name) current_class = self.dt_combobox.currentText() if current_class == "All": self._current_dt_cls_ids = self.dt_image_idxes else: self._current_dt_cls_ids = [ anno["image_idx"][0] for anno in annos if anno["name"] == current_class ] self._current_dt_cls_idx = 0 """ if self.kitti_infos is not None: t = time.time() gt_annos = [info["annos"] for info in self.kitti_infos] self.message(get_official_eval_result(gt_annos, dt_annos, 0)) self.message(f"eval use time: {time.time() - t:.4f}") """ def sample_to_current_data(self): if self.kitti_info is None: self.error( "you must load infos and choose a existing image idx first.") return sampled_difficulty = [] # class_names = ["Car"] rect = self.kitti_info['calib/R0_rect'] P2 = self.kitti_info['calib/P2'] Trv2c = self.kitti_info['calib/Tr_velo_to_cam'] num_features = 4 if 'pointcloud_num_features' in self.kitti_info: num_features = self.kitti_info['pointcloud_num_features'] # class_names = self.w_config.get("UsedClass") # class_names_group = [["trailer", "tractor"]] if self.db_sampler is not None: # gt_boxes_mask = np.array( # [n in class_names for n in self.gt_names], dtype=np.bool_) gt_boxes_mask = np.ones((self.gt_names.shape[0], ), np.bool_) sampled_dict = self.db_sampler.sample_all( self.root_path, self.gt_boxes, self.gt_names, num_features, False, gt_group_ids=self.group_ids, rect=rect, Trv2c=Trv2c, P2=P2) if sampled_dict is not None: sampled_gt_names = sampled_dict["gt_names"] sampled_gt_boxes = sampled_dict["gt_boxes"] sampled_points = sampled_dict["points"] sampled_gt_masks = sampled_dict["gt_masks"] sampled_difficulty = sampled_dict["difficulty"] # gt_names = gt_names[gt_boxes_mask].tolist() self.gt_names = np.concatenate( [self.gt_names, sampled_gt_names], axis=0) # gt_names += [s["name"] for s in sampled] self.gt_boxes = np.concatenate( [self.gt_boxes, sampled_gt_boxes]) gt_boxes_mask = np.concatenate( [gt_boxes_mask, sampled_gt_masks], axis=0) self.difficulty = np.concatenate( [self.difficulty, sampled_difficulty], axis=0) self.points = np.concatenate([sampled_points, self.points], axis=0) sampled_group_ids = sampled_dict["group_ids"] if self.group_ids is not None: self.group_ids = np.concatenate( [self.group_ids, sampled_group_ids]) ''' prep.noise_per_object_( self.gt_boxes, self.points, gt_boxes_mask, rotation_perturb=[-1.57, 1.57], center_noise_std=[1.0, 1.0, 1.0], num_try=50)''' # should remove unrelated objects after noise per object self.gt_boxes = self.gt_boxes[gt_boxes_mask] self.gt_names = self.gt_names[gt_boxes_mask] self.difficulty = self.difficulty[gt_boxes_mask] if self.group_ids is not None: self.group_ids = self.group_ids[gt_boxes_mask] else: self.error("you enable sample but not provide a database") def draw_detection(self, detection_anno, label_color=GLColor.Blue): dt_box_color = self.w_config.get("DTBoxColor")[:3] dt_box_color = (*dt_box_color, self.w_config.get("DTBoxAlpha")) dt_box_lidar = np.array( [detection_anno["box3d_lidar"].detach().cpu().numpy()])[0] scores = np.array([detection_anno["scores"].detach().cpu().numpy()])[0] # filter by score keep_list = np.where(scores > 0.2)[0] dt_box_lidar = dt_box_lidar[keep_list, :] scores = scores[keep_list] dt_boxes_corners = box_np_ops.center_to_corner_box3d( dt_box_lidar[:, :3], dt_box_lidar[:, 3:6], dt_box_lidar[:, 6], origin=[0.5, 0.5, 0], axis=2) # filter bbox by its center centers = (dt_boxes_corners[:, 0, :] + dt_boxes_corners[:, 6, :]) / 2 keep_list = np.where((centers[:, 0] < self.points_range[0]) & (centers[:, 0] > self.points_range[3]) & \ (centers[:, 1] < self.points_range[1]) & (centers[:, 1] > self.points_range[4]) & \ (centers[:, 2] < self.points_range[2]) & (centers[:, 2] > self.points_range[5]))[0] dt_boxes_corners = dt_boxes_corners[keep_list, :, :] dt_box_lidar = dt_box_lidar[keep_list, :] scores = scores[keep_list] num_dt = dt_box_lidar.shape[0] self.info('num_dt', num_dt) if num_dt != 0: for ind in range(num_dt): self.info('scores', scores[ind]) self.info('dt_box_lidar', dt_box_lidar[ind]) dt_box_color = np.tile( np.array(dt_box_color)[np.newaxis, ...], [num_dt, 1]) scores_rank = scores / scores[0] # if self.w_config.get("DTScoreAsAlpha") and scores is not None: # dt_box_color = np.concatenate([dt_box_color[:, :3], scores[..., np.newaxis]], axis=1) dt_box_color = np.concatenate( [dt_box_color[:, :3], scores_rank[..., np.newaxis]], axis=1) # dt_box_color = np.concatenate([dt_box_color[:, :3], np.ones((scores[..., np.newaxis].shape))], axis=1) self.w_pc_viewer.boxes3d("dt_boxes", dt_boxes_corners, dt_box_color, self.w_config.get("DTBoxLineWidth"), 1.0) def plot_pointcloud(self): point_color = self.w_config.get("PointColor")[:3] point_color = (*point_color, self.w_config.get("PointAlpha")) point_color = np.tile(np.array(point_color), [self.points.shape[0], 1]) self.w_pc_viewer.reset_camera() point_size = np.full([self.points.shape[0]], self.w_config.get("PointSize"), dtype=np.float32) self.w_pc_viewer.draw_bounding_box() self.w_pc_viewer.remove("dt_boxes/labels") self.w_pc_viewer.remove("dt_boxes") if self.detection_annos is not None and self.w_config.get( "DrawDTBoxes"): detection_anno = self.detection_annos[0] self.draw_detection(detection_anno) if self.w_config.get("WithReflectivity"): if self.points.shape[1] < 4: self.error("Your pointcloud don't contain reflectivity.") else: point_color = np.concatenate( [point_color[:, :3], self.points[:, 3:4] * 0.8 + 0.2], axis=1) self.w_pc_viewer.scatter("pointcloud", self.points[:, :3], point_color, size=point_size) print('DEBUG: plot_pointcloud') def load_info(self): self.json_setting.set("up_scale", str(float(self.w_up_scale.text()))) self.json_setting.set("w_x_shift", str(float(self.w_x_shift.text()))) self.json_setting.set("w_y_shift", str(float(self.w_y_shift.text()))) self.json_setting.set("w_z_shift", str(float(self.w_z_shift.text()))) det_path = self.w_det_path.text() scale_up = float(self.w_up_scale.text()) w_x_shift = float(self.w_x_shift.text()) w_y_shift = float(self.w_y_shift.text()) w_z_shift = float(self.w_z_shift.text()) self.json_setting.set("latest_det_path", det_path) points = np.transpose(np.load(det_path)) self.points_range = [ np.max(points[:, 0]), np.max(points[:, 1]), np.max(points[:, 2]), np.min(points[:, 0]), np.min(points[:, 1]), np.min(points[:, 2]) ] points[:, 0] -= (self.points_range[0] + self.points_range[3]) / 2 points[:, 1] -= (self.points_range[1] + self.points_range[4]) / 2 # points[:, 2] -= self.points_range[5] points[:, 0] += w_x_shift points[:, 1] += w_y_shift points[:, 2] += w_z_shift points = points[np.where(points[:, 2] > 0)] points = points[np.where(points[:, 3] > 100)] points[:, 3] = 0 self.points_range = [ np.max(points[:, 0]), np.max(points[:, 1]), np.max(points[:, 2]), np.min(points[:, 0]), np.min(points[:, 1]), np.min(points[:, 2]) ] print(self.points_range) points = np.array(points) * scale_up self.points_range = [ np.max(points[:, 0]), np.max(points[:, 1]), np.max(points[:, 2]), np.min(points[:, 0]), np.min(points[:, 1]), np.min(points[:, 2]) ] self.points = points img_path = self.w_image_save_path.text() self.w_image_save_path.setText(img_path) self.json_setting.set("save_image_path", img_path) print('DEBUG: self.points.shape', self.points.shape) print('DEBUG: self.points.shape', self.points_range) def plot_all(self, image_idx): self.load_info(image_idx) self.plot_pointcloud() return True def on_plotButtonPressed(self): image_idx = 107 def closeEvent(self, event): config_str = self.w_config.dumps() self.json_setting.set("config", config_str) return super().closeEvent(event) def on_configchanged(self, msg): # self.warning(msg.name, msg.value) # save config to file idx = self.image_idxes.index(self.kitti_info["image_idx"]) config_str = self.w_config.dumps() self.json_setting.set("config", config_str) pc_redraw_msgs = ["PointSize", "PointAlpha", "GTPointSize"] pc_redraw_msgs += ["GTPointAlpha", "WithReflectivity"] pc_redraw_msgs += ["PointColor", "GTPointColor"] box_redraw = ["GTBoxColor", "GTBoxAlpha"] dt_redraw = [ "DTBoxColor", "DTBoxAlpha", "DrawDTLabels", "DTScoreAsAlpha", "DTScoreThreshold", "DTBoxLineWidth" ] vx_redraw_msgs = ["DrawPositiveVoxelsOnly", "DrawVoxels"] vx_redraw_msgs += ["PosVoxelColor", "PosVoxelAlpha"] vx_redraw_msgs += ["NegVoxelColor", "NegVoxelAlpha"] all_redraw_msgs = ["RemoveOutsidePoint"] if msg.name in vx_redraw_msgs: if self.w_config.get("DrawVoxels"): self.w_pc_viewer.draw_voxels(self.points, self.gt_boxes) else: self.w_pc_viewer.remove("voxels") elif msg.name in pc_redraw_msgs: self.plot_pointcloud() elif msg.name in all_redraw_msgs: self.on_plotButtonPressed() elif msg.name in box_redraw: self.plot_gt_boxes_in_pointcloud() elif msg.name in dt_redraw: if self.detection_annos is not None and self.w_config.get( "DrawDTBoxes"): detection_anno = self.detection_annos[idx] self.draw_detection(detection_anno) def on_loadVxNetCkptPressed(self): ckpt_path = Path(self.w_vckpt_path.text()) self.json_setting.set("latest_vxnet_ckpt_path", self.w_vckpt_path.text()) self.inference_ctx.restore(ckpt_path) # self.w_load_ckpt.setText(self.w_load_ckpt.text() + f": {ckpt_path.stem}") self.info("load VoxelNet ckpt succeed.") def on_BuildVxNetPressed(self): self.inference_ctx = TorchInferenceContext() vconfig_path = Path(self.w_vconfig_path.text()) self.inference_ctx.build(vconfig_path) self.json_setting.set("latest_vxnet_cfg_path", str(vconfig_path)) self.info("Build VoxelNet ckpt succeed.") # self.w_load_config.setText(self.w_load_config.text() + f": {vconfig_path.stem}") def on_InferenceVxNetPressed(self): t = time.time() inputs = self.inference_ctx.get_inference_input_dict(self.points) # print('DEBUG inputs') # for key, value in inputs.items() : # print(key, value.shape) # print('') self.info("input preparation time:", time.time() - t) t = time.time() # print('DEBUG: after filter: ', self.points.shape) # pc_limits = np.asarray(self.inference_ctx.config.model.second.post_center_limit_range) # print('DEBUG model') # print(type(pc_limits)) # print('') with self.inference_ctx.ctx(): predictions_dicts = self.inference_ctx.inference(inputs) self.info("detection time:", time.time() - t) if predictions_dicts[0]['scores'] is not None: self.draw_detection(predictions_dicts[0]) def on_LoadInferenceVxNetPressed(self): self.on_BuildVxNetPressed() self.on_loadVxNetCkptPressed() self.on_InferenceVxNetPressed() @staticmethod def get_simpify_labels(labels): label_map = { "Car": "V", "Pedestrian": "P", "Cyclist": "C", "car": "C", "tractor": "T1", "trailer": "T2", } label_count = { "Car": 0, "Pedestrian": 0, "Cyclist": 0, "car": 0, "tractor": 0, "trailer": 0, } ret = [] for i, name in enumerate(labels): count = 0 if name in label_count: count = label_count[name] label_count[name] += 1 else: label_count[name] = 0 ret.append(f"{label_map[name]}{count}") return ret @staticmethod def get_false_pos_neg(gt_boxes, dt_boxes, labels, fp_thresh=0.1): iou = _riou3d_shapely(gt_boxes, dt_boxes) ret = np.full([len(gt_boxes)], 2, dtype=np.int64) assigned_dt = np.zeros([len(dt_boxes)], dtype=np.bool_) label_thresh_map = { "Car": 0.7, "Pedestrian": 0.5, "Cyclist": 0.5, "car": 0.7, "tractor": 0.7, "trailer": 0.7, } tp_thresh = np.array([label_thresh_map[n] for n in labels]) if len(gt_boxes) != 0 and len(dt_boxes) != 0: iou_max_dt_for_gt = iou.max(1) dt_iou_max_dt_for_gt = iou.argmax(1) ret[iou_max_dt_for_gt >= tp_thresh] = 0 ret[np.logical_and(iou_max_dt_for_gt < tp_thresh, iou_max_dt_for_gt > fp_thresh)] = 1 # FP assigned_dt_inds = dt_iou_max_dt_for_gt assigned_dt_inds = assigned_dt_inds[iou_max_dt_for_gt >= fp_thresh] assigned_dt[assigned_dt_inds] = True return ret, assigned_dt
def on_BuildVxNetPressed(self): self.inference_ctx = TorchInferenceContext() vconfig_path = Path(self.w_vconfig_path.text()) self.inference_ctx.build(vconfig_path) self.json_setting.set("latest_vxnet_cfg_path", str(vconfig_path)) self.info("Build VoxelNet ckpt succeed.")