Ejemplo n.º 1
0
    def pointgroup_cluster_algorithm(self,
                                     coords,
                                     point_offset_preds,
                                     point_semantic_preds,
                                     batch_idxs,
                                     batch_size,
                                     stuff_preds=None):
        #### get prooposal clusters
        if stuff_preds is None:
            object_idxs = torch.nonzero(point_semantic_preds > 1).view(-1)
        else:
            object_idxs = torch.nonzero(stuff_preds == 1).view(-1)
        coords = coords.squeeze()

        batch_idxs_ = batch_idxs[object_idxs]
        batch_offsets_ = utils.get_batch_offsets(batch_idxs_, batch_size)
        coords_ = coords[object_idxs]
        pt_offsets_ = point_offset_preds[object_idxs]

        semantic_preds_cpu = point_semantic_preds[object_idxs].int().cpu()

        if self.cluster_sets == 'Q':
            idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(
                coords_ + pt_offsets_, batch_idxs_, batch_offsets_,
                self.cluster_radius, self.cluster_shift_meanActive)
            proposals_idx_shift, proposals_offset_shift = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx_shift.cpu(), start_len_shift.cpu(),
                self.cluster_npoint_thre)
            proposals_idx_shift[:, 1] = object_idxs[
                proposals_idx_shift[:, 1].long()].int()
            # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset_shift: (nProposal + 1), int

            proposals_idx = proposals_idx_shift
            proposals_offset = proposals_offset_shift
            scores = torch.ones(proposals_offset_shift.shape[0] - 1,
                                1).to(point_offset_preds[0].device)

        elif self.cluster_sets == 'P':
            idx, start_len = pointgroup_ops.ballquery_batch_p(
                coords_, batch_idxs_, batch_offsets_, self.cluster_radius,
                self.cluster_meanActive)
            proposals_idx, proposals_offset = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx.cpu(), start_len.cpu(),
                self.cluster_npoint_thre)
            proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()
            # proposals_idx: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset: (nProposal + 1), int

            proposals_idx = proposals_idx
            proposals_offset = proposals_offset
            scores = torch.ones(proposals_offset.shape[0] - 1,
                                1).to(point_offset_preds[0].device)

        return scores, proposals_idx, proposals_offset
