def test_image_pipeline_and_pin_memory(self):
        '''
        This just should not crash
        :return:
        '''
        try:
            import torch
        except ImportError:
            '''dont test if torch is not installed'''
            return

        from batchgenerators.transforms import MirrorTransform, NumpyToTensor, TransposeAxesTransform, Compose

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(
            TransposeAxesTransform(transpose_any_of_these=(0, 1),
                                   p_per_sample=0.5))
        tr_transforms.append(NumpyToTensor(keys='data', cast_to='float'))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True)

        for _ in range(50):
            res = mt.next()

        assert isinstance(res['data'], torch.Tensor)
        assert res['data'].is_pinned()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
Beispiel #2
0
def get_transforms(mode="train", target_size=128):
    tranform_list = []

    if mode == "train":
        tranform_list = [# CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         MirrorTransform(axes=(1,)),
                         SpatialTransform(patch_size=(target_size,target_size), random_crop=False,
                                          patch_center_dist_from_border=target_size // 2,
                                          do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.),
                                          do_rotation=True, p_rot_per_sample=0.5,
                                          angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8),
                                          scale=(0.5, 1.9), p_scale_per_sample=0.5,
                                          border_mode_data="nearest", border_mode_seg="nearest"),
                         ]


    elif mode == "val":
        tranform_list = [CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         ]

    elif mode == "test":
        tranform_list = [CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         ]

    tranform_list.append(NumpyToTensor())

    return Compose(tranform_list)
Beispiel #3
0
def create_data_gen_train(patient_data_train, BATCH_SIZE, num_classes,
                          num_workers=2, num_cached_per_worker=2,
                          do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.),
                          do_rotation=False, a_x=(0., 2 * np.pi), a_y=(0., 2 * np.pi), a_z=(0., 2 * np.pi),
                          do_scale=True, scale_range=(0.75, 1.25), seeds=None):
    if seeds is None:
        seeds = [None] * num_workers
    elif seeds == 'range':
        seeds = range(num_workers)
    else:
        assert len(seeds) == num_workers
    data_gen_train = BatchGenerator_2D(patient_data_train, BATCH_SIZE, num_batches=None, seed=False,
                                       PATCH_SIZE=(352, 352))

    tr_transforms = []
    tr_transforms.append(MirrorTransform((0, 1)))
    tr_transforms.append(RndTransform(SpatialTransform((352, 352), list(np.array((352, 352)) // 2),
                                                       do_elastic_transform, alpha,
                                                       sigma,
                                                       do_rotation, a_x, a_y,
                                                       a_z,
                                                       do_scale, scale_range, 'constant', 0, 3, 'constant',
                                                       0, 0,
                                                       random_crop=False), prob=0.67,
                                      alternative_transform=RandomCropTransform((352, 352))))
    tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), seg_channel=0, output_key='seg_onehot'))

    tr_composed = Compose(tr_transforms)
    tr_mt_gen = MultiThreadedAugmenter(data_gen_train, tr_composed, num_workers, num_cached_per_worker, seeds)
    tr_mt_gen.restart()
    # tr_mt_gen = SingleThreadedAugmenter(data_gen_train, tr_composed)
    return tr_mt_gen
Beispiel #4
0
def get_transforms(mode="train", target_size=128):
    tranform_list = []

    if mode == "train":
        tranform_list = [  # CenterCropTransform(crop_size=target_size),
            ResizeTransform(target_size=(target_size, target_size), order=1),
            MirrorTransform(axes=(1, )),
        ]

    elif mode == "val":
        tranform_list = [  #CenterCropTransform(crop_size=target_size),
            ResizeTransform(target_size=target_size, order=1),
            MirrorTransform(axes=(1, )),
        ]

    elif mode == "test":
        tranform_list = [  #CenterCropTransform(crop_size=target_size),
            ResizeTransform(target_size=target_size, order=1),
            MirrorTransform(axes=(1, )),
        ]

    tranform_list.append(NumpyToTensor())

    return Compose(tranform_list)
Beispiel #5
0
    def __init__(self, im_4D):
        self.im_4D = im_4D

        self.mirror_transform = MirrorTransform(axes=(0, 1, 2))
        self.spatial_transform = SpatialTransform(patch_size=None,
                                                  do_elastic_deform=False,
                                                  alpha=(0., 1000.),
                                                  sigma=(10., 13.),
                                                  do_rotation=True,
                                                  angle_x=(0, 0),
                                                  angle_y=(0, 0),
                                                  angle_z=(0, 2 * np.pi),
                                                  do_scale=True,
                                                  scale=(0.75, 1.25),
                                                  border_mode_data='constant',
                                                  border_cval_data=0,
                                                  order_data=1,
                                                  random_crop=False)
Beispiel #6
0
    def test_image_pipeline(self):
        '''
        This just should not crash
        :return:
        '''
        from batchgenerators.transforms import MirrorTransform, TransposeAxesTransform, Compose

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False)

        for _ in range(50):
            res = mt.next()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
Beispiel #7
0
def get_augmentation(patch_size,
                     params=default_3D_augmentation_params,
                     border_val_seg=-1):
    print(f'patch size after augmentation {patch_size}')
    tr_transforms = []
    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=3,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=1,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot")))
    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
    tr_transforms = Compose(tr_transforms)
    return tr_transforms
