Beispiel #1
0
class LaserScanVis:
    """Class that creates and handles a visualizer for a pointcloud"""
    def __init__(self,
                 scan,
                 scan_names,
                 label_names,
                 offset=0,
                 semantics=True,
                 instances=False):
        self.scan = scan
        self.scan_names = scan_names
        self.label_names = label_names
        self.offset = offset
        self.semantics = semantics
        self.instances = instances
        # sanity check
        if not self.semantics and self.instances:
            print("Instances are only allowed in when semantics=True")
            raise ValueError

        self.reset()
        self.update_scan()

    def reset(self):
        """ Reset. """
        # last key press (it should have a mutex, but visualization is not
        # safety critical, so let's do things wrong)
        self.action = "no"  # no, next, back, quit are the possibilities

        # new canvas prepared for visualizing data
        self.canvas = SceneCanvas(keys='interactive', show=True)
        # interface (n next, b back, q quit, very simple)
        self.canvas.events.key_press.connect(self.key_press)
        self.canvas.events.draw.connect(self.draw)
        # grid
        self.grid = self.canvas.central_widget.add_grid()

        # laserscan part
        self.scan_view = vispy.scene.widgets.ViewBox(border_color='white',
                                                     parent=self.canvas.scene)
        self.grid.add_widget(self.scan_view, 0, 0)
        self.scan_vis = visuals.Markers()
        self.scan_view.camera = 'turntable'
        self.scan_view.add(self.scan_vis)
        visuals.XYZAxis(parent=self.scan_view.scene)
        # add semantics
        if self.semantics:
            print("Using semantics in visualizer")
            self.sem_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.canvas.scene)
            self.grid.add_widget(self.sem_view, 0, 1)
            self.sem_vis = visuals.Markers()
            self.sem_view.camera = 'turntable'
            self.sem_view.add(self.sem_vis)
            visuals.XYZAxis(parent=self.sem_view.scene)
            # self.sem_view.camera.link(self.scan_view.camera)

        if self.instances:
            print("Using instances in visualizer")
            self.inst_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.canvas.scene)
            self.grid.add_widget(self.inst_view, 0, 2)
            self.inst_vis = visuals.Markers()
            self.inst_view.camera = 'turntable'
            self.inst_view.add(self.inst_vis)
            visuals.XYZAxis(parent=self.inst_view.scene)
            # self.inst_view.camera.link(self.scan_view.camera)

        # img canvas size
        self.multiplier = 1
        self.canvas_W = 1024
        self.canvas_H = 64
        if self.semantics:
            self.multiplier += 1
        if self.instances:
            self.multiplier += 1

        # new canvas for img
        self.img_canvas = SceneCanvas(keys='interactive',
                                      show=True,
                                      size=(self.canvas_W,
                                            self.canvas_H * self.multiplier))
        # grid
        self.img_grid = self.img_canvas.central_widget.add_grid()
        # interface (n next, b back, q quit, very simple)
        self.img_canvas.events.key_press.connect(self.key_press)
        self.img_canvas.events.draw.connect(self.draw)

        # add a view for the depth
        self.img_view = vispy.scene.widgets.ViewBox(
            border_color='white', parent=self.img_canvas.scene)
        self.img_grid.add_widget(self.img_view, 0, 0)
        self.img_vis = visuals.Image(cmap='viridis')
        self.img_view.add(self.img_vis)

        # add semantics
        if self.semantics:
            self.sem_img_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.img_canvas.scene)
            self.img_grid.add_widget(self.sem_img_view, 1, 0)
            self.sem_img_vis = visuals.Image(cmap='viridis')
            self.sem_img_view.add(self.sem_img_vis)

        # add instances
        if self.instances:
            self.inst_img_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.img_canvas.scene)
            self.img_grid.add_widget(self.inst_img_view, 2, 0)
            self.inst_img_vis = visuals.Image(cmap='viridis')
            self.inst_img_view.add(self.inst_img_vis)

    def get_mpl_colormap(self, cmap_name):
        cmap = plt.get_cmap(cmap_name)

        # Initialize the matplotlib color map
        sm = plt.cm.ScalarMappable(cmap=cmap)

        # Obtain linear color range
        color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1]

        return color_range.reshape(256, 3).astype(np.float32) / 255.0

    def update_scan(self):
        # first open data
        self.scan.open_scan(self.scan_names[self.offset])
        if self.semantics:
            self.scan.open_label(self.label_names[self.offset])
            self.scan.colorize()

        # then change names
        title = "scan " + str(self.offset) + " of " + str(len(self.scan_names))
        self.canvas.title = title
        self.img_canvas.title = title

        # then do all the point cloud stuff

        # plot scan
        power = 16
        # print()
        range_data = np.copy(self.scan.unproj_range)
        # print(range_data.max(), range_data.min())
        range_data = range_data**(1 / power)
        # print(range_data.max(), range_data.min())
        viridis_range = ((range_data - range_data.min()) /
                         (range_data.max() - range_data.min()) * 255).astype(
                             np.uint8)
        viridis_map = self.get_mpl_colormap("viridis")
        viridis_colors = viridis_map[viridis_range]
        self.scan_vis.set_data(self.scan.points,
                               face_color=viridis_colors[..., ::-1],
                               edge_color=viridis_colors[..., ::-1],
                               size=1)

        # plot semantics
        if self.semantics:
            self.sem_vis.set_data(
                self.scan.points,
                face_color=self.scan.sem_label_color[..., ::-1],
                edge_color=self.scan.sem_label_color[..., ::-1],
                size=1)

        # plot instances
        if self.instances:
            self.inst_vis.set_data(
                self.scan.points,
                face_color=self.scan.inst_label_color[..., ::-1],
                edge_color=self.scan.inst_label_color[..., ::-1],
                size=1)

        # now do all the range image stuff
        # plot range image
        data = np.copy(self.scan.proj_range)
        # print(data[data > 0].max(), data[data > 0].min())
        data[data > 0] = data[data > 0]**(1 / power)
        data[data < 0] = data[data > 0].min()
        # print(data.max(), data.min())
        data = (data - data[data > 0].min()) / \
            (data.max() - data[data > 0].min())
        # print(data.max(), data.min())
        self.img_vis.set_data(data)
        self.img_vis.update()

        if self.semantics:
            self.sem_img_vis.set_data(self.scan.proj_sem_color[..., ::-1])
            self.sem_img_vis.update()

        if self.instances:
            self.inst_img_vis.set_data(self.scan.proj_inst_color[..., ::-1])
            self.inst_img_vis.update()

    # interface
    def key_press(self, event):
        self.canvas.events.key_press.block()
        self.img_canvas.events.key_press.block()
        if event.key == 'N':
            self.offset += 1
            self.update_scan()
        elif event.key == 'B':
            self.offset -= 1
            self.update_scan()
        elif event.key == 'Q' or event.key == 'Escape':
            self.destroy()

    def draw(self, event):
        if self.canvas.events.key_press.blocked():
            self.canvas.events.key_press.unblock()
        if self.img_canvas.events.key_press.blocked():
            self.img_canvas.events.key_press.unblock()

    def destroy(self):
        # destroy the visualization
        self.canvas.close()
        self.img_canvas.close()
        vispy.app.quit()

    def run(self):
        vispy.app.run()
