Example #1
0
 def test_dtype(self):
     # https://github.com/fepegar/torchio/issues/407
     tensor_int = (100 * torch.rand(1, 2, 3, 4)).byte()
     transform = tio.ZNormalization(masking_method=tio.ZNormalization.mean)
     transform(tensor_int)
     transform = tio.ZNormalization()
     transform(tensor_int)
Example #2
0
 def test_transforms(self):
     landmarks_dict = dict(
         t1=np.linspace(0, 100, 13),
         t2=np.linspace(0, 100, 13),
     )
     elastic = torchio.RandomElasticDeformation(max_displacement=1)
     transforms = (
         torchio.CropOrPad((9, 21, 30)),
         torchio.ToCanonical(),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=2, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(masking_method='label'),
         torchio.HistogramStandardization(landmarks_dict=landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1
         }),
         torchio.Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3),
         torchio.Crop((3, 2, 8, 0, 1, 4)),
     )
     transform = torchio.Compose(transforms)
     transform(self.sample)
Example #3
0
def byol_aug(filename):
    """
        BYOL minimizes the distance between representations of each sample and a transformation of that sample.
        Examples of transformations include: translation, rotation, blurring, color inversion, color jitter, gaussian noise.

        Return an augmented dataset that consisted the above mentioned transformation. Will be used in the training.
        """
    image = tio.ScalarImage(filename)
    get_foreground = tio.ZNormalization.mean
    training_transform = tio.Compose([
        tio.CropOrPad((180, 220, 170)),  # zero mean, unit variance of foreground
        tio.ZNormalization(
            masking_method=get_foreground),
        tio.RandomBlur(p=0.25),  # blur 25% of times
        tio.RandomNoise(p=0.25),  # Gaussian noise 25% of times
        tio.OneOf({  # either
            tio.RandomAffine(): 0.8,  # random affine
            tio.RandomElasticDeformation(): 0.2,  # or random elastic deformation
        }, p=0.8),  # applied to 80% of images
        tio.RandomBiasField(p=0.3),  # magnetic field inhomogeneity 30% of times
        tio.OneOf({  # either
            tio.RandomMotion(): 1,  # random motion artifact
            tio.RandomSpike(): 2,  # or spikes
            tio.RandomGhosting(): 2,  # or ghosts
        }, p=0.5),  # applied to 50% of images
    ])

    tfs_image = training_transform(image)
    return tfs_image
Example #4
0
    def test_bounds_mask(self):
        transform = tio.ZNormalization()
        with self.assertRaises(ValueError):
            transform.get_mask_from_anatomical_label('test', 0)
        tensor = torch.rand((1, 2, 2, 2))

        def get_mask(label):
            mask = transform.get_mask_from_anatomical_label(label, tensor)
            return mask

        left = get_mask('Left')
        assert left[:, 0].sum() == 4 and left[:, 1].sum() == 0
        right = get_mask('Right')
        assert right[:, 1].sum() == 4 and right[:, 0].sum() == 0
        posterior = get_mask('Posterior')
        assert posterior[:, :, 0].sum() == 4 and posterior[:, :, 1].sum() == 0
        anterior = get_mask('Anterior')
        assert anterior[:, :, 1].sum() == 4 and anterior[:, :, 0].sum() == 0
        inferior = get_mask('Inferior')
        assert inferior[..., 0].sum() == 4 and inferior[..., 1].sum() == 0
        superior = get_mask('Superior')
        assert superior[..., 1].sum() == 4 and superior[..., 0].sum() == 0

        mask = transform.get_mask_from_bounds(3 * (0, 1), tensor)
        assert mask[0, 0, 0, 0] == 1
        assert mask.sum() == 1
