예제 #1
0
    def test_model_fn(batch, model, epoch):
        coords = batch['locs'].cuda()              # (N, 1 + 3), long, cuda, dimension 0 for batch_idx
        voxel_coords = batch['voxel_locs'].cuda()  # (M, 1 + 3), long, cuda
        p2v_map = batch['p2v_map'].cuda()          # (N), int, cuda
        v2p_map = batch['v2p_map'].cuda()          # (M, 1 + maxActive), int, cuda

        coords_float = batch['locs_float'].cuda()  # (N, 3), float32, cuda
        feats = batch['feats'].cuda()              # (N, C), float32, cuda

        batch_offsets = batch['offsets'].cuda()    # (B + 1), int, cuda

        spatial_shape = batch['spatial_shape']

        if cfg.use_coords:
            feats = torch.cat((feats, coords_float), 1)
        voxel_feats = pointgroup_ops.voxelization(feats, v2p_map, cfg.mode)  # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, cfg.batch_size)

        ret = model(input_, p2v_map, coords_float, coords[:, 0].int(), batch_offsets, epoch)
        semantic_scores = ret['semantic_scores']  # (N, nClass) float32, cuda
        pt_offsets = ret['pt_offsets']            # (N, 3), float32, cuda
        if (epoch > cfg.prepare_epochs):
            scores, proposals_idx, proposals_offset = ret['proposal_scores']

        ##### preds
        with torch.no_grad():
            preds = {}
            preds['semantic'] = semantic_scores
            preds['pt_offsets'] = pt_offsets
            if (epoch > cfg.prepare_epochs):
                preds['score'] = scores
                preds['proposals'] = (proposals_idx, proposals_offset)

        return preds
예제 #2
0
    def clusters_voxelization(self, clusters_idx, clusters_offset, feats, coords, fullscale, scale, mode):
        '''
        :param clusters_idx: (SumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N, cpu
        :param clusters_offset: (nCluster + 1), int, cpu
        :param feats: (N, C), float, cuda
        :param coords: (N, 3), float, cuda
        :return:
        '''
        c_idxs = clusters_idx[:, 1].cuda()
        clusters_feats = feats[c_idxs.long()]
        clusters_coords = coords[c_idxs.long()]

        clusters_coords_mean = pointgroup_ops.sec_mean(clusters_coords, clusters_offset.cuda())  # (nCluster, 3), float
        clusters_coords_mean = torch.index_select(clusters_coords_mean, 0, clusters_idx[:, 0].cuda().long())  # (sumNPoint, 3), float
        clusters_coords -= clusters_coords_mean

        clusters_coords_min = pointgroup_ops.sec_min(clusters_coords, clusters_offset.cuda())  # (nCluster, 3), float
        clusters_coords_max = pointgroup_ops.sec_max(clusters_coords, clusters_offset.cuda())  # (nCluster, 3), float

        clusters_scale = 1 / ((clusters_coords_max - clusters_coords_min) / fullscale).max(1)[0] - 0.01  # (nCluster), float
        clusters_scale = torch.clamp(clusters_scale, min=None, max=scale)

        min_xyz = clusters_coords_min * clusters_scale.unsqueeze(-1)  # (nCluster, 3), float
        max_xyz = clusters_coords_max * clusters_scale.unsqueeze(-1)

        clusters_scale = torch.index_select(clusters_scale, 0, clusters_idx[:, 0].cuda().long())

        clusters_coords = clusters_coords * clusters_scale.unsqueeze(-1)

        range = max_xyz - min_xyz
        offset = - min_xyz + torch.clamp(fullscale - range - 0.001, min=0) * torch.rand(3).cuda() + torch.clamp(fullscale - range + 0.001, max=0) * torch.rand(3).cuda()
        offset = torch.index_select(offset, 0, clusters_idx[:, 0].cuda().long())
        clusters_coords += offset
        assert clusters_coords.shape.numel() == ((clusters_coords >= 0) * (clusters_coords < fullscale)).sum()

        clusters_coords = clusters_coords.long()
        clusters_coords = torch.cat([clusters_idx[:, 0].view(-1, 1).long(), clusters_coords.cpu()], 1)  # (sumNPoint, 1 + 3)

        out_coords, inp_map, out_map = pointgroup_ops.voxelization_idx(clusters_coords, int(clusters_idx[-1, 0]) + 1, mode)
        # output_coords: M * (1 + 3) long
        # input_map: sumNPoint int
        # output_map: M * (maxActive + 1) int

        out_feats = pointgroup_ops.voxelization(clusters_feats, out_map.cuda(), mode)  # (M, C), float, cuda

        spatial_shape = [fullscale] * 3
        voxelization_feats = spconv.SparseConvTensor(out_feats, out_coords.int().cuda(), spatial_shape, int(clusters_idx[-1, 0]) + 1)

        return voxelization_feats, inp_map
    def point_voxelization(self, data_dict):
        '''Transfering batch point clouds to a sparse tensor

        Args:
            data_dict: dict

        Returns:
            input:SparseTensor

        '''
        feats = data_dict['seg_features'].cuda()
        v2p_map = data_dict['v2p_map'].cuda()
        voxel_locs = data_dict['voxel_locs'].int().cuda()
        spatial_shape = data_dict['spatial_shape']

        voxel_feats = pointgroup_ops.voxelization(
            feats, v2p_map,
            self.config['Segmentation']['mode'])  # (M, C), float, cuda
        input = spconv.SparseConvTensor(voxel_feats, voxel_locs, spatial_shape,
                                        self.config['TRAIN']['batch_size'])

        return input
예제 #4
0
    def model_fn(batch, model, epoch):
        #print ('model fn')
        ##### prepare input and forward
        # batch {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
        # 'locs_float': locs_float, 'feats': feats, 'labels': labels, 'instance_labels': instance_labels,
        # 'instance_info': instance_infos, 'instance_pointnum': instance_pointnum,
        # 'id': tbl, 'offsets': batch_offsets, 'spatial_shape': spatial_shape}
        coords = batch['locs'].cuda(
        )  # (N, 1 + 3), long, cuda, dimension 0 for batch_idx
        voxel_coords = batch['voxel_locs'].cuda()  # (M, 1 + 3), long, cuda
        p2v_map = batch['p2v_map'].cuda()  # (N), int, cuda
        v2p_map = batch['v2p_map'].cuda()  # (M, 1 + maxActive), int, cuda

        coords_float = batch['locs_float'].cuda()  # (N, 3), float32, cuda
        feats = batch['feats'].cuda()  # (N, C), float32, cuda
        labels = batch['labels'].cuda()  # (N), long, cuda
        '''instance_labels = batch['instance_labels'].cuda()      # (N), long, cuda, 0~total_nInst, -100

        instance_info = batch['instance_info'].cuda()          # (N, 9), float32, cuda, (meanxyz, minxyz, maxxyz)
        instance_pointnum = batch['instance_pointnum'].cuda()  # (total_nInst), int, cuda'''

        batch_offsets = batch['offsets'].cuda()  # (B + 1), int, cuda

        spatial_shape = batch['spatial_shape']

        if cfg.use_coords:
            feats = torch.cat((feats, coords_float), 1)
        voxel_feats = pointgroup_ops.voxelization(
            feats, v2p_map, cfg.mode)  # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(),
                                         spatial_shape, cfg.batch_size)

        ret = model(input_, p2v_map, coords_float, coords[:, 0].int(),
                    batch_offsets, epoch)
        semantic_scores = ret['semantic_scores']  # (N, nClass) float32, cuda
        '''pt_offsets = ret['pt_offsets']           # (N, 3), float32, cuda
        if(epoch > cfg.prepare_epochs):
            scores, proposals_idx, proposals_offset = ret['proposal_scores']
            # scores: (nProposal, 1) float, cuda
            # proposals_idx: (sumNPoint, 2), int, cpu, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset: (nProposal + 1), int, cpu'''

        loss_inp = {}
        loss_inp['semantic_scores'] = (semantic_scores, labels)
        '''loss_inp['pt_offsets'] = (pt_offsets, coords_float, instance_info, instance_labels)
        if(epoch > cfg.prepare_epochs):
            loss_inp['proposal_scores'] = (scores, proposals_idx, proposals_offset, instance_pointnum)'''

        loss, loss_out, infos = loss_fn(loss_inp, epoch)

        ##### accuracy / visual_dict / meter_dict
        with torch.no_grad():
            preds = {}
            preds['semantic'] = semantic_scores
            '''preds['pt_offsets'] = pt_offsets
            if(epoch > cfg.prepare_epochs):
                preds['score'] = scores
                preds['proposals'] = (proposals_idx, proposals_offset)'''

            visual_dict = {}
            visual_dict['loss'] = loss
            for k, v in loss_out.items():
                visual_dict[k] = v[0]

            meter_dict = {}
            meter_dict['loss'] = (loss.item(), coords.shape[0])
            for k, v in loss_out.items():
                meter_dict[k] = (float(v[0]), v[1])
        #print (meter_dict.keys())
        return loss, preds, visual_dict, meter_dict
