def transform(self):

        if hp.mode == '3d':
            if hp.aug:
                training_transform = Compose([
                    # ToCanonical(),
                    CropOrPad((hp.crop_or_pad_size), padding_mode='reflect'),
                    # RandomMotion(),
                    RandomBiasField(),
                    ZNormalization(),
                    RandomNoise(),
                    RandomFlip(axes=(0, )),
                    OneOf({
                        RandomAffine(): 0.8,
                        RandomElasticDeformation(): 0.2,
                    }),
                ])
            else:
                training_transform = Compose([
                    CropOrPad((hp.crop_or_pad_size, hp.crop_or_pad_size,
                               hp.crop_or_pad_size),
                              padding_mode='reflect'),
                    ZNormalization(),
                ])
        elif hp.mode == '2d':
            if hp.aug:
                training_transform = Compose([
                    CropOrPad((hp.crop_or_pad_size), padding_mode='reflect'),
                    # RandomMotion(),
                    RandomBiasField(),
                    ZNormalization(),
                    RandomNoise(),
                    RandomFlip(axes=(0, )),
                    OneOf({
                        RandomAffine(): 0.8,
                        RandomElasticDeformation(): 0.2,
                    }),
                ])
            else:
                training_transform = Compose([
                    CropOrPad((hp.crop_or_pad_size, hp.crop_or_pad_size,
                               hp.crop_or_pad_size),
                              padding_mode='reflect'),
                    ZNormalization(),
                ])

        else:
            raise Exception('no such kind of mode!')

        return training_transform
Beispiel #2
0
def mri_artifact(p=1):
    return OneOf(
        {
            RandomMotion(): 0.34,
            RandomGhosting(): 0.33,
            RandomSpike(): 0.33
        },
        p=p)
Beispiel #3
0
 def test_reproducibility_oneof(self):
     subject1, subject2 = self.get_subjects()
     trsfm = Compose([
         OneOf([RandomNoise(p=1.0),
                RandomSpike(num_spikes=3, p=1.0)]),
         RandomNoise(p=.5)
     ])
     transformed1 = trsfm(subject1)
     history1 = transformed1.history
     trsfm_hist, seeds_hist = compose_from_history(history=history1)
     transformed2 = self.apply_transforms(subject2,
                                          trsfm_list=trsfm_hist,
                                          seeds_list=seeds_hist)
     data1, data2 = transformed1.img.data, transformed2.img.data
     self.assertTensorEqual(data1, data2)
Beispiel #4
0
def get_brats(
        data_root='/scratch/weina/dld_data/brats2019/MICCAI_BraTS_2019_Data_Training/',
        fold=1,
        seed=torch.distributed.get_rank()
    if torch.distributed.is_initialized() else 0,
        **kwargs):
    """ data iter for brats
    """
    logging.debug("BratsIter:: fold = {}, seed = {}".format(fold, seed))
    # args for transforms
    d_size, h_size, w_size = 155, 240, 240
    input_size = [7, 223, 223]
    spacing = (d_size / input_size[0], h_size / input_size[1],
               w_size / input_size[2])
    Mean, Std, Max = read_brats_mean(fold, data_root)
    normalize = transforms.Normalize(mean=Mean, std=Std)
    training_transform = Compose([
        # RescaleIntensity((0, 1)),  # so that there are no negative values for RandomMotion
        # RandomMotion(),
        # HistogramStandardization({MRI: landmarks}),
        RandomBiasField(),
        # ZNormalization(masking_method=ZNormalization.mean),
        RandomNoise(),
        ToCanonical(),
        Resample(spacing),
        # CropOrPad((48, 60, 48)),
        RandomFlip(axes=(0, )),
        OneOf({
            RandomAffine(): 0.8,
            RandomElasticDeformation(): 0.2,
        }),
        normalize
    ])
    val_transform = Compose([Resample(spacing), normalize])

    train = BratsIter(csv_file=os.path.join(data_root, 'IDH_label',
                                            'train_fold_{}.csv'.format(fold)),
                      brats_path=os.path.join(data_root, 'all'),
                      brats_transform=training_transform,
                      shuffle=True)

    val = BratsIter(csv_file=os.path.join(data_root, 'IDH_label',
                                          'val_fold_{}.csv'.format(fold)),
                    brats_path=os.path.join(data_root, 'all'),
                    brats_transform=val_transform,
                    shuffle=False)
    return train, val
