Ejemplo n.º 1
0
def get_train_transform(patch_size):
    tr_transforms = []
    tr_transforms.append(
        SpatialTransform_2(
            None, [i // 2 for i in patch_size],
            do_elastic_deform=False,
            deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            do_scale=False,
            scale=(0.8, 1.2),
            border_mode_data='constant',
            border_cval_data=0,
            border_mode_seg='constant',
            border_cval_seg=0,
            order_seg=1,
            order_data=3,
            random_crop=False,
            p_el_per_sample=0.3,
            p_rot_per_sample=0.3,
            p_scale_per_sample=0.3))
    tr_transforms.append(
        RndTransform(MirrorTransform(axes=(0, 1, 2)), prob=0.3))
    tr_transforms = Compose(transforms=tr_transforms)
    return tr_transforms
def get_train_transform(patch_size):
    tr_transforms = []

    tr_transforms.append(
        SpatialTransform_2(
            patch_size, [i // 2 for i in patch_size],
            do_elastic_deform=True,
            deformation_scale=(0, 0.05),
            do_rotation=True,
            angle_x=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi),
            angle_y=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi),
            angle_z=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi),
            do_scale=True,
            scale=(0.75, 1.25),
            border_mode_data='constant',
            border_cval_data=-2.34,
            border_mode_seg='constant',
            border_cval_seg=0))

    tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))
    tr_transforms.append(
        BrightnessMultiplicativeTransform((0.7, 1.5),
                                          per_channel=True,
                                          p_per_sample=0.15))
    tr_transforms.append(
        GammaTransform(gamma_range=(0.5, 2),
                       invert_image=True,
                       per_channel=True,
                       p_per_sample=0.15))
    tr_transforms.append(
        GaussianNoiseTransform(noise_variance=(0, 0.15), p_per_sample=0.15))
    tr_transforms = Compose(tr_transforms)
    return tr_transforms
def get_train_transform(patch_size):
    """
    data augmentation for training data, inspired by:
    https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py
    :param patch_size: shape of network's input
    :return list of transformations
    """

    train_transforms = []

    def rad(deg):
        return (-deg / 360 * 2 * np.pi, deg / 360 * 2 * np.pi)

    train_transforms.append(
        SpatialTransform_2(
            patch_size,
            (10, 10, 10),
            do_elastic_deform=True,
            deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_z=rad(15),
            angle_x=(0, 0),
            angle_y=(0, 0),
            do_scale=True,
            scale=(0.75, 1.25),
            border_mode_data='constant',
            border_cval_data=0,
            border_mode_seg='constant',
            border_cval_seg=0,
            order_seg=1,
            random_crop=False,
            p_el_per_sample=0.2,
            p_rot_per_sample=0.2,
            p_scale_per_sample=0.2,
        ))

    train_transforms.append(MirrorTransform(axes=(0, 1)))

    train_transforms.append(
        BrightnessMultiplicativeTransform((0.7, 1.5),
                                          per_channel=True,
                                          p_per_sample=0.2))

    train_transforms.append(
        GammaTransform(gamma_range=(0.2, 1.0),
                       invert_image=False,
                       per_channel=False,
                       p_per_sample=0.2))

    train_transforms.append(
        GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2))

    train_transforms.append(
        GaussianBlurTransform(blur_sigma=(0.2, 1.0),
                              different_sigma_per_channel=False,
                              p_per_channel=0.0,
                              p_per_sample=0.2))

    return Compose(train_transforms)
Ejemplo n.º 4
0
def get_train_transform(patch_size):
    # we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just
    # to showcase some things
    tr_transforms = []

    # the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus
    # also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the
    # shape and do not transform spatially, so no border artifacts will be introduced
    # Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform
    # We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27%
    # of samples will be augmented, the rest will just be cropped
    tr_transforms.append(
        SpatialTransform_2(
            patch_size, [i // 2 for i in patch_size],
            do_elastic_deform=True,
            deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            do_scale=True,
            scale=(0.75, 1.25),
            border_mode_data='constant',
            border_cval_data=0,
            border_mode_seg='constant',
            border_cval_seg=0,
            order_seg=1,
            order_data=3,
            random_crop=True,
            p_el_per_sample=0.1,
            p_rot_per_sample=0.1,
            p_scale_per_sample=0.1))

    # now we mirror along all axes
    tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))

    # gamma transform. This is a nonlinear transformation of intensity values
    # (https://en.wikipedia.org/wiki/Gamma_correction)
    tr_transforms.append(
        GammaTransform(gamma_range=(0.5, 2),
                       invert_image=False,
                       per_channel=True,
                       p_per_sample=0.15))
    # we can also invert the image, apply the transform and then invert back
    tr_transforms.append(
        GammaTransform(gamma_range=(0.5, 2),
                       invert_image=True,
                       per_channel=True,
                       p_per_sample=0.15))

    # Gaussian Noise
    tr_transforms.append(
        GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15))

    # now we compose these transforms together
    tr_transforms = Compose(tr_transforms)
    return tr_transforms
