예제 #1
0
def combine_transform():
    '''
    组合变换:对比度+镜像
    :return:
    '''
    my_transforms = []
    # 对比度变换
    brightness_transform = ContrastAugmentationTransform((0.3, 3.),
                                                         preserve_range=True)
    my_transforms.append(brightness_transform)

    # 镜像变换
    mirror_transform = MirrorTransform(axes=(2, 3))
    my_transforms.append(mirror_transform)

    all_transform = Compose(my_transforms)

    batchgen = my_data_loader.DataLoader(camera(), 4)
    multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transform,
                                                     4, 2)

    # 显示转换效果
    my_data_loader.plot_batch(multithreaded_generator.__next__())
예제 #2
0
    def get_batches(self, batch_size=1):

        num_processes = 1  # not not use more than 1 if you want to keep original slice order (Threads do return in random order)

        if self.HP.TYPE == "combined":
            # Load from Npy file for Fusion
            data = self.subject
            seg = []
            nr_of_samples = len([self.subject]) * self.HP.INPUT_DIM[0]
            num_batches = int(nr_of_samples / batch_size / num_processes)
            batch_gen = SlicesBatchGeneratorNpyImg_fusion(
                (data, seg),
                BATCH_SIZE=batch_size,
                num_batches=num_batches,
                seed=None)
        else:
            # Load Features
            if self.HP.FEATURES_FILENAME == "12g90g270g":
                data_img = nib.load(
                    join(self.data_dir, "270g_125mm_peaks.nii.gz"))
            else:
                data_img = nib.load(
                    join(self.data_dir, self.HP.FEATURES_FILENAME + ".nii.gz"))
            data = data_img.get_data()
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(
                data, self.HP.DATASET, self.HP.RESOLUTION)
            # data = DatasetUtils.scale_input_to_unet_shape(data, "HCP_32g", "1.25mm")  #If we want to test HCP_32g on HighRes net

            #Load Segmentation
            if self.use_gt_mask:
                seg = nib.load(
                    join(self.data_dir,
                         self.HP.LABELS_FILENAME + ".nii.gz")).get_data()

                if self.HP.LABELS_FILENAME not in [
                        "bundle_peaks_11_808080", "bundle_peaks_20_808080",
                        "bundle_peaks_808080", "bundle_masks_20_808080",
                        "bundle_masks_72_808080", "bundle_peaks_Part1_808080",
                        "bundle_peaks_Part2_808080",
                        "bundle_peaks_Part3_808080",
                        "bundle_peaks_Part4_808080"
                ]:
                    if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
                        # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask
                        seg = DatasetUtils.scale_input_to_unet_shape(
                            seg, "HCP", self.HP.RESOLUTION)
                    else:
                        seg = DatasetUtils.scale_input_to_unet_shape(
                            seg, self.HP.DATASET, self.HP.RESOLUTION)
            else:
                # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth))
                seg = np.zeros(
                    (self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0],
                     self.HP.INPUT_DIM[0],
                     self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE)

            batch_gen = SlicesBatchGenerator((data, seg),
                                             batch_size=batch_size)

        batch_gen.HP = self.HP
        tfs = []  # transforms

        if self.HP.NORMALIZE_DATA:
            tfs.append(ZeroMeanUnitVarianceTransform(per_channel=False))

        if self.HP.TEST_TIME_DAUG:
            center_dist_from_border = int(
                self.HP.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
            tfs.append(
                SpatialTransform(
                    self.HP.INPUT_DIM,
                    patch_center_dist_from_border=center_dist_from_border,
                    do_elastic_deform=True,
                    alpha=(90., 120.),
                    sigma=(9., 11.),
                    do_rotation=True,
                    angle_x=(-0.8, 0.8),
                    angle_y=(-0.8, 0.8),
                    angle_z=(-0.8, 0.8),
                    do_scale=True,
                    scale=(0.9, 1.5),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    border_mode_seg='constant',
                    border_cval_seg=0,
                    order_seg=0,
                    random_crop=True))
            # tfs.append(ResampleTransform(zoom_range=(0.5, 1)))
            # tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05)))
            tfs.append(
                ContrastAugmentationTransform(contrast_range=(0.7, 1.3),
                                              preserve_range=True,
                                              per_channel=False))
            tfs.append(
                BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3),
                                                  per_channel=False))

        tfs.append(ReorderSegTransform())
        batch_gen = MultiThreadedAugmenter(
            batch_gen,
            Compose(tfs),
            num_processes=num_processes,
            num_cached_per_queue=2,
            seeds=None
        )  # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