class PointVis:
    """Class that creates and handles a visualizer for a pointcloud"""
    def __init__(self,
                 target_pts=None,
                 viz_dict=None,
                 viz_point=True,
                 viz_label=True,
                 viz_joint=False,
                 viz_box=False):
        self.viz_point = viz_point
        self.viz_label = viz_label
        self.viz_joint = viz_joint
        self.viz_box = viz_box
        self.viz_label = viz_label
        self.reset()

        self.update_scan(target_pts, viz_dict)

    def reset(self, sem_color_dict=None):
        """ Reset. """
        # new canvas prepared for visualizing data
        self.map_color(sem_color_dict=sem_color_dict)
        self.canvas = SceneCanvas(keys='interactive', show=True)
        # grid
        self.grid = self.canvas.central_widget.add_grid()

        # laserscan part
        self.scan_view = vispy.scene.widgets.ViewBox(border_color='white',
                                                     parent=self.canvas.scene)
        self.grid.add_widget(self.scan_view, 0, 0)
        self.scan_view.camera = 'turntable'

        self.scan_vis = visuals.Markers()
        self.scan_view.add(self.scan_vis)

        if self.viz_joint:
            self.joint_vis = visuals.Arrow(connect='segments',
                                           arrow_size=18,
                                           color='blue',
                                           width=10,
                                           arrow_type='angle_60')
            self.arrow_length = 10
            self.scan_view.add(self.joint_vis)
        if self.viz_box:
            vertices, faces, outline = create_box(width=1,
                                                  height=1,
                                                  depth=1,
                                                  width_segments=1,
                                                  height_segments=1,
                                                  depth_segments=1)
            vertices['color'][:, 3] = 0.2
            # breakpoint()
            self.box = visuals.Box(vertex_colors=vertices['color'],
                                   edge_color='b')
            self.box.transform = STTransform(translate=[-2.5, 0, 0])
            self.theta = 0
            self.phi = 0
            self.scan_view.add(self.box)
        visuals.XYZAxis(parent=self.scan_view.scene)

        # add nocs
        if self.viz_label:
            print("Using nocs in visualizer")
            self.nocs_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.canvas.scene)
            self.grid.add_widget(self.nocs_view, 0, 1)
            self.label_vis = visuals.Markers()
            self.nocs_view.camera = 'turntable'
            self.nocs_view.add(self.label_vis)
            visuals.XYZAxis(parent=self.nocs_view.scene)
            self.nocs_view.camera.link(self.scan_view.camera)

    def update_scan(self, points, viz_dict=None):
        # then change names
        self.canvas.title = "scan "
        if viz_dict is not None:
            # input_point_seq  = viz_dict['input']
            # input_index_seq  = viz_dict['coord'] #(3, 150000, 2)
            # # npts = viz_dict['npts'] # (3,)
            # npts = [150000, 150000, 150000]
            # for i, num in enumerate(npts[:1]):
            #     target_pts = input_point_seq[i, 0, :num, 0:3]
            #     indices = input_index_seq[i, 0]
            #     idxs = np.where(indices[:, 0]==200)[0]
            #     target_pts = target_pts[idxs]
            #     indices = indices[idxs]
            #     y_value = [indices[100, 1], indices[800, 1], indices[1200,1]]
            #     idys = list(np.where(indices[:, 1]==y_value[2])[0]) # + list(np.where(indices[:, 1]==y_value[1])[0]) + list(np.where(indices[:, 1]==y_value[2])[0])
            #     target_pts = target_pts[idys]

            target_pts = np.concatenate(
                [viz_dict['input1'], viz_dict['input2']], axis=0)
            gt_labels = np.concatenate([
                viz_dict['label'], 4 * np.ones(
                    (viz_dict['input2'].shape[0]), dtype=np.int32)
            ],
                                       axis=0)
            print(np.max(target_pts, axis=0).reshape(1, 3))
            print(np.min(target_pts, axis=0).reshape(1, 3))
            # target_pts = target_pts - (np.max(target_pts, axis=0) + np.min(target_pts, axis=0))/2

            power = 16
            range_data = np.copy(np.linalg.norm(target_pts, axis=1))
            range_data = range_data**(1 / power)
            viridis_range = ((range_data - range_data.min()) /
                             (range_data.max() - range_data.min()) *
                             255).astype(np.uint8)
            viridis_map = self.get_mpl_colormap("viridis")
            viridis_colors = viridis_map[viridis_range]

            self.scan_vis.set_data(target_pts,
                                   face_color=viridis_colors[..., ::-1],
                                   edge_color=viridis_colors[..., ::-1],
                                   size=5)

            # plot nocs
            if self.viz_label:
                label_colors = self.sem_color_lut[gt_labels]
                label_colors = label_colors.reshape((-1, 3))
                self.label_vis.set_data(target_pts,
                                        face_color=label_colors[..., ::-1],
                                        edge_color=label_colors[..., ::-1],
                                        size=5)
                # time.sleep(15)
        if self.viz_joint:
            self.update_joints()

        if self.viz_box:
            self.update_boxes()

    def map_color(self, max_classes=20, sem_color_dict=None):
        # make semantic colors
        if sem_color_dict:
            # if I have a dict, make it
            max_sem_key = 0
            for key, data in sem_color_dict.items():
                if key + 1 > max_sem_key:
                    max_sem_key = key + 1
            self.sem_color_lut = np.zeros((max_sem_key + 100, 3),
                                          dtype=np.float32)
            for key, value in sem_color_dict.items():
                self.sem_color_lut[key] = np.array(value, np.float32) / 255.0
        else:
            # otherwise make random
            max_sem_key = max_classes
            self.sem_color_lut = np.random.uniform(low=0.0,
                                                   high=1.0,
                                                   size=(max_sem_key, 3))
            # force zero to a gray-ish color
            self.sem_color_lut[0] = np.full((3), 0.2)
            self.sem_color_lut[4] = np.full((3), 0.6)
            self.sem_color_lut[1] = np.array([1.0, 0.0, 0.0])
            self.sem_color_lut[2] = np.array([0.0, 0.0, 1.0])

    def get_mpl_colormap(self, cmap_name):
        cmap = plt.get_cmap(cmap_name)

        # Initialize the matplotlib color map
        sm = plt.cm.ScalarMappable(cmap=cmap)

        # Obtain linear color range
        color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1]

        return color_range.reshape(256, 3).astype(np.float32) / 255.0

    def update_joints(self, joints=None):
        # plot
        if joints is not None:
            start_coords = joints['p'].reshape(1, 3)
            point_towards = start_coords + joints['l'].reshape(1, 3)
        else:
            start_coords = np.array([[1, 0, 0], [-1, 0, 0]])
            point_towards = np.array([[0, 0, 1], [0, 0, 1]])
        direction_vectors = (start_coords - point_towards).astype(np.float32)
        norms = np.sqrt(np.sum(direction_vectors**2, axis=-1))
        direction_vectors[:, 0] /= norms
        direction_vectors[:, 1] /= norms
        direction_vectors[:, 2] /= norms

        vertices = np.repeat(start_coords, 2, axis=0)
        vertices[::2] = vertices[::2] + (
            (0.5 * self.arrow_length) * direction_vectors)
        vertices[1::2] = vertices[1::2] - (
            (0.5 * self.arrow_length) * direction_vectors)

        self.joint_vis.set_data(
            pos=vertices,
            arrows=vertices.reshape((len(vertices) // 2, 6)),
        )

    def update_boxes(self):
        pass

    def draw(self, event):
        if self.canvas.events.key_press.blocked():
            self.canvas.events.key_press.unblock()
        if self.img_canvas.events.key_press.blocked():
            self.img_canvas.events.key_press.unblock()

    def destroy(self):
        # destroy the visualization
        self.canvas.close()
        vispy.app.quit()

    def run(self):
        vispy.app.run()
class LaserDetVis:
    """Class that creates and handles a visualizer for a pointcloud"""
    def __init__(self, show_img=False):
        self.show_img = show_img
        self.canvas_size = (1920, 1920)
        self.running = True
        self.intensity_mode = False

        self.reset()

    def is_running(self):
        return self.running

    def key_press(self, event):
        raise NotImplementedError

    '''
  Parameters: 
    points: N x 3 or N x 4
  '''

    def add_points(self, points):
        self.points = points[:, 0:3]
        if points.shape[1] >= 4:
            intensity = points[:, 4]
            intensity = ((intensity - intensity.min()) /
                         (intensity.max() - intensity.min()) * 128 +
                         127).astype(np.uint8)
            viridis_map = self.get_mpl_colormap("viridis")
            self.viridis_colors = viridis_map[intensity]

    """ 
  Takes an object and a projection matrix (P) and projects the 3d
    bounding box into the image plane.
    Returns:
      corners_2d: (8,2) array in left image coord.
      corners_3d: (8,3) array in in rect camera coord.
  """

    def compute_box_3d(self, obj):
        """
        7 -------- 4
       /|         /|
      6 -------- 5 .
      | |        | |
      . 3 -------- 0
      |/         |/
      2 -------- 1
    Args:
        boxes3d:  (N, 7) [x, y, z, dx, dy, dz, heading], (x, y, z) is the box center

    Returns:
    """
        boxes3d, is_numpy = check_numpy_to_torch(obj)

        template = boxes3d.new_tensor((
            [1, 1, -1],
            [1, -1, -1],
            [-1, -1, -1],
            [-1, 1, -1],
            [1, 1, 1],
            [1, -1, 1],
            [-1, -1, 1],
            [-1, 1, 1],
        )) / 2

        corners3d = boxes3d[None, 3:6].repeat(8, 1) * template[:, :]
        corners3d = rotate_points_along_z(corners3d.view(-1, 8, 3),
                                          boxes3d[None, 6]).view(8, 3)
        corners3d += boxes3d[0:3]

        connect = np.array([[0, 1], [1, 5], [0, 4], [4, 5], [1, 2], [0, 3],
                            [5, 6], [4, 7], [2, 3], [2, 6], [3, 7], [6, 7]],
                           dtype=np.int32)

        return corners3d if is_numpy else corners3d.numpy(), connect

    '''
  Parameters: 
    objs: N x 7 (x, y, z, w, l, h, yaw)
  '''

    def add_objs(self, objs):
        obj_vertices = []
        obj_vert_connect = []
        obj_label_pos = []
        for i in range(len(objs)):
            vertices, connect = self.compute_box_3d(objs[i])
            obj_vertices.append(vertices)
            obj_vert_connect.append(connect + i * len(vertices))
            obj_label_pos.append(
                (vertices[-1, 0], vertices[-1, 1], vertices[-1, 2]))

        obj_vertices = np.concatenate(obj_vertices, axis=0)
        obj_vert_connect = np.concatenate(obj_vert_connect, axis=0)

        return obj_vertices, obj_vert_connect, obj_label_pos

    '''
  Parameters: 
    data : ndarray
        ImageVisual data. Can be shape (M, N), (M, N, 3), or (M, N, 4).
  '''

    def add_image(self, img):
        self.image = img

    def reset(self):
        """ Reset. """
        # last key press (it should have a mutex, but visualization is not
        # safety critical, so let's do things wrong)
        self.action = "no"  # no, next, back, quit are the possibilities

        # new canvas prepared for visualizing data
        self.canvas = SceneCanvas(keys='interactive',
                                  show=True,
                                  size=self.canvas_size)
        # interface (n next, b back, q quit, very simple)
        self.canvas.events.key_press.connect(self.key_press)
        self.canvas.events.draw.connect(self.draw)

        # laserscan part
        self.scan_view = vispy.scene.widgets.ViewBox(border_color='white',
                                                     parent=self.canvas.scene)
        self.scan_view.camera = vispy.scene.TurntableCamera(elevation=30,
                                                            azimuth=-90,
                                                            distance=30,
                                                            translate_speed=30,
                                                            up='+z')
        # grid
        self.grid = self.canvas.central_widget.add_grid()
        self.grid.add_widget(self.scan_view)

        self.scan_vis = visuals.Markers(parent=self.scan_view.scene)
        self.scan_vis.antialias = 0
        # self.scan_view.add(self.scan_vis)
        visuals.XYZAxis(parent=self.scan_view.scene)

        self.line = visuals.Line(width=1,
                                 method='gl',
                                 parent=self.scan_view.scene)
        self.text = visuals.Text(color='red',
                                 font_size=600,
                                 bold=True,
                                 parent=self.scan_view.scene)
        self.gt_line = visuals.Line(width=1000, parent=self.scan_view.scene)
        # self.sem_view.camera.link(self.scan_view.camera)

        if self.show_img:
            # img canvas size

            # new canvas for img
            self.img_canvas = SceneCanvas(keys='interactive',
                                          show=True,
                                          size=(1242, 375))
            # grid
            self.img_grid = self.img_canvas.central_widget.add_grid()
            # interface (n next, b back, q quit, very simple)
            self.img_canvas.events.key_press.connect(self.key_press)
            self.img_canvas.events.draw.connect(self.draw)

            # add a view for the depth
            self.img_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.img_canvas.scene)
            self.img_grid.add_widget(self.img_view, 0, 0)
            self.img_vis = visuals.Image(cmap='viridis')
            self.img_view.add(self.img_vis)

    def get_mpl_colormap(self, cmap_name):
        cmap = plt.get_cmap(cmap_name)

        # Initialize the matplotlib color map
        sm = plt.cm.ScalarMappable(cmap=cmap)

        # Obtain linear color range
        color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1]

        return color_range.reshape(256, 3).astype(np.float32) / 255.0

    def update_view(self,
                    idx,
                    points=None,
                    objs=None,
                    gt_objs=None,
                    ref_scores=None,
                    ref_labels=None,
                    img=None):
        # then change names
        title = "scan " + str(idx)
        self.canvas.title = title

        # then do all the point cloud stuff

        # plot scan
        if points is not None:
            self.add_points(points)
        if self.viridis_colors is not None:
            self.scan_vis.set_data(
                self.points,
                face_color=self.viridis_colors[..., ::-1],
                edge_color=self.viridis_colors[..., ::-1],
                # face_color='white',
                # edge_color='white',
                size=1)
        else:
            self.scan_vis.set_data(self.points, size=1)

        # plot objs
        if objs is None:
            self.line.set_data()
        else:
            obj_vertices, obj_vert_connect, obj_label_pos = self.add_objs(objs)
            self.line.set_data(pos=obj_vertices,
                               color='red',
                               connect=obj_vert_connect)
            if ref_scores is not None and ref_labels is not None:
                labels = []
                for i in range(len(ref_labels)):
                    labels.append('{:.2f}'.format(ref_scores[i]))
                self.text.text = labels
                self.text.pos = obj_label_pos

        if gt_objs is None:
            self.gt_line.set_data()
        else:
            gt_obj_vertices, gt_obj_vert_connect, _ = self.add_objs(gt_objs)
            self.gt_line.set_data(pos=gt_obj_vertices,
                                  color='green',
                                  connect=gt_obj_vert_connect)

        # plot image
        if self.show_img:
            self.img_canvas.title = title
            if img is not None:
                self.add_image(img)
            self.img_vis.set_data(self.image)
            self.img_vis.update()

    def draw(self, event):
        if self.canvas.events.key_press.blocked():
            self.canvas.events.key_press.unblock()
        if self.show_img:
            if self.img_canvas.events.key_press.blocked():
                self.img_canvas.events.key_press.unblock()

    def destroy(self):
        # destroy the visualization
        self.canvas.close()
        if self.show_img:
            self.img_canvas.close()

        vispy.app.quit()
        self.running = False

    def run(self):
        vispy.app.run()
