def get_train_transform(landmarks_path, resection_params=None):
    spatial_transform = tio.Compose((
        tio.OneOf({
            tio.RandomAffine(): 0.9,
            tio.RandomElasticDeformation(): 0.1,
        }),
        tio.RandomFlip(),
    ))
    resolution_transform = tio.OneOf(
        (
            tio.RandomAnisotropy(),
            tio.RandomBlur(),
        ),
        p=0.75,
    )
    transforms = []
    if resection_params is not None:
        transforms.append(get_simulation_transform(resection_params))
    if landmarks_path is not None:
        transforms.append(
            tio.HistogramStandardization({'image': landmarks_path}))
    transforms.extend([
        # tio.RandomGamma(p=0.2),
        resolution_transform,
        tio.RandomGhosting(p=0.2),
        tio.RandomSpike(p=0.2),
        tio.RandomMotion(p=0.2),
        tio.RandomBiasField(p=0.5),
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
        tio.RandomNoise(p=0.75),  # always after ZNorm and after blur!
        spatial_transform,
        get_tight_crop(),
    ])
    return tio.Compose(transforms)
示例#2
0
def load_pretrain_datasets(data_shape, batch=3, workers=4, transform=None):

    data_path = '/home/mitch/Data/MSD/'
    directories = sorted(glob.glob(data_path + '*/'))

    loaders = []  #var to store dataloader for each task
    datasets = []  #store dataset objects before turning into loaders

    if transform == None:
        transform = tio.RandomFlip(p=0.)
    #preprocess all
    clippy = Lambda(lambda x: torch.clip(x, -80, 300),
                    types_to_apply=[tio.INTENSITY])
    normal = RescaleIntensity((0., 1.))
    resize = Lambda(lambda x: torch.squeeze(
        interpolate(torch.unsqueeze(x, dim=0), data_shape), dim=0))
    rounding = Lambda(lambda x: torch.round(x), types_to_apply=[tio.LABEL])
    transform = tio.Compose([clippy, normal, resize, rounding, transform])

    #deal with weird shapes
    braintransform = Lambda(lambda x: torch.unsqueeze(x[:, :, :, 2], dim=0),
                            types_to_apply=[tio.INTENSITY])
    braintransform = tio.Compose([braintransform, transform])
    prostatetransform = Lambda(lambda x: torch.unsqueeze(x[:, :, :, 1], dim=0),
                               types_to_apply=[tio.INTENSITY])
    prostatetransform = tio.Compose([prostatetransform, transform])

    for i, directory in enumerate(directories):
        images = sorted(glob.glob(directory + 'imagesTr/*'))
        segs = sorted(glob.glob(directory + 'labelsTr/*'))

        subject_list = []

        for image, seg in zip(images, segs):

            subject_list.append(
                tio.Subject(img=tio.ScalarImage(image),
                            label=tio.LabelMap(seg)))

        #handle special cases
        if i == 0:
            datasets.append(
                tio.SubjectsDataset(subject_list, transform=braintransform))
        elif i == 4:
            datasets.append(
                tio.SubjectsDataset(subject_list, transform=prostatetransform))
        else:
            datasets.append(
                tio.SubjectsDataset(subject_list, transform=transform))

        loaders.append(
            DataLoader(datasets[-1],
                       num_workers=workers,
                       batch_size=batch,
                       pin_memory=True))

    return loaders
示例#3
0
def get_transform(augmentation, landmarks_path):
    import datasets
    import torchio as tio
    if augmentation:
        return datasets.get_train_transform(landmarks_path)
    else:
        preprocess = datasets.get_test_transform(landmarks_path)
        augment = tio.Compose((tio.RandomFlip(),
                               tio.OneOf({
                                   tio.RandomAffine(): 0.8,
                                   tio.RandomElasticDeformation(): 0.2,
                               })))
        return tio.Compose((preprocess, augment))
