Beispiel #1
0
def get_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset:
    """Returns dataset.

    Returns
    -------
    torch.utils.data.Dataset
        Dataset object.

    """
    if sum([args.dataset is not None, args.path_tif is not None]) != 1:
        raise ValueError('Must specify one input source type')
    if args.dataset is not None:
        ds_fn = str_to_object(args.dataset)
        if not isinstance(ds_fn, Callable):
            raise ValueError(f'{args.dataset} must be callable')
        return ds_fn(**args.dataset_kwargs)
    if args.path_tif is not None:
        if not os.path.exists(args.path_tif):
            raise ValueError(f'Path does not exists: {args.path_tif}')
        paths_tif = [args.path_tif]
        if os.path.isdir(args.path_tif):
            paths_tif = files_from_dir(args.path_tif)
        ds = TiffDataset(
            dataframe=pd.DataFrame({
                'path_bf': paths_tif,
                'path_target': None
            }),
            transform_signal=[norm_around_center],
            transform_target=[norm_around_center],
            col_signal='path_bf',
        )
        return ds
    raise NotImplementedError
Beispiel #2
0
def DummyFnetDataset(train: bool = False) -> TiffDataset:
    """Returns a dummy Fnetdataset."""
    df = pd.DataFrame({
        "path_signal": [os.path.join("data", "EM_low.tif")],
        "path_target": [os.path.join("data", "MBP_low.tif")],
    }).rename_axis("arbitrary")
    if not train:
        df = add_augmentations(df)
    return TiffDataset(dataframe=df)