Exemple #1
0
def get_root_nucleus_loader(
    path,
    patch_shape,
    samples=None,
    target_resolution=None,
    download=False,
    offsets=None,
    boundaries=False,
    binary=False,
    **kwargs,
):
    assert len(patch_shape) == 3
    _download_plantseg(path, download, "root", "nuclei")
    if target_resolution is not None:
        _resize(path, NATIVE_RESOLUTION, target_resolution)

    file_paths = glob(os.path.join(path, "*.h5"))
    file_paths.sort()

    if samples is not None:
        assert all(isinstance(sample, int) for sample in samples)
        assert all(sample < len(file_paths) for sample in samples)
        file_paths = [file_paths[sample] for sample in samples]

    assert sum((offsets is not None, boundaries, binary)) <= 1
    if offsets is not None:
        # we add a binary target channel for foreground background segmentation
        label_transform = torch_em.transform.label.AffinityTransform(
            offsets=offsets, add_binary_target=True, add_mask=True)
        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               "label_transform2",
                               label_transform,
                               msg=msg)
    elif boundaries:
        label_transform = torch_em.transform.label.BoundaryTransform(
            add_binary_target=True)
        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               "label_transform",
                               label_transform,
                               msg=msg)
    elif binary:
        label_transform = torch_em.transform.label.labels_to_binary
        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               "label_transform",
                               label_transform,
                               msg=msg)

    kwargs = update_kwargs(kwargs, "patch_shape", patch_shape)
    kwargs = update_kwargs(kwargs, "ndim", 3)

    if target_resolution is None:
        raw_key, label_key = "raw", "label_uint16_smooth"
    else:
        raw_key, label_key = "rescaled/raw", "rescaled/label_uint16_smooth"
    return torch_em.default_segmentation_loader(file_paths, raw_key,
                                                file_paths, label_key,
                                                **kwargs)
 def _get_kwargs(self, with_roi=False):
     roi = np.s_[:6, :, :] if with_roi else None
     loader = default_segmentation_loader(
         raw_paths=self.data_path,
         raw_key="raw",
         label_paths=self.data_path,
         label_key="labels",
         batch_size=1,
         patch_shape=(1, 128, 128),
         ndim=2,
         rois=roi,
     )
     model = UNet2d(in_channels=1,
                    out_channels=1,
                    depth=2,
                    initial_features=4)
     kwargs = {
         "name": self.name,
         "train_loader": loader,
         "val_loader": loader,
         "model": model,
         "loss": DiceLoss(),
         "metric": DiceLoss(),
         "optimizer": torch.optim.Adam(model.parameters(), lr=1e-5),
         "device": torch.device("cpu"),
         "mixed_precision": False,
     }
     return kwargs
Exemple #3
0
def get_loader(split, patch_shape, batch_size, n_samples=None, roi=None):
    raw_key = 'raw'
    label_key = 'label'
    paths = get_paths(split, patch_shape, raw_key)

    sampler = torch_em.data.MinForegroundSampler(min_fraction=0.1, p_reject=1.)
    label_transform = torch_em.transform.label.AffinityTransform(
        offsets=OFFSETS,
        ignore_label=None,
        add_binary_target=False,
        add_mask=True)
    return torch_em.default_segmentation_loader(
        paths,
        raw_key,
        paths,
        label_key,
        batch_size=batch_size,
        patch_shape=patch_shape,
        label_transform2=label_transform,
        sampler=sampler,
        n_samples=n_samples,
        num_workers=8 * batch_size,
        shuffle=True,
        label_dtype=torch.float32,
        ndim=2)
