def get_inference_dataloader( dataset: Type[Dataset], transforms: Callable, batch_size: int = 16, num_workers: int = 8, pin_memory: bool = True, limit_num_samples: Optional[int] = None) -> DataLoader: if limit_num_samples is not None: np.random.seed(limit_num_samples) indices = np.random.permutation(len(dataset))[:limit_num_samples] dataset = Subset(dataset, indices) dataset = TransformedDataset(dataset, transform_fn=transforms) sampler = None if dist.is_available() and dist.is_initialized(): sampler = data_dist.DistributedSampler(dataset, shuffle=False) loader = DataLoader(dataset, shuffle=False, batch_size=batch_size, num_workers=num_workers, sampler=sampler, pin_memory=pin_memory, drop_last=False) return loader
def get_trainval_datasets(path, csv_path, train_folds, val_folds, read_img_mask_fn=read_img_mask): ds = UnoSatTiles(path) df = pd.read_csv(csv_path) # remove tiles to skip df = df[~df['skip']].copy() num_folds = len(df['fold_index'].unique()) def _get_indices(fold_indices, folds_indices_dict): indices = [] for i in fold_indices: indices += folds_indices_dict[i] return indices folds_indices_dict = get_fold_indices_dict(df, num_folds) train_indices = _get_indices(train_folds, folds_indices_dict) val_indices = _get_indices(val_folds, folds_indices_dict) train_ds, val_ds = train_val_split(ds, train_indices, val_indices) # Include data reading transformation train_ds = TransformedDataset(train_ds, transform_fn=read_img_mask_fn) val_ds = TransformedDataset(val_ds, transform_fn=read_img_mask_fn) return train_ds, val_ds
#################### Globals #################### seed = 12 debug = False device = 'cuda' num_classes = 2 #################### Dataflow #################### assert "INPUT_PATH" in os.environ data_path = os.path.join(os.environ['INPUT_PATH'], "test_tiles") test_dataset = UnoSatTestTiles(data_path) test_dataset = TransformedDataset(test_dataset, transform_fn=read_img_in_db) batch_size = 4 num_workers = 12 mean = [-17.398721187929123, -10.020421713800838, -12.10841437771272] std = [6.290316422115964, 5.776936185931195, 5.795418280085563] max_value = 1.0 transforms = A.Compose( [A.Normalize(mean=mean, std=std, max_pixel_value=max_value), ToTensorV2()]) data_loader = get_inference_dataloader( test_dataset, transforms=transforms,
def get_train_val_loaders( train_ds: Type[Dataset], val_ds: Type[Dataset], train_transforms: Callable, val_transforms: Callable, batch_size: int = 16, num_workers: int = 8, val_batch_size: Optional[int] = None, pin_memory: bool = True, train_sampler: Optional[Sampler] = None, val_sampler: Optional[Sampler] = None, limit_train_num_samples: Optional[int] = None, limit_val_num_samples: Optional[int] = None ) -> Tuple[DataLoader, DataLoader, DataLoader]: if limit_train_num_samples is not None: np.random.seed(limit_train_num_samples) train_indices = np.random.permutation( len(train_ds))[:limit_train_num_samples] train_ds = Subset(train_ds, train_indices) if limit_val_num_samples is not None: np.random.seed(limit_val_num_samples) val_indices = np.random.permutation( len(val_ds))[:limit_val_num_samples] val_ds = Subset(val_ds, val_indices) # random samples for evaluation on training dataset if len(val_ds) < len(train_ds): train_eval_indices = np.random.permutation(len(train_ds))[:len(val_ds)] train_eval_ds = Subset(train_ds, train_eval_indices) else: train_eval_ds = train_ds train_ds = TransformedDataset(train_ds, transform_fn=train_transforms) val_ds = TransformedDataset(val_ds, transform_fn=val_transforms) train_eval_ds = TransformedDataset(train_eval_ds, transform_fn=val_transforms) if dist.is_available() and dist.is_initialized(): if train_sampler is not None: train_sampler = DistributedProxySampler(train_sampler) else: train_sampler = data_dist.DistributedSampler(train_ds) if val_sampler is not None: val_sampler = DistributedProxySampler(val_sampler) else: val_sampler = data_dist.DistributedSampler(val_ds, shuffle=False) train_loader = DataLoader(train_ds, shuffle=train_sampler is None, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler, pin_memory=pin_memory, drop_last=True) val_batch_size = batch_size * 4 if val_batch_size is None else val_batch_size val_loader = DataLoader(val_ds, shuffle=False, sampler=val_sampler, batch_size=val_batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=False) train_eval_loader = DataLoader(train_eval_ds, shuffle=False, sampler=val_sampler, batch_size=val_batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=False) return train_loader, val_loader, train_eval_loader
def get_train_mean_std(train_dataset, unique_id="", cache_dir="/tmp/unosat/"): # # Ensure that only process 0 in distributed performs the computation, and the others will use the cache # if dist.get_rank() > 0: # torch.distributed.barrier() # synchronization point for all processes > 0 cache_dir = Path(cache_dir) if not cache_dir.exists(): cache_dir.mkdir(parents=True) if len(unique_id) > 0: unique_id += "_" fp = cache_dir / "train_mean_std_{}{}.pth".format(len(train_dataset), unique_id) if fp.exists(): mean_std = torch.load(fp.as_posix()) else: if dist.is_available() and dist.is_initialized(): raise RuntimeError( "Current implementation of Mean/Std computation is not working in distrib config" ) from ignite.engine import Engine from ignite.metrics import Average from ignite.contrib.handlers import ProgressBar from albumentations.pytorch import ToTensorV2 train_dataset = TransformedDataset(train_dataset, transform_fn=ToTensorV2()) train_loader = DataLoader(train_dataset, shuffle=False, drop_last=False, batch_size=16, num_workers=10, pin_memory=False) def compute_mean_std(engine, batch): b, c, *_ = batch['image'].shape data = batch['image'].reshape(b, c, -1).to(dtype=torch.float64) mean = torch.mean(data, dim=-1) mean2 = torch.mean(data**2, dim=-1) return { "mean": mean, "mean^2": mean2, } compute_engine = Engine(compute_mean_std) ProgressBar(desc="Compute Mean/Std").attach(compute_engine) img_mean = Average(output_transform=lambda output: output['mean']) img_mean2 = Average(output_transform=lambda output: output['mean^2']) img_mean.attach(compute_engine, 'mean') img_mean2.attach(compute_engine, 'mean2') state = compute_engine.run(train_loader) state.metrics['std'] = torch.sqrt(state.metrics['mean2'] - state.metrics['mean']**2) mean_std = {'mean': state.metrics['mean'], 'std': state.metrics['std']} # if dist.get_rank() < 1: torch.save(mean_std, fp.as_posix()) # if dist.get_rank() < 1: # torch.distributed.barrier() # synchronization point for process 0 return mean_std['mean'].tolist(), mean_std['std'].tolist()