Ejemplo n.º 5
0
def get_transformer(bbox_image_shape = [256, 256, 256], deformation_scale = 0.2):
    """

    :param bbox_image_shape:  [256, 256, 256]
    :param deformation_scale: 扭曲程度,0几乎没形变,0.2形变很大,故0~0.25是合理的
    :return:
    """
    tr_transforms = []
    # tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))
    # (这个SpatialTransform_2与SpatialTransform的区别就在这里,SpatialTransform_2提供了有一定限制的扭曲变化,保证图像不会过度扭曲)

    tr_transforms.append(
        SpatialTransform_2(
            patch_size=bbox_image_shape,
            patch_center_dist_from_border=[i // 2 for i in bbox_image_shape],
            do_elastic_deform=True, deformation_scale=(deformation_scale, deformation_scale + 0.1),
            do_rotation=False,
            angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),  # 随机旋转的角度
            angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            do_scale=False,
            scale=(0.75, 1.25),
            border_mode_data='constant', border_cval_data=0,
            border_mode_seg='constant', border_cval_seg=0,
            order_seg=1, order_data=3,
            random_crop=False,
            p_el_per_sample=1.0, p_rot_per_sample=1.0, p_scale_per_sample=1.0
        ))
    # tr_transforms.append(
    #     SpatialTransform(
    #         patch_size=bbox_image.shape,
    #         patch_center_dist_from_border=[i // 2 for i in bbox_image.shape],
    #         do_elastic_deform=True, alpha=(2000., 2100.), sigma=(10., 11.),
    #         do_rotation=False,
    #         angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
    #         angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
    #         angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
    #         do_scale=False,
    #         scale=(0.75, 0.75),
    #         border_mode_data='constant', border_cval_data=0,
    #         border_mode_seg='constant', border_cval_seg=0,
    #         order_seg=1, order_data=3,
    #         random_crop=True,
    #         p_el_per_sample=1.0, p_rot_per_sample=1.0, p_scale_per_sample=1.0
    #     ))
    # sigma越小,扭曲越局部(即扭曲的越严重), alpha越大扭曲的越严重
    # tr_transforms.append(
    #     SpatialTransform(bbox_image.shape, [i // 2 for i in bbox_image.shape],
    #                      do_elastic_deform=True, alpha=(1300., 1500.), sigma=(10., 11.),
    #                      do_rotation=False, angle_z=(0, 2 * np.pi),
    #                      do_scale=False, scale=(0.3, 0.7),
    #                      border_mode_data='constant', border_cval_data=0, order_data=1,
    #                      border_mode_seg='constant', border_cval_seg=0,
    #                      random_crop=False))

    all_transforms = Compose(tr_transforms)
    return all_transforms
def get_valid_transform(patch_size):
    """
    data augmentation for validation data
    inspired by:
    https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py
    :param patch_size: shape of network's input
    :return list of transformations
    """

    train_transforms = []

    train_transforms.append(
        SpatialTransform_2(patch_size,
                           patch_size,
                           do_elastic_deform=False,
                           deformation_scale=(0, 0),
                           do_rotation=False,
                           angle_x=(0, 0),
                           angle_y=(0, 0),
                           angle_z=(0, 0),
                           do_scale=False,
                           scale=(1.0, 1.0),
                           border_mode_data='constant',
                           border_cval_data=0,
                           border_mode_seg='constant',
                           border_cval_seg=0,
                           order_seg=1,
                           order_data=3,
                           random_crop=True,
                           p_el_per_sample=0.1,
                           p_rot_per_sample=0.1,
                           p_scale_per_sample=0.1))

    train_transforms.append(MirrorTransform(axes=(0, 1)))

    return Compose(train_transforms)
