Esempio n. 1
0
 def test_transforms(self):
     landmarks_dict = dict(
         t1=np.linspace(0, 100, 13),
         t2=np.linspace(0, 100, 13),
     )
     elastic = torchio.RandomElasticDeformation(max_displacement=1)
     transforms = (
         torchio.CropOrPad((9, 21, 30)),
         torchio.ToCanonical(),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=2, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(masking_method='label'),
         torchio.HistogramStandardization(landmarks_dict=landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1
         }),
         torchio.Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3),
         torchio.Crop((3, 2, 8, 0, 1, 4)),
     )
     transform = torchio.Compose(transforms)
     transform(self.sample)
Esempio n. 2
0
def show_fpg(
        subject,
        to_ras=False,
        stretch_slices=True,
        indices=None,
        intensity_name='mri',
        parcellation=True,
):
    subject = tio.ToCanonical()(subject) if to_ras else subject

    def flip(x):
        return np.rot90(x)

    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    if indices is None:
        half_shape = torch.Tensor(subject.spatial_shape) // 2
        i, j, k = half_shape.long()
        i += 1
        j += 1
        k += 7 # use a better slice
    else:
        i, j, k = indices
    bounds_x, bounds_y, bounds_z = get_bounds(subject.mri)  ###

    orientation = ''.join(subject.mri.orientation)
    if orientation != 'RAS':
        import warnings
        warnings.warn(f'Image orientation should be RAS+, not {orientation}+')

    kwargs = dict(cmap='gray', interpolation='none')
    data = subject[intensity_name].data
    slices = data[0, i], data[0, :, j], data[0, ..., k]
    if stretch_slices:
        slices = [stretch(s.numpy()) for s in slices]
    sag, cor, axi = slices

    axes[0, 0].imshow(flip(sag), extent=bounds_y + bounds_z, **kwargs)
    axes[0, 1].imshow(flip(cor), extent=bounds_x + bounds_z, **kwargs)
    axes[0, 2].imshow(flip(axi), extent=bounds_x + bounds_y, **kwargs)

    kwargs = dict(interpolation='none')
    data = subject.heart.data
    slices = data[0, i], data[0, :, j], data[0, ..., k]
    if parcellation:
        sag, cor, axi = [color_table.colorize(s.long()) if s.max() > 1 else s for s in slices]
    else:
        sag, cor, axi = slices
    axes[1, 0].imshow(flip(sag), extent=bounds_y + bounds_z, **kwargs)
    axes[1, 1].imshow(flip(cor), extent=bounds_x + bounds_z, **kwargs)
    axes[1, 2].imshow(flip(axi), extent=bounds_x + bounds_y, **kwargs)

    plt.tight_layout()
Esempio n. 3
0
def get_dataset(
    input_path,
    tta_iterations=0,
    interpolation='bspline',
    tolerance=0.1,
    mni_transform_path=None,
):
    if mni_transform_path is None:
        image = tio.ScalarImage(input_path)
    else:
        affine = tio.io.read_matrix(mni_transform_path)
        image = tio.ScalarImage(input_path, **{TO_MNI: affine})
    subject = tio.Subject({IMAGE_NAME: image})
    landmarks = np.array([
        0., 0.31331614, 0.61505419, 0.76732501, 0.98887953, 1.71169384,
        3.21741126, 13.06931455, 32.70817796, 40.87807389, 47.83508873,
        63.4408591, 100.
    ])
    hist_std = tio.HistogramStandardization({IMAGE_NAME: landmarks})
    preprocess_transforms = [
        tio.ToCanonical(),
        hist_std,
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
    ]
    zooms = nib.load(input_path).header.get_zooms()
    pixdim = np.array(zooms)
    diff_to_1_iso = np.abs(pixdim - 1)
    if np.any(diff_to_1_iso > tolerance) or mni_transform_path is not None:
        kwargs = {'image_interpolation': interpolation}
        if mni_transform_path is not None:
            kwargs['pre_affine_name'] = TO_MNI
            kwargs['target'] = tio.datasets.Colin27().t1.path
        resample_transform = tio.Resample(**kwargs)
        preprocess_transforms.append(resample_transform)
    preprocess_transforms.append(tio.EnsureShapeMultiple(8, method='crop'))
    preprocess_transform = tio.Compose(preprocess_transforms)
    no_aug_dataset = tio.SubjectsDataset([subject],
                                         transform=preprocess_transform)

    aug_subjects = tta_iterations * [subject]
    if not aug_subjects:
        return no_aug_dataset
    augment_transform = tio.Compose((
        preprocess_transform,
        tio.RandomFlip(),
        tio.RandomAffine(image_interpolation=interpolation),
    ))
    aug_dataset = tio.SubjectsDataset(aug_subjects,
                                      transform=augment_transform)
    dataset = torch.utils.data.ConcatDataset((no_aug_dataset, aug_dataset))
    return dataset
