コード例 #1
0
    def setup_augmentation(self):

        c = self.config

        transforms_train = []
        for t in sorted(c.transforms_train.keys()):
            if c.transforms_train[t]["active"]:
                cls = c.transforms_train[t]["type"]
                kwargs = c.transforms_train[t]["kwargs"]
                transforms_train.append(cls(**kwargs))
        self.augmenter_train = c.augmenter_train(self.generator_train,
                                                 Compose(transforms_train),
                                                 **c.augmenter_train_kwargs)

        transforms_val = []
        for t in sorted(c.transforms_val.keys()):
            if c.transforms_val[t]["active"]:
                cls = c.transforms_val[t]["type"]
                kwargs = c.transforms_val[t]["kwargs"]
                transforms_val.append(cls(**kwargs))
        self.augmenter_val = c.augmenter_val(self.generator_val,
                                             Compose(transforms_val),
                                             **c.augmenter_val_kwargs)
        self.augmenter_test = c.augmenter_val(self.generator_test,
                                              Compose(transforms_val),
                                              **c.augmenter_val_kwargs)
コード例 #2
0
def create_data_gen_train(patient_data_train, BATCH_SIZE, num_classes,
                                  num_workers=5, num_cached_per_worker=2,
                                  do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.),
                                  do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi),
                                  do_scale=True, scale_range=(0.75, 1.25), seeds=None):
    if seeds is None:
        seeds = [None]*num_workers
    elif seeds == 'range':
        seeds = range(num_workers)
    else:
        assert len(seeds) == num_workers
    data_gen_train = BatchGenerator_2D(patient_data_train, BATCH_SIZE, num_batches=None, seed=False,
                                       PATCH_SIZE=(352, 352))

    tr_transforms = []
    tr_transforms.append(Mirror((2, 3)))
    tr_transforms.append(RndTransform(SpatialTransform((352, 352), list(np.array((352, 352))//2),
                                                       do_elastic_transform, alpha,
                                                       sigma,
                                                       do_rotation, a_x, a_y,
                                                       a_z,
                                                       do_scale, scale_range, 'constant', 0, 3, 'constant',
                                                       0, 0,
                                                       random_crop=False), prob=0.67,
                                      alternative_transform=RandomCropTransform((352, 352))))
    tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), seg_channel=0, output_key='seg_onehot'))

    tr_composed = Compose(tr_transforms)
    tr_mt_gen = MultiThreadedAugmenter(data_gen_train, tr_composed, num_workers, num_cached_per_worker, seeds)
    tr_mt_gen.restart()
    return tr_mt_gen
コード例 #3
0
def get_transforms(mode="train", target_size=128):
    tranform_list = []

    if mode == "train":
        tranform_list = [# CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         MirrorTransform(axes=(1,)),
                         SpatialTransform(patch_size=(target_size,target_size), random_crop=False,
                                          patch_center_dist_from_border=target_size // 2,
                                          do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.),
                                          do_rotation=True, p_rot_per_sample=0.5,
                                          angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8),
                                          scale=(0.5, 1.9), p_scale_per_sample=0.5,
                                          border_mode_data="nearest", border_mode_seg="nearest"),
                         ]


    elif mode == "val":
        tranform_list = [CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         ]

    elif mode == "test":
        tranform_list = [CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         ]

    tranform_list.append(NumpyToTensor())

    return Compose(tranform_list)
コード例 #4
0
    def test_image_pipeline_and_pin_memory(self):
        '''
        This just should not crash
        :return:
        '''
        try:
            import torch
        except ImportError:
            '''dont test if torch is not installed'''
            return

        from batchgenerators.transforms import MirrorTransform, NumpyToTensor, TransposeAxesTransform, Compose

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(
            TransposeAxesTransform(transpose_any_of_these=(0, 1),
                                   p_per_sample=0.5))
        tr_transforms.append(NumpyToTensor(keys='data', cast_to='float'))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True)

        for _ in range(50):
            res = mt.next()

        assert isinstance(res['data'], torch.Tensor)
        assert res['data'].is_pinned()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
