Exemplo n.º 1
0
    def __init__(
            self,
            x,  # (n_samples, z, y, x, channels) for 3D or (n_samples, y, x, channels) for 2D data
            y,
            transform,  # the transformations from the batchgenerators API
            batch_size=32,
            shuffle=False,
            seed=None,
            save_to_dir=None,
            save_prefix='',
            save_format='png'):

        self.x = self._check_and_transform_x(x)

        self.y = self._check_and_transform_y(y)

        self.save_to_dir = save_to_dir
        self.save_prefix = save_prefix
        self.save_format = save_format

        super().__init__(x.shape[0], batch_size, shuffle, seed)

        if not isinstance(transform, Compose):
            if isinstance(transform, list):
                self.transform = Compose(transform)
            else:
                self.transform = Compose([transform])
        else:
            self.transform = transform

        # print("NumpyArrayIteratorUpTo3D: self.transform=", self.transform)

        self.data_loader = self._create_data_loader()
Exemplo n.º 2
0
 def wrap_transforms(self, dataloader_train, dataloader_val,
                     train_transforms, val_transforms):
     tr_gen = NonDetMultiThreadedAugmenter(dataloader_train,
                                           Compose(train_transforms),
                                           self.num_proc_DA,
                                           self.num_cached,
                                           seeds=None,
                                           pin_memory=self.pin_memory)
     val_gen = NonDetMultiThreadedAugmenter(dataloader_val,
                                            Compose(val_transforms),
                                            self.num_proc_DA // 2,
                                            self.num_cached,
                                            seeds=None,
                                            pin_memory=self.pin_memory)
     return tr_gen, val_gen
Exemplo n.º 3
0
def get_transforms(
        patch_shape=(256, 320), other_transforms=None, random_crop=False):
    """
    Initializes the transforms for training.
    Args:
        patch_shape:
        other_transforms: List of transforms that you would like to add (optional). Defaults to None.
        random_crop (boolean): whether or not you want to random crop or center crop. Currently, the Transformed3DGenerator
        only supports random cropping. Transformed2DGenerator supports both random_crop = True and False.
    """
    ndim = len(patch_shape)
    spatial_transform = SpatialTransform(patch_shape,
                                         do_elastic_deform=True,
                                         alpha=(0., 1500.),
                                         sigma=(30., 80.),
                                         do_rotation=True,
                                         angle_z=(0, 2 * np.pi),
                                         do_scale=True,
                                         scale=(0.75, 2.),
                                         border_mode_data="nearest",
                                         border_cval_data=0,
                                         order_data=1,
                                         random_crop=random_crop,
                                         p_el_per_sample=0.1,
                                         p_scale_per_sample=0.1,
                                         p_rot_per_sample=0.1)
    mirror_transform = MirrorTransform(axes=(0, 1))
    transforms_list = [spatial_transform, mirror_transform]
    if other_transforms is not None:
        transforms_list = transforms_list + other_transforms
    composed = Compose(transforms_list)
    return composed
def get_train_transform(patch_size):
    tr_transforms = []

    tr_transforms.append(
        SpatialTransform_2(
            patch_size, [i // 2 for i in patch_size],
            do_elastic_deform=True,
            deformation_scale=(0, 0.05),
            do_rotation=True,
            angle_x=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi),
            angle_y=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi),
            angle_z=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi),
            do_scale=True,
            scale=(0.75, 1.25),
            border_mode_data='constant',
            border_cval_data=-2.34,
            border_mode_seg='constant',
            border_cval_seg=0))

    tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))
    tr_transforms.append(
        BrightnessMultiplicativeTransform((0.7, 1.5),
                                          per_channel=True,
                                          p_per_sample=0.15))
    tr_transforms.append(
        GammaTransform(gamma_range=(0.5, 2),
                       invert_image=True,
                       per_channel=True,
                       p_per_sample=0.15))
    tr_transforms.append(
        GaussianNoiseTransform(noise_variance=(0, 0.15), p_per_sample=0.15))
    tr_transforms = Compose(tr_transforms)
    return tr_transforms
