Example #1
0
    def __init__(self,
                 data_paths,
                 input_transform=None,
                 target_transform=None,
                 data_root='/',
                 explicit_rotation=-1,
                 ignore_label=255,
                 return_transformation=False,
                 augment_data=False,
                 elastic_distortion=False,
                 config=None,
                 **kwargs):

        self.augment_data = augment_data
        self.elastic_distortion = elastic_distortion
        self.config = config
        VoxelizationDatasetBase.__init__(
            self,
            data_paths,
            input_transform=input_transform,
            target_transform=target_transform,
            cache=cache,
            data_root=data_root,
            ignore_mask=ignore_label,
            return_transformation=return_transformation)

        self.sparse_voxelizer = SparseVoxelizer(
            voxel_size=self.VOXEL_SIZE,
            clip_bound=self.CLIP_BOUND,
            use_augmentation=augment_data,
            scale_augmentation_bound=self.SCALE_AUGMENTATION_BOUND,
            rotation_augmentation_bound=self.ROTATION_AUGMENTATION_BOUND,
            translation_augmentation_ratio_bound=self.
            TRANSLATION_AUGMENTATION_RATIO_BOUND,
            rotation_axis=self.LOCFEAT_IDX,
            ignore_label=ignore_label)

        # map labels not evaluated to ignore_label
        label_map = {}
        n_used = 0
        for l in range(self.NUM_LABELS):
            if l in self.IGNORE_LABELS:
                label_map[l] = self.ignore_mask
            else:
                label_map[l] = n_used
                n_used += 1
        label_map[self.ignore_mask] = self.ignore_mask
        self.label_map = label_map
        self.NUM_LABELS -= len(self.IGNORE_LABELS)
Example #2
0
    def __init__(self,
                 config,
                 train=True,
                 download=False,
                 data_precent=1.0,
                 whole_scene=True):
        super().__init__()
        if not whole_scene:
            self.data_precent = data_precent
            self.folder = "indoor3d_sem_seg_hdf5_data"
            self.data_path = "/data/eva_share_users/zhaotianchen/"
            self.data_dir = os.path.join(self.data_path, self.folder)
            self.url = (
                "https://shapenet.cs.stanford.edu/media/indoor3d_sem_seg_hdf5_data.zip"
            )
            self.train = train
            self.split = "train" if self.train else "val"

            self.NUM_IN_CHANNEL = 9
            self.NUM_LABELS = 13
            self.IGNORE_LABELS = [-1]  # labels that are not evaluated
            self.NEED_PRED_POSTPROCESSING = False

            if self.train:
                self.data_dict = torch.load(
                    os.path.join(self.data_dir, "s3dis_train.pth"), 'cpu')
            else:
                self.data_dict = torch.load(
                    os.path.join(self.data_dir, "s3dis_val.pth"), 'cpu')

            # self.data_dict = torch.load(os.path.join(self.data_dir, "s3dis_debug.pth"), 'cpu')

            # DEBUG: dirty fixing
            self.data_dict['data'] = np.concatenate(
                self.data_dict['data'], axis=0)  # [batches, 4096, 9]
            self.data_dict['label'] = np.concatenate(
                self.data_dict['label'], axis=0)  # [batches, 4096, 9]

            self.sparse_voxelizer = SparseVoxelizer(
                voxel_size=config.voxel_size,
                clip_bound=None,
                use_augmentation=False,
                scale_augmentation_bound=None,
                rotation_augmentation_bound=None,
                translation_augmentation_ratio_bound=None,
                rotation_axis=0,  # this isn't actually used
                ignore_label=-1)
        else:
            self.whole_scene = whole_scene  # load-whole scene

            self.NUM_IN_CHANNEL = 9
            self.NUM_LABELS = 13
            self.IGNORE_LABELS = [-1]  # labels that are not evaluated
            self.NEED_PRED_POSTPROCESSING = False

            self.train = train
            self.split = "train" if self.train else "val"

            if self.train:
                area_ids = [1, 2, 3, 4, 6]
            else:
                area_ids = [5]

            # area_ids = [1,2,3,4,5,6]

            self.data_path = "/home/zhaotianchen/project/point-transformer/pvcnn/data/s3dis/pointcnn"
            area_names = ["Area_{}".format(i) for i in area_ids]

            self.filelists = []
            for area_name in area_names:
                scene_names = os.listdir(
                    os.path.join(self.data_path, area_name))
                for scene_name in scene_names:
                    self.filelists.append(os.path.join(area_name, scene_name))

            self.sparse_voxelizer = SparseVoxelizer(
                voxel_size=config.voxel_size,
                clip_bound=None,
                use_augmentation=False,
                scale_augmentation_bound=None,
                rotation_augmentation_bound=None,
                translation_augmentation_ratio_bound=None,
                rotation_axis=0,  # this isn't actually used
                ignore_label=-1)