예제 #3
0
def get_insaneDA_augmentation(dataloader_train,
                              dataloader_val,
                              patch_size,
                              params=default_3D_augmentation_params,
                              border_val_seg=-1,
                              seeds_train=None,
                              seeds_val=None,
                              order_seg=1,
                              order_data=3,
                              deep_supervision_scales=None,
                              soft_ds=False,
                              classes=None,
                              pin_memory=True):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
    else:
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.5),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3),
                                          p_per_sample=0.15))
    tr_transforms.append(
        ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                      p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.15))  # inverted gamma

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations"
                      ) and not None and params.get(
                          "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
예제 #4
0
batchgen = DataLoader(data.camera(), 1, None, False)
#batch = next(batchgen)

#print(batch['data'].shape)
def plot_batch(batch):
    batch_size = batch['data'].shape[0]
    for i in range(batch_size):
        plt.subplot(1, batch_size, i+1)
        plt.imshow(batch['data'][i, 0], cmap="gray")
    plt.show()
#plot_batch(batch)

my_transforms = []

brightness_transform = ContrastAugmentationTransform((0.3, 3.), preserve_range=True)
my_transforms.append(brightness_transform)

noise_transform = GaussianNoiseTransform(noise_variance=(0, 20)) ##
my_transforms.append(noise_transform)

spatial_transform = SpatialTransform_2(data.camera().shape, np.array(data.camera().shape)//2,
                                     do_elastic_deform=True, deformation_scale=(0,0.05),
                                     do_rotation=True, angle_z=(0, 2*np.pi),
                                     do_scale=True, scale=(0.8, 1.2),
                                     border_mode_data='constant', border_cval_data=0, order_data=1,
                                     random_crop=False)
my_transforms.append(spatial_transform)
all_transforms = Compose(my_transforms)
multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transforms, 4, 2, seeds=None)
plot_batch(next(multithreaded_generator))
예제 #5
0
    def get_train_transforms(self) -> List[AbstractTransform]:
        # used for transpost and rot90
        matching_axes = np.array(
            [sum([i == j for j in self.patch_size]) for i in self.patch_size])
        valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0])

        tr_transforms = []

        if self.data_aug_params['selected_seg_channels'] is not None:
            tr_transforms.append(
                SegChannelSelectionTransform(
                    self.data_aug_params['selected_seg_channels']))

        if self.do_dummy_2D_aug:
            ignore_axes = (0, )
            tr_transforms.append(Convert3DTo2DTransform())
            patch_size_spatial = self.patch_size[1:]
        else:
            patch_size_spatial = self.patch_size
            ignore_axes = None

        tr_transforms.append(
            SpatialTransform(
                patch_size_spatial,
                patch_center_dist_from_border=None,
                do_elastic_deform=False,
                do_rotation=True,
                angle_x=self.data_aug_params["rotation_x"],
                angle_y=self.data_aug_params["rotation_y"],
                angle_z=self.data_aug_params["rotation_z"],
                p_rot_per_axis=0.5,
                do_scale=True,
                scale=self.data_aug_params['scale_range'],
                border_mode_data="constant",
                border_cval_data=0,
                order_data=3,
                border_mode_seg="constant",
                border_cval_seg=-1,
                order_seg=1,
                random_crop=False,
                p_el_per_sample=0.2,
                p_scale_per_sample=0.2,
                p_rot_per_sample=0.4,
                independent_scale_for_each_axis=True,
            ))

        if self.do_dummy_2D_aug:
            tr_transforms.append(Convert2DTo3DTransform())

        if np.any(matching_axes > 1):
            tr_transforms.append(
                Rot90Transform((0, 1, 2, 3),
                               axes=valid_axes,
                               data_key='data',
                               label_key='seg',
                               p_per_sample=0.5), )

        if np.any(matching_axes > 1):
            tr_transforms.append(
                TransposeAxesTransform(valid_axes,
                                       data_key='data',
                                       label_key='seg',
                                       p_per_sample=0.5))

        tr_transforms.append(
            OneOfTransform([
                MedianFilterTransform((2, 8),
                                      same_for_each_channel=False,
                                      p_per_sample=0.2,
                                      p_per_channel=0.5),
                GaussianBlurTransform((0.3, 1.5),
                                      different_sigma_per_channel=True,
                                      p_per_sample=0.2,
                                      p_per_channel=0.5)
            ]))

        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))

        tr_transforms.append(
            BrightnessTransform(0,
                                0.5,
                                per_channel=True,
                                p_per_sample=0.1,
                                p_per_channel=0.5))

        tr_transforms.append(
            OneOfTransform([
                ContrastAugmentationTransform(contrast_range=(0.5, 2),
                                              preserve_range=True,
                                              per_channel=True,
                                              data_key='data',
                                              p_per_sample=0.2,
                                              p_per_channel=0.5),
                ContrastAugmentationTransform(contrast_range=(0.5, 2),
                                              preserve_range=False,
                                              per_channel=True,
                                              data_key='data',
                                              p_per_sample=0.2,
                                              p_per_channel=0.5),
            ]))

        tr_transforms.append(
            SimulateLowResolutionTransform(zoom_range=(0.25, 1),
                                           per_channel=True,
                                           p_per_channel=0.5,
                                           order_downsample=0,
                                           order_upsample=3,
                                           p_per_sample=0.15,
                                           ignore_axes=ignore_axes))

        tr_transforms.append(
            GammaTransform((0.7, 1.5),
                           invert_image=True,
                           per_channel=True,
                           retain_stats=True,
                           p_per_sample=0.1))
        tr_transforms.append(
            GammaTransform((0.7, 1.5),
                           invert_image=True,
                           per_channel=True,
                           retain_stats=True,
                           p_per_sample=0.1))

        if self.do_mirroring:
            tr_transforms.append(MirrorTransform(self.mirror_axes))

        tr_transforms.append(
            BlankRectangleTransform([[max(1, p // 10), p // 3]
                                     for p in self.patch_size],
                                    rectangle_value=np.mean,
                                    num_rectangles=(1, 5),
                                    force_square=False,
                                    p_per_sample=0.4,
                                    p_per_channel=0.5))

        tr_transforms.append(
            BrightnessGradientAdditiveTransform(
                lambda x, y: np.exp(
                    np.random.uniform(np.log(x[y] // 6), np.log(x[y]))),
                (-0.5, 1.5),
                max_strength=lambda x, y: np.random.uniform(-5, -1)
                if np.random.uniform() < 0.5 else np.random.uniform(1, 5),
                mean_centered=False,
                same_for_all_channels=False,
                p_per_sample=0.3,
                p_per_channel=0.5))

        tr_transforms.append(
            LocalGammaTransform(
                lambda x, y: np.exp(
                    np.random.uniform(np.log(x[y] // 6), np.log(x[y]))),
                (-0.5, 1.5),
                lambda: np.random.uniform(0.01, 0.8)
                if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4),
                same_for_all_channels=False,
                p_per_sample=0.3,
                p_per_channel=0.5))

        tr_transforms.append(
            SharpeningTransform(strength=(0.1, 1),
                                same_for_each_channel=False,
                                p_per_sample=0.2,
                                p_per_channel=0.5))

        if any(self.use_mask_for_norm.values()):
            tr_transforms.append(
                MaskTransform(self.use_mask_for_norm,
                              mask_idx_in_seg=0,
                              set_outside_to=0))

        tr_transforms.append(RemoveLabelTransform(-1, 0))

        if self.data_aug_params["move_last_seg_chanel_to_data"]:
            all_class_labels = np.arange(1, self.num_classes)
            tr_transforms.append(
                MoveSegAsOneHotToData(1, all_class_labels, 'seg', 'data'))
            if self.data_aug_params["cascade_do_cascade_augmentations"]:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(channel_idx=list(
                        range(-len(all_class_labels), 0)),
                                                       p_per_sample=0.4,
                                                       key="data",
                                                       strel_size=(1, 8),
                                                       p_per_label=1))

                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(range(-len(all_class_labels), 0)),
                        key="data",
                        p_per_sample=0.2,
                        fill_with_other_class_p=0.15,
                        dont_do_if_covers_more_than_X_percent=0))

        tr_transforms.append(RenameTransform('seg', 'target', True))

        if self.regions is not None:
            tr_transforms.append(
                ConvertSegmentationToRegionsTransform(self.regions, 'target',
                                                      'target'))

        if self.deep_supervision_scales is not None:
            tr_transforms.append(
                DownsampleSegForDSTransform2(self.deep_supervision_scales,
                                             0,
                                             input_key='target',
                                             output_key='target'))

        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
        return tr_transforms
             up_kwargs={
                 'attention': True
             },
             encode_block=ResBlockStack,
             encode_kwargs_fn=encode_kwargs_fn,
             decode_block=ResBlock).cuda()

patch_size = (160, 160, 80)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                 factor=0.2,
                                                 patience=30)

tr_transform = Compose([
    GammaTransform((0.9, 1.1)),
    ContrastAugmentationTransform((0.9, 1.1)),
    BrightnessMultiplicativeTransform((0.9, 1.1)),
    MirrorTransform(axes=[0]),
    SpatialTransform_2(
        patch_size,
        (90, 90, 50),
        random_crop=True,
        do_elastic_deform=True,
        deformation_scale=(0, 0.05),
        do_rotation=True,
        angle_x=(-0.1 * np.pi, 0.1 * np.pi),
        angle_y=(0, 0),
        angle_z=(0, 0),
        do_scale=True,
        scale=(0.9, 1.1),
        border_mode_data='constant',
예제 #7
0
# export NEPTUNE_API_TOKEN = '...' !!!
logging.getLogger().setLevel('INFO')
source_files = [__file__]
if hparams.config:
    source_files.append(hparams.config)
neptune_logger = NeptuneLogger(project_name=hparams.neptune_project,
                               params=vars(hparams),
                               experiment_name=hparams.experiment_name,
                               tags=[hparams.experiment_name],
                               upload_source_files=source_files)
tb_logger = loggers.TensorBoardLogger(hparams.log_dir)

transform = Compose([
    BrightnessTransform(mu=0.0, sigma=0.3, data_key='data'),
    GammaTransform(gamma_range=(0.7, 1.3), data_key='data'),
    ContrastAugmentationTransform(contrast_range=(0.3, 1.7), data_key='data')
])

with open(hparams.train_set, 'r') as keyfile:
    train_keys = [l.strip() for l in keyfile.readlines()]
print(train_keys)

with open(hparams.val_set, 'r') as keyfile:
    val_keys = [l.strip() for l in keyfile.readlines()]
print(val_keys)

train_ds = MedDataset(hparams.data_path,
                      train_keys,
                      hparams.patches_per_subject,
                      hparams.patch_size,
                      image_group=hparams.image_group,
예제 #8
0
def get_insaneDA_augmentation(dataloader_train,
                              dataloader_val,
                              patch_size,
                              params=default_3D_augmentation_params,
                              border_val_seg=-1,
                              seeds_train=None,
                              seeds_val=None,
                              order_seg=1,
                              order_data=3,
                              deep_supervision_scales=None,
                              soft_ds=False,
                              classes=None,
                              pin_memory=True,
                              regions=None):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    # 'patch_size': array([288, 320]),
    # 'border_val_seg': -1,
    # 'seeds_train': None,
    # 'seeds_val': None,
    # 'order_seg': 1,
    # 'order_data': 3,
    # 'deep_supervision_scales': [[1, 1, 1],
    #                             [1.0, 0.5, 0.5],
    #                             [1.0, 0.25, 0.25],
    #                             [0.5, 0.125, 0.125],
    #                             [0.5, 0.0625, 0.0625]],
    # 'soft_ds': False,
    # 'classes': None,
    # 'pin_memory': True,
    # 'regions': None
    # params
    # {'selected_data_channels': None,
    #  'selected_seg_channels': [0],
    #  'do_elastic': True,
    #  'elastic_deform_alpha': (0.0, 300.0),
    #  'elastic_deform_sigma': (9.0, 15.0),
    #  'p_eldef': 0.1,
    #  'do_scaling': True,
    #  'scale_range': (0.65, 1.6),
    #  'independent_scale_factor_for_each_axis': True,
    #  'p_independent_scale_per_axis': 0.3,
    #  'p_scale': 0.3,
    #  'do_rotation': True,
    #  'rotation_x': (-3.141592653589793, 3.141592653589793),
    #  'rotation_y': (-0.5235987755982988, 0.5235987755982988),
    #  'rotation_z': (-0.5235987755982988, 0.5235987755982988),
    #  'rotation_p_per_axis': 1,
    #  'p_rot': 0.7,
    #  'random_crop': False,
    #  'random_crop_dist_to_border': None,
    #  'do_gamma': True,
    #  'gamma_retain_stats': True,
    #  'gamma_range': (0.5, 1.6),
    #  'p_gamma': 0.3,
    #  'do_mirror': True,
    #  'mirror_axes': (0, 1, 2),
    #  'dummy_2D': True,
    #  'mask_was_used_for_normalization': OrderedDict([(0, False)]),
    #  'border_mode_data': 'constant',
    #  'all_segmentation_labels': None,
    #  'move_last_seg_chanel_to_data': False,
    #  'cascade_do_cascade_augmentations': False,
    #  'cascade_random_binary_transform_p': 0.4,
    #  'cascade_random_binary_transform_p_per_label': 1,
    #  'cascade_random_binary_transform_size': (1, 8),
    #  'cascade_remove_conn_comp_p': 0.2,
    #  'cascade_remove_conn_comp_max_size_percent_threshold': 0.15,
    #  'cascade_remove_conn_comp_fill_with_other_class_p': 0.0,
    #  'do_additive_brightness': True,
    #  'additive_brightness_p_per_sample': 0.3,
    #  'additive_brightness_p_per_channel': 1,
    #  'additive_brightness_mu': 0,
    #  'additive_brightness_sigma': 0.2,
    #  'num_threads': 12,
    #  'num_cached_per_thread': 1,
    #  'patch_size_for_spatialtransform': array([288, 320])}

    # selected_data_channels is None
    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    # selected_seg_channels is [0]
    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    # dummy_2D is True
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
    else:
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis"),
                         p_independent_scale_per_axis=params.get(
                             "p_independent_scale_per_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.5),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3),
                                          p_per_sample=0.15))
    tr_transforms.append(
        ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                      p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.15))  # inverted gamma

    # do_additive_brightness is True
    if params.get("do_additive_brightness"):
        tr_transforms.append(
            BrightnessTransform(
                params.get("additive_brightness_mu"),
                params.get("additive_brightness_sigma"),
                True,
                p_per_sample=params.get("additive_brightness_p_per_sample"),
                p_per_channel=params.get("additive_brightness_p_per_channel")))

    # do_gamma is True
    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    # do_mirror is True
    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    # mask_was_used_for_normalization is OrderedDict([(0, False)]),
    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    # move_last_seg_chanel_to_data is False
    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations"
                      ) and not None and params.get(
                          "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    # regions is None
    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    # deep_supervision_scales is a not None
    if deep_supervision_scales is not None:
        # soft_ds is False
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    # ========================================================
    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    # selected_data_channels is None
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    # selected_seg_channels is [0]
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # move_last_seg_chanel_to_data is False
    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    # regions is None
    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    # deep_supervision_scales is not None
    if deep_supervision_scales is not None:
        # soft_ds is False
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
예제 #9
0
def get_moreDA_augmentation(
    dataloader_train,
    dataloader_val,
    patch_size,
    params=default_3D_augmentation_params,
    border_val_seg=-1,
    seeds_train=None,
    seeds_val=None,
    order_seg=1,
    order_data=3,
    deep_supervision_scales=None,
    soft_ds=False,
    classes=None,
    pin_memory=True,
    anisotropy=False,
    extra_label_keys=None,
    val_mode=False,
    use_conf=False,
):
    '''
    Work as Dataloader with augmentation
    :return: train_loader, val_loader
        for each iterator, return {'data': (B, D, H, W), 'target': (B, D, H, W)}
    '''

    if not val_mode:
        assert params.get(
            'mirror'
        ) is None, "old version of params, use new keyword do_mirror"

        tr_transforms = []
        if params.get("selected_data_channels") is not None:
            tr_transforms.append(
                DataChannelSelectionTransform(
                    params.get("selected_data_channels")))
        if params.get("selected_seg_channels") is not None:
            tr_transforms.append(
                SegChannelSelectionTransform(
                    params.get("selected_seg_channels")))

        # anistropic setting
        if anisotropy or params.get("dummy_2D"):
            ignore_axes = (0, )
            tr_transforms.append(
                Convert3DTo2DTransform(extra_label_keys=extra_label_keys))
            patch_size = patch_size[1:]  # 2D patch size

            print('Using dummy2d data augmentation')
            params["elastic_deform_alpha"] = (0., 200.)
            params["elastic_deform_sigma"] = (9., 13.)
            params["rotation_x"] = (-180. / 360 * 2. * np.pi,
                                    180. / 360 * 2. * np.pi)
            params["rotation_y"] = (-0. / 360 * 2. * np.pi,
                                    0. / 360 * 2. * np.pi)
            params["rotation_z"] = (-0. / 360 * 2. * np.pi,
                                    0. / 360 * 2. * np.pi)

        else:
            ignore_axes = None

        # 1. Spatial Transform: rotation, scaling
        tr_transforms.append(
            SpatialTransform(patch_size,
                             patch_center_dist_from_border=None,
                             do_elastic_deform=params.get("do_elastic"),
                             alpha=params.get("elastic_deform_alpha"),
                             sigma=params.get("elastic_deform_sigma"),
                             do_rotation=params.get("do_rotation"),
                             angle_x=params.get("rotation_x"),
                             angle_y=params.get("rotation_y"),
                             angle_z=params.get("rotation_z"),
                             p_rot_per_axis=params.get("rotation_p_per_axis"),
                             do_scale=params.get("do_scaling"),
                             scale=params.get("scale_range"),
                             border_mode_data=params.get("border_mode_data"),
                             border_cval_data=0,
                             order_data=order_data,
                             border_mode_seg="constant",
                             border_cval_seg=border_val_seg,
                             order_seg=order_seg,
                             random_crop=params.get("random_crop"),
                             p_el_per_sample=params.get("p_eldef"),
                             p_scale_per_sample=params.get("p_scale"),
                             p_rot_per_sample=params.get("p_rot"),
                             independent_scale_for_each_axis=params.get(
                                 "independent_scale_factor_for_each_axis"),
                             extra_label_keys=extra_label_keys))

        if anisotropy or params.get("dummy_2D"):
            tr_transforms.append(
                Convert2DTo3DTransform(extra_label_keys=extra_label_keys))

        # 2. Noise Augmentation: gaussian noise, gaussian blur
        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
        tr_transforms.append(
            GaussianBlurTransform((0.5, 1.),
                                  different_sigma_per_channel=True,
                                  p_per_sample=0.2,
                                  p_per_channel=0.5))

        # 3. Color Augmentation: brightness, constrast, low resolution, gamma_transform
        tr_transforms.append(
            BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25),
                                              p_per_sample=0.15))
        if params.get("do_additive_brightness"):
            tr_transforms.append(
                BrightnessTransform(params.get("additive_brightness_mu"),
                                    params.get("additive_brightness_sigma"),
                                    True,
                                    p_per_sample=params.get(
                                        "additive_brightness_p_per_sample"),
                                    p_per_channel=params.get(
                                        "additive_brightness_p_per_channel")))
        tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
        tr_transforms.append(
            SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                           per_channel=True,
                                           p_per_channel=0.5,
                                           order_downsample=0,
                                           order_upsample=3,
                                           p_per_sample=0.25,
                                           ignore_axes=ignore_axes))
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           True,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=0.1))  # inverted gamma
        if params.get("do_gamma"):
            tr_transforms.append(
                GammaTransform(params.get("gamma_range"),
                               False,
                               True,
                               retain_stats=params.get("gamma_retain_stats"),
                               p_per_sample=params["p_gamma"]))

        # 4. Mirror Transform
        if params.get("do_mirror") or params.get("mirror"):
            tr_transforms.append(
                MirrorTransform(params.get("mirror_axes"),
                                extra_label_keys=extra_label_keys))

        # if params.get("mask_was_used_for_normalization") is not None:
        #     mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
        #     tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))

        tr_transforms.append(
            RemoveLabelTransform(-1, 0, extra_label_keys=extra_label_keys))
        tr_transforms.append(RenameTransform('data', 'image', True))
        tr_transforms.append(RenameTransform('seg', 'gt', True))

        if deep_supervision_scales is not None:
            if soft_ds:
                assert classes is not None
                tr_transforms.append(
                    DownsampleSegForDSTransform3(deep_supervision_scales, 'gt',
                                                 'gt', classes))
            else:
                tr_transforms.append(
                    DownsampleSegForDSTransform2(
                        deep_supervision_scales,
                        0,
                        0,
                        input_key='gt',
                        output_key='gt',
                        extra_label_keys=extra_label_keys))
        toTensorKeys = [
            'image', 'gt'
        ] + extra_label_keys if extra_label_keys is not None else [
            'image', 'gt'
        ]
        tr_transforms.append(NumpyToTensor(toTensorKeys, 'float'))
        tr_transforms = Compose(tr_transforms)

        if seeds_train is not None:
            seeds_train = [seeds_train] * params.get('num_threads')
        if use_conf:
            num_threads = 1
            num_cached_per_thread = 1
        else:
            num_threads, num_cached_per_thread = params.get(
                'num_threads'), params.get("num_cached_per_thread")
        batchgenerator_train = MultiThreadedAugmenter(dataloader_train,
                                                      tr_transforms,
                                                      num_threads,
                                                      num_cached_per_thread,
                                                      seeds=seeds_train,
                                                      pin_memory=pin_memory)

        val_transforms = []
        val_transforms.append(
            RemoveLabelTransform(-1, 0, extra_label_keys=extra_label_keys))
        if params.get("selected_data_channels") is not None:
            val_transforms.append(
                DataChannelSelectionTransform(
                    params.get("selected_data_channels")))
        if params.get("selected_seg_channels") is not None:
            val_transforms.append(
                SegChannelSelectionTransform(
                    params.get("selected_seg_channels")))
        val_transforms.append(RenameTransform('data', 'image', True))
        val_transforms.append(RenameTransform('seg', 'gt', True))

        if deep_supervision_scales is not None:
            if soft_ds:
                assert classes is not None
                val_transforms.append(
                    DownsampleSegForDSTransform3(deep_supervision_scales, 'gt',
                                                 'gt', classes))
            else:
                val_transforms.append(
                    DownsampleSegForDSTransform2(
                        deep_supervision_scales,
                        0,
                        0,
                        input_key='gt',
                        output_key='gt',
                        extra_label_keys=extra_label_keys))

        val_transforms.append(NumpyToTensor(toTensorKeys, 'float'))
        val_transforms = Compose(val_transforms)

        # batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
        #                                             params.get("num_cached_per_thread"),
        #                                             seeds=seeds_val, pin_memory=pin_memory)
        if seeds_val is not None:
            seeds_val = [seeds_val] * 1
        # batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, 1,
        #                                             params.get("num_cached_per_thread"),
        #                                             seeds=seeds_val, pin_memory=False)
        batchgenerator_val = SingleThreadedAugmenter(dataloader_val,
                                                     val_transforms)

    else:
        val_transforms = []
        val_transforms.append(
            RemoveLabelTransform(-1, 0, extra_label_keys=extra_label_keys))
        if params.get("selected_data_channels") is not None:
            val_transforms.append(
                DataChannelSelectionTransform(
                    params.get("selected_data_channels")))
        if params.get("selected_seg_channels") is not None:
            val_transforms.append(
                SegChannelSelectionTransform(
                    params.get("selected_seg_channels")))
        val_transforms.append(RenameTransform('data', 'image', True))
        val_transforms.append(RenameTransform('seg', 'gt', True))

        if deep_supervision_scales is not None:
            if soft_ds:
                assert classes is not None
                val_transforms.append(
                    DownsampleSegForDSTransform3(deep_supervision_scales, 'gt',
                                                 'gt', classes))
            else:
                val_transforms.append(
                    DownsampleSegForDSTransform2(
                        deep_supervision_scales,
                        0,
                        0,
                        input_key='gt',
                        output_key='gt',
                        extra_label_keys=extra_label_keys))

        toTensorKeys = [
            'image', 'gt'
        ] + extra_label_keys if extra_label_keys is not None else [
            'image', 'gt'
        ]
        val_transforms.append(NumpyToTensor(toTensorKeys, 'float'))
        val_transforms = Compose(val_transforms)
        batchgenerator_val = SingleThreadedAugmenter(dataloader_val,
                                                     val_transforms)
        if dataloader_train is not None:
            batchgenerator_train = SingleThreadedAugmenter(
                dataloader_train, val_transforms)
        else:
            batchgenerator_train = None

    return batchgenerator_train, batchgenerator_val