コード例 #5
0
ファイル: train.py プロジェクト: thithaotran/iSeg-2019
def get_train_transform(patch_size):
    tr_transforms = []
    tr_transforms.append(
        SpatialTransform_2(
            None, [i // 2 for i in patch_size],
            do_elastic_deform=False,
            deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            do_scale=False,
            scale=(0.8, 1.2),
            border_mode_data='constant',
            border_cval_data=0,
            border_mode_seg='constant',
            border_cval_seg=0,
            order_seg=1,
            order_data=3,
            random_crop=False,
            p_el_per_sample=0.3,
            p_rot_per_sample=0.3,
            p_scale_per_sample=0.3))
    tr_transforms.append(
        RndTransform(MirrorTransform(axes=(0, 1, 2)), prob=0.3))
    tr_transforms = Compose(transforms=tr_transforms)
    return tr_transforms
コード例 #6
0
def get_train_val_generators(fold):
    tr_keys, te_keys = get_split(fold, split_seed)
    train_data = {i: dataset[i] for i in tr_keys}
    val_data = {i: dataset[i] for i in te_keys}

    data_gen_train = create_data_gen_train(
        train_data,
        BATCH_SIZE,
        num_classes,
        INPUT_PATCH_SIZE,
        num_workers=num_workers,
        do_elastic_transform=True,
        alpha=(0., 350.),
        sigma=(14., 17.),
        do_rotation=True,
        a_x=(0, 2. * np.pi),
        a_y=(-0.000001, 0.00001),
        a_z=(-0.000001, 0.00001),
        do_scale=True,
        scale_range=(0.7, 1.3),
        seeds=workers_seeds)  # new se has no brain mask

    data_gen_validation = BatchGenerator(val_data,
                                         BATCH_SIZE,
                                         num_batches=None,
                                         seed=False,
                                         PATCH_SIZE=INPUT_PATCH_SIZE)
    val_transforms = []
    val_transforms.append(
        ConvertSegToOnehotTransform(range(4), 0, 'seg_onehot'))
    data_gen_validation = MultiThreadedAugmenter(data_gen_validation,
                                                 Compose(val_transforms), 1, 2,
                                                 [0])
    return data_gen_train, data_gen_validation
コード例 #7
0
    def test_future(self):

        info = self.validate_make_default_info()
        info["coords"]["Metric"] = info["coords"]["Metric"] + [
            "Future KL", "Future Reconstruction NLL",
            "Future Reconstruction Dice", "Prior Maximum Dice",
            "Prior Best Volume Dice"
        ]

        # our regular generators only produce 1 timestep, so we create this here manually
        if self.config.test_on_val:
            test_data = self.data_val
        else:
            test_data = self.data_test

        generator = self.config.generator_val(
            test_data,
            self.config.batch_size_val,
            3,
            number_of_threads_in_multithreaded=self.config.
            augmenter_val_kwargs.num_processes)
        transforms = []
        for t in sorted(self.config.transforms_val.keys()):
            if self.config.transforms_val[t]["active"]:
                cls = self.config.transforms_val[t]["type"]
                kwargs = self.config.transforms_val[t]["kwargs"]
                transforms.append(cls(**kwargs))

        augmenter = self.config.augmenter_val(
            generator, Compose(transforms), **self.config.augmenter_val_kwargs)
        test_scores, info = self.test_inner(augmenter, [], info, future=True)
        test_scores = np.array(test_scores)

        self.elog.save_numpy_data(test_scores, "test_future.npy")
        self.elog.save_dict(info, "test_future.json")
コード例 #8
0
def get_train_transform(patch_size):
    """
    data augmentation for training data, inspired by:
    https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py
    :param patch_size: shape of network's input
    :return list of transformations
    """

    train_transforms = []

    def rad(deg):
        return (-deg / 360 * 2 * np.pi, deg / 360 * 2 * np.pi)

    train_transforms.append(
        SpatialTransform_2(
            patch_size,
            (10, 10, 10),
            do_elastic_deform=True,
            deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_z=rad(15),
            angle_x=(0, 0),
            angle_y=(0, 0),
            do_scale=True,
            scale=(0.75, 1.25),
            border_mode_data='constant',
            border_cval_data=0,
            border_mode_seg='constant',
            border_cval_seg=0,
            order_seg=1,
            random_crop=False,
            p_el_per_sample=0.2,
            p_rot_per_sample=0.2,
            p_scale_per_sample=0.2,
        ))

    train_transforms.append(MirrorTransform(axes=(0, 1)))

    train_transforms.append(
        BrightnessMultiplicativeTransform((0.7, 1.5),
                                          per_channel=True,
                                          p_per_sample=0.2))

    train_transforms.append(
        GammaTransform(gamma_range=(0.2, 1.0),
                       invert_image=False,
                       per_channel=False,
                       p_per_sample=0.2))

    train_transforms.append(
        GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2))

    train_transforms.append(
        GaussianBlurTransform(blur_sigma=(0.2, 1.0),
                              different_sigma_per_channel=False,
                              p_per_channel=0.0,
                              p_per_sample=0.2))

    return Compose(train_transforms)
コード例 #9
0
def get_no_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1):
    """
    use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation
    :param dataloader_train:
    :param dataloader_val:
    :param patch_size:
    :param params:
    :param border_val_seg:
    :return:
    """
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
                                                  params.get("num_cached_per_thread"),
                                                  seeds=range(params.get('num_threads')), pin_memory=True)
    batchgenerator_train.restart()

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    val_transforms.append(RenameTransform('seg', 'target', True))
    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads')//2, 1),
                                                params.get("num_cached_per_thread"),
                                                seeds=range(max(params.get('num_threads')//2, 1)), pin_memory=True)
    batchgenerator_val.restart()
    return batchgenerator_train, batchgenerator_val
コード例 #10
0
def get_train_transform(patch_size):
    # we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just
    # to showcase some things
    tr_transforms = []

    # the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus
    # also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the
    # shape and do not transform spatially, so no border artifacts will be introduced
    # Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform
    # We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27%
    # of samples will be augmented, the rest will just be cropped
    tr_transforms.append(
        SpatialTransform_2(
            patch_size, [i // 2 for i in patch_size],
            do_elastic_deform=True,
            deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            do_scale=True,
            scale=(0.75, 1.25),
            border_mode_data='constant',
            border_cval_data=0,
            border_mode_seg='constant',
            border_cval_seg=0,
            order_seg=1,
            order_data=3,
            random_crop=True,
            p_el_per_sample=0.1,
            p_rot_per_sample=0.1,
            p_scale_per_sample=0.1))

    # now we mirror along all axes
    tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))

    # gamma transform. This is a nonlinear transformation of intensity values
    # (https://en.wikipedia.org/wiki/Gamma_correction)
    tr_transforms.append(
        GammaTransform(gamma_range=(0.5, 2),
                       invert_image=False,
                       per_channel=True,
                       p_per_sample=0.15))
    # we can also invert the image, apply the transform and then invert back
    tr_transforms.append(
        GammaTransform(gamma_range=(0.5, 2),
                       invert_image=True,
                       per_channel=True,
                       p_per_sample=0.15))

    # Gaussian Noise
    tr_transforms.append(
        GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15))

    # now we compose these transforms together
    tr_transforms = Compose(tr_transforms)
    return tr_transforms
コード例 #11
0
    def setUp(self) -> None:
        from delira.data_loading.numba_transform import NumbaTransform, \
            NumbaCompose
        self._basic_zoom_trafo = ZoomTransform(3)
        self._numba_zoom_trafo = NumbaTransform(ZoomTransform, zoom_factors=3)
        self._basic_pad_trafo = PadTransform(new_size=(30, 30))
        self._numba_pad_trafo = NumbaTransform(PadTransform, new_size=(30, 30))

        self._basic_compose_trafo = Compose(
            [self._basic_pad_trafo, self._basic_zoom_trafo])
        self._numba_compose_trafo = NumbaCompose(
            [self._basic_pad_trafo, self._basic_zoom_trafo])

        self._input = {"data": np.random.rand(10, 1, 24, 24)}
