def _get_flow_fast(self, h_render, h_real, idx, label_img, cam, b_real,
                       b_ren, K_real):
        m_real = copy.deepcopy(self._mesh[idx])
        m_real = transform_mesh(m_real, h_real)

        rmi_real = RayMeshIntersector(m_real)
        sub = 1
        tl, br = b_real.limit_bb()

        rays_origin_real = self._rays_origin_real[cam][
            int(tl[0]):int(br[0]),
            int(tl[1]):int(br[1])][::sub, ::sub]
        rays_dir_real = self._rays_dir[cam][
            int(tl[0]):int(br[0]),
            int(tl[1]):int(br[1])][::sub, ::sub]

        real_locations, real_index_ray, real_res_mesh_id = rmi_real.intersects_location(
            ray_origins=np.reshape(rays_origin_real, (-1, 3)),
            ray_directions=np.reshape(rays_dir_real, (-1, 3)),
            multiple_hits=False)

        h_trafo = h_render @ np.linalg.inv(h_real)
        ren_locations = (real_locations @ h_trafo[:3, :3].T) + h_trafo[:3, 3]

        uv_ren = backproject_points_np(ren_locations, K=self.K_ren)
        uv_real = backproject_points_np(real_locations, K=K_real)
        dis = uv_ren - uv_real

        uv_real = np.uint32(uv_real)
        idx = np.uint32(uv_real[:, 0] * (self._w) + uv_real[:, 1])
        disparity_pixels = np.zeros((self._h, self._w, 2)) - 999
        disparity_pixels = np.reshape(disparity_pixels, (-1, 2))
        disparity_pixels[idx] = dis
        disparity_pixels = np.reshape(disparity_pixels, (self._h, self._w, 2))

        f_3 = disparity_pixels[:, :, 0] != -999

        u_map = disparity_pixels[:, :, 0]
        v_map = disparity_pixels[:, :, 1]
        u_map = fill(u_map, u_map == -999)
        v_map = fill(v_map, v_map == -999)

        real_tl = np.zeros((2))
        real_tl[0] = int(b_real.tl[0])
        real_tl[1] = int(b_real.tl[1])
        real_br = np.zeros((2))
        real_br[0] = int(b_real.br[0])
        real_br[1] = int(b_real.br[1])
        ren_tl = np.zeros((2))
        ren_tl[0] = int(b_ren.tl[0])
        ren_tl[1] = int(b_ren.tl[1])
        ren_br = np.zeros((2))
        ren_br[0] = int(b_ren.br[0])
        ren_br[1] = int(b_ren.br[1])

        return u_map, v_map, f_3, torch.tensor(
            real_tl, dtype=torch.int32), torch.tensor(
                real_br, dtype=torch.int32), torch.tensor(
                    ren_tl, dtype=torch.int32), torch.tensor(ren_br,
                                                             dtype=torch.int32)
Beispiel #2
0
def ray_cast_mesh(mesh, rays_origins, ray_directions):
    intersector = RayMeshIntersector(mesh)
    index_triangles, index_ray, point_cloud = intersector.intersects_id(
        ray_origins=rays_origins,
        ray_directions=ray_directions,
        multiple_hits=False,
        return_locations=True)
    return index_triangles, index_ray, point_cloud
Beispiel #3
0
    def __init__(self, mesh_path, k_image, size):
        H, W = size

        mesh = trimesh.load(mesh_path)
        self._rmi = RayMeshIntersector(mesh)
        self._start, stop, self._dir = get_rays(k_image,
                                                size,
                                                extrinsic=None,
                                                d_min=0.3,
                                                d_max=1.4)
Beispiel #4
0
 def __init__(self, mesh, check_faces=False):
     self.mesh = mesh
     self.sphere_points = trimesh.sample.sample_surface(
         mesh.bounding_sphere, count=1000)[0] * 2
     self.correct_faces = {i: 0 for i in range(len(mesh.faces))}
     self.ray_mesh = RayMeshIntersector(geometry=mesh)
     self.faces_centroids = self.mesh.triangles.mean(axis=1)
     self.correct_points = np.array([])
     if check_faces:
         self.compute_visible_faces()
     else:
         self.correct_faces = {i: 1 for i in range(len(mesh.faces))}
