def test(config, intensity=False): """Test point cloud data loader. """ from torch.utils.data import DataLoader from lib.utils import Timer import open3d as o3d def make_pcd(coords, feats): pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(coords[:, :3].float().numpy()) pcd.colors = o3d.utility.Vector3dVector(feats[:, :3].numpy() / 255) if intensity: pcd.intensities = o3d.utility.Vector3dVector(feats[:, 3:3].numpy()) return pcd timer = Timer() DatasetClass = FacilityArea5Dataset transformations = [ t.RandomHorizontalFlip(DatasetClass.ROTATION_AXIS, DatasetClass.IS_TEMPORAL), t.ChromaticAutoContrast(), t.ChromaticTranslation(config.data_aug_color_trans_ratio), t.ChromaticJitter(config.data_aug_color_jitter_std), ] dataset = DatasetClass(config, prevoxel_transform=t.ElasticDistortion( DatasetClass.ELASTIC_DISTORT_PARAMS), input_transform=t.Compose(transformations), augment_data=True, cache=True, elastic_distortion=True) data_loader = DataLoader( dataset=dataset, collate_fn=t.cfl_collate_fn_factory(limit_numpoints=False), batch_size=1, shuffle=True) # Start from index 1 iter = data_loader.__iter__() for i in range(100): timer.tic() coords, feats, labels = iter.next() pcd = make_pcd(coords, feats) o3d.visualization.draw_geometries([pcd]) print(timer.toc())
def initialize_data_loader(DatasetClass, config, phase, num_workers, shuffle, repeat, augment_data, batch_size, limit_numpoints, input_transform=None, target_transform=None): if isinstance(phase, str): phase = str2datasetphase_type(phase) if config.return_transformation: collate_fn = t.cflt_collate_fn_factory(limit_numpoints) else: collate_fn = t.cfl_collate_fn_factory(limit_numpoints) prevoxel_transform_train = [] if augment_data: prevoxel_transform_train.append( t.ElasticDistortion(DatasetClass.ELASTIC_DISTORT_PARAMS)) if len(prevoxel_transform_train) > 0: prevoxel_transforms = t.Compose(prevoxel_transform_train) else: prevoxel_transforms = None input_transforms = [] if input_transform is not None: input_transforms += input_transform if augment_data: input_transforms += [ t.RandomDropout(0.2), t.RandomHorizontalFlip(DatasetClass.ROTATION_AXIS, DatasetClass.IS_TEMPORAL), t.ChromaticAutoContrast(), t.ChromaticTranslation(config.data_aug_color_trans_ratio), t.ChromaticJitter(config.data_aug_color_jitter_std), # t.HueSaturationTranslation(config.data_aug_hue_max, config.data_aug_saturation_max), ] if len(input_transforms) > 0: input_transforms = t.Compose(input_transforms) else: input_transforms = None dataset = DatasetClass(config, prevoxel_transform=prevoxel_transforms, input_transform=input_transforms, target_transform=target_transform, cache=config.cache_data, augment_data=augment_data, phase=phase) data_args = { 'dataset': dataset, 'num_workers': num_workers, 'batch_size': batch_size, 'collate_fn': collate_fn, } if repeat: data_args['sampler'] = InfSampler(dataset, shuffle) else: data_args['shuffle'] = shuffle data_loader = DataLoader(**data_args) return data_loader