Ejemplo n.º 2
0
    def forward(self, input, input_map, coords, rgb, ori_coords, batch_idxs,
                batch_offsets, epoch):
        '''
        :param input_map: (N), int, cuda
        :param coords: (N, 3), float, cuda
        :param batch_idxs: (N), int, cuda
        :param batch_offsets: (B + 1), int, cuda
        '''
        ret = {}

        batch_idxs = batch_idxs.squeeze()

        semantic_scores = []
        point_offset_preds = []
        voxel_occupancy_preds = []

        voxel_feats = pointgroup_ops.voxelization(
            input['pt_feats'], input['v2p_map'],
            input['mode'])  # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(voxel_feats, input['voxel_coords'],
                                         input['spatial_shape'],
                                         input['batch_size'])
        output = self.input_conv(input_)
        output = self.unet(output)
        output = self.output_layer(output)
        output_feats = output.features[input_map.long()]
        output_feats = output_feats.squeeze(dim=0)

        ### point prediction
        #### point semantic label prediction
        semantic_scores.append(
            self.point_semantic(output_feats))  # (N, nClass), float

        ### only used to evaluate based on ground truth
        # semantic_scores.append(input['point_semantic_scores'][0])  # (N, nClass), float
        ### ground truth for each category
        # CATE_NUM = 0
        # semantic_output = self.point_semantic(output_feats)
        # if (input['point_semantic_scores'][0].max(dim=1)[1] == CATE_NUM).sum() > 0:
        #     semantic_output[input['point_semantic_scores'][0].max(dim=1)[1] == CATE_NUM] = \
        #     input['point_semantic_scores'][0][input['point_semantic_scores'][0].max(dim=1)[1] == CATE_NUM].float()
        # semantic_output[semantic_output.max(dim=1)[1] == CATE_NUM] = \
        # input['point_semantic_scores'][0][semantic_output.max(dim=1)[1] == CATE_NUM].float()
        # semantic_scores.append(semantic_output)

        point_semantic_preds = semantic_scores[0].max(1)[1]

        #### point offset prediction
        point_offset_pred = self.point_offset(output_feats)
        point_offset_preds.append(point_offset_pred)  # (N, 3), float32
        # only used to evaluate based on ground truth
        # point_offset_preds.append(input['point_offset_preds'])  # (N, 3), float32

        voxel_occupancy_preds.append(self.point_occupancy(output.features))
        point_occupancy_pred = voxel_occupancy_preds[0][
            input_map.long()].squeeze(dim=1)

        if (epoch > self.prepare_epochs):
            #### get prooposal clusters
            object_idxs = torch.nonzero(point_semantic_preds > 1).view(-1)

            batch_idxs_ = batch_idxs[object_idxs]
            batch_offsets_ = utils.get_batch_offsets(batch_idxs_,
                                                     input['batch_size'])
            coords_ = coords[object_idxs]
            pt_offsets_ = point_offset_preds[0][object_idxs]
            point_occupancy_pred_ = point_occupancy_pred[object_idxs]

            semantic_preds_cpu = point_semantic_preds[object_idxs].int().cpu()

            # idx_occupancy, start_len_occupancy = pointgroup_ops.ballquery_batch_p(
            #     coords_ + pt_offsets_, batch_idxs_,
            #     batch_offsets_, self.cluster_radius, self.cluster_shift_meanActive
            # )
            # proposals_idx_occupancy, proposals_offset_occupancy = pointgroup_ops.bfs_occupancy_cluster(
            #     semantic_preds_cpu, point_occupancy_pred_.cpu(), idx_occupancy.cpu(),
            #     start_len_occupancy.cpu(), self.cluster_npoint_thre, self.occupancy_cluster['occupancy_threshold_shift']
            # )
            # proposals_idx_occupancy[:, 1] = object_idxs[proposals_idx_occupancy[:, 1].long()].int()
            # # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # # proposals_offset_shift: (nProposal + 1), int
            #
            # idx, start_len = pointgroup_ops.ballquery_batch_p(
            #     coords_, batch_idxs_, batch_offsets_, self.cluster_radius, self.cluster_meanActive
            # )
            # proposals_idx, proposals_offset = pointgroup_ops.bfs_occupancy_cluster(
            #     semantic_preds_cpu, point_occupancy_pred_.cpu(), idx.cpu(),
            #     start_len.cpu(), self.cluster_npoint_thre, self.occupancy_cluster['occupancy_threshold']
            # )
            # proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()

            idx, start_len = pointgroup_ops.ballquery_batch_p(
                coords_, batch_idxs_, batch_offsets_, self.cluster_radius,
                self.cluster_meanActive)
            proposals_idx, proposals_offset = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx.cpu(), start_len.cpu(),
                self.cluster_npoint_thre)
            proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()

            idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(
                coords_ + pt_offsets_, batch_idxs_, batch_offsets_,
                self.cluster_radius, self.cluster_shift_meanActive)
            proposals_idx_shift, proposals_offset_shift = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx_shift.cpu(), start_len_shift.cpu(),
                self.cluster_npoint_thre)
            proposals_idx_shift[:, 1] = object_idxs[
                proposals_idx_shift[:, 1].long()].int()
            # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset_shift: (nProposal + 1), int

            # proposals_idx_occupancy[:, 0] += (proposals_offset.size(0) - 1)
            # proposals_offset_occupancy += proposals_offset[-1]
            # proposals_idx = torch.cat((proposals_idx, proposals_idx_occupancy), dim=0)
            # proposals_offset = torch.cat((proposals_offset, proposals_offset_occupancy[1:]))

            # proposals_idx_filtered = []
            # proposals_offset_filtered = [0]
            # for proposal in proposals_idx_shift[:, 0].unique():
            #     proposal_index = proposals_idx_shift[:, 0] == proposal
            #     proposal_idxs = proposals_idx_shift[proposal_index, 1].long()
            #     proposals_indexs = proposals_idx_shift[proposal_index, :]
            #
            #     proposal_occupancy_mean = point_occupancy_pred[proposal_idxs].mean()
            #
            #     valid_point_index = torch.ones_like(proposal_idxs).byte()
            #     valid_point_index[point_occupancy_pred[proposal_idxs] < proposal_occupancy_mean *
            #                       (1 - self.occupancy_cluster['occupancy_filter_threshold'])] = 0
            #     valid_point_index[point_occupancy_pred[proposal_idxs] > proposal_occupancy_mean *
            #                       (1 + self.occupancy_cluster['occupancy_filter_threshold'])] = 0
            #     proposal_idx_filtered = proposals_indexs[valid_point_index, :]
            #
            #     proposals_idx_filtered.append(proposal_idx_filtered)
            #     proposals_offset_filtered.append(proposals_offset_filtered[-1] + proposal_idx_filtered.shape[0])
            #
            # proposals_idx_filtered = torch.cat(proposals_idx_filtered, dim=0)
            # proposals_offset_filtered = torch.tensor(proposals_offset_filtered).int()
            # proposals_idx_shift = proposals_idx_filtered
            # proposals_offset_shift = proposals_offset_filtered

            proposals_idx_shift[:, 0] += (proposals_offset.size(0) - 1)
            proposals_offset_shift += proposals_offset[-1]
            proposals_idx = torch.cat((proposals_idx, proposals_idx_shift),
                                      dim=0)
            proposals_offset = torch.cat(
                (proposals_offset, proposals_offset_shift[1:]))

            #### proposals voxelization again
            input_feats, inp_map = self.clusters_voxelization(
                proposals_idx, proposals_offset, output_feats, coords,
                self.score_fullscale, self.score_scale, self.mode)

            #### score
            score = self.score_unet(input_feats)
            score = self.score_outputlayer(score)
            score_feats = score.features[inp_map.long()]  # (sumNPoint, C)
            score_feats = pointgroup_ops.roipool(
                score_feats, proposals_offset.cuda())  # (nProposal, C)
            scores = self.score_linear(score_feats)  # (nProposal, 1)

            # proposals_idx = proposals_idx_occupancy
            # proposals_offset = proposals_offset_occupancy
            # scores = torch.ones(proposals_offset.shape[0] - 1, 1).to(point_offset_preds[0].device)

            ret['proposal_scores'] = (scores, proposals_idx, proposals_offset)

        ret['point_semantic_scores'] = semantic_scores
        ret['point_offset_preds'] = point_offset_preds
        ret['point_features'] = output_feats
        ret['voxel_occupancy_preds'] = voxel_occupancy_preds

        return ret