Beispiel #8
0
def create_data_gen_train(patient_data_train, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.75, 1.5),
                          gamma_range = (0.6, 2),
                                  num_workers=5, num_cached_per_worker=3,
                                  do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.),
                                  do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi),
                                  do_scale=True, scale_range=(0.75, 1.25), seeds=None):
    if seeds is None:
        seeds = [None]*num_workers
    elif seeds == 'range':
        seeds = range(num_workers)
    else:
        assert len(seeds) == num_workers
    data_gen_train = BatchGenerator3D_random_sampling(patient_data_train, BATCH_SIZE, num_batches=None, seed=False,
                                                          patch_size=(160, 192, 160), convert_labels=True)
    tr_transforms = []
    tr_transforms.append(DataChannelSelectionTransform([0, 1, 2, 3]))
    tr_transforms.append(GenerateBrainMaskTransform())
    tr_transforms.append(MirrorTransform())
    tr_transforms.append(SpatialTransform(INPUT_PATCH_SIZE, list(np.array(INPUT_PATCH_SIZE)//2.),
                                       do_elastic_deform=do_elastic_transform, alpha=alpha, sigma=sigma,
                                       do_rotation=do_rotation, angle_x=a_x, angle_y=a_y, angle_z=a_z,
                                       do_scale=do_scale, scale=scale_range, border_mode_data='nearest',
                                       border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0,
                                       order_seg=0, random_crop=True))
    tr_transforms.append(BrainMaskAwareStretchZeroOneTransform((-5, 5), True))
    tr_transforms.append(ContrastAugmentationTransform(contrast_range, True))
    tr_transforms.append(GammaTransform(gamma_range, False))
    tr_transforms.append(BrainMaskAwareStretchZeroOneTransform(per_channel=True))
    tr_transforms.append(BrightnessTransform(0.0, 0.1, True))
    tr_transforms.append(SegChannelSelectionTransform([0]))
    tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), 0, "seg_onehot"))

    gen_train = MultiThreadedAugmenter(data_gen_train, Compose(tr_transforms), num_workers, num_cached_per_worker,
                                       seeds)
    gen_train.restart()
    return gen_train
def get_default_augmentation_withEDT(dataloader_train,
                                     dataloader_val,
                                     patch_size,
                                     idx_of_edts,
                                     params=default_3D_augmentation_params,
                                     border_val_seg=-1,
                                     pin_memory=True,
                                     seeds_train=None,
                                     seeds_val=None):
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=3,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=1,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot")))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())
    """
    ##############################################################
    ##############################################################
    Here we insert moving the EDT to a different key so that it does not get intensity transformed
    ##############################################################
    ##############################################################
    """
    tr_transforms.append(
        AppendChannelsTransform("data",
                                "bound",
                                idx_of_edts,
                                remove_from_input=True))

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get(
                "advanced_pyramid_augmentations") and not None and params.get(
                    "advanced_pyramid_augmentations"):
            tr_transforms.append(
                ApplyRandomBinaryOperatorTransform(channel_idx=list(
                    range(-len(params.get("all_segmentation_labels")), 0)),
                                                   p_per_sample=0.4,
                                                   key="data",
                                                   strel_size=(1, 8)))
            tr_transforms.append(
                RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                    channel_idx=list(
                        range(-len(params.get("all_segmentation_labels")), 0)),
                    key="data",
                    p_per_sample=0.2,
                    fill_with_other_class_p=0.0,
                    dont_do_if_covers_more_than_X_percent=0.15))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))
    """
    ##############################################################
    ##############################################################
    Here we insert moving the EDT to a different key
    ##############################################################
    ##############################################################
    """
    val_transforms.append(
        AppendChannelsTransform("data",
                                "bound",
                                idx_of_edts,
                                remove_from_input=True))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))
    val_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
                             border_val_seg=-1, pin_memory=True,
                             seeds_train=None, seeds_val=None, regions=None):
    assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())

    tr_transforms.append(SpatialTransform(
        patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
        alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"),
        do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
        angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
        border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant",
        border_cval_seg=border_val_seg,
        order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
        p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
        independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
    ))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
        tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
        tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations") and not None and params.get(
                "cascade_do_cascade_augmentations"):
            tr_transforms.append(ApplyRandomBinaryOperatorTransform(
                channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
                p_per_sample=params.get("cascade_random_binary_transform_p"),
                key="data",
                strel_size=params.get("cascade_random_binary_transform_size")))
            tr_transforms.append(RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
                key="data",
                p_per_sample=params.get("cascade_remove_conn_comp_p"),
                fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
                dont_do_if_covers_more_than_X_percent=params.get("cascade_remove_conn_comp_fill_with_other_class_p")))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))

    tr_transforms = Compose(tr_transforms)
    # from batchgenerators.dataloading import SingleThreadedAugmenter
    # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
    # import IPython;IPython.embed()

    batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
                                                  params.get("num_cached_per_thread"), seeds=seeds_train,
                                                  pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
        val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
    batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
                                                params.get("num_cached_per_thread"), seeds=seeds_val,
                                                pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
