Beispiel #1
0
def BuildVoxelNet():
    config_path = Path('second.pytorch/second/configs/xyres_16.proto')
    ckpt_path = Path('second.pytorch/second/voxelnet-331653.tckpt')
    inference_ctx = TorchInferenceContext()
    inference_ctx.build(config_path)
    inference_ctx.restore(ckpt_path)
    return inference_ctx
Beispiel #2
0
def BuildVoxelNet():
    config_path = Path(
        '/home/lucerna/MEGAsync/project/AVP/second/configs/xyres_16.proto')
    ckpt_path = Path(
        '/home/lucerna/MEGAsync/project/AVP/second/voxelnet-331653.tckpt')
    inference_ctx = TorchInferenceContext()
    inference_ctx.build(config_path)
    inference_ctx.restore(ckpt_path)
    return inference_ctx
Beispiel #3
0
def build_network():
    global BACKEND
    instance = request.json
    cfg_path = Path(instance["config_path"])
    ckpt_path = Path(instance["checkpoint_path"])
    response = {"status": "normal"}
    if BACKEND.root_path is None:
        return error_response("root path is not set")
    # if BACKEND.kitti_infos is None:
    #     return error_response("kitti info is not loaded")
    if not cfg_path.exists():
        return error_response("config file not exist.")
    if not ckpt_path.exists():
        return error_response("ckpt file not exist.")
    BACKEND.inference_ctx = TorchInferenceContext()
    BACKEND.inference_ctx.build(str(cfg_path))
    BACKEND.inference_ctx.restore(str(ckpt_path))
    response = jsonify(results=[response])
    response.headers['Access-Control-Allow-Headers'] = '*'
    print("build_network successful!")
    return response
Beispiel #4
0
 def _build(self):
     print("Start build...")
     self.inference_ctx = TorchInferenceContext()
     self.inference_ctx.build(self.config_path)
     print("Build succeeded.")
Beispiel #5
0
class SecondModel:
    def __init__(self, data_path, config_path, ckpt_path, calib_idx=0):
        self.data_path = data_path
        self.config_path = config_path
        self.ckpt_path = ckpt_path
        self.calib_idx = calib_idx

        self.calib_info = None
        self.inference_ctx = None
    
    def initialize(self):
        image_infos = get_kitti_image_info(
            self.data_path,
            training=True,
            label_info=False,
            calib=True,
            image_ids=[self.calib_idx]
        )
        self.calib_info = image_infos[0]
        self._build()
        self._restore()
    
    def predcit(self, pointclouds):
        t = time.time()
        result_annos = self._inference(pointclouds)
        print("Inference time: {} ms".format(int((time.time() - t) * 1000)))
        kitti_anno = self.remove_low_score(result_annos[0])
        lidar_boxes = self.kitti_cam_to_lidar(kitti_anno)

        return lidar_boxes
    
    def _build(self):
        print("Start build...")
        self.inference_ctx = TorchInferenceContext()
        self.inference_ctx.build(self.config_path)
        print("Build succeeded.")

    def _restore(self):
        print("Start restore...")
        self.inference_ctx.restore(self.ckpt_path)
        print("Restore succeeded.")

    def _inference(self, pointclouds):
        inputs = self.inference_ctx.get_inference_input_dict(self.calib_info, pointclouds)
        det_annos = self.inference_ctx.inference(inputs)
        return det_annos
    
    def kitti_cam_to_lidar(self, kitti_anno):
        rect = self.calib_info['calib/R0_rect']
        Tr_velo_to_cam = self.calib_info['calib/Tr_velo_to_cam']
        dims = kitti_anno['dimensions']
        loc = kitti_anno['location']
        rots = kitti_anno['rotation_y']
        boxes_camera = np.concatenate([loc, dims, rots[..., np.newaxis]], axis=1)
        boxes_lidar = box_np_ops.box_camera_to_lidar(boxes_camera, rect, Tr_velo_to_cam)

        return boxes_lidar

    def remove_low_score(self, annos, threshold=0.5):
        img_filtered_annotations = {}
        relevant_annotation_indices = [i for i, s in enumerate(annos['score']) if s >= threshold]
        for key in annos.keys():
            img_filtered_annotations[key] = (annos[key][relevant_annotation_indices])

        return img_filtered_annotations
Beispiel #6
0
 def build_vxnet(self):
     print("Start build_vxnet...")
     self.inference_ctx = TorchInferenceContext()
     self.inference_ctx.build(self.config_path)
     self.json_setting.set("latest_vxnet_cfg_path", self.config_path)
     print("Build VoxelNet ckpt succeeded.")