Exemple #4
0
def get_loader(split, patch_shape, batch_size=1, n_samples=None):

    raw_paths, label_paths = get_datasets()

    split_idx = int(0.8 * len(raw_paths))
    if split == 'train':
        raw_paths = raw_paths[:split_idx]
        label_paths = label_paths[:split_idx]
    else:
        raw_paths = raw_paths[split_idx:]
        label_paths = label_paths[split_idx:]

    raw_key = 'data'
    label_key = 'data'

    # we add a binary target channel for foreground background segmentation
    ignore_label = None  # set ignore label here
    label_transform = torch_em.transform.label.AffinityTransform(
        offsets=OFFSETS,
        ignore_label=ignore_label,
        add_binary_target=True,
        add_mask=True)

    return torch_em.default_segmentation_loader(
        raw_paths,
        raw_key,
        label_paths,
        label_key,
        batch_size=batch_size,
        patch_shape=patch_shape,
        label_transform2=label_transform,
        n_samples=n_samples,
        num_workers=8 * batch_size,
        shuffle=True)
def get_platynereis_nuclei_loader(path, patch_shape,
                                  sample_ids=None, rois={},
                                  offsets=None, boundaries=False, binary=False,
                                  download=False, **kwargs):
    nuc_root = os.path.join(path, "nuclei")
    prefix = "train_data_nuclei_"
    ext = ".h5"
    n_files = 12
    data_is_complete = _check_data(nuc_root, "train_data_nuclei", ".h5", 12)
    if not data_is_complete:
        _require_platy_data(path, "nuclei", download)

    if sample_ids is None:
        sample_ids = list(range(1, n_files + 1))
    assert min(sample_ids) >= 1 and max(sample_ids) <= n_files
    sample_ids.sort()

    data_paths = []
    data_rois = []
    for sample in sample_ids:
        data_paths.append(
            os.path.join(nuc_root, f"{prefix}{sample:02}{ext}")
        )
        data_rois.append(rois.get(sample, np.s_[:, :, :]))

    kwargs = update_kwargs(kwargs, "patch_shape", patch_shape)
    kwargs = update_kwargs(kwargs, "rois", data_rois)
    # set ndim to 3 if not otherwise specified
    if "ndim" not in kwargs:
        kwargs["ndim"] = 3

    assert sum((offsets is not None, boundaries, binary)) <= 1
    if offsets is not None:
        # we add a binary target channel for foreground background segmentation
        label_transform = torch_em.transform.label.AffinityTransform(offsets=offsets,
                                                                     ignore_label=-1,
                                                                     add_binary_target=True,
                                                                     add_mask=True)
        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs, "label_transform2", label_transform, msg=msg)
    elif boundaries:
        label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True)
        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
    elif binary:
        label_transform = torch_em.transform.label.labels_to_binary
        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)

    raw_key = "volumes/raw"
    label_key = "volumes/labels/nucleus_instance_labels"
    return torch_em.default_segmentation_loader(
        data_paths, raw_key,
        data_paths, label_key,
        **kwargs
    )
def get_loader(patch_shape, batch_size):
    _download_covid_if(DATA_FOLDER, True)
    file_paths = glob(os.path.join(DATA_FOLDER, "*.h5"))
    file_paths.sort()
    raw_key = "raw/serum_IgG/s0"
    label_key = "labels/cells/s0"
    label_transform = torch_em.transform.label.labels_to_binary
    return torch_em.default_segmentation_loader(file_paths, raw_key,
                                                file_paths, label_key,
                                                batch_size=batch_size,
                                                patch_shape=patch_shape,
                                                label_transform=label_transform,
                                                label_transform2=rescale_labels)