Beispiel #11
0
def Transforms(patch_size,
               params=default_3D_augmentation_params,
               border_val_seg=-1):
    tr_transforms = []
    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(params.get("selected_data_channels"),
                                          data_key="data"))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())
    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=3,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=1,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot")))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get(
                "advanced_pyramid_augmentations") and not None and params.get(
                    "advanced_pyramid_augmentations"):
            tr_transforms.append(
                ApplyRandomBinaryOperatorTransform(channel_idx=list(
                    range(-len(params.get("all_segmentation_labels")), 0)),
                                                   p_per_sample=0.4,
                                                   key="data",
                                                   strel_size=(1, 8)))
            tr_transforms.append(
                RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                    channel_idx=list(
                        range(-len(params.get("all_segmentation_labels")), 0)),
                    key="data",
                    p_per_sample=0.2,
                    fill_with_other_class_p=0.0,
                    dont_do_if_covers_more_than_X_percent=0.15))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)
    return tr_transforms
Beispiel #12
0
    def get_training_transforms(self):
        assert self.params.get(
            'mirror'
        ) is None, "old version of params, use new keyword do_mirror"

        tr_transforms = []

        if self.params.get("selected_data_channels"):
            tr_transforms.append(
                DataChannelSelectionTransform(
                    self.params.get("selected_data_channels")))
        if self.params.get("selected_seg_channels"):
            tr_transforms.append(
                SegChannelSelectionTransform(
                    self.params.get("selected_seg_channels")))

        # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
        if self.params.get("dummy_2D", False):
            ignore_axes = (0, )
            tr_transforms.append(Convert3DTo2DTransform())
        else:
            ignore_axes = None

        tr_transforms.append(
            SpatialTransform(
                self._spatial_transform_patch_size,
                patch_center_dist_from_border=None,
                do_elastic_deform=self.params.get("do_elastic"),
                alpha=self.params.get("elastic_deform_alpha"),
                sigma=self.params.get("elastic_deform_sigma"),
                do_rotation=self.params.get("do_rotation"),
                angle_x=self.params.get("rotation_x"),
                angle_y=self.params.get("rotation_y"),
                angle_z=self.params.get("rotation_z"),
                do_scale=self.params.get("do_scaling"),
                scale=self.params.get("scale_range"),
                order_data=self.params.get("order_data"),
                border_mode_data=self.params.get("border_mode_data"),
                border_cval_data=self.params.get("border_cval_data"),
                order_seg=self.params.get("order_seg"),
                border_mode_seg=self.params.get("border_mode_seg"),
                border_cval_seg=self.params.get("border_cval_seg"),
                random_crop=self.params.get("random_crop"),
                p_el_per_sample=self.params.get("p_eldef"),
                p_scale_per_sample=self.params.get("p_scale"),
                p_rot_per_sample=self.params.get("p_rot"),
                independent_scale_for_each_axis=self.params.get(
                    "independent_scale_factor_for_each_axis"),
            ))

        if self.params.get("dummy_2D"):
            tr_transforms.append(Convert2DTo3DTransform())

        # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
        # channel gets in the way
        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
        tr_transforms.append(
            GaussianBlurTransform((0.5, 1.5),
                                  different_sigma_per_channel=True,
                                  p_per_sample=0.2,
                                  p_per_channel=0.5), )
        tr_transforms.append(
            BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.3),
                                              p_per_sample=0.15))
        if self.params.get("do_additive_brightness"):
            tr_transforms.append(
                BrightnessTransform(
                    self.params.get("additive_brightness_mu"),
                    self.params.get("additive_brightness_sigma"),
                    True,
                    p_per_sample=self.params.get(
                        "additive_brightness_p_per_sample"),
                    p_per_channel=self.params.get(
                        "additive_brightness_p_per_channel")))
        tr_transforms.append(
            ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                          p_per_sample=0.15))
        tr_transforms.append(
            SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                           per_channel=True,
                                           p_per_channel=0.5,
                                           order_downsample=0,
                                           order_upsample=3,
                                           p_per_sample=0.25,
                                           ignore_axes=ignore_axes), )
        tr_transforms.append(
            GammaTransform(self.params.get("gamma_range"),
                           True,
                           True,
                           retain_stats=self.params.get("gamma_retain_stats"),
                           p_per_sample=0.15))  # inverted gamma

        if self.params.get("do_gamma"):
            tr_transforms.append(
                GammaTransform(
                    self.params.get("gamma_range"),
                    False,
                    True,
                    retain_stats=self.params.get("gamma_retain_stats"),
                    p_per_sample=self.params["p_gamma"]))
        if self.params.get("do_mirror") or self.params.get("mirror"):
            tr_transforms.append(
                MirrorTransform(self.params.get("mirror_axes")))
        if self.params.get("use_mask_for_norm"):
            use_mask_for_norm = self.params.get("use_mask_for_norm")
            tr_transforms.append(
                MaskTransform(use_mask_for_norm,
                              mask_idx_in_seg=0,
                              set_outside_to=0))

        tr_transforms.append(RemoveLabelTransform(-1, 0))
        tr_transforms.append(RenameTransform('seg', 'target', True))
        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
        return Compose(tr_transforms)