def get_train_transform(landmarks_path, resection_params=None):
    spatial_transform = tio.Compose((
        tio.OneOf({
            tio.RandomAffine(): 0.9,
            tio.RandomElasticDeformation(): 0.1,
        }),
        tio.RandomFlip(),
    ))
    resolution_transform = tio.OneOf(
        (
            tio.RandomAnisotropy(),
            tio.RandomBlur(),
        ),
        p=0.75,
    )
    transforms = []
    if resection_params is not None:
        transforms.append(get_simulation_transform(resection_params))
    if landmarks_path is not None:
        transforms.append(
            tio.HistogramStandardization({'image': landmarks_path}))
    transforms.extend([
        # tio.RandomGamma(p=0.2),
        resolution_transform,
        tio.RandomGhosting(p=0.2),
        tio.RandomSpike(p=0.2),
        tio.RandomMotion(p=0.2),
        tio.RandomBiasField(p=0.5),
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
        tio.RandomNoise(p=0.75),  # always after ZNorm and after blur!
        spatial_transform,
        get_tight_crop(),
    ])
    return tio.Compose(transforms)
def get_test_transform(landmarks_path):
    transforms = []
    if landmarks_path is not None:
        transforms.append(tio.HistogramStandardization({'image': landmarks_path}))
    transforms.extend([
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
        get_tight_crop(),
    ])
    return tio.Compose(transforms)
Example #7
0
    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.df.iloc[index, 1])
        subject = tio.Subject(img=Image(img_path, type=tio.INTENSITY))

        if (self.transform):
            # in training phase
            transformations = (
                tio.ZNormalization(),
                tio.Resample(target=2,
                             pre_affine_name="affine"),  # preprocessing
                tio.OneOf(transforms_dict),
                tio.OneOf(transforms_dict2))
        else:
            # in validation and testing phase
            transformations = (
                tio.ZNormalization(),
                tio.Resample(target=2,
                             pre_affine_name="affine")  # preprocessing
            )

        transformations = Compose(transformations)
        transformed_image = transformations(subject)

        get_image = transformed_image.img
        tensor_resampled_image = get_image.data
        tensor_resampled_image = tensor_resampled_image.unsqueeze(
            dim=0)  # adding batch size

        resampled_image = torch.nn.functional.interpolate(
            input=tensor_resampled_image,
            size=(256, 256, 166),
            mode='trilinear'
        )  # trilinear because it had 5D (mini-batch x channels x height x width x depth)
        resampled_image = np.reshape(resampled_image, (1, 256, 256, 166))

        y_label = 0.0 if self.df.iloc[index, 2] == 'AD' else 1.0
        y_label = torch.tensor(y_label, dtype=torch.float)

        return resampled_image, y_label
Example #8
0
    def __getitem__(self, idx):

        #modify the collate_fn from the dataloader so that it filters out None elements.
        img_name = self.list_images[idx]

        try:
            _, image, spacing = load_reorient(img_name)
        except:
            spacing = [2, 2, 2]
            print('error loading {0}'.format(img_name))
            if self.remove_corrupt and os.path.isfile(img_name):
                os.remove(img_name)
            image, _ = self.get_random(fa=True)

        if not np.any(image):
            image, _ = self.get_random(fa=True)

        #preprocessing
        original_spacing = np.array(spacing)
        get_foreground = tio.ZNormalization.mean
        target_shape = 128, 128, 128
        crop_pad = tio.CropOrPad(target_shape)

        ###operations###
        standardize = tio.ZNormalization(masking_method=get_foreground)
        if 'wb' not in self.class_names and 'abd-pel' not in self.class_names:
            downsample = tio.Resample(
                (2 / spacing[0], 2 / spacing[1], 2 / spacing[2]))
            try:
                image = standardize(crop_pad(downsample(image)))
            except:
                print(img_name)
        else:
            x, y, z = image.shape[1:] / np.asarray(target_shape)
            downsample = tio.Resample((x, y, z))
            image = standardize(crop_pad(downsample(image)))

        if image.shape[0] > 1:

            image = np.expand_dims(image[0, :, :, :], axis=0)
            #print(image.shape)
            print(img_name)

        sample = {
            'image': torch.from_numpy(image),
            'label': np.array(0),
            'spacing': original_spacing,
            'fn': img_name
        }

        return sample