Exemplo n.º 5
0
def create_data_gen_pipeline(cf, patient_data, do_aug=True, **kwargs):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
    :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :return: multithreaded_generator
    """

    # create instance of batch generator as first element in pipeline.
    data_gen = BatchGenerator(cf, patient_data, **kwargs)

    my_transforms = []
    if do_aug:
        if cf.da_kwargs["mirror"]:
            mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes'])
            my_transforms.append(mirror_transform)

        spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                             random_crop=cf.da_kwargs['random_crop'])

        my_transforms.append(spatial_transform)
    else:
        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))

    my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg))
    all_transforms = Compose(my_transforms)
    # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
    return multithreaded_generator
Exemplo n.º 6
0
def crop_transform_random():
    '''
    从图片的任意位置剪切出一个固定的尺寸
    :return:
    '''
    crop_size = (128, 128)
    batchgen = my_data_loader.DataLoader(camera(), 4)

    # 随机地从图片上剪切除(128,128)尺寸的图片块
    randomCrop = RandomCropTransform(crop_size=crop_size)
    spatial_transform = SpatialTransform(crop_size,
                                         np.array(crop_size) // 2,
                                         do_elastic_deform=True,
                                         alpha=(0., 1500.),
                                         sigma=(30., 50.),
                                         do_rotation=True,
                                         angle_z=(0, 2 * np.pi),
                                         do_scale=True,
                                         scale=(0.5, 2),
                                         border_mode_data='constant',
                                         border_cval_data=0,
                                         order_data=1,
                                         random_crop=False)
    multithreaded_generator = MultiThreadedAugmenter(
        batchgen, Compose([randomCrop, spatial_transform]), 4, 2)
    my_data_loader.plot_batch(multithreaded_generator.__next__())
Exemplo n.º 7
0
    def test_image_pipeline_and_pin_memory(self):
        '''
        This just should not crash
        :return:
        '''
        try:
            import torch
        except ImportError:
            '''dont test if torch is not installed'''
            return

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(
            TransposeAxesTransform(transpose_any_of_these=(0, 1),
                                   p_per_sample=0.5))
        tr_transforms.append(NumpyToTensor(keys='data', cast_to='float'))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True)

        for _ in range(50):
            res = mt.next()

        assert isinstance(res['data'], torch.Tensor)
        assert res['data'].is_pinned()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
Exemplo n.º 8
0
    def get_batches(self,
                    batch_size=128,
                    type=None,
                    subjects=None,
                    num_batches=None):
        data = type
        seg = []

        num_processes = 1  # 6 is a bit faster than 16
        nr_of_samples = len(subjects) * self.HP.INPUT_DIM[0]
        if num_batches is None:
            num_batches_multithr = int(
                nr_of_samples / batch_size /
                num_processes)  #number of batches for exactly one epoch
        else:
            num_batches_multithr = int(num_batches / num_processes)

        batch_gen = SlicesBatchGeneratorPrecomputedBatches(
            (data, seg),
            BATCH_SIZE=batch_size,
            num_batches=num_batches_multithr,
            seed=None)
        batch_gen.HP = self.HP

        batch_gen = MultiThreadedAugmenter(batch_gen,
                                           Compose([]),
                                           num_processes=num_processes,
                                           num_cached_per_queue=1,
                                           seeds=None)
        return batch_gen
Exemplo n.º 9
0
    def get_batches(self, batch_size=1):
        data = np.nan_to_num(self.data)
        # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth))
        seg = np.zeros(
            (self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0],
             self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE)

        num_processes = 1  # not not use more than 1 if you want to keep original slice order (Threads do return in random order)
        batch_gen = SlicesBatchGenerator((data, seg), BATCH_SIZE=batch_size)
        batch_gen.HP = self.HP
        tfs = []  # transforms

        if self.HP.NORMALIZE_DATA:
            tfs.append(
                ZeroMeanUnitVarianceTransform(
                    per_channel=self.HP.NORMALIZE_PER_CHANNEL))
        tfs.append(ReorderSegTransform())
        batch_gen = MultiThreadedAugmenter(
            batch_gen,
            Compose(tfs),
            num_processes=num_processes,
            num_cached_per_queue=2,
            seeds=None
        )  # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
Exemplo n.º 10
0
def spatial_transforms():
    '''
    空间变换:变形、缩放、旋转
    :return:
    '''
    # 变形+旋转+缩放
    spatial_transform = SpatialTransform(camera().shape,
                                         np.array(camera().shape) // 2,
                                         do_elastic_deform=True,
                                         alpha=(0., 1500.),
                                         sigma=(30., 50.),
                                         do_rotation=True,
                                         angle_z=(0, 2 * np.pi),
                                         do_scale=True,
                                         scale=(0.3, 3.),
                                         border_mode_data='constant',
                                         border_cval_data=0,
                                         order_data=1,
                                         random_crop=False)
    my_transforms = []
    my_transforms.append(spatial_transform)
    all_transforms = Compose(my_transforms)
    batchgen = my_data_loader.DataLoader(camera(), 4)
    multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transforms,
                                                     4, 2)

    # 显示转换效果
    my_data_loader.plot_batch(multithreaded_generator.__next__())
Exemplo n.º 11
0
def _augment_data(Config, batch_generator, type=None):
    batch_gen = MultiThreadedAugmenter(batch_generator,
                                       Compose([]),
                                       num_processes=1,
                                       num_cached_per_queue=1,
                                       seeds=None)
    return batch_gen
Exemplo n.º 12
0
def random_transform():
    '''
    随机地对某些批次的数据进行变换
    :return:
    '''
    spatial_transform = SpatialTransform(camera().shape,
                                         np.array(camera().shape) // 2,
                                         do_elastic_deform=True,
                                         alpha=(0., 1500.),
                                         sigma=(30., 50.),
                                         do_rotation=True,
                                         angle_z=(0, 2 * np.pi),
                                         do_scale=True,
                                         scale=(0.3, 3.),
                                         border_mode_data='constant',
                                         border_cval_data=0,
                                         order_data=1,
                                         random_crop=False)

    sometimes_spatial_transform = RndTransform(spatial_transform, prob=0.5)
    batchgen = my_data_loader.DataLoader(camera(), 4)
    multithreaded_generator = MultiThreadedAugmenter(
        batchgen, Compose([sometimes_spatial_transform]), 4, 2)
    for _ in range(4):
        my_data_loader.plot_batch(multithreaded_generator.__next__())
Exemplo n.º 13
0
def get_transformer(bbox_image_shape = [256, 256, 256], deformation_scale = 0.2):
    """

    :param bbox_image_shape:  [256, 256, 256]
    :param deformation_scale: 扭曲程度,0几乎没形变,0.2形变很大,故0~0.25是合理的
    :return:
    """
    tr_transforms = []
    # tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))
    # (这个SpatialTransform_2与SpatialTransform的区别就在这里,SpatialTransform_2提供了有一定限制的扭曲变化,保证图像不会过度扭曲)

    tr_transforms.append(
        SpatialTransform_2(
            patch_size=bbox_image_shape,
            patch_center_dist_from_border=[i // 2 for i in bbox_image_shape],
            do_elastic_deform=True, deformation_scale=(deformation_scale, deformation_scale + 0.1),
            do_rotation=False,
            angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),  # 随机旋转的角度
            angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            do_scale=False,
            scale=(0.75, 1.25),
            border_mode_data='constant', border_cval_data=0,
            border_mode_seg='constant', border_cval_seg=0,
            order_seg=1, order_data=3,
            random_crop=False,
            p_el_per_sample=1.0, p_rot_per_sample=1.0, p_scale_per_sample=1.0
        ))
    # tr_transforms.append(
    #     SpatialTransform(
    #         patch_size=bbox_image.shape,
    #         patch_center_dist_from_border=[i // 2 for i in bbox_image.shape],
    #         do_elastic_deform=True, alpha=(2000., 2100.), sigma=(10., 11.),
    #         do_rotation=False,
    #         angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
    #         angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
    #         angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
    #         do_scale=False,
    #         scale=(0.75, 0.75),
    #         border_mode_data='constant', border_cval_data=0,
    #         border_mode_seg='constant', border_cval_seg=0,
    #         order_seg=1, order_data=3,
    #         random_crop=True,
    #         p_el_per_sample=1.0, p_rot_per_sample=1.0, p_scale_per_sample=1.0
    #     ))
    # sigma越小,扭曲越局部(即扭曲的越严重), alpha越大扭曲的越严重
    # tr_transforms.append(
    #     SpatialTransform(bbox_image.shape, [i // 2 for i in bbox_image.shape],
    #                      do_elastic_deform=True, alpha=(1300., 1500.), sigma=(10., 11.),
    #                      do_rotation=False, angle_z=(0, 2 * np.pi),
    #                      do_scale=False, scale=(0.3, 0.7),
    #                      border_mode_data='constant', border_cval_data=0, order_data=1,
    #                      border_mode_seg='constant', border_cval_seg=0,
    #                      random_crop=False))

    all_transforms = Compose(tr_transforms)
    return all_transforms
Exemplo n.º 14
0
def crop_transform_center_seg():
    '''
    从图片的正正中心剪切分割
    :return:
    '''
    crop_size = (128, 128)
    batchgen = my_data_loader.DataLoader(camera(), 4)
    centerCropSeg = CenterCropSegTransform(output_size=crop_size)
    multithreaded_generator = MultiThreadedAugmenter(batchgen,
                                                     Compose([centerCropSeg]),
                                                     4, 2)
    my_data_loader.plot_batch(multithreaded_generator.__next__())
Exemplo n.º 15
0
    def _augment_data(self, batch_generator, type=None):

        if self.Config.DATA_AUGMENTATION:
            num_processes = 16  # 2D: 8 is a bit faster than 16
            # num_processes = 8
        else:
            num_processes = 6

        tfs = []  #transforms

        if self.Config.NORMALIZE_DATA:
            tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL))

        if self.Config.DATA_AUGMENTATION:
            if type == "train":
                # scale: inverted: 0.5 -> bigger; 2 -> smaller
                # patch_center_dist_from_border: if 144/2=72 -> always exactly centered; otherwise a bit off center (brain can get off image and will be cut then)

                if self.Config.DAUG_SCALE:
                    center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
                    tfs.append(SpatialTransform(self.Config.INPUT_DIM,
                                                patch_center_dist_from_border=center_dist_from_border,
                                                do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM,
                                                alpha=(90., 120.), sigma=(9., 11.),
                                                do_rotation=self.Config.DAUG_ROTATE,
                                                angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8),
                                                do_scale=True, scale=(0.9, 1.5), border_mode_data='constant',
                                                border_cval_data=0,
                                                order_data=3,
                                                border_mode_seg='constant', border_cval_seg=0,
                                                order_seg=0, random_crop=True, p_el_per_sample=0.2,
                                                p_rot_per_sample=0.2, p_scale_per_sample=0.2))

                if self.Config.DAUG_RESAMPLE:
                    tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2))

                if self.Config.DAUG_NOISE:
                    tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2))

                if self.Config.DAUG_MIRROR:
                    tfs.append(MirrorTransform())

                if self.Config.DAUG_FLIP_PEAKS:
                    tfs.append(FlipVectorAxisTransform())

        tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float"))

        #num_cached_per_queue 1 or 2 does not really make a difference
        batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes,
                                           num_cached_per_queue=1, seeds=None, pin_memory=True)
        return batch_gen    # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
def create_data_gen_pipeline(cf, cities=None, data_split='train', do_aug=True, random=True, n_batches=None):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param cities: list of strings or None
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset
    :param test_pids: (optional) list of test patient ids, calls the test generator.
    :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :param random: bool, whether to draw random batches or go through data linearly
    :return: multithreaded_generator
    """
    data_gen = BatchGenerator(cities=cities, batch_size=cf.batch_size, data_dir=cf.data_dir,
                              label_density=cf.label_density, data_split=data_split, resolution=cf.resolution,
                              gt_instances=cf.gt_instances, n_batches=n_batches, random=random)
    my_transforms = []
    if do_aug:
        mirror_transform = MirrorTransform(axes=(3,))
        my_transforms.append(mirror_transform)
        spatial_transform = SpatialTransform(patch_size=cf.patch_size[-2:],
                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                             random_crop=cf.da_kwargs['random_crop'],
                                             border_mode_data=cf.da_kwargs['border_mode_data'],
                                             border_mode_seg=cf.da_kwargs['border_mode_seg'],
                                             border_cval_seg=cf.da_kwargs['border_cval_seg'])
        my_transforms.append(spatial_transform)
    else:
        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[-2:]))

    my_transforms.append(GammaTransform(cf.da_kwargs['gamma_range'], invert_image=False, per_channel=True,
                                        retain_stats=cf.da_kwargs['gamma_retain_stats'],
                                        p_per_sample=cf.da_kwargs['p_gamma']))
    my_transforms.append(AddLossMask(cf.ignore_label))
    if cf.label_switches is not None:
        my_transforms.append(StochasticLabelSwitches(cf.name2trainId, cf.label_switches))
    all_transforms = Compose(my_transforms)
    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers,
                                                     seeds=range(cf.n_workers))
    return multithreaded_generator