Beispiel #13
0
    def get_training_transforms(self):
        assert self.params.get(
            'mirror'
        ) is None, "old version of params, use new keyword do_mirror"
        tr_transforms = []

        if self.params.get("selected_data_channels"):
            tr_transforms.append(
                DataChannelSelectionTransform(
                    self.params.get("selected_data_channels")))

        if self.params.get("selected_seg_channels"):
            tr_transforms.append(
                SegChannelSelectionTransform(
                    self.params.get("selected_seg_channels")))

        if self.params.get("dummy_2D", False):
            # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
            tr_transforms.append(Convert3DTo2DTransform())

        tr_transforms.append(
            SpatialTransform(
                self._spatial_transform_patch_size,
                patch_center_dist_from_border=None,
                do_elastic_deform=self.params.get("do_elastic"),
                alpha=self.params.get("elastic_deform_alpha"),
                sigma=self.params.get("elastic_deform_sigma"),
                do_rotation=self.params.get("do_rotation"),
                angle_x=self.params.get("rotation_x"),
                angle_y=self.params.get("rotation_y"),
                angle_z=self.params.get("rotation_z"),
                do_scale=self.params.get("do_scaling"),
                scale=self.params.get("scale_range"),
                order_data=self.params.get("order_data"),
                border_mode_data=self.params.get("border_mode_data"),
                border_cval_data=self.params.get("border_cval_data"),
                order_seg=self.params.get("order_seg"),
                border_mode_seg=self.params.get("border_mode_seg"),
                border_cval_seg=self.params.get("border_cval_seg"),
                random_crop=self.params.get("random_crop"),
                p_el_per_sample=self.params.get("p_eldef"),
                p_scale_per_sample=self.params.get("p_scale"),
                p_rot_per_sample=self.params.get("p_rot"),
                independent_scale_for_each_axis=self.params.get(
                    "independent_scale_factor_for_each_axis"),
            ))

        if self.params.get("dummy_2D", False):
            tr_transforms.append(Convert2DTo3DTransform())

        if self.params.get("do_gamma", False):
            tr_transforms.append(
                GammaTransform(
                    self.params.get("gamma_range"),
                    False,
                    True,
                    retain_stats=self.params.get("gamma_retain_stats"),
                    p_per_sample=self.params["p_gamma"]))

        if self.params.get("do_mirror", False):
            tr_transforms.append(
                MirrorTransform(self.params.get("mirror_axes")))

        if self.params.get("use_mask_for_norm"):
            use_mask_for_norm = self.params.get("use_mask_for_norm")
            tr_transforms.append(
                MaskTransform(use_mask_for_norm,
                              mask_idx_in_seg=0,
                              set_outside_to=0))

        tr_transforms.append(RemoveLabelTransform(-1, 0))
        tr_transforms.append(RenameTransform('seg', 'target', True))
        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
        return Compose(tr_transforms)