Exemple #7
0
def get_monuseg_loader(path,
                       patch_shape,
                       download=False,
                       roi=None,
                       offsets=None,
                       boundaries=False,
                       binary=False,
                       **kwargs):
    _download_monuseg(path, download)

    image_path = os.path.join(path, "images")
    label_path = os.path.join(path, "labels")

    assert sum((offsets is not None, boundaries, binary)) <= 1
    if offsets is not None:
        # we add a binary target channel for foreground background segmentation
        label_transform = torch_em.transform.label.AffinityTransform(
            offsets=offsets, add_binary_target=True, add_mask=True)
        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               'label_transform2',
                               label_transform,
                               msg=msg)
    elif boundaries:
        label_transform = torch_em.transform.label.BoundaryTransform(
            add_binary_target=True)
        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               "label_transform",
                               label_transform,
                               msg=msg)
    elif binary:
        label_transform = torch_em.transform.label.labels_to_binary
        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               "label_transform",
                               label_transform,
                               msg=msg)

    kwargs = update_kwargs(kwargs, "patch_shape", patch_shape)
    kwargs = update_kwargs(kwargs, "ndim", 2)

    return torch_em.default_segmentation_loader(image_path,
                                                "*.tif",
                                                label_path,
                                                "*.tif",
                                                is_seg_dataset=False,
                                                rois=roi,
                                                **kwargs)
Exemple #8
0
def get_loader(args, split, patch_shape):
    paths = glob(os.path.join(args.input, split, "*.zarr"))
    assert len(paths) > 0
    label_transform = torch_em.transform.AffinityTransform(OFFSETS,
                                                           add_mask=True)
    return torch_em.default_segmentation_loader(
        paths,
        "raw",
        paths,
        "labels",
        args.batch_size,
        patch_shape,
        shuffle=True,
        num_workers=4 * args.batch_size,
        label_transform2=label_transform)
Exemple #9
0
def get_vnc_mito_loader(path,
                        patch_shape,
                        offsets=None,
                        boundaries=False,
                        binary=False,
                        download=False,
                        **kwargs):
    _get_vnc_data(path, download)
    data_path = os.path.join(path, "vnc_train.h5")

    assert sum((offsets is not None, boundaries,
                binary)) <= 1, f"{offsets}, {boundaries}, {binary}"
    if offsets is not None:
        # we add a binary target channel for foreground background segmentation
        label_transform = torch_em.transform.label.AffinityTransform(
            offsets=offsets,
            ignore_label=None,
            add_binary_target=True,
            add_mask=True)
        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               'label_transform2',
                               label_transform,
                               msg=msg)
    elif boundaries:
        label_transform = torch_em.transform.label.BoundaryTransform(
            add_binary_target=True)
        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               'label_transform',
                               label_transform,
                               msg=msg)
    elif binary:
        label_transform = torch_em.transform.label.labels_to_binary
        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               'label_transform',
                               label_transform,
                               msg=msg)

    raw_key = "raw"
    label_key = "labels/mitochondria"
    return torch_em.default_segmentation_loader(data_path,
                                                raw_key,
                                                data_path,
                                                label_key,
                                                patch_shape=patch_shape,
                                                **kwargs)
def get_loader(split, patch_shape, batch_size, n_samples=None, roi=None):
    raw_key = 'raw'
    label_key = 'label'
    paths = get_paths(split, patch_shape, raw_key)

    sampler = torch_em.data.MinForegroundSampler(min_fraction=0.1, p_reject=1.)
    label_transform = partial(torch_em.transform.label.connected_components,
                              ensure_zero=True)

    return torch_em.default_segmentation_loader(
        paths,
        raw_key,
        paths,
        label_key,
        batch_size=batch_size,
        patch_shape=patch_shape,
        label_transform=label_transform,
        sampler=sampler,
        n_samples=n_samples,
        num_workers=8 * batch_size,
        shuffle=True,
        label_dtype=torch.int64,
        ndim=2)
Exemple #11
0
def get_isbi_loader(path, patch_shape, download=False,
                    offsets=None, boundaries=False,
                    use_original_labels=False,
                    **kwargs):
    """
    """

    assert len(patch_shape) == 3
    download_source(path, ISBI_URL, download, CHECKSUM)
    ndim = 2 if patch_shape[0] == 1 else 3
    kwargs = update_kwargs(kwargs, "patch_shape", patch_shape)
    kwargs = update_kwargs(kwargs, "ndim", ndim)

    raw_key = "raw"
    label_key = "labels/membranes" if use_original_labels else "labels/gt_segmentation"

    assert not ((offsets is not None) and boundaries)
    if offsets is not None:
        assert not use_original_labels
        # we add a binary target channel for foreground background segmentation
        label_transform = torch_em.transform.label.AffinityTransform(offsets=offsets,
                                                                     ignore_label=None,
                                                                     add_binary_target=False,
                                                                     add_mask=True)
        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs, "label_transform2", label_transform, msg=msg)
    elif boundaries:
        assert not use_original_labels
        label_transform = torch_em.transform.label.BoundaryTransform()
        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)

    return torch_em.default_segmentation_loader(
        path, raw_key,
        path, label_key,
        **kwargs
    )
