def test_transforms(self): landmarks_dict = dict( t1=np.linspace(0, 100, 13), t2=np.linspace(0, 100, 13), ) elastic = torchio.RandomElasticDeformation(max_displacement=1) transforms = ( torchio.CropOrPad((9, 21, 30)), torchio.ToCanonical(), torchio.Resample((1, 1.1, 1.25)), torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1), torchio.RandomMotion(), torchio.RandomGhosting(axes=(0, 1, 2)), torchio.RandomSpike(), torchio.RandomNoise(), torchio.RandomBlur(), torchio.RandomSwap(patch_size=2, num_iterations=5), torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY), torchio.RandomBiasField(), torchio.RescaleIntensity((0, 1)), torchio.ZNormalization(masking_method='label'), torchio.HistogramStandardization(landmarks_dict=landmarks_dict), elastic, torchio.RandomAffine(), torchio.OneOf({ torchio.RandomAffine(): 3, elastic: 1 }), torchio.Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3), torchio.Crop((3, 2, 8, 0, 1, 4)), ) transform = torchio.Compose(transforms) transform(self.sample)
def show_fpg( subject, to_ras=False, stretch_slices=True, indices=None, intensity_name='mri', parcellation=True, ): subject = tio.ToCanonical()(subject) if to_ras else subject def flip(x): return np.rot90(x) fig, axes = plt.subplots(2, 3, figsize=(12, 8)) if indices is None: half_shape = torch.Tensor(subject.spatial_shape) // 2 i, j, k = half_shape.long() i += 1 j += 1 k += 7 # use a better slice else: i, j, k = indices bounds_x, bounds_y, bounds_z = get_bounds(subject.mri) ### orientation = ''.join(subject.mri.orientation) if orientation != 'RAS': import warnings warnings.warn(f'Image orientation should be RAS+, not {orientation}+') kwargs = dict(cmap='gray', interpolation='none') data = subject[intensity_name].data slices = data[0, i], data[0, :, j], data[0, ..., k] if stretch_slices: slices = [stretch(s.numpy()) for s in slices] sag, cor, axi = slices axes[0, 0].imshow(flip(sag), extent=bounds_y + bounds_z, **kwargs) axes[0, 1].imshow(flip(cor), extent=bounds_x + bounds_z, **kwargs) axes[0, 2].imshow(flip(axi), extent=bounds_x + bounds_y, **kwargs) kwargs = dict(interpolation='none') data = subject.heart.data slices = data[0, i], data[0, :, j], data[0, ..., k] if parcellation: sag, cor, axi = [color_table.colorize(s.long()) if s.max() > 1 else s for s in slices] else: sag, cor, axi = slices axes[1, 0].imshow(flip(sag), extent=bounds_y + bounds_z, **kwargs) axes[1, 1].imshow(flip(cor), extent=bounds_x + bounds_z, **kwargs) axes[1, 2].imshow(flip(axi), extent=bounds_x + bounds_y, **kwargs) plt.tight_layout()
def get_dataset( input_path, tta_iterations=0, interpolation='bspline', tolerance=0.1, mni_transform_path=None, ): if mni_transform_path is None: image = tio.ScalarImage(input_path) else: affine = tio.io.read_matrix(mni_transform_path) image = tio.ScalarImage(input_path, **{TO_MNI: affine}) subject = tio.Subject({IMAGE_NAME: image}) landmarks = np.array([ 0., 0.31331614, 0.61505419, 0.76732501, 0.98887953, 1.71169384, 3.21741126, 13.06931455, 32.70817796, 40.87807389, 47.83508873, 63.4408591, 100. ]) hist_std = tio.HistogramStandardization({IMAGE_NAME: landmarks}) preprocess_transforms = [ tio.ToCanonical(), hist_std, tio.ZNormalization(masking_method=tio.ZNormalization.mean), ] zooms = nib.load(input_path).header.get_zooms() pixdim = np.array(zooms) diff_to_1_iso = np.abs(pixdim - 1) if np.any(diff_to_1_iso > tolerance) or mni_transform_path is not None: kwargs = {'image_interpolation': interpolation} if mni_transform_path is not None: kwargs['pre_affine_name'] = TO_MNI kwargs['target'] = tio.datasets.Colin27().t1.path resample_transform = tio.Resample(**kwargs) preprocess_transforms.append(resample_transform) preprocess_transforms.append(tio.EnsureShapeMultiple(8, method='crop')) preprocess_transform = tio.Compose(preprocess_transforms) no_aug_dataset = tio.SubjectsDataset([subject], transform=preprocess_transform) aug_subjects = tta_iterations * [subject] if not aug_subjects: return no_aug_dataset augment_transform = tio.Compose(( preprocess_transform, tio.RandomFlip(), tio.RandomAffine(image_interpolation=interpolation), )) aug_dataset = tio.SubjectsDataset(aug_subjects, transform=augment_transform) dataset = torch.utils.data.ConcatDataset((no_aug_dataset, aug_dataset)) return dataset
def get_transform(self, channels, is_3d=True, labels=True): landmarks_dict = { channel: np.linspace(0, 100, 13) for channel in channels } disp = 1 if is_3d else (1, 1, 0.01) elastic = tio.RandomElasticDeformation(max_displacement=disp) cp_args = (9, 21, 30) if is_3d else (21, 30, 1) resize_args = (10, 20, 30) if is_3d else (10, 20, 1) flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) remapping = {1: 2, 2: 1, 3: 20, 4: 25} transforms = [ tio.CropOrPad(cp_args), tio.EnsureShapeMultiple(2, method='crop'), tio.Resize(resize_args), tio.ToCanonical(), tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample), tio.CopyAffine(channels[0]), tio.Resample((1, 1.1, 1.25)), tio.RandomFlip(axes=flip_axes, flip_probability=1), tio.RandomMotion(), tio.RandomGhosting(axes=(0, 1, 2)), tio.RandomSpike(), tio.RandomNoise(), tio.RandomBlur(), tio.RandomSwap(patch_size=swap_patch, num_iterations=5), tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY), tio.RandomBiasField(), tio.RescaleIntensity(out_min_max=(0, 1)), tio.ZNormalization(), tio.HistogramStandardization(landmarks_dict), elastic, tio.RandomAffine(), tio.OneOf({ tio.RandomAffine(): 3, elastic: 1, }), tio.RemapLabels(remapping=remapping, masking_method='Left'), tio.RemoveLabels([1, 3]), tio.SequentialLabels(), tio.Pad(pad_args, padding_mode=3), tio.Crop(crop_args), ] if labels: transforms.append(tio.RandomLabelsToImage(label_key='label')) return tio.Compose(transforms)
def test_get_subjects(self): ct = tio.ScalarImage(tensor=torch.rand(1, 3, 3, 2)) structure = tio.LabelMap( tensor=torch.ones((1, 3, 3, 2), dtype=torch.uint8)) subject_1_dir = "tests/test_data/subjects/subject_1" os.makedirs(subject_1_dir, exist_ok=True) ct.save(os.path.join(subject_1_dir, "ct.nii")) structure.save(os.path.join(subject_1_dir, "structure.nii")) transform = tio.Compose( [tio.ToCanonical(), tio.RescaleIntensity(1, (1, 99.0))]) subject_dataset = get_subjects(os.path.dirname(subject_1_dir), structures=["structure"], transform=transform) self.assertEqual(len(subject_dataset), 1) shutil.rmtree(os.path.dirname(subject_1_dir), ignore_errors=True)
def get_transform(self, channels, is_3d=True, labels=True): landmarks_dict = { channel: np.linspace(0, 100, 13) for channel in channels } disp = 1 if is_3d else (1, 1, 0.01) elastic = torchio.RandomElasticDeformation(max_displacement=disp) cp_args = (9, 21, 30) if is_3d else (21, 30, 1) flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) transforms = [ torchio.CropOrPad(cp_args), torchio.ToCanonical(), torchio.RandomDownsample(downsampling=(1.75, 2), axes=axes_downsample), torchio.Resample((1, 1.1, 1.25)), torchio.RandomFlip(axes=flip_axes, flip_probability=1), torchio.RandomMotion(), torchio.RandomGhosting(axes=(0, 1, 2)), torchio.RandomSpike(), torchio.RandomNoise(), torchio.RandomBlur(), torchio.RandomSwap(patch_size=swap_patch, num_iterations=5), torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY), torchio.RandomBiasField(), torchio.RescaleIntensity((0, 1)), torchio.ZNormalization(), torchio.HistogramStandardization(landmarks_dict), elastic, torchio.RandomAffine(), torchio.OneOf({ torchio.RandomAffine(): 3, elastic: 1, }), torchio.Pad(pad_args, padding_mode=3), torchio.Crop(crop_args), ] if labels: transforms.append(torchio.RandomLabelsToImage(label_key='label')) return torchio.Compose(transforms)
def main(input_path, checkpoint_path, output_dir, landmarks_path, batch_size, num_workers, resample): import torch from tqdm import tqdm import torchio as tio import models import datasets fps = get_paths(input_path) subjects = [tio.Subject(image=tio.ScalarImage(fp)) for fp in fps] # key must be 'image' as in get_test_transform transform = tio.Compose(( tio.ToCanonical(), datasets.get_test_transform(landmarks_path), )) if resample: transform = tio.Compose(( tio.Resample(), transform, # tio.CropOrPad((264, 268, 144)), # ################################# for BITE? )) dataset = tio.SubjectsDataset(subjects, transform) checkpoint = torch.load(checkpoint_path) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = models.get_unet().to(device) model.load_state_dict(checkpoint['model']) output_dir = Path(output_dir) model.eval() torch.set_grad_enabled(False) loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) output_dir.mkdir(exist_ok=True, parents=True) for batch in tqdm(loader): inputs = batch['image'][tio.DATA].float().to(device) seg = model(inputs).softmax(dim=1)[:, 1:].cpu() > 0.5 for tensor, affine, path in zip(seg, batch['image'][tio.AFFINE], batch['image'][tio.PATH]): image = tio.LabelMap(tensor=tensor, affine=affine.numpy()) path = Path(path) out_path = output_dir / path.name.replace('.nii', '_seg_cnn.nii') image.save(out_path) return 0
""" import pprint import torch import torchio as tio import matplotlib.pyplot as plt torch.manual_seed(0) batch_size = 4 subject = tio.datasets.FPG() subject.remove_image('seg') subjects = 4 * [subject] transform = tio.Compose(( tio.ToCanonical(), tio.RandomGamma(p=0.75), tio.RandomBlur(p=0.5), tio.RandomFlip(), tio.RescaleIntensity((-1, 1)), )) dataset = tio.SubjectsDataset(subjects, transform=transform) transformed = dataset[0] print('Applied transforms:') # noqa: T001 pprint.pprint(transformed.history) # noqa: T003 print('\nComposed transform to reproduce history:') # noqa: T001 print(transformed.get_composed_history()) # noqa: T001 print('\nComposed transform to invert applied transforms when possible:' ) # noqa: T001, E501
hr=tio.ScalarImage(t2_file), lr_1=tio.ScalarImage(t2_file), lr_2=tio.ScalarImage(t2_file), lr_3=tio.ScalarImage(t2_file), ) subjects.append(subject) print('DHCP Dataset size:', len(subjects), 'subjects') # DATA AUGMENTATION normalization = tio.ZNormalization() spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0.75) flip = tio.RandomFlip(axes=('LR', ), flip_probability=0.5) tocanonical = tio.ToCanonical() b1 = tio.Blur(std=(0.001, 0.001, 1), include='lr_1') #blur d1 = tio.Resample((0.8, 0.8, 2), include='lr_1') #downsampling u1 = tio.Resample(target='hr', include='lr_1') #upsampling if in_channels == 3: b2 = tio.Blur(std=(0.001, 1, 0.001), include='lr_2') #blur d2 = tio.Resample((0.8, 2, 0.8), include='lr_2') #downsampling u2 = tio.Resample(target='hr', include='lr_2') #upsampling b3 = tio.Blur(std=(1, 0.001, 0.001), include='lr_3') #blur d3 = tio.Resample((2, 0.8, 0.8), include='lr_3') #downsampling u3 = tio.Resample(target='hr', include='lr_3') #upsampling if in_channels == 1:
def __init__(self, **kwargs): super().__init__() self.save_hyperparameters() device = torch.device( "cuda:0" if torch.cuda.is_available() and self.hparams.cuda else "cpu") if self.hparams.modelID == 0: self.net = ResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks, starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks, is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm, res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV, upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D) # TODO think of 2D # self.net = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1) elif self.hparams.modelID == 2: self.net = DualSpaceResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks, starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks, is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm, res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV, upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D, connect_mode=self.hparams.model_dspace_connect_mode, inner_norm_ksp=self.hparams.model_inner_norm_ksp) elif self.hparams.modelID == 3: #Primal-Dual Network, complex Primal self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10, use_original_block = True, use_original_init = True, use_complex_primal = True, g_normtype = "magmax", transform = "Fourier", return_abs = True) elif self.hparams.modelID == 4: #Primal-Dual Network, absolute Primal self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10, use_original_block = True, use_original_init = True, use_complex_primal = False, g_normtype = "magmax", transform = "Fourier") elif self.hparams.modelID == 5: #Primal-Dual UNet Network, absolute Primal self.net = PrimalDualNetwork(n_primary=4, n_dual=5, n_iterations=2, use_original_block = False, use_original_init = False, use_complex_primal = False, g_normtype = "magmax", transform = "Fourier") elif self.hparams.modelID == 6: #Primal-Dual Network v2 (no residual), complex Primal self.net = PrimalDualNetworkNoResidue(n_primary=5, n_dual=5, n_iterations=10, use_original_block = True, use_original_init = True, use_complex_primal = True, residuals=False, g_normtype = "magmax", transform = "Fourier", return_abs = True) else: # TODO: other models sys.exit("Only ReconResNet and DualSpaceResNet have been implemented so far in ReconEngine") if bool(self.hparams.preweights_path): print("Pre-weights found, loding...") chk = torch.load(self.hparams.preweights_path, map_location='cpu') self.net.load_state_dict(chk['state_dict']) if self.hparams.lossID == 0: if self.hparams.in_channels != 1 or self.hparams.out_channels != 1: sys.exit( "Perceptual Loss used here only works for 1 channel input and output") self.loss = PerceptualLoss(device=device, loss_model="unet3Dds", resize=None, loss_type=self.hparams.ploss_type, n_level=self.hparams.ploss_level) # TODO thinkof 2D elif self.hparams.lossID == 1: self.loss = nn.L1Loss(reduction='mean') elif self.hparams.lossID == 2: self.loss = MS_SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device) elif self.hparams.lossID == 3: self.loss = SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device) else: sys.exit("Invalid Loss ID") self.dataspace = DataSpaceHandler(**self.hparams) if self.hparams.ds_mode == 0: trans = tioTransforms augs = tioAugmentations elif self.hparams.ds_mode == 1: trans = pytTransforms augs = pytAugmentations # TODO parameterised everything self.init_transforms = [] self.aug_transforms = [] self.transforms = [] if self.hparams.ds_mode == 0 and self.hparams.cannonicalResample: # Only applicable for TorchIO self.init_transforms += [tio.ToCanonical(), tio.Resample('gt')] if self.hparams.ds_mode == 0 and self.hparams.forceNormAffine: # Only applicable for TorchIO self.init_transforms += [trans.ForceAffine()] if self.hparams.croppad and self.hparams.ds_mode == 1: self.init_transforms += [ trans.CropOrPad(size=self.hparams.input_shape)] self.init_transforms += [trans.IntensityNorm(type=self.hparams.norm_type, return_meta=self.hparams.motion_return_meta)] # dataspace_transforms = self.dataspace.getTransforms() #TODO: dataspace transforms are not in use # self.init_transforms += dataspace_transforms if bool(self.hparams.random_crop) and self.hparams.ds_mode == 1: self.aug_transforms += [augs.RandomCrop( size=self.hparams.random_crop, p=self.hparams.p_random_crop)] if self.hparams.p_contrast_augment > 0: self.aug_transforms += [augs.getContrastAugs( p=self.hparams.p_contrast_augment)] # if the task if MoCo and pre-corrupted vols are not supplied if self.hparams.taskID == 1 and not bool(self.hparams.train_path_inp): if self.hparams.motion_mode == 0 and self.hparams.ds_mode == 0: motion_params = {k.split('motionmg_')[ 1]: v for k, v in self.hparams.items() if k.startswith('motionmg')} self.transforms += [tioMotion.RandomMotionGhostingFast( **motion_params), trans.IntensityNorm()] elif self.hparams.motion_mode == 1 and self.hparams.ds_mode == 1 and not self.hparams.is3D: self.transforms += [pytMotion.Motion2Dv0( sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)] elif self.hparams.motion_mode == 2 and self.hparams.ds_mode == 1 and not self.hparams.is3D: self.transforms += [pytMotion.Motion2Dv1(sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads, restore_original=self.hparams.motion_restore_original, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)] else: sys.exit( "Error: invalid motion_mode, ds_mode, is3D combo. Please double check!") self.static_metamat = sio.loadmat(self.hparams.static_metamat_file) if bool( self.hparams.static_metamat_file) else None if self.hparams.taskID == 0 and self.hparams.use_datacon: self.datacon = DataConsistency( isRadial=self.hparams.is_radial, metadict=self.static_metamat) else: self.datacon = None input_shape = self.hparams.input_shape if self.hparams.is3D else self.hparams.input_shape[ :-1] self.example_input_array = torch.empty( self.hparams.batch_size, self.hparams.in_channels, *input_shape).float() self.saver = ResSaver( self.hparams.res_path, save_inp=self.hparams.save_inp, do_norm=self.hparams.do_savenorm)
def main( input_path, parcellation_path, output_image_path, output_label_path, min_volume, max_volume, volumes_path, ): """Console script for resector.""" import torchio import resector hemispheres = 'left', 'right' input_path = Path(input_path) output_dir = input_path.parent stem = input_path.name.split('.nii')[0] # assume it's a .nii file gm_paths = [] resectable_paths = [] for hemisphere in hemispheres: dst = output_dir / f'{stem}_gray_matter_{hemisphere}_seg.nii.gz' gm_paths.append(dst) if not dst.is_file(): gm = resector.parcellation.get_gray_matter_mask( parcellation_path, hemisphere) resector.io.write(gm, dst) dst = output_dir / f'{stem}_resectable_{hemisphere}_seg.nii.gz' resectable_paths.append(dst) if not dst.is_file(): resectable = resector.parcellation.get_resectable_hemisphere_mask( parcellation_path, hemisphere, ) resector.io.write(resectable, dst) noise_path = output_dir / f'{stem}_noise.nii.gz' if not noise_path.is_file(): resector.parcellation.make_noise_image( input_path, parcellation_path, noise_path, ) if volumes_path is not None: import pandas as pd df = pd.read_csv(volumes_path) volumes = df.Volume.values kwargs = dict(volumes=volumes) else: kwargs = dict(volumes_range=(min_volume, max_volume)) transform = torchio.Compose(( torchio.ToCanonical(), resector.RandomResection(**kwargs), )) subject = torchio.Subject( image=torchio.ScalarImage(input_path), resection_resectable_left=torchio.LabelMap(resectable_paths[0]), resection_resectable_right=torchio.LabelMap(resectable_paths[1]), resection_gray_matter_left=torchio.LabelMap(gm_paths[0]), resection_gray_matter_right=torchio.LabelMap(gm_paths[1]), resection_noise=torchio.ScalarImage(noise_path), ) transformed = transform(subject) transformed['image'].save(output_image_path) transformed['label'].save(output_label_path) return 0
def colorize(self, label_map: np.ndarray) -> np.ndarray: rgb = np.stack(3 * [label_map], axis=-1) for label in np.unique(label_map): mask = label_map == label color = self.get_color(label) rgb[mask] = color return rgb color_table = ColorTable('colormap.txt') fpg = one_subject #tio.dataset.FPG() print('Sample subject:', fpg) show_fpg(fpg) to_ras = tio.ToCanonical() fpg_ras = to_ras(fpg) print('Old orientation:', fpg.mri.orientation) print('New orientation:', fpg_ras.mri.orientation) show_fpg(fpg_ras) print(fpg_ras.mri) print(fpg_ras.heart) #Another handy use for the Resample transform is to apply a precomputed transformation to a standard space, #such as the MNI space. The FPG dataset includes this transform: np.set_printoptions(precision=2, suppress=True) fpg.mri.affine #.numpy() # To maybe CropOrPad
def main( input_path, parcellation_path, output_image_path, output_label_path, seed, min_volume, max_volume, volumes_path, simplex_path, std_blur, shape, texture, center_ras, wm_lesion, clot, verbose, debug_dir, cleanup, ): import torchio as tio import resector if seed is not None: import torch torch.manual_seed(seed) if debug_dir is not None: resector.io.debug_dir = Path(debug_dir).expanduser().absolute() resectable_paths, gm_paths, noise_path, existed = ensure_images( input_path, parcellation_path, ) try: if volumes_path is not None: import pandas as pd df = pd.read_csv(volumes_path) volumes = df.Volume.values kwargs = dict(volumes=volumes) else: kwargs = dict(volumes_range=(min_volume, max_volume)) if std_blur is not None: kwargs['sigmas_range'] = std_blur, std_blur kwargs['simplex_path'] = simplex_path kwargs['wm_lesion_p'] = wm_lesion kwargs['clot_p'] = clot kwargs['verbose'] = verbose kwargs['shape'] = shape kwargs['texture'] = texture kwargs['center_ras'] = center_ras transform = tio.Compose(( tio.ToCanonical(), resector.RandomResection(**kwargs), )) subject = tio.Subject( image=tio.ScalarImage(input_path), resection_resectable_left=tio.LabelMap(resectable_paths[0]), resection_resectable_right=tio.LabelMap(resectable_paths[1]), resection_gray_matter_left=tio.LabelMap(gm_paths[0]), resection_gray_matter_right=tio.LabelMap(gm_paths[1]), resection_noise=tio.ScalarImage(noise_path), ) with resector.timer('RandomResection', verbose): transformed = transform(subject) with resector.timer('Saving images', verbose): transformed['image'].save(output_image_path) transformed['label'].save(output_label_path) return_code = 0 except Exception as e: return_code = 1 raise finally: if not existed and cleanup: with resector.timer('Cleaning up', verbose): for p in resectable_paths: p.unlink() for p in gm_paths: p.unlink() noise_path.unlink() return return_code