コード例 #12
0
 def get_validation_transforms(self):
     val_transforms = []
     if self.params.get("selected_data_channels"):
         val_transforms.append(
             DataChannelSelectionTransform(
                 self.params.get("selected_data_channels")))
     if self.params.get("selected_seg_channels"):
         val_transforms.append(
             SegChannelSelectionTransform(
                 self.params.get("selected_seg_channels")))
     val_transforms.append(CenterCropTransform(self.patch_size))
     val_transforms.append(RemoveLabelTransform(-1, 0))
     val_transforms.append(RenameTransform('seg', 'target', True))
     val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
     return Compose(val_transforms)
コード例 #13
0
    def test_image_pipeline(self):
        '''
        This just should not crash
        :return:
        '''
        from batchgenerators.transforms import MirrorTransform, TransposeAxesTransform, Compose

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False)

        for _ in range(50):
            res = mt.next()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
コード例 #14
0
ファイル: augment_data.py プロジェクト: dchealth/covid-mil
def get_augmentation(patch_size,
                     params=default_3D_augmentation_params,
                     border_val_seg=-1):
    print(f'patch size after augmentation {patch_size}')
    tr_transforms = []
    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=3,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=1,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot")))
    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
    tr_transforms = Compose(tr_transforms)
    return tr_transforms
コード例 #15
0
def create_data_gen_train(patient_data_train, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.75, 1.5),
                          gamma_range = (0.6, 2),
                                  num_workers=5, num_cached_per_worker=3,
                                  do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.),
                                  do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi),
                                  do_scale=True, scale_range=(0.75, 1.25), seeds=None):
    if seeds is None:
        seeds = [None]*num_workers
    elif seeds == 'range':
        seeds = range(num_workers)
    else:
        assert len(seeds) == num_workers
    data_gen_train = BatchGenerator3D_random_sampling(patient_data_train, BATCH_SIZE, num_batches=None, seed=False,
                                                          patch_size=(160, 192, 160), convert_labels=True)
    tr_transforms = []
    tr_transforms.append(DataChannelSelectionTransform([0, 1, 2, 3]))
    tr_transforms.append(GenerateBrainMaskTransform())
    tr_transforms.append(MirrorTransform())
    tr_transforms.append(SpatialTransform(INPUT_PATCH_SIZE, list(np.array(INPUT_PATCH_SIZE)//2.),
                                       do_elastic_deform=do_elastic_transform, alpha=alpha, sigma=sigma,
                                       do_rotation=do_rotation, angle_x=a_x, angle_y=a_y, angle_z=a_z,
                                       do_scale=do_scale, scale=scale_range, border_mode_data='nearest',
                                       border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0,
                                       order_seg=0, random_crop=True))
    tr_transforms.append(BrainMaskAwareStretchZeroOneTransform((-5, 5), True))
    tr_transforms.append(ContrastAugmentationTransform(contrast_range, True))
    tr_transforms.append(GammaTransform(gamma_range, False))
    tr_transforms.append(BrainMaskAwareStretchZeroOneTransform(per_channel=True))
    tr_transforms.append(BrightnessTransform(0.0, 0.1, True))
    tr_transforms.append(SegChannelSelectionTransform([0]))
    tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), 0, "seg_onehot"))

    gen_train = MultiThreadedAugmenter(data_gen_train, Compose(tr_transforms), num_workers, num_cached_per_worker,
                                       seeds)
    gen_train.restart()
    return gen_train
コード例 #16
0
def get_data_augmenter(data, batch_size=1,
                       mode=DataLoader.Mode.NORMAL,
                       volumetric=True,
                       normalization_range=None,
                       vector_generator=None,
                       input_shape=None,
                       sample_count=1,
                       transforms=None,
                       threads=1,
                       seed=None):
    transforms = [] if transforms is None else transforms
    threads = min(int(np.ceil(len(data) / batch_size)), threads)
    loader = DataLoader(data=data,
                        batch_size=batch_size,
                        mode=mode,
                        volumetric=volumetric,
                        normalization_range=normalization_range,
                        vector_generator=vector_generator,
                        input_shape=input_shape,
                        sample_count=sample_count,
                        number_of_threads_in_multithreaded=threads,
                        seed=seed)
    transforms = transforms + [PrepareForTF()]
    return MultiThreadedAugmenter(loader, Compose(transforms), threads)
コード例 #17
0
def get_valid_transform(patch_size):
    """
    data augmentation for validation data
    inspired by:
    https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py
    :param patch_size: shape of network's input
    :return list of transformations
    """

    train_transforms = []

    train_transforms.append(
        SpatialTransform_2(patch_size,
                           patch_size,
                           do_elastic_deform=False,
                           deformation_scale=(0, 0),
                           do_rotation=False,
                           angle_x=(0, 0),
                           angle_y=(0, 0),
                           angle_z=(0, 0),
                           do_scale=False,
                           scale=(1.0, 1.0),
                           border_mode_data='constant',
                           border_cval_data=0,
                           border_mode_seg='constant',
                           border_cval_seg=0,
                           order_seg=1,
                           order_data=3,
                           random_crop=True,
                           p_el_per_sample=0.1,
                           p_rot_per_sample=0.1,
                           p_scale_per_sample=0.1))

    train_transforms.append(MirrorTransform(axes=(0, 1)))

    return Compose(train_transforms)
コード例 #18
0
def get_transforms(mode="train", target_size=128):
    tranform_list = []

    if mode == "train":
        tranform_list = [  # CenterCropTransform(crop_size=target_size),
            ResizeTransform(target_size=(target_size, target_size), order=1),
            MirrorTransform(axes=(1, )),
        ]

    elif mode == "val":
        tranform_list = [  #CenterCropTransform(crop_size=target_size),
            ResizeTransform(target_size=target_size, order=1),
            MirrorTransform(axes=(1, )),
        ]

    elif mode == "test":
        tranform_list = [  #CenterCropTransform(crop_size=target_size),
            ResizeTransform(target_size=target_size, order=1),
            MirrorTransform(axes=(1, )),
        ]

    tranform_list.append(NumpyToTensor())

    return Compose(tranform_list)
コード例 #19
0
from delira.training.backends import convert_torch_to_numpy
from functools import partial
from delira.data_loading import DataManager, SequentialSampler
from delira_unet import UNetTorch
from delira import set_debug_mode