def get_loader(input_path, patch_shape, roi,
               batch_size=1, use_diagonal_offsets=False):

    raw_key = 'raw'
    label_key = 'labels/gt_segmentation'

    offsets = get_offsets(use_diagonal_offsets)
    # we add a binary target channel for foreground background segmentation
    label_transform = torch_em.transform.label.AffinityTransform(offsets=offsets,
                                                                 ignore_label=None,
                                                                 add_binary_target=False,
                                                                 add_mask=True)

    return torch_em.default_segmentation_loader(
        input_path, raw_key,
        input_path, label_key,
        batch_size=batch_size,
        patch_shape=patch_shape,
        label_transform2=label_transform,
        rois=roi,
        ndim=2,
        num_workers=8*batch_size,
        shuffle=True
    )
def get_loader(datasets, patch_shape, batch_size=1, n_samples=None, roi=None):

    paths = [os.path.join(ROOT, f'{ds}.n5') for ds in datasets]

    raw_key = 'raw'
    label_key = 'labels'

    sampler = torch_em.data.MinForegroundSampler(min_fraction=0.05,
                                                 p_reject=.75)
    label_transform = torch_em.transform.label.connected_components

    return torch_em.default_segmentation_loader(
        paths,
        raw_key,
        paths,
        label_key,
        batch_size=batch_size,
        patch_shape=patch_shape,
        label_transform=label_transform,
        sampler=sampler,
        n_samples=n_samples,
        num_workers=8 * batch_size,
        shuffle=True,
        label_dtype=torch.int64)
Exemple #14
0
def get_hpa_segmentation_loader(path, patch_shape, split,
                                offsets=None, boundaries=False, binary=False,
                                channels=["microtubules", "protein", "nuclei", "er"],
                                download=False, n_workers_preproc=8, **kwargs):
    data_is_complete = _check_data(path)
    if not data_is_complete:
        _download_hpa_data(path, "segmentation", download)
        _process_hpa_data(path, channels, n_workers_preproc, remove=True)

    assert sum((offsets is not None, boundaries, binary)) <= 1
    if offsets is not None:
        # we add a binary target channel for foreground background segmentation
        label_transform = torch_em.transform.label.AffinityTransform(offsets=offsets,
                                                                     add_binary_target=True,
                                                                     add_mask=True)
        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs, "label_transform2", label_transform, msg=msg)
    elif boundaries:
        label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True)
        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
    elif binary:
        label_transform = torch_em.transform.label.labels_to_binary
        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)

    kwargs = update_kwargs(kwargs, "patch_shape", patch_shape)
    kwargs = update_kwargs(kwargs, "ndim", 2)
    kwargs = update_kwargs(kwargs, "with_channels", True)

    paths = glob(os.path.join(path, split, "*.h5"))
    raw_key = "raw"
    label_key = "labels"
    return torch_em.default_segmentation_loader(
        paths, raw_key, paths, label_key, **kwargs
    )