示例#4
0
def get_dataset(
    input_path,
    tta_iterations=0,
    interpolation='bspline',
    tolerance=0.1,
    mni_transform_path=None,
):
    if mni_transform_path is None:
        image = tio.ScalarImage(input_path)
    else:
        affine = tio.io.read_matrix(mni_transform_path)
        image = tio.ScalarImage(input_path, **{TO_MNI: affine})
    subject = tio.Subject({IMAGE_NAME: image})
    landmarks = np.array([
        0., 0.31331614, 0.61505419, 0.76732501, 0.98887953, 1.71169384,
        3.21741126, 13.06931455, 32.70817796, 40.87807389, 47.83508873,
        63.4408591, 100.
    ])
    hist_std = tio.HistogramStandardization({IMAGE_NAME: landmarks})
    preprocess_transforms = [
        tio.ToCanonical(),
        hist_std,
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
    ]
    zooms = nib.load(input_path).header.get_zooms()
    pixdim = np.array(zooms)
    diff_to_1_iso = np.abs(pixdim - 1)
    if np.any(diff_to_1_iso > tolerance) or mni_transform_path is not None:
        kwargs = {'image_interpolation': interpolation}
        if mni_transform_path is not None:
            kwargs['pre_affine_name'] = TO_MNI
            kwargs['target'] = tio.datasets.Colin27().t1.path
        resample_transform = tio.Resample(**kwargs)
        preprocess_transforms.append(resample_transform)
    preprocess_transforms.append(tio.EnsureShapeMultiple(8, method='crop'))
    preprocess_transform = tio.Compose(preprocess_transforms)
    no_aug_dataset = tio.SubjectsDataset([subject],
                                         transform=preprocess_transform)

    aug_subjects = tta_iterations * [subject]
    if not aug_subjects:
        return no_aug_dataset
    augment_transform = tio.Compose((
        preprocess_transform,
        tio.RandomFlip(),
        tio.RandomAffine(image_interpolation=interpolation),
    ))
    aug_dataset = tio.SubjectsDataset(aug_subjects,
                                      transform=augment_transform)
    dataset = torch.utils.data.ConcatDataset((no_aug_dataset, aug_dataset))
    return dataset
示例#5
0
    def get_volume_torchio_without_motion(self, idx, return_orig=False):
        subject_row = self.get_row(idx)
        dict_suj = dict()
        if not pd.isna(subject_row["image_filename"]):
            path_imgs = self.read_path(subject_row["image_filename"])
            if path_imgs:
                imgs = ScalarImage(path_imgs)
                dict_suj["t1"] = imgs

        if "label_filename" in subject_row.keys() and not pd.isna(
                subject_row["label_filename"]):
            path_imgs = self.read_path(subject_row["label_filename"])
            imgs = LabelMap(path_imgs)
            dict_suj["label"] = imgs
        sub = Subject(dict_suj)
        if "history" not in self.df_data.columns:
            return sub
        else:
            trsfms = self.get_transformations(idx)
            trsfms_short = []
            for tr in trsfms.transforms:  #.transforms:
                print(tr.name)
                if isinstance(tr, torchio.transforms.LabelsToImage):
                    tr.label_key = "label"
                if isinstance(tr, torchio.transforms.MotionFromTimeCourse):
                    tmot = tr
                    break
                trsfms_short.append(tr)
            trsfms_short = torchio.Compose(trsfms_short)
            res = trsfms_short(sub)
            return res, tmot
def initialize_transforms_simple(p=0.8):
    transforms = [
        RandomFlip(axes=(0, 1, 2), flip_probability=1, p=p),

        #RandomAffine(scales=(0.9, 1.1), degrees=(10), isotropic=False,
        #             default_pad_value='otsu', image_interpolation=Interpolation.LINEAR,
        #             p = p, seed=None),

        # *** SLOWS DOWN DATALOADER ***
        #RandomElasticDeformation(num_control_points = 7, max_displacement = 7.5,
        #                         locked_borders = 2, image_interpolation = Interpolation.LINEAR,
        #                         p = 0.5, seed = None),
        RandomMotion(degrees=10,
                     translation=10,
                     num_transforms=2,
                     image_interpolation='linear',
                     p=p),
        RandomAnisotropy(axes=(0, 1, 2), downsampling=2),
        RandomBiasField(coefficients=0.5, order=3, p=p),
        RandomBlur(std=(0, 2), p=p),
        RandomNoise(mean=0, std=(0, 5), p=p),
        RescaleIntensity((0, 255))
    ]
    transform = tio.Compose(transforms)
    return transform