Beispiel #5
0
class RayCaster():
    def __init__(self, mesh_path, k_image, size):
        H, W = size

        mesh = trimesh.load(mesh_path)
        self._rmi = RayMeshIntersector(mesh)
        self._start, stop, self._dir = get_rays(k_image,
                                                size,
                                                extrinsic=None,
                                                d_min=0.3,
                                                d_max=1.4)

        # TODO: Is this needed ?
        # Get for onehote encoding the colors in the map
        # colors = self._rmi.mesh.visual.face_colors[:,:3]
        # self.faces_to_labels = np.zeros( (colors.shape[0] ))
        # unique, inverse = np.unique(colors, return_inverse=True, axis= 0)
        # for k, c in enumerate( unique):
        #   self.faces_to_labels[ inverse == k] = np.argmin( np.linalg.norm(self._rgb-c[:3], axis=1,ord=2)  , axis = 0 )

    def raycast(self, H_cam):
        # Move Camera Rays
        ray_origins = transform_points(self._start.reshape((-1, 3)), H_cam)
        H_turn = np.eye(4)
        H_turn[:3, :3] = H_cam[:3, :3]
        ray_directions = transform_points(self._dir.reshape((-1, 3)), H_turn)

        # Perform Raytracing
        locations, index_ray, index_tri = self._rmi.intersects_location(
            ray_origins=ray_origins,
            ray_directions=ray_directions,
            multiple_hits=False)
        return locations, index_ray, index_tri, ray_origins