Beispiel #4
0
class QtViewer(QSplitter):
    with open(os.path.join(resources_dir, 'stylesheet.qss'), 'r') as f:
        raw_stylesheet = f.read()

    def __init__(self, viewer):
        super().__init__()

        self.pool = QThreadPool()

        QCoreApplication.setAttribute(
            Qt.AA_UseStyleSheetPropagationInWidgetStyles, True)

        self.viewer = viewer
        self.dims = QtDims(self.viewer.dims)
        self.controls = QtControls(self.viewer)
        self.layers = QtLayerList(self.viewer.layers)
        self.layerButtons = QtLayerButtons(self.viewer)
        self.viewerButtons = QtViewerButtons(self.viewer)
        self.console = QtConsole({'viewer': self.viewer})

        layerList = QWidget()
        layerList.setObjectName('layerList')
        layerListLayout = QVBoxLayout()
        layerListLayout.addWidget(self.layerButtons)
        layerListLayout.addWidget(self.layers)
        layerListLayout.addWidget(self.viewerButtons)
        layerListLayout.setContentsMargins(8, 4, 8, 6)
        layerList.setLayout(layerListLayout)
        self.dockLayerList = QtViewerDockWidget(
            self,
            layerList,
            name='layer list',
            area='left',
            allowed_areas=['left', 'right'],
        )
        self.dockLayerControls = QtViewerDockWidget(
            self,
            self.controls,
            name='layer controls',
            area='left',
            allowed_areas=['left', 'right'],
        )
        self.dockConsole = QtViewerDockWidget(
            self,
            self.console,
            name='console',
            area='bottom',
            allowed_areas=['top', 'bottom'],
            shortcut='Ctrl+Shift+C',
        )
        self.dockConsole.setVisible(False)
        self.dockLayerControls.visibilityChanged.connect(self._constrain_width)
        self.dockLayerList.setMaximumWidth(258)
        self.dockLayerList.setMinimumWidth(258)

        self.aboutKeybindings = QtAboutKeybindings(self.viewer)
        self.aboutKeybindings.hide()

        # This dictionary holds the corresponding vispy visual for each layer
        self.layer_to_visual = {}

        if self.console.shell is not None:
            self.viewerButtons.consoleButton.clicked.connect(
                lambda: self.toggle_console())
        else:
            self.viewerButtons.consoleButton.setEnabled(False)

        self.canvas = SceneCanvas(keys=None, vsync=True)
        self.canvas.events.ignore_callback_errors = False
        self.canvas.events.draw.connect(self.dims.enable_play)
        self.canvas.native.setMinimumSize(QSize(200, 200))
        self.canvas.context.set_depth_func('lequal')

        self.canvas.connect(self.on_mouse_move)
        self.canvas.connect(self.on_mouse_press)
        self.canvas.connect(self.on_mouse_release)
        self.canvas.connect(self.on_key_press)
        self.canvas.connect(self.on_key_release)
        self.canvas.connect(self.on_draw)

        self.view = self.canvas.central_widget.add_view()
        self._update_camera()

        main_widget = QWidget()
        main_layout = QVBoxLayout()
        main_layout.setContentsMargins(10, 22, 10, 2)
        main_layout.addWidget(self.canvas.native)
        main_layout.addWidget(self.dims)
        main_layout.setSpacing(10)
        main_widget.setLayout(main_layout)

        self.setOrientation(Qt.Vertical)
        self.addWidget(main_widget)

        self._last_visited_dir = str(Path.home())

        self._cursors = {
            'disabled':
            QCursor(
                QPixmap(':/icons/cursor/cursor_disabled.png').scaled(20, 20)),
            'cross':
            Qt.CrossCursor,
            'forbidden':
            Qt.ForbiddenCursor,
            'pointing':
            Qt.PointingHandCursor,
            'standard':
            QCursor(),
        }

        self._update_palette(viewer.palette)

        self._key_release_generators = {}

        self.viewer.events.interactive.connect(self._on_interactive)
        self.viewer.events.cursor.connect(self._on_cursor)
        self.viewer.events.reset_view.connect(self._on_reset_view)
        self.viewer.events.palette.connect(
            lambda event: self._update_palette(event.palette))
        self.viewer.layers.events.reordered.connect(self._reorder_layers)
        self.viewer.layers.events.added.connect(self._add_layer)
        self.viewer.layers.events.removed.connect(self._remove_layer)
        self.viewer.dims.events.camera.connect(
            lambda event: self._update_camera())
        # stop any animations whenever the layers change
        self.viewer.events.layers_change.connect(lambda x: self.dims.stop())

        self.setAcceptDrops(True)

    def _constrain_width(self, event):
        # allow the layer controls to be wider, only if floated
        if self.dockLayerControls.isFloating():
            self.controls.setMaximumWidth(700)
        else:
            self.controls.setMaximumWidth(220)

    def _add_layer(self, event):
        """When a layer is added, set its parent and order."""
        layers = event.source
        layer = event.item
        vispy_layer = create_vispy_visual(layer)
        vispy_layer.camera = self.view.camera
        vispy_layer.node.parent = self.view.scene
        vispy_layer.order = len(layers)
        self.layer_to_visual[layer] = vispy_layer

    def _remove_layer(self, event):
        """When a layer is removed, remove its parent."""
        layer = event.item
        vispy_layer = self.layer_to_visual[layer]
        vispy_layer.node.transforms = ChainTransform()
        vispy_layer.node.parent = None
        del self.layer_to_visual[layer]

    def _reorder_layers(self, event):
        """When the list is reordered, propagate changes to draw order."""
        for i, layer in enumerate(self.viewer.layers):
            vispy_layer = self.layer_to_visual[layer]
            vispy_layer.order = i
        self.canvas._draw_order.clear()
        self.canvas.update()

    def _update_camera(self):
        if self.viewer.dims.ndisplay == 3:
            # Set a 3D camera
            if not isinstance(self.view.camera, ArcballCamera):
                self.view.camera = ArcballCamera(name="ArcballCamera", fov=0)
                # flip y-axis to have correct alignment
                # self.view.camera.flip = (0, 1, 0)

                self.view.camera.viewbox_key_event = viewbox_key_event
                self.viewer.reset_view()
        else:
            # Set 2D camera
            if not isinstance(self.view.camera, PanZoomCamera):
                self.view.camera = PanZoomCamera(aspect=1,
                                                 name="PanZoomCamera")
                # flip y-axis to have correct alignment
                self.view.camera.flip = (0, 1, 0)

                self.view.camera.viewbox_key_event = viewbox_key_event
                self.viewer.reset_view()

    def screenshot(self):
        """Take currently displayed screen and convert to an image array.

        Returns
        -------
        image : array
            Numpy array of type ubyte and shape (h, w, 4). Index [0, 0] is the
            upper-left corner of the rendered region.
        """
        img = self.canvas.native.grabFramebuffer()
        return QImg2array(img)

    def _open_images(self):
        """Add image files from the menubar."""
        filenames, _ = QFileDialog.getOpenFileNames(
            parent=self,
            caption='Select image(s)...',
            directory=self._last_visited_dir,  # home dir by default
        )
        if (filenames != []) and (filenames is not None):
            self._add_files(filenames)

    def _open_folder(self):
        """Add a folder of files from the menubar."""
        folder = QFileDialog.getExistingDirectory(
            parent=self,
            caption='Select folder...',
            directory=self._last_visited_dir,  # home dir by default
        )
        if folder not in {'', None}:
            self._add_files([folder])

    def _add_files(self, filenames):
        """Add an image layer to the viewer.

        If multiple images are selected, they are stacked along the 0th
        axis.

        Parameters
        -------
        filenames : list
            List of filenames to be opened
        """
        if len(filenames) > 0:
            self.viewer.add_image(path=filenames)
            self._last_visited_dir = os.path.dirname(filenames[0])

    def _on_interactive(self, event):
        self.view.interactive = self.viewer.interactive

    def _on_cursor(self, event):
        cursor = self.viewer.cursor
        size = self.viewer.cursor_size
        if cursor == 'square':
            if size < 10 or size > 300:
                q_cursor = self._cursors['cross']
            else:
                q_cursor = QCursor(
                    QPixmap(':/icons/cursor/cursor_square.png').scaledToHeight(
                        size))
        else:
            q_cursor = self._cursors[cursor]
        self.canvas.native.setCursor(q_cursor)

    def _on_reset_view(self, event):
        if isinstance(self.view.camera, ArcballCamera):
            quat = self.view.camera._quaternion.create_from_axis_angle(
                *event.quaternion)
            self.view.camera._quaternion = quat
            self.view.camera.center = event.center
            self.view.camera.scale_factor = event.scale_factor
        else:
            # Assumes default camera has the same properties as PanZoomCamera
            self.view.camera.rect = event.rect

    def _update_palette(self, palette):
        # template and apply the primary stylesheet
        themed_stylesheet = template(self.raw_stylesheet, **palette)
        self.console.style_sheet = themed_stylesheet
        self.console.syntax_style = palette['syntax_style']
        bracket_color = QtGui.QColor(*str_to_rgb(palette['highlight']))
        self.console._bracket_matcher.format.setBackground(bracket_color)
        self.setStyleSheet(themed_stylesheet)
        self.aboutKeybindings.setStyleSheet(themed_stylesheet)
        self.canvas.bgcolor = palette['canvas']

    def toggle_console(self):
        """Toggle console visible and not visible."""
        viz = not self.dockConsole.isVisible()
        # modulate visibility at the dock widget level as console is docakable
        self.dockConsole.setVisible(viz)
        if self.dockConsole.isFloating():
            self.dockConsole.setFloating(True)

        self.viewerButtons.consoleButton.setProperty(
            'expanded', self.dockConsole.isVisible())
        self.viewerButtons.consoleButton.style().unpolish(
            self.viewerButtons.consoleButton)
        self.viewerButtons.consoleButton.style().polish(
            self.viewerButtons.consoleButton)

    def on_mouse_press(self, event):
        """Called whenever mouse pressed in canvas.
        """
        if event.pos is None:
            return

        event = ReadOnlyWrapper(event)
        mouse_press_callbacks(self.viewer, event)

        layer = self.viewer.active_layer
        if layer is not None:
            # Line bellow needed until layer mouse callbacks are refactored
            self.layer_to_visual[layer].on_mouse_press(event)
            mouse_press_callbacks(layer, event)

    def on_mouse_move(self, event):
        """Called whenever mouse moves over canvas.
        """
        if event.pos is None:
            return

        mouse_move_callbacks(self.viewer, event)

        layer = self.viewer.active_layer
        if layer is not None:
            # Line bellow needed until layer mouse callbacks are refactored
            self.layer_to_visual[layer].on_mouse_move(event)
            mouse_move_callbacks(layer, event)

    def on_mouse_release(self, event):
        """Called whenever mouse released in canvas.
        """
        mouse_release_callbacks(self.viewer, event)

        layer = self.viewer.active_layer
        if layer is not None:
            # Line bellow needed until layer mouse callbacks are refactored
            self.layer_to_visual[layer].on_mouse_release(event)
            mouse_release_callbacks(layer, event)

    def on_key_press(self, event):
        """Called whenever key pressed in canvas.
        """
        if (event.native is not None and event.native.isAutoRepeat()
                and event.key.name not in ['Up', 'Down', 'Left', 'Right'
                                           ]) or event.key is None:
            # pass is no key is present or if key is held down, unless the
            # key being held down is one of the navigation keys
            return

        comb = components_to_key_combo(event.key.name, event.modifiers)

        layer = self.viewer.active_layer

        if layer is not None and comb in layer.keymap:
            parent = layer
        elif comb in self.viewer.keymap:
            parent = self.viewer
        else:
            return

        func = parent.keymap[comb]
        gen = func(parent)

        if inspect.isgeneratorfunction(func):
            try:
                next(gen)
            except StopIteration:  # only one statement
                pass
            else:
                self._key_release_generators[event.key] = gen

    def on_key_release(self, event):
        """Called whenever key released in canvas.
        """
        try:
            next(self._key_release_generators[event.key])
        except (KeyError, StopIteration):
            pass

    def on_draw(self, event):
        """Called whenever drawn in canvas. Called for all layers, not just top
        """
        for visual in self.layer_to_visual.values():
            visual.on_draw(event)

    def keyPressEvent(self, event):
        self.canvas._backend._keyEvent(self.canvas.events.key_press, event)
        event.accept()

    def keyReleaseEvent(self, event):
        self.canvas._backend._keyEvent(self.canvas.events.key_release, event)
        event.accept()

    def dragEnterEvent(self, event):
        if event.mimeData().hasUrls():
            event.accept()
        else:
            event.ignore()

    def dropEvent(self, event):
        """Add local files and web URLS with drag and drop."""
        filenames = []
        for url in event.mimeData().urls():
            if url.isLocalFile():
                filenames.append(url.toLocalFile())
            else:
                filenames.append(url.toString())
        self._add_files(filenames)

    def closeEvent(self, event):
        if self.pool.activeThreadCount() > 0:
            self.pool.clear()
        event.accept()

    def shutdown(self):
        self.pool.clear()
        self.canvas.close()
        self.console.shutdown()