예제 #10
0
 def run(self, img_data, seg_data):
     # Define label for segmentation for segmentation augmentation
     if self.seg_augmentation: seg_label = "seg"
     else: seg_label = "class"
     # Create a parser for the batchgenerators module
     data_generator = DataParser(img_data, seg_data, seg_label)
     # Initialize empty transform list
     transforms = []
     # Add mirror augmentation
     if self.mirror:
         aug_mirror = MirrorTransform(axes=self.config_mirror_axes)
         transforms.append(aug_mirror)
     # Add contrast augmentation
     if self.contrast:
         aug_contrast = ContrastAugmentationTransform(
             self.config_contrast_range,
             preserve_range=self.config_contrast_preserverange,
             per_channel=self.coloraug_per_channel,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_contrast)
     # Add brightness augmentation
     if self.brightness:
         aug_brightness = BrightnessMultiplicativeTransform(
             self.config_brightness_range,
             per_channel=self.coloraug_per_channel,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_brightness)
     # Add gamma augmentation
     if self.gamma:
         aug_gamma = GammaTransform(self.config_gamma_range,
                                    invert_image=False,
                                    per_channel=self.coloraug_per_channel,
                                    retain_stats=True,
                                    p_per_sample=self.config_p_per_sample)
         transforms.append(aug_gamma)
     # Add gaussian noise augmentation
     if self.gaussian_noise:
         aug_gaussian_noise = GaussianNoiseTransform(
             self.config_gaussian_noise_range,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_gaussian_noise)
     # Add spatial transformations as augmentation
     # (rotation, scaling, elastic deformation)
     if self.rotations or self.scaling or self.elastic_deform or \
         self.cropping:
         # Identify patch shape (full image or cropping)
         if self.cropping: patch_shape = self.cropping_patch_shape
         else: patch_shape = img_data[0].shape[0:-1]
         # Assembling the spatial transformation
         aug_spatial_transform = SpatialTransform(
             patch_shape, [i // 2 for i in patch_shape],
             do_elastic_deform=self.elastic_deform,
             alpha=self.config_elastic_deform_alpha,
             sigma=self.config_elastic_deform_sigma,
             do_rotation=self.rotations,
             angle_x=self.config_rotations_angleX,
             angle_y=self.config_rotations_angleY,
             angle_z=self.config_rotations_angleZ,
             do_scale=self.scaling,
             scale=self.config_scaling_range,
             border_mode_data='constant',
             border_cval_data=0,
             border_mode_seg='constant',
             border_cval_seg=0,
             order_data=3,
             order_seg=0,
             p_el_per_sample=self.config_p_per_sample,
             p_rot_per_sample=self.config_p_per_sample,
             p_scale_per_sample=self.config_p_per_sample,
             random_crop=self.cropping)
         # Append spatial transformation to transformation list
         transforms.append(aug_spatial_transform)
     # Compose the batchgenerators transforms
     all_transforms = Compose(transforms)
     # Assemble transforms into a augmentation generator
     augmentation_generator = SingleThreadedAugmenter(
         data_generator, all_transforms)
     # Perform the data augmentation x times (x = cycles)
     aug_img_data = None
     aug_seg_data = None
     for i in range(0, self.cycles):
         # Run the computation process for the data augmentations
         augmentation = next(augmentation_generator)
         # Access augmentated data from the batchgenerators data structure
         if aug_img_data is None and aug_seg_data is None:
             aug_img_data = augmentation["data"]
             aug_seg_data = augmentation[seg_label]
         # Concatenate the new data augmentated data with the cached data
         else:
             aug_img_data = np.concatenate(
                 (augmentation["data"], aug_img_data), axis=0)
             aug_seg_data = np.concatenate(
                 (augmentation[seg_label], aug_seg_data), axis=0)
     # Transform data from channel-first back to channel-last structure
     # Data structure channel-first 3D:  (batch, channel, x, y, z)
     # Data structure channel-last 3D:   (batch, x, y, z, channel)
     aug_img_data = np.moveaxis(aug_img_data, 1, -1)
     aug_seg_data = np.moveaxis(aug_seg_data, 1, -1)
     # Return augmentated image and segmentation data
     return aug_img_data, aug_seg_data
예제 #11
0
def main():

    # assign global args
    global args
    args = parser.parse_args()

    # make a folder for the experiment
    general_folder_name = args.output_path
    try:
        os.mkdir(general_folder_name)
    except OSError:
        pass

    # create train, test split, return the indices, patients in test_split wont be seen during whole training
    train_idx, val_idx, test_idx = CreateTrainValTestSplit(
        HistoFile_path=args.input_file,
        num_splits=args.num_splits,
        num_test_folds=args.num_test_folds,
        num_val_folds=args.num_val_folds,
        seed=args.seed)

    IDs = train_idx + val_idx

    print('size of training set {}'.format(len(train_idx)))
    print('size of validation set {}'.format(len(val_idx)))
    print('size of test set {}'.format(len(test_idx)))

    # data loading
    Data = ProstataData(args.input_file)  #For details on this class see README

    # train and validate
    for cv in range(args.cv_start, args.cv_end):
        best_epoch = 0
        train_loss = []
        val_loss = []

        # define patients for training and validation
        train_idx, val_idx = split_training(IDs,
                                            len_val=62,
                                            cv=cv,
                                            cv_runs=args.cv_number)

        oversampling_factor, Slices_total, Natural_probability_tu_slice, Natural_probability_PRO_slice = get_oversampling(
            Data,
            train_idx=sorted(train_idx),
            Batch_Size=args.b,
            patch_size=args.patch_size)

        training_batches = Slices_total / args.b

        lr = args.lr
        base_lr = args.lr
        args.seed += 1

        print('train_idx', train_idx, len(train_idx))
        print('val_idx', val_idx, len(val_idx))

        # get class frequencies
        print('calculating class frequencie')

        Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC,\
        Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2\
        , ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std \
            = get_class_frequencies(Data, train_idx, patch_size=args.patch_size)

        print ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std

        print('ADC', Tumor_frequencie_ADC, Prostate_frequencie_ADC,
              Background_frequencie_ADC)
        print('T2', Tumor_frequencie_T2, Prostate_frequencie_T2,
              Background_frequencie_T2)

        all_ADC = np.float(Background_frequencie_ADC +
                           Prostate_frequencie_ADC + Tumor_frequencie_ADC)
        all_T2 = np.float(Background_frequencie_T2 + Prostate_frequencie_T2 +
                          Tumor_frequencie_T2)

        print all_ADC
        print all_T2

        W1_ADC = 1 / (Background_frequencie_ADC / all_ADC)**0.25
        W2_ADC = 1 / (Prostate_frequencie_ADC / all_ADC)**0.25
        W3_ADC = 1 / (Tumor_frequencie_ADC / all_ADC)**0.25

        Wa_ADC = W1_ADC / (W1_ADC + W2_ADC + W3_ADC)
        Wb_ADC = W2_ADC / (W1_ADC + W2_ADC + W3_ADC)
        Wc_ADC = W3_ADC / (W1_ADC + W2_ADC + W3_ADC)

        print 'Weights ADC', Wa_ADC, Wb_ADC, Wc_ADC

        weight_ADC = (Wa_ADC, Wb_ADC, Wc_ADC)

        W1_T2 = 1 / (Background_frequencie_T2 / all_T2)**0.25
        W2_T2 = 1 / (Prostate_frequencie_T2 / all_T2)**0.25
        W3_T2 = 1 / (Tumor_frequencie_T2 / all_T2)**0.25

        Wa_T2 = W1_T2 / (W1_T2 + W2_T2 + W3_T2)
        Wb_T2 = W2_T2 / (W1_T2 + W2_T2 + W3_T2)
        Wc_T2 = W3_T2 / (W1_T2 + W2_T2 + W3_T2)

        print 'Weights T2', Wa_T2, Wb_T2, Wc_T2

        weight_T2 = (Wa_T2, Wb_T2, Wc_T2)

        # define model
        Net = UNetPytorch(in_shape=(3, args.patch_size[0], args.patch_size[1]))
        Net_Name = 'UNetPytorch'
        model = Net.cuda()

        # model parameter
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr,
                                     weight_decay=args.weight_decay)
        criterion_ADC = CrossEntropyLoss2d(
            weight=torch.FloatTensor(weight_ADC)).cuda()
        criterion_T2 = CrossEntropyLoss2d(
            weight=torch.FloatTensor(weight_T2)).cuda()

        # new folder name for cv
        folder_name = general_folder_name + '/CV_{}'.format(cv)
        try:
            os.mkdir(folder_name)
        except OSError:
            pass

        checkpoint_file = folder_name + '/checkpoint_' + '{}.pth.tar'.format(
            Net_Name)

        # augmentation
        for epoch in range(args.epochs):
            torch.manual_seed(args.seed + epoch + cv)
            np.random.seed(epoch + cv)
            np.random.shuffle(train_idx)

            if epoch == 0:
                my_transforms = []
                spatial_transform = SpatialTransform(
                    args.patch_size,
                    np.array(args.patch_size) // 2,
                    do_elastic_deform=True,
                    alpha=(100., 450.),
                    sigma=(13., 17.),
                    do_rotation=True,
                    angle_z=(-np.pi / 2., np.pi / 2.),
                    do_scale=True,
                    scale=(0.75, 1.25),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    random_crop=True)
                resample_transform = ResampleTransform(zoom_range=(0.7, 1.3))
                brightness_transform = BrightnessTransform(0.0, 0.1, True)
                my_transforms.append(resample_transform)
                my_transforms.append(
                    ContrastAugmentationTransform((0.75, 1.25), True))
                my_transforms.append(brightness_transform)
                my_transforms.append(Mirror((2, 3)))
                all_transforms = Compose(my_transforms)
                sometimes_spatial_transforms = RndTransform(
                    spatial_transform,
                    prob=args.p,
                    alternative_transform=CenterCropTransform(args.patch_size))
                sometimes_other_transforms = RndTransform(all_transforms,
                                                          prob=1.0)
                final_transform = Compose(
                    [sometimes_spatial_transforms, sometimes_other_transforms])
                Center_Crop = CenterCropTransform(args.patch_size)

            if epoch == 30:
                my_transforms = []
                spatial_transform = SpatialTransform(
                    args.patch_size,
                    np.array(args.patch_size) // 2,
                    do_elastic_deform=True,
                    alpha=(0., 250.),
                    sigma=(11., 14.),
                    do_rotation=True,
                    angle_z=(-np.pi / 2., np.pi / 2.),
                    do_scale=True,
                    scale=(0.85, 1.15),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    random_crop=True)
                resample_transform = ResampleTransform(zoom_range=(0.8, 1.2))
                brightness_transform = BrightnessTransform(0.0, 0.1, True)
                my_transforms.append(resample_transform)
                my_transforms.append(
                    ContrastAugmentationTransform((0.85, 1.15), True))
                my_transforms.append(brightness_transform)
                all_transforms = Compose(my_transforms)
                sometimes_spatial_transforms = RndTransform(
                    spatial_transform,
                    prob=args.p,
                    alternative_transform=CenterCropTransform(args.patch_size))
                sometimes_other_transforms = RndTransform(all_transforms,
                                                          prob=1.0)
                final_transform = Compose(
                    [sometimes_spatial_transforms, sometimes_other_transforms])
                Center_Crop = CenterCropTransform(args.patch_size)

            if epoch == 50:
                my_transforms = []
                spatial_transform = SpatialTransform(
                    args.patch_size,
                    np.array(args.patch_size) // 2,
                    do_elastic_deform=True,
                    alpha=(0., 150.),
                    sigma=(10., 12.),
                    do_rotation=True,
                    angle_z=(-np.pi / 2., np.pi / 2.),
                    do_scale=True,
                    scale=(0.85, 1.15),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    random_crop=False)
                resample_transform = ResampleTransform(zoom_range=(0.9, 1.1))
                brightness_transform = BrightnessTransform(0.0, 0.1, True)
                my_transforms.append(resample_transform)
                my_transforms.append(
                    ContrastAugmentationTransform((0.95, 1.05), True))
                my_transforms.append(brightness_transform)
                all_transforms = Compose(my_transforms)
                sometimes_spatial_transforms = RndTransform(
                    spatial_transform,
                    prob=args.p,
                    alternative_transform=CenterCropTransform(args.patch_size))
                sometimes_other_transforms = RndTransform(all_transforms,
                                                          prob=1.0)
                final_transform = Compose(
                    [sometimes_spatial_transforms, sometimes_other_transforms])
                Center_Crop = CenterCropTransform(args.patch_size)

            train_loader = BatchGenerator(
                Data,
                BATCH_SIZE=args.b,
                split_idx=train_idx,
                seed=args.seed,
                ProbabilityTumorSlices=oversampling_factor,
                epoch=epoch,
                ADC_mean=ADC_mean,
                ADC_std=ADC_std,
                BVAL_mean=BVAL_mean,
                BVAL_std=BVAL_std,
                T2_mean=T2_mean,
                T2_std=T2_std)

            val_loader = BatchGenerator(Data,
                                        BATCH_SIZE=0,
                                        split_idx=val_idx,
                                        seed=args.seed,
                                        ProbabilityTumorSlices=None,
                                        epoch=epoch,
                                        test=True,
                                        ADC_mean=ADC_mean,
                                        ADC_std=ADC_std,
                                        BVAL_mean=BVAL_mean,
                                        BVAL_std=BVAL_std,
                                        T2_mean=T2_mean,
                                        T2_std=T2_std)

            #train on training set
            train_losses = train(train_loader=train_loader,
                                 model=model,
                                 optimizer=optimizer,
                                 criterion_ADC=criterion_ADC,
                                 criterion_T2=criterion_T2,
                                 final_transform=final_transform,
                                 workers=args.workers,
                                 seed=args.seed,
                                 training_batches=training_batches)
            train_loss.append(train_losses)

            # evaluate on validation set
            val_losses = validate(val_loader=val_loader,
                                  model=model,
                                  folder_name=folder_name,
                                  criterion_ADC=criterion_ADC,
                                  criterion_T2=criterion_T2,
                                  split_ixs=val_idx,
                                  epoch=epoch,
                                  workers=1,
                                  Center_Crop=Center_Crop,
                                  seed=args.seed)
            val_loss.append(val_losses)

            # write TrainingsCSV to folder name
            TrainingsCSV = pd.DataFrame({
                'train_loss': train_loss,
                'val_loss': val_loss
            })
            TrainingsCSV.to_csv(folder_name + '/TrainingsCSV.csv')

            if val_losses <= min(val_loss):
                best_epoch = epoch
                print 'best epoch', epoch
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    filename=checkpoint_file)

            optimizer, lr = adjust_learning_rate(optimizer, base_lr, epoch)

        # delete all output except best epoch
        clear_image_data(folder_name, best_epoch, epoch)