Exemplo n.º 17
0
    def get_batches(self,
                    batch_size=128,
                    type=None,
                    subjects=None,
                    num_batches=None):
        data = type
        seg = []

        num_processes = 1  # 6 is a bit faster than 16

        batch_gen = SlicesBatchGeneratorPrecomputedBatches(
            (data, seg), batch_size=batch_size)
        batch_gen.HP = self.HP

        batch_gen = MultiThreadedAugmenter(batch_gen,
                                           Compose([]),
                                           num_processes=num_processes,
                                           num_cached_per_queue=1,
                                           seeds=None)
        return batch_gen
Exemplo n.º 18
0
def create_data_gen_pipeline(patient_data, cf, test_pids=None, do_aug=True):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset
    :param test_pids: (optional) list of test patient ids, calls the test generator.
    :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :return: multithreaded_generator
    """
    if test_pids is None:
        data_gen = BatchGenerator(patient_data, batch_size=cf.batch_size,
                                 pre_crop_size=cf.pre_crop_size, dim=cf.dim)
    else:
        data_gen = TestGenerator(patient_data, batch_size=cf.batch_size, n_batches=None,
                                 pre_crop_size=cf.pre_crop_size, test_pids=test_pids, dim=cf.dim)
        cf.n_workers = 1

    my_transforms = []
    if do_aug:
        mirror_transform = Mirror(axes=(2, 3))
        my_transforms.append(mirror_transform)
        spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                             random_crop=cf.da_kwargs['random_crop'])

        my_transforms.append(spatial_transform)
    else:
        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))

    my_transforms.append(ConvertSegToOnehotTransform(classes=(0, 1, 2)))
    my_transforms.append(TransposeChannels())
    all_transforms = Compose(my_transforms)
    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
    return multithreaded_generator
Exemplo n.º 19
0
def combine_transform():
    '''
    组合变换:对比度+镜像
    :return:
    '''
    my_transforms = []
    # 对比度变换
    brightness_transform = ContrastAugmentationTransform((0.3, 3.),
                                                         preserve_range=True)
    my_transforms.append(brightness_transform)

    # 镜像变换
    mirror_transform = MirrorTransform(axes=(2, 3))
    my_transforms.append(mirror_transform)

    all_transform = Compose(my_transforms)

    batchgen = my_data_loader.DataLoader(camera(), 4)
    multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transform,
                                                     4, 2)

    # 显示转换效果
    my_data_loader.plot_batch(multithreaded_generator.__next__())
Exemplo n.º 20
0
    def test_image_pipeline(self):
        '''
        This just should not crash
        :return:
        '''

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(
            TransposeAxesTransform(transpose_any_of_these=(0, 1),
                                   p_per_sample=0.5))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False)

        for _ in range(50):
            res = mt.next()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
def get_train_transform(patch_size):
    # we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just
    # to showcase some things
    tr_transforms = []

    # the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus
    # also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the
    # shape and do not transform spatially, so no border artifacts will be introduced
    # Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform
    # We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27%
    # of samples will be augmented, the rest will just be cropped
    tr_transforms.append(
        SpatialTransform_2(
            patch_size, [i // 2 for i in patch_size],
            do_elastic_deform=True,
            deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            do_scale=True,
            scale=(0.75, 1.25),
            border_mode_data='constant',
            border_cval_data=0,
            border_mode_seg='constant',
            border_cval_seg=0,
            order_seg=1,
            order_data=3,
            random_crop=True,
            p_el_per_sample=0.1,
            p_rot_per_sample=0.1,
            p_scale_per_sample=0.1))

    # now we mirror along all axes
    tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))

    # brightness transform for 15% of samples
    tr_transforms.append(
        BrightnessMultiplicativeTransform((0.7, 1.5),
                                          per_channel=True,
                                          p_per_sample=0.15))

    # gamma transform. This is a nonlinear transformation of intensity values
    # (https://en.wikipedia.org/wiki/Gamma_correction)
    tr_transforms.append(
        GammaTransform(gamma_range=(0.5, 2),
                       invert_image=False,
                       per_channel=True,
                       p_per_sample=0.15))
    # we can also invert the image, apply the transform and then invert back
    tr_transforms.append(
        GammaTransform(gamma_range=(0.5, 2),
                       invert_image=True,
                       per_channel=True,
                       p_per_sample=0.15))

    # Gaussian Noise
    tr_transforms.append(
        GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15))

    # blurring. Some BraTS cases have very blurry modalities. This can simulate more patients with this problem and
    # thus make the model more robust to it
    tr_transforms.append(
        GaussianBlurTransform(blur_sigma=(0.5, 1.5),
                              different_sigma_per_channel=True,
                              p_per_channel=0.5,
                              p_per_sample=0.15))

    # now we compose these transforms together
    tr_transforms = Compose(tr_transforms)
    return tr_transforms
Exemplo n.º 22
0
    def _augment_data(self, batch_generator, type=None):

        if self.Config.DATA_AUGMENTATION:
            num_processes = 15  # 15 is a bit faster than 8 on cluster
            # num_processes = multiprocessing.cpu_count()  # on cluster: gives all cores, not only assigned cores
        else:
            num_processes = 6

        tfs = []

        if self.Config.NORMALIZE_DATA:
            tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL))

        if self.Config.SPATIAL_TRANSFORM == "SpatialTransformPeaks":
            SpatialTransformUsed = SpatialTransformPeaks
        elif self.Config.SPATIAL_TRANSFORM == "SpatialTransformCustom":
            SpatialTransformUsed = SpatialTransformCustom
        else:
            SpatialTransformUsed = SpatialTransform

        if self.Config.DATA_AUGMENTATION:
            if type == "train":
                # patch_center_dist_from_border:
                #   if 144/2=72 -> always exactly centered; otherwise a bit off center
                #   (brain can get off image and will be cut then)
                if self.Config.DAUG_SCALE:

                    if self.Config.INPUT_RESCALING:
                        source_mm = 2  # for bb
                        target_mm = float(self.Config.RESOLUTION[:-2])
                        scale_factor = target_mm / source_mm
                        scale = (scale_factor, scale_factor)
                    else:
                        scale = (0.9, 1.5)

                    if self.Config.PAD_TO_SQUARE:
                        patch_size = self.Config.INPUT_DIM
                    else:
                        patch_size = None  # keeps dimensions of the data

                    # spatial transform automatically crops/pads to correct size
                    center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
                    tfs.append(SpatialTransformUsed(patch_size,
                                                patch_center_dist_from_border=center_dist_from_border,
                                                do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM,
                                                alpha=self.Config.DAUG_ALPHA, sigma=self.Config.DAUG_SIGMA,
                                                do_rotation=self.Config.DAUG_ROTATE,
                                                angle_x=self.Config.DAUG_ROTATE_ANGLE,
                                                angle_y=self.Config.DAUG_ROTATE_ANGLE,
                                                angle_z=self.Config.DAUG_ROTATE_ANGLE,
                                                do_scale=True, scale=scale, border_mode_data='constant',
                                                border_cval_data=0,
                                                order_data=3,
                                                border_mode_seg='constant', border_cval_seg=0,
                                                order_seg=0, random_crop=True,
                                                p_el_per_sample=self.Config.P_SAMP,
                                                p_rot_per_sample=self.Config.P_SAMP,
                                                p_scale_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_RESAMPLE:
                    tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2, per_channel=False))

                if self.Config.DAUG_RESAMPLE_LEGACY:
                    tfs.append(ResampleTransformLegacy(zoom_range=(0.5, 1)))

                if self.Config.DAUG_GAUSSIAN_BLUR:
                    tfs.append(GaussianBlurTransform(blur_sigma=self.Config.DAUG_BLUR_SIGMA,
                                                     different_sigma_per_channel=False,
                                                     p_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_NOISE:
                    tfs.append(GaussianNoiseTransform(noise_variance=self.Config.DAUG_NOISE_VARIANCE,
                                                      p_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_MIRROR:
                    tfs.append(MirrorTransform())

                if self.Config.DAUG_FLIP_PEAKS:
                    tfs.append(FlipVectorAxisTransform())

        tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float"))

        #num_cached_per_queue 1 or 2 does not really make a difference
        batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes,
                                           num_cached_per_queue=1, seeds=None, pin_memory=True)
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
def get_default_augmentation(dataloader_train,
                             dataloader_val,
                             patch_size,
                             params=default_3D_augmentation_params,
                             border_val_seg=-1,
                             pin_memory=True,
                             seeds_train=None,
                             seeds_val=None,
                             regions=None):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())
        patch_size_spatial = patch_size[1:]
    else:
        patch_size_spatial = patch_size

    # Set order_data=0 and order_seg=0 for some more speed for cascade???
    tr_transforms.append(
        SpatialTransform(patch_size_spatial,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=3,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=1,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis")))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations"
                      ) and not None and params.get(
                          "cascade_do_cascade_augmentations"):
            # Remove the following transforms to remove cascade DA ??
            tr_transforms.append(
                ApplyRandomBinaryOperatorTransform(
                    channel_idx=list(
                        range(-len(params.get("all_segmentation_labels")), 0)),
                    p_per_sample=params.get(
                        "cascade_random_binary_transform_p"),
                    key="data",
                    strel_size=params.get(
                        "cascade_random_binary_transform_size")))
            tr_transforms.append(
                RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                    channel_idx=list(
                        range(-len(params.get("all_segmentation_labels")), 0)),
                    key="data",
                    p_per_sample=params.get("cascade_remove_conn_comp_p"),
                    fill_with_other_class_p=params.get(
                        "cascade_remove_conn_comp_max_size_percent_threshold"),
                    dont_do_if_covers_more_than_X_percent=params.get(
                        "cascade_remove_conn_comp_fill_with_other_class_p")))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))

    tr_transforms = Compose(tr_transforms)
    # from batchgenerators.dataloading import SingleThreadedAugmenter
    # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
    # import IPython;IPython.embed()

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
Exemplo n.º 24
0
#         border_mode_seg='constant', border_cval_seg=0,
#         order_seg=1, order_data=3,
#         random_crop=True,
#         p_el_per_sample=1.0, p_rot_per_sample=1.0, p_scale_per_sample=1.0
#     ))
# sigma越小,扭曲越局部(即扭曲的越严重), alpha越大扭曲的越严重
# tr_transforms.append(
#     SpatialTransform(bbox_image.shape, [i // 2 for i in bbox_image.shape],
#                      do_elastic_deform=True, alpha=(1300., 1500.), sigma=(10., 11.),
#                      do_rotation=False, angle_z=(0, 2 * np.pi),
#                      do_scale=False, scale=(0.3, 0.7),
#                      border_mode_data='constant', border_cval_data=0, order_data=1,
#                      border_mode_seg='constant', border_cval_seg=0,
#                      random_crop=False))

all_transforms = Compose(tr_transforms)

bbox_image_batch_trans = all_transforms(**bbox_image_batch)  # 加入**相当于
# bbox_image_batch_trans1 = all_transforms(**bbox_image_batch)

# show3D(np.concatenate([np.squeeze(bbox_image_batch['data'][0],axis=0),np.squeeze(bbox_image_batch_trans['data'][1],axis=0)],axis=1))
# show3Dslice(np.concatenate([np.squeeze(bbox_image_batch['data'][0],axis=0),np.squeeze(bbox_image_batch_trans['data'][1],axis=0)],axis=1))
# show3Dslice(np.concatenate([np.squeeze(bbox_image_batch['seg'][0],axis=0),np.squeeze(bbox_image_batch_trans['seg'][1],axis=0)],axis=1))
show3D(
    np.concatenate([
        np.squeeze(bbox_image_batch['seg'][0], axis=0),
        np.squeeze(bbox_image_batch_trans['seg'][0], axis=0)
    ],
                   axis=1))
show3D(
    np.concatenate([
Exemplo n.º 25
0
    numpy_to_tensor = NumpyToTensor(['data', 'labels'], cast_to=None)
    fname = os.path.join(dataset_dir, 'cifar10_training_data.npz')
    dataset = np.load(fname)
    cifar_dataset_as_arrays = (dataset['data'], dataset['labels'],
                               dataset['filenames'])
    print('batch_size', batch_size)
    print('num_workers', num_workers)
    print('pin_memory', pin_memory)
    print('num_epochs', num_epochs)

    tr_transforms = [
        SpatialTransform((32, 32))
    ] * 1  # SpatialTransform is computationally expensive and we need some
    # load on CPU so we just stack 5 of them on top of each other
    tr_transforms.append(numpy_to_tensor)
    tr_transforms = Compose(tr_transforms)

    cifar_dataset = CifarDataset(dataset_dir,
                                 train=True,
                                 transform=tr_transforms)

    dl = DataLoaderFromDataset(cifar_dataset, batch_size, num_workers, 1)
    mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, pin_memory)

    batches = 0
    for _ in mt:
        batches += 1
    assert len(_['data'].shape) == 4

    assert batches == len(
        cifar_dataset
Exemplo n.º 26
0
    def get_batches(self, batch_size=1):

        num_processes = 1  # not not use more than 1 if you want to keep original slice order (Threads do return in random order)

        if self.HP.TYPE == "combined":
            # Load from Npy file for Fusion
            data = self.subject
            seg = []
            nr_of_samples = len([self.subject]) * self.HP.INPUT_DIM[0]
            num_batches = int(nr_of_samples / batch_size / num_processes)
            batch_gen = SlicesBatchGeneratorNpyImg_fusion(
                (data, seg),
                BATCH_SIZE=batch_size,
                num_batches=num_batches,
                seed=None)
        else:
            # Load Features
            if self.HP.FEATURES_FILENAME == "12g90g270g":
                data_img = nib.load(
                    join(self.data_dir, "270g_125mm_peaks.nii.gz"))
            else:
                data_img = nib.load(
                    join(self.data_dir, self.HP.FEATURES_FILENAME + ".nii.gz"))
            data = data_img.get_data()
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(
                data, self.HP.DATASET, self.HP.RESOLUTION)
            # data = DatasetUtils.scale_input_to_unet_shape(data, "HCP_32g", "1.25mm")  #If we want to test HCP_32g on HighRes net

            #Load Segmentation
            if self.use_gt_mask:
                seg = nib.load(
                    join(self.data_dir,
                         self.HP.LABELS_FILENAME + ".nii.gz")).get_data()

                if self.HP.LABELS_FILENAME not in [
                        "bundle_peaks_11_808080", "bundle_peaks_20_808080",
                        "bundle_peaks_808080", "bundle_masks_20_808080",
                        "bundle_masks_72_808080", "bundle_peaks_Part1_808080",
                        "bundle_peaks_Part2_808080",
                        "bundle_peaks_Part3_808080",
                        "bundle_peaks_Part4_808080"
                ]:
                    if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
                        # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask
                        seg = DatasetUtils.scale_input_to_unet_shape(
                            seg, "HCP", self.HP.RESOLUTION)
                    else:
                        seg = DatasetUtils.scale_input_to_unet_shape(
                            seg, self.HP.DATASET, self.HP.RESOLUTION)
            else:
                # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth))
                seg = np.zeros(
                    (self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0],
                     self.HP.INPUT_DIM[0],
                     self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE)

            batch_gen = SlicesBatchGenerator((data, seg),
                                             batch_size=batch_size)

        batch_gen.HP = self.HP
        tfs = []  # transforms

        if self.HP.NORMALIZE_DATA:
            tfs.append(ZeroMeanUnitVarianceTransform(per_channel=False))

        if self.HP.TEST_TIME_DAUG:
            center_dist_from_border = int(
                self.HP.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
            tfs.append(
                SpatialTransform(
                    self.HP.INPUT_DIM,
                    patch_center_dist_from_border=center_dist_from_border,
                    do_elastic_deform=True,
                    alpha=(90., 120.),
                    sigma=(9., 11.),
                    do_rotation=True,
                    angle_x=(-0.8, 0.8),
                    angle_y=(-0.8, 0.8),
                    angle_z=(-0.8, 0.8),
                    do_scale=True,
                    scale=(0.9, 1.5),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    border_mode_seg='constant',
                    border_cval_seg=0,
                    order_seg=0,
                    random_crop=True))
            # tfs.append(ResampleTransform(zoom_range=(0.5, 1)))
            # tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05)))
            tfs.append(
                ContrastAugmentationTransform(contrast_range=(0.7, 1.3),
                                              preserve_range=True,
                                              per_channel=False))
            tfs.append(
                BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3),
                                                  per_channel=False))

        tfs.append(ReorderSegTransform())
        batch_gen = MultiThreadedAugmenter(
            batch_gen,
            Compose(tfs),
            num_processes=num_processes,
            num_cached_per_queue=2,
            seeds=None
        )  # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
Exemplo n.º 27
0
    def get_batches(self,
                    batch_size=128,
                    type=None,
                    subjects=None,
                    num_batches=None):
        data = subjects
        seg = []

        #6 -> >30GB RAM
        if self.HP.DATA_AUGMENTATION:
            num_processes = 8  # 6 is a bit faster than 16
        else:
            num_processes = 6

        nr_of_samples = len(subjects) * self.HP.INPUT_DIM[0]
        if num_batches is None:
            num_batches_multithr = int(
                nr_of_samples / batch_size /
                num_processes)  #number of batches for exactly one epoch
        else:
            num_batches_multithr = int(num_batches / num_processes)

        if self.HP.TYPE == "combined":
            # Simple with .npy  -> just a little bit faster than Nifti (<10%) and f1 not better => use Nifti
            # batch_gen = SlicesBatchGeneratorRandomNpyImg_fusion((data, seg), batch_size=batch_size)
            batch_gen = SlicesBatchGeneratorRandomNpyImg_fusion(
                (data, seg), batch_size=batch_size)
        else:
            batch_gen = SlicesBatchGeneratorRandomNiftiImg(
                (data, seg), batch_size=batch_size)
            # batch_gen = SlicesBatchGeneratorRandomNiftiImg_5slices((data, seg), batch_size=batch_size)

        batch_gen.HP = self.HP
        tfs = []  #transforms

        if self.HP.NORMALIZE_DATA:
            tfs.append(
                ZeroMeanUnitVarianceTransform(
                    per_channel=self.HP.NORMALIZE_PER_CHANNEL))

        if self.HP.DATASET == "Schizo" and self.HP.RESOLUTION == "2mm":
            tfs.append(PadToMultipleTransform(16))

        if self.HP.DATA_AUGMENTATION:
            if type == "train":
                # scale: inverted: 0.5 -> bigger; 2 -> smaller
                # patch_center_dist_from_border: if 144/2=72 -> always exactly centered; otherwise a bit off center (brain can get off image and will be cut then)

                if self.HP.DAUG_SCALE:
                    center_dist_from_border = int(
                        self.HP.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
                    tfs.append(
                        SpatialTransform(
                            self.HP.INPUT_DIM,
                            patch_center_dist_from_border=
                            center_dist_from_border,
                            do_elastic_deform=self.HP.DAUG_ELASTIC_DEFORM,
                            alpha=(90., 120.),
                            sigma=(9., 11.),
                            do_rotation=self.HP.DAUG_ROTATE,
                            angle_x=(-0.8, 0.8),
                            angle_y=(-0.8, 0.8),
                            angle_z=(-0.8, 0.8),
                            do_scale=True,
                            scale=(0.9, 1.5),
                            border_mode_data='constant',
                            border_cval_data=0,
                            order_data=3,
                            border_mode_seg='constant',
                            border_cval_seg=0,
                            order_seg=0,
                            random_crop=True))

                if self.HP.DAUG_RESAMPLE:
                    tfs.append(ResampleTransform(zoom_range=(0.5, 1)))

                if self.HP.DAUG_NOISE:
                    tfs.append(GaussianNoiseTransform(noise_variance=(0,
                                                                      0.05)))

                if self.HP.DAUG_MIRROR:
                    tfs.append(MirrorTransform())

                if self.HP.DAUG_FLIP_PEAKS:
                    tfs.append(FlipVectorAxisTransform())

        #num_cached_per_queue 1 or 2 does not really make a difference
        batch_gen = MultiThreadedAugmenter(batch_gen,
                                           Compose(tfs),
                                           num_processes=num_processes,
                                           num_cached_per_queue=1,
                                           seeds=None)
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
Exemplo n.º 28
0
def get_transforms(mode="train",
                   n_channels=1,
                   target_size=128,
                   add_resize=False,
                   add_noise=False,
                   mask_type="",
                   batch_size=16,
                   rotate=True,
                   elastic_deform=True,
                   rnd_crop=False,
                   color_augment=True):
    tranform_list = []
    noise_list = []

    if mode == "train":

        tranform_list = [
            FillupPadTransform(min_size=(n_channels, target_size + 5,
                                         target_size + 5)),
            ResizeTransform(target_size=(target_size + 1, target_size + 1),
                            order=1,
                            concatenate_list=True),

            # RandomCropTransform(crop_size=(target_size + 5, target_size + 5)),
            MirrorTransform(axes=(2, )),
            ReshapeTransform(new_shape=(1, -1, "h", "w")),
            SpatialTransform(patch_size=(target_size, target_size),
                             random_crop=rnd_crop,
                             patch_center_dist_from_border=target_size // 2,
                             do_elastic_deform=elastic_deform,
                             alpha=(0., 100.),
                             sigma=(10., 13.),
                             do_rotation=rotate,
                             angle_x=(-0.1, 0.1),
                             angle_y=(0, 1e-8),
                             angle_z=(0, 1e-8),
                             scale=(0.9, 1.2),
                             border_mode_data="nearest",
                             border_mode_seg="nearest"),
            ReshapeTransform(new_shape=(batch_size, -1, "h", "w"))
        ]
        if color_augment:
            tranform_list += [  # BrightnessTransform(mu=0, sigma=0.2),
                BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1))
            ]

        tranform_list += [
            GaussianNoiseTransform(noise_variance=(0., 0.05)),
            ClipValueRange(min=-1.5, max=1.5),
        ]

        noise_list = []
        if mask_type == "gaussian":
            noise_list += [GaussianNoiseTransform(noise_variance=(0., 0.2))]

    elif mode == "val":
        tranform_list = [
            FillupPadTransform(min_size=(n_channels, target_size + 5,
                                         target_size + 5)),
            ResizeTransform(target_size=(target_size + 1, target_size + 1),
                            order=1,
                            concatenate_list=True),
            CenterCropTransform(crop_size=(target_size, target_size)),
            ClipValueRange(min=-1.5, max=1.5),
            # BrightnessTransform(mu=0, sigma=0.2),
            # BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1)),
            CopyTransform({"data": "data_clean"}, copy=True)
        ]

        noise_list += []

    if add_noise:
        tranform_list = tranform_list + noise_list

    tranform_list.append(NumpyToTensor())

    return Compose(tranform_list)
Exemplo n.º 29
0
# in order to use neptune logging:
# export NEPTUNE_API_TOKEN = '...' !!!
logging.getLogger().setLevel('INFO')
source_files = [__file__]
if hparams.config:
    source_files.append(hparams.config)
neptune_logger = NeptuneLogger(project_name=hparams.neptune_project,
                               params=vars(hparams),
                               experiment_name=hparams.experiment_name,
                               tags=[hparams.experiment_name],
                               upload_source_files=source_files)
tb_logger = loggers.TensorBoardLogger(hparams.log_dir)

transform = Compose([
    BrightnessTransform(mu=0.0, sigma=0.3, data_key='data'),
    GammaTransform(gamma_range=(0.7, 1.3), data_key='data'),
    ContrastAugmentationTransform(contrast_range=(0.3, 1.7), data_key='data')
])

with open(hparams.train_set, 'r') as keyfile:
    train_keys = [l.strip() for l in keyfile.readlines()]
print(train_keys)

with open(hparams.val_set, 'r') as keyfile:
    val_keys = [l.strip() for l in keyfile.readlines()]
print(val_keys)

train_ds = MedDataset(hparams.data_path,
                      train_keys,
                      hparams.patches_per_subject,
                      hparams.patch_size,
Exemplo n.º 30
0
def get_moreDA_augmentation(dataloader_train,
                            dataloader_val,
                            patch_size,
                            params=default_3D_augmentation_params,
                            border_val_seg=-1,
                            seeds_train=None,
                            seeds_val=None,
                            order_seg=1,
                            order_data=3,
                            deep_supervision_scales=None,
                            soft_ds=False,
                            classes=None,
                            pin_memory=True,
                            regions=None,
                            use_nondetMultiThreadedAugmenter: bool = False):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
        patch_size_spatial = patch_size[1:]
    else:
        patch_size_spatial = patch_size
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size_spatial,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         p_rot_per_axis=params.get("rotation_p_per_axis"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25),
                                          p_per_sample=0.15))

    if params.get("do_additive_brightness"):
        tr_transforms.append(
            BrightnessTransform(
                params.get("additive_brightness_mu"),
                params.get("additive_brightness_sigma"),
                True,
                p_per_sample=params.get("additive_brightness_p_per_sample"),
                p_per_channel=params.get("additive_brightness_p_per_channel")))

    tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.1))  # inverted gamma

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get(
                "cascade_do_cascade_augmentations") is not None and params.get(
                    "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size"),
                        p_per_label=params.get(
                            "cascade_random_binary_transform_p_per_label")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    if use_nondetMultiThreadedAugmenter:
        if NonDetMultiThreadedAugmenter is None:
            raise RuntimeError(
                'NonDetMultiThreadedAugmenter is not yet available')
        batchgenerator_train = NonDetMultiThreadedAugmenter(
            dataloader_train,
            tr_transforms,
            params.get('num_threads'),
            params.get("num_cached_per_thread"),
            seeds=seeds_train,
            pin_memory=pin_memory)
    else:
        batchgenerator_train = MultiThreadedAugmenter(
            dataloader_train,
            tr_transforms,
            params.get('num_threads'),
            params.get("num_cached_per_thread"),
            seeds=seeds_train,
            pin_memory=pin_memory)
    # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
    # import IPython;IPython.embed()

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    if use_nondetMultiThreadedAugmenter:
        if NonDetMultiThreadedAugmenter is None:
            raise RuntimeError(
                'NonDetMultiThreadedAugmenter is not yet available')
        batchgenerator_val = NonDetMultiThreadedAugmenter(
            dataloader_val,
            val_transforms,
            max(params.get('num_threads') // 2, 1),
            params.get("num_cached_per_thread"),
            seeds=seeds_val,
            pin_memory=pin_memory)
    else:
        batchgenerator_val = MultiThreadedAugmenter(
            dataloader_val,
            val_transforms,
            max(params.get('num_threads') // 2, 1),
            params.get("num_cached_per_thread"),
            seeds=seeds_val,
            pin_memory=pin_memory)
    # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)

    return batchgenerator_train, batchgenerator_val