Beispiel #6
0
class MeshSampler:
    def __init__(self, mesh, check_faces=False):
        self.mesh = mesh
        self.sphere_points = trimesh.sample.sample_surface(
            mesh.bounding_sphere, count=1000)[0] * 2
        self.correct_faces = {i: 0 for i in range(len(mesh.faces))}
        self.ray_mesh = RayMeshIntersector(geometry=mesh)
        self.faces_centroids = self.mesh.triangles.mean(axis=1)
        self.correct_points = np.array([])
        if check_faces:
            self.compute_visible_faces()
        else:
            self.correct_faces = {i: 1 for i in range(len(mesh.faces))}

    def visible_faces(self):
        return self.correct_faces

    def compute_visible_faces(self):
        for i, face in enumerate(tqdm(self.mesh.triangles)):
            for point in face:
                ray_directions = -(self.sphere_points - point)
                faces_hit = self.ray_mesh.intersects_first(
                    self.sphere_points, ray_directions)
                if i in faces_hit:
                    self.correct_faces[i] = 1
        return self.correct_faces

    def sample_points(self, n_points=10000):
        points = trimesh.sample.sample_surface(self.mesh, count=n_points)
        correct_points = []
        normals_for_points = []
        for i, point in enumerate(tqdm(points[0])):
            if self.correct_faces[points[1][i]] == 1:
                correct_points += [point]
                normals_for_points += [self.mesh.face_normals[points[1][i]]]
        self.correct_points = np.array(correct_points)
        self.normals_for_points = np.array(normals_for_points)
        return self.correct_points

    def compute_sdf(self, sigma=0.0025):
        noise = np.random.normal(0, sigma, self.correct_points.shape)
        noisy_points = self.correct_points + noise
        sdf = self.mesh.nearest.signed_distance(noisy_points)
        correct_mesh_points = []
        correct_sdf = []
        correct_normals = []
        for i, distance in enumerate(sdf):
            if math.isnan(distance):
                continue
            else:
                correct_mesh_points += [noisy_points[i]]
                correct_sdf += [sdf[i]]
                correct_normals += [self.normals_for_points[i]]
        return np.array(correct_mesh_points), np.array(correct_sdf), np.array(
            correct_normals)
    def _get_flow(self, h_render, h_real, idx, label_img, cam, b_real, b_ren):
        st___ = time.time()
        f_1 = label_img == int(idx)

        min_vis_size = self._cfg_d.get('flow_cfg', {}).get('min_vis_size', 200)
        if np.sum(f_1) < min_vis_size:
            # to little of the object is visible
            return False
        st = time.time()
        m_real = copy.deepcopy(self._mesh[idx])
        m_render = copy.deepcopy(self._mesh[idx])

        m_real = transform_mesh(m_real, h_real)
        m_render = transform_mesh(m_render, h_render)

        rmi_real = RayMeshIntersector(m_real)
        rmi_render = RayMeshIntersector(m_render)
        # locations index_ray index_tri

        # crop the rays to the bounding box of the object to compute less rays
        # subsample to even compute less rays !
        sub = self._cfg_d.get('flow_cfg', {}).get('sub', 2)

        tl, br = b_real.limit_bb()
        h_idx_real = np.reshape(
            self._grid_x[int(tl[0]):int(br[0]),
                         int(tl[1]):int(br[1])][::sub, ::sub], (-1))
        w_idx_real = np.reshape(
            self._grid_y[int(tl[0]):int(br[0]),
                         int(tl[1]):int(br[1])][::sub, ::sub], (-1))

        rays_origin_real = self._rays_origin_real[cam][
            int(tl[0]):int(br[0]),
            int(tl[1]):int(br[1])][::sub, ::sub]
        rays_dir_real = self._rays_dir[cam][
            int(tl[0]):int(br[0]),
            int(tl[1]):int(br[1])][::sub, ::sub]

        tl, br = b_ren.limit_bb()
        rays_origin_render = self._rays_origin_real[0][
            int(tl[0]):int(br[0]),
            int(tl[1]):int(br[1])][::sub, ::sub]
        rays_dir_render = self._rays_dir[0][
            int(tl[0]):int(br[0]),
            int(tl[1]):int(br[1])][::sub, ::sub]
        h_idx_render = np.reshape(
            self._grid_x[int(tl[0]):int(br[0]),
                         int(tl[1]):int(br[1])][::sub, ::sub], (-1))
        w_idx_render = np.reshape(
            self._grid_y[int(tl[0]):int(br[0]),
                         int(tl[1]):int(br[1])][::sub, ::sub], (-1))

        st_ = time.time()
        # ray traceing
        render_res_mesh_id = rmi_render.intersects_first(
            ray_origins=np.reshape(rays_origin_render, (-1, 3)),
            ray_directions=np.reshape(rays_dir_render, (-1, 3)))
        real_res_mesh_id = rmi_real.intersects_first(
            ray_origins=np.reshape(rays_origin_real, (-1, 3)),
            ray_directions=np.reshape(rays_dir_real, (-1, 3)))

        # a = np.reshape( rays_origin_render, (-1,3) )
        # b = np.reshape( a, rays_origin_render.shape )

        # tl, br = b_real.limit_bb()
        # c1 = np.zeros ( self._grid_x.shape ) - 1
        # c1 [int(tl[0]): int(br[0]), int(tl[1]): int(br[1])][::sub,::sub]  = np.reshape ( real_res_mesh_id  , (rays_origin_real.shape[0], rays_origin_real.shape[1]) )

        # tl, br = b_ren.limit_bb()
        # c2 = np.zeros ( self._grid_x.shape ) - 1
        # c2 [int(tl[0]): int(br[0]), int(tl[1]): int(br[1])][::sub,::sub]  = np.reshape ( render_res_mesh_id  , (rays_origin_render.shape[0], rays_origin_render.shape[1]) )

        f_real = real_res_mesh_id != -1
        f_render = render_res_mesh_id != -1

        render_res_mesh_id = render_res_mesh_id[f_render]
        h_idx_render = h_idx_render[f_render]
        w_idx_render = w_idx_render[f_render]

        real_res_mesh_id = real_res_mesh_id[f_real]
        h_idx_real = h_idx_real[f_real]
        w_idx_real = w_idx_real[f_real]

        real_res_mesh_id.shape[0]

        disparity_pixels = np.zeros((self._h, self._w, 2)) - 999
        matches = 0
        i = 0
        idx_pre = np.random.permutation(np.arange(
            0, real_res_mesh_id.shape[0])).astype(np.long)
        while matches < self._max_matches and i < self._max_iterations and i < real_res_mesh_id.shape[
                0]:
            r_id = idx_pre[i]
            mesh_id = int(real_res_mesh_id[r_id])
            s = np.where(render_res_mesh_id == mesh_id)
            if s[0].shape[0] > 0:
                j = s[0][0]
                _h = h_idx_real[r_id]
                _w = w_idx_real[r_id]

                if _h == -1 or _w == -1 or h_idx_render[
                        j] == -1 or w_idx_render[j] == -1:
                    pass
                    # print('encountered invalid pixel')
                else:
                    disparity_pixels[_h, _w, 0] = h_idx_render[j] - _h
                    disparity_pixels[_h, _w, 1] = w_idx_render[j] - _w
                    matches += 1
            i += 1

        # print(f'Rays origin real: {rays_origin_real.shape},  Rays dir: {rays_dir_real.shape}')
        # try:
        # print(f'IDX REAL max{np.max ( h_idx_real[ h_idx_real !=  -1 ] )}')
        # print(f'IDX REAL min{np.min ( h_idx_real[ h_idx_real !=  -1 ] )}')
        # except:
        # pass
        f_2 = disparity_pixels[:, :, 0] != -999
        f_3 = f_2 * f_1
        points = np.where(f_3 != False)
        points = np.stack([np.array(points[0]), np.array(points[1])], axis=1)

        min_matches = self._cfg_d.get('flow_cfg', {}).get('min_matches', 50)
        if np.sum(f_3) < min_matches:
            # print(f'not enough matches{matches}, F3 {np.sum(f_3)}, REAL {h_idx_real.shape}')
            # print(render_res, rays_dir_render2.shape, rays_origin_render2.shape )
            return False

        u_map = griddata(points,
                         disparity_pixels[f_3][:, 0],
                         (self._grid_x, self._grid_y),
                         method='nearest')
        v_map = griddata(points,
                         disparity_pixels[f_3][:, 1],
                         (self._grid_x, self._grid_y),
                         method='nearest')

        dil_kernel_size = self._cfg_d.get('flow_cfg',
                                          {}).get('dil_kernel_size', 2)
        inp = np.uint8(f_3 * 255)
        kernel = np.ones((dil_kernel_size, dil_kernel_size), np.uint8)
        valid_flow_mask = (cv2.dilate(inp, kernel, iterations=1) != 0)
        valid_flow_mask = valid_flow_mask * f_1

        real_tl = np.zeros((2))
        real_tl[0] = int(b_real.tl[0])
        real_tl[1] = int(b_real.tl[1])
        real_br = np.zeros((2))
        real_br[0] = int(b_real.br[0])
        real_br[1] = int(b_real.br[1])
        ren_tl = np.zeros((2))
        ren_tl[0] = int(b_ren.tl[0])
        ren_tl[1] = int(b_ren.tl[1])
        ren_br = np.zeros((2))
        ren_br[0] = int(b_ren.br[0])
        ren_br[1] = int(b_ren.br[1])

        return u_map, v_map, valid_flow_mask, torch.tensor(
            real_tl, dtype=torch.int32), torch.tensor(
                real_br, dtype=torch.int32), torch.tensor(
                    ren_tl, dtype=torch.int32), torch.tensor(ren_br,
                                                             dtype=torch.int32)
    def __init__(self, arg):
        self.arg = arg

        # Pass the args config
        scannet_scene_dir = arg.scannet_scene_dir
        label_scene_dir = arg.label_scene_dir
        mapping_scannet_path = arg.mapping_scannet_path
        self._mesh_path = arg.mesh_path
        map_serialized_path = arg.map_serialized_path

        self._mode = arg.mode
        self._confidence = arg.confidence

        self._gt_dir = f"{scannet_scene_dir}/label-filt/"
        self._label_scene_dir = label_scene_dir
        self._mapping_scannet = get_mapping_scannet(mapping_scannet_path)

        rospack = rospkg.RosPack()
        kimera_interfacer_path = rospack.get_path('kimera_interfacer')
        # MAPPING
        mapping = np.genfromtxt(
            f'{kimera_interfacer_path}/cfg/nyu40_segmentation_mapping.csv',
            delimiter=',')
        ids = mapping[1:, 5]
        self._rgb = mapping[1:, 1:4]
        self._rgb[0, :] = 255

        mesh = trimesh.load(self._mesh_path)

        # GT LABEL TEST
        label_gt = imageio.imread(
            os.path.join(scannet_scene_dir, "label-filt/0" + '.png'))
        H, W = label_gt.shape

        size = (H, W)
        self._size = size
        data = np.loadtxt(
            f"{ scannet_scene_dir }/intrinsic/intrinsic_depth.txt")
        k_render = np.array([[data[0, 0], 0, data[0, 2]],
                             [0, data[1, 1], data[1, 2]], [0, 0, 1]])
        data = np.loadtxt(f"{scannet_scene_dir}/intrinsic/intrinsic_color.txt")
        k_image = np.array([[data[0, 0], 0, data[0, 2]],
                            [0, data[1, 1], data[1, 2]], [0, 0, 1]])

        start, stop, dir = get_rays(k_image,
                                    size,
                                    extrinsic=None,
                                    d_min=0.3,
                                    d_max=1.4)
        self._rmi = RayMeshIntersector(mesh)
        self._start = start
        self._dir = dir

        # Get for onehote encoding the colors in the map
        colors = self._rmi.mesh.visual.face_colors[:, :3]
        self.faces_to_labels = np.zeros((colors.shape[0]))
        unique, inverse = np.unique(colors, return_inverse=True, axis=0)
        for k, c in enumerate(unique):
            self.faces_to_labels[inverse == k] = np.argmin(np.linalg.norm(
                self._rgb - c[:3], axis=1, ord=2),
                                                           axis=0)

        # Parse serialized voxel_data to usefull numpy structure
        map = get_semantic_map(map_serialized_path)
        self._voxels, self._mi = parse_protobug_msg_into_accessiable_np_array(
            map)

        self._output_buffer_probs = np.zeros((H, W, self._voxels.shape[3]))
        self._output_buffer_img = np.zeros((H, W, 3))
        self._output_buffer_label = np.zeros((H, W))

        self._voxel_size = map.semantic_blocks[0].voxel_size

        v, u = np.mgrid[0:H, 0:W]
        self._v = v.reshape(-1)
        self._u = u.reshape(-1)
