예제 #1
0
def get_inference_dataloader(
        dataset: Type[Dataset],
        transforms: Callable,
        batch_size: int = 16,
        num_workers: int = 8,
        pin_memory: bool = True,
        limit_num_samples: Optional[int] = None) -> DataLoader:

    if limit_num_samples is not None:
        np.random.seed(limit_num_samples)
        indices = np.random.permutation(len(dataset))[:limit_num_samples]
        dataset = Subset(dataset, indices)

    dataset = TransformedDataset(dataset, transform_fn=transforms)

    sampler = None
    if dist.is_available() and dist.is_initialized():
        sampler = data_dist.DistributedSampler(dataset, shuffle=False)

    loader = DataLoader(dataset,
                        shuffle=False,
                        batch_size=batch_size,
                        num_workers=num_workers,
                        sampler=sampler,
                        pin_memory=pin_memory,
                        drop_last=False)
    return loader
예제 #2
0
def get_trainval_datasets(path, csv_path, train_folds, val_folds, read_img_mask_fn=read_img_mask):
    ds = UnoSatTiles(path)
    df = pd.read_csv(csv_path)
    # remove tiles to skip
    df = df[~df['skip']].copy()

    num_folds = len(df['fold_index'].unique())

    def _get_indices(fold_indices, folds_indices_dict):
        indices = []
        for i in fold_indices:
            indices += folds_indices_dict[i]
        return indices

    folds_indices_dict = get_fold_indices_dict(df, num_folds)
    train_indices = _get_indices(train_folds, folds_indices_dict)
    val_indices = _get_indices(val_folds, folds_indices_dict)
    train_ds, val_ds = train_val_split(ds, train_indices, val_indices)

    # Include data reading transformation
    train_ds = TransformedDataset(train_ds, transform_fn=read_img_mask_fn)
    val_ds = TransformedDataset(val_ds, transform_fn=read_img_mask_fn)

    return train_ds, val_ds
#################### Globals ####################

seed = 12
debug = False
device = 'cuda'

num_classes = 2

#################### Dataflow ####################

assert "INPUT_PATH" in os.environ
data_path = os.path.join(os.environ['INPUT_PATH'], "test_tiles")

test_dataset = UnoSatTestTiles(data_path)
test_dataset = TransformedDataset(test_dataset, transform_fn=read_img_in_db)

batch_size = 4
num_workers = 12

mean = [-17.398721187929123, -10.020421713800838, -12.10841437771272]
std = [6.290316422115964, 5.776936185931195, 5.795418280085563]
max_value = 1.0

transforms = A.Compose(
    [A.Normalize(mean=mean, std=std, max_pixel_value=max_value),
     ToTensorV2()])