Example #3
0
class SparseVoxelizationDataset(VoxelizationDatasetBase):
    """This dataset loads RGB point clouds and their labels as a list of points
  and voxelizes the pointcloud with sufficient data augmentation.
  """
    # Voxelization arguments
    CLIP_BOUND = None
    VOXEL_SIZE = 0.05  # 5cm

    # Augmentation arguments
    SCALE_AUGMENTATION_BOUND = (0.9, 1.1)
    ROTATION_AUGMENTATION_BOUND = ((-np.pi / 6, np.pi / 6), (-np.pi, np.pi),
                                   (-np.pi / 6, np.pi / 6))
    TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (-0.05, 0.05), (-0.2,
                                                                         0.2))
    ELASTIC_DISTORT_PARAMS = None
    PREVOXELIZE_VOXEL_SIZE = None

    def __init__(self,
                 data_paths,
                 input_transform=None,
                 target_transform=None,
                 data_root='/',
                 explicit_rotation=-1,
                 ignore_label=255,
                 return_transformation=False,
                 augment_data=False,
                 elastic_distortion=False,
                 config=None,
                 **kwargs):

        self.augment_data = augment_data
        self.elastic_distortion = elastic_distortion
        self.config = config
        VoxelizationDatasetBase.__init__(
            self,
            data_paths,
            input_transform=input_transform,
            target_transform=target_transform,
            cache=cache,
            data_root=data_root,
            ignore_mask=ignore_label,
            return_transformation=return_transformation)

        self.sparse_voxelizer = SparseVoxelizer(
            voxel_size=self.VOXEL_SIZE,
            clip_bound=self.CLIP_BOUND,
            use_augmentation=augment_data,
            scale_augmentation_bound=self.SCALE_AUGMENTATION_BOUND,
            rotation_augmentation_bound=self.ROTATION_AUGMENTATION_BOUND,
            translation_augmentation_ratio_bound=self.
            TRANSLATION_AUGMENTATION_RATIO_BOUND,
            rotation_axis=self.LOCFEAT_IDX,
            ignore_label=ignore_label)

        # map labels not evaluated to ignore_label
        label_map = {}
        n_used = 0
        for l in range(self.NUM_LABELS):
            if l in self.IGNORE_LABELS:
                label_map[l] = self.ignore_mask
            else:
                label_map[l] = n_used
                n_used += 1
        label_map[self.ignore_mask] = self.ignore_mask
        self.label_map = label_map
        self.NUM_LABELS -= len(self.IGNORE_LABELS)

    def get_output_id(self, iteration):
        return self.data_paths[iteration]

    def convert_mat2cfl(self, mat):
        # Generally, xyz,rgb,label
        return mat[:, :3], mat[:, 3:-1], mat[:, -1]

    def _augment_elastic_distortion(self, pointcloud):
        if self.ELASTIC_DISTORT_PARAMS is not None:
            if random.random() < 0.95:
                for granularity, magnitude in self.ELASTIC_DISTORT_PARAMS:
                    pointcloud = t.elastic_distortion(pointcloud, granularity,
                                                      magnitude)
        return pointcloud

    def __getitem__(self, index):
        if self.explicit_rotation > 1:
            rotation_space = np.linspace(-np.pi, np.pi,
                                         self.explicit_rotation + 1)
            rotation_angle = rotation_space[index % self.explicit_rotation]
            index //= self.explicit_rotation
        else:
            rotation_angle = None
        pointcloud, center = self.load_ply(index)
        if self.PREVOXELIZE_VOXEL_SIZE is not None:
            inds = ME.SparseVoxelize(pointcloud[:, :3] /
                                     self.PREVOXELIZE_VOXEL_SIZE,
                                     return_index=True)
            pointcloud = pointcloud[inds]

        if self.elastic_distortion:
            pointcloud = self._augment_elastic_distortion(pointcloud)

        # import open3d as o3d
        # from lib.open3d_utils import make_pointcloud
        # pcd = make_pointcloud(np.floor(pointcloud[:, :3] / self.PREVOXELIZE_VOXEL_SIZE))
        # o3d.draw_geometries([pcd])

        coords, feats, labels = self.convert_mat2cfl(pointcloud)
        outs = self.sparse_voxelizer.voxelize(
            coords,
            feats,
            labels,
            center=center,
            rotation_angle=rotation_angle,
            return_transformation=self.return_transformation)
        if self.return_transformation:
            coords, feats, labels, transformation = outs
            transformation = np.expand_dims(transformation, 0)
        else:
            coords, feats, labels = outs

        # map labels not used for evaluation to ignore_label
        if self.input_transform is not None:
            coords, feats, labels = self.input_transform(coords, feats, labels)
        if self.target_transform is not None:
            coords, feats, labels = self.target_transform(
                coords, feats, labels)
        if self.IGNORE_LABELS is not None:
            labels = np.array([self.label_map[x] for x in labels],
                              dtype=np.int)

        return_args = [coords, feats, labels]
        if self.return_transformation:
            return_args.extend([
                pointcloud.astype(np.float32),
                transformation.astype(np.float32)
            ])
        return tuple(return_args)

    def cleanup(self):
        self.sparse_voxelizer.cleanup()
