Ejemplo n.º 1
0
        def _switch_to_new_mode(self):
            """
            helper function to switch to the new debug mode
            (and saving the previous one in ``self._mode``)

            """
            prev_mode = get_current_debug_mode()
            set_debug_mode(self._mode)
            self._mode = prev_mode
    def n_process_augmentation(self):
        """
        Property to access the number of augmentation processes

        Returns
        -------
        int
            number of augmentation processes
        """

        if get_current_debug_mode():
            return 1
        return self._n_process_augmentation
Ejemplo n.º 3
0
    def __init__(self,
                 transform: AbstractTransform,
                 nopython=True,
                 target="cpu",
                 parallel=False,
                 **options):

        if get_current_debug_mode():
            # set options for debug mode
            logging.debug("Debug mode detected. Overwriting numba options "
                          "nopython to False and target to cpu")
            nopython = False
            target = "cpu"

        transform.__call__ = numba.jit(transform.__call__,
                                       nopython=nopython,
                                       target=target,
                                       parallel=parallel,
                                       **options)
        self._transform = transform
Ejemplo n.º 4
0
 def _resolve_augmenter_cls(num_processes, **kwargs):
     """
     Resolves the augmenter class by the number of specified processes and
     the debug mode and creates an instance of the chosen class
     Parameters
     ----------
     num_processes : int
         the number of processes to use for dataloading + augmentation;
         if None: the number of available CPUs will be used as number of
         processes
     **kwargs :
         additional keyword arguments, used for instantiation of the chosen
         class
     Returns
     -------
     :class:`AbstractAugmenter`
         an instance of the chosen augmenter class
     """
     if get_current_debug_mode() or num_processes == 0:
         return _SequentialAugmenter(**kwargs)
     return _ParallelAugmenter(num_processes=num_processes, **kwargs)
Ejemplo n.º 5
0
    def _make_dataset(self, path: str):

        patients = subdirs(path)

        sub_dirs = []

        for pat in patients:
            sub_dirs += [x for x in subdirs(pat)]

        sub_dirs = sorted(sub_dirs)
        sub_dirs_not_flipped = [(x, False) for x in sub_dirs]
        if self._include_flipped:
            sub_dirs_flipped = [(x, True) for x in sub_dirs]
        else:
            sub_dirs_flipped = []

        sub_dirs = sub_dirs_not_flipped + sub_dirs_flipped

        if not self.lazy:
            if get_current_debug_mode():
                return [
                    self._load_fn(tmp[0],
                                  self._img_size,
                                  flip=tmp[1],
                                  **self._add_kwargs) for tmp in sub_dirs
                ]
            else:
                func = partial(multiproc_fn,
                               img_size=self._img_size,
                               func=self._load_fn,
                               **self._add_kwargs)

                with Pool() as p:
                    return p.map(func, sub_dirs)

        return sub_dirs
Ejemplo n.º 6
0
with open(config_path, "r") as f:
    config = DeliraConfig(**json.load(f))

base_transforms = [
    # HistogramEqualization(),
    RangeTransform((-1, 1)),
    # AddGridTransform()
]
train_specific_transforms = []
test_specific_transforms = []

train_transforms = Compose(base_transforms + train_specific_transforms)
test_transforms = Compose(base_transforms + test_specific_transforms)

if get_current_debug_mode():
    train_dir = "Test"
    test_dir = "Test"
else:
    train_dir = "Train"
    test_dir = "Test"

train_path = os.path.join(data_path, train_dir)
test_path = os.path.join(data_path, test_dir)

if bone_label is None:
    dset = WholeLegDataset(train_path,
                           include_flipped=True,
                           img_size=config.img_size,
                           contourwidth=5)
    dset_test = WholeLegDataset(test_path,
Ejemplo n.º 7
0
    def __init__(self, data_loader: BaseDataLoader, transforms,
                 n_process_augmentation, sampler, sampler_queues: list,
                 num_cached_per_queue=2, seeds=None, **kwargs):
        """

        Parameters
        ----------
        data_loader : :class:`BaseDataLoader`
            the dataloader providing the actual data
        transforms : Callable or None
            the transforms to use. Can be single callable or None
        n_process_augmentation : int
            the number of processes to use for augmentation (only necessary if
            not in debug mode)
        sampler : :class:`AbstractSampler`
            the sampler to use; must be used here instead of inside the
            dataloader to avoid duplications and oversampling due to
            multiprocessing
        sampler_queues : list of :class:`multiprocessing.Queue`
            queues to pass the sample indices to the actual dataloader
        num_cached_per_queue : int
            the number of samples to cache per queue (only necessary if not in
            debug mode)
        seeds : int or list
            the seeds for each process (only necessary if not in debug mode)
        **kwargs :
            additional keyword arguments
        """

        self._batchsize = data_loader.batch_size

        # don't use multiprocessing in debug mode
        if get_current_debug_mode():
            augmenter = SingleThreadedAugmenter(data_loader, transforms)

        else:
            assert isinstance(n_process_augmentation, int)
            # no seeds are given -> use default seed of 1
            if seeds is None:
                seeds = 1

            # only an int is gien as seed -> replicate it for each process
            if isinstance(seeds, int):
                seeds = [seeds] * n_process_augmentation

            # avoid same seeds for all processes
            if any([seeds[0] == _seed for _seed in seeds[1:]]):
                for idx in range(len(seeds)):
                    seeds[idx] = seeds[idx] + idx

            augmenter = MultiThreadedAugmenter(
                data_loader, transforms,
                num_processes=n_process_augmentation,
                num_cached_per_queue=num_cached_per_queue,
                seeds=seeds,
                **kwargs)

        self._augmenter = augmenter
        self._sampler = sampler
        self._sampler_queues = sampler_queues
        self._queue_id = 0