if __name__ == "__main__":
    checkpoint_path = ""
    data_path = ""
    save_path = ""

    set_debug_mode(True)

    transforms = Compose([
        CopyTransform("data", "data_orig"),
        # HistogramEqualization(),
        RangeTransform((-1, 1)),
        # AddGridTransform(),
    ])

    img_size = (1024, 256)
    thresh = 0.5

    print("Load Model")
    torch.jit.loat(checkpoint_path)
    model.eval()

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
コード例 #20
0
    def get_training_transforms(self):
        assert self.params.get(
            'mirror'
        ) is None, "old version of params, use new keyword do_mirror"

        tr_transforms = []

        if self.params.get("selected_data_channels"):
            tr_transforms.append(
                DataChannelSelectionTransform(
                    self.params.get("selected_data_channels")))
        if self.params.get("selected_seg_channels"):
            tr_transforms.append(
                SegChannelSelectionTransform(
                    self.params.get("selected_seg_channels")))

        # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
        if self.params.get("dummy_2D", False):
            ignore_axes = (0, )
            tr_transforms.append(Convert3DTo2DTransform())
        else:
            ignore_axes = None

        tr_transforms.append(
            SpatialTransform(
                self._spatial_transform_patch_size,
                patch_center_dist_from_border=None,
                do_elastic_deform=self.params.get("do_elastic"),
                alpha=self.params.get("elastic_deform_alpha"),
                sigma=self.params.get("elastic_deform_sigma"),
                do_rotation=self.params.get("do_rotation"),
                angle_x=self.params.get("rotation_x"),
                angle_y=self.params.get("rotation_y"),
                angle_z=self.params.get("rotation_z"),
                do_scale=self.params.get("do_scaling"),
                scale=self.params.get("scale_range"),
                order_data=self.params.get("order_data"),
                border_mode_data=self.params.get("border_mode_data"),
                border_cval_data=self.params.get("border_cval_data"),
                order_seg=self.params.get("order_seg"),
                border_mode_seg=self.params.get("border_mode_seg"),
                border_cval_seg=self.params.get("border_cval_seg"),
                random_crop=self.params.get("random_crop"),
                p_el_per_sample=self.params.get("p_eldef"),
                p_scale_per_sample=self.params.get("p_scale"),
                p_rot_per_sample=self.params.get("p_rot"),
                independent_scale_for_each_axis=self.params.get(
                    "independent_scale_factor_for_each_axis"),
            ))

        if self.params.get("dummy_2D"):
            tr_transforms.append(Convert2DTo3DTransform())

        # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
        # channel gets in the way
        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
        tr_transforms.append(
            GaussianBlurTransform((0.5, 1.5),
                                  different_sigma_per_channel=True,
                                  p_per_sample=0.2,
                                  p_per_channel=0.5), )
        tr_transforms.append(
            BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.3),
                                              p_per_sample=0.15))
        if self.params.get("do_additive_brightness"):
            tr_transforms.append(
                BrightnessTransform(
                    self.params.get("additive_brightness_mu"),
                    self.params.get("additive_brightness_sigma"),
                    True,
                    p_per_sample=self.params.get(
                        "additive_brightness_p_per_sample"),
                    p_per_channel=self.params.get(
                        "additive_brightness_p_per_channel")))
        tr_transforms.append(
            ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                          p_per_sample=0.15))
        tr_transforms.append(
            SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                           per_channel=True,
                                           p_per_channel=0.5,
                                           order_downsample=0,
                                           order_upsample=3,
                                           p_per_sample=0.25,
                                           ignore_axes=ignore_axes), )
        tr_transforms.append(
            GammaTransform(self.params.get("gamma_range"),
                           True,
                           True,
                           retain_stats=self.params.get("gamma_retain_stats"),
                           p_per_sample=0.15))  # inverted gamma

        if self.params.get("do_gamma"):
            tr_transforms.append(
                GammaTransform(
                    self.params.get("gamma_range"),
                    False,
                    True,
                    retain_stats=self.params.get("gamma_retain_stats"),
                    p_per_sample=self.params["p_gamma"]))
        if self.params.get("do_mirror") or self.params.get("mirror"):
            tr_transforms.append(
                MirrorTransform(self.params.get("mirror_axes")))
        if self.params.get("use_mask_for_norm"):
            use_mask_for_norm = self.params.get("use_mask_for_norm")
            tr_transforms.append(
                MaskTransform(use_mask_for_norm,
                              mask_idx_in_seg=0,
                              set_outside_to=0))

        tr_transforms.append(RemoveLabelTransform(-1, 0))
        tr_transforms.append(RenameTransform('seg', 'target', True))
        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
        return Compose(tr_transforms)