Beispiel #14
0
def get_insaneDA_augmentation(dataloader_train,
                              dataloader_val,
                              patch_size,
                              params=default_3D_augmentation_params,
                              border_val_seg=-1,
                              seeds_train=None,
                              seeds_val=None,
                              order_seg=1,
                              order_data=3,
                              deep_supervision_scales=None,
                              soft_ds=False,
                              classes=None,
                              pin_memory=True):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
    else:
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.5),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3),
                                          p_per_sample=0.15))
    tr_transforms.append(
        ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                      p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.15))  # inverted gamma

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations"
                      ) and not None and params.get(
                          "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
Beispiel #15
0
 def run(self, img_data, seg_data):
     # Create a parser for the batchgenerators module
     data_generator = DataParser(img_data, seg_data)
     # Initialize empty transform list
     transforms = []
     # Add mirror augmentation
     if self.mirror:
         aug_mirror = MirrorTransform(axes=self.config_mirror_axes)
         transforms.append(aug_mirror)
     # Add contrast augmentation
     if self.contrast:
         aug_contrast = ContrastAugmentationTransform(
             self.config_contrast_range,
             preserve_range=True,
             per_channel=True,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_contrast)
     # Add brightness augmentation
     if self.brightness:
         aug_brightness = BrightnessMultiplicativeTransform(
             self.config_brightness_range,
             per_channel=True,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_brightness)
     # Add gamma augmentation
     if self.gamma:
         aug_gamma = GammaTransform(self.config_gamma_range,
                                    invert_image=False,
                                    per_channel=True,
                                    retain_stats=True,
                                    p_per_sample=self.config_p_per_sample)
         transforms.append(aug_gamma)
     # Add gaussian noise augmentation
     if self.gaussian_noise:
         aug_gaussian_noise = GaussianNoiseTransform(
             self.config_gaussian_noise_range,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_gaussian_noise)
     # Add spatial transformations as augmentation
     # (rotation, scaling, elastic deformation)
     if self.rotations or self.scaling or self.elastic_deform or \
         self.cropping:
         # Identify patch shape (full image or cropping)
         if self.cropping: patch_shape = self.cropping_patch_shape
         else: patch_shape = img_data[0].shape[0:-1]
         # Assembling the spatial transformation
         aug_spatial_transform = SpatialTransform(
             patch_shape, [i // 2 for i in patch_shape],
             do_elastic_deform=self.elastic_deform,
             alpha=self.config_elastic_deform_alpha,
             sigma=self.config_elastic_deform_sigma,
             do_rotation=self.rotations,
             angle_x=self.config_rotations_angleX,
             angle_y=self.config_rotations_angleY,
             angle_z=self.config_rotations_angleZ,
             do_scale=self.scaling,
             scale=self.config_scaling_range,
             border_mode_data='constant',
             border_cval_data=0,
             border_mode_seg='constant',
             border_cval_seg=0,
             order_data=3,
             order_seg=0,
             p_el_per_sample=self.config_p_per_sample,
             p_rot_per_sample=self.config_p_per_sample,
             p_scale_per_sample=self.config_p_per_sample,
             random_crop=self.cropping)
         # Append spatial transformation to transformation list
         transforms.append(aug_spatial_transform)
     # Compose the batchgenerators transforms
     all_transforms = Compose(transforms)
     # Assemble transforms into a augmentation generator
     augmentation_generator = SingleThreadedAugmenter(
         data_generator, all_transforms)
     # Perform the data augmentation x times (x = cycles)
     aug_img_data = None
     aug_seg_data = None
     for i in range(0, self.cycles):
         # Run the computation process for the data augmentations
         augmentation = next(augmentation_generator)
         # Access augmentated data from the batchgenerators data structure
         if aug_img_data is None and aug_seg_data is None:
             aug_img_data = augmentation["data"]
             aug_seg_data = augmentation["seg"]
         # Concatenate the new data augmentated data with the cached data
         else:
             aug_img_data = np.concatenate(
                 (augmentation["data"], aug_img_data), axis=0)
             aug_seg_data = np.concatenate(
                 (augmentation["seg"], aug_seg_data), axis=0)
     # Transform data from channel-first back to channel-last structure
     # Data structure channel-first 3D:  (batch, channel, x, y, z)
     # Data structure channel-last 3D:   (batch, x, y, z, channel)
     aug_img_data = np.moveaxis(aug_img_data, 1, -1)
     aug_seg_data = np.moveaxis(aug_seg_data, 1, -1)
     # Return augmentated image and segmentation data
     return aug_img_data, aug_seg_data
Beispiel #16
0
def get_transforms(mode="train",
                   n_channels=1,
                   target_size=128,
                   add_resize=False,
                   add_noise=False,
                   mask_type="",
                   batch_size=16,
                   rotate=True,
                   elastic_deform=True,
                   rnd_crop=False,
                   color_augment=True):
    tranform_list = []
    noise_list = []

    if mode == "train":

        tranform_list = [
            FillupPadTransform(min_size=(n_channels, target_size + 5,
                                         target_size + 5)),
            ResizeTransform(target_size=(target_size + 1, target_size + 1),
                            order=1,
                            concatenate_list=True),

            # RandomCropTransform(crop_size=(target_size + 5, target_size + 5)),
            MirrorTransform(axes=(2, )),
            ReshapeTransform(new_shape=(1, -1, "h", "w")),
            SpatialTransform(patch_size=(target_size, target_size),
                             random_crop=rnd_crop,
                             patch_center_dist_from_border=target_size // 2,
                             do_elastic_deform=elastic_deform,
                             alpha=(0., 100.),
                             sigma=(10., 13.),
                             do_rotation=rotate,
                             angle_x=(-0.1, 0.1),
                             angle_y=(0, 1e-8),
                             angle_z=(0, 1e-8),
                             scale=(0.9, 1.2),
                             border_mode_data="nearest",
                             border_mode_seg="nearest"),
            ReshapeTransform(new_shape=(batch_size, -1, "h", "w"))
        ]
        if color_augment:
            tranform_list += [  # BrightnessTransform(mu=0, sigma=0.2),
                BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1))
            ]

        tranform_list += [
            GaussianNoiseTransform(noise_variance=(0., 0.05)),
            ClipValueRange(min=-1.5, max=1.5),
        ]

        noise_list = []
        if mask_type == "gaussian":
            noise_list += [GaussianNoiseTransform(noise_variance=(0., 0.2))]

    elif mode == "val":
        tranform_list = [
            FillupPadTransform(min_size=(n_channels, target_size + 5,
                                         target_size + 5)),
            ResizeTransform(target_size=(target_size + 1, target_size + 1),
                            order=1,
                            concatenate_list=True),
            CenterCropTransform(crop_size=(target_size, target_size)),
            ClipValueRange(min=-1.5, max=1.5),
            # BrightnessTransform(mu=0, sigma=0.2),
            # BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1)),
            CopyTransform({"data": "data_clean"}, copy=True)
        ]

        noise_list += []

    if add_noise:
        tranform_list = tranform_list + noise_list

    tranform_list.append(NumpyToTensor())

    return Compose(tranform_list)
train_path = PATH / 'train'
val_path = PATH / 'val'

# %%
import torch
import numpy as np
import random
from torchvision.datasets import ImageFolder
from pltools.data import ToTensor
from data_loading.experimental import DataLoader
from data_loading import numpy_collate

from batchgenerators.transforms import ZeroMeanUnitVarianceTransform, Compose, MirrorTransform

pre_transforms = [ZeroMeanUnitVarianceTransform()]
train_transforms = [MirrorTransform(axes=(0, 1))]
post_transforms = [
    ToTensor(keys=('data', 'label'),
             dtypes=(('data', torch.float32), ('label', torch.int64)))
]


def init_fn(worker_id):
    seed = torch.utils.data._utils.worker._worker_info.seed
    import numpy as np
    seed32 = np.array(seed).astype(np.uint32)
    np.random.seed(seed32)