Ejemplo n.º 7
0
    def _make_training_transforms(self):
        if self.no_data_augmentation:
            print("No data augmentation will be performed during training!")
            return []

        patch_size = self.patch_size[::-1]  # (x, y, z) order
        rot_angle_x = self.training_augmentation_args.get('angle_x', 15)
        rot_angle_y = self.training_augmentation_args.get('angle_y', 15)
        rot_angle_z = self.training_augmentation_args.get('angle_z', 15)
        p_per_sample = self.training_augmentation_args.get(
            'p_per_sample', 0.15)

        train_transforms = [
            SpatialTransform_2(
                patch_size,
                patch_size // 2,
                do_elastic_deform=self.training_augmentation_args.get(
                    'do_elastic_deform', True),
                deformation_scale=self.training_augmentation_args.get(
                    'deformation_scale', (0, 0.25)),
                do_rotation=self.training_augmentation_args.get(
                    'do_rotation', True),
                angle_x=(-rot_angle_x / 360. * 2 * np.pi,
                         rot_angle_x / 360. * 2 * np.pi),
                angle_y=(-rot_angle_y / 360. * 2 * np.pi,
                         rot_angle_y / 360. * 2 * np.pi),
                angle_z=(-rot_angle_z / 360. * 2 * np.pi,
                         rot_angle_z / 360. * 2 * np.pi),
                do_scale=self.training_augmentation_args.get('do_scale', True),
                scale=self.training_augmentation_args.get(
                    'scale', (0.75, 1.25)),
                border_mode_data='nearest',
                border_cval_data=0,
                order_data=3,
                # border_mode_seg='nearest', border_cval_seg=0,
                # order_seg=0,
                random_crop=False,
                p_el_per_sample=self.training_augmentation_args.get(
                    'p_el_per_sample', 0.5),
                p_rot_per_sample=self.training_augmentation_args.get(
                    'p_rot_per_sample', 0.5),
                p_scale_per_sample=self.training_augmentation_args.get(
                    'p_scale_per_sample', 0.5))
        ]

        if self.training_augmentation_args.get("do_mirror", False):
            train_transforms.append(MirrorTransform(axes=(0, 1, 2)))

        train_transforms.append(
            BrightnessMultiplicativeTransform(
                self.training_augmentation_args.get('brightness_range',
                                                    (0.7, 1.5)),
                per_channel=True,
                p_per_sample=p_per_sample))
        train_transforms.append(
            GaussianNoiseTransform(
                noise_variance=self.training_augmentation_args.get(
                    'gaussian_noise_variance', (0, 0.05)),
                p_per_sample=p_per_sample))
        train_transforms.append(
            GammaTransform(gamma_range=self.training_augmentation_args.get(
                'gamma_range', (0.5, 2)),
                           invert_image=False,
                           per_channel=True,
                           p_per_sample=p_per_sample))

        print("train_transforms\n", train_transforms)

        return train_transforms