Esempio n. 4
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = tio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     resize_args = (10, 20, 30) if is_3d else (10, 20, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     remapping = {1: 2, 2: 1, 3: 20, 4: 25}
     transforms = [
         tio.CropOrPad(cp_args),
         tio.EnsureShapeMultiple(2, method='crop'),
         tio.Resize(resize_args),
         tio.ToCanonical(),
         tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
         tio.CopyAffine(channels[0]),
         tio.Resample((1, 1.1, 1.25)),
         tio.RandomFlip(axes=flip_axes, flip_probability=1),
         tio.RandomMotion(),
         tio.RandomGhosting(axes=(0, 1, 2)),
         tio.RandomSpike(),
         tio.RandomNoise(),
         tio.RandomBlur(),
         tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
         tio.RandomBiasField(),
         tio.RescaleIntensity(out_min_max=(0, 1)),
         tio.ZNormalization(),
         tio.HistogramStandardization(landmarks_dict),
         elastic,
         tio.RandomAffine(),
         tio.OneOf({
             tio.RandomAffine(): 3,
             elastic: 1,
         }),
         tio.RemapLabels(remapping=remapping, masking_method='Left'),
         tio.RemoveLabels([1, 3]),
         tio.SequentialLabels(),
         tio.Pad(pad_args, padding_mode=3),
         tio.Crop(crop_args),
     ]
     if labels:
         transforms.append(tio.RandomLabelsToImage(label_key='label'))
     return tio.Compose(transforms)
Esempio n. 5
0
    def test_get_subjects(self):
        ct = tio.ScalarImage(tensor=torch.rand(1, 3, 3, 2))
        structure = tio.LabelMap(
            tensor=torch.ones((1, 3, 3, 2), dtype=torch.uint8))

        subject_1_dir = "tests/test_data/subjects/subject_1"

        os.makedirs(subject_1_dir, exist_ok=True)
        ct.save(os.path.join(subject_1_dir, "ct.nii"))
        structure.save(os.path.join(subject_1_dir, "structure.nii"))
        transform = tio.Compose(
            [tio.ToCanonical(),
             tio.RescaleIntensity(1, (1, 99.0))])
        subject_dataset = get_subjects(os.path.dirname(subject_1_dir),
                                       structures=["structure"],
                                       transform=transform)
        self.assertEqual(len(subject_dataset), 1)
        shutil.rmtree(os.path.dirname(subject_1_dir), ignore_errors=True)
Esempio n. 6
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = torchio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     transforms = [
         torchio.CropOrPad(cp_args),
         torchio.ToCanonical(),
         torchio.RandomDownsample(downsampling=(1.75, 2),
                                  axes=axes_downsample),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=flip_axes, flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(),
         torchio.HistogramStandardization(landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1,
         }),
         torchio.Pad(pad_args, padding_mode=3),
         torchio.Crop(crop_args),
     ]
     if labels:
         transforms.append(torchio.RandomLabelsToImage(label_key='label'))
     return torchio.Compose(transforms)