示例#7
0
def getRandomMotionGhostingFast(
        degrees: Tuple[float] = (1.0, 3.0),
        translation: Tuple[float] = (1.0, 3.0),
        num_transforms: Tuple[int] = (2, 10),
        num_ghosts: Tuple[int] = (2, 5),
        intensity: Tuple[float] = (0.01, 0.75),
        restore: Tuple[float] = (0.01, 1.0),
        motion_image_interpolation: str = 'linear',
        ghosting_axes: Tuple[int] = (0, 1),
        p_motion: float = 1,
        p_ghosting: float = 1
):
    """
        Function that combines RandomMotion and RandomGhosting transforms of TorchIO into one transformation and returns a composed transformation
    """
    transform = tio.Compose([
        RandomMotionExtended(degrees=degrees,
                             translation=translation,
                             num_transforms=num_transforms,
                             image_interpolation=motion_image_interpolation,
                             p=p_motion),
        RandomGhostingExtended(num_ghosts=num_ghosts,
                               axes=ghosting_axes,
                               intensity=intensity,
                               restore=restore,
                               p=p_ghosting)
    ])
    return transform
示例#8
0
 def apply_transform(self, subject):
     im = deepcopy(subject['gt'])
     degrees = np.random.uniform(
         low=self.degrees[0], high=self.degrees[1], size=1)
     translation = np.random.uniform(
         low=self.translation[0], high=self.translation[1], size=1)
     num_transforms = np.random.randint(
         low=self.num_transforms[0], high=self.num_transforms[1], size=(1))
     num_ghosts = np.random.randint(
         low=self.num_ghosts[0], high=self.num_ghosts[1], size=(1))
     intensity = np.random.uniform(
         low=self.intensity[0], high=self.intensity[1], size=1)
     restore = np.random.uniform(
         low=self.restore[0], high=self.restore[1], size=1)
     transform_rnd = tio.Compose([
         tio.transforms.RandomMotion(degrees=float(degrees[0]),
                                     translation=float(translation[0]),
                                     num_transforms=int(num_transforms[0]),
                                     image_interpolation=self.motion_image_interpolation,
                                     p=self.p_motion),
         tio.transforms.RandomGhosting(num_ghosts=int(num_ghosts[0]),
                                       axes=self.ghosting_axes,
                                       intensity=float(intensity[0]),
                                       restore=float(restore[0]),
                                       p=self.p_ghosting)
     ])
     subject.add_image(transform_rnd(im), "inp")
     return subject
示例#9
0
def byol_aug(filename):
    """
        BYOL minimizes the distance between representations of each sample and a transformation of that sample.
        Examples of transformations include: translation, rotation, blurring, color inversion, color jitter, gaussian noise.

        Return an augmented dataset that consisted the above mentioned transformation. Will be used in the training.
        """
    image = tio.ScalarImage(filename)
    get_foreground = tio.ZNormalization.mean
    training_transform = tio.Compose([
        tio.CropOrPad((180, 220, 170)),  # zero mean, unit variance of foreground
        tio.ZNormalization(
            masking_method=get_foreground),
        tio.RandomBlur(p=0.25),  # blur 25% of times
        tio.RandomNoise(p=0.25),  # Gaussian noise 25% of times
        tio.OneOf({  # either
            tio.RandomAffine(): 0.8,  # random affine
            tio.RandomElasticDeformation(): 0.2,  # or random elastic deformation
        }, p=0.8),  # applied to 80% of images
        tio.RandomBiasField(p=0.3),  # magnetic field inhomogeneity 30% of times
        tio.OneOf({  # either
            tio.RandomMotion(): 1,  # random motion artifact
            tio.RandomSpike(): 2,  # or spikes
            tio.RandomGhosting(): 2,  # or ghosts
        }, p=0.5),  # applied to 50% of images
    ])

    tfs_image = training_transform(image)
    return tfs_image
示例#10
0
 def test_transforms(self):
     landmarks_dict = dict(
         t1=np.linspace(0, 100, 13),
         t2=np.linspace(0, 100, 13),
     )
     elastic = torchio.RandomElasticDeformation(max_displacement=1)
     transforms = (
         torchio.CropOrPad((9, 21, 30)),
         torchio.ToCanonical(),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=2, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(masking_method='label'),
         torchio.HistogramStandardization(landmarks_dict=landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1
         }),
         torchio.Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3),
         torchio.Crop((3, 2, 8, 0, 1, 4)),
     )
     transform = torchio.Compose(transforms)
     transform(self.sample)