Example #4
0
class S3DIS(data.Dataset):
    def __init__(self,
                 config,
                 train=True,
                 download=False,
                 data_precent=1.0,
                 whole_scene=True):
        super().__init__()
        if not whole_scene:
            self.data_precent = data_precent
            self.folder = "indoor3d_sem_seg_hdf5_data"
            self.data_path = "/data/eva_share_users/zhaotianchen/"
            self.data_dir = os.path.join(self.data_path, self.folder)
            self.url = (
                "https://shapenet.cs.stanford.edu/media/indoor3d_sem_seg_hdf5_data.zip"
            )
            self.train = train
            self.split = "train" if self.train else "val"

            self.NUM_IN_CHANNEL = 9
            self.NUM_LABELS = 13
            self.IGNORE_LABELS = [-1]  # labels that are not evaluated
            self.NEED_PRED_POSTPROCESSING = False

            if self.train:
                self.data_dict = torch.load(
                    os.path.join(self.data_dir, "s3dis_train.pth"), 'cpu')
            else:
                self.data_dict = torch.load(
                    os.path.join(self.data_dir, "s3dis_val.pth"), 'cpu')

            # self.data_dict = torch.load(os.path.join(self.data_dir, "s3dis_debug.pth"), 'cpu')

            # DEBUG: dirty fixing
            self.data_dict['data'] = np.concatenate(
                self.data_dict['data'], axis=0)  # [batches, 4096, 9]
            self.data_dict['label'] = np.concatenate(
                self.data_dict['label'], axis=0)  # [batches, 4096, 9]

            self.sparse_voxelizer = SparseVoxelizer(
                voxel_size=config.voxel_size,
                clip_bound=None,
                use_augmentation=False,
                scale_augmentation_bound=None,
                rotation_augmentation_bound=None,
                translation_augmentation_ratio_bound=None,
                rotation_axis=0,  # this isn't actually used
                ignore_label=-1)
        else:
            self.whole_scene = whole_scene  # load-whole scene

            self.NUM_IN_CHANNEL = 9
            self.NUM_LABELS = 13
            self.IGNORE_LABELS = [-1]  # labels that are not evaluated
            self.NEED_PRED_POSTPROCESSING = False

            self.train = train
            self.split = "train" if self.train else "val"

            if self.train:
                area_ids = [1, 2, 3, 4, 6]
            else:
                area_ids = [5]

            # area_ids = [1,2,3,4,5,6]

            self.data_path = "/home/zhaotianchen/project/point-transformer/pvcnn/data/s3dis/pointcnn"
            area_names = ["Area_{}".format(i) for i in area_ids]

            self.filelists = []
            for area_name in area_names:
                scene_names = os.listdir(
                    os.path.join(self.data_path, area_name))
                for scene_name in scene_names:
                    self.filelists.append(os.path.join(area_name, scene_name))

            self.sparse_voxelizer = SparseVoxelizer(
                voxel_size=config.voxel_size,
                clip_bound=None,
                use_augmentation=False,
                scale_augmentation_bound=None,
                rotation_augmentation_bound=None,
                translation_augmentation_ratio_bound=None,
                rotation_axis=0,  # this isn't actually used
                ignore_label=-1)

        # if download and not os.path.exists(self.data_dir):
        # zipfile = os.path.join(self.data_dir, os.path.basename(self.url))
        # subprocess.check_call(
        # shlex.split("curl {} -o {}".format(self.url, zipfile))
        # )

        # subprocess.check_call(
        # shlex.split("unzip {} -d {}".format(zipfile, self.data_dir))
        # )

        # subprocess.check_call(shlex.split("rm {}".format(zipfile)))

        # all_files = _get_data_files(os.path.join(self.data_dir, "all_files.txt"))
        # room_filelist = _get_data_files(
        # os.path.join(self.data_dir, "room_filelist.txt")
        # )

        # data_batchlist, label_batchlist = [], []
        # for f in all_files:
        # data, label = _load_data_file(os.path.join(self.data_path, f))
        # data_batchlist.append(data)
        # label_batchlist.append(label)

        # data_batches = np.concatenate(data_batchlist, 0)
        # labels_batches = np.concatenate(label_batchlist, 0)

        # test_area = "Area_5"
        # train_idxs, test_idxs = [], []

        # d = {}
        # d_val = {}
        # for i,room_name in enumerate(room_filelist):
        # if room_name not in d.keys():
        # d[room_name] = []
        # d[room_name].append(i)

        # test_keys = []
        # for k in d.keys():
        # if test_area in k:
        # test_keys.append(k)

        # for k in test_keys:
        # d_val[k] = d.pop(k)

        # train_names = d.keys()
        # val_names = d_val.keys()

        # d2save_train = {}
        # d2save_train['data'] = []
        # d2save_train['label'] = []
        # d2save_train['scene_name'] = list(train_names)
        # d2save_val = {}
        # d2save_val['data'] = []
        # d2save_val['label'] = []
        # d2save_val['scene_name'] = list(val_names)

        # for _, idx in d.items():
        # d2save_train['data'].append(data_batches[idx])
        # d2save_train['label'].append(labels_batches[idx])

        # for _, idx in d_val.items():
        # d2save_val['data'].append(data_batches[idx])
        # d2save_val['label'].append(labels_batches[idx])

        # ======================================================

        # # if self.train:
        # # self.points = data_batches[train_idxs, ...]
        # # self.labels = labels_batches[train_idxs, ...]
        # # else:
        # # self.points = data_batches[test_idxs, ...]
        # # self.labels = labels_batches[test_idxs, ...]

        # torch.save(d2save_train, os.path.join(self.data_dir, 's3dis_train.pth'))
        # torch.save(d2save_val, os.path.join(self.data_dir, 's3dis_val.pth'))
        # d_debug = {}
        # for k in d2save_train.keys():
        # d_debug[k] = d2save_train[k][:3]
        # torch.save(d_debug, os.path.join(self.data_dir, 's3dis_debug.pth'))

    def __getitem__(self, index):

        if not self.whole_scene:
            data = self.data_dict['data'][index].reshape([-1, 9])
            target = self.data_dict['label'][index].reshape([-1])
            data_ = data

            if 'train' in self.split:
                theta = np.random.uniform(0, 2 * np.pi)
                scale_factor = np.random.uniform(0.95, 1.05)
                rot_mat = np.array([[np.cos(theta),
                                     np.sin(theta), 0],
                                    [-np.sin(theta),
                                     np.cos(theta), 0], [0, 0, 1]])

                data[:, :3] = np.dot(data_[:, :3], rot_mat) * scale_factor
            else:
                pass

            coords = data[:, :3] - (data[:, :3]).min(0)
            outs = self.sparse_voxelizer.voxelize(
                coords,  # debug, not sure
                data_,
                target,
                center=None,
                rotation_angle=None,
                return_transformation=False)

            # d = {}
            # d['label'] = target
            # d['origin_pc'] = data_[:,:3]
            # d['v_coord'] = outs[0]
            # d['v_label'] = outs[2]
            # torch.save(d, "./plot/data-s3dis.pth")
            # import ipdb; ipdb.set_trace()

            (coords, feats, labels, unique_map, inverse_map) = outs
        else:
            root_path = os.path.join(self.data_path, self.filelists[index])
            data = np.load(os.path.join(root_path,
                                        "xyzrgb.npy")).astype(np.float32)

            # Apply AUG
            if 'train' in self.split:
                theta = np.random.uniform(0, 2 * np.pi)
                scale_factor = np.random.uniform(0.95, 1.05)
                rot_mat = np.array([[np.cos(theta),
                                     np.sin(theta), 0],
                                    [-np.sin(theta),
                                     np.cos(theta), 0], [0, 0, 1]])

                data[:, :3] = np.dot(data[:, :3], rot_mat) * scale_factor
            else:
                pass

            label = np.load(os.path.join(root_path,
                                         "label.npy")).astype(np.float32)
            label = label.squeeze(-1)
            normalized_xyz = (data[:, :3] / data[:, :3].max() - 0.5)
            data = np.concatenate([data, normalized_xyz], axis=-1)
            coords = data[:, :3]
            coords[:, :3] -= np.expand_dims(coords[:, :3].min(0), axis=0)
            outs = self.sparse_voxelizer.voxelize(coords,
                                                  data,
                                                  label,
                                                  center=None,
                                                  rotation_angle=None,
                                                  return_transformation=False)
            # d = {}
            # d['label'] = label
            # d['origin_pc'] = coords[:,:3]
            # d['v_coord'] = outs[0]
            # d['v_label'] = outs[2]
            # torch.save(d, "./plot/data-s3dis-whole.pth")
            # print(label.shape, outs[2].shape)

        assert isinstance(outs, tuple)
        return outs

    def __len__(self):
        if not self.whole_scene:
            return len(self.data_dict['data'])
        else:
            return len(self.filelists)

    def cleanup(self):
        self.sparse_voxelizer.cleanup()

    def reorder_result(self, result):
        return result