Beispiel #7
0
class Processor_ROS:
    def __init__(self, calib_path, config_path, ckpt_path):
        self.points = None

        self.json_setting = Settings(str('/home/hradt/' + ".kittiviewerrc"))
        # self.config_path = self.json_setting.get("latest_vxnet_cfg_path", "")
        self.calib_path = calib_path
        self.config_path = config_path
        self.ckpt_path = ckpt_path

        self.calib_info = None
        self.inputs = None

        self.inference_ctx = None

    def initialize(self):
        self.read_calib()
        self.build_vxnet()
        self.load_vxnet()

    def run(self, points):
        num_features = 4
        rect = self.calib_info['calib/R0_rect']
        P2 = self.calib_info['calib/P2']
        Trv2c = self.calib_info['calib/Tr_velo_to_cam']
        image_shape = self.calib_info['img_shape']
        self.points = points.reshape([-1, num_features])

        # self.points = box_np_ops.remove_outside_points(
        #             self.points, rect, Trv2c, P2, image_shape)
        # print(self.points)

        [results] = self.inference_vxnet()

        results = remove_low_score(results, 0.5)

        dt_boxes_corners, scores, dt_box_lidar = kitti_anno_to_corners(
            self.calib_info, results)

        print("dt_box_lidar: ", dt_box_lidar)

        return dt_boxes_corners, scores, dt_box_lidar

    def _extend_matrix(self, mat):
        mat = np.concatenate([mat, np.array([[0., 0., 0., 1.]])], axis=0)
        return mat

    def read_calib(self, extend_matrix=True):
        # print(self.calib_path)
        print("Start read_calib...")
        calib_info = {'calib_path': self.calib_path}
        with open(self.calib_path, 'r') as f:
            lines = f.readlines()
        P0 = np.array([float(info)
                       for info in lines[0].split(' ')[1:13]]).reshape([3, 4])
        P1 = np.array([float(info)
                       for info in lines[1].split(' ')[1:13]]).reshape([3, 4])
        P2 = np.array([float(info)
                       for info in lines[2].split(' ')[1:13]]).reshape([3, 4])
        P3 = np.array([float(info)
                       for info in lines[3].split(' ')[1:13]]).reshape([3, 4])
        if extend_matrix:
            P0 = self._extend_matrix(P0)
            P1 = self._extend_matrix(P1)
            P2 = self._extend_matrix(P2)
            P3 = self._extend_matrix(P3)
        # calib_info['calib/P0'] = P0
        # calib_info['calib/P1'] = P1
        calib_info['calib/P2'] = P2
        # calib_info['calib/P3'] = P3
        R0_rect = np.array([float(info) for info in lines[4].split(' ')[1:10]
                            ]).reshape([3, 3])
        if extend_matrix:
            rect_4x4 = np.zeros([4, 4], dtype=R0_rect.dtype)
            rect_4x4[3, 3] = 1.
            rect_4x4[:3, :3] = R0_rect
        else:
            rect_4x4 = R0_rect
        calib_info['calib/R0_rect'] = rect_4x4
        Tr_velo_to_cam = np.array([
            float(info) for info in lines[5].split(' ')[1:13]
        ]).reshape([3, 4])
        Tr_imu_to_velo = np.array([
            float(info) for info in lines[6].split(' ')[1:13]
        ]).reshape([3, 4])
        if extend_matrix:
            Tr_velo_to_cam = self._extend_matrix(Tr_velo_to_cam)
            Tr_imu_to_velo = self._extend_matrix(Tr_imu_to_velo)
        calib_info['calib/Tr_velo_to_cam'] = Tr_velo_to_cam
        # calib_info['calib/Tr_imu_to_velo'] = Tr_imu_to_velo
        # add image shape info for lidar point cloud preprocessing
        calib_info["img_shape"] = np.array(
            [375, 1242])  # kitti image size: height, width
        self.calib_info = calib_info
        print("Read calib file succeeded.")

    def build_vxnet(self):
        print("Start build_vxnet...")
        self.inference_ctx = TorchInferenceContext()
        self.inference_ctx.build(self.config_path)
        self.json_setting.set("latest_vxnet_cfg_path", self.config_path)
        print("Build VoxelNet ckpt succeeded.")

    def load_vxnet(self):
        print("Start load_vxnet...")
        self.json_setting.set("latest_vxnet_ckpt_path", self.ckpt_path)
        self.inference_ctx.restore(ckpt_path)
        print("Load VoxelNet ckpt succeeded.")

    def inference_vxnet(self):
        print("Start inference_vxnet...")
        t = time.time()
        self.inputs = self.inference_ctx.get_inference_input_dict_ros(
            self.calib_info, self.points)
        print("input preparation time:", time.time() - t)

        # print("self.inputs['points'] shape: ", self.inputs["points"].shape)
        # print("self.inputs['points']: ", self.inputs["points"])

        t = time.time()
        with self.inference_ctx.ctx():
            det_annos = self.inference_ctx.inference(self.inputs)
            # print(det_annos)
        print("detection time:", time.time() - t)
        return det_annos
)
ckpt_path = Path(
    '/home/dongwonshin/Desktop/second.pytorch_multiclass/pretrained_models/original_model/voxelnet-74240.tckpt'
)

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
publish_topic_name = '/inference_results'
lidar_frame_step = 1
point_sampling_step = 1