示例#11
0
    def __call__(self, sample, metadata=None):
        if np.random.random() < self.p:
            # Get params
            random_bias_field = tio.Compose([
                tio.RandomBiasField(coefficients=self.coefficients,
                                    order=self.order,
                                    p=self.p)
            ])

            # Save params
            metadata[MetadataKW.BIAS_FIELD] = [random_bias_field]

        else:
            metadata[MetadataKW.BIAS_FIELD] = [None]

        if any(metadata[MetadataKW.BIAS_FIELD]):
            # Apply random bias field
            data_out, history = tio_transform(x=sample,
                                              transform=random_bias_field)

            # Keep data type
            data_out = data_out.astype(sample.dtype)

            # Update metadata to history
            metadata[MetadataKW.BIAS_FIELD] = [history]

            return data_out, metadata

        else:
            return sample, metadata
示例#12
0
def load_kidney_seg(data_shape, batch=3, workers=4, transform=None):

    #take input transform and apply it after clip, normalization, resize
    if transform == None:
        transform = tio.RandomFlip(p=0.)
    #preprocess all
    clippy = Lambda(lambda x: torch.clip(x, -80, 300),
                    types_to_apply=[tio.INTENSITY])
    normal = RescaleIntensity((0., 1.))
    resize = Lambda(lambda x: torch.squeeze(
        interpolate(torch.unsqueeze(x, dim=0), data_shape), dim=0))
    rounding = Lambda(lambda x: torch.round(x), types_to_apply=[tio.LABEL])
    transform = tio.Compose([clippy, normal, resize, rounding, transform])

    subject_list = []
    for i in range(210):
        pt_image = ("data/case_{:05d}/imaging.nii.gz".format(i))
        pt_label = ("data/case_{:05d}/segmentation.nii.gz".format(i))
        subject_list.append(
            tio.Subject(img=tio.ScalarImage(pt_image),
                        label=tio.LabelMap(pt_label)))
    dataset = tio.SubjectsDataset(subject_list, transform=transform)
    return DataLoader(dataset,
                      num_workers=workers,
                      batch_size=batch,
                      pin_memory=True)
示例#13
0
    def __getitem__(self, idx):
        # Generate one batch of data
        # ScalarImage expect 4DTensor, so add a singleton dimension
        image = self.CT_partition[idx].unsqueeze(0)
        mask = self.mask_partition[idx].unsqueeze(0)
        if self.augment:
            aug = tio.Compose([tio.OneOf\
                               ({tio.RandomAffine(scales= (0.9, 1.1, 0.9, 1.1, 1, 1),
                                                  degrees= (5.0, 5.0, 0)): 0.35,
                                 tio.RandomElasticDeformation(num_control_points=9,
                                                  max_displacement= (0.1, 0.1, 0.1),
                                                  locked_borders= 2,
                                                  image_interpolation= 'linear'): 0.35,
                                 tio.RandomFlip(axes=(2,)):.3}),
                              ])
            subject = tio.Subject(ct=tio.ScalarImage(tensor=image),
                                  mask=tio.ScalarImage(tensor=mask))
            output = aug(subject)
            augmented_image = output['ct']
            augmented_mask = output['mask']
            image = augmented_image.data
            mask = augmented_mask.data
        # note that mask is integer
        mask = mask.type(torch.IntTensor)
        image = image.type(torch.FloatTensor)

        #The tensor we pass into ScalarImage is C x W x H x D, so permute axes to
        # C x D x H x W. At the end we have N x 1 x D x H x W.
        image = image.permute(0, 3, 2, 1)
        mask = mask.permute(0, 3, 2, 1)

        # Return image and mask pair tensors
        return image, mask