Example #5
0
  def __init__(self,
               data_paths,
               input_transform=None,
               target_transform=None,
               data_root='/',
               explicit_rotation=-1,
               # ignore_label=255,
               ignore_label= -1,
               return_transformation=False,
               augment_data=False,
               elastic_distortion=False,
               config=None,
               phase=DatasetPhase.Train,
               **kwargs):

    self.augment_data = augment_data
    self.elastic_distortion = elastic_distortion
    self.config = config
    self.phase = phase
    VoxelizationDatasetBase.__init__(
        self,
        data_paths,
        input_transform=input_transform,
        target_transform=target_transform,
        cache=cache,
        data_root=data_root,
        ignore_mask=ignore_label,
        return_transformation=return_transformation)

    self.sparse_voxelizer = SparseVoxelizer(
        voxel_size=config.voxel_size,
        clip_bound=self.CLIP_BOUND,
        use_augmentation=augment_data,
        scale_augmentation_bound=self.SCALE_AUGMENTATION_BOUND,
        rotation_augmentation_bound=self.ROTATION_AUGMENTATION_BOUND,
        translation_augmentation_ratio_bound=self.TRANSLATION_AUGMENTATION_RATIO_BOUND,
        rotation_axis=self.LOCFEAT_IDX,
        ignore_label=ignore_label)

    # loading the whole dataset from the pth file but not all ply files
    # if load-whole: the pre-defined pth already processed the label mapping, so should skip this part
    if self.phase == DatasetPhase.Train:
        self.split = 'train'
    elif self.phase == DatasetPhase.Val:
        self.split = 'val'
    else:
        raise NotImplementedError

    if hasattr(config, 'use_aux') and config.use_aux:
        # assert config.load_whole is True  # load_whole to avoid buggy, since the loaded file order are not consistent(e.g. scene 1 with different idx)

        aux_path = self.config.log_dir + 'preds_{}.pth'.format(self.split)
        assert os.path.exists(aux_path),  "No Aux file found, link the `preds_{}` file into the log_dir"
        self.aux_data = torch.load(aux_path, 'cpu')['pred']

    if hasattr(config, "load_whole") and config.load_whole:
        self.load_whole = True
        datapath = "/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles/"

        datapath = datapath + 'new_{}.pth'.format(self.split)
        self.data_dict = torch.load(datapath, 'cpu')
    else:
        self.load_whole = False

    if not self.load_whole:
        # map labels not evaluated to ignore_label
        label_map = {}
        n_used = 0
        for l in range(self.NUM_LABELS):
          if l in self.IGNORE_LABELS:
            label_map[l] = self.ignore_mask
          else:
            label_map[l] = n_used
            n_used += 1
        label_map[self.ignore_mask] = self.ignore_mask
        self.label_map = label_map
        self.NUM_LABELS -= len(self.IGNORE_LABELS)
    else:
        # for load-whole, still need to change the num_labels to 20
        self.NUM_LABELS -= len(self.IGNORE_LABELS)