Ejemplo n.º 3
0
    def forward(self, input, input_map, coords, rgb, ori_coords, batch_idxs,
                batch_offsets, epoch):
        '''
        :param input_map: (N), int, cuda
        :param coords: (N, 3), float, cuda
        :param batch_idxs: (N), int, cuda
        :param batch_offsets: (B + 1), int, cuda
        '''
        ret = {}

        batch_idxs = batch_idxs.squeeze()

        semantic_scores = []
        point_offset_preds = []

        voxel_feats = pointgroup_ops.voxelization(
            input['pt_feats'], input['v2p_map'],
            input['mode'])  # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(voxel_feats, input['voxel_coords'],
                                         input['spatial_shape'],
                                         input['batch_size'])
        output = self.input_conv(input_)
        output = self.unet(output)
        output = self.output_layer(output)
        output_feats = output.features[input_map.long()]
        output_feats = output_feats.squeeze(dim=0)

        ### point prediction
        #### point semantic label prediction
        semantic_scores.append(
            self.point_semantic(output_feats))  # (N, nClass), float

        ### only used to evaluate based on ground truth
        # semantic_scores.append(input['point_semantic_scores'][0])  # (N, nClass), float
        ### ground truth for each category
        # CATE_NUM = 0
        # semantic_output = self.point_semantic(output_feats)
        # if (input['point_semantic_scores'][0].max(dim=1)[1] == CATE_NUM).sum() > 0:
        #     semantic_output[input['point_semantic_scores'][0].max(dim=1)[1] == CATE_NUM] = \
        #     input['point_semantic_scores'][0][input['point_semantic_scores'][0].max(dim=1)[1] == CATE_NUM].float()
        # semantic_output[semantic_output.max(dim=1)[1] == CATE_NUM] = \
        # input['point_semantic_scores'][0][semantic_output.max(dim=1)[1] == CATE_NUM].float()
        # semantic_scores.append(semantic_output)

        point_semantic_preds = semantic_scores[0].max(1)[1]

        #### point offset prediction
        point_offset_pred = self.point_offset(output_feats)
        if self.instance_triplet_loss['activate']:
            point_offset_pred = point_offset_pred - input['pt_feats'][:, 3:]
        point_offset_preds.append(point_offset_pred)  # (N, 3), float32
        # only used to evaluate based on ground truth
        # point_offset_preds.append(input['point_offset_preds'])  # (N, 3), float32

        if self.voxel_center_prediction['activate']:
            voxel_center_preds = self.voxel_center_pred(output.features)
            voxel_center_offset_preds = self.voxel_center_offset(
                output.features)
            voxel_center_semantic_preds = self.voxel_center_semantic(
                output.features)

            ret['voxel_center_preds'] = voxel_center_preds
            ret['voxel_center_offset_preds'] = voxel_center_offset_preds
            ret['voxel_center_semantic_preds'] = voxel_center_semantic_preds

        if self.point_xyz_reconstruction_loss['activate']:
            point_reconstructed_coords = self.point_reconstruction_coords(
                output_feats)

            ret['point_reconstructed_coords'] = point_reconstructed_coords

        if self.instance_classifier['activate']:
            instance_id_preds = self.point_instance_classifier(output_feats)

            ret['instance_id_preds'] = instance_id_preds

        if (epoch > self.prepare_epochs):
            #### get prooposal clusters
            object_idxs = torch.nonzero(point_semantic_preds > 1).view(-1)

            batch_idxs_ = batch_idxs[object_idxs]
            batch_offsets_ = utils.get_batch_offsets(batch_idxs_,
                                                     input['batch_size'])
            coords_ = coords[object_idxs]
            pt_offsets_ = point_offset_preds[0][object_idxs]

            semantic_preds_cpu = point_semantic_preds[object_idxs].int().cpu()

            idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(
                coords_ + pt_offsets_, batch_idxs_, batch_offsets_,
                self.cluster_radius, self.cluster_shift_meanActive)
            # idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(
            #     coords_ + pt_offsets_ + (torch.rand(coords_.shape) * 1e-2).cuda(), batch_idxs_,
            #     batch_offsets_, 0.001, self.cluster_shift_meanActive
            # )
            proposals_idx_shift, proposals_offset_shift = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx_shift.cpu(), start_len_shift.cpu(),
                self.cluster_npoint_thre)

            proposals_idx_shift[:, 1] = object_idxs[
                proposals_idx_shift[:, 1].long()].int()
            # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset_shift: (nProposal + 1), int

            idx, start_len = pointgroup_ops.ballquery_batch_p(
                coords_, batch_idxs_, batch_offsets_, self.cluster_radius,
                self.cluster_meanActive)
            proposals_idx, proposals_offset = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx.cpu(), start_len.cpu(),
                self.cluster_npoint_thre)
            proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()
            # proposals_idx: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset: (nProposal + 1), int

            proposals_idx_shift[:, 0] += (proposals_offset.size(0) - 1)
            proposals_offset_shift += proposals_offset[-1]
            proposals_idx = torch.cat((proposals_idx, proposals_idx_shift),
                                      dim=0)
            proposals_offset = torch.cat(
                (proposals_offset, proposals_offset_shift[1:]))

            #### proposals voxelization again
            input_feats, inp_map = self.clusters_voxelization(
                proposals_idx, proposals_offset, output_feats, coords,
                self.score_fullscale, self.score_scale, self.mode)

            #### score
            score = self.score_unet(input_feats)
            score = self.score_outputlayer(score)
            score_feats = score.features[inp_map.long()]  # (sumNPoint, C)
            score_feats = pointgroup_ops.roipool(
                score_feats, proposals_offset.cuda())  # (nProposal, C)
            scores = self.score_linear(score_feats)  # (nProposal, 1)

            ret['proposal_scores'] = (scores, proposals_idx, proposals_offset)

        ret['point_semantic_scores'] = semantic_scores
        ret['point_offset_preds'] = point_offset_preds
        if self.instance_triplet_loss['activate']:
            ret['point_offset_feats'] = output_feats
        ret['point_features'] = output_feats

        return ret