def get_context(device, variables, augmentation_mode, **kwargs):
    context = base_config.get_context(device, variables, **kwargs)
    context.file_paths.append(os.path.abspath(__file__))
    context.config.update({'augmentation_mode': augmentation_mode})

    # training_transform is a tio.Compose where the second transform is the augmentation
    dataset_defn = context.get_component_definition("dataset")
    training_transform = dataset_defn['params']['transforms']['training']

    dwi_augmentation = ReconstructMeanDWI(num_dwis=(1, 7),
                                          num_directions=(1, 3),
                                          directionality=(4, 10))

    noise = tio.RandomNoise(std=0.035, p=0.3)
    blur = tio.RandomBlur((0, 1), p=0.2)
    standard_augmentations = tio.Compose([
        tio.RandomFlip(axes=(0, 1, 2)),
        tio.RandomElasticDeformation(p=0.5,
                                     num_control_points=(7, 7, 4),
                                     locked_borders=1,
                                     image_interpolation='bspline',
                                     exclude="full_dwi"),
        tio.RandomBiasField(p=0.5),
        tio.RescaleIntensity((0, 1), (0.01, 99.9)),
        tio.RandomGamma(p=0.8),
        tio.RescaleIntensity((-1, 1)),
        tio.OneOf([
            tio.Compose([blur, noise]),
            tio.Compose([noise, blur]),
        ])
    ],
                                         exclude="full_dwi")

    if augmentation_mode == 'no_augmentation':
        training_transform.transforms.pop(1)
    elif augmentation_mode == 'standard':
        training_transform.transforms[1] = standard_augmentations
    elif augmentation_mode == 'dwi_reconstruction':
        training_transform.transforms[1] = dwi_augmentation
    elif augmentation_mode == 'combined':
        training_transform.transforms[1] = tio.Compose(
            [dwi_augmentation, standard_augmentations])
    else:
        raise ValueError(f"Invalid augmentation mode {augmentation_mode}")

    return context
示例#15
0
def get_test_transform(landmarks_path):
    transforms = []
    if landmarks_path is not None:
        transforms.append(tio.HistogramStandardization({'image': landmarks_path}))
    transforms.extend([
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
        get_tight_crop(),
    ])
    return tio.Compose(transforms)
示例#16
0
 def get_large_composed_transform(self):
     all_classes = get_all_random_transforms()
     shuffle(all_classes)
     transforms = [t() for t in all_classes]
     # Hack as default patch size for RandomSwap is 15 and sample_subject
     # is (10, 20, 30)
     for tr in transforms:
         if tr.name == 'RandomSwap':
             tr.patch_size = np.array((10, 10, 10))
     return tio.Compose(transforms)
示例#17
0
def createTIODynDS(path_gt,
                   path_corrupt,
                   is_infer=False,
                   p=1,
                   transforms=[],
                   **kwargs):
    files_gt = glob(path_gt + "/**/*.nii", recursive=True) + glob(
        path_gt + "/**/*.nii.gz", recursive=True)
    if path_corrupt:
        files_inp = glob(path_corrupt + "/**/*.nii", recursive=True) + glob(
            path_corrupt + "/**/*.nii.gz", recursive=True)
        corruptFly = False
    else:
        files_inp = files_gt.copy()
        corruptFly = True
    subjects = []

    inp_dicts, files_inp = __process_TPs(files_inp)
    gt_dicts, _ = __process_TPs(files_gt)
    for filename in files_inp:
        inp_files = [d for d in inp_dicts if filename in d['filename']]
        gt_files = [d for d in gt_dicts if filename in d['filename']]
        tps = list(set(dic["tp"] for dic in inp_files))
        tp_prev = tps.pop(0)
        for tp in tps:
            inp_tp_prev = [d for d in inp_files if tp_prev == d['tp']]
            gt_tp_prev = [d for d in gt_files if tp_prev == d['tp']]
            inp_tp = [d for d in inp_files if tp == d['tp']]
            gt_tp = [d for d in gt_files if tp == d['tp']]
            tp_prev = tp
            if len(gt_tp_prev) > 0 and len(gt_tp) > 0:
                subjects.append(
                    tio.Subject(
                        gt_tp_prev=tio.ScalarImage(gt_tp_prev[0]['path']),
                        inp_tp_prev=tio.ScalarImage(inp_tp_prev[0]['path']),
                        gt=tio.ScalarImage(gt_tp[0]['path']),
                        inp=tio.ScalarImage(inp_tp[0]['path']),
                        filename=filename,
                        tp=tp,
                        tag="CorruptNGT",
                    ))

            else:
                print(
                    "Warning: Not Implemented if GT is missing. Skipping Sub-TP."
                )
                continue

    if corruptFly:
        moco = MotionCorrupter(**kwargs)
        transforms.append(tio.Lambda(moco.perform, p=p))
    transforms.append(ProcessTIOSubsTPs())
    transform = tio.Compose(transforms)
    subjects_dataset = tio.SubjectsDataset(subjects, transform=transform)
    return subjects_dataset