예제 #5
0
    def model_fn(batch, model, epoch):
        #print ('model fn')
        ##### prepare input and forward
        # batch {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
        # 'locs_float': locs_float, 'feats': feats, 'labels': labels, 'instance_labels': instance_labels,
        # 'instance_info': instance_infos, 'instance_pointnum': instance_pointnum,
        # 'id': tbl, 'offsets': batch_offsets, 'spatial_shape': spatial_shape}
        coords = batch['locs'].cuda(
        )  # (N, 1 + 3), long, cuda, dimension 0 for batch_idx
        voxel_coords = batch['voxel_locs'].cuda()  # (M, 1 + 3), long, cuda
        p2v_map = batch['p2v_map'].cuda()  # (N), int, cuda
        v2p_map = batch['v2p_map'].cuda()  # (M, 1 + maxActive), int, cuda

        coords_float = batch['locs_float'].cuda()  # (N, 3), float32, cuda
        feats = batch['feats'].cuda()  # (N, C), float32, cuda
        labels = batch['labels'].cuda()  # (N), long, cuda

        groups = batch['groups']
        group2points = batch['group2points']

        group_fulls = batch['group_fulls']
        group2point_fulls = batch['group2point_fulls']

        #for i in range(4):
        #    print (len(groups[i][0]))

        classes = []
        poss = []
        negs = []
        for i in range(20):
            classes.append([])

        for g in range(len(groups)):
            group = groups[g]
            for i in range(20):
                for s in range(len(group[i])):
                    classes[i].append((i, g, group[i][s]))

        ignore = []
        mini = 10  #min(min(map(len, classes)),30)
        for i in range(20):
            random.shuffle(classes[i])
            if len(classes[i]) == 0:
                ignore.append(i)
                continue
            if len(classes[i]) >= mini:
                classes[i] = classes[i][:mini]
            else:
                while (len(classes[i]) < 10):
                    classes[i].append(random.choice(classes[i]))

        batch_offsets = batch['offsets'].cuda()  # (B + 1), int, cuda

        spatial_shape = batch['spatial_shape']

        if cfg.use_coords:
            feats = torch.cat((feats, coords_float), 1)
        voxel_feats = pointgroup_ops.voxelization(
            feats, v2p_map, cfg.mode)  # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(),
                                         spatial_shape, cfg.batch_size)

        ret, output_feats = model(input_, p2v_map, coords_float,
                                  coords[:, 0].int(), batch_offsets, epoch)

        tmpfeat = torch.Tensor(20, mini, model.m).cuda()

        label = torch.zeros(20 * mini).long().cuda()

        for i in range(20):
            if i in ignore:
                tmpfeat[i, :, :] = 0
                label[i * mini:i * mini + mini] = -100
                continue
            for j in range(mini):
                sample = classes[i][j]
                c0 = sample[0]
                b0 = sample[1]
                idx0 = sample[2]
                idx_off = torch.tensor(np.asarray(
                    group2points[b0][idx0])) + batch['offsets'][b0]
                feat = output_feats[idx_off]
                feat = torch.mean(feat, 0)
                tmpfeat[i, j, :] = feat

                label[i * mini + j] = i

        if model.start == 1:
            model.start = 0

            model.feat = torch.mean(tmpfeat.detach(), 1)
        else:
            model.feat = 0.9 * model.feat + 0.1 * torch.mean(
                tmpfeat.detach(), 1)

        #model.feat=nn.functional.normalize(model.feat,1)
        #tmpfeat=nn.functional.normalize(tmpfeat,2)

        tmpfeat = torch.reshape(tmpfeat, (20 * mini, model.m))

        product = torch.matmul(tmpfeat, torch.transpose(model.feat, 0,
                                                        1)) / 0.07

        semantic_scores_pred = ret[
            'semantic_scores']  # (N, nClass) float32, cuda

        semantic_scores_crfs = torch.zeros(
            (semantic_scores_pred.shape[0], 20)).cuda()
        result_crfs = torch.zeros((semantic_scores_pred.shape[0], )).cuda()
        #print ('semantic_scores_crfs',semantic_scores_crfs.shape)
        for batch_idx in range(len(groups)):
            group = groups[batch_idx]
            start = batch['offsets'][batch_idx]
            end = batch['offsets'][batch_idx + 1]
            #CRF
            Q = torch.nn.functional.softmax(
                semantic_scores_pred[start:end, :],
                1)  #torch.exp(semantic_scores_pred)

            #print (start,end,Q.shape)
            cnt = 0
            group_full = group_fulls[batch_idx]
            pair = torch.zeros((len(group_full), 6 + 32)).cuda()
            U = torch.zeros((len(group_full), 20)).cuda()

            for i in group_full:
                idxs = torch.from_numpy(
                    np.asarray(group2point_fulls[batch_idx][i]))  #.cuda()
                pair[cnt, :6] = torch.mean(feats[idxs], 0) * 1
                #pair[cnt,3:6]=torch.mean(coords_float[idxs],0)*0
                pair[cnt, 6:38] = torch.mean(output_feats[idxs], 0) * 0

                #print (Q[idxs],Q[idxs].shape)
                U[cnt, :] = torch.mean(Q[idxs], 0)
                #print (torch.max(idxs))
                assert (torch.max(idxs) < Q.shape[0])
                #logits[cnt,]=torch.mean(semantic_scores_pred[idxs],0)
                cnt += 1
            #U=U/torch.abs(torch.sum(U,1))
            #print ('U', U,U.shape,torch.unique(U))
            #Usum=torch.unsqueeze(torch.abs(torch.sum(U,1)),1).repeat(1,20)
            #print (torch.abs(torch.sum(U,1)))
            #U=U/Usum
            #print (U,U.shape,torch.unique(U))
            #assert
            Q_ori = torch.log((U + 1e-5))

            #print ('Q_ori',name,torch.unique(Q_ori))

            pair1 = torch.unsqueeze(pair, 1).repeat(1, pair.shape[0], 1)
            pair2 = torch.transpose(pair1, 0, 1)

            #print (U1.shape,U2.shape)
            diff = torch.sum((pair1 - pair2).pow(2), 2)

            #print ('torch.exp(-0.5*diff)',torch.exp(-0.5*diff),torch.unique(torch.exp(-0.5*diff)),torch.exp(-0.5*diff).shape)

            k = torch.exp(-0.5 * diff)  #w*exp(-0.5 * |f_i - f_j|^2)

            Q = U

            for it in range(1):
                #print ('iter',it)
                assert k.shape[0] > 0
                kQ = torch.mm(k, Q) / k.shape[0] * 1
                #kQ=torch.mm(kQ,miu)
                #Qsum=torch.expand_dims(torch.sum(kQ,1),1,20)-kQ
                #print ('Q',torch.unique(Q))
                #print ('kQ',kQ,torch.unique(kQ),kQ.shape)
                #print ('expkQ',torch.exp(kQ))
                #kQ[:]=0
                #print ('Q_ori',Q_ori,torch.unique(Q_ori),Q_ori.shape)
                newQ = Q_ori + kQ  #target is: Q_ori*torch.exp(kQ)

                #newQ=torch.nn.functional.softmax(newQ,1)

                #print ('nq', torch.unique(newQ))

                Q = newQ  #.clone()
                #print ('newQ',newQ,torch.unique(newQ),newQ.shape,torch.unique(torch.sum(newQ,1)))

            #Q_result=torch.argmax(Q,1)

            semantic_scores_crf = torch.zeros((end - start, 20)).cuda()
            semantic_scores_crf[:] = semantic_scores_pred[start:
                                                          end, :]  #.clone()
            #result_crf=torch.zeros((semantic_scores_pred.shape[0],))

            #print (group_full,len(group_full))'
            #print (Q.shape)

            for i in range(len(group_full)):
                #print (i)
                g = group_full[i]
                #print (g)
                #print (group2point_fulls[batch_idx])
                #print (group2point_fulls[batch_idx][g])
                #print (np.asarray(group2point_fulls[batch_idx][g]))
                idxs = torch.from_numpy(
                    np.asarray(group2point_fulls[batch_idx][g])).cuda()
                #print (idxs)]
                #tmp=torch.unsqueeze(Q[i],0)
                #print ('Q',tmp,torch.unique(tmp),tmp.shape)
                #print (i,newQ[i])
                #for j in idxs:
                semantic_scores_crf[idxs, :] = newQ[i]

                #result_crf[idxs]=torch.argmax(Q[i],1)
            #print (semantic_scores_crf.shape, torch.unique(semantic_scores_crf))
            semantic_scores_crfs[start:end, :] = semantic_scores_crf
            #print (start,end)

        #print ('semantic_scros_crf',semantic_scores_crfs.shape,torch.unique(semantic_scores_crfs))

        loss_inp = {}
        loss_inp['semantic_scores'] = (product, label)

        loss_inp['semantic_scores_crf'] = (semantic_scores_crfs, labels)

        loss_inp['semantic_scores_pred'] = (semantic_scores_pred, labels)

        #print (semantic_scores_crfs, labels, semantic_scores_pred, labels)

        loss, loss_relation, loss_pred, loss_crf, loss_out, infos = loss_fn(
            loss_inp, epoch)

        ##### accuracy / visual_dict / meter_dict
        with torch.no_grad():

            preds = {}
            #preds['semantic'] = semantic_scores

            visual_dict = {}
            visual_dict['loss'] = loss
            for k, v in loss_out.items():
                visual_dict[k] = v[0]

            meter_dict = {}
            meter_dict['loss'] = (loss.item(), coords.shape[0])
            for k, v in loss_out.items():
                meter_dict[k] = (float(v[0]), v[1])

        return loss, loss_relation, loss_pred, loss_crf, preds, visual_dict, meter_dict