train_dset = PytorchDatasetWrapper(ImageFolder(train_path))
train_transforms = Compose(pre_transforms + train_transforms + post_transforms)
def run_experiment(cp: str, test=True) -> str:
    """
    Run classification experiment on patches
    Imports moved inside because of logging setups

    Parameters
    ----------
    ch : str
        path to config file
    test : bool
        test best model on test set

    Returns
    -------
    str
        path to experiment folder
    """
    # setup config
    ch = ConfigHandlerPyTorchDelira(cp)
    ch = feature_map_params(ch)

    if 'mixed_precision' not in ch or ch['mixed_precision'] is None:
        ch['mixed_precision'] = True
    if 'debug_delira' in ch and ch['debug_delira'] is not None:
        delira.set_debug_mode(ch['debug_delira'])
        print("Debug mode active: settings n_process_augmentation to 1!")
        ch['augment.n_process'] = 1

    dset_keys = ['train', 'val', 'test']

    losses = {'class_ce': torch.nn.CrossEntropyLoss()}
    train_metrics = {}
    val_metrics = {'CE': metric_wrapper_pytorch(torch.nn.CrossEntropyLoss())}
    test_metrics = {'CE': metric_wrapper_pytorch(torch.nn.CrossEntropyLoss())}

    #########################
    #   Setup Parameters    #
    #########################
    params_dict = ch.get_params(losses=losses,
                                train_metrics=train_metrics,
                                val_metrics=val_metrics,
                                add_self=ch['add_config_to_params'])
    params = Parameters(**params_dict)

    #################
    #   Setup IO    #
    #################
    # setup io
    load_sample = load_pickle
    load_fn = LoadPatches(load_fn=load_sample,
                          patch_size=ch['patch_size'],
                          **ch['data.load_patch'])

    datasets = {}
    for key in dset_keys:
        p = os.path.join(ch["data.path"], str(key))

        datasets[key] = BaseExtendCacheDataset(p,
                                               load_fn=load_fn,
                                               **ch['data.kwargs'])

    #############################
    #   Setup Transformations   #
    #############################
    base_transforms = []
    base_transforms.append(PopKeys("mapping"))

    train_transforms = []
    if ch['augment.mode']:
        logger.info("Training augmentation enabled.")
        train_transforms.append(
            SpatialTransform(patch_size=ch['patch_size'],
                             **ch['augment.kwargs']))
        train_transforms.append(MirrorTransform(axes=(0, 1)))
    process = ch['augment.n_process'] if 'augment.n_process' in ch else 1

    #########################
    #   Setup Datamanagers  #
    #########################
    datamanagers = {}
    for key in dset_keys:
        if key == 'train':
            trafos = base_transforms + train_transforms
            sampler = WeightedPrevalenceRandomSampler
        else:
            trafos = base_transforms
            sampler = SequentialSampler

        datamanagers[key] = BaseDataManager(
            data=datasets[key],
            batch_size=params.nested_get('batch_size'),
            n_process_augmentation=process,
            transforms=Compose(trafos),
            sampler_cls=sampler,
        )

    #############################
    #   Initialize Experiment   #
    #############################
    experiment = \
        PyTorchExperiment(
            params=params,
            model_cls=ClassNetwork,
            name=ch['exp.name'],
            save_path=ch['exp.dir'],
            optim_builder=create_optims_default_pytorch,
            trainer_cls=PyTorchNetworkTrainer,
            mixed_precision=ch['mixed_precision'],
            mixed_precision_kwargs={'verbose': False},
            key_mapping={"input_batch": "data"},
            **ch['exp.kwargs'],
        )

    # save configurations
    ch.dump(os.path.join(experiment.save_path, 'config.json'))

    #################
    #   Training    #
    #################
    model = experiment.run(datamanagers['train'],
                           datamanagers['val'],
                           save_path_exp=experiment.save_path,
                           ch=ch,
                           metric_keys={'val_CE': ['pred', 'label']},
                           val_freq=1,
                           verbose=True)
    ################
    #   Testing    #
    ################
    if test and datamanagers['test'] is not None:
        # metrics and metric_keys are used differently than in original
        # Delira implementation in order to support Evaluator
        # see mscl.training.predictor
        preds = experiment.test(
            network=model,
            test_data=datamanagers['test'],
            metrics=test_metrics,
            metric_keys={'CE': ['pred', 'label']},
            verbose=True,
        )

        softmax_fn = metric_wrapper_pytorch(
            partial(torch.nn.functional.softmax, dim=1))
        preds = softmax_fn(preds[0]['pred'])
        labels = [d['label'] for d in datasets['test']]
        fpr, tpr, thresholds = roc_curve(labels, preds[:, 1])
        roc_auc = auc(fpr, tpr)

        plt.plot(fpr, tpr, label='ROC (AUC = %0.2f)' % roc_auc)
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver operating characteristic example')
        plt.legend(loc="lower right")
        plt.savefig(os.path.join(experiment.save_path, 'test_roc.pdf'))
        plt.close()

        preds = experiment.test(
            network=model,
            test_data=datamanagers['val'],
            metrics=test_metrics,
            metric_keys={'CE': ['pred', 'label']},
            verbose=True,
        )

        preds = softmax_fn(preds[0]['pred'])
        labels = [d['label'] for d in datasets['val']]
        fpr, tpr, thresholds = roc_curve(labels, preds[:, 1])
        roc_auc = auc(fpr, tpr)

        plt.plot(fpr, tpr, label='ROC (AUC = %0.2f)' % roc_auc)
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver operating characteristic example')
        plt.legend(loc="lower right")
        plt.savefig(os.path.join(experiment.save_path, 'best_val_roc.pdf'))
        plt.close()

    return experiment.save_path