Example #6
0
class SparseVoxelizationDataset(VoxelizationDatasetBase):
  """This dataset loads RGB point clouds and their labels as a list of points
  and voxelizes the pointcloud with sufficient data augmentation.
  """
  # Voxelization arguments
  CLIP_BOUND = None
  VOXEL_SIZE = 0.05  # 5cm

  # Augmentation arguments
  SCALE_AUGMENTATION_BOUND = (0.9, 1.1)
  ROTATION_AUGMENTATION_BOUND = ((-np.pi / 6, np.pi / 6), (-np.pi, np.pi), (-np.pi / 6, np.pi / 6))
  TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (-0.05, 0.05), (-0.2, 0.2))
  ELASTIC_DISTORT_PARAMS = None
  PREVOXELIZE_VOXEL_SIZE = None

  def __init__(self,
               data_paths,
               input_transform=None,
               target_transform=None,
               data_root='/',
               explicit_rotation=-1,
               # ignore_label=255,
               ignore_label= -1,
               return_transformation=False,
               augment_data=False,
               elastic_distortion=False,
               config=None,
               phase=DatasetPhase.Train,
               **kwargs):

    self.augment_data = augment_data
    self.elastic_distortion = elastic_distortion
    self.config = config
    self.phase = phase
    VoxelizationDatasetBase.__init__(
        self,
        data_paths,
        input_transform=input_transform,
        target_transform=target_transform,
        cache=cache,
        data_root=data_root,
        ignore_mask=ignore_label,
        return_transformation=return_transformation)

    self.sparse_voxelizer = SparseVoxelizer(
        voxel_size=config.voxel_size,
        clip_bound=self.CLIP_BOUND,
        use_augmentation=augment_data,
        scale_augmentation_bound=self.SCALE_AUGMENTATION_BOUND,
        rotation_augmentation_bound=self.ROTATION_AUGMENTATION_BOUND,
        translation_augmentation_ratio_bound=self.TRANSLATION_AUGMENTATION_RATIO_BOUND,
        rotation_axis=self.LOCFEAT_IDX,
        ignore_label=ignore_label)

    # loading the whole dataset from the pth file but not all ply files
    # if load-whole: the pre-defined pth already processed the label mapping, so should skip this part
    if self.phase == DatasetPhase.Train:
        self.split = 'train'
    elif self.phase == DatasetPhase.Val:
        self.split = 'val'
    else:
        raise NotImplementedError

    if hasattr(config, 'use_aux') and config.use_aux:
        # assert config.load_whole is True  # load_whole to avoid buggy, since the loaded file order are not consistent(e.g. scene 1 with different idx)

        aux_path = self.config.log_dir + 'preds_{}.pth'.format(self.split)
        assert os.path.exists(aux_path),  "No Aux file found, link the `preds_{}` file into the log_dir"
        self.aux_data = torch.load(aux_path, 'cpu')['pred']

    if hasattr(config, "load_whole") and config.load_whole:
        self.load_whole = True
        datapath = "/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles/"

        datapath = datapath + 'new_{}.pth'.format(self.split)
        self.data_dict = torch.load(datapath, 'cpu')
    else:
        self.load_whole = False

    if not self.load_whole:
        # map labels not evaluated to ignore_label
        label_map = {}
        n_used = 0
        for l in range(self.NUM_LABELS):
          if l in self.IGNORE_LABELS:
            label_map[l] = self.ignore_mask
          else:
            label_map[l] = n_used
            n_used += 1
        label_map[self.ignore_mask] = self.ignore_mask
        self.label_map = label_map
        self.NUM_LABELS -= len(self.IGNORE_LABELS)
    else:
        # for load-whole, still need to change the num_labels to 20
        self.NUM_LABELS -= len(self.IGNORE_LABELS)

  def get_output_id(self, iteration):
    return self.data_paths[iteration]

  def convert_mat2cfl(self, mat):
    # Generally, xyz,rgb,label
    return mat[:, :3], mat[:, 3:-1], mat[:, -1]

  def _augment_elastic_distortion(self, pointcloud):
    if self.ELASTIC_DISTORT_PARAMS is not None:
      random.seed(123)
      if random.random() < 0.95:
        for granularity, magnitude in self.ELASTIC_DISTORT_PARAMS:
          pointcloud = t.elastic_distortion(pointcloud, granularity, magnitude)
    return pointcloud

  def __getitem__(self, index):

    if self.explicit_rotation > 1:
      rotation_space = np.linspace(-np.pi, np.pi, self.explicit_rotation + 1)
      rotation_angle = rotation_space[index % self.explicit_rotation]
      index //= self.explicit_rotation
    else:
      rotation_angle = None

    # pointcloud, center = self.load_ply(0) # DEBUG ONLY!
    pointcloud, center = self.load_ply(index)
    pointcloud = pointcloud.astype('float32') # when load_whole, data from pth file stays float64
    if self.PREVOXELIZE_VOXEL_SIZE is not None:
      inds = ME.SparseVoxelize(pointcloud[:, :3] / self.PREVOXELIZE_VOXEL_SIZE, return_index=True)
      pointcloud = pointcloud[inds]

    if self.elastic_distortion:
      pointcloud = self._augment_elastic_distortion(pointcloud)

    # import open3d as o3d
    # from lib.open3d_utils import make_pointcloud
    # pcd = make_pointcloud(np.floor(pointcloud[:, :3] / self.PREVOXELIZE_VOXEL_SIZE))
    # o3d.draw_geometries([pcd])

    coords, feats, labels = self.convert_mat2cfl(pointcloud)

    # d = {}
    # d['coords'] = coords
    # d['feats'] = feats
    # d['labels'] = labels
    # torch.save(d,'pc.pth')

    outs = self.sparse_voxelizer.voxelize(
        coords,
        feats,
        labels,
        center=center,
        rotation_angle=rotation_angle,
        return_transformation=self.return_transformation)

    if self.return_transformation:
      coords, feats, labels, unique_map, inverse_map, transformation = outs
      transformation = np.expand_dims(transformation, 0)
    else:
      coords, feats, labels, unique_map, inverse_map = outs

    if self.config.use_aux:
        aux = self.aux_data[index]
        # check if saved is tensor
        if isinstance(aux, torch.Tensor):
            aux = aux[unique_map].numpy()
        else:
            aux = aux[unique_map]
            pass
        aux = aux + 1 # align with preds

    if self.config.is_export:
      self.input_transform = None
      self.target_transform = None

    # d = {}
    # d['coords'] = coords
    # d['feats'] = feats
    # d['labels'] = labels
    # d['aux'] = aux
    # torch.save(d,'voxel.pth')

    if self.config.use_aux:
        # cat the aux into labels
        labels = np.stack([labels, aux], axis=1)

    # map labels not used for evaluation to ignore_label
    if self.input_transform is not None:
        coords, feats, labels = self.input_transform(coords, feats, labels)
    if self.target_transform is not None:
        coords, feats, labels = self.target_transform(coords, feats, labels)
    if self.IGNORE_LABELS is not None:
        if self.load_whole:
            labels = labels - 1  # should align all labels to [-1, 20]
        else:
            labels = np.array([self.label_map[x] for x in labels], dtype=np.int)

    # for load-whole, the aux are in [0-19] 20 classes
    # but labels have [-1, 19]
    # d = {}
    # d['coords'] = coords
    # d['feats'] = feats
    # d['labels'] = labels[:,0]
    # d['aux'] = labels[:,1]
    # torch.save(d,'pre.pth')

    return_args = [coords, feats, labels, unique_map, inverse_map]
    if self.return_transformation:
      return_args.extend([pointcloud.astype(np.float32), transformation.astype(np.float32)])
    return tuple(return_args)

  def cleanup(self):
    self.sparse_voxelizer.cleanup()
    def __init__(self,
                 config,
                 train=True,
                 cylinder_voxelize=False,
                 sample_stride=1):

        self.config = config
        self.train = train
        self.split = 'train' if self.train else "val"
        self.cylinder_voxelize = cylinder_voxelize
        self.sample_stride = sample_stride

        self.NUM_IN_CHANNEL = 5
        self.NUM_LABELS = 16  # TODO: fix
        self.IGNORE_LABELS = [-1]  # DEBUG: not actually used
        self.NEED_PRED_POSTPROCESSING = False

        self.metadata_pickle_path = "/data/eva_share_users/zhaotianchen/nuscenes/nuscenes_infos_{}.pkl".format(
            self.split)
        self.label_mapping_path = "/data/eva_share_users/zhaotianchen/nuscenes/nuscenes_label_mapping.yaml"
        self.label_filename_path = "/data/eva_share_users/zhaotianchen/nuscenes/nuscenes_label_filename_{}.pkl".format(
            self.split)
        self.data_path = "/data/eva_share_users/zhaotianchen/nuscenes"

        # def __init__(self, data_path, imageset='train',
        # return_ref=False, label_mapping="nuscenes.yaml", nusc=None):
        # self.return_ref = return_ref

        # load a few elements
        with open(self.metadata_pickle_path, 'rb') as f:
            metadata = pickle.load(f)
        with open(self.label_mapping_path, 'r') as stream:
            label_mapping = yaml.safe_load(stream)
        self.learning_map = label_mapping['learning_map']
        self.nusc_infos = metadata['infos']

        # TODO: instead of loading the whole nusc pickle file, just load the label filenames
        # self.nusc = NuScenes(version='v1.0-trainval', dataroot=self.data_path, verbose=True)
        # with open("/data/eva_share_users/zhaotianchen/nuscenes/nusc.pkl", 'rb') as f:
        # self.nusc = pickle.load(f)
        # self.nusc = NuScenes(version='v1.0-trainval', dataroot=self.data_path, verbose=True)
        with open(self.label_filename_path, 'rb') as f:
            self.label_filenames = pickle.load(f)

        self.nusc_infos = self.nusc_infos[::self.sample_stride]
        self.label_filenames = self.label_filenames[::self.sample_stride]

        # self.point_cloud_dataset = in_dataset
        self.grid_size = np.asarray([480, 360, 32])
        self.rotate_aug = False
        self.flip_aug = False
        self.scale_aug = False
        self.ignore_label = 0
        self.return_test = False
        self.fixed_volume_space = False
        self.max_volume_space = [50, np.pi, 3]
        self.min_volume_space = [0, -np.pi, -5]
        self.transform = False
        self.trans_std = [0.1, 0.1, 0.1]

        min_rad = -np.pi / 4
        max_rad = np.pi / 4
        self.noise_rotation = np.random.uniform(min_rad, max_rad)

        self.sparse_voxelizer = SparseVoxelizer(
            voxel_size=config.voxel_size,
            clip_bound=None,
            use_augmentation=False,
            scale_augmentation_bound=None,
            rotation_augmentation_bound=None,
            translation_augmentation_ratio_bound=None,
            rotation_axis=0,  # this isn't actually used
            ignore_label=-1)