Ejemplo n.º 4
0
    def forward(self, input, input_map, coords, batch_idxs, batch_offsets,
                epoch):
        '''
        :param input_map: (N), int, cuda
        :param coords: (N, 3), float, cuda
        :param batch_idxs: (N), int, cuda
        :param batch_offsets: (B + 1), int, cuda
        '''
        ret = {}

        output = self.input_conv(input)
        output = self.unet(output)
        output = self.output_layer(output)
        output_feats = output.features[input_map.long()]

        #### semantic segmentation
        semantic_scores = self.linear(output_feats)  # (N, nClass), float
        semantic_preds = semantic_scores.max(1)[1]  # (N), long

        # ScanRefer:
        # semantic_preds are needed in loss calculation, loss_helper.py
        ret['semantic_preds'] = semantic_preds

        ret['semantic_scores'] = semantic_scores

        #### offset
        pt_offsets_feats = self.offset(output_feats)
        pt_offsets = self.offset_linear(pt_offsets_feats)  # (N, 3), float32

        ret['pt_offsets'] = pt_offsets

        if (epoch > self.prepare_epochs):
            #### get prooposal clusters
            # NOTE: proposals_idx_shift are the result of the shifted clustering
            # and proposals_idx of the normal clustering. They are concat afterwards.
            # proposal_offset (_shift) mark the beginning of each proposal (what batch_offsets
            # does for batches) as there is no extra dimension for proposals or batches.
            # NOTE: sumNPoints are all the points and their assigned cluster.
            # Although points can only be in a cluster together with points from the same object
            # type, they can still be part of multiple clusters. That is the reason why
            # sumNPoints >= N (sum of all points in all scenes that have been passed in as one batch)
            object_idxs = torch.nonzero(semantic_preds > 1).view(-1)

            batch_idxs_ = batch_idxs[object_idxs]
            batch_offsets_ = utils.get_batch_offsets(batch_idxs_,
                                                     input.batch_size)
            coords_ = coords[object_idxs]
            pt_offsets_ = pt_offsets[object_idxs]

            semantic_preds_cpu = semantic_preds[object_idxs].int().cpu()

            idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(
                coords_ + pt_offsets_, batch_idxs_, batch_offsets_,
                self.cluster_radius, self.cluster_shift_meanActive)
            proposals_idx_shift, proposals_offset_shift = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx_shift.cpu(), start_len_shift.cpu(),
                self.cluster_npoint_thre)
            proposals_idx_shift[:, 1] = object_idxs[
                proposals_idx_shift[:, 1].long()].int()
            # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset_shift: (nProposal + 1), int

            idx, start_len = pointgroup_ops.ballquery_batch_p(
                coords_, batch_idxs_, batch_offsets_, self.cluster_radius,
                self.cluster_meanActive)
            proposals_idx, proposals_offset = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx.cpu(), start_len.cpu(),
                self.cluster_npoint_thre)
            proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()
            # proposals_idx: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset: (nProposal + 1), int

            # TODO: Doesn't one need to extract only the unique points from each array? (proposal_idx U proposal_idx_shift)
            proposals_idx_shift[:, 0] += (proposals_offset.size(0) - 1)
            proposals_offset_shift += proposals_offset[-1]
            proposals_idx = torch.cat((proposals_idx, proposals_idx_shift),
                                      dim=0)
            proposals_offset = torch.cat(
                (proposals_offset, proposals_offset_shift[1:]))

            # ScanRefer:
            # predicted instances are needed in loss calculation in loss_helper.py
            ret['proposals_idx'] = proposals_idx

            #### proposals voxelization again
            input_feats, inp_map = self.clusters_voxelization(
                proposals_idx, proposals_offset, output_feats, coords,
                self.score_fullscale, self.score_scale, self.mode)

            #### score
            score = self.score_unet(input_feats)
            score = self.score_outputlayer(score)
            score_feats = score.features[inp_map.long()]  # (sumNPoint, C)
            score_feats = pointgroup_ops.roipool(
                score_feats, proposals_offset.cuda())  # (nProposal, C)

            # ScanRefer:
            # save intermediate result
            ret['score_feats'] = score_feats

            scores = self.score_linear(score_feats)  # (nProposal, 1)

            ret['proposal_scores'] = (scores, proposals_idx, proposals_offset)

        return ret