예제 #6
0
    def forward(self, input, input_map, coords, rgb, ori_coords, batch_idxs,
                batch_offsets, epoch):
        '''
        :param input_map: (N), int, cuda
        :param coords: (N, 3), float, cuda
        :param batch_idxs: (N), int, cuda
        :param batch_offsets: (B + 1), int, cuda
        '''
        ret = {}

        batch_idxs = batch_idxs.squeeze()

        semantic_scores = []
        point_offset_preds = []
        voxel_occupancy_preds = []

        voxel_feats = pointgroup_ops.voxelization(
            input['pt_feats'], input['v2p_map'],
            input['mode'])  # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(voxel_feats, input['voxel_coords'],
                                         input['spatial_shape'],
                                         input['batch_size'])
        output = self.input_conv(input_)
        output = self.unet(output)
        output = self.output_layer(output)
        output_feats = output.features[input_map.long()]
        output_feats = output_feats.squeeze(dim=0)

        ### point prediction
        #### point semantic label prediction
        semantic_scores.append(
            self.point_semantic(output_feats))  # (N, nClass), float

        ### only used to evaluate based on ground truth
        # semantic_scores.append(input['point_semantic_scores'][0])  # (N, nClass), float
        ### ground truth for each category
        # CATE_NUM = 0
        # semantic_output = self.point_semantic(output_feats)
        # if (input['point_semantic_scores'][0].max(dim=1)[1] == CATE_NUM).sum() > 0:
        #     semantic_output[input['point_semantic_scores'][0].max(dim=1)[1] == CATE_NUM] = \
        #     input['point_semantic_scores'][0][input['point_semantic_scores'][0].max(dim=1)[1] == CATE_NUM].float()
        # semantic_output[semantic_output.max(dim=1)[1] == CATE_NUM] = \
        # input['point_semantic_scores'][0][semantic_output.max(dim=1)[1] == CATE_NUM].float()
        # semantic_scores.append(semantic_output)

        point_semantic_preds = semantic_scores[0].max(1)[1]

        #### point offset prediction
        point_offset_pred = self.point_offset(output_feats)
        point_offset_preds.append(point_offset_pred)  # (N, 3), float32
        # only used to evaluate based on ground truth
        # point_offset_preds.append(input['point_offset_preds'])  # (N, 3), float32

        voxel_occupancy_preds.append(self.point_occupancy(output.features))
        point_occupancy_pred = voxel_occupancy_preds[0][
            input_map.long()].squeeze(dim=1)

        if (epoch > self.prepare_epochs):
            #### get prooposal clusters
            object_idxs = torch.nonzero(point_semantic_preds > 1).view(-1)

            batch_idxs_ = batch_idxs[object_idxs]
            batch_offsets_ = utils.get_batch_offsets(batch_idxs_,
                                                     input['batch_size'])
            coords_ = coords[object_idxs]
            pt_offsets_ = point_offset_preds[0][object_idxs]
            point_occupancy_pred_ = point_occupancy_pred[object_idxs]

            semantic_preds_cpu = point_semantic_preds[object_idxs].int().cpu()

            # idx_occupancy, start_len_occupancy = pointgroup_ops.ballquery_batch_p(
            #     coords_ + pt_offsets_, batch_idxs_,
            #     batch_offsets_, self.cluster_radius, self.cluster_shift_meanActive
            # )
            # proposals_idx_occupancy, proposals_offset_occupancy = pointgroup_ops.bfs_occupancy_cluster(
            #     semantic_preds_cpu, point_occupancy_pred_.cpu(), idx_occupancy.cpu(),
            #     start_len_occupancy.cpu(), self.cluster_npoint_thre, self.occupancy_cluster['occupancy_threshold_shift']
            # )
            # proposals_idx_occupancy[:, 1] = object_idxs[proposals_idx_occupancy[:, 1].long()].int()
            # # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # # proposals_offset_shift: (nProposal + 1), int
            #
            # idx, start_len = pointgroup_ops.ballquery_batch_p(
            #     coords_, batch_idxs_, batch_offsets_, self.cluster_radius, self.cluster_meanActive
            # )
            # proposals_idx, proposals_offset = pointgroup_ops.bfs_occupancy_cluster(
            #     semantic_preds_cpu, point_occupancy_pred_.cpu(), idx.cpu(),
            #     start_len.cpu(), self.cluster_npoint_thre, self.occupancy_cluster['occupancy_threshold']
            # )
            # proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()

            idx, start_len = pointgroup_ops.ballquery_batch_p(
                coords_, batch_idxs_, batch_offsets_, self.cluster_radius,
                self.cluster_meanActive)
            proposals_idx, proposals_offset = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx.cpu(), start_len.cpu(),
                self.cluster_npoint_thre)
            proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()

            idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(
                coords_ + pt_offsets_, batch_idxs_, batch_offsets_,
                self.cluster_radius, self.cluster_shift_meanActive)
            proposals_idx_shift, proposals_offset_shift = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx_shift.cpu(), start_len_shift.cpu(),
                self.cluster_npoint_thre)
            proposals_idx_shift[:, 1] = object_idxs[
                proposals_idx_shift[:, 1].long()].int()
            # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset_shift: (nProposal + 1), int

            # proposals_idx_occupancy[:, 0] += (proposals_offset.size(0) - 1)
            # proposals_offset_occupancy += proposals_offset[-1]
            # proposals_idx = torch.cat((proposals_idx, proposals_idx_occupancy), dim=0)
            # proposals_offset = torch.cat((proposals_offset, proposals_offset_occupancy[1:]))

            # proposals_idx_filtered = []
            # proposals_offset_filtered = [0]
            # for proposal in proposals_idx_shift[:, 0].unique():
            #     proposal_index = proposals_idx_shift[:, 0] == proposal
            #     proposal_idxs = proposals_idx_shift[proposal_index, 1].long()
            #     proposals_indexs = proposals_idx_shift[proposal_index, :]
            #
            #     proposal_occupancy_mean = point_occupancy_pred[proposal_idxs].mean()
            #
            #     valid_point_index = torch.ones_like(proposal_idxs).byte()
            #     valid_point_index[point_occupancy_pred[proposal_idxs] < proposal_occupancy_mean *
            #                       (1 - self.occupancy_cluster['occupancy_filter_threshold'])] = 0
            #     valid_point_index[point_occupancy_pred[proposal_idxs] > proposal_occupancy_mean *
            #                       (1 + self.occupancy_cluster['occupancy_filter_threshold'])] = 0
            #     proposal_idx_filtered = proposals_indexs[valid_point_index, :]
            #
            #     proposals_idx_filtered.append(proposal_idx_filtered)
            #     proposals_offset_filtered.append(proposals_offset_filtered[-1] + proposal_idx_filtered.shape[0])
            #
            # proposals_idx_filtered = torch.cat(proposals_idx_filtered, dim=0)
            # proposals_offset_filtered = torch.tensor(proposals_offset_filtered).int()
            # proposals_idx_shift = proposals_idx_filtered
            # proposals_offset_shift = proposals_offset_filtered

            proposals_idx_shift[:, 0] += (proposals_offset.size(0) - 1)
            proposals_offset_shift += proposals_offset[-1]
            proposals_idx = torch.cat((proposals_idx, proposals_idx_shift),
                                      dim=0)
            proposals_offset = torch.cat(
                (proposals_offset, proposals_offset_shift[1:]))

            #### proposals voxelization again
            input_feats, inp_map = self.clusters_voxelization(
                proposals_idx, proposals_offset, output_feats, coords,
                self.score_fullscale, self.score_scale, self.mode)

            #### score
            score = self.score_unet(input_feats)
            score = self.score_outputlayer(score)
            score_feats = score.features[inp_map.long()]  # (sumNPoint, C)
            score_feats = pointgroup_ops.roipool(
                score_feats, proposals_offset.cuda())  # (nProposal, C)
            scores = self.score_linear(score_feats)  # (nProposal, 1)

            # proposals_idx = proposals_idx_occupancy
            # proposals_offset = proposals_offset_occupancy
            # scores = torch.ones(proposals_offset.shape[0] - 1, 1).to(point_offset_preds[0].device)

            ret['proposal_scores'] = (scores, proposals_idx, proposals_offset)

        ret['point_semantic_scores'] = semantic_scores
        ret['point_offset_preds'] = point_offset_preds
        ret['point_features'] = output_feats
        ret['voxel_occupancy_preds'] = voxel_occupancy_preds

        return ret