class Nuscenes(data.Dataset):
    def __init__(self,
                 config,
                 train=True,
                 cylinder_voxelize=False,
                 sample_stride=1):

        self.config = config
        self.train = train
        self.split = 'train' if self.train else "val"
        self.cylinder_voxelize = cylinder_voxelize
        self.sample_stride = sample_stride

        self.NUM_IN_CHANNEL = 5
        self.NUM_LABELS = 16  # TODO: fix
        self.IGNORE_LABELS = [-1]  # DEBUG: not actually used
        self.NEED_PRED_POSTPROCESSING = False

        self.metadata_pickle_path = "/data/eva_share_users/zhaotianchen/nuscenes/nuscenes_infos_{}.pkl".format(
            self.split)
        self.label_mapping_path = "/data/eva_share_users/zhaotianchen/nuscenes/nuscenes_label_mapping.yaml"
        self.label_filename_path = "/data/eva_share_users/zhaotianchen/nuscenes/nuscenes_label_filename_{}.pkl".format(
            self.split)
        self.data_path = "/data/eva_share_users/zhaotianchen/nuscenes"

        # def __init__(self, data_path, imageset='train',
        # return_ref=False, label_mapping="nuscenes.yaml", nusc=None):
        # self.return_ref = return_ref

        # load a few elements
        with open(self.metadata_pickle_path, 'rb') as f:
            metadata = pickle.load(f)
        with open(self.label_mapping_path, 'r') as stream:
            label_mapping = yaml.safe_load(stream)
        self.learning_map = label_mapping['learning_map']
        self.nusc_infos = metadata['infos']

        # TODO: instead of loading the whole nusc pickle file, just load the label filenames
        # self.nusc = NuScenes(version='v1.0-trainval', dataroot=self.data_path, verbose=True)
        # with open("/data/eva_share_users/zhaotianchen/nuscenes/nusc.pkl", 'rb') as f:
        # self.nusc = pickle.load(f)
        # self.nusc = NuScenes(version='v1.0-trainval', dataroot=self.data_path, verbose=True)
        with open(self.label_filename_path, 'rb') as f:
            self.label_filenames = pickle.load(f)

        self.nusc_infos = self.nusc_infos[::self.sample_stride]
        self.label_filenames = self.label_filenames[::self.sample_stride]

        # self.point_cloud_dataset = in_dataset
        self.grid_size = np.asarray([480, 360, 32])
        self.rotate_aug = False
        self.flip_aug = False
        self.scale_aug = False
        self.ignore_label = 0
        self.return_test = False
        self.fixed_volume_space = False
        self.max_volume_space = [50, np.pi, 3]
        self.min_volume_space = [0, -np.pi, -5]
        self.transform = False
        self.trans_std = [0.1, 0.1, 0.1]

        min_rad = -np.pi / 4
        max_rad = np.pi / 4
        self.noise_rotation = np.random.uniform(min_rad, max_rad)

        self.sparse_voxelizer = SparseVoxelizer(
            voxel_size=config.voxel_size,
            clip_bound=None,
            use_augmentation=False,
            scale_augmentation_bound=None,
            rotation_augmentation_bound=None,
            translation_augmentation_ratio_bound=None,
            rotation_axis=0,  # this isn't actually used
            ignore_label=-1)

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.nusc_infos)

    def __getitem__(self, index):
        info = self.nusc_infos[index]
        lidar_path = info['lidar_path'][16:]
        # lidar_sd_token = self.nusc.get('sample', info['token'])['data']['LIDAR_TOP']
        # lidarseg_labels_filename = os.path.join(self.data_path,
        # self.nusc.get('lidarseg', lidar_sd_token)['filename'])
        lidarseg_labels_filename = os.path.join(self.data_path,
                                                self.label_filenames[index])
        points_label = np.fromfile(lidarseg_labels_filename,
                                   dtype=np.uint8).reshape([-1, 1])
        points_label = np.vectorize(
            self.learning_map.__getitem__)(points_label)
        points = np.fromfile(os.path.join(self.data_path, lidar_path),
                             dtype=np.float32,
                             count=-1).reshape([-1, 5])

        # data_tuple = (points[:, :3], points_label.astype(np.uint8))
        # if self.return_ref:
        # data_tuple += (points[:, 3],)

        # ------ split of 2 files: pc_dataset & dataset_nuscenes.py -------

        # data = self.point_cloud_dataset[index]
        # if len(data) == 2:
        # xyz, labels = data
        # elif len(data) == 3:
        # xyz, labels, sig = data
        # if len(sig.shape) == 2: sig = np.squeeze(sig)
        # else:
        # raise Exception('Return invalid data tuple

        xyz = points[:, :3]
        labels = np.squeeze(
            points_label - 1,
            axis=-1)  # INFO: the label is 0~16, minus one adnd ignore -1
        points[:, :3] = (points[:, :3] / points[:, :3].max() - 0.5
                         )  # normalize the coordinate xyz in feature

        if not self.cylinder_voxelize:
            xyz[:, :3] -= np.expand_dims(xyz[:, :3].min(0), axis=0)
            outs = self.sparse_voxelizer.voxelize(xyz,
                                                  points,
                                                  labels,
                                                  center=None,
                                                  rotation_angle=None,
                                                  return_transformation=False)

            # for visualization:
            # d = {}
            # d['coord_pt'] = xyz
            # d['label_pt'] = labels
            # d['coord_voxel'] = outs[0]
            # d['label_voxel'] = outs[2]
            # torch.save(d, './nuscene-voxel-demo.pth')
            # import ipdb; ipdb.set_trace()
            return outs  # (coord, feat, target)

        else:  # conduct cylinder-like voxelization
            # TODO: cylinder voxelization needs extra network part to align, hard to adapt, use naive way

            # random data augmentation by rotation
            if self.rotate_aug:
                rotate_rad = np.deg2rad(np.random.random() * 360) - np.pi
                c, s = np.cos(rotate_rad), np.sin(rotate_rad)
                j = np.matrix([[c, s], [-s, c]])
                xyz[:, :2] = np.dot(xyz[:, :2], j)

            # random data augmentation by flip x , y or x+y
            if self.flip_aug:
                flip_type = np.random.choice(4, 1)
                if flip_type == 1:
                    xyz[:, 0] = -xyz[:, 0]
                elif flip_type == 2:
                    xyz[:, 1] = -xyz[:, 1]
                elif flip_type == 3:
                    xyz[:, :2] = -xyz[:, :2]
            if self.scale_aug:
                noise_scale = np.random.uniform(0.95, 1.05)
                xyz[:, 0] = noise_scale * xyz[:, 0]
                xyz[:, 1] = noise_scale * xyz[:, 1]
            # convert coordinate into polar coordinates

            if self.transform:
                noise_translate = np.array([
                    np.random.normal(0, self.trans_std[0], 1),
                    np.random.normal(0, self.trans_std[1], 1),
                    np.random.normal(0, self.trans_std[2], 1)
                ]).T

                xyz[:, 0:3] += noise_translate

            xyz_pol = cart2polar(xyz)

            max_bound_r = np.percentile(xyz_pol[:, 0], 100, axis=0)
            min_bound_r = np.percentile(xyz_pol[:, 0], 0, axis=0)
            max_bound = np.max(xyz_pol[:, 1:], axis=0)
            min_bound = np.min(xyz_pol[:, 1:], axis=0)
            max_bound = np.concatenate(([max_bound_r], max_bound))
            min_bound = np.concatenate(([min_bound_r], min_bound))
            if self.fixed_volume_space:
                max_bound = np.asarray(self.max_volume_space)
                min_bound = np.asarray(self.min_volume_space)
            # get grid index
            crop_range = max_bound - min_bound
            cur_grid_size = self.grid_size
            intervals = crop_range / (cur_grid_size - 1)

            if (intervals == 0).any(): print("Zero interval!")
            grid_ind = (np.floor(
                (np.clip(xyz_pol, min_bound, max_bound) - min_bound) /
                intervals)).astype(np.int)

            voxel_position = np.zeros(self.grid_size, dtype=np.float32)
            dim_array = np.ones(len(self.grid_size) + 1, int)
            dim_array[0] = -1
            voxel_position = np.indices(self.grid_size) * intervals.reshape(
                dim_array) + min_bound.reshape(dim_array)
            voxel_position = polar2cat(voxel_position)

            # process labels
            processed_label = np.ones(self.grid_size,
                                      dtype=np.uint8) * self.ignore_label
            label_voxel_pair = np.concatenate([grid_ind, labels], axis=1)
            label_voxel_pair = label_voxel_pair[np.lexsort(
                (grid_ind[:, 0], grid_ind[:, 1], grid_ind[:, 2])), :]
            processed_label = nb_process_label(np.copy(processed_label),
                                               label_voxel_pair)
            data_tuple = (voxel_position, processed_label)

            # center data on each voxel for PTnet
            voxel_centers = (grid_ind.astype(np.float32) +
                             0.5) * intervals + min_bound
            return_xyz = xyz_pol - voxel_centers
            return_xyz = np.concatenate((return_xyz, xyz_pol, xyz[:, :2]),
                                        axis=1)

            # if len(data) == 2:
            # return_fea = return_xyz
            # elif len(data) == 3:
            # return_fea = np.concatenate((return_xyz, sig[..., np.newaxis]), axis=1)
            return_fea = return_xyz

            if self.return_test:
                data_tuple += (grid_ind, labels, return_fea, index)
            else:
                data_tuple += (grid_ind, labels, return_fea)

            d1 = {
                'coord': grid_ind,
                'label': labels,
            }
            torch.save(d1, './nuscene-cylinder-voxel.pth')

            import ipdb
            ipdb.set_trace()

            return data_tuple

    # def __len__(self):
    # if not self.whole_scene:
    # return len(self.data_dict['data'])
    # else:
    # return len(self.filelists)

    def cleanup(self):
        self.sparse_voxelizer.cleanup()

    def reorder_result(self, result):  # used for valid
        return result