def hippo_inference(context, args, i, log_callback=None):
    subject_name = context.dataset.subjects_dataset.subject_folder_names[i]
    log_text = f"subject {subject_name}: "
    log = False

    inverse_transforms = tio.Compose([
        tio.Crop((0, 0, 0, 0, 2, 2)),
        tio.Pad((62, 62, 70, 58, 0, 0)),
    ])
    with torch.no_grad():
        left_side_prob = context.model(context.dataset[i*2][0].to(context.device))[0]
        right_side_prob = context.model(context.dataset[i*2 + 1][0].to(context.device))[0]
    if args.output_probabilities:
        right_side_prob = torch.flip(right_side_prob, dims=(1,))
        out = torch.cat((right_side_prob, left_side_prob), dim=1)
        out = out.cpu()
        out = inverse_transforms(out)
        return out

    left_side = torch.argmax(left_side_prob, dim=0)
    right_side = torch.argmax(right_side_prob, dim=0)

    if args.lateral_uniformity:
        left_side, left_removed_count = lateral_uniformity(left_side, left_side_prob, return_counts=True)
        right_side, right_removed_count = lateral_uniformity(right_side, right_side_prob, return_counts=True)
        total_removed = left_removed_count + right_removed_count
        if total_removed > 0:
            log_text += f" Changed {total_removed} voxels to enforce lateral uniformity."


    left_side[left_side != 0] += torch.max(right_side)
    right_side = torch.flip(right_side, dims=(0,))
    out = torch.cat((right_side, left_side), dim=0)

    out = out.cpu().numpy()

    if args.remove_isolated_components:
        num_components = out.max()
        out, components_removed, component_voxels_removed = keep_components(out, num_components, return_counts=True)
        if component_voxels_removed > 0:
            log_text += f" Removed {component_voxels_removed} voxels from " \
                        f"{components_removed} detected isolated components."
            log = True
    if args.remove_holes:
        out, hole_voxels_removed = remove_holes(out, hole_size=64, return_counts=True)
        if hole_voxels_removed > 0:
            log_text += f" Filled {hole_voxels_removed} voxels from detected holes."
            log = True
    if log:
        log_callback(log_text)

    out = torch.from_numpy(out).unsqueeze(0)
    out = inverse_transforms(out)

    return out
示例#19
0
    def test_all_random_transforms(self):
        sample = Subject(t1=ScalarImage(tensor=torch.rand(1, 20, 20, 20)),
                         seg=LabelMap(tensor=torch.rand(1, 20, 20, 20) > 1))

        transforms_names = [
            name for name in dir(torchio) if name.startswith('Random')
        ]

        # Downsample at the end so that image shape is not modified
        transforms_names.remove('RandomDownsample')
        transforms_names.append('RandomDownsample')

        transforms = []
        for transform_name in transforms_names:
            # Only transform needing an argument for __init__
            if transform_name == 'RandomLabelsToImage':
                transform = getattr(torchio, transform_name)(label_key='seg')
            else:
                transform = getattr(torchio, transform_name)()
            transforms.append(transform)
        composed_transform = torchio.Compose(transforms)
        with warnings.catch_warnings():  # ignore elastic deformation warning
            warnings.simplefilter('ignore', RuntimeWarning)
            transformed = composed_transform(sample)

        new_transforms, seeds = compose_from_history(transformed.history)
        new_transformed = self.apply_transforms(subject=sample,
                                                trsfm_list=new_transforms,
                                                seeds_list=seeds)
        """
        new_transforms = []
        seeds = []

        for transform_name, params_dict in transformed.history:
            # The Resample transform in the history comes from the DownSampling
            if transform_name in ['Resample', 'Compose']:
                continue
            transform_class = getattr(torchio, transform_name)

            if transform_name == 'RandomLabelsToImage':
                transform = transform_class(label_key='seg')
            else:
                transform = transform_class()
            new_transforms.append(transform)
            seeds.append(params_dict['seed'])

        composed_transform = torchio.Compose(new_transforms)
        with warnings.catch_warnings():  # ignore elastic deformation warning
            warnings.simplefilter('ignore', RuntimeWarning)
            new_transformed = composed_transform(sample, seeds=seeds)
        """

        self.assertTensorEqual(transformed.t1.data, new_transformed.t1.data)
        self.assertTensorEqual(transformed.seg.data, new_transformed.seg.data)
示例#20
0
    def __init__(self, use_tio_flip=True):

        self.flip = tio.RandomFlip(p=0.5) if use_tio_flip is True else Flip3D()
        self.affine = tio.RandomAffine(p=0.5,
                                       scales=0.1,
                                       degrees=5,
                                       translation=0,
                                       image_interpolation="nearest")
        self.random_noise = tio.RandomNoise(
            p=0.5, std=(0, 0.1), include=["x"])  # don't apply noise to mask
        self.transform = tio.Compose(
            [self.flip, self.affine, self.random_noise], include=["x", "y"])