예제 #7
0
    def forward(self, input, input_map, coords, rgb, ori_coords, batch_idxs,
                batch_offsets, epoch):
        '''
        :param input_map: (N), int, cuda
        :param coords: (N, 3), float, cuda
        :param batch_idxs: (N), int, cuda
        :param batch_offsets: (B + 1), int, cuda
        '''
        ret = {}

        batch_idxs = batch_idxs.squeeze()

        semantic_scores = []
        point_offset_preds = []

        voxel_feats = pointgroup_ops.voxelization(
            input['pt_feats'], input['v2p_map'],
            input['mode'])  # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(voxel_feats, input['voxel_coords'],
                                         input['spatial_shape'],
                                         input['batch_size'])
        output = self.input_conv(input_)
        output = self.unet(output)
        output = self.output_layer(output)
        output_feats = output.features[input_map.long()]
        output_feats = output_feats.squeeze(dim=0)

        ### point prediction
        #### point semantic label prediction
        semantic_scores.append(
            self.point_semantic(output_feats))  # (N, nClass), float

        ### only used to evaluate based on ground truth
        # semantic_scores.append(input['point_semantic_scores'][0])  # (N, nClass), float
        ### ground truth for each category
        # CATE_NUM = 0
        # semantic_output = self.point_semantic(output_feats)
        # if (input['point_semantic_scores'][0].max(dim=1)[1] == CATE_NUM).sum() > 0:
        #     semantic_output[input['point_semantic_scores'][0].max(dim=1)[1] == CATE_NUM] = \
        #     input['point_semantic_scores'][0][input['point_semantic_scores'][0].max(dim=1)[1] == CATE_NUM].float()
        # semantic_output[semantic_output.max(dim=1)[1] == CATE_NUM] = \
        # input['point_semantic_scores'][0][semantic_output.max(dim=1)[1] == CATE_NUM].float()
        # semantic_scores.append(semantic_output)

        point_semantic_preds = semantic_scores[0].max(1)[1]

        #### point offset prediction
        point_offset_pred = self.point_offset(output_feats)
        if self.instance_triplet_loss['activate']:
            point_offset_pred = point_offset_pred - input['pt_feats'][:, 3:]
        point_offset_preds.append(point_offset_pred)  # (N, 3), float32
        # only used to evaluate based on ground truth
        # point_offset_preds.append(input['point_offset_preds'])  # (N, 3), float32

        if self.voxel_center_prediction['activate']:
            voxel_center_preds = self.voxel_center_pred(output.features)
            voxel_center_offset_preds = self.voxel_center_offset(
                output.features)
            voxel_center_semantic_preds = self.voxel_center_semantic(
                output.features)

            ret['voxel_center_preds'] = voxel_center_preds
            ret['voxel_center_offset_preds'] = voxel_center_offset_preds
            ret['voxel_center_semantic_preds'] = voxel_center_semantic_preds

        if self.point_xyz_reconstruction_loss['activate']:
            point_reconstructed_coords = self.point_reconstruction_coords(
                output_feats)

            ret['point_reconstructed_coords'] = point_reconstructed_coords

        if self.instance_classifier['activate']:
            instance_id_preds = self.point_instance_classifier(output_feats)

            ret['instance_id_preds'] = instance_id_preds

        if (epoch > self.prepare_epochs):
            #### get prooposal clusters
            object_idxs = torch.nonzero(point_semantic_preds > 1).view(-1)

            batch_idxs_ = batch_idxs[object_idxs]
            batch_offsets_ = utils.get_batch_offsets(batch_idxs_,
                                                     input['batch_size'])
            coords_ = coords[object_idxs]
            pt_offsets_ = point_offset_preds[0][object_idxs]

            semantic_preds_cpu = point_semantic_preds[object_idxs].int().cpu()

            idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(
                coords_ + pt_offsets_, batch_idxs_, batch_offsets_,
                self.cluster_radius, self.cluster_shift_meanActive)
            # idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(
            #     coords_ + pt_offsets_ + (torch.rand(coords_.shape) * 1e-2).cuda(), batch_idxs_,
            #     batch_offsets_, 0.001, self.cluster_shift_meanActive
            # )
            proposals_idx_shift, proposals_offset_shift = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx_shift.cpu(), start_len_shift.cpu(),
                self.cluster_npoint_thre)

            proposals_idx_shift[:, 1] = object_idxs[
                proposals_idx_shift[:, 1].long()].int()
            # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset_shift: (nProposal + 1), int

            idx, start_len = pointgroup_ops.ballquery_batch_p(
                coords_, batch_idxs_, batch_offsets_, self.cluster_radius,
                self.cluster_meanActive)
            proposals_idx, proposals_offset = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx.cpu(), start_len.cpu(),
                self.cluster_npoint_thre)
            proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()
            # proposals_idx: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset: (nProposal + 1), int

            proposals_idx_shift[:, 0] += (proposals_offset.size(0) - 1)
            proposals_offset_shift += proposals_offset[-1]
            proposals_idx = torch.cat((proposals_idx, proposals_idx_shift),
                                      dim=0)
            proposals_offset = torch.cat(
                (proposals_offset, proposals_offset_shift[1:]))

            #### proposals voxelization again
            input_feats, inp_map = self.clusters_voxelization(
                proposals_idx, proposals_offset, output_feats, coords,
                self.score_fullscale, self.score_scale, self.mode)

            #### score
            score = self.score_unet(input_feats)
            score = self.score_outputlayer(score)
            score_feats = score.features[inp_map.long()]  # (sumNPoint, C)
            score_feats = pointgroup_ops.roipool(
                score_feats, proposals_offset.cuda())  # (nProposal, C)
            scores = self.score_linear(score_feats)  # (nProposal, 1)

            ret['proposal_scores'] = (scores, proposals_idx, proposals_offset)

        ret['point_semantic_scores'] = semantic_scores
        ret['point_offset_preds'] = point_offset_preds
        if self.instance_triplet_loss['activate']:
            ret['point_offset_feats'] = output_feats
        ret['point_features'] = output_feats

        return ret