コード例 #21
0
    def get_training_transforms(self):
        assert self.params.get(
            'mirror'
        ) is None, "old version of params, use new keyword do_mirror"
        tr_transforms = []

        if self.params.get("selected_data_channels"):
            tr_transforms.append(
                DataChannelSelectionTransform(
                    self.params.get("selected_data_channels")))

        if self.params.get("selected_seg_channels"):
            tr_transforms.append(
                SegChannelSelectionTransform(
                    self.params.get("selected_seg_channels")))

        if self.params.get("dummy_2D", False):
            # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
            tr_transforms.append(Convert3DTo2DTransform())

        tr_transforms.append(
            SpatialTransform(
                self._spatial_transform_patch_size,
                patch_center_dist_from_border=None,
                do_elastic_deform=self.params.get("do_elastic"),
                alpha=self.params.get("elastic_deform_alpha"),
                sigma=self.params.get("elastic_deform_sigma"),
                do_rotation=self.params.get("do_rotation"),
                angle_x=self.params.get("rotation_x"),
                angle_y=self.params.get("rotation_y"),
                angle_z=self.params.get("rotation_z"),
                do_scale=self.params.get("do_scaling"),
                scale=self.params.get("scale_range"),
                order_data=self.params.get("order_data"),
                border_mode_data=self.params.get("border_mode_data"),
                border_cval_data=self.params.get("border_cval_data"),
                order_seg=self.params.get("order_seg"),
                border_mode_seg=self.params.get("border_mode_seg"),
                border_cval_seg=self.params.get("border_cval_seg"),
                random_crop=self.params.get("random_crop"),
                p_el_per_sample=self.params.get("p_eldef"),
                p_scale_per_sample=self.params.get("p_scale"),
                p_rot_per_sample=self.params.get("p_rot"),
                independent_scale_for_each_axis=self.params.get(
                    "independent_scale_factor_for_each_axis"),
            ))

        if self.params.get("dummy_2D", False):
            tr_transforms.append(Convert2DTo3DTransform())

        if self.params.get("do_gamma", False):
            tr_transforms.append(
                GammaTransform(
                    self.params.get("gamma_range"),
                    False,
                    True,
                    retain_stats=self.params.get("gamma_retain_stats"),
                    p_per_sample=self.params["p_gamma"]))

        if self.params.get("do_mirror", False):
            tr_transforms.append(
                MirrorTransform(self.params.get("mirror_axes")))

        if self.params.get("use_mask_for_norm"):
            use_mask_for_norm = self.params.get("use_mask_for_norm")
            tr_transforms.append(
                MaskTransform(use_mask_for_norm,
                              mask_idx_in_seg=0,
                              set_outside_to=0))

        tr_transforms.append(RemoveLabelTransform(-1, 0))
        tr_transforms.append(RenameTransform('seg', 'target', True))
        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
        return Compose(tr_transforms)
コード例 #22
0
def get_insaneDA_augmentation(dataloader_train,
                              dataloader_val,
                              patch_size,
                              params=default_3D_augmentation_params,
                              border_val_seg=-1,
                              seeds_train=None,
                              seeds_val=None,
                              order_seg=1,
                              order_data=3,
                              deep_supervision_scales=None,
                              soft_ds=False,
                              classes=None,
                              pin_memory=True):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
    else:
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.5),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3),
                                          p_per_sample=0.15))
    tr_transforms.append(
        ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                      p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.15))  # inverted gamma

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations"
                      ) and not None and params.get(
                          "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
コード例 #23
0
def get_no_augmentation(dataloader_train,
                        dataloader_val,
                        params=default_3D_augmentation_params,
                        deep_supervision_scales=None,
                        soft_ds=False,
                        classes=None,
                        pin_memory=True,
                        regions=None):
    """
    use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation
    """
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))

    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=range(params.get('num_threads')),
        pin_memory=pin_memory)
    batchgenerator_train.restart()

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=range(max(params.get('num_threads') // 2, 1)),
        pin_memory=pin_memory)
    batchgenerator_val.restart()
    return batchgenerator_train, batchgenerator_val
コード例 #24
0
def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
                             border_val_seg=-1, pin_memory=True,
                             seeds_train=None, seeds_val=None, regions=None):
    assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())

    tr_transforms.append(SpatialTransform(
        patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
        alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"),
        do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
        angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
        border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant",
        border_cval_seg=border_val_seg,
        order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
        p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
        independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
    ))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
        tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
        tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations") and not None and params.get(
                "cascade_do_cascade_augmentations"):
            tr_transforms.append(ApplyRandomBinaryOperatorTransform(
                channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
                p_per_sample=params.get("cascade_random_binary_transform_p"),
                key="data",
                strel_size=params.get("cascade_random_binary_transform_size")))
            tr_transforms.append(RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
                key="data",
                p_per_sample=params.get("cascade_remove_conn_comp_p"),
                fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
                dont_do_if_covers_more_than_X_percent=params.get("cascade_remove_conn_comp_fill_with_other_class_p")))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))

    tr_transforms = Compose(tr_transforms)
    # from batchgenerators.dataloading import SingleThreadedAugmenter
    # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
    # import IPython;IPython.embed()

    batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
                                                  params.get("num_cached_per_thread"), seeds=seeds_train,
                                                  pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
        val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
    batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
                                                params.get("num_cached_per_thread"), seeds=seeds_val,
                                                pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
コード例 #25
0
train_transforms = [MirrorTransform(axes=(0, 1))]
post_transforms = [
    ToTensor(keys=('data', 'label'),
             dtypes=(('data', torch.float32), ('label', torch.int64)))
]


def init_fn(worker_id):
    seed = torch.utils.data._utils.worker._worker_info.seed
    import numpy as np
    seed32 = np.array(seed).astype(np.uint32)
    np.random.seed(seed32)


train_dset = PytorchDatasetWrapper(ImageFolder(train_path))
train_transforms = Compose(pre_transforms + train_transforms + post_transforms)
train_data = DataLoader(train_dset,
                        shuffle=True,
                        batch_size=32,
                        num_workers=4,
                        batch_transforms=train_transforms,
                        collate_fn=numpy_collate,
                        pin_memory=False,
                        worker_init_fn=init_fn)