data_loader = get_inference_dataloader(
    test_dataset,
    transforms=transforms,
예제 #4
0
def get_train_val_loaders(
    train_ds: Type[Dataset],
    val_ds: Type[Dataset],
    train_transforms: Callable,
    val_transforms: Callable,
    batch_size: int = 16,
    num_workers: int = 8,
    val_batch_size: Optional[int] = None,
    pin_memory: bool = True,
    train_sampler: Optional[Sampler] = None,
    val_sampler: Optional[Sampler] = None,
    limit_train_num_samples: Optional[int] = None,
    limit_val_num_samples: Optional[int] = None
) -> Tuple[DataLoader, DataLoader, DataLoader]:

    if limit_train_num_samples is not None:
        np.random.seed(limit_train_num_samples)
        train_indices = np.random.permutation(
            len(train_ds))[:limit_train_num_samples]
        train_ds = Subset(train_ds, train_indices)

    if limit_val_num_samples is not None:
        np.random.seed(limit_val_num_samples)
        val_indices = np.random.permutation(
            len(val_ds))[:limit_val_num_samples]
        val_ds = Subset(val_ds, val_indices)

    # random samples for evaluation on training dataset
    if len(val_ds) < len(train_ds):
        train_eval_indices = np.random.permutation(len(train_ds))[:len(val_ds)]
        train_eval_ds = Subset(train_ds, train_eval_indices)
    else:
        train_eval_ds = train_ds

    train_ds = TransformedDataset(train_ds, transform_fn=train_transforms)
    val_ds = TransformedDataset(val_ds, transform_fn=val_transforms)
    train_eval_ds = TransformedDataset(train_eval_ds,
                                       transform_fn=val_transforms)

    if dist.is_available() and dist.is_initialized():
        if train_sampler is not None:
            train_sampler = DistributedProxySampler(train_sampler)
        else:
            train_sampler = data_dist.DistributedSampler(train_ds)

        if val_sampler is not None:
            val_sampler = DistributedProxySampler(val_sampler)
        else:
            val_sampler = data_dist.DistributedSampler(val_ds, shuffle=False)

    train_loader = DataLoader(train_ds,
                              shuffle=train_sampler is None,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              sampler=train_sampler,
                              pin_memory=pin_memory,
                              drop_last=True)

    val_batch_size = batch_size * 4 if val_batch_size is None else val_batch_size
    val_loader = DataLoader(val_ds,
                            shuffle=False,
                            sampler=val_sampler,
                            batch_size=val_batch_size,
                            num_workers=num_workers,
                            pin_memory=pin_memory,
                            drop_last=False)

    train_eval_loader = DataLoader(train_eval_ds,
                                   shuffle=False,
                                   sampler=val_sampler,
                                   batch_size=val_batch_size,
                                   num_workers=num_workers,
                                   pin_memory=pin_memory,
                                   drop_last=False)

    return train_loader, val_loader, train_eval_loader
예제 #5
0
def get_train_mean_std(train_dataset, unique_id="", cache_dir="/tmp/unosat/"):
    # # Ensure that only process 0 in distributed performs the computation, and the others will use the cache
    # if dist.get_rank() > 0:
    #     torch.distributed.barrier()  # synchronization point for all processes > 0

    cache_dir = Path(cache_dir)
    if not cache_dir.exists():
        cache_dir.mkdir(parents=True)

    if len(unique_id) > 0:
        unique_id += "_"

    fp = cache_dir / "train_mean_std_{}{}.pth".format(len(train_dataset),
                                                      unique_id)

    if fp.exists():
        mean_std = torch.load(fp.as_posix())
    else:
        if dist.is_available() and dist.is_initialized():
            raise RuntimeError(
                "Current implementation of Mean/Std computation is not working in distrib config"
            )

        from ignite.engine import Engine
        from ignite.metrics import Average
        from ignite.contrib.handlers import ProgressBar
        from albumentations.pytorch import ToTensorV2

        train_dataset = TransformedDataset(train_dataset,
                                           transform_fn=ToTensorV2())
        train_loader = DataLoader(train_dataset,
                                  shuffle=False,
                                  drop_last=False,
                                  batch_size=16,
                                  num_workers=10,
                                  pin_memory=False)

        def compute_mean_std(engine, batch):
            b, c, *_ = batch['image'].shape
            data = batch['image'].reshape(b, c, -1).to(dtype=torch.float64)
            mean = torch.mean(data, dim=-1)
            mean2 = torch.mean(data**2, dim=-1)

            return {
                "mean": mean,
                "mean^2": mean2,
            }

        compute_engine = Engine(compute_mean_std)
        ProgressBar(desc="Compute Mean/Std").attach(compute_engine)
        img_mean = Average(output_transform=lambda output: output['mean'])
        img_mean2 = Average(output_transform=lambda output: output['mean^2'])
        img_mean.attach(compute_engine, 'mean')
        img_mean2.attach(compute_engine, 'mean2')
        state = compute_engine.run(train_loader)
        state.metrics['std'] = torch.sqrt(state.metrics['mean2'] -
                                          state.metrics['mean']**2)
        mean_std = {'mean': state.metrics['mean'], 'std': state.metrics['std']}

        # if dist.get_rank() < 1:
        torch.save(mean_std, fp.as_posix())

    # if dist.get_rank() < 1:
    #     torch.distributed.barrier()  # synchronization point for process 0

    return mean_std['mean'].tolist(), mean_std['std'].tolist()