Exemplo n.º 1
0
def load_kidney_seg(data_shape, batch=3, workers=4, transform=None):

    #take input transform and apply it after clip, normalization, resize
    if transform == None:
        transform = tio.RandomFlip(p=0.)
    #preprocess all
    clippy = Lambda(lambda x: torch.clip(x, -80, 300),
                    types_to_apply=[tio.INTENSITY])
    normal = RescaleIntensity((0., 1.))
    resize = Lambda(lambda x: torch.squeeze(
        interpolate(torch.unsqueeze(x, dim=0), data_shape), dim=0))
    rounding = Lambda(lambda x: torch.round(x), types_to_apply=[tio.LABEL])
    transform = tio.Compose([clippy, normal, resize, rounding, transform])

    subject_list = []
    for i in range(210):
        pt_image = ("data/case_{:05d}/imaging.nii.gz".format(i))
        pt_label = ("data/case_{:05d}/segmentation.nii.gz".format(i))
        subject_list.append(
            tio.Subject(img=tio.ScalarImage(pt_image),
                        label=tio.LabelMap(pt_label)))
    dataset = tio.SubjectsDataset(subject_list, transform=transform)
    return DataLoader(dataset,
                      num_workers=workers,
                      batch_size=batch,
                      pin_memory=True)
Exemplo n.º 2
0
def load_pretrain_datasets(data_shape, batch=3, workers=4, transform=None):

    data_path = '/home/mitch/Data/MSD/'
    directories = sorted(glob.glob(data_path + '*/'))

    loaders = []  #var to store dataloader for each task
    datasets = []  #store dataset objects before turning into loaders

    if transform == None:
        transform = tio.RandomFlip(p=0.)
    #preprocess all
    clippy = Lambda(lambda x: torch.clip(x, -80, 300),
                    types_to_apply=[tio.INTENSITY])
    normal = RescaleIntensity((0., 1.))
    resize = Lambda(lambda x: torch.squeeze(
        interpolate(torch.unsqueeze(x, dim=0), data_shape), dim=0))
    rounding = Lambda(lambda x: torch.round(x), types_to_apply=[tio.LABEL])
    transform = tio.Compose([clippy, normal, resize, rounding, transform])

    #deal with weird shapes
    braintransform = Lambda(lambda x: torch.unsqueeze(x[:, :, :, 2], dim=0),
                            types_to_apply=[tio.INTENSITY])
    braintransform = tio.Compose([braintransform, transform])
    prostatetransform = Lambda(lambda x: torch.unsqueeze(x[:, :, :, 1], dim=0),
                               types_to_apply=[tio.INTENSITY])
    prostatetransform = tio.Compose([prostatetransform, transform])

    for i, directory in enumerate(directories):
        images = sorted(glob.glob(directory + 'imagesTr/*'))
        segs = sorted(glob.glob(directory + 'labelsTr/*'))

        subject_list = []

        for image, seg in zip(images, segs):

            subject_list.append(
                tio.Subject(img=tio.ScalarImage(image),
                            label=tio.LabelMap(seg)))

        #handle special cases
        if i == 0:
            datasets.append(
                tio.SubjectsDataset(subject_list, transform=braintransform))
        elif i == 4:
            datasets.append(
                tio.SubjectsDataset(subject_list, transform=prostatetransform))
        else:
            datasets.append(
                tio.SubjectsDataset(subject_list, transform=transform))

        loaders.append(
            DataLoader(datasets[-1],
                       num_workers=workers,
                       batch_size=batch,
                       pin_memory=True))

    return loaders
Exemplo n.º 3
0
 def test_transforms(self):
     landmarks_dict = dict(
         t1=np.linspace(0, 100, 13),
         t2=np.linspace(0, 100, 13),
     )
     transforms = (
         CenterCropOrPad((9, 21, 30)),
         ToCanonical(),
         Resample((1, 1.1, 1.25)),
         RandomFlip(axes=(0, 1, 2), flip_probability=1),
         RandomMotion(proportion_to_augment=1),
         RandomGhosting(proportion_to_augment=1, axes=(0, 1, 2)),
         RandomSpike(),
         RandomNoise(),
         RandomBlur(),
         RandomSwap(patch_size=2, num_iterations=5),
         Lambda(lambda x: 1.5 * x, types_to_apply=INTENSITY),
         RandomBiasField(),
         Rescale((0, 1)),
         ZNormalization(masking_method='label'),
         HistogramStandardization(landmarks_dict=landmarks_dict),
         RandomElasticDeformation(proportion_to_augment=1),
         RandomAffine(),
         Pad((1, 2, 3, 0, 5, 6)),
         Crop((3, 2, 8, 0, 1, 4)),
     )
     transformed = self.get_sample()
     for transform in transforms:
         transformed = transform(transformed)