示例#21
0
def main(input_path, checkpoint_path, output_dir, landmarks_path, batch_size, num_workers, resample):
    import torch
    from tqdm import tqdm
    import torchio as tio
    import models
    import datasets

    fps = get_paths(input_path)
    subjects = [tio.Subject(image=tio.ScalarImage(fp)) for fp in fps]  # key must be 'image' as in get_test_transform
    transform = tio.Compose((
        tio.ToCanonical(),
        datasets.get_test_transform(landmarks_path),
    ))
    if resample:
        transform = tio.Compose((
            tio.Resample(),
            transform,
            # tio.CropOrPad((264, 268, 144)),  # ################################# for BITE?
        ))
    dataset = tio.SubjectsDataset(subjects, transform)
    checkpoint = torch.load(checkpoint_path)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = models.get_unet().to(device)
    model.load_state_dict(checkpoint['model'])
    output_dir = Path(output_dir)
    model.eval()
    torch.set_grad_enabled(False)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
    output_dir.mkdir(exist_ok=True, parents=True)
    for batch in tqdm(loader):
        inputs = batch['image'][tio.DATA].float().to(device)
        seg = model(inputs).softmax(dim=1)[:, 1:].cpu() > 0.5
        for tensor, affine, path in zip(seg, batch['image'][tio.AFFINE], batch['image'][tio.PATH]):
            image = tio.LabelMap(tensor=tensor, affine=affine.numpy())
            path = Path(path)
            out_path = output_dir / path.name.replace('.nii', '_seg_cnn.nii')
            image.save(out_path)
    return 0
示例#22
0
def get_test_time_transform():
    for dim_perm in itertools.permutations((0, 1, 2)):
        for flip_mask in itertools.product([False, True], repeat=3):
            flip_axis = np.array([0, 1, 2])
            axes_to_flip = flip_axis[np.array(flip_mask)]

            transforms = [PermuteDimensions(dim_perm)]

            if axes_to_flip.size != 0:
                transforms.append(tio.Flip(tuple(axes_to_flip.tolist())))

            composeTrans = tio.Compose(transforms)

            yield composeTrans
