Ejemplo n.º 1
0
    def export_surfaces(self, field_fn, th, bits):
        """
        extract triangle-meshes from the implicit field using marching cube algorithm
            Lewiner, Thomas, et al. "Efficient implementation of marching cubes' cases with topological guarantees." 
            Journal of graphics tools 8.2 (2003): 1-15.
        """
        logger.info("marching cube...")
        encoder_states = self.precompute(id=None)
        points = encoder_states['voxel_center_xyz']

        scores = self.get_scores(field_fn, th=th, bits=bits, encoder_states=encoder_states)
        coords, residual = discretize_points(points, self.voxel_size)
        A, B, C = [s + 1 for s in coords.max(0).values.cpu().tolist()]
    
        # prepare grids
        full_grids = points.new_ones(A * B * C, bits ** 3)
        full_grids[coords[:, 0] * B * C + coords[:, 1] * C + coords[:, 2]] = scores
        full_grids = full_grids.reshape(A, B, C, bits, bits, bits)
        full_grids = full_grids.permute(0, 3, 1, 4, 2, 5).reshape(A * bits, B * bits, C * bits)
        full_grids = 1 - full_grids

        # marching cube
        from skimage import measure
        space_step = self.voxel_size.item() / bits
        verts, faces, normals, _ = measure.marching_cubes_lewiner(
            volume=full_grids.cpu().numpy(), level=0.5,
            spacing=(space_step, space_step, space_step)
        )
        verts += (residual - (self.voxel_size / 2)).cpu().numpy()
        verts = np.array([tuple(a) for a in verts.tolist()], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
        faces = np.array([(a, ) for a in faces.tolist()], dtype=[('vertex_indices', 'i4', (3,))])
        return PlyData([PlyElement.describe(verts, 'vertex'), PlyElement.describe(faces, 'face')])
Ejemplo n.º 2
0
    def export_voxels(self, return_mesh=False):
        logger.info("exporting learned sparse voxels...")
        voxel_idx = torch.arange(self.keep.size(0), device=self.keep.device)
        voxel_idx = voxel_idx[self.keep.bool()]
        voxel_pts = self.points[self.keep.bool()]
        if not return_mesh:
            # HACK: we export the original voxel indices as "quality" in case for editing
            points = [
                (voxel_pts[k, 0], voxel_pts[k, 1], voxel_pts[k, 2], voxel_idx[k])
                for k in range(voxel_idx.size(0))
            ]
            vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('quality', 'f4')])
            return PlyData([PlyElement.describe(vertex, 'vertex')])
        
        else:
            # generate polygon for voxels
            center_coords, residual = discretize_points(voxel_pts, self.voxel_size / 2)
            offsets = torch.tensor([[-1,-1,-1],[-1,-1,1],[-1,1,-1],[1,-1,-1],[1,1,-1],[1,-1,1],[-1,1,1],[1,1,1]], device=center_coords.device)
            vertex_coords = center_coords[:, None, :] + offsets[None, :, :]
            vertex_points = vertex_coords.type_as(residual) * self.voxel_size / 2 + residual
            
            faceidxs = [[1,6,7,5],[7,6,2,4],[5,7,4,3],[1,0,2,6],[1,5,3,0],[0,3,4,2]]
            all_vertex_keys, all_vertex_idxs  = {}, []
            for i in range(vertex_coords.shape[0]):
                for j in range(8):
                    key = " ".join(["{}".format(int(p)) for p in vertex_coords[i,j]])
                    if key not in all_vertex_keys:
                        all_vertex_keys[key] = vertex_points[i,j]
                        all_vertex_idxs += [key]
            all_vertex_dicts = {key: u for u, key in enumerate(all_vertex_idxs)}
            all_faces = torch.stack([torch.stack([vertex_coords[:, k] for k in f]) for f in faceidxs]).permute(2,0,1,3).reshape(-1,4,3)
    
            all_faces_keys = {}
            for l in range(all_faces.size(0)):
                key = " ".join(["{}".format(int(p)) for p in all_faces[l].sum(0) // 4])
                if key not in all_faces_keys:
                    all_faces_keys[key] = all_faces[l]

            vertex = np.array([tuple(all_vertex_keys[key].cpu().tolist()) for key in all_vertex_idxs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
            face = np.array([([all_vertex_dicts["{} {} {}".format(*b)] for b in a.cpu().tolist()],) for a in all_faces_keys.values()],
                dtype=[('vertex_indices', 'i4', (4,))])
            return PlyData([PlyElement.describe(vertex, 'vertex'), PlyElement.describe(face, 'face')])
Ejemplo n.º 3
0
    def __init__(self, args, voxel_path=None, bbox_path=None, shared_values=None):
        super().__init__(args)
        # read initial voxels or learned sparse voxels
        self.voxel_path = voxel_path if voxel_path is not None else args.voxel_path
        self.bbox_path = bbox_path if bbox_path is not None else getattr(args, "initial_boundingbox", None)
        assert (self.bbox_path is not None) or (self.voxel_path is not None), \
            "at least initial bounding box or pretrained voxel files are required."
        self.voxel_index = None
        self.scene_scale = getattr(args, "scene_scale", 1.0)
 
        if self.voxel_path is not None:
            # read voxel file
            assert os.path.exists(self.voxel_path), "voxel file must exist"
            
            if Path(self.voxel_path).suffix == '.ply':
                from plyfile import PlyData, PlyElement
                plyvoxel = PlyData.read(self.voxel_path)
                elements = [x.name for x in plyvoxel.elements]
                assert 'vertex' in elements
                plydata = plyvoxel['vertex']
                fine_points = torch.from_numpy(
                    np.stack([plydata['x'], plydata['y'], plydata['z']]).astype('float32').T)

                if 'face' in elements:
                    # read voxel meshes... automatically detect voxel size
                    faces = plyvoxel['face']['vertex_indices']
                    t = fine_points[faces[0].astype('int64')]
                    voxel_size = torch.abs(t[0] - t[1]).max()

                    # indexing voxel vertices
                    fine_points = torch.unique(fine_points, dim=0)

                    # vertex_ids, _ = discretize_points(fine_points, voxel_size)
                    # vertex_ids_offset = vertex_ids + 1
                    
                    # # simple hashing
                    # vertex_ids = vertex_ids[:, 0] * 1000000 + vertex_ids[:, 1] * 1000 + vertex_ids[:, 2]
                    # vertex_ids_offset = vertex_ids_offset[:, 0] * 1000000 + vertex_ids_offset[:, 1] * 1000 + vertex_ids_offset[:, 2]

                    # vertex_ids = {k: True for k in vertex_ids.tolist()}
                    # vertex_inside = [v in vertex_ids for v in vertex_ids_offset.tolist()]
                    
                    # # get voxel centers
                    # fine_points = fine_points[torch.tensor(vertex_inside)] + voxel_size * .5
                    # fine_points = fine_points + voxel_size * .5   --> use all corners as centers
                
                else:
                    # voxel size must be provided
                    assert getattr(args, "voxel_size", None) is not None, "final voxel size is essential."
                    voxel_size = args.voxel_size

                if 'quality' in elements:
                    self.voxel_index = torch.from_numpy(plydata['quality']).long()
               
            else:
                # supporting the old style .txt voxel points
                fine_points = torch.from_numpy(np.loadtxt(self.voxel_path)[:, 3:].astype('float32'))
        else:
            # read bounding-box file
            bbox = np.loadtxt(self.bbox_path)
            voxel_size = bbox[-1] if getattr(args, "voxel_size", None) is None else args.voxel_size
            fine_points = torch.from_numpy(bbox2voxels(bbox[:6], voxel_size))
        
        half_voxel = voxel_size * .5
        
        # transform from voxel centers to voxel corners (key/values)
        fine_coords, _ = discretize_points(fine_points, half_voxel)
        fine_keys0 = offset_points(fine_coords, 1.0).reshape(-1, 3)
        fine_keys, fine_feats = torch.unique(fine_keys0, dim=0, sorted=True, return_inverse=True)
        fine_feats = fine_feats.reshape(-1, 8)
        num_keys = torch.scalar_tensor(fine_keys.size(0)).long()
        
        # ray-marching step size
        if getattr(args, "raymarching_stepsize_ratio", 0) > 0:
            step_size = args.raymarching_stepsize_ratio * voxel_size
        else:
            step_size = args.raymarching_stepsize
        
        # register parameters (will be saved to checkpoints)
        self.register_buffer("points", fine_points)          # voxel centers
        self.register_buffer("keys", fine_keys.long())       # id used to find voxel corners/embeddings
        self.register_buffer("feats", fine_feats.long())     # for each voxel, 8 voxel corner ids
        self.register_buffer("num_keys", num_keys)
        #self.register_buffer("points_labels", fine_points)          # voxel centers
        self.register_buffer("keep", fine_feats.new_ones(fine_feats.size(0)).long())  # whether the voxel will be pruned

        self.register_buffer("voxel_size", torch.scalar_tensor(voxel_size))
        self.register_buffer("step_size", torch.scalar_tensor(step_size))
        self.register_buffer("max_hits", torch.scalar_tensor(args.max_hits))

        logger.info("loaded {} voxel centers, {} voxel corners".format(fine_points.size(0), num_keys))

        # set-up other hyperparameters and initialize running time caches
        self.embed_dim = getattr(args, "voxel_embed_dim", None)
        self.deterministic_step = getattr(args, "deterministic_step", False)
        self.use_octree = getattr(args, "use_octree", False)
        self.track_max_probs = getattr(args, "track_max_probs", False)    
        self._runtime_caches = {
            "flatten_centers": None,
            "flatten_children": None,
            "max_voxel_probs": None
        }

        # sparse voxel embeddings     
        if shared_values is None and self.embed_dim > 0:
            self.values = Embedding(num_keys, self.embed_dim, None)
        else:
            self.values = shared_values
Ejemplo n.º 4
0
    def __init__(self,
                 args,
                 voxel_path=None,
                 bbox_path=None,
                 shared_values=None):
        super().__init__(args)
        # read initial voxels or learned sparse voxels
        self.voxel_path = voxel_path if voxel_path is not None else args.voxel_path
        self.bbox_path = bbox_path if bbox_path is not None else getattr(
            args, "initial_boundingbox", None)
        assert (self.bbox_path is not None) or (self.voxel_path is not None), \
            "at least initial bounding box or pretrained voxel files are required."
        self.voxel_index = None
        if self.voxel_path is not None:
            assert os.path.exists(self.voxel_path), "voxel file must exist"
            assert getattr(args, "voxel_size",
                           None) is not None, "final voxel size is essential."

            voxel_size = args.voxel_size

            if Path(self.voxel_path).suffix == '.ply':
                from plyfile import PlyData, PlyElement
                plydata = PlyData.read(self.voxel_path)['vertex']
                fine_points = torch.from_numpy(
                    np.stack([plydata['x'], plydata['y'],
                              plydata['z']]).astype('float32').T)
                try:
                    self.voxel_index = torch.from_numpy(
                        plydata['quality']).long()
                except ValueError:
                    pass
            else:
                # supporting the old version voxel points
                fine_points = torch.from_numpy(
                    np.loadtxt(self.voxel_path)[:, 3:].astype('float32'))
        else:
            bbox = np.loadtxt(self.bbox_path)
            voxel_size = bbox[-1]
            fine_points = torch.from_numpy(bbox2voxels(bbox[:6], voxel_size))
        half_voxel = voxel_size * .5

        # transform from voxel centers to voxel corners (key/values)
        fine_coords, _ = discretize_points(fine_points, half_voxel)
        fine_keys0 = offset_points(fine_coords, 1.0).reshape(-1, 3)
        fine_keys, fine_feats = torch.unique(fine_keys0,
                                             dim=0,
                                             sorted=True,
                                             return_inverse=True)
        fine_feats = fine_feats.reshape(-1, 8)
        num_keys = torch.scalar_tensor(fine_keys.size(0)).long()

        # ray-marching step size
        if getattr(args, "raymarching_stepsize_ratio", 0) > 0:
            step_size = args.raymarching_stepsize_ratio * voxel_size
        else:
            step_size = args.raymarching_stepsize

        # register parameters (will be saved to checkpoints)
        self.register_buffer("points", fine_points)  # voxel centers
        self.register_buffer(
            "keys",
            fine_keys.long())  # id used to find voxel corners/embeddings
        self.register_buffer(
            "feats", fine_feats.long())  # for each voxel, 8 voxel corner ids
        self.register_buffer("num_keys", num_keys)
        self.register_buffer(
            "keep",
            fine_feats.new_ones(
                fine_feats.size(0)).long())  # whether the voxel will be pruned

        self.register_buffer("voxel_size", torch.scalar_tensor(voxel_size))
        self.register_buffer("step_size", torch.scalar_tensor(step_size))
        self.register_buffer("max_hits", torch.scalar_tensor(args.max_hits))

        # set-up other hyperparameters and initialize running time caches
        self.embed_dim = getattr(args, "voxel_embed_dim", None)
        self.deterministic_step = getattr(args, "deterministic_step", False)
        self.use_octree = getattr(args, "use_octree", False)
        self.track_max_probs = getattr(args, "track_max_probs", False)
        self._runtime_caches = {
            "flatten_centers": None,
            "flatten_children": None,
            "max_voxel_probs": None
        }

        # sparse voxel embeddings
        if shared_values is None and self.embed_dim > 0:
            self.values = Embedding(num_keys, self.embed_dim, None)
        else:
            self.values = shared_values