Exemple #15
0
def get_covid_if_loader(path,
                        patch_shape,
                        sample_range=None,
                        target="cells",
                        download=False,
                        offsets=None,
                        boundaries=False,
                        binary=False,
                        **kwargs):

    available_targets = ("cells", "nuclei")
    # TODO support all of these
    # available_targets = ("cells", "nuclei", "infected_cells")
    assert target in available_targets, f"{target} not found in {available_targets}"

    if target == "cells":
        raw_key = "raw/serum_IgG/s0"
        label_key = "labels/cells/s0"
    elif target == "nuclei":
        raw_key = "raw/nuclei/s0"
        label_key = "labels/nuclei/s0"
    # elif target == "infected_cells":

    _download_covid_if(path, download)

    file_paths = glob(os.path.join(path, "*.h5"))
    file_paths.sort()
    if sample_range is not None:
        start, stop = sample_range
        if start is None:
            start = 0
        if stop is None:
            stop = len(file_paths)
        file_paths = [
            os.path.join(path, f"gt_image_{idx:03}.h5")
            for idx in range(start, stop)
        ]
        assert all(
            os.path.exists(fp)
            for fp in file_paths), f"Invalid sample range {sample_range}"

    assert sum((offsets is not None, boundaries, binary)) <= 1
    if offsets is not None:
        # we add a binary target channel for foreground background segmentation
        label_transform = torch_em.transform.label.AffinityTransform(
            offsets=offsets, add_binary_target=True, add_mask=True)
        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               "label_transform2",
                               label_transform,
                               msg=msg)
    elif boundaries:
        label_transform = torch_em.transform.label.BoundaryTransform(
            add_binary_target=True)
        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               "label_transform",
                               label_transform,
                               msg=msg)
    elif binary:
        label_transform = torch_em.transform.label.labels_to_binary
        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               "label_transform",
                               label_transform,
                               msg=msg)

    kwargs = update_kwargs(kwargs, "patch_shape", patch_shape)
    kwargs = update_kwargs(kwargs, "ndim", 2)

    return torch_em.default_segmentation_loader(file_paths, raw_key,
                                                file_paths, label_key,
                                                **kwargs)
Exemple #16
0
def get_mitoem_loader(
    path,
    patch_shape,
    splits,
    samples=("human", "rat"),
    download=False,
    offsets=None,
    boundaries=False,
    binary=False,
    **kwargs,
):
    """
    """
    assert len(patch_shape) == 3
    if isinstance(splits, str):
        splits = [splits]
    assert len(set(splits) - {"train", "val"}) == 0, f"{splits}"
    assert len(set(samples) - {"human", "rat"}) == 0, f"{samples}"
    os.makedirs(path, exist_ok=True)

    data_paths = []
    for sample in samples:
        if not _check_data(path, sample):
            print("The MitoEM data for sample", sample,
                  "is not available yet and will be downloaded and created.")
            print(
                "Note that this dataset is large, so this step can take several hours (depending on your internet)."
            )
            _require_mitoem_sample(path, sample, download)
            print("The MitoEM data for sample", sample, "has been created.")

        for split in splits:
            split_path = os.path.join(path, f"{sample}_{split}.n5")
            assert os.path.exists(split_path), split_path
            data_paths.append(split_path)

    assert sum((offsets is not None, boundaries,
                binary)) <= 1, f"{offsets}, {boundaries}, {binary}"
    if offsets is not None:
        # we add a binary target channel for foreground background segmentation
        label_transform = torch_em.transform.label.AffinityTransform(
            offsets=offsets,
            ignore_label=None,
            add_binary_target=True,
            add_mask=True)
        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               'label_transform2',
                               label_transform,
                               msg=msg)
    elif boundaries:
        label_transform = torch_em.transform.label.BoundaryTransform(
            add_binary_target=True)
        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               'label_transform',
                               label_transform,
                               msg=msg)
    elif binary:
        label_transform = torch_em.transform.label.labels_to_binary
        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
        kwargs = update_kwargs(kwargs,
                               'label_transform',
                               label_transform,
                               msg=msg)

    raw_key = "raw"
    label_key = "labels"
    kwargs["ndim"] = 3
    return torch_em.default_segmentation_loader(data_paths,
                                                raw_key,
                                                data_paths,
                                                label_key,
                                                patch_shape=patch_shape,
                                                **kwargs)