Beispiel #5
0
class LaserScanVis:
    """Class that creates and handles a visualizer for a pointcloud"""
    def __init__(self,
                 scan,
                 scan_names,
                 label_names,
                 offset=0,
                 semantics=True,
                 bboxes_names=None,
                 use_bbox_measurements=False,
                 bboxes_labels_names=None,
                 roi_filter=False,
                 instances=False):
        self.scan = scan
        self.scan_names = scan_names
        self.label_names = label_names
        self.offset = offset
        self.semantics = semantics
        self.bboxes_names = bboxes_names
        self.use_bbox_measurements = use_bbox_measurements
        self.bboxes_labels_names = bboxes_labels_names
        self.roi_filter = roi_filter
        self.instances = instances
        # sanity check
        if not self.semantics and self.instances:
            print("Instances are only allowed in when semantics=True")
            raise ValueError

        self.reset()
        self.update_scan()

    def reset(self):
        """ Reset. """
        # last key press (it should have a mutex, but visualization is not
        # safety critical, so let's do things wrong)
        self.action = "no"  # no, next, back, quit are the possibilities

        # new canvas prepared for visualizing data
        self.canvas = SceneCanvas(keys='interactive', show=True)
        # interface (n next, b back, q quit, very simple)
        self.canvas.events.key_press.connect(self.key_press)
        self.canvas.events.draw.connect(self.draw)
        # grid
        self.grid = self.canvas.central_widget.add_grid()

        # laserscan part
        self.scan_view = vispy.scene.widgets.ViewBox(border_color='white',
                                                     parent=self.canvas.scene)
        self.grid.add_widget(self.scan_view, 0, 0)
        self.scan_vis = visuals.Markers()
        self.scan_view.camera = 'turntable'
        self.scan_view.add(self.scan_vis)
        visuals.XYZAxis(parent=self.scan_view.scene)
        # add semantics
        if self.semantics:
            print("Using semantics in visualizer")
            self.sem_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.canvas.scene)
            self.grid.add_widget(self.sem_view, 0, 1)
            self.sem_vis = visuals.Markers()
            self.sem_view.camera = 'turntable'

            self.sem_view.add(self.sem_vis)
            visuals.XYZAxis(parent=self.sem_view.scene)
            self.sem_view.camera.link(self.scan_view.camera)

        if self.instances:
            print("Using instances in visualizer")
            self.inst_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.canvas.scene)
            self.grid.add_widget(self.inst_view, 0, 2)
            self.inst_vis = visuals.Markers()
            self.inst_view.camera = 'turntable'
            self.inst_view.add(self.inst_vis)
            visuals.XYZAxis(parent=self.inst_view.scene)
            # self.inst_view.camera.link(self.scan_view.camera)

        # img canvas size
        self.multiplier = 1
        self.canvas_W = 1024
        self.canvas_H = 64
        if self.semantics:
            self.multiplier += 1
        if self.instances:
            self.multiplier += 1

        # new canvas for img
        self.img_canvas = SceneCanvas(keys='interactive',
                                      show=True,
                                      size=(self.canvas_W,
                                            self.canvas_H * self.multiplier))
        # grid
        self.img_grid = self.img_canvas.central_widget.add_grid()
        # interface (n next, b back, q quit, very simple)
        self.img_canvas.events.key_press.connect(self.key_press)
        self.img_canvas.events.draw.connect(self.draw)

        # add a view for the depth
        self.img_view = vispy.scene.widgets.ViewBox(
            border_color='white', parent=self.img_canvas.scene)
        self.img_grid.add_widget(self.img_view, 0, 0)
        self.img_vis = visuals.Image(cmap='viridis')
        self.img_view.add(self.img_vis)

        # add semantics
        if self.semantics:
            self.sem_img_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.img_canvas.scene)
            self.img_grid.add_widget(self.sem_img_view, 1, 0)
            self.sem_img_vis = visuals.Image(cmap='viridis')
            self.sem_img_view.add(self.sem_img_vis)

        # add instances
        if self.instances:
            self.inst_img_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.img_canvas.scene)
            self.img_grid.add_widget(self.inst_img_view, 2, 0)
            self.inst_img_vis = visuals.Image(cmap='viridis')
            self.inst_img_view.add(self.inst_img_vis)

    def roi_filter_(self, pointcloud, colors, x_roi, y_roi, z_roi):
        min_x, max_x = x_roi
        min_y, max_y = y_roi
        min_z, max_z = z_roi

        for pcloud, i in zip(self.scan.points,
                             range(len(self.scan.sem_label_color))):
            if ((pcloud[0] > 0) & (pcloud[0] < max_x) & (pcloud[1] > min_y) &
                (pcloud[1] < max_y) & (pcloud[2] > min_z) &
                (pcloud[2] < max_z)):
                pointcloud.append(np.array(pcloud))
                colors.append(np.array(self.scan.sem_label_color[i]))
            else:
                pointcloud.append(np.array(pcloud))
                colors.append(np.array((0.5, 0.5, 0.5)))

    def get_mpl_colormap(self, cmap_name):
        cmap = plt.get_cmap(cmap_name)

        # Initialize the matplotlib color map
        sm = plt.cm.ScalarMappable(cmap=cmap)

        # Obtain linear color range
        color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1]

        return color_range.reshape(256, 3).astype(np.float32) / 255.0

    def update_scan(self):
        # first open data
        self.scan.open_scan(self.scan_names[self.offset])
        if self.semantics:
            self.scan.open_label(self.label_names[self.offset])
            self.scan.colorize()
        if self.bboxes_names:
            self.scan.open_bbox(self.bboxes_names[self.offset],
                                self.use_bbox_measurements)
        if self.bboxes_labels_names:
            self.scan.open_bbox_labels(self.bboxes_labels_names[self.offset])
        # then change names
        title = "scan " + str(self.offset) + " of " + str(len(self.scan_names))
        self.canvas.title = title
        self.img_canvas.title = title

        # then do all the point cloud stuff

        # plot scan
        power = 16
        # print()
        range_data = np.copy(self.scan.unproj_range)
        # print(range_data.max(), range_data.min())
        range_data = range_data**(1 / power)
        # print(range_data.max(), range_data.min())
        viridis_range = ((range_data - range_data.min()) /
                         (range_data.max() - range_data.min()) * 255).astype(
                             np.uint8)
        viridis_map = self.get_mpl_colormap("viridis")
        viridis_colors = viridis_map[viridis_range]
        self.scan_vis.set_data(self.scan.points,
                               face_color=viridis_colors[..., ::-1],
                               edge_color=viridis_colors[..., ::-1],
                               size=1)

        # plot semantics
        if self.semantics:
            colors = []
            pointcloud = []
            if self.roi_filter:
                self.roi_filter_(pointcloud, colors, [0, 45], [-14, 14],
                                 [-2, 1])
            else:
                for pcloud, i in zip(self.scan.points,
                                     range(len(self.scan.sem_label_color))):
                    pointcloud.append(np.array(pcloud))
                    colors.append(np.array(self.scan.sem_label_color[i]))
            self.sem_view.add(self.sem_vis)
            self.sem_vis.set_data(np.array(pointcloud),
                                  face_color=np.array(colors),
                                  edge_color=np.array(colors),
                                  size=1)

        # plot instances
        if self.instances:
            self.inst_vis.set_data(
                self.scan.points,
                face_color=self.scan.inst_label_color[..., ::-1],
                edge_color=self.scan.inst_label_color[..., ::-1],
                size=1)

        # plot draw_clusters
        if self.bboxes_names and self.scan.bboxes:

            self.sem_vis.set_data()
            self.sem_vis.update()

            color = (0, 1, 1, 0.6)
            edge_color = (0, 0.05, 1)
            global bboxes, labels
            bboxes = []
            labels = []

            for bbox in self.scan.bboxes:
                width = bbox[0]
                depth = bbox[1]
                height = bbox[2]

                bboxes.append(
                    vispy.scene.visuals.Box(width=width,
                                            height=height,
                                            depth=depth,
                                            color=color,
                                            edge_color=edge_color,
                                            parent=self.sem_view.scene))

            for cluster, i in zip(bboxes, range(len(self.scan.bboxes))):
                bbox = self.scan.bboxes[i]
                center = bbox[3]
                angle = bbox[4]
                cluster.transform = vispy.visuals.transforms.MatrixTransform()
                cluster.transform.rotate(-angle, (0, 0, 1))
                cluster.transform.translate(center)

            if self.bboxes_labels_names:
                for i in range(len(self.scan.bbox_labels)):
                    bbox = self.scan.bboxes[i]
                    center = bbox[3]
                    #labels.append(vispy.scene.visuals.Text(text = self.scan.bbox_labels[i], parent = self.sem_view.scene,  color = self.scan.bbox_label_color[i], bold=True))
                    labels.append(
                        vispy.scene.visuals.Text(text=self.scan.bbox_labels[i],
                                                 parent=self.sem_view.scene,
                                                 color="red",
                                                 bold=True))
                    labels[i].pos = center[0], center[1], center[2] + 1
                    labels[i].font_size = 600

        # now do all the range image stuff
        # plot range image
        data = np.copy(self.scan.proj_range)
        # print(data[data > 0].max(), data[data > 0].min())
        data[data > 0] = data[data > 0]**(1 / power)
        data[data < 0] = data[data > 0].min()
        # print(data.max(), data.min())
        data = (data - data[data > 0].min()) / \
            (data.max() - data[data > 0].min())
        # print(data.max(), data.min())
        self.img_vis.set_data(data)
        self.img_vis.update()

        if self.semantics:
            self.sem_img_vis.set_data(self.scan.proj_sem_color[..., ::-1])
            self.sem_img_vis.update()

        if self.instances:
            self.inst_img_vis.set_data(self.scan.proj_inst_color[..., ::-1])
            self.inst_img_vis.update()

    # interface
    def key_press(self, event):
        self.canvas.events.key_press.block()
        self.img_canvas.events.key_press.block()
        if event.key == 'N':
            for bbox in bboxes:
                bbox.parent = None
            for label in labels:
                label.parent = None
            self.offset += 1
            self.update_scan()

        elif event.key == 'B':
            for bbox in bboxes:
                bbox.parent = None
            for label in labels:
                label.parent = None
            self.offset -= 1
            self.update_scan()
        elif event.key == 'Q' or event.key == 'Escape':
            self.destroy()

    def draw(self, event):
        if self.canvas.events.key_press.blocked():
            self.canvas.events.key_press.unblock()
        if self.img_canvas.events.key_press.blocked():
            self.img_canvas.events.key_press.unblock()

    def destroy(self):
        # destroy the visualization
        self.canvas.close()
        self.img_canvas.close()
        vispy.app.quit()

    def run(self):
        vispy.app.run()