val_dset = PytorchDatasetWrapper(ImageFolder(val_path))
val_transforms = Compose(pre_transforms + post_transforms)
val_data = DataLoader(val_dset,
                      shuffle=False,
                      batch_size=32,
                      num_workers=4,
コード例 #26
0
def run(fold=0):
    print fold
    # =================================================================================================================
    I_AM_FOLD = fold
    np.random.seed(65432)
    lasagne.random.set_rng(np.random.RandomState(98765))
    sys.setrecursionlimit(2000)
    BATCH_SIZE = 2
    INPUT_PATCH_SIZE =(128, 128, 128)
    num_classes=4

    EXPERIMENT_NAME = "final"
    results_dir = os.path.join(paths.results_folder)
    if not os.path.isdir(results_dir):
        os.mkdir(results_dir)
    results_dir = os.path.join(results_dir, EXPERIMENT_NAME)
    if not os.path.isdir(results_dir):
        os.mkdir(results_dir)
    results_dir = os.path.join(results_dir, "fold%d"%I_AM_FOLD)
    if not os.path.isdir(results_dir):
        os.mkdir(results_dir)

    n_epochs = 300
    lr_decay = np.float32(0.985)
    base_lr = np.float32(0.0005)
    n_batches_per_epoch = 100
    n_test_batches = 10
    n_feedbacks_per_epoch = 10.
    num_workers = 6
    workers_seeds = [123, 1234, 12345, 123456, 1234567, 12345678]

    # =================================================================================================================

    all_data = load_dataset()
    keys_sorted = np.sort(all_data.keys())

    crossval_folds = KFold(len(all_data.keys()), n_folds=5, shuffle=True, random_state=123456)

    ctr = 0
    for train_idx, test_idx in crossval_folds:
        print len(train_idx), len(test_idx)
        if ctr == I_AM_FOLD:
            train_keys = [keys_sorted[i] for i in train_idx]
            test_keys = [keys_sorted[i] for i in test_idx]
            break
        ctr += 1

    train_data = {i:all_data[i] for i in train_keys}
    test_data = {i:all_data[i] for i in test_keys}

    data_gen_train = create_data_gen_train(train_data, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE,
                                           contrast_range=(0.75, 1.5), gamma_range = (0.8, 1.5),
                                           num_workers=num_workers, num_cached_per_worker=2,
                                           do_elastic_transform=True, alpha=(0., 1300.), sigma=(10., 13.),
                                           do_rotation=True, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi),
                                           do_scale=True, scale_range=(0.75, 1.25), seeds=workers_seeds)

    data_gen_validation = BatchGenerator3D_random_sampling(test_data, BATCH_SIZE, num_batches=None, seed=False,
                                                           patch_size=INPUT_PATCH_SIZE, convert_labels=True)
    val_transforms = []
    val_transforms.append(GenerateBrainMaskTransform())
    val_transforms.append(BrainMaskAwareStretchZeroOneTransform(clip_range=(-5, 5), per_channel=True))
    val_transforms.append(SegChannelSelectionTransform([0]))
    val_transforms.append(ConvertSegToOnehotTransform(range(4), 0, "seg_onehot"))
    val_transforms.append(DataChannelSelectionTransform([0, 1, 2, 3]))
    data_gen_validation = MultiThreadedAugmenter(data_gen_validation, Compose(val_transforms), 2, 2)

    x_sym = T.tensor5()
    seg_sym = T.matrix()

    net, seg_layer = build_net(x_sym, INPUT_PATCH_SIZE, num_classes, 4, 16, batch_size=BATCH_SIZE,
                               do_instance_norm=True)
    output_layer_for_loss = net

    # add some weight decay
    l2_loss = lasagne.regularization.regularize_network_params(output_layer_for_loss, lasagne.regularization.l2) * 1e-5

    # the distinction between prediction_train and test is important only if we enable dropout (batch norm/inst norm
    # does not use or save moving averages)
    prediction_train = lasagne.layers.get_output(output_layer_for_loss, x_sym, deterministic=False,
                                                 batch_norm_update_averages=False, batch_norm_use_averages=False)

    loss_vec = - soft_dice_per_img_in_batch(prediction_train, seg_sym, BATCH_SIZE)[:, 1:]

    loss = loss_vec.mean()
    loss += l2_loss
    acc_train = T.mean(T.eq(T.argmax(prediction_train, axis=1), seg_sym.argmax(-1)), dtype=theano.config.floatX)

    prediction_test = lasagne.layers.get_output(output_layer_for_loss, x_sym, deterministic=True,
                                                batch_norm_update_averages=False, batch_norm_use_averages=False)
    loss_val = - soft_dice_per_img_in_batch(prediction_test, seg_sym, BATCH_SIZE)[:, 1:]

    loss_val = loss_val.mean()
    loss_val += l2_loss
    acc = T.mean(T.eq(T.argmax(prediction_test, axis=1), seg_sym.argmax(-1)), dtype=theano.config.floatX)

    # learning rate has to be a shared variable because we decrease it with every epoch
    params = lasagne.layers.get_all_params(output_layer_for_loss, trainable=True)
    learning_rate = theano.shared(base_lr)
    updates = lasagne.updates.adam(T.grad(loss, params), params, learning_rate=learning_rate, beta1=0.9, beta2=0.999)

    dc = hard_dice_per_img_in_batch(prediction_test, seg_sym.argmax(1), num_classes, BATCH_SIZE).mean(0)

    train_fn = theano.function([x_sym, seg_sym], [loss, acc_train, loss_vec], updates=updates)
    val_fn = theano.function([x_sym, seg_sym], [loss_val, acc, dc])

    all_val_dice_scores=None

    all_training_losses = []
    all_validation_losses = []
    all_validation_accuracies = []
    all_training_accuracies = []
    val_dice_scores = []
    epoch = 0

    while epoch < n_epochs:
        if epoch == 100:
            data_gen_train = create_data_gen_train(train_data, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE,
                                                   contrast_range=(0.85, 1.25), gamma_range = (0.8, 1.5),
                                                   num_workers=6, num_cached_per_worker=2,
                                                   do_elastic_transform=True, alpha=(0., 1000.), sigma=(10., 13.),
                                                   do_rotation=True, a_x=(0., 2*np.pi), a_y=(-np.pi/8., np.pi/8.),
                                                   a_z=(-np.pi/8., np.pi/8.), do_scale=True, scale_range=(0.85, 1.15),
                                                   seeds=workers_seeds)

        if epoch == 175:
            data_gen_train = create_data_gen_train(train_data, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE,
                                                   contrast_range=(0.9, 1.1), gamma_range = (0.85, 1.3),
                                                   num_workers=6, num_cached_per_worker=2,
                                                   do_elastic_transform=True, alpha=(0., 750.), sigma=(10., 13.),
                                                   do_rotation=True, a_x=(0., 2*np.pi), a_y=(-0.00001, 0.00001),
                                                   a_z=(-0.00001, 0.00001), do_scale=True, scale_range=(0.85, 1.15),
                                                   seeds=workers_seeds)

        epoch_start_time = time.time()
        learning_rate.set_value(np.float32(base_lr* lr_decay**(epoch)))
        print "epoch: ", epoch, " learning rate: ", learning_rate.get_value()
        train_loss = 0
        train_acc_tmp = 0
        train_loss_tmp = 0
        batch_ctr = 0
        for data_dict in data_gen_train:
            data = data_dict["data"].astype(np.float32)
            seg = data_dict["seg_onehot"].astype(np.float32).transpose(0, 2, 3, 4, 1).reshape((-1, num_classes))
            if batch_ctr != 0 and batch_ctr % int(np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)) == 0:
                print "number of batches: ", batch_ctr, "/", n_batches_per_epoch
                print "training_loss since last update: ", \
                    train_loss_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch), " train accuracy: ", \
                    train_acc_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)
                all_training_losses.append(train_loss_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch))
                all_training_accuracies.append(train_acc_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch))
                train_loss_tmp = 0
                train_acc_tmp = 0
                if len(val_dice_scores) > 0:
                    all_val_dice_scores = np.concatenate(val_dice_scores, axis=0).reshape((-1, num_classes))
                try:
                    printLosses(all_training_losses, all_training_accuracies, all_validation_losses,
                                all_validation_accuracies, os.path.join(results_dir, "%s.png" % EXPERIMENT_NAME),
                                n_feedbacks_per_epoch, val_dice_scores=all_val_dice_scores,
                                val_dice_scores_labels=["brain", "1", "2", "3", "4", "5"])
                except:
                    pass
            loss_vec, acc, l = train_fn(data, seg)

            loss = loss_vec.mean()
            train_loss += loss
            train_loss_tmp += loss
            train_acc_tmp += acc
            batch_ctr += 1
            if batch_ctr >= n_batches_per_epoch:
                break
        all_training_losses.append(train_loss_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch))
        all_training_accuracies.append(train_acc_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch))
        train_loss /= n_batches_per_epoch
        print "training loss average on epoch: ", train_loss

        val_loss = 0
        accuracies = []
        valid_batch_ctr = 0
        all_dice = []
        for data_dict in data_gen_validation:
            data = data_dict["data"].astype(np.float32)
            seg = data_dict["seg_onehot"].astype(np.float32).transpose(0, 2, 3, 4, 1).reshape((-1, num_classes))
            w = np.zeros(num_classes, dtype=np.float32)
            w[np.unique(seg.argmax(-1))] = 1
            loss, acc, dice = val_fn(data, seg)
            dice[w==0] = 2
            all_dice.append(dice)
            val_loss += loss
            accuracies.append(acc)
            valid_batch_ctr += 1
            if valid_batch_ctr >= n_test_batches:
                break
        all_dice = np.vstack(all_dice)
        dice_means = np.zeros(num_classes)
        for i in range(num_classes):
            dice_means[i] = all_dice[all_dice[:, i]!=2, i].mean()
        val_loss /= n_test_batches
        print "val loss: ", val_loss
        print "val acc: ", np.mean(accuracies), "\n"
        print "val dice: ", dice_means
        print "This epoch took %f sec" % (time.time()-epoch_start_time)
        val_dice_scores.append(dice_means)
        all_validation_losses.append(val_loss)
        all_validation_accuracies.append(np.mean(accuracies))
        all_val_dice_scores = np.concatenate(val_dice_scores, axis=0).reshape((-1, num_classes))
        try:
            printLosses(all_training_losses, all_training_accuracies, all_validation_losses, all_validation_accuracies,
                        os.path.join(results_dir, "%s.png" % EXPERIMENT_NAME), n_feedbacks_per_epoch,
                        val_dice_scores=all_val_dice_scores, val_dice_scores_labels=["brain", "1", "2", "3", "4", "5"])
        except:
            pass
        with open(os.path.join(results_dir, "%s_Params.pkl" % (EXPERIMENT_NAME)), 'w') as f:
            cPickle.dump(lasagne.layers.get_all_param_values(output_layer_for_loss), f)
        with open(os.path.join(results_dir, "%s_allLossesNAccur.pkl"% (EXPERIMENT_NAME)), 'w') as f:
            cPickle.dump([all_training_losses, all_training_accuracies, all_validation_losses,
                          all_validation_accuracies, val_dice_scores], f)
        epoch += 1