예제 #12
0
def main():
    image_data_file = [
        'C:/dev/data/Endoviz2018/GIANA/polyp_detection_segmentation/image_data_1.npy',
        'C:/dev/data/Endoviz2018/GIANA/polyp_detection_segmentation/image_data_2.npy',
        'C:/dev/data/Endoviz2018/GIANA/polyp_detection_segmentation/image_data_3.npy'
    ]
    label_data_file = [
        'C:/dev/data/Endoviz2018/GIANA/polyp_detection_segmentation/gt_data_1.npy',
        'C:/dev/data/Endoviz2018/GIANA/polyp_detection_segmentation/gt_data_2.npy',
        'C:/dev/data/Endoviz2018/GIANA/polyp_detection_segmentation/gt_data_3.npy'
    ]

    ## Using batchgenerators
    transform_list = []

    ## Spatial transform
    # transform_list.append(SpatialTransform((600, 600), np.array((600, 600)) // 2,
    #                         do_elastic_deform=False, alpha=(0., 1500.), sigma=(30., 50.),
    #                         do_rotation=True, angle_z=(0, 2 * np.pi),
    #                         do_scale=True, scale=(0.3, 3.),
    #                         border_mode_data='constant', border_cval_data=0, order_data=1,
    #                         random_crop=False))
    # transform_list.append(Mirror(axes=(2, 3)))

    ## Noise transforms
    # transform_list.append(GaussianNoiseTransform(noise_variance=(0.1, 0.5)))
    # transform_list.append(RicianNoiseTransform(noise_variance=(0, 0.3)))

    ## Color transforms
    transform_list.append(
        ContrastAugmentationTransform((0.3, 0.5), preserve_range=True))

    transformations = Compose(transform_list)

    giana_dataset = GianaDataset(image_data_file, label_data_file)
    giana_dataloader = GianaDataGenerator(giana_dataset, 4, 4)
    multithreaded_generator = MultiThreadedAugmenter(giana_dataloader,
                                                     transformations,
                                                     4,
                                                     2,
                                                     seeds=None)
    #
    #for data_dict in multithreaded_generator:

    #data_dict = transformations(**data_dict)

    ## Using pytorch build-in transformations
    # transformations = transforms.Compose([#transforms.ToPILImage(),
    #                                         # transforms.RandomApply([transforms.RandomVerticalFlip(),
    #                                                                 # transforms.RandomHorizontalFlip()]),
    #                                         #transforms.ToTensor(),
    #                                         #Permute(),
    #                                         RicianNoiseTransform(noise_variance=(0, 200))
    #                                         #transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                                                 # std=[0.229, 0.224, 0.225])
    #                                                                 ])
    # giana_dataset = GianaDataset(image_data_file, label_data_file)#, transform=transformations)
    # giana_dataloader = DataLoader(giana_dataset, batch_size=4, shuffle=True)

    ## Using torchsample
    # transformations = Compose([ToTensor(),
    #                      TypeCast('float'),
    #                      #ChannelsFirst(),
    #                      RangeNormalize(0,1),
    #                      RandomGamma(0.2,1.8),
    #                      Brightness(0.4),
    #                      RandomSaturation(0.5,0.9)
    #                     ])

    # giana_dataset = GianaDataset(image_data_file, label_data_file, transform=transformations)
    # giana_dataloader = DataLoader(giana_dataset, batch_size=4, shuffle=True)

    # for data_dict in giana_dataloader.next():
    #     print(data_dict)
    #     break
    for i in range(4):
        data_dict = next(giana_dataloader)
        data_dict = transformations(**data_dict)
        print("Dataset selected: {0}".format(
            giana_dataloader.dataset.selected))
        images, labels = data_dict['data'], data_dict['seg']
        plt.figure()
        idx = 0
        for image, label in zip(images, labels):
            # mylog.show_image(image, name="img plot"+str(idx), title="image title"+str(idx))
            # mylog.show_image(label, name="label plot"+str(idx), title="label title"+str(idx))
            # print(image.shape, label.shape)
            plot_image_label(np.rollaxis(image, 0, 3), np.squeeze(label,
                                                                  axis=0), idx)
            # print(idx, image.shape, label.shape, type(image))
            idx += 1
        # plt.show()
        # break
    plt.show()