Ejemplo n.º 5
0
    def forward(self, input, input_map, coords, batch_idxs, batch_offsets,
                epoch):
        '''
        :param input_map: (N), int, cuda
        :param coords: (N, 3), float, cuda
        :param batch_idxs: (N), int, cuda
        :param batch_offsets: (B + 1), int, cuda
        '''
        ret = {}

        output = self.input_conv(input)
        output = self.unet(output)
        output = self.output_layer(output)
        output_feats = output.features[input_map.long()]

        #### semantic segmentation
        semantic_scores = self.linear(output_feats)  # (N, nClass), float
        semantic_preds = semantic_scores.max(1)[1]  # (N), long

        ret['semantic_scores'] = semantic_scores

        #### offset
        pt_offsets_feats = self.offset(output_feats)
        pt_offsets = self.offset_linear(pt_offsets_feats)  # (N, 3), float32

        ret['pt_offsets'] = pt_offsets

        if (epoch > self.prepare_epochs):
            #### get prooposal clusters
            object_idxs = torch.nonzero(semantic_preds > 1).view(-1)

            batch_idxs_ = batch_idxs[object_idxs]
            batch_offsets_ = utils.get_batch_offsets(batch_idxs_,
                                                     input.batch_size)
            coords_ = coords[object_idxs]
            pt_offsets_ = pt_offsets[object_idxs]

            semantic_preds_cpu = semantic_preds[object_idxs].int().cpu()

            idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(
                coords_ + pt_offsets_, batch_idxs_, batch_offsets_,
                self.cluster_radius, self.cluster_shift_meanActive)
            proposals_idx_shift, proposals_offset_shift = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx_shift.cpu(), start_len_shift.cpu(),
                self.cluster_npoint_thre)
            proposals_idx_shift[:, 1] = object_idxs[
                proposals_idx_shift[:, 1].long()].int()
            # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset_shift: (nProposal + 1), int

            idx, start_len = pointgroup_ops.ballquery_batch_p(
                coords_, batch_idxs_, batch_offsets_, self.cluster_radius,
                self.cluster_meanActive)
            proposals_idx, proposals_offset = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx.cpu(), start_len.cpu(),
                self.cluster_npoint_thre)
            proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()
            # proposals_idx: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset: (nProposal + 1), int

            proposals_idx_shift[:, 0] += (proposals_offset.size(0) - 1)
            proposals_offset_shift += proposals_offset[-1]
            proposals_idx = torch.cat((proposals_idx, proposals_idx_shift),
                                      dim=0)
            proposals_offset = torch.cat(
                (proposals_offset, proposals_offset_shift[1:]))

            #### proposals voxelization again
            input_feats, inp_map = self.clusters_voxelization(
                proposals_idx, proposals_offset, output_feats, coords,
                self.score_fullscale, self.score_scale, self.mode)

            #### score
            score = self.score_unet(input_feats)
            score = self.score_outputlayer(score)
            score_feats = score.features[inp_map.long()]  # (sumNPoint, C)
            score_feats = pointgroup_ops.roipool(
                score_feats, proposals_offset.cuda())  # (nProposal, C)
            scores = self.score_linear(score_feats)  # (nProposal, 1)

            ret['proposal_scores'] = (scores, proposals_idx, proposals_offset)

        return ret
    def forward(self, input, input_map, coords, rgb, ori_coords, batch_idxs,
                batch_offsets, epoch):
        '''
        :param input_map: (N), int, cuda
        :param coords: (N, 3), float, cuda
        :param batch_idxs: (N), int, cuda
        :param batch_offsets: (B + 1), int, cuda
        '''
        ret = {}

        point_semantic_scores = []
        point_offset_preds = []
        local_point_semantic_scores = []
        local_point_offset_preds = []
        '''point feature extraction'''
        # voxelize point cloud, and those voxels are passed to a sparse 3D U-Net
        # the output of sparse 3D U-Net is then remapped back to points
        voxel_feats = pointgroup_ops.voxelization(
            input['pt_feats'], input['v2p_map'],
            input['mode'])  # (M, C), float, cuda
        input_ = spconv.SparseConvTensor(voxel_feats, input['voxel_coords'],
                                         input['spatial_shape'],
                                         input['batch_size'])
        output = self.input_conv(input_)
        output = self.unet(output)
        output = self.output_layer(output)
        output_feats = output.features[input_map.long()]
        '''point-wise predictions'''
        point_semantic_scores.append(
            self.point_semantic(output_feats))  # (N, nClass), float
        point_semantic_preds = point_semantic_scores[-1].max(1)[1]

        point_offset_preds.append(
            self.point_offset(output_feats))  # (N, 3), float32

        # cluster proposals with the stable output of the backbone network
        if (epoch > self.prepare_epochs):
            if not input['test'] and self.local_proposal['use_gt_semantic']:
                point_semantic_preds = input['semantic_labels']
            '''clustering algorithm'''
            object_idxs = torch.nonzero(point_semantic_preds > 1).view(-1)
            batch_idxs_ = batch_idxs[object_idxs]
            batch_offsets_ = utils.get_batch_offsets(batch_idxs_,
                                                     input['batch_size'])
            coords_ = coords[object_idxs]
            pt_offsets_ = point_offset_preds[-1][object_idxs]

            semantic_preds_cpu = point_semantic_preds[object_idxs].int().cpu()

            idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(
                coords_ + pt_offsets_, batch_idxs_, batch_offsets_,
                self.cluster_radius, self.cluster_shift_meanActive)
            proposals_idx_shift, proposals_offset_shift = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx_shift.cpu(), start_len_shift.cpu(),
                self.cluster_npoint_thre)
            proposals_idx_shift[:, 1] = object_idxs[
                proposals_idx_shift[:, 1].long()].int()
            # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset_shift: (nProposal + 1), int
            '''local proposal refinement'''
            local_proposals_idx = []
            local_proposals_offset = [0]

            # compute proposal centers
            proposal_center_coords = []
            for proposal_index in range(1, len(proposals_offset_shift)):
                proposal_point_index = proposals_idx_shift[
                    proposals_offset_shift[proposal_index - 1]:
                    proposals_offset_shift[proposal_index], 1]
                proposal_center_coords.append(
                    coords[proposal_point_index.long(), :].mean(dim=0,
                                                                keepdim=True))
            proposal_center_coords = torch.cat(proposal_center_coords, dim=0)

            # select the topk closest proposals for each proposal
            proposal_dist_mat = euclidean_dist(proposal_center_coords,
                                               proposal_center_coords)
            proposal_dist_mat[range(len(proposal_dist_mat)),
                              range(len(proposal_dist_mat))] = 100
            closest_proposals_dist, closest_proposals_index = proposal_dist_mat.topk(
                k=min(self.local_proposal['topk'], proposal_dist_mat.shape[1]),
                dim=1,
                largest=False)
            valid_closest_proposals_index = closest_proposals_dist < self.local_proposal[
                'dist_th']

            # select proposals which are the closest and in the distance threshold
            for proposal_index in range(1, len(proposals_offset_shift)):
                local_indexs = []
                local_indexs.append(proposals_idx_shift[proposals_offset_shift[
                    proposal_index -
                    1]:proposals_offset_shift[proposal_index]])
                closest_proposals_ind = closest_proposals_index[
                    proposal_index - 1, :]
                for selected_proposal_index in closest_proposals_ind[
                        valid_closest_proposals_index[proposal_index - 1, :]]:
                    local_index = proposals_idx_shift[
                        proposals_offset_shift[selected_proposal_index]:
                        proposals_offset_shift[selected_proposal_index +
                                               1]][:, 1].unsqueeze(dim=1)
                    local_indexs.append(
                        torch.cat((torch.LongTensor(
                            local_index.shape[0],
                            1).fill_(proposal_index - 1).int(),
                                   local_index.cpu()),
                                  dim=1))

                local_proposals_idx.append(torch.cat(local_indexs, dim=0))
                local_proposals_offset.append(local_proposals_offset[-1] +
                                              local_proposals_idx[-1].shape[0])

            local_proposals_idx = torch.cat(local_proposals_idx, dim=0)
            local_proposals_offset = torch.tensor(local_proposals_offset).int()

            #### proposals voxelization again
            input_feats, inp_map = self.clusters_voxelization(
                local_proposals_idx, local_proposals_offset, output_feats,
                coords, self.local_proposal['local_proposal_full_scale'],
                self.local_proposal['local_proposal_scale'], self.mode)

            #### cluster features
            if self.local_proposal['reuse_backbone_unet']:
                proposals = self.unet(input_feats)
                proposals = self.output_layer(proposals)
            else:
                proposals = self.proposal_unet(input_feats)
                proposals = self.proposal_outputlayer(proposals)
            proposals_point_features = proposals.features[
                inp_map.long()]  # (sumNPoint, C)

            ret['proposals_point_features'] = (proposals_point_features,
                                               local_proposals_idx)

            ### scatter mean point predictions
            if self.local_proposal['scatter_mean_target'] == 'prediction':
                refined_point_semantic_score = point_semantic_scores[-1]
                local_point_semantic_score = self.point_semantic(
                    proposals_point_features)
                refined_point_semantic_score[:min((local_proposals_idx[:, 1].max() + 1).item(), coords.shape[0]), :] = \
                    scatter_mean(local_point_semantic_score, local_proposals_idx[:, 1].cuda().long(), dim=0)
                point_semantic_scores.append(refined_point_semantic_score)
                point_semantic_preds = refined_point_semantic_score.max(1)[1]

                refined_point_offset_pred = point_offset_preds[-1]
                local_point_offset_pred = self.point_offset(
                    proposals_point_features)
                refined_point_offset_pred[:min((local_proposals_idx[:, 1].max() + 1).item(), coords.shape[0]), :] = \
                    scatter_mean(local_point_offset_pred, local_proposals_idx[:, 1].cuda().long(), dim=0)
                point_offset_preds.append(refined_point_offset_pred)

            ### scatter mean point features
            elif self.local_proposal['scatter_mean_target'] == 'feature':
                refined_point_features = torch.zeros_like(output_feats).cuda()
                refined_point_features[:min((local_proposals_idx[:, 1].max() + 1).item(), coords.shape[0]), :] = \
                    scatter_mean(proposals_point_features, local_proposals_idx[:, 1].cuda().long(), dim=0)
                #### filling 0 rows with output_feats
                filled_index = (refined_point_features == torch.zeros_like(
                    refined_point_features[0, :])).all(dim=1)
                refined_point_features[filled_index, :] = output_feats[
                    filled_index, :]

                ### refined point prediction
                #### refined point semantic label prediction
                point_semantic_scores.append(
                    self.point_semantic(
                        refined_point_features))  # (N, nClass), float
                point_semantic_preds = point_semantic_scores[-1].max(1)[1]

                #### point offset prediction
                point_offset_preds.append(
                    self.point_offset(
                        output_feats +
                        refined_point_features))  # (N, 3), float32

            elif self.local_proposal['scatter_mean_target'] == False:
                local_point_semantic_scores.append(
                    self.point_semantic(proposals_point_features))
                local_point_offset_preds.append(
                    self.point_offset(proposals_point_features))

                ret['local_point_semantic_scores'] = (
                    local_point_semantic_scores, local_proposals_idx)
                ret['local_point_offset_preds'] = (local_point_offset_preds,
                                                   local_proposals_idx)

            if input['test']:
                if self.local_proposal['scatter_mean_target'] == False:
                    local_point_semantic_score = self.point_semantic(
                        proposals_point_features)
                    local_point_offset_pred = self.point_offset(
                        proposals_point_features)

                    local_point_semantic_score = scatter_mean(
                        local_point_semantic_score,
                        local_proposals_idx[:, 1].cuda().long(),
                        dim=0)

                    point_semantic_score = point_semantic_scores[-1]
                    point_semantic_score[:local_point_semantic_score.shape[
                        0], :] += local_point_semantic_score
                    point_semantic_preds = point_semantic_score.max(1)[1]

                    point_offset_pred = point_offset_preds[-1]
                    local_point_offset_pred = scatter_mean(
                        local_point_offset_pred,
                        local_proposals_idx[:, 1].cuda().long(),
                        dim=0)
                    point_offset_pred[:local_point_offset_pred.
                                      shape[0], :] = local_point_offset_pred

                self.cluster_sets = 'Q'
                scores, proposals_idx, proposals_offset = self.pointgroup_cluster_algorithm(
                    coords, point_offset_pred, point_semantic_preds,
                    batch_idxs, input['batch_size'])
                ret['proposal_scores'] = (scores, proposals_idx,
                                          proposals_offset)

        ret['point_semantic_scores'] = point_semantic_scores
        ret['point_offset_preds'] = point_offset_preds
        ret['point_features'] = output_feats

        return ret
    def forward(self, input, input_map, coords, rgb, ori_coords, batch_idxs, batch_offsets, epoch):
        '''
        :param input_map: (N), int, cuda
        :param coords: (N, 3), float, cuda
        :param batch_idxs: (N), int, cuda
        :param batch_offsets: (B + 1), int, cuda
        '''
        ret = {}

        batch_idxs = batch_idxs.squeeze()

        point_offset_preds = []
        point_semantic_scores = []

        proposals_confidence_preds = []
        proposals_idx_shifts = []
        proposals_offset_shifts = []

        voxel_feats = pointgroup_ops.voxelization(input['pt_feats'], input['v2p_map'], input['mode']) # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(
            voxel_feats, input['voxel_coords'],
            input['spatial_shape'], input['batch_size']
        )
        output = self.input_conv(input_)
        output = self.unet(output)
        output = self.output_layer(output)
        output_feats = output.features[input_map.long()]
        output_feats = output_feats.squeeze(dim=0)

        ### point prediction
        #### point semantic label prediction
        point_semantic_scores.append(self.point_semantic(output_feats))  # (N, nClass), float
        # point_semantic_preds = semantic_scores
        point_semantic_preds = point_semantic_scores[-1].max(1)[1]

        #### point offset prediction
        point_offset_preds.append(self.point_offset(output_feats))  # (N, 3), float32

        point_features = output_feats.clone()

        if (epoch > self.prepare_epochs):

            for _ in range(self.proposal_refinement['refine_times']):
                #### get prooposal clusters
                object_idxs = torch.nonzero(point_semantic_preds > 1).view(-1)

                batch_idxs_ = batch_idxs[object_idxs]
                batch_offsets_ = utils.get_batch_offsets(batch_idxs_, input['batch_size'])
                coords_ = coords[object_idxs]
                pt_offsets_ = point_offset_preds[-1][object_idxs]

                semantic_preds_cpu = point_semantic_preds[object_idxs].int().cpu()

                idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(coords_ + pt_offsets_, batch_idxs_,
                                                                              batch_offsets_, self.cluster_radius,
                                                                              self.cluster_shift_meanActive)
                proposals_idx_shift, proposals_offset_shift = pointgroup_ops.bfs_cluster(semantic_preds_cpu,
                                                                                         idx_shift.cpu(),
                                                                                         start_len_shift.cpu(),
                                                                                         self.cluster_npoint_thre)
                proposals_idx_shift[:, 1] = object_idxs[proposals_idx_shift[:, 1].long()].int()
                # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
                # proposals_offset_shift: (nProposal + 1), int

                if proposals_idx_shift.shape[0] == 0:
                    continue

                #### proposals voxelization again
                input_feats, inp_map = self.clusters_voxelization(
                    proposals_idx_shift, proposals_offset_shift, output_feats,
                    coords, self.proposal_refinement['proposal_refine_full_scale'],
                    self.proposal_refinement['proposal_refine_scale'], self.mode
                )

                #### proposal features
                proposals = self.proposal_unet(input_feats)
                proposals = self.proposal_outputlayer(proposals)
                proposal_feats = proposals.features[inp_map.long()]  # (sumNPoint, C)
                proposal_feats = pointgroup_ops.roipool(proposal_feats, proposals_offset_shift.cuda())  # (nProposal, C)

                proposals_confidence_preds.append(self.proposal_confidence_linear(proposal_feats))  # (nProposal, 1)
                proposals_idx_shifts.append(proposals_idx_shift)
                proposals_offset_shifts.append(proposals_offset_shift)

                proposal_pts_idxs = proposals_idx_shift[proposals_offset_shift[:-1].long()][:, 1].cuda()

                if len(proposal_pts_idxs) == 0:
                    continue

                proposal_batch_idxs = proposal_pts_idxs.clone()
                for _batch_idx in range(len(batch_offsets_) - 1, 0, -1):
                    proposal_batch_idxs[proposal_pts_idxs < batch_offsets_[_batch_idx]] = _batch_idx

                refined_point_features = []
                for _batch_idx in range(1, len(batch_offsets)):
                    key_input = proposal_feats[proposal_batch_idxs == _batch_idx, :].unsqueeze(dim=0)
                    query_input = output_feats[batch_offsets[_batch_idx - 1]:batch_offsets[_batch_idx],
                                  :].unsqueeze(dim=0).permute(0, 2, 1)
                    # query_input = output_feats[batch_offsets[_batch_idx - 1]:batch_offsets[_batch_idx-1]+100,
                    #               :].unsqueeze(dim=0).permute(0, 2, 1)
                    point_refined_feature, _ = self.proposal_transformer(
                        src=key_input,
                        query_embed=query_input,
                    )

                    refined_point_features.append(point_refined_feature.squeeze(dim=0))

                refined_point_features = torch.cat(refined_point_features, dim=0)
                assert refined_point_features.shape[0] == point_features.shape[
                    0], 'point wise features have wrong point numbers'

                point_features = refined_point_features.clone()

                ### refined point prediction
                #### refined point semantic label prediction
                point_semantic_scores.append(self.point_semantic(refined_point_features))  # (N, nClass), float
                point_semantic_preds = point_semantic_scores[-1].max(1)[1]

                #### point offset prediction
                point_offset_preds.append(self.point_offset(refined_point_features))  # (N, 3), float32

            if (epoch == self.test_epoch) and input['test']:
                self.cluster_sets = 'Q'
                scores, proposals_idx, proposals_offset = self.pointgroup_cluster_algorithm(
                    coords, point_offset_preds[-1], point_semantic_preds,
                    batch_idxs, input['batch_size']
                )
                ret['proposal_scores'] = (scores, proposals_idx, proposals_offset)

            ret['proposal_confidences'] = (proposals_confidence_preds, proposals_idx_shifts, proposals_offset_shifts)

        ret['point_semantic_scores'] = point_semantic_scores
        ret['point_offset_preds'] = point_offset_preds

        return ret