示例#23
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = tio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     resize_args = (10, 20, 30) if is_3d else (10, 20, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     remapping = {1: 2, 2: 1, 3: 20, 4: 25}
     transforms = [
         tio.CropOrPad(cp_args),
         tio.EnsureShapeMultiple(2, method='crop'),
         tio.Resize(resize_args),
         tio.ToCanonical(),
         tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
         tio.CopyAffine(channels[0]),
         tio.Resample((1, 1.1, 1.25)),
         tio.RandomFlip(axes=flip_axes, flip_probability=1),
         tio.RandomMotion(),
         tio.RandomGhosting(axes=(0, 1, 2)),
         tio.RandomSpike(),
         tio.RandomNoise(),
         tio.RandomBlur(),
         tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
         tio.RandomBiasField(),
         tio.RescaleIntensity(out_min_max=(0, 1)),
         tio.ZNormalization(),
         tio.HistogramStandardization(landmarks_dict),
         elastic,
         tio.RandomAffine(),
         tio.OneOf({
             tio.RandomAffine(): 3,
             elastic: 1,
         }),
         tio.RemapLabels(remapping=remapping, masking_method='Left'),
         tio.RemoveLabels([1, 3]),
         tio.SequentialLabels(),
         tio.Pad(pad_args, padding_mode=3),
         tio.Crop(crop_args),
     ]
     if labels:
         transforms.append(tio.RandomLabelsToImage(label_key='label'))
     return tio.Compose(transforms)
示例#24
0
def create_trainDS(path, p=1, **kwargs):
    files = glob(path + "/**/*.nii", recursive=True) + glob(
        path + "/**/*.nii.gz", recursive=True)
    subjects = []
    for file in files:
        subjects.append(
            tio.Subject(
                im=tio.ScalarImage(file),
                filename=os.path.basename(file),
            ))
    moco = MotionCorrupter(**kwargs)
    transforms = [tio.Lambda(moco.perform, p=p)]
    transform = tio.Compose(transforms)
    subjects_dataset = tio.SubjectsDataset(subjects, transform=transform)
    return subjects_dataset
示例#25
0
文件: data.py 项目: JIiminIT/Torch
def normalization(histogram_transform, dataset):
    znorm_transform = tio.ZNormalization(
        masking_method=tio.ZNormalization.mean)

    sample = dataset[0]
    transform = tio.Compose([histogram_transform, znorm_transform])
    znormed = transform(sample)

    fig, ax = plt.subplots(dpi=100)
    plot_histogram(ax, znormed.mri.data, label='Z-normed', alpha=1)
    ax.set_title('Intensity values of one sample after z-normalization')
    ax.set_xlabel('Intensity')
    ax.grid()
    plt.show()
    ax.show()
示例#26
0
def filter_transform(
    transform: tio.Compose,
    include_types: Sequence[Type[tio.Transform]] = None,
    exclude_types: Sequence[Type[tio.Transform]] = None,
):
    if isinstance(transform, tio.Compose):
        return tio.Compose([
            filter_transform(t,
                             include_types=include_types,
                             exclude_types=exclude_types) for t in transform
            if isinstance(t, tio.Compose) or ((include_types is None or any(
                isinstance(t, typ) for typ in include_types)) and (
                    exclude_types is None or not any(
                        isinstance(t, typ) for typ in exclude_types)))
        ])
    return transform
示例#27
0
def create_trainDS_precorrupt(path_gt, path_corrupt, p=1, norm_mode=0):
    files = glob(path_gt + "/**/*.nii", recursive=True) + glob(
        path_gt + "/**/*.nii.gz", recursive=True)
    subjects = []
    for file in files:
        subjects.append(
            tio.Subject(
                im=tio.ScalarImage(file),
                filename=os.path.basename(file),
            ))
    transforms = [
        ReadCorrupted(path_corrupt=path_corrupt, p=p, norm_mode=norm_mode)
    ]
    transform = tio.Compose(transforms)
    subjects_dataset = tio.SubjectsDataset(subjects, transform=transform)
    return subjects_dataset
示例#28
0
    def test_get_subjects(self):
        ct = tio.ScalarImage(tensor=torch.rand(1, 3, 3, 2))
        structure = tio.LabelMap(
            tensor=torch.ones((1, 3, 3, 2), dtype=torch.uint8))

        subject_1_dir = "tests/test_data/subjects/subject_1"

        os.makedirs(subject_1_dir, exist_ok=True)
        ct.save(os.path.join(subject_1_dir, "ct.nii"))
        structure.save(os.path.join(subject_1_dir, "structure.nii"))
        transform = tio.Compose(
            [tio.ToCanonical(),
             tio.RescaleIntensity(1, (1, 99.0))])
        subject_dataset = get_subjects(os.path.dirname(subject_1_dir),
                                       structures=["structure"],
                                       transform=transform)
        self.assertEqual(len(subject_dataset), 1)
        shutil.rmtree(os.path.dirname(subject_1_dir), ignore_errors=True)
示例#29
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = torchio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     transforms = [
         torchio.CropOrPad(cp_args),
         torchio.ToCanonical(),
         torchio.RandomDownsample(downsampling=(1.75, 2),
                                  axes=axes_downsample),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=flip_axes, flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(),
         torchio.HistogramStandardization(landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1,
         }),
         torchio.Pad(pad_args, padding_mode=3),
         torchio.Crop(crop_args),
     ]
     if labels:
         transforms.append(torchio.RandomLabelsToImage(label_key='label'))
     return torchio.Compose(transforms)
示例#30
0
 def test_batch_history(self):
     # https://github.com/fepegar/torchio/discussions/743
     subject = self.sample_subject
     transform = tio.Compose([
         tio.RandomAffine(),
         tio.CropOrPad(5),
         tio.OneHot(),
     ])
     dataset = tio.SubjectsDataset([subject], transform=transform)
     loader = torch.utils.data.DataLoader(
         dataset,
         collate_fn=tio.utils.history_collate
     )
     batch = tio.utils.get_first_item(loader)
     transformed: tio.Subject = tio.utils.get_subjects_from_batch(batch)[0]
     inverse = transformed.apply_inverse_transform()
     images1 = subject.get_images(intensity_only=False)
     images2 = inverse.get_images(intensity_only=False)
     for image1, image2 in zip(images1, images2):
         assert image1.shape == image2.shape