Esempio n. 7
0
def main(input_path, checkpoint_path, output_dir, landmarks_path, batch_size, num_workers, resample):
    import torch
    from tqdm import tqdm
    import torchio as tio
    import models
    import datasets

    fps = get_paths(input_path)
    subjects = [tio.Subject(image=tio.ScalarImage(fp)) for fp in fps]  # key must be 'image' as in get_test_transform
    transform = tio.Compose((
        tio.ToCanonical(),
        datasets.get_test_transform(landmarks_path),
    ))
    if resample:
        transform = tio.Compose((
            tio.Resample(),
            transform,
            # tio.CropOrPad((264, 268, 144)),  # ################################# for BITE?
        ))
    dataset = tio.SubjectsDataset(subjects, transform)
    checkpoint = torch.load(checkpoint_path)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = models.get_unet().to(device)
    model.load_state_dict(checkpoint['model'])
    output_dir = Path(output_dir)
    model.eval()
    torch.set_grad_enabled(False)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
    output_dir.mkdir(exist_ok=True, parents=True)
    for batch in tqdm(loader):
        inputs = batch['image'][tio.DATA].float().to(device)
        seg = model(inputs).softmax(dim=1)[:, 1:].cpu() > 0.5
        for tensor, affine, path in zip(seg, batch['image'][tio.AFFINE], batch['image'][tio.PATH]):
            image = tio.LabelMap(tensor=tensor, affine=affine.numpy())
            path = Path(path)
            out_path = output_dir / path.name.replace('.nii', '_seg_cnn.nii')
            image.save(out_path)
    return 0