Ejemplo n.º 8
0
def get_insaneDA_augmentation2(dataloader_train,
                               dataloader_val,
                               patch_size,
                               params=default_3D_augmentation_params,
                               border_val_seg=-1,
                               seeds_train=None,
                               seeds_val=None,
                               order_seg=1,
                               order_data=3,
                               deep_supervision_scales=None,
                               soft_ds=False,
                               classes=None,
                               pin_memory=True,
                               regions=None):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
        patch_size_spatial = patch_size[1:]
    else:
        patch_size_spatial = patch_size
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform_2(
            patch_size_spatial,
            patch_center_dist_from_border=None,
            do_elastic_deform=params.get("do_elastic"),
            deformation_scale=params.get("eldef_deformation_scale"),
            do_rotation=params.get("do_rotation"),
            angle_x=params.get("rotation_x"),
            angle_y=params.get("rotation_y"),
            angle_z=params.get("rotation_z"),
            do_scale=params.get("do_scaling"),
            scale=params.get("scale_range"),
            border_mode_data=params.get("border_mode_data"),
            border_cval_data=0,
            order_data=order_data,
            border_mode_seg="constant",
            border_cval_seg=border_val_seg,
            order_seg=order_seg,
            random_crop=params.get("random_crop"),
            p_el_per_sample=params.get("p_eldef"),
            p_scale_per_sample=params.get("p_scale"),
            p_rot_per_sample=params.get("p_rot"),
            independent_scale_for_each_axis=params.get(
                "independent_scale_factor_for_each_axis"),
            p_independent_scale_per_axis=params.get(
                "p_independent_scale_per_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.5),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3),
                                          p_per_sample=0.15))
    tr_transforms.append(
        ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                      p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.15))  # inverted gamma

    if params.get("do_additive_brightness"):
        tr_transforms.append(
            BrightnessTransform(
                params.get("additive_brightness_mu"),
                params.get("additive_brightness_sigma"),
                True,
                p_per_sample=params.get("additive_brightness_p_per_sample"),
                p_per_channel=params.get("additive_brightness_p_per_channel")))

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations"
                      ) and not None and params.get(
                          "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)
    #batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
Ejemplo n.º 9
0
tr_transforms = []
# tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))
deformation_scale = 0.0  # 0几乎没形变,0.2形变很大,故0~0.25是合理的
# (这个SpatialTransform_2与SpatialTransform的区别就在这里,SpatialTransform_2提供了有一定限制的扭曲变化,保证图像不会过度扭曲)
tr_transforms.append(
    SpatialTransform_2(
        patch_size=bbox_image.shape,
        patch_center_dist_from_border=[i // 2 for i in bbox_image.shape],
        do_elastic_deform=True,
        deformation_scale=(deformation_scale, deformation_scale + 0.1),
        do_rotation=False,
        angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
        angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
        angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
        do_scale=True,
        scale=(0.75, 1.25),
        border_mode_data='constant',
        border_cval_data=0,
        border_mode_seg='constant',
        border_cval_seg=0,
        order_seg=1,
        order_data=3,
        random_crop=False,
        p_el_per_sample=1.0,
        p_rot_per_sample=1.0,
        p_scale_per_sample=1.0))
# tr_transforms.append(
#     SpatialTransform(
#         patch_size=bbox_image.shape,
#         patch_center_dist_from_border=[i // 2 for i in bbox_image.shape],
#         do_elastic_deform=True, alpha=(2000., 2100.), sigma=(10., 11.),
Ejemplo n.º 10
0
batchgen = DataLoader(data.camera(), 1, None, False)
#batch = next(batchgen)

#print(batch['data'].shape)
def plot_batch(batch):
    batch_size = batch['data'].shape[0]
    for i in range(batch_size):
        plt.subplot(1, batch_size, i+1)
        plt.imshow(batch['data'][i, 0], cmap="gray")
    plt.show()
#plot_batch(batch)

my_transforms = []

brightness_transform = ContrastAugmentationTransform((0.3, 3.), preserve_range=True)
my_transforms.append(brightness_transform)

noise_transform = GaussianNoiseTransform(noise_variance=(0, 20)) ##
my_transforms.append(noise_transform)

spatial_transform = SpatialTransform_2(data.camera().shape, np.array(data.camera().shape)//2,
                                     do_elastic_deform=True, deformation_scale=(0,0.05),
                                     do_rotation=True, angle_z=(0, 2*np.pi),
                                     do_scale=True, scale=(0.8, 1.2),
                                     border_mode_data='constant', border_cval_data=0, order_data=1,
                                     random_crop=False)
my_transforms.append(spatial_transform)
all_transforms = Compose(my_transforms)
multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transforms, 4, 2, seeds=None)
plot_batch(next(multithreaded_generator))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                 factor=0.2,
                                                 patience=30)

tr_transform = Compose([
    GammaTransform((0.9, 1.1)),
    ContrastAugmentationTransform((0.9, 1.1)),
    BrightnessMultiplicativeTransform((0.9, 1.1)),
    MirrorTransform(axes=[0]),
    SpatialTransform_2(
        patch_size,
        (90, 90, 50),
        random_crop=True,
        do_elastic_deform=True,
        deformation_scale=(0, 0.05),
        do_rotation=True,
        angle_x=(-0.1 * np.pi, 0.1 * np.pi),
        angle_y=(0, 0),
        angle_z=(0, 0),
        do_scale=True,
        scale=(0.9, 1.1),
        border_mode_data='constant',
    ),
    RandomCropTransform(crop_size=patch_size),
])

vd_transform = Compose([
    RandomCropTransform(crop_size=patch_size),
])

trainer = Trainer(
    model=model,