Beispiel #5
0
def random_augment(x):
    '''Randomly augment input data.

    Returns: Randomly augmented input
    '''

    # Data augmentations to be used
    transforms_dict = {
        RandomFlip(): 1,
        RandomElasticDeformation(): 1,
        RandomAffine(): 1,
        RandomNoise(): 1,
        RandomBlur(): 1
    }

    # Create random transform, with a p chance to apply augmentation
    transform = OneOf(transforms_dict, p=0.95)
    return augment(x, transform)
Beispiel #6
0
def training_network(landmarks, dataset, subjects):
    training_transform = Compose([
        ToCanonical(),
        Resample(4),
        CropOrPad((48, 60, 48), padding_mode='reflect'),
        RandomMotion(),
        HistogramStandardization({'mri': landmarks}),
        RandomBiasField(),
        ZNormalization(masking_method=ZNormalization.mean),
        RandomNoise(),
        RandomFlip(axes=(0, )),
        OneOf({
            RandomAffine(): 0.8,
            RandomElasticDeformation(): 0.2,
        }),
    ])

    validation_transform = Compose([
        ToCanonical(),
        Resample(4),
        CropOrPad((48, 60, 48), padding_mode='reflect'),
        HistogramStandardization({'mri': landmarks}),
        ZNormalization(masking_method=ZNormalization.mean),
    ])

    training_split_ratio = 0.9
    num_subjects = len(dataset)
    num_training_subjects = int(training_split_ratio * num_subjects)

    training_subjects = subjects[:num_training_subjects]
    validation_subjects = subjects[num_training_subjects:]

    training_set = tio.SubjectsDataset(training_subjects,
                                       transform=training_transform)

    validation_set = tio.SubjectsDataset(validation_subjects,
                                         transform=validation_transform)

    print('Training set:', len(training_set), 'subjects')
    print('Validation set:', len(validation_set), 'subjects')
    return training_set, validation_set
Beispiel #7
0
 def test_wrong_input_type(self):
     with self.assertRaises(ValueError):
         OneOf(1)
Beispiel #8
0
 def test_not_transform(self):
     with self.assertRaises(ValueError):
         OneOf({RandomAffine: 1, RandomElasticDeformation: 2})
Beispiel #9
0
 def test_zero_probabilities(self):
     with self.assertRaises(ValueError):
         OneOf({RandomAffine(): 0, RandomElasticDeformation(): 0})
Beispiel #10
0
 def test_negative_probabilities(self):
     with self.assertRaises(ValueError):
         OneOf({RandomAffine(): -1, RandomElasticDeformation(): 1})