Beispiel #6
0
class ScannetVis(QWidget):
    """Class that creates and handles a visualizer for a pointcloud"""
    def __init__(self,
                 scan,
                 rgb_names,
                 depth_names,
                 pose_names,
                 offset=0,
                 skip_im=10,
                 mesh_plot=True,
                 parent=None):
        super(ScannetVis, self).__init__(parent=parent)

        self.scan = scan
        self.rgb_names = rgb_names
        self.depth_names = depth_names
        self.pose_names = pose_names
        self.offset = offset
        self.offset_prev = offset
        self.skip_im = skip_im
        self.mesh_plot = mesh_plot

        self.keyboard_inputs = None
        self.total = len(self.rgb_names)

        self.checkBox_list = []
        self.checkBox_with_3D = []

        self.reset()
        self.initUI()
        self.update_scan()

    def initUI(self):
        self.setStyleSheet("background-color: white;")
        self.principalLayout = QHBoxLayout(self)
        ''' left left Frame : RGB with yolact & depth frame '''
        self.left2Frame = QFrame(self)
        self.left2Frame.setFrameShape(QFrame.StyledPanel)
        self.left2Frame.setFrameShadow(QFrame.Raised)
        self.vertical2Layout = QVBoxLayout(self.left2Frame)
        # self.vertical2Layout.setSpacing(0)
        self.principalLayout.addWidget(self.left2Frame)

        # self.vertical2_1Layout = QVBoxLayout(self.left2Frame)
        # self.vertical2Layout.addWidget(self.left2Frame)
        # add rgb depth
        self.img_canvas.create_native()
        self.img_canvas.native.setMinimumSize(320, 480)
        self.vertical2Layout.addWidget(self.img_canvas.native)
        ''' left Frame : 3D reconstructed Scene '''
        self.leftFrame = QFrame(self)
        self.leftFrame.setFrameShape(QFrame.StyledPanel)
        self.leftFrame.setFrameShadow(QFrame.Raised)
        self.verticalLayout = QVBoxLayout(self.leftFrame)
        # self.verticalLayout.setSpacing(0)
        self.principalLayout.addWidget(self.leftFrame)

        self.canvas.create_native()
        self.canvas.native.setMinimumSize(640, 480)
        self.verticalLayout.addWidget(self.canvas.native)
        ''' left center Frame : 3D Scene graph'''
        self.SGFrame = QFrame(self)
        self.SGFrame.setFrameShape(QFrame.StyledPanel)
        self.SGFrame.setFrameShadow(QFrame.Raised)
        # self.verticalSGLayout = QVBoxLayout(self.SGFrame)
        # self.verticalLayout.setSpacing(0)
        self.principalLayout.addWidget(self.SGFrame)

        self.scene_graph_canvas.create_native()
        self.scene_graph_canvas.native.setMinimumSize(640, 480)
        self.verticalLayout.addWidget(self.scene_graph_canvas.native)
        ''' center Frame : control pannel '''
        self.keyFrame = QFrame(self)
        self.keyFrame.setFrameShape(QFrame.StyledPanel)
        self.keyFrame.setFrameShadow(QFrame.Raised)
        self.keysverticalLayout = QVBoxLayout(self.keyFrame)

        self.label1 = QLabel(
            "To navigate: "
            "\n   n: next (next scan) "
            "\n   s: start (start processing sequential rgb-d images)"
            "\n   p: pause (pause processing)"
            "\n   q: quit (exit program)"
            "\n\n To control 3D view: "
            "\n   LMB: orbits the view around its center point"
            "\n   RMB or scroll: change scale_factor (i.e. zoom level)"
            "\n   SHIFT + LMB: translate the center point"
            "\n   SHIFT + RMB: change FOV")
        self.label2 = QLabel("To find specific objects in 3D Space : ")
        # self.keysverticalLayout.addWidget(self.label1)
        # self.keysverticalLayout.addWidget(self.label2)
        self.vertical2Layout.addWidget(self.label1)
        self.vertical2Layout.addWidget(self.label2)

        self.le = QLineEdit(self)
        self.vertical2Layout.addWidget(self.le)

        self.spb = QPushButton('search', self)
        self.vertical2Layout.addWidget(self.spb)
        self.spb.clicked.connect(self.search_button_click)

        self.cpb = QPushButton('clear', self)
        self.vertical2Layout.addWidget(self.cpb)
        self.cpb.clicked.connect(self.clear_button_click)

        self.verticalLayoutR = QVBoxLayout()
        self.verticalLayoutR.addWidget(self.keyFrame)
        self.verticalLayoutR.setContentsMargins(0, 0, 0, 0)
        self.verticalLayoutR.setSpacing(0)
        self.principalLayout.addLayout(self.verticalLayoutR)
        ''' Right Frame : result images of searched objects '''
        self.rightFrame = QFrame(self)
        self.rightFrame.setFrameShape(QFrame.StyledPanel)
        self.rightFrame.setFrameShadow(QFrame.Raised)
        self.verticalLayoutRight = QVBoxLayout(self.rightFrame)
        self.verticalLayoutRight.setContentsMargins(0, 0, 0, 0)
        self.verticalLayoutRight.setSpacing(0)
        self.principalLayout.addWidget(self.rightFrame)

        self.setLayout(self.principalLayout)
        self.setWindowTitle('Searching objects')
        self.setGeometry(300, 300, 300, 200)
        self.show()

    def reset(self):
        """ Reset. """
        # last key press (it should have a mutex, but visualization is not
        # safety critical, so let's do things wrong)
        self.action = "no"  # no, next, back, quit are the possibilities
        ''' 3D points cloud or mesh SceneCanvas '''
        # new canvas prepared for visualizing data
        self.canvas = SceneCanvas(keys='interactive', show=True)
        # interface (n next, b back, q quit, very simple)
        self.canvas.events.key_press.connect(self.key_press)
        self.canvas.events.draw.connect(self.draw)
        # grid
        self.grid = self.canvas.central_widget.add_grid()

        # add point cloud views
        self.scan_view = vispy.scene.widgets.ViewBox(border_color='white',
                                                     parent=self.canvas.scene)
        self.grid.add_widget(self.scan_view, 0, 0)

        # Camera location settings
        self.scene_cam = vispy.scene.cameras.BaseCamera()
        # self.scene_cam.center = (-10, -10, 10)
        # self.scan_view.add(self.scene_cam)
        # self.scene_cam.pre_transform.set_range()

        canvas2 = vispy.app.Canvas()
        w = QMainWindow()
        widget = QWidget()
        w.setCentralWidget(widget)
        widget.setLayout(QVBoxLayout())
        widget.layout().addWidget(canvas2.native)
        widget.layout().addWidget(QPushButton())
        w.show()

        self.scan_vis = visuals.Mesh()
        self.scan_vis_mean = visuals.Line()
        self.scan_vis_cam = visuals.Line()
        self.scan_bbox_3d = visuals.Line()
        self.label_vis = visuals.Text()

        self.scan_view.add(self.scan_vis)
        self.scan_view.add(self.scan_vis_mean)
        self.scan_view.add(self.scan_vis_cam)
        self.scan_view.add(self.scan_bbox_3d)
        self.scan_view.add(self.label_vis)

        self.scan_view.camera = 'arcball'
        self.tr = self.scan_vis.transforms.get_transform(map_from='visual',
                                                         map_to='canvas')
        # self.scan_view.camera = self.scene_cam
        # self.scan_view.camera = 'arcball' , 'turntable'
        # self.scan_view.camera.transform.rotate(90, (0,1,0))

        visuals.XYZAxis(parent=self.scan_view.scene)
        ''' 2D images SceneCanvas '''
        # img canvas size
        self.canvas_W = 320
        self.canvas_H = 280
        self.multiplier = 2
        ''' new canvas for RGB & Depth img '''
        self.img_canvas = SceneCanvas(keys='interactive',
                                      show=True,
                                      size=(self.canvas_W,
                                            self.canvas_H * self.multiplier))
        self.img_grid = self.img_canvas.central_widget.add_grid()
        # interface (n next, s start, p pause, q quit, )
        self.img_canvas.events.key_press.connect(self.key_press)
        self.img_canvas.events.draw.connect(self.draw)

        # add rgb views
        self.rgb_img_raw_view = vispy.scene.widgets.ViewBox(
            border_color='white', parent=self.img_canvas.scene)
        self.img_grid.add_widget(self.rgb_img_raw_view, 0, 0)
        self.rgb_img_raw_vis = visuals.Image(cmap='viridis')
        self.rgb_img_raw_view.add(self.rgb_img_raw_vis)

        # add a view for the depth
        self.depth_img_view = vispy.scene.widgets.ViewBox(
            border_color='white', parent=self.img_canvas.scene)
        self.img_grid.add_widget(self.depth_img_view, 1, 0)
        self.depth_img_vis = visuals.Image(cmap='viridis')
        self.depth_img_view.add(self.depth_img_vis)
        ''' new canvas for 3D scene graph img '''
        self.scene_graph_canvas = SceneCanvas(keys='interactive',
                                              show=True,
                                              size=(640, 480))
        self.scene_graph_grid = self.scene_graph_canvas.central_widget.add_grid(
        )
        self.scene_graph_canvas.events.key_press.connect(self.key_press)
        self.scene_graph_canvas.events.draw.connect(self.draw)

        # add a view for 3D scene graphs
        self.scene_graph_view = vispy.scene.widgets.ViewBox(
            border_color='white', parent=self.scene_graph_canvas.scene)
        self.scene_graph_grid.add_widget(self.scene_graph_view, 0, 0)
        self.scene_graph_vis = visuals.Image(cmap='viridis')
        self.scene_graph_view.add(self.scene_graph_vis)

    def get_mpl_colormap(self, cmap_name):
        cmap = plt.get_cmap(cmap_name)

        # Initialize the matplotlib color map
        sm = plt.cm.ScalarMappable(cmap=cmap)

        # Obtain linear color range
        color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1]

        return color_range.reshape(256, 3).astype(np.float32) / 255.0

    def update_yolact(self):
        title = "scan " + str(self.offset)

        # draw color & depth image
        self.img_canvas.title = title

        _, _, _, _ = self.scan.open_scan(self.rgb_names[self.offset],
                                         self.depth_names[self.offset],
                                         self.pose_names[self.offset],
                                         self.offset,
                                         recon=False)

        text_str = 'Frame %d ' % (self.offset)
        font_face = cv2.FONT_HERSHEY_DUPLEX
        font_scale = 0.6
        font_thickness = 1
        text_w, text_h = cv2.getTextSize(text_str, font_face, font_scale,
                                         font_thickness)[0]
        masked_img = self.scan.masked_img.copy()
        masked_img = cv2.resize(masked_img, (320, 240),
                                interpolation=cv2.INTER_AREA)

        x1, y1 = 0, 0
        text_pt = (x1, y1 + 15)
        text_color = [255, 255, 255]
        color = [0, 0, 0]
        cv2.rectangle(masked_img, (x1, y1), (x1 + text_w, y1 + text_h + 4),
                      color, -1)
        cv2.putText(masked_img, text_str, text_pt, font_face, font_scale,
                    text_color, font_thickness, cv2.LINE_AA)

        self.rgb_img_raw_vis.set_data(masked_img)
        self.rgb_img_raw_vis.update()

        depth_img = cv2.resize(self.scan.depth_im.copy(), (320, 240),
                               interpolation=cv2.INTER_AREA)
        self.depth_img_vis.set_data(depth_img)
        self.depth_img_vis.update()

    def update_3d_recon(self):
        title = "scan " + str(self.offset)
        if (self.offset % self.skip_im == 0):
            start_time = time.time()
            verts, faces, norms, colors = self.scan.open_scan(
                self.rgb_names[self.offset],
                self.depth_names[self.offset],
                self.pose_names[self.offset],
                self.offset,
                recon=True)
            self.verts, self.faces, self.norms, self.colors = verts, faces, norms, colors

            self.canvas.title = title
            self.scan_vis.set_data(vertices=verts,
                                   faces=faces,
                                   vertex_colors=colors / 255.)
            self.scan_vis.update()

            #if self.scan.num_dets_to_consider > 0 and not self.scan.use_gpu:
            if self.scan.num_dets_to_consider > 0 and self.scan.tsdf_vol.debug_same_node_detector:
                self.mean_pose = np.array(self.scan.tsdf_vol.mask_centers)
                self.scan_vis_mean.set_data(self.mean_pose,
                                            color='red',
                                            width=3,
                                            connect='strip')
                self.scan_vis_mean.update()

                # find object's position and visualize
                self.label_vis.text = self.scan.tsdf_vol.class_label
                self.label_vis.pos = self.mean_pose
                self.label_vis.font_size = int(40)

            self.cam_frustum = np.array(self.scan.tsdf_vol.cam_frustum)
            self.scan_vis_cam.set_data(self.cam_frustum,
                                       color='blue',
                                       width=3,
                                       connect=self.scan.tsdf_vol.cam_connect)
            self.scan_vis_cam.update()
            if ('camera' in self.label_vis.text):
                self.label_vis.text.pop()
                self.label_vis.pos = self.label_vis.pos[:-1, :]
            self.label_vis.text += self.scan.tsdf_vol.cam_label
            self.label_vis.pos = np.append(self.label_vis.pos,
                                           self.scan.tsdf_vol.cam_centers,
                                           axis=0)

            # Draw Scene graph images
            generated_scene_graph_file = os.path.join(
                self.scan.tsdf_vol.scene_graph_path,
                'scene_graph' + str(self.offset) + '.png')
            if os.path.exists(generated_scene_graph_file):
                print('Draw scene graph{}'.format(self.offset))
                sg_img = cv2.cvtColor(cv2.imread(generated_scene_graph_file),
                                      cv2.COLOR_BGR2RGB)
                self.sg_img = cv2.resize(sg_img, (640, 480),
                                         interpolation=cv2.INTER_AREA)
                self.scene_graph_vis.set_data(self.sg_img)
                self.scene_graph_vis.update()

            print("--- %s seconds of %d to %d images---" %
                  (time.time() - start_time, self.offset - self.skip_im + 1,
                   self.offset))
            print("--- fps : {} ---".format(self.skip_im /
                                            (time.time() - start_time)))

    def update_scan(self):
        # update_yolact images
        self.update_yolact()

        # Reconstruct 3D Scene and detect same nodes or not
        self.update_3d_recon()

    def update_seq_scan(self):
        if self.canvas.events.key_press.blocked():
            self.canvas.events.key_press.unblock()
        if self.img_canvas.events.key_press.blocked():
            self.img_canvas.events.key_press.unblock()
        if self.scene_graph_canvas.events.key_press.blocked():
            self.scene_graph_canvas.events.key_press.unblock()
        if (self.start):
            self.offset += 1
            self.update_yolact()
            self.update_3d_recon()

            self.canvas.scene.update()
            self.img_canvas.scene.update()
            self.scene_graph_canvas.update()
            self.canvas.on_draw(None)
            self.img_canvas.on_draw(None)
            self.scene_graph_canvas.on_draw(None)

    # interface
    def key_press(self, event):
        self.keyboard_inputs = event.key
        if event.key == 'N':
            self.offset += 1
            if self.offset >= self.total:
                self.offset = 0
            self.update_scan()

        elif event.key == 'S':
            # Start to process RGB-D sequences
            self.start = True
            self.timer1 = vispy.app.Timer(0.033,
                                          connect=self.on_timer1,
                                          start=True)
            self.timer2 = vispy.app.Timer(0.033,
                                          connect=self.on_timer2,
                                          start=True)

        elif event.key == 'P':
            # Pause to process RGB sequences
            self.start = False

        elif event.key == 'U':
            # test when updated draw function
            self.canvas.scene.update()
            self.img_canvas.scene.update()
            self.scene_graph_canvas.update()

        elif event.key == 'Q' or event.key == 'Escape':
            self.destroy()

    def on_timer1(self, event):
        # self.update_seq_scan()
        if (self.start):
            self.offset += 1
            self.update_yolact()

    def on_timer2(self, event):
        if (self.start):
            # self.offset += 1
            self.update_3d_recon()

    def search_button_click(self):
        print('searching object : {}'.format(self.le.text()))
        objects_dict = self.scan.tsdf_vol.node_data

        is_obj_exist = []

        self.clear_searched_items(self.verticalLayoutRight)

        for key, val in objects_dict.items():
            if (val['class'] == self.le.text()):
                print('find {}'.format(self.le.text()))

                thumbnail_path = os.path.join(
                    self.scan.tsdf_vol.bbox_path,
                    'thumbnail_' + str(key) + '_' +
                    str(int(objects_dict[str(key)]['detection_cnt'] / 2)) +
                    '.png')
                cv2_img = cv2.cvtColor(cv2.imread(thumbnail_path),
                                       cv2.COLOR_BGR2RGB)
                image = QImage(cv2_img.data, cv2_img.shape[1],
                               cv2_img.shape[0], cv2_img.strides[0],
                               QImage.Format_RGB888)
                image_frame = QLabel()
                image_frame.setPixmap(QPixmap.fromImage(image))
                self.verticalLayoutRight.addWidget(image_frame)

                checkBox = QCheckBox(val['class'] + str(key))
                self.checkBox_list += [[checkBox, val['class'], str(key)]]

                scan_bbox_3d = visuals.Line()
                self.checkBox_with_3D += [scan_bbox_3d]
                self.scan_view.add(scan_bbox_3d)

                checkBox.stateChanged.connect(self.checkBoxState)

                # searched_obj = QLabel(val['class'] + str(key))
                self.verticalLayoutRight.addWidget(checkBox)
                is_obj_exist += [True]

        if (not is_obj_exist):
            searched_obj = QLabel("Nothing was found!")
            self.verticalLayoutRight.addWidget(searched_obj)
        else:
            searched_obj = QLabel(
                "Check box if you want to find objects in 3D Scene.")
            self.verticalLayoutRight.addWidget(searched_obj)

    def clear_button_click(self):
        print('clear previous searched object')
        self.clear_searched_items(self.verticalLayoutRight)

    def clear_searched_items(self, layout):
        # reset rearching results widget
        while layout.count() > 0:
            item = layout.takeAt(0)
            if not item:
                continue

            w = item.widget()
            if w:
                w.deleteLater()

        # reset visuals.Line for 3D BBox of searched objects
        for i, check in enumerate(self.checkBox_list):
            self.checkBox_with_3D[i].parent = None
            self.checkBox_with_3D[i] = visuals.Line()
            self.scan_view.add(self.checkBox_with_3D[i])

        self.checkBox_list = []
        self.checkBox_with_3D = []

    def checkBoxState(self):
        # checkBox_list is composed of [QcheckBox, class_name, class_3D_ID]
        for i, check in enumerate(self.checkBox_list):
            if check[0].isChecked() == True:
                print('checked!!!')
                # Find 3D BBox in 3D Scene Canvas\
                bbox_3d = np.array(self.scan.tsdf_vol.bbox_3ds[check[2]])
                bbox_connect = np.array([[0, 1], [1, 2], [2, 3], [3,
                                                                  0], [4, 5],
                                         [5, 6], [6, 7], [7, 4], [0, 4],
                                         [1, 5], [2, 6], [3, 7]])
                self.checkBox_with_3D[i].set_data(bbox_3d,
                                                  color='green',
                                                  width=3,
                                                  connect=bbox_connect)
            else:
                self.checkBox_with_3D[i].parent = None
                self.checkBox_with_3D[i] = visuals.Line()
                self.scan_view.add(self.checkBox_with_3D[i])

    def draw(self, event):
        # print('draw states!!')
        # print('event key: {}'.format(self.keyboard_inputs))
        if self.canvas.events.key_press.blocked():
            self.canvas.events.key_press.unblock()
        if self.img_canvas.events.key_press.blocked():
            self.img_canvas.events.key_press.unblock()
        if self.scene_graph_canvas.events.key_press.blocked():
            self.scene_graph_canvas.events.key_press.unblock()

        if self.keyboard_inputs == 'P':
            # Pause to process RGB sequences
            self.start = False
        # if self.keyboard_inputs == 'S':
        #     self.update_seq_scan()

    def destroy(self):
        # destroy the visualization
        self.canvas.close()
        self.img_canvas.close()
        self.scene_graph_canvas.close()
        vispy.app.quit()

    def run(self):
        vispy.app.use_app(backend_name="PyQt5", call_reuse=True)
        vispy.app.run()