Example #9
0
def get_dataset(
    input_path,
    tta_iterations=0,
    interpolation='bspline',
    tolerance=0.1,
    mni_transform_path=None,
):
    if mni_transform_path is None:
        image = tio.ScalarImage(input_path)
    else:
        affine = tio.io.read_matrix(mni_transform_path)
        image = tio.ScalarImage(input_path, **{TO_MNI: affine})
    subject = tio.Subject({IMAGE_NAME: image})
    landmarks = np.array([
        0., 0.31331614, 0.61505419, 0.76732501, 0.98887953, 1.71169384,
        3.21741126, 13.06931455, 32.70817796, 40.87807389, 47.83508873,
        63.4408591, 100.
    ])
    hist_std = tio.HistogramStandardization({IMAGE_NAME: landmarks})
    preprocess_transforms = [
        tio.ToCanonical(),
        hist_std,
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
    ]
    zooms = nib.load(input_path).header.get_zooms()
    pixdim = np.array(zooms)
    diff_to_1_iso = np.abs(pixdim - 1)
    if np.any(diff_to_1_iso > tolerance) or mni_transform_path is not None:
        kwargs = {'image_interpolation': interpolation}
        if mni_transform_path is not None:
            kwargs['pre_affine_name'] = TO_MNI
            kwargs['target'] = tio.datasets.Colin27().t1.path
        resample_transform = tio.Resample(**kwargs)
        preprocess_transforms.append(resample_transform)
    preprocess_transforms.append(tio.EnsureShapeMultiple(8, method='crop'))
    preprocess_transform = tio.Compose(preprocess_transforms)
    no_aug_dataset = tio.SubjectsDataset([subject],
                                         transform=preprocess_transform)

    aug_subjects = tta_iterations * [subject]
    if not aug_subjects:
        return no_aug_dataset
    augment_transform = tio.Compose((
        preprocess_transform,
        tio.RandomFlip(),
        tio.RandomAffine(image_interpolation=interpolation),
    ))
    aug_dataset = tio.SubjectsDataset(aug_subjects,
                                      transform=augment_transform)
    dataset = torch.utils.data.ConcatDataset((no_aug_dataset, aug_dataset))
    return dataset
Example #10
0
    def __init__(self, foldername, annotations_data):
        self.folder = foldername
        self.filedata = annotations_data
        self.fileinfo = np.array(self.filedata.filelist)
        self.standardizer = tio.ZNormalization()
        self.padder = tio.CropOrPad((144, 176, 144))

        self.images = []
        for i in range(len(self.fileinfo)):
            self.images.append(
                torch.from_numpy(
                    nib.load(os.path.join(
                        self.folder,
                        self.fileinfo[i])).get_fdata()).float().unsqueeze(0))