예제 #8
0
    def model_fn(batch, model, epoch):
        #print ('model fn')
        ##### prepare input and forward
        # batch {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
        # 'locs_float': locs_float, 'feats': feats, 'labels': labels, 'instance_labels': instance_labels,
        # 'instance_info': instance_infos, 'instance_pointnum': instance_pointnum,
        # 'id': tbl, 'offsets': batch_offsets, 'spatial_shape': spatial_shape}
        coords = batch['locs'].cuda(
        )  # (N, 1 + 3), long, cuda, dimension 0 for batch_idx
        voxel_coords = batch['voxel_locs'].cuda()  # (M, 1 + 3), long, cuda
        p2v_map = batch['p2v_map'].cuda()  # (N), int, cuda
        v2p_map = batch['v2p_map'].cuda()  # (M, 1 + maxActive), int, cuda

        coords_float = batch['locs_float'].cuda()  # (N, 3), float32, cuda
        feats = batch['feats'].cuda()  # (N, C), float32, cuda
        labels = batch['labels'].cuda()  # (N), long, cuda

        groups = batch['groups']
        group2points = batch['group2points']

        #for i in range(4):
        #    print (len(groups[i][0]))

        classes = []
        poss = []
        negs = []
        for i in range(20):
            classes.append([])

        for g in range(len(groups)):
            group = groups[g]
            for i in range(20):
                for s in range(len(group[i])):
                    classes[i].append((i, g, group[i][s]))

        ignore = []
        mini = 10  #min(min(map(len, classes)),30)
        for i in range(20):
            random.shuffle(classes[i])
            if len(classes[i]) == 0:
                ignore.append(i)
                continue
            if len(classes[i]) >= mini:
                classes[i] = classes[i][:mini]
            else:
                while (len(classes[i]) < 10):
                    classes[i].append(random.choice(classes[i]))
        '''for i in range(20):
            if len(classes[i])==0:
                continue
            for times in range(10):
                poss.append((random.choice(classes[i]),random.choice(classes[i])))
            for j in range(20):
                if j==i or len(classes[j])==0:
                  continue 
                negs.append((random.choice(classes[i]),random.choice(classes[j])))
            while(len(negs)<20):
                j=random.randint(0,19)
                if j==i or len(classes[j])==0:
                  continue 
                negs.append((random.choice(classes[i]),random.choice(classes[j])))
        
        
        posidxs=[]
        for pos in poss:
            c0 = pos[0][0]
            b0 = pos[0][1]
            idx0 = pos[0][2]
            #print ('fffff',group2points[b0].keys())
            idx_off0 = torch.tensor(np.asarray(group2points[b0][idx0])) + batch['offsets'][b0]
            #print ('sssssss',group2points[b0])
            c1 = pos[1][0]
            b1 = pos[1][1]
            idx1 = pos[1][2]
            idx_off1 = torch.tensor(np.asarray(group2points[b1][idx1])) + batch['offsets'][b1]
            posidxs.append((idx_off0,idx_off1))

        negidxs=[]
        for neg in negs:
            c0 = neg[0][0]
            b0 = neg[0][1]
            idx0 = neg[0][2]
            idx_off0 = torch.tensor(np.asarray(group2points[b0][idx0])) + batch['offsets'][b0]
            #print (idx_off0)
            c1 = neg[1][0]
            b1 = neg[1][1]
            idx1 = neg[1][2]
            idx_off1 = torch.tensor(np.asarray(group2points[b1][idx1])) + batch['offsets'][b1]
            negidxs.append((idx_off0,idx_off1))'''

        batch_offsets = batch['offsets'].cuda()  # (B + 1), int, cuda

        spatial_shape = batch['spatial_shape']

        if cfg.use_coords:
            feats = torch.cat((feats, coords_float), 1)
        voxel_feats = pointgroup_ops.voxelization(
            feats, v2p_map, cfg.mode)  # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(),
                                         spatial_shape, cfg.batch_size)

        ret, output_feats = model(input_, p2v_map, coords_float,
                                  coords[:, 0].int(), batch_offsets, epoch)

        #print (torch.unique(voxel_feats),output_feats)

        tmpfeat = torch.Tensor(20, mini, model.m).cuda()

        label = torch.zeros(20 * mini).long().cuda()

        for i in range(20):
            if i in ignore:
                tmpfeat[i, :, :] = 0
                label[i * mini:i * mini + mini] = -100
                continue
            for j in range(mini):
                sample = classes[i][j]
                c0 = sample[0]
                b0 = sample[1]
                idx0 = sample[2]
                #print ('fffff',group2points[b0].keys())
                idx_off = torch.tensor(np.asarray(
                    group2points[b0][idx0])) + batch['offsets'][b0]
                feat = output_feats[idx_off]
                feat = torch.mean(feat, 0)
                tmpfeat[i, j, :] = feat

                label[i * mini + j] = i

        if model.start == 1:
            model.start = 0

            model.feat = torch.mean(tmpfeat.detach(), 1)
        else:
            model.feat = 0.9 * model.feat + 0.1 * torch.mean(
                tmpfeat.detach(), 1)

        #model.feat=nn.functional.normalize(model.feat,1)
        #tmpfeat=nn.functional.normalize(tmpfeat,2)

        tmpfeat = torch.reshape(tmpfeat, (20 * mini, model.m))

        product = torch.matmul(tmpfeat, torch.transpose(model.feat, 0,
                                                        1)) / 0.07

        loss_inp = {}
        loss_inp['semantic_scores'] = (product, label)
        '''posfeat=torch.zeros((10,128*2))
        for j in range(len(posidxs)):
          posidx=posidxs[j]
          feat0=output_feats[posidx[0]]      
          feat0=torch.mean(feat0,0)
          feat1=output_feats[posidx[1]]      
          feat1=torch.mean(feat0,0)
          feat=torch.cat((feat0,feat1),1)
          posfeat[j][:128]=feat0
          posfeat[j][128:]=feat1


        negfeat=torch.zeros((20,128*2))
        for j in range(len(negidxs)):
          negidx=negidxs[j]
          feat0=output_feats[negidx[0]]      
          feat0=torch.mean(feat0,0)
          feat1=output_feats[negidx[1]]      
          feat1=torch.mean(feat0,0)
          feat=torch.cat((feat0,feat1),1)
          negfeat[j][:128]=feat0
          negfeat[j][128:]=feat1'''
        '''semantic_scores = ret['semantic_scores'] # (N, nClass) float32, cuda


        print ('s',semantic_scores.shape)
        #print ('groups',groups)


        loss_inp = {}
        loss_inp['semantic_scores'] = (semantic_scores, labels)
        print ('l',labels.shape,'\n')'''

        loss, loss_out, infos = loss_fn(loss_inp, epoch)

        ##### accuracy / visual_dict / meter_dict
        with torch.no_grad():

            preds = {}
            #preds['semantic'] = semantic_scores

            visual_dict = {}
            visual_dict['loss'] = loss
            for k, v in loss_out.items():
                visual_dict[k] = v[0]

            meter_dict = {}
            meter_dict['loss'] = (loss.item(), coords.shape[0])
            for k, v in loss_out.items():
                meter_dict[k] = (float(v[0]), v[1])

        return loss, preds, visual_dict, meter_dict
예제 #9
0
    def test_model_fn(batch, model, epoch):
        #print ('test model fn')

        feats = torch.zeros((20, 16))

        coords = batch['locs'].cuda(
        )  # (N, 1 + 3), long, cuda, dimension 0 for batch_idx
        voxel_coords = batch['voxel_locs'].cuda()  # (M, 1 + 3), long, cuda
        p2v_map = batch['p2v_map'].cuda()  # (N), int, cuda
        v2p_map = batch['v2p_map'].cuda()  # (M, 1 + maxActive), int, cuda

        coords_float = batch['locs_float'].cuda()  # (N, 3), float32, cuda
        feats = batch['feats'].cuda()  # (N, C), float32, cuda

        batch_offsets = batch['offsets'].cuda()  # (B + 1), int, cuda

        spatial_shape = batch['spatial_shape']
        labels = batch['labels']
        groups = batch['groups']
        group2points = batch['group2points']

        if cfg.use_coords:
            feats = torch.cat((feats, coords_float), 1)
        voxel_feats = pointgroup_ops.voxelization(
            feats, v2p_map, cfg.mode)  # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(),
                                         spatial_shape, cfg.batch_size)

        ret, output_feats = model(input_, p2v_map, coords_float,
                                  coords[:, 0].int(), batch_offsets, epoch)
        products = torch.zeros((output_feats.shape[0], 20))
        result = torch.zeros((output_feats.shape[0]))
        result_feat = torch.zeros((output_feats.shape[0], 32))
        #for group in groups:

        #print ('len',len(groups),groups)

        feat_voxel = []
        gt_voxel = []
        rgb_voxel = []

        group = groups[0]
        for i in range(20):
            for s in range(len(group[i])):

                idxs = group2points[0][group[i][s]]
                feats = output_feats[idxs]
                #rgbs=feats[idxs]
                feat = torch.mean(feats, 0)
                #rgb=torch.mean(rgbs,0)
                #feat_voxel.append(feat.detach().cpu().numpy())
                #gt_voxel.append(i)
                #rgb_voxel.append(rgb.detach().cpu().numpy())

                #stdfeat=torch.Tensor(torch.load('exp/scannetv2/pointgroup/pointgroup_run1_scannet/pointgroup_run1_scannet-000000074_feat.pth')).cuda()

                product = torch.matmul(feat, torch.transpose(model.feat, 0, 1))
                #print (i,int(torch.argmax(product)),product)
                #result[idxs]=int(torch.argmax(product))
                for idx in idxs:
                    result[idx] = int(torch.argmax(product))
                    result_feat[idx, :] = feat
                    products[idx, :] = product
                #print (int(torch.argmax(product)))
                #print ('idxs',len(idxs),idxs)

        #print ('unique',torch.unique(result))
        #semantic_scores=torch.argmax(product,1)
        '''feats=np.zeros((20,16))
        labels=np.zeros((20))
        for i in range(feat.shape[0]):
          if label[i]==-100:
            continue
          feats[label[i],:]+=feat[i,:].cpu().numpy()
          labels[label[i]]+=1
        for i in range(20):
          feats[i,:]/=labels[i]'''

        #print (feat.shape,label.shape)

        #semantic_scores = ret['semantic_scores']  # (N, nClass) float32, cuda
        '''pt_offsets = ret['pt_offsets']            # (N, 3), float32, cuda
        if (epoch > cfg.prepare_epochs):
            scores, proposals_idx, proposals_offset = ret['proposal_scores']'''
        '''product=torch.matmul(feat,torch.transpose(model.feat,0,1))
        
        
        print ('product',product.shape)
        
        result=torch.argmax(product,1)
        
        print ('result',result.shape)'''

        #loss_inp = {}
        #loss_inp['semantic_scores'] = (product, label)

        ##### preds
        with torch.no_grad():
            preds = {}
            preds['semantic'] = result
            preds['products'] = products
            print('feat')
            preds['feat'] = result_feat  #np.asarray(feat_voxel)
            #preds['gt'] = np.asarray(gt_voxel)
            #preds['rgb'] = np.asarray(rgb_voxel)
            '''preds['pt_offsets'] = pt_offsets
            if (epoch > cfg.prepare_epochs):
                preds['score'] = scores
                preds['proposals'] = (proposals_idx, proposals_offset)'''

        return preds