Beispiel #19
0
def get_arteries_augmentation(dataloader_train,
                              dataloader_val,
                              patch_size,
                              params=default_3D_augmentation_params,
                              border_val_seg=-1,
                              seeds_train=None,
                              seeds_val=None,
                              order_seg=1,
                              order_data=3,
                              deep_supervision_scales=None,
                              soft_ds=False,
                              classes=None,
                              pin_memory=True,
                              regions=None,
                              use_nondetMultiThreadedAugmenter: bool = False):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    # if params.get("dummy_2D") is not None and params.get("dummy_2D"):
    #     ignore_axes = (0,)
    #     tr_transforms.append(Convert3DTo2DTransform())
    # else:
    #     ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=False,
                         do_rotation=False,
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=False,
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    # if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
    #     tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
    #     if params.get("cascade_do_cascade_augmentations") is not None and params.get(
    #             "cascade_do_cascade_augmentations"):
    #         if params.get("cascade_random_binary_transform_p") > 0:
    #             tr_transforms.append(ApplyRandomBinaryOperatorTransform(
    #                 channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
    #                 p_per_sample=params.get("cascade_random_binary_transform_p"),
    #                 key="data",
    #                 strel_size=params.get("cascade_random_binary_transform_size"),
    #                 p_per_label=params.get("cascade_random_binary_transform_p_per_label")))
    #         if params.get("cascade_remove_conn_comp_p") > 0:
    #             tr_transforms.append(
    #                 RemoveRandomConnectedComponentFromOneHotEncodingTransform(
    #                     channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
    #                     key="data",
    #                     p_per_sample=params.get("cascade_remove_conn_comp_p"),
    #                     fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
    #                     dont_do_if_covers_more_than_X_percent=params.get(
    #                         "cascade_remove_conn_comp_fill_with_other_class_p")))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    # if deep_supervision_scales is not None:
    #     if soft_ds:
    #         assert classes is not None
    #         tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
    #     else:
    #         tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target',
    #                                                           output_key='target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)
    # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
    # import IPython;IPython.embed()
    batchgenerator_train.restart()

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
    #     val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    # if deep_supervision_scales is not None:
    #     if soft_ds:
    #         assert classes is not None
    #         val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
    #     else:
    #         val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target',
    #                                                            output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
    batchgenerator_val.restart()

    return batchgenerator_train, batchgenerator_val
def create_data_gen_train(patient_data_train,
                          BATCH_SIZE,
                          num_classes,
                          patch_size,
                          num_workers=5,
                          num_cached_per_worker=2,
                          do_elastic_transform=False,
                          alpha=(0., 1300.),
                          sigma=(10., 13.),
                          do_rotation=False,
                          a_x=(0., 2 * np.pi),
                          a_y=(0., 2 * np.pi),
                          a_z=(0., 2 * np.pi),
                          do_scale=True,
                          scale_range=(0.75, 1.25),
                          seeds=None):
    if seeds is None:
        seeds = [None] * num_workers
    elif seeds == 'range':
        seeds = range(num_workers)
    else:
        assert len(seeds) == num_workers
    data_gen_train = BatchGenerator(patient_data_train,
                                    BATCH_SIZE,
                                    num_batches=None,
                                    seed=False,
                                    PATCH_SIZE=(10, 352, 352))

    # train transforms
    tr_transforms = []
    tr_transforms.append(MotionAugmentationTransform(0.1, 0, 20))
    tr_transforms.append(MirrorTransform((3, 4)))
    tr_transforms.append(Convert3DTo2DTransform())
    tr_transforms.append(
        RndTransform(SpatialTransform(patch_size[1:],
                                      112,
                                      do_elastic_transform,
                                      alpha,
                                      sigma,
                                      do_rotation,
                                      a_x,
                                      a_y,
                                      a_z,
                                      do_scale,
                                      scale_range,
                                      'constant',
                                      0,
                                      3,
                                      'constant',
                                      0,
                                      0,
                                      random_crop=False),
                     prob=0.67,
                     alternative_transform=RandomCropTransform(
                         patch_size[1:])))
    tr_transforms.append(Convert2DTo3DTransform(patch_size))
    tr_transforms.append(
        RndTransform(GammaTransform((0.85, 1.3), False), prob=0.5))
    tr_transforms.append(
        RndTransform(GammaTransform((0.85, 1.3), True), prob=0.5))
    tr_transforms.append(CutOffOutliersTransform(0.3, 99.7, True))
    tr_transforms.append(ZeroMeanUnitVarianceTransform(True))
    tr_transforms.append(
        ConvertSegToOnehotTransform(range(num_classes), 0, 'seg_onehot'))

    tr_composed = Compose(tr_transforms)
    tr_mt_gen = MultiThreadedAugmenter(data_gen_train, tr_composed,
                                       num_workers, num_cached_per_worker,
                                       seeds)
    tr_mt_gen.restart()
    return tr_mt_gen