Example #11
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = tio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     resize_args = (10, 20, 30) if is_3d else (10, 20, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     remapping = {1: 2, 2: 1, 3: 20, 4: 25}
     transforms = [
         tio.CropOrPad(cp_args),
         tio.EnsureShapeMultiple(2, method='crop'),
         tio.Resize(resize_args),
         tio.ToCanonical(),
         tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
         tio.CopyAffine(channels[0]),
         tio.Resample((1, 1.1, 1.25)),
         tio.RandomFlip(axes=flip_axes, flip_probability=1),
         tio.RandomMotion(),
         tio.RandomGhosting(axes=(0, 1, 2)),
         tio.RandomSpike(),
         tio.RandomNoise(),
         tio.RandomBlur(),
         tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
         tio.RandomBiasField(),
         tio.RescaleIntensity(out_min_max=(0, 1)),
         tio.ZNormalization(),
         tio.HistogramStandardization(landmarks_dict),
         elastic,
         tio.RandomAffine(),
         tio.OneOf({
             tio.RandomAffine(): 3,
             elastic: 1,
         }),
         tio.RemapLabels(remapping=remapping, masking_method='Left'),
         tio.RemoveLabels([1, 3]),
         tio.SequentialLabels(),
         tio.Pad(pad_args, padding_mode=3),
         tio.Crop(crop_args),
     ]
     if labels:
         transforms.append(tio.RandomLabelsToImage(label_key='label'))
     return tio.Compose(transforms)
Example #12
0
def normalization(histogram_transform, dataset):
    znorm_transform = tio.ZNormalization(
        masking_method=tio.ZNormalization.mean)

    sample = dataset[0]
    transform = tio.Compose([histogram_transform, znorm_transform])
    znormed = transform(sample)

    fig, ax = plt.subplots(dpi=100)
    plot_histogram(ax, znormed.mri.data, label='Z-normed', alpha=1)
    ax.set_title('Intensity values of one sample after z-normalization')
    ax.set_xlabel('Intensity')
    ax.grid()
    plt.show()
    ax.show()
Example #13
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = torchio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     transforms = [
         torchio.CropOrPad(cp_args),
         torchio.ToCanonical(),
         torchio.RandomDownsample(downsampling=(1.75, 2),
                                  axes=axes_downsample),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=flip_axes, flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(),
         torchio.HistogramStandardization(landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1,
         }),
         torchio.Pad(pad_args, padding_mode=3),
         torchio.Crop(crop_args),
     ]
     if labels:
         transforms.append(torchio.RandomLabelsToImage(label_key='label'))
     return torchio.Compose(transforms)
Example #14
0
 def test_z_normalization(self):
     transform = tio.ZNormalization()
     transformed = transform(self.sample_subject)
     self.assertAlmostEqual(float(transformed.t1.data.mean()), 0., places=6)
     self.assertAlmostEqual(float(transformed.t1.data.std()), 1.)
Example #15
0
 def test_no_std(self):
     image = tio.ScalarImage(tensor=torch.ones(1, 2, 2, 2))
     with self.assertRaises(RuntimeError):
         tio.ZNormalization()(image)
"""
Exclude images from transform
=============================

In this example we show how the kwargs ``include`` and ``exclude`` can be
used to apply a transform to only some of the images within a subject.
"""

import torch
import torchio as tio

torch.manual_seed(0)

subject = tio.datasets.Pediatric(years=(4.5, 8.5))
subject.plot()
transform = tio.Compose([
    tio.RandomAffine(degrees=(20, 30)),
    tio.ZNormalization(),
    tio.RandomBlur(std=(3, 4), include='t1'),
    tio.RandomNoise(std=(1, 1.5), exclude='t1'),
])
transformed = transform(subject)
transformed.plot()
Example #17
0
 def __init__(self, foldername, annotations_data):
     self.folder = foldername
     self.filedata = annotations_data
     self.fileinfo = np.array(self.filedata.filelist)
     self.standardizer = tio.ZNormalization()
     self.padder = tio.CropOrPad((144, 176, 144))
Example #18
0
#Normalisation
landmarks = tio.HistogramStandardization.train(
    image_paths,
    output_path=histogram_landmarks_path,
)
np.set_printoptions(suppress=True, precision=3)
print('\nTrained landmarks:', landmarks)

#Histogram standardisation
#Hist standardization
landmarks_dict = {'mri': landmarks}
histogram_transform = tio.HistogramStandardization(landmarks_dict)

#Z-Norm
znorm_transform = tio.ZNormalization(masking_method=tio.ZNormalization.mean)

sample = dataset[0]
transform = tio.Compose([histogram_transform, znorm_transform])
znormed = transform(sample)