Esempio n. 8
0
"""

import pprint
import torch
import torchio as tio
import matplotlib.pyplot as plt

torch.manual_seed(0)

batch_size = 4
subject = tio.datasets.FPG()
subject.remove_image('seg')
subjects = 4 * [subject]

transform = tio.Compose((
    tio.ToCanonical(),
    tio.RandomGamma(p=0.75),
    tio.RandomBlur(p=0.5),
    tio.RandomFlip(),
    tio.RescaleIntensity((-1, 1)),
))

dataset = tio.SubjectsDataset(subjects, transform=transform)

transformed = dataset[0]
print('Applied transforms:')  # noqa: T001
pprint.pprint(transformed.history)  # noqa: T003
print('\nComposed transform to reproduce history:')  # noqa: T001
print(transformed.get_composed_history())  # noqa: T001
print('\nComposed transform to invert applied transforms when possible:'
      )  # noqa: T001, E501
Esempio n. 9
0
            hr=tio.ScalarImage(t2_file),
            lr_1=tio.ScalarImage(t2_file),
            lr_2=tio.ScalarImage(t2_file),
            lr_3=tio.ScalarImage(t2_file),
        )

    subjects.append(subject)

print('DHCP Dataset size:', len(subjects), 'subjects')

# DATA AUGMENTATION
normalization = tio.ZNormalization()
spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0.75)
flip = tio.RandomFlip(axes=('LR', ), flip_probability=0.5)

tocanonical = tio.ToCanonical()

b1 = tio.Blur(std=(0.001, 0.001, 1), include='lr_1')  #blur
d1 = tio.Resample((0.8, 0.8, 2), include='lr_1')  #downsampling
u1 = tio.Resample(target='hr', include='lr_1')  #upsampling

if in_channels == 3:
    b2 = tio.Blur(std=(0.001, 1, 0.001), include='lr_2')  #blur
    d2 = tio.Resample((0.8, 2, 0.8), include='lr_2')  #downsampling
    u2 = tio.Resample(target='hr', include='lr_2')  #upsampling

    b3 = tio.Blur(std=(1, 0.001, 0.001), include='lr_3')  #blur
    d3 = tio.Resample((2, 0.8, 0.8), include='lr_3')  #downsampling
    u3 = tio.Resample(target='hr', include='lr_3')  #upsampling

if in_channels == 1:
Esempio n. 10
0
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        device = torch.device(
            "cuda:0" if torch.cuda.is_available() and self.hparams.cuda else "cpu")

        if self.hparams.modelID == 0:
            self.net = ResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks,
                              starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks,
                              is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm,
                              res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV,
                              upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D)  # TODO think of 2D
            # self.net = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1)
        elif self.hparams.modelID == 2:
            self.net = DualSpaceResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks,
                                        starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks,
                                        is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm,
                                        res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV,
                                        upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D,
                                        connect_mode=self.hparams.model_dspace_connect_mode, inner_norm_ksp=self.hparams.model_inner_norm_ksp)
        elif self.hparams.modelID == 3: #Primal-Dual Network, complex Primal
            self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10,
                            use_original_block = True,
                            use_original_init = True,
                            use_complex_primal = True,
                            g_normtype = "magmax",
                            transform = "Fourier",
                            return_abs = True)
        elif self.hparams.modelID == 4: #Primal-Dual Network, absolute Primal
            self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10,
                            use_original_block = True,
                            use_original_init = True,
                            use_complex_primal = False,
                            g_normtype = "magmax",
                            transform = "Fourier")
        elif self.hparams.modelID == 5: #Primal-Dual UNet Network, absolute Primal
            self.net = PrimalDualNetwork(n_primary=4, n_dual=5, n_iterations=2,
                            use_original_block = False,
                            use_original_init = False,
                            use_complex_primal = False,
                            g_normtype = "magmax",
                            transform = "Fourier")
        elif self.hparams.modelID == 6: #Primal-Dual Network v2 (no residual), complex Primal
            self.net = PrimalDualNetworkNoResidue(n_primary=5, n_dual=5, n_iterations=10,
                            use_original_block = True,
                            use_original_init = True,
                            use_complex_primal = True,
                            residuals=False,
                            g_normtype = "magmax",
                            transform = "Fourier",
                            return_abs = True)

        else:
            # TODO: other models
            sys.exit("Only ReconResNet and DualSpaceResNet have been implemented so far in ReconEngine")

        if bool(self.hparams.preweights_path):
            print("Pre-weights found, loding...")
            chk = torch.load(self.hparams.preweights_path, map_location='cpu')
            self.net.load_state_dict(chk['state_dict'])

        if self.hparams.lossID == 0:
            if self.hparams.in_channels != 1 or self.hparams.out_channels != 1:
                sys.exit(
                    "Perceptual Loss used here only works for 1 channel input and output")
            self.loss = PerceptualLoss(device=device, loss_model="unet3Dds", resize=None,
                                       loss_type=self.hparams.ploss_type, n_level=self.hparams.ploss_level)  # TODO thinkof 2D
        elif self.hparams.lossID == 1:
            self.loss = nn.L1Loss(reduction='mean')
        elif self.hparams.lossID == 2:
            self.loss = MS_SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device)
        elif self.hparams.lossID == 3:
            self.loss = SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device)
        else:
            sys.exit("Invalid Loss ID")

        self.dataspace = DataSpaceHandler(**self.hparams)

        if self.hparams.ds_mode == 0:
            trans = tioTransforms
            augs = tioAugmentations
        elif self.hparams.ds_mode == 1:
            trans = pytTransforms
            augs = pytAugmentations

        # TODO parameterised everything
        self.init_transforms = []
        self.aug_transforms = []
        self.transforms = []
        if self.hparams.ds_mode == 0 and self.hparams.cannonicalResample:  # Only applicable for TorchIO
            self.init_transforms += [tio.ToCanonical(), tio.Resample('gt')]
        if self.hparams.ds_mode == 0 and self.hparams.forceNormAffine:  # Only applicable for TorchIO
            self.init_transforms += [trans.ForceAffine()]
        if self.hparams.croppad and self.hparams.ds_mode == 1:
            self.init_transforms += [
                trans.CropOrPad(size=self.hparams.input_shape)]
        self.init_transforms += [trans.IntensityNorm(type=self.hparams.norm_type, return_meta=self.hparams.motion_return_meta)]
        # dataspace_transforms = self.dataspace.getTransforms() #TODO: dataspace transforms are not in use
        # self.init_transforms += dataspace_transforms
        if bool(self.hparams.random_crop) and self.hparams.ds_mode == 1:
            self.aug_transforms += [augs.RandomCrop(
                size=self.hparams.random_crop, p=self.hparams.p_random_crop)]
        if self.hparams.p_contrast_augment > 0:
            self.aug_transforms += [augs.getContrastAugs(
                p=self.hparams.p_contrast_augment)]
        # if the task if MoCo and pre-corrupted vols are not supplied
        if self.hparams.taskID == 1 and not bool(self.hparams.train_path_inp):
            if self.hparams.motion_mode == 0 and self.hparams.ds_mode == 0:
                motion_params = {k.split('motionmg_')[
                    1]: v for k, v in self.hparams.items() if k.startswith('motionmg')}
                self.transforms += [tioMotion.RandomMotionGhostingFast(
                    **motion_params), trans.IntensityNorm()]
            elif self.hparams.motion_mode == 1 and self.hparams.ds_mode == 1 and not self.hparams.is3D:
                self.transforms += [pytMotion.Motion2Dv0(
                    sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)]
            elif self.hparams.motion_mode == 2 and self.hparams.ds_mode == 1 and not self.hparams.is3D:
                self.transforms += [pytMotion.Motion2Dv1(sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads,
                                                         restore_original=self.hparams.motion_restore_original, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)]
            else:
                sys.exit(
                    "Error: invalid motion_mode, ds_mode, is3D combo. Please double check!")

        self.static_metamat = sio.loadmat(self.hparams.static_metamat_file) if bool(
            self.hparams.static_metamat_file) else None
        if self.hparams.taskID == 0 and self.hparams.use_datacon:
            self.datacon = DataConsistency(
                isRadial=self.hparams.is_radial, metadict=self.static_metamat)
        else:
            self.datacon = None

        input_shape = self.hparams.input_shape if self.hparams.is3D else self.hparams.input_shape[
            :-1]
        self.example_input_array = torch.empty(
            self.hparams.batch_size, self.hparams.in_channels, *input_shape).float()
        self.saver = ResSaver(
            self.hparams.res_path, save_inp=self.hparams.save_inp, do_norm=self.hparams.do_savenorm)
Esempio n. 11
0
def main(
    input_path,
    parcellation_path,
    output_image_path,
    output_label_path,
    min_volume,
    max_volume,
    volumes_path,
):
    """Console script for resector."""
    import torchio
    import resector
    hemispheres = 'left', 'right'
    input_path = Path(input_path)
    output_dir = input_path.parent
    stem = input_path.name.split('.nii')[0]  # assume it's a .nii file

    gm_paths = []
    resectable_paths = []
    for hemisphere in hemispheres:
        dst = output_dir / f'{stem}_gray_matter_{hemisphere}_seg.nii.gz'
        gm_paths.append(dst)
        if not dst.is_file():
            gm = resector.parcellation.get_gray_matter_mask(
                parcellation_path, hemisphere)
            resector.io.write(gm, dst)
        dst = output_dir / f'{stem}_resectable_{hemisphere}_seg.nii.gz'
        resectable_paths.append(dst)
        if not dst.is_file():
            resectable = resector.parcellation.get_resectable_hemisphere_mask(
                parcellation_path,
                hemisphere,
            )
            resector.io.write(resectable, dst)
    noise_path = output_dir / f'{stem}_noise.nii.gz'
    if not noise_path.is_file():
        resector.parcellation.make_noise_image(
            input_path,
            parcellation_path,
            noise_path,
        )

    if volumes_path is not None:
        import pandas as pd
        df = pd.read_csv(volumes_path)
        volumes = df.Volume.values
        kwargs = dict(volumes=volumes)
    else:
        kwargs = dict(volumes_range=(min_volume, max_volume))

    transform = torchio.Compose((
        torchio.ToCanonical(),
        resector.RandomResection(**kwargs),
    ))
    subject = torchio.Subject(
        image=torchio.ScalarImage(input_path),
        resection_resectable_left=torchio.LabelMap(resectable_paths[0]),
        resection_resectable_right=torchio.LabelMap(resectable_paths[1]),
        resection_gray_matter_left=torchio.LabelMap(gm_paths[0]),
        resection_gray_matter_right=torchio.LabelMap(gm_paths[1]),
        resection_noise=torchio.ScalarImage(noise_path),
    )
    transformed = transform(subject)
    transformed['image'].save(output_image_path)
    transformed['label'].save(output_label_path)
    return 0
Esempio n. 12
0
    def colorize(self, label_map: np.ndarray) -> np.ndarray:
        rgb = np.stack(3 * [label_map], axis=-1)
        for label in np.unique(label_map):
            mask = label_map == label
            color = self.get_color(label)
            rgb[mask] = color
        return rgb

color_table = ColorTable('colormap.txt')


fpg = one_subject #tio.dataset.FPG()
print('Sample subject:', fpg)
show_fpg(fpg)

to_ras = tio.ToCanonical()
fpg_ras = to_ras(fpg)
print('Old orientation:', fpg.mri.orientation)
print('New orientation:', fpg_ras.mri.orientation)
show_fpg(fpg_ras)

print(fpg_ras.mri)
print(fpg_ras.heart)

#Another handy use for the Resample transform is to apply a precomputed transformation to a standard space,
#such as the MNI space. The FPG dataset includes this transform:

np.set_printoptions(precision=2, suppress=True)
fpg.mri.affine #.numpy()

# To maybe CropOrPad
Esempio n. 13
0
def main(
    input_path,
    parcellation_path,
    output_image_path,
    output_label_path,
    seed,
    min_volume,
    max_volume,
    volumes_path,
    simplex_path,
    std_blur,
    shape,
    texture,
    center_ras,
    wm_lesion,
    clot,
    verbose,
    debug_dir,
    cleanup,
):
    import torchio as tio
    import resector

    if seed is not None:
        import torch
        torch.manual_seed(seed)

    if debug_dir is not None:
        resector.io.debug_dir = Path(debug_dir).expanduser().absolute()

    resectable_paths, gm_paths, noise_path, existed = ensure_images(
        input_path,
        parcellation_path,
    )

    try:
        if volumes_path is not None:
            import pandas as pd
            df = pd.read_csv(volumes_path)
            volumes = df.Volume.values
            kwargs = dict(volumes=volumes)
        else:
            kwargs = dict(volumes_range=(min_volume, max_volume))
        if std_blur is not None:
            kwargs['sigmas_range'] = std_blur, std_blur
        kwargs['simplex_path'] = simplex_path
        kwargs['wm_lesion_p'] = wm_lesion
        kwargs['clot_p'] = clot
        kwargs['verbose'] = verbose
        kwargs['shape'] = shape
        kwargs['texture'] = texture
        kwargs['center_ras'] = center_ras

        transform = tio.Compose((
            tio.ToCanonical(),
            resector.RandomResection(**kwargs),
        ))
        subject = tio.Subject(
            image=tio.ScalarImage(input_path),
            resection_resectable_left=tio.LabelMap(resectable_paths[0]),
            resection_resectable_right=tio.LabelMap(resectable_paths[1]),
            resection_gray_matter_left=tio.LabelMap(gm_paths[0]),
            resection_gray_matter_right=tio.LabelMap(gm_paths[1]),
            resection_noise=tio.ScalarImage(noise_path),
        )
        with resector.timer('RandomResection', verbose):
            transformed = transform(subject)
        with resector.timer('Saving images', verbose):
            transformed['image'].save(output_image_path)
            transformed['label'].save(output_label_path)
        return_code = 0
    except Exception as e:
        return_code = 1
        raise
    finally:
        if not existed and cleanup:
            with resector.timer('Cleaning up', verbose):
                for p in resectable_paths:
                    p.unlink()
                for p in gm_paths:
                    p.unlink()
                noise_path.unlink()
    return return_code