Beispiel #7
0
class LaserScanVis():
    """Class that creates and handles a visualizer for a pointcloud"""
    def __init__(self,
                 W,
                 H,
                 show_mesh=False,
                 show_diff=False,
                 show_range=False,
                 show_remissions=False,
                 show_target=True,
                 show_label=True):
        self.W = W
        self.H = H
        self.point_size = 3
        self.frame = 0
        self.nframes = 0
        self.view_mode = 'label'
        self.show_label = show_label
        self.show_target = show_target
        self.show_mesh = show_mesh
        self.show_remissions = show_remissions
        self.show_diff = show_diff
        self.show_range = show_range
        self.reset()

    def reset(self):
        """ Reset. """
        # last key press (it should have a mutex, but visualization is not
        # safety critical, so let's do things wrong)
        self.action = "no"  # no, next, back, quit are the possibilities

        # 3D canvas
        self.scan_canvas = SceneCanvas(keys='interactive',
                                       show=True,
                                       title='',
                                       size=(1600, 600),
                                       bgcolor='white')
        self.scan_canvas.events.key_press.connect(self.key_press)
        self.grid_view = self.scan_canvas.central_widget.add_grid()

        # source laserscan 3D
        self.scan_view = vispy.scene.widgets.ViewBox(
            border_color='white', parent=self.scan_canvas.scene)
        self.scan_vis = visuals.Markers()
        self.scan_view.camera = 'turntable'
        self.scan_view.add(self.scan_vis)
        visuals.XYZAxis(parent=self.scan_view.scene)
        self.grid_view.add_widget(self.scan_view, 0, 0)

        # target laserscan 3D
        if self.show_target is True:
            self.back_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.scan_canvas.scene)
            self.back_vis = visuals.Markers()
            self.back_view.camera = 'turntable'
            self.back_view.camera.link(self.scan_view.camera)
            self.back_view.add(self.back_vis)
            visuals.XYZAxis(parent=self.back_view.scene)
            self.grid_view.add_widget(self.back_view, 0, 1)

        # self.grid_view.padding = 6

        # Set height of images
        h = 1
        if self.show_range is True:
            h += 1
        if self.show_remissions is True:
            h += 1

        # source canvas 2D
        source_canvas_title = 'Source ' + str(self.H[0]) + 'x' + str(self.W[0])
        self.source_canvas = SceneCanvas(keys='interactive',
                                         show=True,
                                         title=source_canvas_title,
                                         size=(self.W[0], h * self.H[0]))
        self.source_canvas.events.key_press.connect(self.key_press)
        self.source_view = self.source_canvas.central_widget.add_grid()
        source_grid_idx = 0

        # Add label image
        if self.show_label:
            self.img_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.source_canvas.scene)
            self.img_vis = visuals.Image(cmap='viridis')
            self.img_view.add(self.img_vis)
            self.source_view.add_widget(self.img_view, source_grid_idx, 0)
            source_grid_idx += 1

        # target canvas 2D
        if self.show_target:
            target_canvas_title = 'Target ' + str(self.H[1]) + 'x' + str(
                self.W[1])
            self.target_canvas = SceneCanvas(keys='interactive',
                                             show=True,
                                             title=target_canvas_title,
                                             size=(self.W[1], h * self.H[1]))
            self.target_canvas.events.key_press.connect(self.key_press)
            self.target_view = self.target_canvas.central_widget.add_grid()
            target_grid_idx = 0

            # Add label image
            if self.show_label:
                self.test_view = vispy.scene.widgets.ViewBox(
                    border_color='white', parent=self.target_canvas.scene)
                self.test_vis = visuals.Image(cmap='viridis')
                self.test_view.add(self.test_vis)
                self.target_view.add_widget(self.test_view, target_grid_idx, 0)
                target_grid_idx += 1

        if self.show_range:
            # Add source range image
            self.range_view_source = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.source_canvas.scene)
            self.range_image_source = visuals.Image()
            self.range_view_source.add(self.range_image_source)
            self.source_view.add_widget(self.range_view_source,
                                        source_grid_idx, 0)
            source_grid_idx += 1

            if self.show_target:
                self.range_view_target = vispy.scene.widgets.ViewBox(
                    border_color='white', parent=self.target_canvas.scene)
                self.range_image_target = visuals.Image(cmap='viridis')
                self.range_view_target.add(self.range_image_target)
                self.target_view.add_widget(self.range_view_target,
                                            target_grid_idx, 0)
                target_grid_idx += 1

        if self.show_remissions:
            # Add source remissions image
            self.remissions_view_source = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.source_canvas.scene)
            self.remissions_image_source = visuals.Image()
            self.remissions_view_source.add(self.remissions_image_source)
            self.source_view.add_widget(self.remissions_view_source,
                                        source_grid_idx, 0)
            source_grid_idx += 1

            # Add target remissions image
            if self.show_target:
                self.remissions_view_target = vispy.scene.widgets.ViewBox(
                    border_color='white', parent=self.target_canvas.scene)
                self.remissions_image_target = visuals.Image(cmap='viridis')
                self.remissions_view_target.add(self.remissions_image_target)
                self.target_view.add_widget(self.remissions_view_target,
                                            target_grid_idx, 0)
                target_grid_idx += 1

        # 2D canvas for showing difference in range, labels and remissions
        if self.show_diff:
            self.diff_canvas = SceneCanvas(keys='interactive',
                                           show=True,
                                           title='Difference Range Image',
                                           size=(self.W[1], self.H[1] * h))
            self.diff_canvas.events.key_press.connect(self.key_press)
            self.diff_view = self.diff_canvas.central_widget.add_grid()
            grid_idx = 0

            # Add label difference
            if self.show_label:
                self.diff_view_label = vispy.scene.widgets.ViewBox(
                    border_color='white', parent=self.diff_canvas.scene)
                self.diff_image_label = visuals.Image(cmap='viridis')
                self.diff_view_label.add(self.diff_image_label)
                self.diff_view.add_widget(self.diff_view_label, grid_idx, 0)
                grid_idx += 1

            # Add range difference
            if self.show_range:
                self.diff_view_depth = vispy.scene.widgets.ViewBox(
                    border_color='white', parent=self.diff_canvas.scene)
                self.diff_image_depth = visuals.Image()
                self.diff_view_depth.add(self.diff_image_depth)
                self.diff_view.add_widget(self.diff_view_depth, grid_idx, 0)
                grid_idx += 1

            # Add remissions difference
            if self.show_remissions:
                self.diff_view_remissions = vispy.scene.widgets.ViewBox(
                    border_color='white', parent=self.diff_canvas.scene)
                self.diff_image_remissions = visuals.Image()
                self.diff_view_remissions.add(self.diff_image_remissions)
                self.diff_view.add_widget(self.diff_view_remissions, grid_idx,
                                          0)
                grid_idx += 1

        if self.show_mesh:
            self.mesh_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.scan_canvas.scene)
            self.mesh_vis = visuals.Mesh(shading=None)
            self.mesh_view.camera = 'turntable'
            self.mesh_view.camera.link(self.scan_view.camera)
            self.mesh_view.add(self.mesh_vis)
            visuals.XYZAxis(parent=self.mesh_view.scene)
            self.grid_view.add_widget(self.mesh_view, 0, 2)

    def set_title(self):
        self.scan_canvas.title = 'Frame %d of %d' % (self.frame + 1,
                                                     self.nframes)

    def set_source_scan(self, scan):
        """ Set single raw scan (2D only)
    """
        # plot 2D images
        if self.show_label:
            self.img_vis.set_data(scan.proj_color[..., ::-1])
            self.img_vis.update()
        if self.show_range:
            data = convert_range(scan.proj_range)
            self.range_image_source.set_data(data)
        if self.show_remissions:
            data = scan.proj_remissions
            # data = convert_range(scan.proj_remissions)
            self.remissions_image_source.set_data(data)

    def set_data(self,
                 scan_source,
                 scan_target,
                 verts=None,
                 verts_colors=None,
                 faces=None,
                 W=None,
                 H=None):
        """ Set both source and target scans (2D and mesh in 3D)
    """
        self.set_title()

        if self.show_target:
            self.set_target_3d(scan_target)

        if self.show_label:
            self.img_vis.set_data(scan_source.proj_color[..., ::-1])
            if self.show_target:
                self.test_vis.set_data(scan_target.proj_color[..., ::-1])

        if self.show_range:
            source_range = scan_source.proj_range
            # print("source", source_range.max(), source_range.min(),
            #       source_range[source_range>=0].mean())
            # source_range = self.convert_ranges(scan_source.proj_range)
            self.range_image_source.set_data(source_range)
            self.range_image_source.update()
            if self.show_target:
                target_range = scan_target.proj_range
                # target_data = self.convert_range(target_range)
                # print("target", target_data.max(), target_data.min(),
                #       target_data.mean())
                target_data = target_range
                self.range_image_target.set_data(target_data)
                self.range_image_target.update()

        if self.show_remissions:
            source_data = np.copy(scan_source.proj_remissions)
            source_data[source_data == -1] = 0
            # print("source", source_data.max(), source_data.min(),
            #       source_data[source_data >= 0].mean())
            self.remissions_image_source.set_data(source_data)
            self.remissions_image_source.update()
            if self.show_target:
                target_data = scan_target.proj_remissions
                # print("source", target_data.max(), target_data.min(),
                #       target_data[target_data >= 0].mean())
                self.remissions_image_target.set_data(target_data)
                self.remissions_image_target.update()

        if self.show_mesh:
            self.mesh_vis.set_data(vertices=verts,
                                   vertex_colors=verts_colors[..., ::-1],
                                   faces=faces)
            self.mesh_vis.update()

    def set_diff(self, label_diff, range_diff, remissions_diff, m_iou, m_acc,
                 MSE):
        """ Update difference images
    """
        if self.show_label:
            self.diff_image_label.set_data(label_diff[..., ::-1])
            self.diff_image_label.update()

        if self.show_range:
            data = convert_range(range_diff)
            self.diff_image_depth.set_data(data)
            self.diff_image_depth.set_data(range_diff)
            self.diff_image_depth.update()

        if self.show_remissions:
            self.diff_image_remissions.set_data(remissions_diff)
            self.diff_image_remissions.update()

        self.diff_canvas.title = \
            'IoU %5.2f%%, Acc %5.2f%%, MSE %f' % (m_iou * 100.0, m_acc * 100, MSE)

    def set_mesh(self, verts, verts_colors, faces):
        if self.show_mesh:
            self.mesh_vis.set_data(vertices=verts,
                                   vertex_colors=verts_colors,
                                   faces=faces)
            self.mesh_vis.update()

    def set_source_3d(self, scan_source):
        points = scan_source.points

        if self.view_mode == 'label':
            colors = scan_source.label_color
        else:
            if self.view_mode == 'range':
                range_data = np.copy(scan_source.unproj_range.reshape(-1))
                power = 2
                range_data = range_data**(1 / power)
                viridis_range = ((range_data - range_data.min()) /
                                 (range_data.max() - range_data.min()) *
                                 255).astype(np.uint8)
            elif self.view_mode == 'rem':
                range_data = np.copy(scan_source.remissions.reshape(-1))
                viridis_range = (range_data * 255).astype(np.uint8)
            viridis_map = get_mpl_colormap("viridis")
            colors = viridis_map[viridis_range]

        self.scan_vis.set_data(points,
                               face_color=colors[..., ::-1],
                               edge_color=colors[..., ::-1],
                               size=self.point_size)
        self.scan_vis.update()

    def set_target_3d(self, scan_target):
        points = scan_target.back_points

        if self.view_mode == 'label':
            colors = scan_target.label_color
        else:
            if self.view_mode == 'range':
                range_data = np.copy(scan_target.proj_range.reshape(-1))
                power = 2
                range_data = range_data**(1 / power)
                viridis_range = ((range_data - range_data.min()) /
                                 (range_data.max() - range_data.min()) *
                                 255).astype(np.uint8)
            elif self.view_mode == 'rem':
                range_data = np.copy(scan_target.proj_remissions.reshape(-1))
                viridis_range = (range_data * 255).astype(np.uint8)
            viridis_map = get_mpl_colormap("viridis")
            colors = viridis_map[viridis_range]

        self.back_vis.set_data(points,
                               face_color=colors[..., ::-1],
                               edge_color=colors[..., ::-1],
                               size=self.point_size)
        self.back_vis.update()

    # interface
    def key_press(self, event):
        if event.key == 'N':
            self.action = 'next'
        elif event.key == 'B':
            self.action = 'back'
        elif event.key == 'Q' or event.key == 'Escape':
            self.destroy()
            self.action = 'quit'
        elif event.key == '1':
            self.action = 'change'
            self.view_mode = 'label'
        elif event.key == '2':
            self.action = 'change'
            self.view_mode = 'range'
        elif event.key == '3':
            self.action = 'change'
            self.view_mode = 'rem'

    def get_action(self, timeout=0):
        # return action and void it to avoid reentry
        vispy.app.use_app().sleep(timeout)
        ret = self.action
        self.action = 'no'
        return ret

    def destroy(self):
        # destroy the visualization
        self.source_canvas.events.key_press.disconnect()
        self.source_canvas.close()
        if self.show_target:
            self.target_canvas.events.key_press.disconnect()
            self.target_canvas.close()
        if self.show_diff:
            self.diff_canvas.events.key_press.disconnect()
            self.diff_canvas.close()
        vispy.app.quit()