예제 #10
0
    def forward(self, data_dict, epoch):
        '''
        :param input_map: (N), int, cuda
        :param coords: (N, 3), float, cuda
        :param batch_idxs: (N), int, cuda
        :param batch_offsets: (B + 1), int, cuda
        '''
        #unpack input data
        coords = data_dict[
            'locs']  # (N, 1 + 3), long, cuda, dimension 0 for batch_idx
        voxel_coords = data_dict['voxel_locs']  # (M, 1 + 3), long, cuda
        p2v_map = data_dict['p2v_map']  # (N), int, cuda
        v2p_map = data_dict['v2p_map']  # (M, 1 + maxActive), int, cuda

        coords_float = data_dict['locs_float']  # (N, 3), float32, cuda
        feats = data_dict['feats']  # (N, C), float32, cuda

        spatial_shape = data_dict['spatial_shape']

        batch_idxs = coords[:, 0].int()

        if self.cfg.use_coords:
            feats = torch.cat((coords_float, feats), 1)
        voxel_feats = pointgroup_ops.voxelization(
            feats, v2p_map, self.cfg.mode)  # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(),
                                         spatial_shape, self.cfg.batch_size)

        output = self.input_conv(input_)
        output = self.unet(output)
        output = self.output_layer(output)
        output_feats = output.features[p2v_map.long()]  ## F

        #### semantic segmentation
        semantic_scores = self.linear(output_feats)  # (N, nClass), float
        semantic_preds = semantic_scores.max(1)[1]  # (N), long   ---- S

        data_dict['semantic_scores'] = semantic_scores

        #### offset
        pt_offsets_feats = self.offset(output_feats)
        pt_offsets = self.offset_linear(
            pt_offsets_feats)  # (N, 3), float32 --- O

        data_dict['pt_offsets'] = pt_offsets

        if (epoch > self.prepare_epochs):
            #### get prooposal clusters
            object_idxs = torch.nonzero(semantic_preds >= 0).view(-1)

            batch_idxs_ = batch_idxs[object_idxs]
            batch_offsets_ = utils.get_batch_offsets(batch_idxs_,
                                                     input_.batch_size)
            coords_ = coords_float[object_idxs]
            pt_offsets_ = pt_offsets[object_idxs]

            semantic_preds_cpu = semantic_preds[object_idxs].int().cpu()

            #Q
            idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(
                coords_ + pt_offsets_, batch_idxs_, batch_offsets_,
                self.cluster_radius, self.cluster_shift_meanActive)
            proposals_idx_shift, proposals_offset_shift = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx_shift.cpu(), start_len_shift.cpu(),
                self.cluster_npoint_thre)
            proposals_idx_shift[:, 1] = object_idxs[
                proposals_idx_shift[:, 1].long()].int()
            # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset_shift: (nProposal + 1), int

            #P
            idx, start_len = pointgroup_ops.ballquery_batch_p(
                coords_, batch_idxs_, batch_offsets_, self.cluster_radius,
                self.cluster_meanActive)
            proposals_idx, proposals_offset = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx.cpu(), start_len.cpu(),
                self.cluster_npoint_thre)
            proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()
            # proposals_idx: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset: (nProposal + 1), int

            proposals_idx_shift[:, 0] += (proposals_offset.size(0) - 1)
            proposals_offset_shift += proposals_offset[-1]
            proposals_idx = torch.cat((proposals_idx, proposals_idx_shift),
                                      dim=0)
            proposals_offset = torch.cat(
                (proposals_offset, proposals_offset_shift[1:]))

            #### proposals voxelization again
            input_feats, inp_map = self.clusters_voxelization(
                proposals_idx, proposals_offset, output_feats, coords_float,
                self.score_fullscale, self.score_scale, self.mode)

            #### score
            score = self.score_unet(input_feats)
            score = self.score_outputlayer(score)
            score_feats = score.features[inp_map.long()]  # (sumNPoint, C)
            score_feats = pointgroup_ops.roipool(
                score_feats, proposals_offset.cuda())  # (nProposal, C)
            scores = self.score_linear(score_feats)  # (nProposal, 1)

            object_feats = self.object_features_linear(score_feats)

            data_dict['proposal_scores'] = (scores, proposals_idx,
                                            proposals_offset)
            data_dict['proposal_feature'] = object_feats

        return data_dict
    def forward(self, input, input_map, coords, rgb, ori_coords, batch_idxs,
                batch_offsets, epoch):
        '''
        :param input_map: (N), int, cuda
        :param coords: (N, 3), float, cuda
        :param batch_idxs: (N), int, cuda
        :param batch_offsets: (B + 1), int, cuda
        '''
        ret = {}

        semantic_scores = []
        point_offset_preds = []

        voxel_feats = pointgroup_ops.voxelization(
            input['pt_feats'], input['v2p_map'],
            input['mode'])  # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(voxel_feats, input['voxel_coords'],
                                         input['spatial_shape'],
                                         input['batch_size'])
        output = self.input_conv(input_)
        output = self.unet(output)
        output = self.output_layer(output)
        output_feats = output.features[input_map.long()]
        output_feats = output_feats.squeeze(dim=0)

        ### point prediction
        #### point semantic label prediction
        semantic_scores.append(
            self.point_semantic(output_feats))  # (N, nClass), float
        # point_semantic_preds = semantic_scores
        point_semantic_preds = semantic_scores[0].max(1)[1]
        #### point offset prediction
        point_offset_preds.append(
            self.point_offset(output_feats))  # (N, 3), float32

        point_feats = []
        sampled_indexes = []
        for sample_indx in range(1, len(batch_offsets)):
            coords_input = coords[
                batch_offsets[sample_indx -
                              1]:batch_offsets[sample_indx], :].unsqueeze(
                                  dim=0)
            rgb_input = rgb[
                batch_offsets[sample_indx -
                              1]:batch_offsets[sample_indx], :].unsqueeze(
                                  dim=0)
            if not input['test'] and self.center_clustering['use_gt_semantic']:
                point_semantic_preds = input['semantic_labels']
            point_semantic_pred = point_semantic_preds[
                batch_offsets[sample_indx - 1]:batch_offsets[sample_indx]]
            for semantic_idx in point_semantic_pred.unique():
                if semantic_idx < 2:
                    continue
                semantic_point_idx = (
                    point_semantic_pred == semantic_idx).nonzero().squeeze(
                        dim=1)

                sampled_index = pointnet2_utils.furthest_point_sample(
                    coords_input[:, semantic_point_idx, :3].contiguous(),
                    self.pointnet_max_npoint).squeeze(dim=0).long()
                sampled_index = semantic_point_idx[sampled_index]
                sampled_indexes.append(
                    torch.cat((torch.LongTensor(sampled_index.shape[0],
                                                1).fill_(sample_indx).cuda(),
                               sampled_index.unsqueeze(dim=1)),
                              dim=1))

                sampled_coords_input = coords_input[:, sampled_index, :]
                sampled_rgb_input = rgb_input[:, sampled_index, :]

                point_feat, _ = self.pointnet_encoder(
                    sampled_coords_input,
                    torch.cat((sampled_rgb_input, sampled_coords_input),
                              dim=2).transpose(1, 2).contiguous())
                point_feats.append(point_feat)

        point_feats = torch.cat(point_feats, dim=0)
        sampled_indexes = torch.cat(sampled_indexes, dim=0)

        ### center prediction
        center_preds = self.center_pred(point_feats)
        center_semantic_preds = self.center_semantic(point_feats)
        center_offset_preds = self.center_offset(point_feats)

        ret['point_semantic_scores'] = semantic_scores
        ret['point_offset_preds'] = point_offset_preds

        ret['center_preds'] = (center_preds, sampled_indexes)
        ret['center_semantic_preds'] = (center_semantic_preds, sampled_indexes)
        ret['center_offset_preds'] = (center_offset_preds, sampled_indexes)

        ret['point_features'] = output_feats

        return ret
    def forward(self, input, input_map, coords, rgb, ori_coords, batch_idxs,
                batch_offsets, epoch):
        '''
        :param input_map: (N), int, cuda
        :param coords: (N, 3), float, cuda
        :param batch_idxs: (N), int, cuda
        :param batch_offsets: (B + 1), int, cuda
        '''
        ret = {}

        point_semantic_scores = []
        point_offset_preds = []
        local_point_semantic_scores = []
        local_point_offset_preds = []
        '''point feature extraction'''
        # voxelize point cloud, and those voxels are passed to a sparse 3D U-Net
        # the output of sparse 3D U-Net is then remapped back to points
        voxel_feats = pointgroup_ops.voxelization(
            input['pt_feats'], input['v2p_map'],
            input['mode'])  # (M, C), float, cuda
        input_ = spconv.SparseConvTensor(voxel_feats, input['voxel_coords'],
                                         input['spatial_shape'],
                                         input['batch_size'])
        output = self.input_conv(input_)
        output = self.unet(output)
        output = self.output_layer(output)
        output_feats = output.features[input_map.long()]
        '''point-wise predictions'''
        point_semantic_scores.append(
            self.point_semantic(output_feats))  # (N, nClass), float
        point_semantic_preds = point_semantic_scores[-1].max(1)[1]

        point_offset_preds.append(
            self.point_offset(output_feats))  # (N, 3), float32

        # cluster proposals with the stable output of the backbone network
        if (epoch > self.prepare_epochs):
            if not input['test'] and self.local_proposal['use_gt_semantic']:
                point_semantic_preds = input['semantic_labels']
            '''clustering algorithm'''
            object_idxs = torch.nonzero(point_semantic_preds > 1).view(-1)
            batch_idxs_ = batch_idxs[object_idxs]
            batch_offsets_ = utils.get_batch_offsets(batch_idxs_,
                                                     input['batch_size'])
            coords_ = coords[object_idxs]
            pt_offsets_ = point_offset_preds[-1][object_idxs]

            semantic_preds_cpu = point_semantic_preds[object_idxs].int().cpu()

            idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(
                coords_ + pt_offsets_, batch_idxs_, batch_offsets_,
                self.cluster_radius, self.cluster_shift_meanActive)
            proposals_idx_shift, proposals_offset_shift = pointgroup_ops.bfs_cluster(
                semantic_preds_cpu, idx_shift.cpu(), start_len_shift.cpu(),
                self.cluster_npoint_thre)
            proposals_idx_shift[:, 1] = object_idxs[
                proposals_idx_shift[:, 1].long()].int()
            # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
            # proposals_offset_shift: (nProposal + 1), int
            '''local proposal refinement'''
            local_proposals_idx = []
            local_proposals_offset = [0]

            # compute proposal centers
            proposal_center_coords = []
            for proposal_index in range(1, len(proposals_offset_shift)):
                proposal_point_index = proposals_idx_shift[
                    proposals_offset_shift[proposal_index - 1]:
                    proposals_offset_shift[proposal_index], 1]
                proposal_center_coords.append(
                    coords[proposal_point_index.long(), :].mean(dim=0,
                                                                keepdim=True))
            proposal_center_coords = torch.cat(proposal_center_coords, dim=0)

            # select the topk closest proposals for each proposal
            proposal_dist_mat = euclidean_dist(proposal_center_coords,
                                               proposal_center_coords)
            proposal_dist_mat[range(len(proposal_dist_mat)),
                              range(len(proposal_dist_mat))] = 100
            closest_proposals_dist, closest_proposals_index = proposal_dist_mat.topk(
                k=min(self.local_proposal['topk'], proposal_dist_mat.shape[1]),
                dim=1,
                largest=False)
            valid_closest_proposals_index = closest_proposals_dist < self.local_proposal[
                'dist_th']

            # select proposals which are the closest and in the distance threshold
            for proposal_index in range(1, len(proposals_offset_shift)):
                local_indexs = []
                local_indexs.append(proposals_idx_shift[proposals_offset_shift[
                    proposal_index -
                    1]:proposals_offset_shift[proposal_index]])
                closest_proposals_ind = closest_proposals_index[
                    proposal_index - 1, :]
                for selected_proposal_index in closest_proposals_ind[
                        valid_closest_proposals_index[proposal_index - 1, :]]:
                    local_index = proposals_idx_shift[
                        proposals_offset_shift[selected_proposal_index]:
                        proposals_offset_shift[selected_proposal_index +
                                               1]][:, 1].unsqueeze(dim=1)
                    local_indexs.append(
                        torch.cat((torch.LongTensor(
                            local_index.shape[0],
                            1).fill_(proposal_index - 1).int(),
                                   local_index.cpu()),
                                  dim=1))

                local_proposals_idx.append(torch.cat(local_indexs, dim=0))
                local_proposals_offset.append(local_proposals_offset[-1] +
                                              local_proposals_idx[-1].shape[0])

            local_proposals_idx = torch.cat(local_proposals_idx, dim=0)
            local_proposals_offset = torch.tensor(local_proposals_offset).int()

            #### proposals voxelization again
            input_feats, inp_map = self.clusters_voxelization(
                local_proposals_idx, local_proposals_offset, output_feats,
                coords, self.local_proposal['local_proposal_full_scale'],
                self.local_proposal['local_proposal_scale'], self.mode)

            #### cluster features
            if self.local_proposal['reuse_backbone_unet']:
                proposals = self.unet(input_feats)
                proposals = self.output_layer(proposals)
            else:
                proposals = self.proposal_unet(input_feats)
                proposals = self.proposal_outputlayer(proposals)
            proposals_point_features = proposals.features[
                inp_map.long()]  # (sumNPoint, C)

            ret['proposals_point_features'] = (proposals_point_features,
                                               local_proposals_idx)

            ### scatter mean point predictions
            if self.local_proposal['scatter_mean_target'] == 'prediction':
                refined_point_semantic_score = point_semantic_scores[-1]
                local_point_semantic_score = self.point_semantic(
                    proposals_point_features)
                refined_point_semantic_score[:min((local_proposals_idx[:, 1].max() + 1).item(), coords.shape[0]), :] = \
                    scatter_mean(local_point_semantic_score, local_proposals_idx[:, 1].cuda().long(), dim=0)
                point_semantic_scores.append(refined_point_semantic_score)
                point_semantic_preds = refined_point_semantic_score.max(1)[1]

                refined_point_offset_pred = point_offset_preds[-1]
                local_point_offset_pred = self.point_offset(
                    proposals_point_features)
                refined_point_offset_pred[:min((local_proposals_idx[:, 1].max() + 1).item(), coords.shape[0]), :] = \
                    scatter_mean(local_point_offset_pred, local_proposals_idx[:, 1].cuda().long(), dim=0)
                point_offset_preds.append(refined_point_offset_pred)

            ### scatter mean point features
            elif self.local_proposal['scatter_mean_target'] == 'feature':
                refined_point_features = torch.zeros_like(output_feats).cuda()
                refined_point_features[:min((local_proposals_idx[:, 1].max() + 1).item(), coords.shape[0]), :] = \
                    scatter_mean(proposals_point_features, local_proposals_idx[:, 1].cuda().long(), dim=0)
                #### filling 0 rows with output_feats
                filled_index = (refined_point_features == torch.zeros_like(
                    refined_point_features[0, :])).all(dim=1)
                refined_point_features[filled_index, :] = output_feats[
                    filled_index, :]

                ### refined point prediction
                #### refined point semantic label prediction
                point_semantic_scores.append(
                    self.point_semantic(
                        refined_point_features))  # (N, nClass), float
                point_semantic_preds = point_semantic_scores[-1].max(1)[1]

                #### point offset prediction
                point_offset_preds.append(
                    self.point_offset(
                        output_feats +
                        refined_point_features))  # (N, 3), float32

            elif self.local_proposal['scatter_mean_target'] == False:
                local_point_semantic_scores.append(
                    self.point_semantic(proposals_point_features))
                local_point_offset_preds.append(
                    self.point_offset(proposals_point_features))

                ret['local_point_semantic_scores'] = (
                    local_point_semantic_scores, local_proposals_idx)
                ret['local_point_offset_preds'] = (local_point_offset_preds,
                                                   local_proposals_idx)

            if input['test']:
                if self.local_proposal['scatter_mean_target'] == False:
                    local_point_semantic_score = self.point_semantic(
                        proposals_point_features)
                    local_point_offset_pred = self.point_offset(
                        proposals_point_features)

                    local_point_semantic_score = scatter_mean(
                        local_point_semantic_score,
                        local_proposals_idx[:, 1].cuda().long(),
                        dim=0)

                    point_semantic_score = point_semantic_scores[-1]
                    point_semantic_score[:local_point_semantic_score.shape[
                        0], :] += local_point_semantic_score
                    point_semantic_preds = point_semantic_score.max(1)[1]

                    point_offset_pred = point_offset_preds[-1]
                    local_point_offset_pred = scatter_mean(
                        local_point_offset_pred,
                        local_proposals_idx[:, 1].cuda().long(),
                        dim=0)
                    point_offset_pred[:local_point_offset_pred.
                                      shape[0], :] = local_point_offset_pred

                self.cluster_sets = 'Q'
                scores, proposals_idx, proposals_offset = self.pointgroup_cluster_algorithm(
                    coords, point_offset_pred, point_semantic_preds,
                    batch_idxs, input['batch_size'])
                ret['proposal_scores'] = (scores, proposals_idx,
                                          proposals_offset)

        ret['point_semantic_scores'] = point_semantic_scores
        ret['point_offset_preds'] = point_offset_preds
        ret['point_features'] = output_feats

        return ret
    def forward(self, input, input_map, coords, rgb, ori_coords, batch_idxs, batch_offsets, epoch):
        '''
        :param input_map: (N), int, cuda
        :param coords: (N, 3), float, cuda
        :param batch_idxs: (N), int, cuda
        :param batch_offsets: (B + 1), int, cuda
        '''
        ret = {}

        batch_idxs = batch_idxs.squeeze()

        point_offset_preds = []
        point_semantic_scores = []

        proposals_confidence_preds = []
        proposals_idx_shifts = []
        proposals_offset_shifts = []

        voxel_feats = pointgroup_ops.voxelization(input['pt_feats'], input['v2p_map'], input['mode']) # (M, C), float, cuda

        input_ = spconv.SparseConvTensor(
            voxel_feats, input['voxel_coords'],
            input['spatial_shape'], input['batch_size']
        )
        output = self.input_conv(input_)
        output = self.unet(output)
        output = self.output_layer(output)
        output_feats = output.features[input_map.long()]
        output_feats = output_feats.squeeze(dim=0)

        ### point prediction
        #### point semantic label prediction
        point_semantic_scores.append(self.point_semantic(output_feats))  # (N, nClass), float
        # point_semantic_preds = semantic_scores
        point_semantic_preds = point_semantic_scores[-1].max(1)[1]

        #### point offset prediction
        point_offset_preds.append(self.point_offset(output_feats))  # (N, 3), float32

        point_features = output_feats.clone()

        if (epoch > self.prepare_epochs):

            for _ in range(self.proposal_refinement['refine_times']):
                #### get prooposal clusters
                object_idxs = torch.nonzero(point_semantic_preds > 1).view(-1)

                batch_idxs_ = batch_idxs[object_idxs]
                batch_offsets_ = utils.get_batch_offsets(batch_idxs_, input['batch_size'])
                coords_ = coords[object_idxs]
                pt_offsets_ = point_offset_preds[-1][object_idxs]

                semantic_preds_cpu = point_semantic_preds[object_idxs].int().cpu()

                idx_shift, start_len_shift = pointgroup_ops.ballquery_batch_p(coords_ + pt_offsets_, batch_idxs_,
                                                                              batch_offsets_, self.cluster_radius,
                                                                              self.cluster_shift_meanActive)
                proposals_idx_shift, proposals_offset_shift = pointgroup_ops.bfs_cluster(semantic_preds_cpu,
                                                                                         idx_shift.cpu(),
                                                                                         start_len_shift.cpu(),
                                                                                         self.cluster_npoint_thre)
                proposals_idx_shift[:, 1] = object_idxs[proposals_idx_shift[:, 1].long()].int()
                # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
                # proposals_offset_shift: (nProposal + 1), int

                if proposals_idx_shift.shape[0] == 0:
                    continue

                #### proposals voxelization again
                input_feats, inp_map = self.clusters_voxelization(
                    proposals_idx_shift, proposals_offset_shift, output_feats,
                    coords, self.proposal_refinement['proposal_refine_full_scale'],
                    self.proposal_refinement['proposal_refine_scale'], self.mode
                )

                #### proposal features
                proposals = self.proposal_unet(input_feats)
                proposals = self.proposal_outputlayer(proposals)
                proposal_feats = proposals.features[inp_map.long()]  # (sumNPoint, C)
                proposal_feats = pointgroup_ops.roipool(proposal_feats, proposals_offset_shift.cuda())  # (nProposal, C)

                proposals_confidence_preds.append(self.proposal_confidence_linear(proposal_feats))  # (nProposal, 1)
                proposals_idx_shifts.append(proposals_idx_shift)
                proposals_offset_shifts.append(proposals_offset_shift)

                proposal_pts_idxs = proposals_idx_shift[proposals_offset_shift[:-1].long()][:, 1].cuda()

                if len(proposal_pts_idxs) == 0:
                    continue

                proposal_batch_idxs = proposal_pts_idxs.clone()
                for _batch_idx in range(len(batch_offsets_) - 1, 0, -1):
                    proposal_batch_idxs[proposal_pts_idxs < batch_offsets_[_batch_idx]] = _batch_idx

                refined_point_features = []
                for _batch_idx in range(1, len(batch_offsets)):
                    key_input = proposal_feats[proposal_batch_idxs == _batch_idx, :].unsqueeze(dim=0)
                    query_input = output_feats[batch_offsets[_batch_idx - 1]:batch_offsets[_batch_idx],
                                  :].unsqueeze(dim=0).permute(0, 2, 1)
                    # query_input = output_feats[batch_offsets[_batch_idx - 1]:batch_offsets[_batch_idx-1]+100,
                    #               :].unsqueeze(dim=0).permute(0, 2, 1)
                    point_refined_feature, _ = self.proposal_transformer(
                        src=key_input,
                        query_embed=query_input,
                    )

                    refined_point_features.append(point_refined_feature.squeeze(dim=0))

                refined_point_features = torch.cat(refined_point_features, dim=0)
                assert refined_point_features.shape[0] == point_features.shape[
                    0], 'point wise features have wrong point numbers'

                point_features = refined_point_features.clone()

                ### refined point prediction
                #### refined point semantic label prediction
                point_semantic_scores.append(self.point_semantic(refined_point_features))  # (N, nClass), float
                point_semantic_preds = point_semantic_scores[-1].max(1)[1]

                #### point offset prediction
                point_offset_preds.append(self.point_offset(refined_point_features))  # (N, 3), float32

            if (epoch == self.test_epoch) and input['test']:
                self.cluster_sets = 'Q'
                scores, proposals_idx, proposals_offset = self.pointgroup_cluster_algorithm(
                    coords, point_offset_preds[-1], point_semantic_preds,
                    batch_idxs, input['batch_size']
                )
                ret['proposal_scores'] = (scores, proposals_idx, proposals_offset)

            ret['proposal_confidences'] = (proposals_confidence_preds, proposals_idx_shifts, proposals_offset_shifts)

        ret['point_semantic_scores'] = point_semantic_scores
        ret['point_offset_preds'] = point_offset_preds

        return ret