def get_insaneDA_augmentation(dataloader_train,
                              dataloader_val,
                              patch_size,
                              params=default_3D_augmentation_params,
                              border_val_seg=-1,
                              seeds_train=None,
                              seeds_val=None,
                              order_seg=1,
                              order_data=3,
                              deep_supervision_scales=None,
                              soft_ds=False,
                              classes=None,
                              pin_memory=True,
                              regions=None):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    # 'patch_size': array([288, 320]),
    # 'border_val_seg': -1,
    # 'seeds_train': None,
    # 'seeds_val': None,
    # 'order_seg': 1,
    # 'order_data': 3,
    # 'deep_supervision_scales': [[1, 1, 1],
    #                             [1.0, 0.5, 0.5],
    #                             [1.0, 0.25, 0.25],
    #                             [0.5, 0.125, 0.125],
    #                             [0.5, 0.0625, 0.0625]],
    # 'soft_ds': False,
    # 'classes': None,
    # 'pin_memory': True,
    # 'regions': None
    # params
    # {'selected_data_channels': None,
    #  'selected_seg_channels': [0],
    #  'do_elastic': True,
    #  'elastic_deform_alpha': (0.0, 300.0),
    #  'elastic_deform_sigma': (9.0, 15.0),
    #  'p_eldef': 0.1,
    #  'do_scaling': True,
    #  'scale_range': (0.65, 1.6),
    #  'independent_scale_factor_for_each_axis': True,
    #  'p_independent_scale_per_axis': 0.3,
    #  'p_scale': 0.3,
    #  'do_rotation': True,
    #  'rotation_x': (-3.141592653589793, 3.141592653589793),
    #  'rotation_y': (-0.5235987755982988, 0.5235987755982988),
    #  'rotation_z': (-0.5235987755982988, 0.5235987755982988),
    #  'rotation_p_per_axis': 1,
    #  'p_rot': 0.7,
    #  'random_crop': False,
    #  'random_crop_dist_to_border': None,
    #  'do_gamma': True,
    #  'gamma_retain_stats': True,
    #  'gamma_range': (0.5, 1.6),
    #  'p_gamma': 0.3,
    #  'do_mirror': True,
    #  'mirror_axes': (0, 1, 2),
    #  'dummy_2D': True,
    #  'mask_was_used_for_normalization': OrderedDict([(0, False)]),
    #  'border_mode_data': 'constant',
    #  'all_segmentation_labels': None,
    #  'move_last_seg_chanel_to_data': False,
    #  'cascade_do_cascade_augmentations': False,
    #  'cascade_random_binary_transform_p': 0.4,
    #  'cascade_random_binary_transform_p_per_label': 1,
    #  'cascade_random_binary_transform_size': (1, 8),
    #  'cascade_remove_conn_comp_p': 0.2,
    #  'cascade_remove_conn_comp_max_size_percent_threshold': 0.15,
    #  'cascade_remove_conn_comp_fill_with_other_class_p': 0.0,
    #  'do_additive_brightness': True,
    #  'additive_brightness_p_per_sample': 0.3,
    #  'additive_brightness_p_per_channel': 1,
    #  'additive_brightness_mu': 0,
    #  'additive_brightness_sigma': 0.2,
    #  'num_threads': 12,
    #  'num_cached_per_thread': 1,
    #  'patch_size_for_spatialtransform': array([288, 320])}

    # selected_data_channels is None
    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    # selected_seg_channels is [0]
    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    # dummy_2D is True
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
    else:
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis"),
                         p_independent_scale_per_axis=params.get(
                             "p_independent_scale_per_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.5),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3),
                                          p_per_sample=0.15))
    tr_transforms.append(
        ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                      p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.15))  # inverted gamma

    # do_additive_brightness is True
    if params.get("do_additive_brightness"):
        tr_transforms.append(
            BrightnessTransform(
                params.get("additive_brightness_mu"),
                params.get("additive_brightness_sigma"),
                True,
                p_per_sample=params.get("additive_brightness_p_per_sample"),
                p_per_channel=params.get("additive_brightness_p_per_channel")))

    # do_gamma is True
    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    # do_mirror is True
    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    # mask_was_used_for_normalization is OrderedDict([(0, False)]),
    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    # move_last_seg_chanel_to_data is False
    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations"
                      ) and not None and params.get(
                          "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    # regions is None
    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    # deep_supervision_scales is a not None
    if deep_supervision_scales is not None:
        # soft_ds is False
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    # ========================================================
    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    # selected_data_channels is None
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    # selected_seg_channels is [0]
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # move_last_seg_chanel_to_data is False
    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    # regions is None
    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    # deep_supervision_scales is not None
    if deep_supervision_scales is not None:
        # soft_ds is False
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
Beispiel #22
0
def get_default_augmentation(dataloader_train, dataloader_val=None, params=None,
                             patch_size=None, border_val_seg=-1, pin_memory=True,
                             seeds_train=None, seeds_val=None, regions=None):
    assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
    tr_transforms = []

    assert params is not None, "augmentation params expect to be not None"

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())

    tr_transforms.append(SpatialTransform(
        patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
        alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"),
        do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
        angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
        border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant",
        border_cval_seg=border_val_seg,
        order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
        p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
        independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
    ))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))

    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
                                                  params.get("num_cached_per_thread"), seeds=seeds_train,
                                                  pin_memory=pin_memory)
    batchgenerator_train.restart()

    if dataloader_val is None:
        return batchgenerator_train, None

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
                                                params.get("num_cached_per_thread"), seeds=seeds_val,
                                                pin_memory=pin_memory)
    batchgenerator_val.restart()

    return batchgenerator_train, batchgenerator_val