Beispiel #8
0
class LaserScanExtract:
    """Class that creates and handles a visualizer for a pointcloud"""
    def __init__(self,
                 scan,
                 scan_names,
                 label_names,
                 offset=0,
                 semantics=True,
                 instances=False,
                 classid=51):
        self.scan = scan
        self.scan_names = scan_names
        self.label_names = label_names
        self.offset = offset
        self.total = len(self.scan_names)
        self.semantics = semantics
        self.instances = instances
        self.classid = classid
        # sanity check
        if not self.semantics and self.instances:
            print("Instances are only allowed in when semantics=True")
            raise ValueError

        # self.reset()
        self.update_scan()

    def reset(self):
        """ Reset. """
        # last key press (it should have a mutex, but visualization is not
        # safety critical, so let's do things wrong)
        self.action = "no"  # no, next, back, quit are the possibilities

        # new canvas prepared for visualizing data
        self.canvas = SceneCanvas(keys='interactive', show=True)
        # interface (n next, b back, q quit, very simple)
        self.canvas.events.key_press.connect(self.key_press)
        self.canvas.events.draw.connect(self.draw)
        # grid
        self.grid = self.canvas.central_widget.add_grid()

        # laserscan part
        self.scan_view = vispy.scene.widgets.ViewBox(border_color='white',
                                                     parent=self.canvas.scene)
        self.grid.add_widget(self.scan_view, 0, 0)
        self.scan_vis = visuals.Markers()
        self.scan_view.camera = 'turntable'
        self.scan_view.add(self.scan_vis)
        visuals.XYZAxis(parent=self.scan_view.scene)
        # add semantics
        if self.semantics:
            print("Using semantics in visualizer")
            self.sem_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.canvas.scene)
            self.grid.add_widget(self.sem_view, 0, 1)
            self.sem_vis = visuals.Markers()
            self.sem_view.camera = 'turntable'
            self.sem_view.add(self.sem_vis)
            visuals.XYZAxis(parent=self.sem_view.scene)
            # self.sem_view.camera.link(self.scan_view.camera)

        if self.instances:
            print("Using instances in visualizer")
            self.inst_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.canvas.scene)
            self.grid.add_widget(self.inst_view, 0, 2)
            self.inst_vis = visuals.Markers()
            self.inst_view.camera = 'turntable'
            self.inst_view.add(self.inst_vis)
            visuals.XYZAxis(parent=self.inst_view.scene)
            # self.inst_view.camera.link(self.scan_view.camera)

        # img canvas size
        self.multiplier = 1
        self.canvas_W = 1024
        self.canvas_H = 64
        if self.semantics:
            self.multiplier += 1
        if self.instances:
            self.multiplier += 1

        # new canvas for img
        self.img_canvas = SceneCanvas(keys='interactive',
                                      show=True,
                                      size=(self.canvas_W,
                                            self.canvas_H * self.multiplier))
        # grid
        self.img_grid = self.img_canvas.central_widget.add_grid()
        # interface (n next, b back, q quit, very simple)
        self.img_canvas.events.key_press.connect(self.key_press)
        self.img_canvas.events.draw.connect(self.draw)

        # add a view for the depth
        self.img_view = vispy.scene.widgets.ViewBox(
            border_color='white', parent=self.img_canvas.scene)
        self.img_grid.add_widget(self.img_view, 0, 0)
        self.img_vis = visuals.Image(cmap='viridis')
        self.img_view.add(self.img_vis)

        # add semantics
        if self.semantics:
            self.sem_img_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.img_canvas.scene)
            self.img_grid.add_widget(self.sem_img_view, 1, 0)
            self.sem_img_vis = visuals.Image(cmap='viridis')
            self.sem_img_view.add(self.sem_img_vis)

        # add instances
        if self.instances:
            self.inst_img_view = vispy.scene.widgets.ViewBox(
                border_color='white', parent=self.img_canvas.scene)
            self.img_grid.add_widget(self.inst_img_view, 2, 0)
            self.inst_img_vis = visuals.Image(cmap='viridis')
            self.inst_img_view.add(self.inst_img_vis)

    def get_mpl_colormap(self, cmap_name):
        cmap = plt.get_cmap(cmap_name)

        # Initialize the matplotlib color map
        sm = plt.cm.ScalarMappable(cmap=cmap)

        # Obtain linear color range
        color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1]

        return color_range.reshape(256, 3).astype(np.float32) / 255.0

    def update_scan(self):
        count = 0
        for offset, name in enumerate(self.scan_names):
            self.scan.open_scan(self.scan_names[offset])
            if self.semantics:
                self.scan.open_label(self.label_names[offset])

            ind = np.where(self.scan.sem_label == self.classid)

            if ind[0].size > 50:
                print(self.scan_names[offset], ind[0].size)
                cluster = np.hstack(
                    (self.scan.points[ind[0], :],
                     np.expand_dims(self.scan.remissions[ind[0]], axis=1)))
                np.savetxt(
                    "cluster" + str(ind[0].size) + "_" + str(count) + ".txt",
                    cluster)
                count += 1

    # interface
    def key_press(self, event):
        self.canvas.events.key_press.block()
        self.img_canvas.events.key_press.block()
        if event.key == 'N':
            self.offset += 1
            if self.offset >= self.total:
                self.offset = 0
            self.update_scan()
        elif event.key == 'B':
            self.offset -= 1
            if self.offset < 0:
                self.offset = self.total - 1
            self.update_scan()
        elif event.key == 'Q' or event.key == 'Escape':
            self.destroy()

    def draw(self, event):
        if self.canvas.events.key_press.blocked():
            self.canvas.events.key_press.unblock()
        if self.img_canvas.events.key_press.blocked():
            self.img_canvas.events.key_press.unblock()

    def destroy(self):
        # destroy the visualization
        self.canvas.close()
        self.img_canvas.close()
        vispy.app.quit()

    def run(self):
        vispy.app.run()
