Example #1
0
    def splitting(self):
        logger.info("splitting...")
        encoder_states = self.precompute(id=None)
        feats, points, values = encoder_states[
            'voxel_vertex_idx'], encoder_states[
                'voxel_center_xyz'], encoder_states['voxel_vertex_emb']
        new_points, new_feats, new_values, new_keys = splitting_points(
            points, feats, values, self.voxel_size / 2.0)
        new_num_keys = new_keys.size(0)
        new_point_length = new_points.size(0)

        # set new voxel embeddings
        if new_values is not None:
            self.values.weight = nn.Parameter(new_values)
            self.values.num_embeddings = self.values.weight.size(0)

        self.total_size = new_num_keys
        self.num_keys = self.num_keys * 0 + self.total_size

        self.points = new_points
        self.feats = new_feats
        self.keep = self.keep.new_ones(new_point_length)
        logger.info(
            "splitting done. # of voxels before: {}, after: {} voxels".format(
                points.size(0), self.keep.sum()))
Example #2
0
    def splitting(self):
        logger.info("splitting...")
        all_feats, all_points = [], []
        for id in range(len(self.all_voxels)):
            encoder_states = self.all_voxels[id].precompute(id=None)
            feats = encoder_states['voxel_vertex_idx']
            points = encoder_states['voxel_center_xyz']
            values = encoder_states['voxel_vertex_emb']

            all_feats.append(feats)
            all_points.append(points)
        
        feats, points = torch.cat(all_feats, 0), torch.cat(all_points, 0)
        unique_feats, unique_idx = torch.unique(feats, dim=0, return_inverse=True)
        unique_points = points[
            unique_feats.new_zeros(unique_feats.size(0)).scatter_(
                0, unique_idx, torch.arange(unique_idx.size(0), device=unique_feats.device)
        )]
        new_points, new_feats, new_values, new_keys = splitting_points(unique_points, unique_feats, values, self.voxel_size / 2.0)
        new_num_keys = new_keys.size(0)
        new_point_length = new_points.size(0)

        # set new voxel embeddings (shared voxels)
        if values is not None:
            self.all_voxels[0].values.weight = nn.Parameter(new_values)
            self.all_voxels[0].values.num_embeddings = new_num_keys

        for id in range(len(self.all_voxels)):
            self.all_voxels[id].total_size = new_num_keys
            self.all_voxels[id].num_keys = self.all_voxels[id].num_keys * 0 + self.all_voxels[id].total_size

            self.all_voxels[id].points = new_points
            self.all_voxels[id].feats = new_feats
            self.all_voxels[id].keep = self.all_voxels[id].keep.new_ones(new_point_length)

        logger.info("splitting done. # of voxels before: {}, after: {} voxels".format(
            unique_points.size(0), new_point_length))