def load_kidney_seg(data_shape, batch=3, workers=4, transform=None): #take input transform and apply it after clip, normalization, resize if transform == None: transform = tio.RandomFlip(p=0.) #preprocess all clippy = Lambda(lambda x: torch.clip(x, -80, 300), types_to_apply=[tio.INTENSITY]) normal = RescaleIntensity((0., 1.)) resize = Lambda(lambda x: torch.squeeze( interpolate(torch.unsqueeze(x, dim=0), data_shape), dim=0)) rounding = Lambda(lambda x: torch.round(x), types_to_apply=[tio.LABEL]) transform = tio.Compose([clippy, normal, resize, rounding, transform]) subject_list = [] for i in range(210): pt_image = ("data/case_{:05d}/imaging.nii.gz".format(i)) pt_label = ("data/case_{:05d}/segmentation.nii.gz".format(i)) subject_list.append( tio.Subject(img=tio.ScalarImage(pt_image), label=tio.LabelMap(pt_label))) dataset = tio.SubjectsDataset(subject_list, transform=transform) return DataLoader(dataset, num_workers=workers, batch_size=batch, pin_memory=True)
def load_pretrain_datasets(data_shape, batch=3, workers=4, transform=None): data_path = '/home/mitch/Data/MSD/' directories = sorted(glob.glob(data_path + '*/')) loaders = [] #var to store dataloader for each task datasets = [] #store dataset objects before turning into loaders if transform == None: transform = tio.RandomFlip(p=0.) #preprocess all clippy = Lambda(lambda x: torch.clip(x, -80, 300), types_to_apply=[tio.INTENSITY]) normal = RescaleIntensity((0., 1.)) resize = Lambda(lambda x: torch.squeeze( interpolate(torch.unsqueeze(x, dim=0), data_shape), dim=0)) rounding = Lambda(lambda x: torch.round(x), types_to_apply=[tio.LABEL]) transform = tio.Compose([clippy, normal, resize, rounding, transform]) #deal with weird shapes braintransform = Lambda(lambda x: torch.unsqueeze(x[:, :, :, 2], dim=0), types_to_apply=[tio.INTENSITY]) braintransform = tio.Compose([braintransform, transform]) prostatetransform = Lambda(lambda x: torch.unsqueeze(x[:, :, :, 1], dim=0), types_to_apply=[tio.INTENSITY]) prostatetransform = tio.Compose([prostatetransform, transform]) for i, directory in enumerate(directories): images = sorted(glob.glob(directory + 'imagesTr/*')) segs = sorted(glob.glob(directory + 'labelsTr/*')) subject_list = [] for image, seg in zip(images, segs): subject_list.append( tio.Subject(img=tio.ScalarImage(image), label=tio.LabelMap(seg))) #handle special cases if i == 0: datasets.append( tio.SubjectsDataset(subject_list, transform=braintransform)) elif i == 4: datasets.append( tio.SubjectsDataset(subject_list, transform=prostatetransform)) else: datasets.append( tio.SubjectsDataset(subject_list, transform=transform)) loaders.append( DataLoader(datasets[-1], num_workers=workers, batch_size=batch, pin_memory=True)) return loaders
def test_transforms(self): landmarks_dict = dict( t1=np.linspace(0, 100, 13), t2=np.linspace(0, 100, 13), ) transforms = ( CenterCropOrPad((9, 21, 30)), ToCanonical(), Resample((1, 1.1, 1.25)), RandomFlip(axes=(0, 1, 2), flip_probability=1), RandomMotion(proportion_to_augment=1), RandomGhosting(proportion_to_augment=1, axes=(0, 1, 2)), RandomSpike(), RandomNoise(), RandomBlur(), RandomSwap(patch_size=2, num_iterations=5), Lambda(lambda x: 1.5 * x, types_to_apply=INTENSITY), RandomBiasField(), Rescale((0, 1)), ZNormalization(masking_method='label'), HistogramStandardization(landmarks_dict=landmarks_dict), RandomElasticDeformation(proportion_to_augment=1), RandomAffine(), Pad((1, 2, 3, 0, 5, 6)), Crop((3, 2, 8, 0, 1, 4)), ) transformed = self.get_sample() for transform in transforms: transformed = transform(transformed)
def test_image_types(self): transform = Lambda(lambda x: x + 1, types_to_apply=[LABEL]) transformed = transform(self.sample_subject) assert torch.all(torch.eq( transformed.t1.data, self.sample_subject.t1.data)) assert torch.all(torch.eq( transformed.t2.data, self.sample_subject.t2.data)) assert torch.all(torch.eq( transformed.label.data, self.sample_subject.label.data + 1))
def test_lambda(self): transform = Lambda(lambda x: x + 1) transformed = transform(self.sample_subject) assert torch.all(torch.eq( transformed.t1.data, self.sample_subject.t1.data + 1)) assert torch.all(torch.eq( transformed.t2.data, self.sample_subject.t2.data + 1)) assert torch.all(torch.eq( transformed.label.data, self.sample_subject.label.data + 1))
def test_image_types(self): transform = Lambda(lambda x: x + 1, types_to_apply=[LABEL]) transformed = transform(self.sample) assert torch.all( torch.eq(transformed['t1'][DATA], self.sample['t1'][DATA])) assert torch.all( torch.eq(transformed['t2'][DATA], self.sample['t2'][DATA])) assert torch.all( torch.eq(transformed['label'][DATA], self.sample['label'][DATA] + 1))
def test_lambda(self): transform = Lambda(lambda x: x + 1) transformed = transform(self.sample) assert torch.all( torch.eq(transformed['t1'][DATA], self.sample['t1'][DATA] + 1)) assert torch.all( torch.eq(transformed['t2'][DATA], self.sample['t2'][DATA] + 1)) assert torch.all( torch.eq(transformed['label'][DATA], self.sample['label'][DATA] + 1))
def rotate_180(parameters): return Lambda( function=partial(tensor_rotate_180, axis=parameters["axis"]), p=parameters["probability"], )
def threshold_transform(min, max, p=1): return Lambda(function=partial(threshold_intensities, min=min, max=max), p=p)
def rotate_180(axis, p=1): return Lambda(function=partial(tensor_rotate_180, axis=axis), p=p)
def clip_transform(min, max, p=1): return Lambda(function=partial(clip_intensities, min=min, max=max), p=p)
def test_wrong_return_type(self): transform = Lambda(lambda x: 'Not a tensor') with self.assertRaises(ValueError): transform(self.sample)
def test_wrong_return_shape(self): transform = Lambda(lambda x: torch.rand(1)) with self.assertRaises(ValueError): transform(self.sample)
def test_wrong_return_data_type(self): transform = Lambda(lambda x: torch.rand(1) > 0) with self.assertRaises(ValueError): transform(self.sample_subject)
from torchio.transforms import Lambda from torchio import Image, ImagesDataset, INTENSITY, LABEL, Subject subject = Subject( Image('label', '~/Dropbox/MRI/t1_brain_seg.nii.gz', LABEL), Image('t1', '~/Dropbox/MRI/t1.nii.gz', INTENSITY), ) subjects_list = [subject] dataset = ImagesDataset(subjects_list) sample = dataset[0] transform = Lambda(lambda x: -1.5 * x, types_to_apply=INTENSITY) transformed = transform(sample) dataset.save_sample(transformed, {'t1': '/tmp/t1_lambda.nii'})