コード例 #27
0
                      p_rot_per_axis=params.get("rotation_p_per_axis"),
                      do_scale=params.get("do_scaling"),
                      scale=params.get("scale_range"),
                      border_mode_data=params.get("border_mode_data"),
                      border_cval_data=0,
                      order_data=order_data,
                      border_mode_seg="constant",
                      border_cval_seg=border_val_seg,
                      order_seg=order_seg,
                      random_crop=params.get("random_crop"),
                      p_el_per_sample=params.get("p_eldef"),
                      p_scale_per_sample=params.get("p_scale"),
                      p_rot_per_sample=params.get("p_rot"),
                      independent_scale_for_each_axis=params.get(
                          "independent_scale_factor_for_each_axis")))
 tr_transforms = Compose(tr_transforms)
 batchgenerator_train = MultiThreadedAugmenter(
     dataloader_train,
     tr_transforms,
     params.get('num_threads'),
     params.get("num_cached_per_thread"),
     pin_memory=True)
 train_loader = batchgenerator_train
 train_batch = next(train_loader)
 print(
     train_batch.keys()
 )  # dict_keys(['data', 'target']), each with torch.Size([2, 1, 112, 240, 272])
 print((train_batch['seg'] - train_batch['weak_label']).sum())
 train_batch = next(train_loader)
 print((train_batch['seg'] - train_batch['weak_label']).sum())
 ipdb.set_trace()
コード例 #28
0
dataset_test = ConditionalGanDataset(path_test_real, load_sample_cgan_test, ['.PNG', '.png'], ['.PNG', '.png'])




### Transforms applied to data

from batchgenerators.transforms import RandomCropTransform, Compose
from batchgenerators.transforms.spatial_transforms import ResizeTransform,SpatialTransform


transforms = Compose([
    #SpatialTransform(patch_size=(1024, 1024), do_rotation=True, patch_center_dist_from_border=1024, border_mode_data='reflect',
     #                border_mode_seg='reflect', angle_x=(args.rot_angle, args.rot_angle), angle_y=(0, 0), angle_z=(0, 0),
      #               do_elastic_deform=False, order_data=1, order_seg=1)
    ResizeTransform((int(args.resize_size), int(args.resize_size)), order=1),
    RandomCropTransform((params.nested_get("image_size"), params.nested_get("image_size"))),
    ])


