def crop_pc(points, labels, search_tree, pick_idx):
     # crop a fixed size point cloud for training
     center_point = points[pick_idx, :].reshape(1, -1)
     select_idx = search_tree.query(center_point, cfg.num_points)[1][0]
     select_idx = DP.shuffle_idx(select_idx)
     select_points = points[select_idx]
     select_labels = labels[select_idx]
     return select_points, select_labels, select_idx
    def tf_map(self, batch_pc, batch_label, batch_pc_idx, batch_cloud_idx):
        features = batch_pc
        input_points = []
        input_neighbors = []
        input_pools = []
        input_up_samples = []

        for i in range(cfg.num_layers):
            neighbour_idx = DP.knn_search(batch_pc, batch_pc, cfg.k_n)
            sub_points = batch_pc[:, :batch_pc.shape[1] //
                                  cfg.sub_sampling_ratio[i], :]
            pool_i = neighbour_idx[:, :batch_pc.shape[1] //
                                   cfg.sub_sampling_ratio[i], :]
            up_i = DP.knn_search(sub_points, batch_pc, 1)
            input_points.append(batch_pc)
            input_neighbors.append(neighbour_idx)
            input_pools.append(pool_i)
            input_up_samples.append(up_i)
            batch_pc = sub_points
        input_list = input_points + input_neighbors + input_pools + input_up_samples
        input_list += [features, batch_label, batch_pc_idx, batch_cloud_idx]
        return input_list
    def __init__(self, mode, test_id=None, batch_size=20, data_list=None):
        self.name = 'SemanticKITTI'
        self.dataset_path = '/tmp2/tsunghan/PCL_Seg_data/sequences_0.06'
        self.batch_size = batch_size
        self.num_classes = cfg.num_classes
        self.ignored_labels = np.sort([0])

        self.seq_list = np.sort(os.listdir(self.dataset_path))
        if test_id is not None:
            self.test_scan_number = test_id
            self.data_list = DP.get_file_list(self.dataset_path, [test_id])
        else:
            self.data_list = data_list
        self.data_list = sorted(self.data_list)
    def __init__(self, mode, data_list=None):
        self.name = 'SemanticKITTI'
        self.dataset_path = '/tmp2/tsunghan/PCL_Seg_data/sequences_0.06'

        self.num_classes = cfg.num_classes
        self.ignored_labels = np.sort([0])

        self.mode = mode
        if data_list is None:
            if mode == 'training':
                seq_list = ['00', '01', '02', '03', '04', '05', '06', '07', '09', '10']
            elif mode == 'validation':
                seq_list = ['08']
            self.data_list = DP.get_file_list(self.dataset_path, seq_list)
        else:
            self.data_list = data_list
        self.data_list = sorted(self.data_list)
    seq_path_out = join(output_path, seq_id)
    pc_path = join(seq_path, 'velodyne')
    pc_path_out = join(seq_path_out, 'velodyne')
    KDTree_path_out = join(seq_path_out, 'KDTree')
    os.makedirs(seq_path_out) if not exists(seq_path_out) else None
    os.makedirs(pc_path_out) if not exists(pc_path_out) else None
    os.makedirs(KDTree_path_out) if not exists(KDTree_path_out) else None

    if int(seq_id) < 11:
        label_path = join(seq_path, 'labels')
        label_path_out = join(seq_path_out, 'labels')
        os.makedirs(label_path_out) if not exists(label_path_out) else None
        scan_list = np.sort(os.listdir(pc_path))
        for scan_id in scan_list:
            print(scan_id)
            points = DP.load_pc_kitti(join(pc_path, scan_id))
            labels = DP.load_label_kitti(
                join(label_path,
                     str(scan_id[:-4]) + '.label'), remap_lut)
            sub_points, sub_labels = DP.grid_sub_sampling(points,
                                                          labels=labels,
                                                          grid_size=grid_size)
            search_tree = KDTree(sub_points)
            KDTree_save = join(KDTree_path_out, str(scan_id[:-4]) + '.pkl')
            np.save(join(pc_path_out, scan_id)[:-4], sub_points)
            np.save(join(label_path_out, scan_id)[:-4], sub_labels)
            with open(KDTree_save, 'wb') as f:
                pickle.dump(search_tree, f)
            if seq_id == '08':
                proj_path = join(seq_path_out, 'proj')
                os.makedirs(proj_path) if not exists(proj_path) else None
 def get_class_weight(self):
     return DP.get_class_weights(self.dataset_path, self.data_list, self.num_classes)