Beispiel #11
0
    input_size = [7, 223, 223]
    spacing = (d_size / input_size[0], h_size / input_size[1],
               w_size / input_size[2])
    training_transform = Compose([
        # RescaleIntensity((0, 1)),  # so that there are no negative values for RandomMotion
        # RandomMotion(),
        # HistogramStandardization({MRI: landmarks}),
        RandomBiasField(),
        # ZNormalization(masking_method=ZNormalization.mean),
        RandomNoise(),
        ToCanonical(),
        Resample(spacing),
        # CropOrPad((48, 60, 48)),
        RandomFlip(axes=(0, )),
        OneOf({
            RandomAffine(): 0.8,
            RandomElasticDeformation(): 0.2,
        }),
    ])

    fold = 1
    data_root = '../../dld_data/brats2019/MICCAI_BraTS_2019_Data_Training/'

    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    logging.getLogger().setLevel(logging.DEBUG)

    logging.info("Testing BratsIter without transformer [not torch wapper]")

    # has memory (no metter random or not, it will trival all none overlapped clips)
    train_dataset = BratsIter(csv_file=os.path.join(
Beispiel #12
0
 def test_one_of(self):
     transform = OneOf({
         RandomAffine(): 0.2,
         RandomElasticDeformation(max_displacement=0.5): 0.8,
     })
     transform(self.sample_subject)
Beispiel #13
0
    label=torchio.Image(tensor=torch.from_numpy(train_seg),
                        label=torchio.LABEL),
)
valid_subject = torchio.Subject(
    data=torchio.Image(tensor=torch.from_numpy(valid_data),
                       label=torchio.INTENSITY),
    label=torchio.Image(tensor=torch.from_numpy(valid_seg),
                        label=torchio.LABEL),
)
# Define the transforms for the set of training patches
training_transform = Compose([
    RandomNoise(p=0.2),
    RandomFlip(axes=(0, 1, 2)),
    RandomBlur(p=0.2),
    OneOf({
        RandomAffine(): 0.8,
        RandomElasticDeformation(): 0.2,
    }, p=0.5),  # Changed from p=0.75 24/6/20
])
# Create the datasets
training_dataset = torchio.ImagesDataset(
    [train_subject], transform=training_transform)

validation_dataset = torchio.ImagesDataset(
    [valid_subject])
# Define the queue of sampled patches for training and validation
sampler = torchio.data.UniformSampler(PATCH_SIZE)
patches_training_set = torchio.Queue(
    subjects_dataset=training_dataset,
    max_length=MAX_QUEUE_LENGTH,
    samples_per_volume=TRAIN_PATCHES,
    sampler=sampler,
Beispiel #14
0
def mri_artifact(parameters):
    return OneOf(
        {RandomGhosting(): 0.5, RandomSpike(): 0.5},
        p=parameters["probability"],
    )
Beispiel #15
0
    def __init__(
        self,
        generator,
        zoom_range=None,
        filter_range=None,
        flip=None,
        transpose=False,
        noise=None,
        normalization=False,
        minmax=False,
        affine=None,
        window_width=None,
        window_level=None,
        window_vmax=1.0,
        window_vmin=0.0,
        intensity_shift=0.0,
        rotate=False,
        rotate3d=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.generator = generator
        self.shapes = generator.shapes
        # augmenting methods
        self.methods = []

        if normalization:

            def _norm(data):
                mean = np.mean(data['image'])
                std = np.std(data['image'])
                std = 1.0 if std == 0. else std
                data['image'] = (data['image'] - mean) / std
                return data

            self.methods.append(_norm)

        if minmax:

            def _minmax(data):
                lower_percentile = 0.2,
                upper_percentile = 99.8

                foreground = data['image'] != data['image'][
                    (0, ) * len(data['image'].shape)]
                min_val = np.percentile(data['image'][foreground].ravel(),
                                        lower_percentile)
                max_val = np.percentile(data['image'][foreground].ravel(),
                                        upper_percentile)
                data['image'][data['image'] > max_val] = max_val
                data['image'][data['image'] < min_val] = min_val
                data['image'] = (data['image'] - min_val) / (max_val - min_val)
                data['image'][~foreground] = 0

                return data

            self.methods.append(_minmax)

        # affine
        if affine is not None:
            import torchio
            from torchio.transforms import (
                RandomAffine,
                RandomElasticDeformation,
                OneOf,
            )

            if affine == 'strong':
                transform = OneOf(
                    {
                        RandomAffine(translation=10,
                                     degrees=10,
                                     scales=(0.9, 1.1),
                                     default_pad_value='otsu',
                                     image_interpolation='bspline'):
                        0.5,
                        RandomElasticDeformation():
                        0.5
                    },
                    p=0.75,
                )
            else:
                transform = OneOf(
                    {
                        RandomAffine(translation=10): 0.5,
                        RandomElasticDeformation(): 0.5
                    },
                    p=0.75,
                )

            def _affine(data):

                for key in data:
                    data[key] = torch.Tensor(data[key])

                subjs = {
                    'label':
                    torchio.Image(tensor=data['label'], type=torchio.LABEL)
                }
                shape = data['image'].shape

                # We need to seperate out the case of 4D image
                if len(shape) == 4:
                    n_channels = shape[-1]
                    for i in range(n_channels):
                        subjs.update({
                            f'ch{i}':
                            torchio.Image(tensor=data['image'][..., i],
                                          type=torchio.INTENSITY)
                        })

                else:
                    assert len(shape) == 3
                    subjs.update({
                        'image':
                        torchio.Image(tensor=data['image'],
                                      type=torchio.INTENSITY)
                    })

                transformed = transform(torchio.Subject(**subjs))

                if 'image' in subjs.keys():
                    data['image'] = transformed.image.numpy()

                else:
                    # if image contains multiple channels,
                    # then aggregate the transformed results into one
                    data['image'] = np.stack(tuple(
                        getattr(transformed, ch).numpy()
                        for ch in subjs.keys() if 'ch' in ch),
                                             axis=-1)
                data['label'] = transformed.label.numpy()

                for key in data:
                    data[key] = data[key].squeeze()

                return data

            self.methods.append(_affine)

        # convert image to float
        def _to_float(data):
            data['image'] = data['image'].astype(np.float)
            return data

        self.methods.append(_to_float)

        # adjust contrast/window
        if window_width or window_level:
            from ..preprocessings import window
            window_width = window_width if window_width else 100
            window_level = window_level if window_level else 50

            def _window(data):
                if isinstance(window_width, (tuple, list)):
                    _window_width = random_factor(window_width)
                else:
                    _window_width = window_width
                if isinstance(window_level, (tuple, list)):
                    _window_level = random_factor(window_level)
                else:
                    _window_level = window_level
                data['image'] = window(
                    data['image'],
                    width=_window_width,
                    level=_window_level,
                    vmin=window_vmin,
                    vmax=window_vmax,
                )
                return data

            self.methods.append(_window)

        if noise is not None:
            assert isinstance(noise, float)
            assert noise > 0.

            def _noise(data):
                data['image'] += np.random.normal(loc=0.0,
                                                  scale=noise,
                                                  size=data['image'].shape)
                return data

            self.methods.append(_noise)

        if zoom_range is not None:
            from ..preprocessings import zoom

            def _zoom(data):
                zoom_factor = random_factor(zoom_range)
                for key in data:
                    data[key] = zoom(data[key], zoom_factor)
                return data

            self.methods.append(_zoom)

        # TODO: deprecated
        # if filter_range is not None:
        #     from scipy import ndimage

        #     def _filter(data):
        #         sigma = random_factor(filter_range)
        #         ndim = len(data['image'].shape)
        #         data['image'] = ndimage.gaussian_filter(
        #             data['image'],
        #             sigma=(sigma, sigma) + (0,) * (ndim - 2),
        #         )
        #         return data
        #     self.methods.append(_filter)

        if flip is not None:
            assert isinstance(flip, (list, tuple))
            for f in flip:
                assert f >= 0 and f <= 1

            def flip_img(img, flip_x, flip_y, flip_z):
                if flip_x:
                    img = img[::-1, :, :, ...]

                if flip_y:
                    img = img[:, ::-1, :, ...]

                if flip_z:
                    img = img[:, :, ::-1, ...]

                return img

            def _flip(data):
                to_flip_x = random.random() < flip[0]
                to_flip_y = random.random() < flip[1]
                to_flip_z = random.random() < flip[2]
                for key in data:
                    data[key] = flip_img(data[key],
                                         flip_x=to_flip_x,
                                         flip_y=to_flip_y,
                                         flip_z=to_flip_z)
                return data

            self.methods.append(_flip)

        if transpose:

            def _transpose(data):
                if random.random() > 0.5:
                    for key in data:
                        data[key] = np.moveaxis(data[key], 0, 1)
                return data

            self.methods.append(_transpose)

        if intensity_shift > 0.:

            def _shift(data):
                data['image'] += (np.random.rand() * 2.0 -
                                  1.0) * intensity_shift
                return data

            self.methods.append(_shift)

        if rotate:

            def _rotate(data):
                # randomly rotate 0~3 times about the z-axis by 90 degrees
                times = np.random.randint(0, 4)
                if times > 0:
                    for key in ['image', 'label']:
                        data[key] = np.rot90(data[key], times, (0, 1))
                return data

            self.methods.append(_rotate)

        if rotate3d:

            def _rotate3d(data):
                # check isotropic
                assert data['image'].shape[0] == data['image'].shape[1]
                assert data['image'].shape[0] == data['image'].shape[2]

                # randomly select the plane spanning by the axes: (0, 1), (1, 2), (0, 2)
                the_axes = [0, 1, 2]
                the_axes.remove(np.random.randint(0, 3))
                the_axes = tuple(the_axes)

                # randomly rotate 0~3 times about the axis by 90 degrees
                times = np.random.randint(0, 4)
                if times > 0:
                    for key in ['image', 'label']:
                        data[key] = np.rot90(data[key], times, the_axes)
                return data

            self.methods.append(_rotate3d)