class LabelGenerator:
    def __init__(self, arg):
        self.arg = arg

        # Pass the args config
        scannet_scene_dir = arg.scannet_scene_dir
        label_scene_dir = arg.label_scene_dir
        mapping_scannet_path = arg.mapping_scannet_path
        self._mesh_path = arg.mesh_path
        map_serialized_path = arg.map_serialized_path

        self._mode = arg.mode
        self._confidence = arg.confidence

        self._gt_dir = f"{scannet_scene_dir}/label-filt/"
        self._label_scene_dir = label_scene_dir
        self._mapping_scannet = get_mapping_scannet(mapping_scannet_path)

        rospack = rospkg.RosPack()
        kimera_interfacer_path = rospack.get_path('kimera_interfacer')
        # MAPPING
        mapping = np.genfromtxt(
            f'{kimera_interfacer_path}/cfg/nyu40_segmentation_mapping.csv',
            delimiter=',')
        ids = mapping[1:, 5]
        self._rgb = mapping[1:, 1:4]
        self._rgb[0, :] = 255

        mesh = trimesh.load(self._mesh_path)

        # GT LABEL TEST
        label_gt = imageio.imread(
            os.path.join(scannet_scene_dir, "label-filt/0" + '.png'))
        H, W = label_gt.shape

        size = (H, W)
        self._size = size
        data = np.loadtxt(
            f"{ scannet_scene_dir }/intrinsic/intrinsic_depth.txt")
        k_render = np.array([[data[0, 0], 0, data[0, 2]],
                             [0, data[1, 1], data[1, 2]], [0, 0, 1]])
        data = np.loadtxt(f"{scannet_scene_dir}/intrinsic/intrinsic_color.txt")
        k_image = np.array([[data[0, 0], 0, data[0, 2]],
                            [0, data[1, 1], data[1, 2]], [0, 0, 1]])

        start, stop, dir = get_rays(k_image,
                                    size,
                                    extrinsic=None,
                                    d_min=0.3,
                                    d_max=1.4)
        self._rmi = RayMeshIntersector(mesh)
        self._start = start
        self._dir = dir

        # Get for onehote encoding the colors in the map
        colors = self._rmi.mesh.visual.face_colors[:, :3]
        self.faces_to_labels = np.zeros((colors.shape[0]))
        unique, inverse = np.unique(colors, return_inverse=True, axis=0)
        for k, c in enumerate(unique):
            self.faces_to_labels[inverse == k] = np.argmin(np.linalg.norm(
                self._rgb - c[:3], axis=1, ord=2),
                                                           axis=0)

        # Parse serialized voxel_data to usefull numpy structure
        map = get_semantic_map(map_serialized_path)
        self._voxels, self._mi = parse_protobug_msg_into_accessiable_np_array(
            map)

        self._output_buffer_probs = np.zeros((H, W, self._voxels.shape[3]))
        self._output_buffer_img = np.zeros((H, W, 3))
        self._output_buffer_label = np.zeros((H, W))

        self._voxel_size = map.semantic_blocks[0].voxel_size

        v, u = np.mgrid[0:H, 0:W]
        self._v = v.reshape(-1)
        self._u = u.reshape(-1)

    def get_label(self, H_cam, frame, visu=True, override_mode=None):
        if override_mode is None:
            mode = self._mode
        else:
            mode = override_mode

        if mode == "gt":
            self._output_buffer_label = load_label_scannet(
                f"{self._gt_dir}/{frame}.png", self._mapping_scannet)
            for i in range(0, 41):
                self._output_buffer_img[self._output_buffer_label ==
                                        i, :3] = self._rgb[i]
            self._output_buffer_probs.fill(0)
            tup = ()
        elif mode == "network_prediction":
            self._output_buffer_label = load_label_network(
                os.path.join(self._label_scene_dir, f"{frame}.png"),
                self._mapping_scannet)

            for i in range(0, 41):
                self._output_buffer_img[self._output_buffer_label ==
                                        i, :3] = self._rgb[i]
            self._output_buffer_probs.fill(0)
        elif mode == "map_onehot" or mode == "map_probs" or mode.find(
                "map_probs_with_confidence") != -1:
            self.set_label_raytracing(H_cam, mode, visu)
        else:
            raise ValueError("Invalid mode")

        return self._output_buffer_label, np.uint8(
            self._output_buffer_img), self._output_buffer_probs

    def set_label_raytracing(self, H_cam, mode, visu):
        # Move Camera Rays
        ray_origins = transform(self._start.reshape((-1, 3)), H_cam)
        H_turn = np.eye(4)
        H_turn[:3, :3] = H_cam[:3, :3]
        ray_directions = transform(self._dir.reshape((-1, 3)), H_turn)

        # Perform Raytracing
        locations, index_ray, index_tri = self._rmi.intersects_location(
            ray_origins=ray_origins,
            ray_directions=ray_directions,
            multiple_hits=False)
        if mode == "map_onehot":
            colors = self._rmi.mesh.visual.face_colors[index_tri]
            self._rmi.mesh.visual.vertex_colors
            # Reset the buffer to invalid
            self._output_buffer_label[:, :] = 0
            tmp = np.copy(self._output_buffer_img)
            for j in range(locations.shape[0]):
                _v, _u = self._v[index_ray[j]], self._u[index_ray[j]]
                tmp[_v, _u, :] = colors[j, :3]
                self._output_buffer_label[_v, _u] = self.faces_to_labels[
                    index_tri[j]]

            for i in range(0, 41):
                self._output_buffer_img[self._output_buffer_label ==
                                        i, :3] = self._rgb[i]

            if visu:
                plt.imshow(np.uint8(self._output_buffer_img))
                plt.imshow(np.uint8(tmp))
            return self._output_buffer_label, self._output_buffer_img, None

        # Compute closest voxel index
        idx_tmp = np.floor(((locations - self._mi + eps) /
                            self._voxel_size)).astype(np.uint32)
        self._output_buffer_probs.fill(0)
        self._output_buffer_probs[:, :, 0] = 1
        # Store class probabilites in buffer
        for j in range(locations.shape[0]):
            _v, _u = self._v[index_ray[j]], self._u[index_ray[j]]
            self._output_buffer_probs[self._v[index_ray[j]],
                                      self._u[index_ray[j]], :] = self._voxels[
                                          tuple(idx_tmp[j])]

        self._output_buffer_probs = self._output_buffer_probs - \
                                    (np.min( self._output_buffer_probs, axis=2)[...,None]).repeat(self._output_buffer_probs.shape[2],2)
        self._output_buffer_probs = self._output_buffer_probs/ \
                                    (np.sum( self._output_buffer_probs, axis=2)[...,None]).repeat(self._output_buffer_probs.shape[2],2)

        if mode.find("map_probs_with_confidence") != -1:
            m = self._output_buffer_probs.max(axis=2) < self._confidence
            self._output_buffer_label[m] = 0
            self._output_buffer_probs[m] = 0
            self._output_buffer_probs[m, 0] = 1

        self._output_buffer_label = np.argmax(self._output_buffer_probs,
                                              axis=2)
        self._output_buffer_img.fill(0)
        for i in range(0, 41):
            self._output_buffer_img[self._output_buffer_label ==
                                    i, :3] = self._rgb[i]

        if visu:
            plt.imshow(np.uint8(self._output_buffer_img))
            plt.show()
            self.visu_current_buffer(locations, ray_origins)

    def visu_current_buffer(self,
                            locations=None,
                            ray_origins=None,
                            sub=1000,
                            sub2=8):
        vis = o3d.visualization.Visualizer()
        vis.create_window(width=self._size[1],
                          height=self._size[0],
                          visible=True)
        mesh_o3d = o3d.io.read_triangle_mesh(self._mesh_path)
        vis.add_geometry(mesh_o3d)

        if locations is not None:
            for j in range(0, locations.shape[0], sub):
                # Draw detected mesh intersection points
                sphere_o3d = o3d.geometry.TriangleMesh.create_sphere(
                    radius=0.01).translate(locations[j, :])
                vis.add_geometry(sphere_o3d)

        if ray_origins is not None:
            # Draw camera rays start and end
            for j in range(0, ray_origins.shape[0], sub):
                sphere_o3d = o3d.geometry.TriangleMesh.create_sphere(
                    radius=0.01).translate(ray_origins[j, :])
                sphere_o3d.paint_uniform_color([1, 0, 0])
                vis.add_geometry(sphere_o3d)

        if False:
            for block in range(len(map.semantic_blocks)):
                large_size = (map.semantic_blocks[block].voxel_size *
                              map.semantic_blocks[block].voxels_per_side
                              )  # in meters
                cube = o3d.geometry.TriangleMesh.create_box()
                cube.scale(1 * large_size / 5, center=cube.get_center())
                cube.translate(origins[block])
                vis.add_geometry(cube)
                print(large_size, origins[block],
                      map.semantic_blocks[block].voxel_size)

        if False:
            for block in range(4):  # len(map.semantic_blocks)):
                voxel_size = map.semantic_blocks[block].voxel_size  # in meters
                voxels_per_side = map.semantic_blocks[block].voxels_per_side

                for j in range(0,
                               len(map.semantic_blocks[block].semantic_voxels),
                               1):
                    cube = o3d.geometry.TriangleMesh.create_box()
                    cube.scale(1 * voxel_size * 2, center=cube.get_center())
                    index = map.semantic_blocks[block].semantic_voxels[
                        j].linear_index
                    trans = get_x_y_z(index, origins[block], voxels_per_side,
                                      voxel_size)
                    cube.translate(trans)
                    rgb = [
                        map.semantic_blocks[block].semantic_voxels[j].color.r /
                        255,
                        map.semantic_blocks[block].semantic_voxels[j].color.g /
                        255,
                        map.semantic_blocks[block].semantic_voxels[j].color.b /
                        255
                    ]

                    rgb = [0, 0, 1]
                    cube.paint_uniform_color(rgb)
                    vis.add_geometry(cube)

        if True:
            for j in range(0, self._voxels.shape[0], sub2):
                for k in range(0, self._voxels.shape[1], sub2):
                    for l in range(0, self._voxels.shape[2], sub2):
                        translation = np.array([j, k, l], dtype=np.float)
                        translation *= self._voxel_size
                        translation += self._mi

                        # check if voxel is valid
                        if np.sum(self._voxels[j, k, l, :]) != 0:
                            col_index = np.argmax(self._voxels[j, k, l, :])
                            draw_cube(vis, translation, self._voxel_size,
                                      self._rgb[col_index])

        if locations is not None:
            idx_tmp = np.floor(((locations - self._mi + eps) /
                                self._voxel_size)).astype(np.uint32)

            for j in range(0, locations.shape[0], sub):
                col_index = np.argmax(self._voxels[idx_tmp[j, 0], idx_tmp[j,
                                                                          1],
                                                   idx_tmp[j, 2], :])

                translation = np.copy(idx_tmp[j]).astype(np.float)
                translation *= self._voxel_size
                translation += self._mi
                draw_cube(vis, translation, self._voxel_size * 2,
                          self._rgb[col_index])

        vis.run()
        vis.destroy_window()