fig, ax = plt.subplots(dpi=100)
plot_histogram(ax, znormed.mri.data, label='Z-normed', alpha=1)
ax.set_title('Intensity values of one sample after z-normalization')
ax.set_xlabel('Intensity')
ax.grid()

training_transform = Compose([
    ToCanonical(),
    #  Resample(4),
    CropOrPad((112, 112, 48), padding_mode=0),  #reflect , original 112,112,48
Example #19
0
 def test_bad_bounds_mask(self):
     transform = tio.ZNormalization(masking_method='test')
     with self.assertRaises(ValueError):
         transform(self.sample_subject)
Example #20
0
    subject = tio.Subject(image=tio.ScalarImage(image_path),
                          mask=tio.LabelMap(label_path))
    subjects.append(subject)

# собираем особый датасет torchio с пациентами
dataset = tio.SubjectsDataset(subjects)

# приводим маску к 1 классу
if config.to_one_class:
    for subject in dataset.dry_iter():
        subject['mask'] = one(subject['mask'])

training_transform = tio.Compose([
    tio.Resample(4),
    tio.ZNormalization(
        masking_method=tio.ZNormalization.mean
    ),  # вот эту штуку все рекомендовали на форумах torchio. 
    tio.RandomFlip(p=0.25),
    tio.RandomNoise(p=0.25),
    # !!!  Приходится насильно переводить тензоры в float
    tio.Lambda(to_float)
])

validation_transform = tio.Compose([
    tio.Resample(4),
    tio.ZNormalization(masking_method=tio.ZNormalization.mean),
    tio.RandomNoise(p=0.25),
    tio.Lambda(to_float)
])

Example #21
0
            create_images(true_gif_output_path,
                          channels[0],
                          channels[1],
                          channels[2],
                          mri_chan=subject['data'][0][1] > 0,
                          angle_num=gif_angle_rotation,
                          angle_view=gif_view_angle,
                          fig_size=fig_size_gif)
            make_gif(true_gif_output_path,
                     os.path.join(true_gif_output_path, 'true.gif'),
                     angle_num=gif_angle_rotation)


if __name__ == "__main__":
    validation_transform = tio.Compose([
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
        tio.CropOrPad((240, 240, 160)),
        tio.OneHot(num_classes=5)
    ])
    gen_visuals(
        image_path=
        "../brats_new/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_010",
        transforms=validation_transform,
        model_path="./Models/test_train_many_1e-3.pt",
        gen_pred=True,
        gen_true=True,
        input_channels_list=['flair', 't1', 't2', 't1ce'],
        seg_channels=[1, 2, 4],
        gen_gif=False,
        true_gif_output_path="../output/true",
        pred_gif_output_path="../output/pred",
Example #22
0
rescaled = rescale(fpg_ras)
fig, axes = plt.subplots(2, 1)
sns.distplot(fpg.mri.data, ax=axes[0], kde=False)
sns.distplot(rescaled.mri.data, ax=axes[1], kde=False)
axes[0].set_title('Original histogram')
axes[1].set_title('Intensity rescaling with percentiles 1 and 99')
axes[0].set_ylim((0, 1e6))
axes[1].set_ylim((0, 1e6))
plt.tight_layout()
show_fpg(rescaled)

#Z-normalization
#Another common approach for normalization is forcing data points to have zero-mean
#and unit variance. We can use ZNormalization for this.

standardize = tio.ZNormalization()
standardized = standardize(fpg)
fig, axes = plt.subplots(2, 1)
sns.distplot(fpg.mri.data, ax=axes[0], kde=False)
sns.distplot(standardized.mri.data, ax=axes[1], kde=False)
axes[0].set_title('Original histogram')
axes[1].set_title('Z-normalization')
axes[0].set_ylim((0, 1e6))
axes[1].set_ylim((0, 1e6))
plt.tight_layout()

#The second mode in our distribution, corresponding to the foreground, is far from zero
#because the background contributes a lot to the mean computation.
#We can compute the stats using e.g. values above the mean only.
#Let's see if the mean is a good threshold to segment the foreground.
Example #23
0
            lr_1=tio.ScalarImage(t2_file),
        )
    if in_channels == 3:
        subject = tio.Subject(
            hr=tio.ScalarImage(t2_file),
            lr_1=tio.ScalarImage(t2_file),
            lr_2=tio.ScalarImage(t2_file),
            lr_3=tio.ScalarImage(t2_file),
        )

    subjects.append(subject)

print('DHCP Dataset size:', len(subjects), 'subjects')

# DATA AUGMENTATION
normalization = tio.ZNormalization()
spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0.75)
flip = tio.RandomFlip(axes=('LR', ), flip_probability=0.5)

tocanonical = tio.ToCanonical()

b1 = tio.Blur(std=(0.001, 0.001, 1), include='lr_1')  #blur
d1 = tio.Resample((0.8, 0.8, 2), include='lr_1')  #downsampling
u1 = tio.Resample(target='hr', include='lr_1')  #upsampling

if in_channels == 3:
    b2 = tio.Blur(std=(0.001, 1, 0.001), include='lr_2')  #blur
    d2 = tio.Resample((0.8, 2, 0.8), include='lr_2')  #downsampling
    u2 = tio.Resample(target='hr', include='lr_2')  #upsampling

    b3 = tio.Blur(std=(1, 0.001, 0.001), include='lr_3')  #blur
Example #24
0
        subject = tio.Subject(
            t2=tio.ScalarImage(t2_file),
            t1=tio.ScalarImage(t1_file),
            label=tio.LabelMap(seg_file),
        )
        subjects.append(subject)

    dataset = tio.SubjectsDataset(subjects)
    print('Dataset size:', len(dataset), 'subjects')

    onehot = tio.OneHot()
    flip = tio.RandomFlip(axes=('LR', ), flip_probability=0.5)
    bias = tio.RandomBiasField(coefficients=0.5, p=0.5)
    noise = tio.RandomNoise(std=0.1, p=0.25)
    normalization = tio.ZNormalization(masking_method='label')
    spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0.75)

    transforms = [flip, spatial, bias, normalization, noise, onehot]

    training_transform = tio.Compose(transforms)
    validation_transform = tio.Compose([normalization, onehot])

    #%%
    num_subjects = len(dataset)
    num_training_subjects = int(training_split_ratio * num_subjects)
    num_validation_subjects = num_subjects - num_training_subjects

    num_split_subjects = num_training_subjects, num_validation_subjects
    training_subjects, validation_subjects = torch.utils.data.random_split(
        subjects,
Example #25
0
def dataloader(handles, mode = 'train'):
    # If pickle exists, load it
    try:
        with open('../inputs/flpickles/' + mode + '.pickle', 'rb') as f:
            images = pickle.load(f)
            
    except:
        
        images = {}
        images['Image'] = []
        images['Label'] = []
        images['Gap'] = []
        images['ID'] = []

        # Data augmentations
        random_flip = tio.RandomFlip(axes=1)
        random_flip2 = tio.RandomFlip(axes=2)
        random_affine = tio.RandomAffine(seed=0, scales=(3, 3))
        random_elastic = tio.RandomElasticDeformation(
            max_displacement=(0, 20, 40),
            num_control_points=20,
            seed=0,
        )
        rescale = tio.RescaleIntensity((-1, 1), percentiles=(1, 99))
        standardize_foreground = tio.ZNormalization(masking_method=lambda x: x > x.mean())
        blur = tio.RandomBlur(seed=0)
        standardize = tio.ZNormalization()
        add_noise = tio.RandomNoise(std=0.5, seed=42)
        add_spike = tio.RandomSpike(seed=42)
        add_ghosts = tio.RandomGhosting(intensity=1.5, seed=42)
        add_motion = tio.RandomMotion(num_transforms=6, image_interpolation='nearest', seed=42)
        swap = tio.RandomSwap(patch_size = 7)

        # For each image
        for idx, row in handles.iterrows():
            im_aug = []
            lb_aug = []
            gap_aug = []
            imgs = np.zeros(shape=(1, 1,7,1024, 1024), dtype=np.float32)   # change patch shape if necessary
            lbs = np.zeros(shape=(1, 1,7,1024, 1024), dtype=np.float32)
            gaps = np.zeros(shape=(1, 1,7,1024, 1024), dtype=np.float32)
            im = io.imread(row['Image'])
            im = im / 255 # Normalization
            im = np.expand_dims(im, axis=0)
            imgs[0] = im
            im_aug.append(imgs)
            images['ID'].append(row['ID'])
            if mode == 'train':
                im_flip1 = random_flip(im)
                imgs[0] = im_flip1
                im_aug.append(imgs)
                im_flip2 = random_flip2(im)
                imgs[0] = im_flip2
                im_aug.append(imgs)
                im_affine = random_affine(im)
                imgs[0] = im_affine
                im_aug.append(imgs)
                im_elastic = random_elastic(im)
                imgs[0] = im_elastic
                im_aug.append(imgs)
                im_rescale = rescale(im)
                imgs[0] = im_rescale
                im_aug.append(imgs)
                im_standard = standardize_foreground(im)
                imgs[0] = im_standard
                im_aug.append(imgs)
                im_blur = blur(im)
                imgs[0] = im_blur
                im_aug.append(imgs)
                im_noisy = add_noise(standardize(im))
                imgs[0] = im_noisy
                im_aug.append(imgs)
                im_spike = add_spike(im)
                imgs[0] = im_spike
                im_aug.append(imgs)
                im_ghost = add_ghosts(im)
                imgs[0] = im_ghost
                im_aug.append(imgs)
                im_motion = add_motion(im)
                imgs[0] = im_motion
                im_aug.append(imgs)
                im_swap = swap(im)
                imgs[0] = im_swap
                im_aug.append(imgs)
            images['Image'].append(np.array(im_aug))
            
            if mode != 'test':
                lb = io.imread(row['Label'])
                lb = label_converter(lb)
                lb = np.expand_dims(lb, axis=0)
                lbs[0] = lb
                lb_aug.append(lbs)
                gap = io.imread(row['Gap'])
                gap = np.expand_dims(gap, axis = 0)
                gaps[0] = gap
                gap_aug.append(gaps)
                if mode == 'train':
                    lb_flip1 = random_flip(lb)
                    lbs[0] = lb_flip1
                    lb_aug.append(lbs)
                    lb_flip2 = random_flip2(lb)
                    lbs[0] = lb_flip2
                    lb_aug.append(lbs)
                    lb_affine = random_affine(lb)
                    lbs[0] = lb_affine
                    lb_aug.append(lbs)
                    lb_elastic = random_elastic(lb)
                    lbs[0] = lb_elastic
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)
                    lbs[0] = lb
                    lb_aug.append(lbs)

                    gap_flip1 = random_flip(gap)
                    gaps[0] = gap_flip1
                    gap_aug.append(gaps)
                    gap_flip2 = random_flip2(gap)
                    gaps[0] = gap_flip2
                    gap_aug.append(gaps)
                    gap_affine = random_affine(gap)
                    gaps[0] = gap_affine
                    gap_aug.append(gaps)
                    gap_elastic = random_elastic(gap)
                    gaps[0] = gap_elastic
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                    gaps[0] = gap
                    gap_aug.append(gaps)
                images['Label'].append(np.array(lb_aug))
                images['Gap'].append(np.array(gap_aug))
        # Save images
        with open("../inputs/flpickles/" + mode + '.pickle', 'wb') as f:
            pickle.dump(images, f)
        with open('../inputs/flpickles/' + mode + '.pickle', 'rb') as f:
            images = pickle.load(f)

    return images