Exemplo n.º 4
0
 def test_image_types(self):
     transform = Lambda(lambda x: x + 1, types_to_apply=[LABEL])
     transformed = transform(self.sample_subject)
     assert torch.all(torch.eq(
         transformed.t1.data, self.sample_subject.t1.data))
     assert torch.all(torch.eq(
         transformed.t2.data, self.sample_subject.t2.data))
     assert torch.all(torch.eq(
         transformed.label.data, self.sample_subject.label.data + 1))
Exemplo n.º 5
0
 def test_lambda(self):
     transform = Lambda(lambda x: x + 1)
     transformed = transform(self.sample_subject)
     assert torch.all(torch.eq(
         transformed.t1.data, self.sample_subject.t1.data + 1))
     assert torch.all(torch.eq(
         transformed.t2.data, self.sample_subject.t2.data + 1))
     assert torch.all(torch.eq(
         transformed.label.data, self.sample_subject.label.data + 1))
Exemplo n.º 6
0
 def test_image_types(self):
     transform = Lambda(lambda x: x + 1, types_to_apply=[LABEL])
     transformed = transform(self.sample)
     assert torch.all(
         torch.eq(transformed['t1'][DATA], self.sample['t1'][DATA]))
     assert torch.all(
         torch.eq(transformed['t2'][DATA], self.sample['t2'][DATA]))
     assert torch.all(
         torch.eq(transformed['label'][DATA],
                  self.sample['label'][DATA] + 1))
Exemplo n.º 7
0
 def test_lambda(self):
     transform = Lambda(lambda x: x + 1)
     transformed = transform(self.sample)
     assert torch.all(
         torch.eq(transformed['t1'][DATA], self.sample['t1'][DATA] + 1))
     assert torch.all(
         torch.eq(transformed['t2'][DATA], self.sample['t2'][DATA] + 1))
     assert torch.all(
         torch.eq(transformed['label'][DATA],
                  self.sample['label'][DATA] + 1))
Exemplo n.º 8
0
def rotate_180(parameters):
    return Lambda(
        function=partial(tensor_rotate_180, axis=parameters["axis"]),
        p=parameters["probability"],
    )
Exemplo n.º 9
0
def threshold_transform(min, max, p=1):
    return Lambda(function=partial(threshold_intensities, min=min, max=max),
                  p=p)
Exemplo n.º 10
0
def rotate_180(axis, p=1):
    return Lambda(function=partial(tensor_rotate_180, axis=axis), p=p)
Exemplo n.º 11
0
def clip_transform(min, max, p=1):
    return Lambda(function=partial(clip_intensities, min=min, max=max), p=p)
Exemplo n.º 12
0
 def test_wrong_return_type(self):
     transform = Lambda(lambda x: 'Not a tensor')
     with self.assertRaises(ValueError):
         transform(self.sample)
Exemplo n.º 13
0
 def test_wrong_return_shape(self):
     transform = Lambda(lambda x: torch.rand(1))
     with self.assertRaises(ValueError):
         transform(self.sample)
Exemplo n.º 14
0
 def test_wrong_return_data_type(self):
     transform = Lambda(lambda x: torch.rand(1) > 0)
     with self.assertRaises(ValueError):
         transform(self.sample_subject)
Exemplo n.º 15
0
from torchio.transforms import Lambda
from torchio import Image, ImagesDataset, INTENSITY, LABEL, Subject

subject = Subject(
    Image('label', '~/Dropbox/MRI/t1_brain_seg.nii.gz', LABEL),
    Image('t1', '~/Dropbox/MRI/t1.nii.gz', INTENSITY),
)
subjects_list = [subject]

dataset = ImagesDataset(subjects_list)
sample = dataset[0]
transform = Lambda(lambda x: -1.5 * x, types_to_apply=INTENSITY)
transformed = transform(sample)
dataset.save_sample(transformed, {'t1': '/tmp/t1_lambda.nii'})