def splitting(self): logger.info("splitting...") encoder_states = self.precompute(id=None) feats, points, values = encoder_states[ 'voxel_vertex_idx'], encoder_states[ 'voxel_center_xyz'], encoder_states['voxel_vertex_emb'] new_points, new_feats, new_values, new_keys = splitting_points( points, feats, values, self.voxel_size / 2.0) new_num_keys = new_keys.size(0) new_point_length = new_points.size(0) # set new voxel embeddings if new_values is not None: self.values.weight = nn.Parameter(new_values) self.values.num_embeddings = self.values.weight.size(0) self.total_size = new_num_keys self.num_keys = self.num_keys * 0 + self.total_size self.points = new_points self.feats = new_feats self.keep = self.keep.new_ones(new_point_length) logger.info( "splitting done. # of voxels before: {}, after: {} voxels".format( points.size(0), self.keep.sum()))
def splitting(self): logger.info("splitting...") all_feats, all_points = [], [] for id in range(len(self.all_voxels)): encoder_states = self.all_voxels[id].precompute(id=None) feats = encoder_states['voxel_vertex_idx'] points = encoder_states['voxel_center_xyz'] values = encoder_states['voxel_vertex_emb'] all_feats.append(feats) all_points.append(points) feats, points = torch.cat(all_feats, 0), torch.cat(all_points, 0) unique_feats, unique_idx = torch.unique(feats, dim=0, return_inverse=True) unique_points = points[ unique_feats.new_zeros(unique_feats.size(0)).scatter_( 0, unique_idx, torch.arange(unique_idx.size(0), device=unique_feats.device) )] new_points, new_feats, new_values, new_keys = splitting_points(unique_points, unique_feats, values, self.voxel_size / 2.0) new_num_keys = new_keys.size(0) new_point_length = new_points.size(0) # set new voxel embeddings (shared voxels) if values is not None: self.all_voxels[0].values.weight = nn.Parameter(new_values) self.all_voxels[0].values.num_embeddings = new_num_keys for id in range(len(self.all_voxels)): self.all_voxels[id].total_size = new_num_keys self.all_voxels[id].num_keys = self.all_voxels[id].num_keys * 0 + self.all_voxels[id].total_size self.all_voxels[id].points = new_points self.all_voxels[id].feats = new_feats self.all_voxels[id].keep = self.all_voxels[id].keep.new_ones(new_point_length) logger.info("splitting done. # of voxels before: {}, after: {} voxels".format( unique_points.size(0), new_point_length))