if __name__ == '__main__':

    # Network model load
    with open(info_path, 'rb') as f:
        kitti_infos = pickle.load(f)
    inference_ctx = TorchInferenceContext()
    inference_ctx.build(config_path)
    inference_ctx.restore(ckpt_path)

    # publisher init
    rospy.init_node('SECOND network pub_example')
    pcl_pub = rospy.Publisher(publish_topic_name, PointCloud2)
    rospy.sleep(1.)
    rate = rospy.Rate(3)

    # for lidar_file_path in lidar_file_paths[1::lidar_frame_step]:
    idx = 0
    while not rospy.is_shutdown():

        all_points = []
        for flip in [True, False]:
Beispiel #9
0
class KittiViewer(QMainWindow):
    def __init__(self):
        super().__init__()
        self.title = 'F110 AVP Viewer'
        self.bbox_window = [10, 10, 1600, 900]
        self.sstream = sysio.StringIO()
        self.json_setting = Settings(str(Path.home() / ".kittiviewerrc"))
        self.kitti_infos = None
        self.detection_annos = None
        self.image_idxes = None
        self.root_path = None
        self.current_idx = 0
        self.dt_image_idxes = None
        self.current_image = None
        self.init_ui()
        self.kitti_info = None
        self.points = None
        self.gt_boxes = None
        self.gt_names = None
        self.difficulty = None
        self.group_ids = None
        self.inference_ctx = None

    def init_ui(self):
        # ouster lidar
        # self.os1 = OS1('10.5.5.66', '10.5.5.1', mode='2048x10')
        # self.os1.start()
        # self.unprocessed_packets = Queue()
        # self.beam_intrinsics = json.loads(self.os1.get_beam_intrinsics())
        # self.beam_alt_angles = self.beam_intrinsics['beam_altitude_angles']
        # self.beam_az_angles = self.beam_intrinsics['beam_azimuth_angles']
        # self.pc_handler = frame_handler(self.unprocessed_packets)
        # self.ind_list = np.array(range(1, 1024, 4))

        self.setWindowTitle(self.title)
        self.setGeometry(*self.bbox_window)
        # self.statusBar().showMessage('Message in statusbar.')
        control_panel_layout = QVBoxLayout()
        root_path = self.json_setting.get("kitti_root_path", "")
        self.w_root_path = QLineEdit(root_path)
        image_idx = self.json_setting.get("image_idx", "0")
        self.w_imgidx = QLineEdit(image_idx)
        det_path = self.json_setting.get("latest_det_path", "")
        self.w_det_path = QLineEdit(det_path)
        up_scale = self.json_setting.get("up_scale", "")
        self.w_up_scale = QLineEdit(up_scale)
        w_x_shift = self.json_setting.get("w_x_shift", "0")
        self.w_x_shift = QLineEdit(w_x_shift)
        w_y_shift = self.json_setting.get("w_y_shift", "0")
        self.w_y_shift = QLineEdit(w_y_shift)
        w_z_shift = self.json_setting.get("w_z_shift", "0")
        self.w_z_shift = QLineEdit(w_z_shift)

        # self.w_cmd = QLineEdit()
        # self.w_cmd.returnPressed.connect(self.on_CmdReturnPressed)
        self.w_load = QPushButton('load info')
        self.w_load.clicked.connect(self.on_loadButtonPressed)
        self.w_load_det = QPushButton('load detection')
        self.w_load_det.clicked.connect(self.on_loadDetPressed)
        self.w_config = KittiDrawControl('ctrl')
        config = self.json_setting.get("config", "")
        if config != "":
            self.w_config.loads(config)
        self.w_config.configChanged.connect(self.on_configchanged)
        self.w_plot = QPushButton('plot')
        self.w_plot.clicked.connect(self.on_plotButtonPressed)

        self.w_show_panel = QPushButton('control panel')
        self.w_show_panel.clicked.connect(self.on_panel_clicked)

        center_widget = QWidget(self)
        self.w_output = QTextEdit()
        self.w_config_gbox = QGroupBox("Read Config")
        layout = QFormLayout()
        layout.addRow(QLabel("root path:"), self.w_root_path)
        layout.addRow(QLabel("PC path:"), self.w_det_path)
        layout.addRow(QLabel("Up Scale:"), self.w_up_scale)
        layout.addRow(QLabel("x shift:"), self.w_x_shift)
        layout.addRow(QLabel("y shift:"), self.w_y_shift)
        layout.addRow(QLabel("z shift:"), self.w_z_shift)
        self.w_config_gbox.setLayout(layout)

        control_panel_layout.addWidget(self.w_config_gbox)
        h_layout = QHBoxLayout()
        h_layout.addWidget(self.w_load)
        control_panel_layout.addLayout(h_layout)

        h_layout = QHBoxLayout()
        control_panel_layout.addLayout(h_layout)
        control_panel_layout.addWidget(self.w_show_panel)

        vcfg_path = self.json_setting.get("latest_vxnet_cfg_path", "")
        self.w_vconfig_path = QLineEdit(vcfg_path)
        vckpt_path = self.json_setting.get("latest_vxnet_ckpt_path", "")
        self.w_vckpt_path = QLineEdit(vckpt_path)
        layout = QFormLayout()
        layout.addRow(QLabel("config path:"), self.w_vconfig_path)
        layout.addRow(QLabel("ckpt path:"), self.w_vckpt_path)
        control_panel_layout.addLayout(layout)
        self.w_build_net = QPushButton('Build Network')
        self.w_build_net.clicked.connect(self.on_BuildVxNetPressed)

        self.w_load_ckpt = QPushButton('load Network checkpoint')
        self.w_load_ckpt.clicked.connect(self.on_loadVxNetCkptPressed)
        h_layout = QHBoxLayout()
        h_layout.addWidget(self.w_build_net)
        h_layout.addWidget(self.w_load_ckpt)
        control_panel_layout.addLayout(h_layout)
        self.w_inference = QPushButton('Inference Network')
        self.w_inference.clicked.connect(self.on_InferenceVxNetPressed)
        control_panel_layout.addWidget(self.w_inference)
        self.w_load_infer = QPushButton('Load and Inference Network')
        self.w_load_infer.clicked.connect(self.on_LoadInferenceVxNetPressed)
        control_panel_layout.addWidget(self.w_load_infer)

        self.gt_combobox = QComboBox()
        self.gt_combobox.addItem("All")

        save_image_path = self.json_setting.get("save_image_path", "")
        self.w_image_save_path = QLineEdit(save_image_path)
        self.w_save_image = QPushButton('save image')
        self.w_save_image.clicked.connect(self.on_saveimg_clicked)
        control_panel_layout.addWidget(self.w_image_save_path)
        control_panel_layout.addWidget(self.w_save_image)
        control_panel_layout.addWidget(self.w_output)
        self.center_layout = QHBoxLayout()

        self.w_pc_viewer = KittiPointCloudView(
            self.w_config, coors_range=self.w_config.get("CoorsRange"))

        self.center_layout.addWidget(self.w_pc_viewer)
        self.center_layout.addLayout(control_panel_layout)
        self.center_layout.setStretch(0, 2)
        self.center_layout.setStretch(1, 1)
        center_widget.setLayout(self.center_layout)
        self.setCentralWidget(center_widget)
        self.show()

    # def worker(self, queue, beam_altitude_angles, beam_azimuth_angles):
    #     build_trig_table(beam_altitude_angles, beam_azimuth_angles)
    #     while True:
    #         buffer = queue.get()["buffer"]
    #         buffer_len = len(buffer)
    #         points = np.zeros((3, 256*buffer_len))
    #         for ind in range(buffer_len):
    #             packet = buffer[ind]
    #             coords = np.array(xyz_points(packet, os16=False))
    #             coords = coords[:, self.ind_list]
    #             points[:, ind*256:(ind+1)*256] = coords
    #         self.points = np.transpose(points)
    #         # print(self.points.shape)
    #         # self.plot_pointcloud()
    #         # print(points)

    # def ouster_worker(self):
    #     self.os1.run_forever(self.pc_handler)

    # def spawn_workers(self, n, worker, *args, **kwargs):
    #     processes = []
    #     for i in range(n):
    #         process = Process(
    #             target=worker,
    #             args=args,
    #             kwargs=kwargs
    #         )
    #         process.start()
    #         processes.append(process)
    #     return processes

    def on_panel_clicked(self):
        if self.w_config.isHidden():
            self.w_config.show()
        else:
            self.w_config.hide()

    def on_saveimg_clicked(self):
        self.save_image(self.current_image)

    def on_gt_checkbox_statechanged(self):
        self.w_cb_gt_curcls.setChecked(True)
        self.w_cb_dt_curcls.setChecked(False)

    def on_dt_checkbox_statechanged(self):
        self.w_cb_gt_curcls.setChecked(False)
        self.w_cb_dt_curcls.setChecked(True)

    def on_gt_combobox_changed(self):
        self._current_gt_cls_idx = 0
        self.on_loadButtonPressed()

    def on_dt_combobox_changed(self):
        self._current_dt_cls_idx = 0
        annos = kitti.filter_empty_annos(self.detection_annos)
        if self.dt_image_idxes is not None and annos is not None:
            current_class = self.dt_combobox.currentText()
            if current_class == "All":
                self._current_dt_cls_ids = self.dt_image_idxes
            else:
                self._current_dt_cls_ids = [
                    anno["image_idx"][0] for anno in annos
                    if current_class in anno["name"]
                ]

    def message(self, value, *arg, color="Black"):
        colorHtml = f"<font color=\"{color}\">"
        endHtml = "</font><br>"
        msg = self.print_str(value, *arg)
        self.w_output.insertHtml(colorHtml + msg + endHtml)
        self.w_output.verticalScrollBar().setValue(
            self.w_output.verticalScrollBar().maximum())

    def error(self, value, *arg):
        time_str = datetime.datetime.now().strftime("[%H:%M:%S]")
        return self.message(time_str, value, *arg, color="Red")

    def info(self, value, *arg):
        time_str = datetime.datetime.now().strftime("[%H:%M:%S]")
        return self.message(time_str, value, *arg, color="Black")

    def warning(self, value, *arg):
        time_str = datetime.datetime.now().strftime("[%H:%M:%S]")
        return self.message(time_str, value, *arg, color="Yellow")

    def save_image(self, image):
        img_path = self.w_image_save_path.text()
        self.json_setting.set("save_image_path", img_path)
        if self.current_image is not None:
            io.imsave(img_path, image)
        # p = self.w_pc_viewer.grab()
        p = self.w_pc_viewer.grabFrameBuffer()

        # p = QtGui.QPixmap.grabWindow(self.w_pc_viewer)
        pc_img_path = str(
            Path(img_path).parent / (str(Path(img_path).stem) + "_pc.jpg"))
        # p.save(pc_img_path, 'jpg')
        p.save(pc_img_path, 'jpg')
        self.info("image saved to", img_path)

    def print_str(self, value, *arg):
        #self.strprint.flush()
        self.sstream.truncate(0)
        self.sstream.seek(0)
        print(value, *arg, file=self.sstream)
        return self.sstream.getvalue()

    def on_nextOrPrevPressed(self, prev):
        if prev is True:
            self.current_idx = max(self.current_idx - 1, 0)
        else:
            info_len = len(self.image_idxes)
            self.current_idx = min(self.current_idx + 1, info_len - 1)
        image_idx = self.image_idxes[self.current_idx]
        self.w_imgidx.setText(str(image_idx))
        self.plot_all(image_idx)

    def on_nextOrPrevCurClsPressed(self, prev):
        if self.w_cb_dt_curcls.isChecked():
            if prev is True:
                self._current_dt_cls_idx = max(self._current_dt_cls_idx - 1, 0)
            else:
                info_len = len(self._current_dt_cls_ids)
                self._current_dt_cls_idx = min(self._current_dt_cls_idx + 1,
                                               info_len - 1)
            image_idx = self._current_dt_cls_ids[self._current_dt_cls_idx]
            self.info("current dt image idx:", image_idx)
        elif self.w_cb_gt_curcls.isChecked():
            if prev is True:
                self._current_gt_cls_idx = max(self._current_gt_cls_idx - 1, 0)
            else:
                info_len = len(self._current_gt_cls_ids)
                self._current_gt_cls_idx = min(self._current_gt_cls_idx + 1,
                                               info_len - 1)
            image_idx = self._current_gt_cls_ids[self._current_gt_cls_idx]
            self.info("current gt image idx:", image_idx)
        self.plot_all(image_idx)

    def on_CmdReturnPressed(self):
        cmd = self.print_str(self.cmd.text())
        self.output.insertPlainText(cmd)

    def on_loadButtonPressed(self):
        # workers = self.spawn_workers(1, self.worker, self.unprocessed_packets, self.beam_alt_angles, self.beam_az_angles)
        # workers = self.spawn_workers(1, self.ouster_worker)
        self.load_info()
        self.plot_pointcloud()
        # self.os1.handle_request(self.handler)

    def on_loadDetPressed(self):
        det_path = self.w_det_path.text()
        if Path(det_path).is_file():
            with open(det_path, "rb") as f:
                dt_annos = pickle.load(f)
        else:
            dt_annos = kitti.get_label_annos(det_path)
        if len(dt_annos) == 0:
            self.warning("detection path contain nothing.")
            return
        self.detection_annos = dt_annos
        self.info(f"load {len(dt_annos)} detections.")
        self.json_setting.set("latest_det_path", det_path)
        annos = kitti.filter_empty_annos(self.detection_annos)
        self.dt_image_idxes = [anno["image_idx"][0] for anno in annos]
        # get class in dt
        available_cls = []
        for anno in self.detection_annos:
            for name in anno["name"]:
                if name not in available_cls:
                    available_cls.append(name)

        self.dt_combobox.clear()
        self.dt_combobox.addItem("All")
        for cls_name in available_cls:
            self.dt_combobox.addItem(cls_name)

        current_class = self.dt_combobox.currentText()
        if current_class == "All":
            self._current_dt_cls_ids = self.dt_image_idxes
        else:
            self._current_dt_cls_ids = [
                anno["image_idx"][0] for anno in annos
                if anno["name"] == current_class
            ]
        self._current_dt_cls_idx = 0
        """
        if self.kitti_infos is not None:
            t = time.time()
            gt_annos = [info["annos"] for info in self.kitti_infos]
            self.message(get_official_eval_result(gt_annos, dt_annos, 0))
            self.message(f"eval use time: {time.time() - t:.4f}")
        """

    def sample_to_current_data(self):
        if self.kitti_info is None:
            self.error(
                "you must load infos and choose a existing image idx first.")
            return

        sampled_difficulty = []
        # class_names = ["Car"]
        rect = self.kitti_info['calib/R0_rect']
        P2 = self.kitti_info['calib/P2']
        Trv2c = self.kitti_info['calib/Tr_velo_to_cam']
        num_features = 4
        if 'pointcloud_num_features' in self.kitti_info:
            num_features = self.kitti_info['pointcloud_num_features']

        # class_names = self.w_config.get("UsedClass")
        # class_names_group = [["trailer", "tractor"]]

        if self.db_sampler is not None:
            # gt_boxes_mask = np.array(
            #     [n in class_names for n in self.gt_names], dtype=np.bool_)
            gt_boxes_mask = np.ones((self.gt_names.shape[0], ), np.bool_)
            sampled_dict = self.db_sampler.sample_all(
                self.root_path,
                self.gt_boxes,
                self.gt_names,
                num_features,
                False,
                gt_group_ids=self.group_ids,
                rect=rect,
                Trv2c=Trv2c,
                P2=P2)
            if sampled_dict is not None:
                sampled_gt_names = sampled_dict["gt_names"]
                sampled_gt_boxes = sampled_dict["gt_boxes"]
                sampled_points = sampled_dict["points"]
                sampled_gt_masks = sampled_dict["gt_masks"]
                sampled_difficulty = sampled_dict["difficulty"]
                # gt_names = gt_names[gt_boxes_mask].tolist()
                self.gt_names = np.concatenate(
                    [self.gt_names, sampled_gt_names], axis=0)
                # gt_names += [s["name"] for s in sampled]
                self.gt_boxes = np.concatenate(
                    [self.gt_boxes, sampled_gt_boxes])
                gt_boxes_mask = np.concatenate(
                    [gt_boxes_mask, sampled_gt_masks], axis=0)
                self.difficulty = np.concatenate(
                    [self.difficulty, sampled_difficulty], axis=0)
                self.points = np.concatenate([sampled_points, self.points],
                                             axis=0)
                sampled_group_ids = sampled_dict["group_ids"]
                if self.group_ids is not None:
                    self.group_ids = np.concatenate(
                        [self.group_ids, sampled_group_ids])
            '''
            prep.noise_per_object_(
                self.gt_boxes,
                self.points,
                gt_boxes_mask,
                rotation_perturb=[-1.57, 1.57],
                center_noise_std=[1.0, 1.0, 1.0],
                num_try=50)'''
            # should remove unrelated objects after noise per object
            self.gt_boxes = self.gt_boxes[gt_boxes_mask]
            self.gt_names = self.gt_names[gt_boxes_mask]
            self.difficulty = self.difficulty[gt_boxes_mask]
            if self.group_ids is not None:
                self.group_ids = self.group_ids[gt_boxes_mask]
        else:
            self.error("you enable sample but not provide a database")

    def draw_detection(self, detection_anno, label_color=GLColor.Blue):
        dt_box_color = self.w_config.get("DTBoxColor")[:3]
        dt_box_color = (*dt_box_color, self.w_config.get("DTBoxAlpha"))

        dt_box_lidar = np.array(
            [detection_anno["box3d_lidar"].detach().cpu().numpy()])[0]
        scores = np.array([detection_anno["scores"].detach().cpu().numpy()])[0]

        # filter by score
        keep_list = np.where(scores > 0.2)[0]
        dt_box_lidar = dt_box_lidar[keep_list, :]
        scores = scores[keep_list]

        dt_boxes_corners = box_np_ops.center_to_corner_box3d(
            dt_box_lidar[:, :3],
            dt_box_lidar[:, 3:6],
            dt_box_lidar[:, 6],
            origin=[0.5, 0.5, 0],
            axis=2)

        # filter bbox by its center
        centers = (dt_boxes_corners[:, 0, :] + dt_boxes_corners[:, 6, :]) / 2
        keep_list = np.where((centers[:, 0] < self.points_range[0]) & (centers[:, 0] > self.points_range[3]) & \
                             (centers[:, 1] < self.points_range[1]) & (centers[:, 1] > self.points_range[4]) & \
                             (centers[:, 2] < self.points_range[2]) & (centers[:, 2] > self.points_range[5]))[0]
        dt_boxes_corners = dt_boxes_corners[keep_list, :, :]
        dt_box_lidar = dt_box_lidar[keep_list, :]
        scores = scores[keep_list]

        num_dt = dt_box_lidar.shape[0]
        self.info('num_dt', num_dt)

        if num_dt != 0:
            for ind in range(num_dt):
                self.info('scores', scores[ind])
                self.info('dt_box_lidar', dt_box_lidar[ind])
            dt_box_color = np.tile(
                np.array(dt_box_color)[np.newaxis, ...], [num_dt, 1])
            scores_rank = scores / scores[0]
            # if self.w_config.get("DTScoreAsAlpha") and scores is not None:
            # dt_box_color = np.concatenate([dt_box_color[:, :3], scores[..., np.newaxis]], axis=1)
            dt_box_color = np.concatenate(
                [dt_box_color[:, :3], scores_rank[..., np.newaxis]], axis=1)
            # dt_box_color = np.concatenate([dt_box_color[:, :3], np.ones((scores[..., np.newaxis].shape))], axis=1)
            self.w_pc_viewer.boxes3d("dt_boxes", dt_boxes_corners,
                                     dt_box_color,
                                     self.w_config.get("DTBoxLineWidth"), 1.0)

    def plot_pointcloud(self):
        point_color = self.w_config.get("PointColor")[:3]
        point_color = (*point_color, self.w_config.get("PointAlpha"))
        point_color = np.tile(np.array(point_color), [self.points.shape[0], 1])
        self.w_pc_viewer.reset_camera()
        point_size = np.full([self.points.shape[0]],
                             self.w_config.get("PointSize"),
                             dtype=np.float32)
        self.w_pc_viewer.draw_bounding_box()
        self.w_pc_viewer.remove("dt_boxes/labels")
        self.w_pc_viewer.remove("dt_boxes")
        if self.detection_annos is not None and self.w_config.get(
                "DrawDTBoxes"):
            detection_anno = self.detection_annos[0]
            self.draw_detection(detection_anno)
        if self.w_config.get("WithReflectivity"):
            if self.points.shape[1] < 4:
                self.error("Your pointcloud don't contain reflectivity.")
            else:
                point_color = np.concatenate(
                    [point_color[:, :3], self.points[:, 3:4] * 0.8 + 0.2],
                    axis=1)
        self.w_pc_viewer.scatter("pointcloud",
                                 self.points[:, :3],
                                 point_color,
                                 size=point_size)
        print('DEBUG: plot_pointcloud')

    def load_info(self):
        self.json_setting.set("up_scale", str(float(self.w_up_scale.text())))
        self.json_setting.set("w_x_shift", str(float(self.w_x_shift.text())))
        self.json_setting.set("w_y_shift", str(float(self.w_y_shift.text())))
        self.json_setting.set("w_z_shift", str(float(self.w_z_shift.text())))

        det_path = self.w_det_path.text()
        scale_up = float(self.w_up_scale.text())
        w_x_shift = float(self.w_x_shift.text())
        w_y_shift = float(self.w_y_shift.text())
        w_z_shift = float(self.w_z_shift.text())
        self.json_setting.set("latest_det_path", det_path)
        points = np.transpose(np.load(det_path))
        self.points_range = [
            np.max(points[:, 0]),
            np.max(points[:, 1]),
            np.max(points[:, 2]),
            np.min(points[:, 0]),
            np.min(points[:, 1]),
            np.min(points[:, 2])
        ]

        points[:, 0] -= (self.points_range[0] + self.points_range[3]) / 2
        points[:, 1] -= (self.points_range[1] + self.points_range[4]) / 2
        # points[:, 2] -= self.points_range[5]

        points[:, 0] += w_x_shift
        points[:, 1] += w_y_shift
        points[:, 2] += w_z_shift
        points = points[np.where(points[:, 2] > 0)]
        points = points[np.where(points[:, 3] > 100)]

        points[:, 3] = 0

        self.points_range = [
            np.max(points[:, 0]),
            np.max(points[:, 1]),
            np.max(points[:, 2]),
            np.min(points[:, 0]),
            np.min(points[:, 1]),
            np.min(points[:, 2])
        ]
        print(self.points_range)
        points = np.array(points) * scale_up

        self.points_range = [
            np.max(points[:, 0]),
            np.max(points[:, 1]),
            np.max(points[:, 2]),
            np.min(points[:, 0]),
            np.min(points[:, 1]),
            np.min(points[:, 2])
        ]

        self.points = points
        img_path = self.w_image_save_path.text()
        self.w_image_save_path.setText(img_path)
        self.json_setting.set("save_image_path", img_path)
        print('DEBUG: self.points.shape', self.points.shape)
        print('DEBUG: self.points.shape', self.points_range)

    def plot_all(self, image_idx):
        self.load_info(image_idx)
        self.plot_pointcloud()
        return True

    def on_plotButtonPressed(self):
        image_idx = 107

    def closeEvent(self, event):
        config_str = self.w_config.dumps()
        self.json_setting.set("config", config_str)
        return super().closeEvent(event)

    def on_configchanged(self, msg):
        # self.warning(msg.name, msg.value)
        # save config to file
        idx = self.image_idxes.index(self.kitti_info["image_idx"])
        config_str = self.w_config.dumps()
        self.json_setting.set("config", config_str)
        pc_redraw_msgs = ["PointSize", "PointAlpha", "GTPointSize"]
        pc_redraw_msgs += ["GTPointAlpha", "WithReflectivity"]
        pc_redraw_msgs += ["PointColor", "GTPointColor"]
        box_redraw = ["GTBoxColor", "GTBoxAlpha"]
        dt_redraw = [
            "DTBoxColor", "DTBoxAlpha", "DrawDTLabels", "DTScoreAsAlpha",
            "DTScoreThreshold", "DTBoxLineWidth"
        ]

        vx_redraw_msgs = ["DrawPositiveVoxelsOnly", "DrawVoxels"]
        vx_redraw_msgs += ["PosVoxelColor", "PosVoxelAlpha"]
        vx_redraw_msgs += ["NegVoxelColor", "NegVoxelAlpha"]
        all_redraw_msgs = ["RemoveOutsidePoint"]
        if msg.name in vx_redraw_msgs:
            if self.w_config.get("DrawVoxels"):
                self.w_pc_viewer.draw_voxels(self.points, self.gt_boxes)
            else:
                self.w_pc_viewer.remove("voxels")
        elif msg.name in pc_redraw_msgs:
            self.plot_pointcloud()
        elif msg.name in all_redraw_msgs:
            self.on_plotButtonPressed()
        elif msg.name in box_redraw:
            self.plot_gt_boxes_in_pointcloud()
        elif msg.name in dt_redraw:
            if self.detection_annos is not None and self.w_config.get(
                    "DrawDTBoxes"):
                detection_anno = self.detection_annos[idx]
                self.draw_detection(detection_anno)

    def on_loadVxNetCkptPressed(self):
        ckpt_path = Path(self.w_vckpt_path.text())
        self.json_setting.set("latest_vxnet_ckpt_path",
                              self.w_vckpt_path.text())
        self.inference_ctx.restore(ckpt_path)
        # self.w_load_ckpt.setText(self.w_load_ckpt.text() + f": {ckpt_path.stem}")
        self.info("load VoxelNet ckpt succeed.")

    def on_BuildVxNetPressed(self):
        self.inference_ctx = TorchInferenceContext()
        vconfig_path = Path(self.w_vconfig_path.text())
        self.inference_ctx.build(vconfig_path)
        self.json_setting.set("latest_vxnet_cfg_path", str(vconfig_path))
        self.info("Build VoxelNet ckpt succeed.")
        # self.w_load_config.setText(self.w_load_config.text() + f": {vconfig_path.stem}")

    def on_InferenceVxNetPressed(self):
        t = time.time()
        inputs = self.inference_ctx.get_inference_input_dict(self.points)
        # print('DEBUG inputs')
        # for key, value in inputs.items() :
        #     print(key, value.shape)
        # print('')

        self.info("input preparation time:", time.time() - t)
        t = time.time()
        # print('DEBUG: after filter: ', self.points.shape)
        # pc_limits = np.asarray(self.inference_ctx.config.model.second.post_center_limit_range)
        # print('DEBUG model')
        # print(type(pc_limits))
        # print('')
        with self.inference_ctx.ctx():
            predictions_dicts = self.inference_ctx.inference(inputs)
        self.info("detection time:", time.time() - t)
        if predictions_dicts[0]['scores'] is not None:
            self.draw_detection(predictions_dicts[0])

    def on_LoadInferenceVxNetPressed(self):
        self.on_BuildVxNetPressed()
        self.on_loadVxNetCkptPressed()
        self.on_InferenceVxNetPressed()

    @staticmethod
    def get_simpify_labels(labels):
        label_map = {
            "Car": "V",
            "Pedestrian": "P",
            "Cyclist": "C",
            "car": "C",
            "tractor": "T1",
            "trailer": "T2",
        }
        label_count = {
            "Car": 0,
            "Pedestrian": 0,
            "Cyclist": 0,
            "car": 0,
            "tractor": 0,
            "trailer": 0,
        }
        ret = []
        for i, name in enumerate(labels):
            count = 0
            if name in label_count:
                count = label_count[name]
                label_count[name] += 1
            else:
                label_count[name] = 0
            ret.append(f"{label_map[name]}{count}")
        return ret

    @staticmethod
    def get_false_pos_neg(gt_boxes, dt_boxes, labels, fp_thresh=0.1):
        iou = _riou3d_shapely(gt_boxes, dt_boxes)
        ret = np.full([len(gt_boxes)], 2, dtype=np.int64)
        assigned_dt = np.zeros([len(dt_boxes)], dtype=np.bool_)
        label_thresh_map = {
            "Car": 0.7,
            "Pedestrian": 0.5,
            "Cyclist": 0.5,
            "car": 0.7,
            "tractor": 0.7,
            "trailer": 0.7,
        }
        tp_thresh = np.array([label_thresh_map[n] for n in labels])
        if len(gt_boxes) != 0 and len(dt_boxes) != 0:
            iou_max_dt_for_gt = iou.max(1)
            dt_iou_max_dt_for_gt = iou.argmax(1)
            ret[iou_max_dt_for_gt >= tp_thresh] = 0
            ret[np.logical_and(iou_max_dt_for_gt < tp_thresh,
                               iou_max_dt_for_gt > fp_thresh)] = 1  # FP
            assigned_dt_inds = dt_iou_max_dt_for_gt
            assigned_dt_inds = assigned_dt_inds[iou_max_dt_for_gt >= fp_thresh]
            assigned_dt[assigned_dt_inds] = True
        return ret, assigned_dt
Beispiel #10
0
 def on_BuildVxNetPressed(self):
     self.inference_ctx = TorchInferenceContext()
     vconfig_path = Path(self.w_vconfig_path.text())
     self.inference_ctx.build(vconfig_path)
     self.json_setting.set("latest_vxnet_cfg_path", str(vconfig_path))
     self.info("Build VoxelNet ckpt succeed.")