Beispiel #9
0
class SemanticKittiTool:
    """ Class that creates and handles point cloud data for other application"""
    def __init__(self,
                 scan,
                 scan_names,
                 label_names,
                 config,
                 bbox_path,
                 obj,
                 offset=0,
                 semantics=True,
                 instances=False):
        self.scan = scan
        self.scan_names = scan_names
        self.label_names = label_names
        self.offset = offset
        self.config = config
        self.total = len(self.scan_names)
        self.semantics = semantics
        self.instances = instances
        self._labels_of_interest_name = yaml.load(open(obj))
        self._labels_of_interest_num = self.GetLabelIdx(
            self._labels_of_interest_name)
        self.sizepoints = 1
        self._bbox_path = bbox_path  # path where bounding boxes will be stored

        self._scan_labels = dict([('inst', []), ('sem', [])])
        # make instance colors
        max_inst_id = 100000
        self.inst_color_lut = np.random.uniform(low=0.0,
                                                high=1.0,
                                                size=(max_inst_id, 3))
        # force zero to a gray-ish color
        self.inst_color_lut[0] = np.full((3), 0.1)

        self.reset()

    def get_mpl_colormap(self, cmap_name):
        cmap = plt.get_cmap(cmap_name)

        # Initialize the matplotlib color map
        sm = plt.cm.ScalarMappable(cmap=cmap)
        # Obtain linear color range
        color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1]

        return color_range.reshape(256, 3).astype(np.float32) / 255.0

    def GetLabelIdx(self, objects):

        config = yaml.load(open(self.config))
        convlabel = config['labels']
        objectlabels = list(objects.values())
        key = []
        for value in objectlabels[0]:
            key.append(self.getKeysByValue(convlabel, value))
        return key

    def CreateAll3DBoundingBoxes(self):

        self.scan.reset()

        self.scan.open_scan(self.scan_names[self.offset])
        self.scan.open_label(self.label_names[self.offset])

        self.scan_labels = self.scan.sem_label
        self.scan_pts = self.scan.points

        # MergeColorFlag=1 - Merge all colors of all objects in same color frame
        # MergeColorFlag=0 - Create a color frame for each object
        bboxes, scan_color_all_obj = self.Create3DBoundingBoxes(
            self.scan_pts, self.scan_labels, MergeColorFlag=1)

        file_name = str(self.offset) + ".txt"

        self.Save3DBoundingBox(bboxes, file_name)

        # Plot objects & bounding boxes
        #shape = len(list(scan_color_all_obj))
        shape = len(scan_color_all_obj.shape)

        #if shape==1: # Plot all objects with different colors (MergeColorFlag=1 )
        #    scan_pts,scan_labels = self.Color3DBoundingBox(bboxes,scan_color_all_obj)
        #    self.colorizeObject(scan_labels)
        #    self.PlotPcl(scan_pts)

        #else: # Plot all instances of each object class at a time
        #    for scan_color in scan_color_all_obj:
        #        scan_pts,scan_labels = self.Color3DBoundingBox(bboxes,scan_color)
        #        self.colorizeObject(scan_color)
        #        self.PlotPcl(scan_pts)

        self.colorize(self.scan_labels)
        self.PlotPcl(self.scan_pts)

        #self.SaveBoundingBoxes(boundingboxes)
    def Save3DBoundingBox(self, bboxes, filename):

        label_path = os.path.join(self._bbox_path, "labels_2")

        if os.path.isdir(label_path):
            true_label_path = label_path
        else:
            try:
                access_rights = 0x755
                os.makedirs(label_path, access_rights)
                true_label_path = label_path
            except OSError:
                print("Creation of the directory %s failed" % label_path)
            else:
                print("Successfully created the directory %s " % label_path)

        file_path = os.path.join(true_label_path, filename)  # full file path

        for bbox in bboxes.items():

            kitti_labels = self.conv_to_kitti_format(bbox)
            kitti.write_label(kitti_labels, file_path)

    def conv_to_kitti_format(self, bbox):

        label = kitti.Object3d()
        labels = []
        for obj in bbox[1]:
            objtype = bbox[0]
            w = obj['w']
            h = obj['h']
            l = obj['l']
            rz = obj['rz']
            t = obj['t']
            score = 1
            label.loadBox3D(objtype, h, w, l, t, rz, score)
            labels.append(label)
        return (labels)

    def Color3DBoundingBox(self, bboxes, colorframe=[]):

        #scan_labels = self.scan_labels
        scan_pts = self.scan_pts
        scan_labels = colorframe
        bb_pts = np.array([])
        for classname, data in bboxes.items():
            for idx in range(0, len(data)):
                vertices = np.asarray(data[idx]['bb']['vertices'])

                if (len(bb_pts[:]) == 0):
                    bb_pts = vertices
                else:
                    bb_pts = np.concatenate((bb_pts, vertices))

        value = max(scan_labels) + 1
        redcolor = np.ones(bb_pts.shape[0], dtype=int) * value

        scan_labels = np.concatenate((scan_labels, redcolor))

        sizepoints = np.ones(self.scan_pts.shape[0], dtype=int)

        scan_pts = np.concatenate((scan_pts, bb_pts))

        sizebb = np.ones(bb_pts.shape[0], dtype=int) * 4

        self.sizepoints = np.concatenate((sizepoints, sizebb))
        return (scan_pts, scan_labels)

    def Create3DBoundingBoxes(self, ScanPts, ScanLabels, MergeColorFlag=0):

        # Check if objects of interest exist in the labels
        # ...
        # Do it here
        # ...

        # A segment may contain more than one object from the same class
        object_segment_idx = self.SplitIntoObjectSegments(ScanLabels)
        boundingboxes, label_color_scan = self.ObjectClassClustering(
            ScanPts, object_segment_idx, MergeColorFlag)

        return (boundingboxes, label_color_scan)

    def SplitIntoObjectSegments(self, ScanLabels):

        labels_of_interest_num = self._labels_of_interest_num
        labels_of_interest_name = self._labels_of_interest_name
        scan_semantic_labels = ScanLabels

        pointsLabel = dict()

        for label_num in labels_of_interest_num:
            pts_idx = np.where(scan_semantic_labels == label_num)[0]
            if pts_idx.size > 0:
                index = labels_of_interest_num.index(label_num)
                class_name = list(labels_of_interest_name.values())[0][index]
                pointsLabel.update({class_name: pts_idx})

        return pointsLabel

    def ObjectClassClustering(self, ScanPts, SegmentIdx, MergeColorFlag=0):

        # Create
        objects = dict()
        label_color_scan = []
        for name, segment in SegmentIdx.items():

            if (MergeColorFlag):
                bbox, single_class_objects, segment_scan_color = self.SegmentClustering(
                    ScanPts, segment, label_color_scan)
                label_color_scan = segment_scan_color
            else:
                bbox, single_class_objects, segment_scan_color = self.SegmentClustering(
                    ScanPts, segment, [])
                label_color_scan.append(segment_scan_color)

            objects.update({name: bbox})

        return (objects, label_color_scan)

    def SegmentClustering(self, ScanPts, SegmentIdx, ColorScan=[]):
        # Split segment (with multiple objects of the same class)
        # in different instances (point clouds)
        #
        # Input: Segment with same objects
        # Output: dict with all instances (objects)

        if len(ColorScan) > 0:
            max_value = int(max(ColorScan)) + 1
            segment_scan_color = ColorScan
        else:
            max_value = 1
            segment_scan_color = np.zeros(ScanPts.shape[0], dtype=int)

        # Clustering
        segment_points = ScanPts[SegmentIdx]
        db = DBSCAN(eps=0.3, min_samples=5).fit(segment_points)

        # go through all clusters
        num_instances = int(max(db.labels_))
        candidates = np.array(db.labels_)
        instance_boundle = []
        bbox_list = []
        for inst_label in range(0, num_instances):

            inst_index = np.where(candidates == inst_label)
            instance_index = SegmentIdx[inst_index]
            instance_points = segment_points[inst_index]
            segment_scan_color[instance_index] = int(max_value + inst_label)

            bb = self.Campute3DBoundingBox(instance_points)
            bbox_list.append(bb)
            instance = dict([('bb', bb), ('idx', instance_index)])
            instance_boundle.append(instance)

        return (bbox_list, instance_boundle, segment_scan_color)

    def getKeysByValue(self, dictOfElements, valueToFind):

        listOfKeys = []
        listOfItems = dictOfElements.items()
        for item in listOfItems:
            if item[1] == valueToFind:
                listOfKeys = item[0]
        return listOfKeys

    def Campute3DBoundingBox(self, points):

        # Transform to object referencial

        # translation
        t = np.mean(points, axis=0).reshape(3,
                                            1)  # compute object's mass center
        t[2] = 0  # object frame is mass center in the ground surface (as described kitti paper)
        pts = points.T - t

        # rotation
        # ....
        rz = 0

        # Get 3D bounding box bounderies
        x = pts[0, :]
        y = pts[1, :]
        z = pts[2, :]

        xminbound = float(min(x))
        xmaxbound = float(max(x))
        yminbound = float(min(y))
        ymaxbound = float(max(y))
        zminbound = float(min(z))
        zmaxbound = float(max(z))

        height = np.abs(zmaxbound - zminbound)  # height is z axis
        length = np.abs(xmaxbound - xminbound)  # length
        width = np.abs(ymaxbound - yminbound)  # width

        #Rot = r.as_matrix()
        obj = {'t': t, 'rz': rz, 'h': height, 'w': width, 'l': length}
        return (obj)

    def CreateTransf(self, R, T):
        row = np.zeros((1, 3), dtype=int)

        a = np.concatenate((R, row))
        b = np.concatenate((T, np.array([1]).reshape(1, 1)))
        c = np.concatenate((a, b), axis=1)

        return (c)

    def rotationMatrixToEulerAngles(self, R):

        #assert(isRotationMatrix(R))
        sy = math.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0])

        singular = sy < 1e-6

        if not singular:
            x = math.atan2(R[2, 1], R[2, 2])
            y = math.atan2(-R[2, 0], sy)
            z = math.atan2(R[1, 0], R[0, 0])
        else:
            x = math.atan2(-R[1, 2], R[1, 1])
            y = math.atan2(-R[2, 0], sy)
            z = 0

        return np.array([x, y, z])

    def reset(self):
        """ Reset. """
        # last key press (it should have a mutex, but visualization is not
        # safety critical, so let's do things wrong)
        self.action = "no"  # no, next, back, quit are the possibilities

        # new canvas prepared for visualizing data
        self.canvas = SceneCanvas(keys='interactive', show=True)
        # interface (n next, b back, q quit, very simple)
        self.canvas.events.key_press.connect(self.key_press)
        self.canvas.events.draw.connect(self.draw)
        # grid
        self.grid = self.canvas.central_widget.add_grid()

        #if self.instances:
        print("Using instances in visualizer")
        self.inst_view = vispy.scene.widgets.ViewBox(border_color='white',
                                                     parent=self.canvas.scene)
        self.grid.add_widget(self.inst_view, 0, 0)
        self.inst_vis = visuals.Markers()
        self.inst_view.camera = 'turntable'
        self.inst_view.add(self.inst_vis)
        visuals.XYZAxis(parent=self.inst_view.scene)
        # self.inst_view.camera.link(self.scan_view.camera)

    def resetTwo(self):
        """ Reset. """
        # last key press (it should have a mutex, but visualization is not
        # safety critical, so let's do things wrong)
        self.action = "no"  # no, next, back, quit are the possibilities

        # new canvas prepared for visualizing data
        self.canvas = SceneCanvas(keys='interactive', show=True)
        # interface (n next, b back, q quit, very simple)
        self.canvas.events.key_press.connect(self.key_press)
        self.canvas.events.draw.connect(self.draw)
        # grid
        self.grid = self.canvas.central_widget.add_grid()

        # laserscan part
        self.scan_view = vispy.scene.widgets.ViewBox(border_color='white',
                                                     parent=self.canvas.scene)
        self.grid.add_widget(self.scan_view, 0, 0)
        self.scan_vis = visuals.Markers()
        self.scan_view.camera = 'turntable'
        self.scan_view.add(self.scan_vis)
        visuals.XYZAxis(parent=self.scan_view.scene)

        #if self.instances:
        print("Using instances in visualizer")
        self.inst_view = vispy.scene.widgets.ViewBox(border_color='white',
                                                     parent=self.canvas.scene)
        self.grid.add_widget(self.inst_view, 0, 1)
        self.inst_vis = visuals.Markers()
        self.inst_view.camera = 'turntable'
        self.inst_view.add(self.inst_vis)
        visuals.XYZAxis(parent=self.inst_view.scene)
        # self.inst_view.camera.link(self.scan_view.camera)

    def draw(self, event):
        if self.canvas.events.key_press.blocked():
            self.canvas.events.key_press.unblock()

    # interface
    def key_press(self, event):
        self.canvas.events.key_press.block()
        if event.key == 'N':
            self.offset += 1
            if self.offset >= self.total:
                self.offset = 0
            self.CreateAll3DBoundingBoxes()
        elif event.key == 'B':
            self.offset -= 1
            if self.offset < 0:
                self.offset = self.total - 1
            self.CreateAll3DBoundingBoxes()
        elif event.key == 'Q' or event.key == 'Escape':
            self.destroy()

    def PlotPcl(self, points):

        self.inst_vis.set_data(points,
                               face_color=self.inst_label_color[..., ::-1],
                               edge_color=self.inst_label_color[..., ::-1],
                               size=self.sizepoints)

    def PlotPclTwo(self, points):

        # Generate 3D bounding boxes of all Laser scans
        # plot scan
        power = 16
        # print()
        range_data = np.copy(self.scan.unproj_range)
        # print(range_data.max(), range_data.min())
        range_data = range_data**(1 / power)
        # print(range_data.max(), range_data.min())
        viridis_range = ((range_data - range_data.min()) /
                         (range_data.max() - range_data.min()) * 255).astype(
                             np.uint8)
        viridis_map = self.get_mpl_colormap("viridis")
        viridis_colors = viridis_map[viridis_range]

        self.scan_vis.set_data(self.scan.points,
                               face_color=viridis_colors[..., ::-1],
                               edge_color=viridis_colors[..., ::-1],
                               size=1)

        self.inst_vis.set_data(points,
                               face_color=self.inst_label_color[..., ::-1],
                               edge_color=self.inst_label_color[..., ::-1],
                               size=self.sizepoints)

    def colorizeObject(self, ScanLabels):
        """ Colorize pointcloud with the color of each semantic label
        """
        #shapescan = self.scan.points.shape
        #sem_label = self.scan_labels['sem']
        #labels_of_interest_num = self._labels_of_interest_num
        #instlabel = self.scan_labels['inst']
        self.inst_label_color = self.inst_color_lut[ScanLabels]
        self.inst_label_color = self.inst_label_color.reshape((-1, 3))
        return

    def colorize(self, ScanLabels):
        """ Colorize pointcloud with the color of each semantic label
        """
        shapescan = self.scan.points.shape
        sem_label = ScanLabels
        labels_of_interest_num = self._labels_of_interest_num
        #instlabel = self.scan_labels['inst']
        new_scan_labels = np.zeros(shapescan[0], dtype=int)

        for labelnum in labels_of_interest_num:
            idx = np.where(sem_label == labelnum)[0]
            new_scan_labels[idx] = labelnum

        self.inst_label_color = self.inst_color_lut[new_scan_labels]
        self.inst_label_color = self.inst_label_color.reshape((-1, 3))

    def destroy(self):
        # destroy the visualization
        self.canvas.close()
        vispy.app.quit()

    def run(self):
        vispy.app.run()