from delira.data_loading import BaseDataManager, SequentialSampler, RandomSampler
manager_test = BaseDataManager(dataset_test, params.nested_get("batch_size"),
                              transforms=transforms,
                              sampler_cls=SequentialSampler,
                              n_process_augmentation=1)

import warnings
warnings.simplefilter("ignore", UserWarning) # ignore UserWarnings raised by dependency code
warnings.simplefilter("ignore", FutureWarning) # ignore FutureWarnings raised by dependency code
コード例 #29
0
def get_default_augmentation_withEDT(dataloader_train,
                                     dataloader_val,
                                     patch_size,
                                     idx_of_edts,
                                     params=default_3D_augmentation_params,
                                     border_val_seg=-1,
                                     pin_memory=True,
                                     seeds_train=None,
                                     seeds_val=None):
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=3,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=1,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot")))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())
    """
    ##############################################################
    ##############################################################
    Here we insert moving the EDT to a different key so that it does not get intensity transformed
    ##############################################################
    ##############################################################
    """
    tr_transforms.append(
        AppendChannelsTransform("data",
                                "bound",
                                idx_of_edts,
                                remove_from_input=True))

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get(
                "advanced_pyramid_augmentations") and not None and params.get(
                    "advanced_pyramid_augmentations"):
            tr_transforms.append(
                ApplyRandomBinaryOperatorTransform(channel_idx=list(
                    range(-len(params.get("all_segmentation_labels")), 0)),
                                                   p_per_sample=0.4,
                                                   key="data",
                                                   strel_size=(1, 8)))
            tr_transforms.append(
                RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                    channel_idx=list(
                        range(-len(params.get("all_segmentation_labels")), 0)),
                    key="data",
                    p_per_sample=0.2,
                    fill_with_other_class_p=0.0,
                    dont_do_if_covers_more_than_X_percent=0.15))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))
    """
    ##############################################################
    ##############################################################
    Here we insert moving the EDT to a different key
    ##############################################################
    ##############################################################
    """
    val_transforms.append(
        AppendChannelsTransform("data",
                                "bound",
                                idx_of_edts,
                                remove_from_input=True))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))
    val_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
コード例 #30
0
 def run(self, img_data, seg_data):
     # Create a parser for the batchgenerators module
     data_generator = DataParser(img_data, seg_data)
     # Initialize empty transform list
     transforms = []
     # Add mirror augmentation
     if self.mirror:
         aug_mirror = MirrorTransform(axes=self.config_mirror_axes)
         transforms.append(aug_mirror)
     # Add contrast augmentation
     if self.contrast:
         aug_contrast = ContrastAugmentationTransform(
             self.config_contrast_range,
             preserve_range=True,
             per_channel=True,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_contrast)
     # Add brightness augmentation
     if self.brightness:
         aug_brightness = BrightnessMultiplicativeTransform(
             self.config_brightness_range,
             per_channel=True,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_brightness)
     # Add gamma augmentation
     if self.gamma:
         aug_gamma = GammaTransform(self.config_gamma_range,
                                    invert_image=False,
                                    per_channel=True,
                                    retain_stats=True,
                                    p_per_sample=self.config_p_per_sample)
         transforms.append(aug_gamma)
     # Add gaussian noise augmentation
     if self.gaussian_noise:
         aug_gaussian_noise = GaussianNoiseTransform(
             self.config_gaussian_noise_range,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_gaussian_noise)
     # Add spatial transformations as augmentation
     # (rotation, scaling, elastic deformation)
     if self.rotations or self.scaling or self.elastic_deform or \
         self.cropping:
         # Identify patch shape (full image or cropping)
         if self.cropping: patch_shape = self.cropping_patch_shape
         else: patch_shape = img_data[0].shape[0:-1]
         # Assembling the spatial transformation
         aug_spatial_transform = SpatialTransform(
             patch_shape, [i // 2 for i in patch_shape],
             do_elastic_deform=self.elastic_deform,
             alpha=self.config_elastic_deform_alpha,
             sigma=self.config_elastic_deform_sigma,
             do_rotation=self.rotations,
             angle_x=self.config_rotations_angleX,
             angle_y=self.config_rotations_angleY,
             angle_z=self.config_rotations_angleZ,
             do_scale=self.scaling,
             scale=self.config_scaling_range,
             border_mode_data='constant',
             border_cval_data=0,
             border_mode_seg='constant',
             border_cval_seg=0,
             order_data=3,
             order_seg=0,
             p_el_per_sample=self.config_p_per_sample,
             p_rot_per_sample=self.config_p_per_sample,
             p_scale_per_sample=self.config_p_per_sample,
             random_crop=self.cropping)
         # Append spatial transformation to transformation list
         transforms.append(aug_spatial_transform)
     # Compose the batchgenerators transforms
     all_transforms = Compose(transforms)
     # Assemble transforms into a augmentation generator
     augmentation_generator = SingleThreadedAugmenter(
         data_generator, all_transforms)
     # Perform the data augmentation x times (x = cycles)
     aug_img_data = None
     aug_seg_data = None
     for i in range(0, self.cycles):
         # Run the computation process for the data augmentations
         augmentation = next(augmentation_generator)
         # Access augmentated data from the batchgenerators data structure
         if aug_img_data is None and aug_seg_data is None:
             aug_img_data = augmentation["data"]
             aug_seg_data = augmentation["seg"]
         # Concatenate the new data augmentated data with the cached data
         else:
             aug_img_data = np.concatenate(
                 (augmentation["data"], aug_img_data), axis=0)
             aug_seg_data = np.concatenate(
                 (augmentation["seg"], aug_seg_data), axis=0)
     # Transform data from channel-first back to channel-last structure
     # Data structure channel-first 3D:  (batch, channel, x, y, z)
     # Data structure channel-last 3D:   (batch, x, y, z, channel)
     aug_img_data = np.moveaxis(aug_img_data, 1, -1)
     aug_seg_data = np.moveaxis(aug_seg_data, 1, -1)
     # Return augmentated image and segmentation data
     return aug_img_data, aug_seg_data