Esempio n. 1
0
    def __getitem__(self, idx):

        #modify the collate_fn from the dataloader so that it filters out None elements.
        img_name = self.list_images[idx]

        try:
            _, image, spacing = load_reorient(img_name)
        except:
            spacing = [2, 2, 2]
            print('error loading {0}'.format(img_name))
            if self.remove_corrupt and os.path.isfile(img_name):
                os.remove(img_name)
            image, _ = self.get_random(fa=True)

        if not np.any(image):
            image, _ = self.get_random(fa=True)

        #preprocessing
        original_spacing = np.array(spacing)
        get_foreground = tio.ZNormalization.mean
        target_shape = 128, 128, 128
        crop_pad = tio.CropOrPad(target_shape)

        ###operations###
        standardize = tio.ZNormalization(masking_method=get_foreground)
        if 'wb' not in self.class_names and 'abd-pel' not in self.class_names:
            downsample = tio.Resample(
                (2 / spacing[0], 2 / spacing[1], 2 / spacing[2]))
            try:
                image = standardize(crop_pad(downsample(image)))
            except:
                print(img_name)
        else:
            x, y, z = image.shape[1:] / np.asarray(target_shape)
            downsample = tio.Resample((x, y, z))
            image = standardize(crop_pad(downsample(image)))

        if image.shape[0] > 1:

            image = np.expand_dims(image[0, :, :, :], axis=0)
            #print(image.shape)
            print(img_name)

        sample = {
            'image': torch.from_numpy(image),
            'label': np.array(0),
            'spacing': original_spacing,
            'fn': img_name
        }

        return sample
Esempio n. 2
0
 def test_different_spaces(self):
     t1 = self.sample_subject.t1
     label = tio.Resample(2)(self.sample_subject.label)
     new_subject = tio.Subject(t1=t1, label=label)
     with self.assertRaises(RuntimeError):
         tio.RandomAffine()(new_subject)
     tio.RandomAffine(check_shape=False)(new_subject)
Esempio n. 3
0
 def test_bad_affine(self):
     shape = 1, 2, 3
     affine = np.eye(3)
     target = shape, affine
     transform = tio.Resample(target)
     with self.assertRaises(RuntimeError):
         transform(self.sample_subject)
Esempio n. 4
0
 def test_spacing(self):
     # Should this raise an error if sizes are different?
     spacing = 2
     transform = tio.Resample(spacing)
     transformed = transform(self.sample_subject)
     for image in transformed.get_images(intensity_only=False):
         self.assertEqual(image.spacing, 3 * (spacing, ))
Esempio n. 5
0
 def test_reference_path(self):
     reference_image, reference_path = self.get_reference_image_and_path()
     transform = tio.Resample(reference_path)
     transformed = transform(self.sample_subject)
     for image in transformed.values():
         self.assertEqual(reference_image.shape, image.shape)
         self.assertTensorAlmostEqual(reference_image.affine, image.affine)
Esempio n. 6
0
 def test_transforms(self):
     landmarks_dict = dict(
         t1=np.linspace(0, 100, 13),
         t2=np.linspace(0, 100, 13),
     )
     elastic = torchio.RandomElasticDeformation(max_displacement=1)
     transforms = (
         torchio.CropOrPad((9, 21, 30)),
         torchio.ToCanonical(),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=2, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(masking_method='label'),
         torchio.HistogramStandardization(landmarks_dict=landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1
         }),
         torchio.Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3),
         torchio.Crop((3, 2, 8, 0, 1, 4)),
     )
     transform = torchio.Compose(transforms)
     transform(self.sample)
Esempio n. 7
0
 def test_reference_name(self):
     subject = self.get_inconsistent_shape_subject()
     reference_name = 't1'
     transform = tio.Resample(reference_name)
     transformed = transform(subject)
     reference_image = subject[reference_name]
     for image in transformed.get_images(intensity_only=False):
         self.assertEqual(reference_image.shape, image.shape)
         self.assertTensorAlmostEqual(reference_image.affine, image.affine)
Esempio n. 8
0
 def test_affine(self):
     spacing = 1
     affine_name = 'pre_affine'
     transform = tio.Resample(spacing, pre_affine_name=affine_name)
     transformed = transform(self.sample_subject)
     for image in transformed.values():
         if affine_name in image:
             target_affine = np.eye(4)
             target_affine[:3, 3] = 10, 0, -0.1
             self.assertTensorAlmostEqual(image.affine, target_affine)
         else:
             self.assertTensorEqual(image.affine, np.eye(4))
Esempio n. 9
0
def inference(dataset, model):

    for i in range(len(dataset)):
        subject = dataset[i]
        untransformed_subject = dataset.subjects[i]

        print(f"Running model for subject {subject['name']}")

        out_folder = args.out_folder
        if out_folder == "":
            out_folder = Path(subject["folder"])
        else:
            out_folder = Path(args.out_folder) / subject['name']
            out_folder.mkdir(exist_ok=True, parents=True)

        with torch.no_grad():
            subject = patch_predict(model=model,
                                    device=device,
                                    subjects=[subject],
                                    patch_batch_size=1,
                                    patch_size=96,
                                    patch_overlap=48,
                                    padding_mode="edge",
                                    overlap_mode="average")[0]

        transform = subject.get_composed_history()
        inverse_transform = transform.inverse(warn=False)

        pred_subject = tio.Subject({'y': subject['y_pred']})
        inverse_pred_subject = inverse_transform(pred_subject)
        output_label = inverse_pred_subject.get_first_image()

        label_data = output_label['data'][0].numpy()

        label_data, hole_voxels_removed = remove_holes(label_data, hole_size=64)
        print(f"Filled {hole_voxels_removed} voxels from detected holes.")

        label_data, small_lesions_removed = remove_small_components(label_data, 3)
        print(f"Removed {small_lesions_removed} voxels from small predictions less than size 3.")

        label_data = torch.from_numpy(label_data[None]).to(torch.int32)
        output_label.set_data(label_data)

        target_image = untransformed_subject.get_first_image()
        output_label = tio.Resample(target_image)(output_label)

        if output_label.spatial_shape != target_image.spatial_shape:
            raise Warning(f"Segmentation shape and original image shape do not match.")

        print()

        output_label.save(out_folder / args.output_filename)
Esempio n. 10
0
    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.df.iloc[index, 1])
        subject = tio.Subject(img=Image(img_path, type=tio.INTENSITY))

        if (self.transform):
            # in training phase
            transformations = (
                tio.ZNormalization(),
                tio.Resample(target=2,
                             pre_affine_name="affine"),  # preprocessing
                tio.OneOf(transforms_dict),
                tio.OneOf(transforms_dict2))
        else:
            # in validation and testing phase
            transformations = (
                tio.ZNormalization(),
                tio.Resample(target=2,
                             pre_affine_name="affine")  # preprocessing
            )

        transformations = Compose(transformations)
        transformed_image = transformations(subject)

        get_image = transformed_image.img
        tensor_resampled_image = get_image.data
        tensor_resampled_image = tensor_resampled_image.unsqueeze(
            dim=0)  # adding batch size

        resampled_image = torch.nn.functional.interpolate(
            input=tensor_resampled_image,
            size=(256, 256, 166),
            mode='trilinear'
        )  # trilinear because it had 5D (mini-batch x channels x height x width x depth)
        resampled_image = np.reshape(resampled_image, (1, 256, 256, 166))

        y_label = 0.0 if self.df.iloc[index, 2] == 'AD' else 1.0
        y_label = torch.tensor(y_label, dtype=torch.float)

        return resampled_image, y_label
Esempio n. 11
0
def get_dataset(
    input_path,
    tta_iterations=0,
    interpolation='bspline',
    tolerance=0.1,
    mni_transform_path=None,
):
    if mni_transform_path is None:
        image = tio.ScalarImage(input_path)
    else:
        affine = tio.io.read_matrix(mni_transform_path)
        image = tio.ScalarImage(input_path, **{TO_MNI: affine})
    subject = tio.Subject({IMAGE_NAME: image})
    landmarks = np.array([
        0., 0.31331614, 0.61505419, 0.76732501, 0.98887953, 1.71169384,
        3.21741126, 13.06931455, 32.70817796, 40.87807389, 47.83508873,
        63.4408591, 100.
    ])
    hist_std = tio.HistogramStandardization({IMAGE_NAME: landmarks})
    preprocess_transforms = [
        tio.ToCanonical(),
        hist_std,
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
    ]
    zooms = nib.load(input_path).header.get_zooms()
    pixdim = np.array(zooms)
    diff_to_1_iso = np.abs(pixdim - 1)
    if np.any(diff_to_1_iso > tolerance) or mni_transform_path is not None:
        kwargs = {'image_interpolation': interpolation}
        if mni_transform_path is not None:
            kwargs['pre_affine_name'] = TO_MNI
            kwargs['target'] = tio.datasets.Colin27().t1.path
        resample_transform = tio.Resample(**kwargs)
        preprocess_transforms.append(resample_transform)
    preprocess_transforms.append(tio.EnsureShapeMultiple(8, method='crop'))
    preprocess_transform = tio.Compose(preprocess_transforms)
    no_aug_dataset = tio.SubjectsDataset([subject],
                                         transform=preprocess_transform)

    aug_subjects = tta_iterations * [subject]
    if not aug_subjects:
        return no_aug_dataset
    augment_transform = tio.Compose((
        preprocess_transform,
        tio.RandomFlip(),
        tio.RandomAffine(image_interpolation=interpolation),
    ))
    aug_dataset = tio.SubjectsDataset(aug_subjects,
                                      transform=augment_transform)
    dataset = torch.utils.data.ConcatDataset((no_aug_dataset, aug_dataset))
    return dataset
Esempio n. 12
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = tio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     resize_args = (10, 20, 30) if is_3d else (10, 20, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     remapping = {1: 2, 2: 1, 3: 20, 4: 25}
     transforms = [
         tio.CropOrPad(cp_args),
         tio.EnsureShapeMultiple(2, method='crop'),
         tio.Resize(resize_args),
         tio.ToCanonical(),
         tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
         tio.CopyAffine(channels[0]),
         tio.Resample((1, 1.1, 1.25)),
         tio.RandomFlip(axes=flip_axes, flip_probability=1),
         tio.RandomMotion(),
         tio.RandomGhosting(axes=(0, 1, 2)),
         tio.RandomSpike(),
         tio.RandomNoise(),
         tio.RandomBlur(),
         tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
         tio.RandomBiasField(),
         tio.RescaleIntensity(out_min_max=(0, 1)),
         tio.ZNormalization(),
         tio.HistogramStandardization(landmarks_dict),
         elastic,
         tio.RandomAffine(),
         tio.OneOf({
             tio.RandomAffine(): 3,
             elastic: 1,
         }),
         tio.RemapLabels(remapping=remapping, masking_method='Left'),
         tio.RemoveLabels([1, 3]),
         tio.SequentialLabels(),
         tio.Pad(pad_args, padding_mode=3),
         tio.Crop(crop_args),
     ]
     if labels:
         transforms.append(tio.RandomLabelsToImage(label_key='label'))
     return tio.Compose(transforms)
Esempio n. 13
0
    def apply_transform(self, subject):
        current_spacing = subject.get_first_image().spacing

        if isinstance(self.target_spacing, str):
            target_spacing = self.spacing_modes[self.target_spacing](
                current_spacing)
            target_spacing = (target_spacing, target_spacing, target_spacing)
        else:
            target_spacing = self.target_spacing

        # No operation if all current spacings are within tolerance of the target spacing
        if all(
                abs(cur - tar) < tol for cur, tar, tol in zip(
                    current_spacing, target_spacing, self.tolerance)):
            return subject

        # Iteratively scale
        new_spacing = []
        for cur, tar, tol in zip(current_spacing, target_spacing,
                                 self.tolerance):
            step = 1

            spacing = cur

            while abs(spacing - tar) > tol:

                if cur < tar:
                    scale = tar / cur
                    scale = round(scale * step) / step
                else:
                    scale = cur / tar
                    scale = 1 / (round(scale * step) / step)

                spacing = cur * scale
                step += 1

            new_spacing.append(spacing)

        new_spacing = tuple(new_spacing)
        resample = tio.Resample(target=new_spacing,
                                image_interpolation=self.image_interpolation,
                                pre_affine_name=self.pre_affine_name,
                                scalars_only=self.scalars_only,
                                **self.kwargs)
        subject = resample(subject)

        return subject
Esempio n. 14
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = torchio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     transforms = [
         torchio.CropOrPad(cp_args),
         torchio.ToCanonical(),
         torchio.RandomDownsample(downsampling=(1.75, 2),
                                  axes=axes_downsample),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=flip_axes, flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(),
         torchio.HistogramStandardization(landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1,
         }),
         torchio.Pad(pad_args, padding_mode=3),
         torchio.Crop(crop_args),
     ]
     if labels:
         transforms.append(torchio.RandomLabelsToImage(label_key='label'))
     return torchio.Compose(transforms)
Esempio n. 15
0
    def preprocess(self, raw):

        resample = tio.Resample(raw.shape[0] / self.dim[0],
                                image_interpolation='bspline')
        resampled = resample(np.expand_dims(raw, 0))

        # #fix if image is bigger than fixed third dimension
        # if raw.shape[2] > self.upper:
        #     # crop the image along the third dimension
        #     dh = int(raw.shape[2] - self.upper / 2)
        #     raw = raw[:, :, dh:-dh]
        #
        # # resizing
        # if self.dim[:2] != raw.shape[:-1]:
        #     output = np.zeros((self.dim[0], self.dim[1], self.dim[2]))
        #     for i in range(raw.shape[-1]):
        #         try:
        #             output[:, :, i] = cv2.resize(raw[:, :, i].astype('float32'), (self.dim[0], self.dim[1]))
        #         except Exception as e:
        #             pass
        #             #print(e)
        #     raw = output
        #
        # # check the third dimension
        # if raw.shape[2] < self.upper:
        #
        #     ##pad the image with zeros
        #
        #     if (self.upper - raw.shape[2]) % 2 == 0:
        #         w = int((self.upper - raw.shape[-1]) / 2)
        #         u = np.zeros((self.dim[0], self.dim[1], w))
        #         raw = np.concatenate((u, raw, u), axis=-1)
        #     else:
        #         w = int((self.upper - raw.shape[-1] - 1) / 2)
        #         u = np.zeros((self.dim[0], self.dim[1], w))
        #         d = np.zeros((self.dim[0], self.dim[1], w + 1))
        #         raw = np.concatenate((u, raw, d), axis=-1)

        # print(f"[DEBUG2] {raw.shape}")

        return resampled
Esempio n. 16
0
    def test_dataset(self, modalities):
        """
        Testing the Dataset class
        Note that test is not for 
        """
        output_path_mod = pathlib.Path(self.output_path, str("temp_folder_" + ("_").join(modalities.split(",")))).as_posix()
        comp_path = pathlib.Path(output_path_mod).resolve().joinpath('dataset.csv').as_posix()
        comp_table = pd.read_csv(comp_path, index_col=0)
        print(comp_path, comp_table)
        
        #Loading from nrrd files
        subjects_nrrd = Dataset.load_image(output_path_mod, ignore_multi=True)
        #Loading files directly
        # subjects_direct = Dataset.load_directly(self.input_path,modalities=modalities,ignore_multi=True)
        
        #The number of subjects is equal to the number of components which is 2 for this dataset
        # assert len(subjects_nrrd) == len(subjects_direct) == 2, "There was some error in generation of subject object"
        # assert subjects_nrrd[0].keys() == subjects_direct[0].keys()

        # del subjects_direct
        # To check if all metadata items present in the keys
        # temp_nrrd = subjects_nrrd[0]
        # columns_shdbe_present = set([col if col.split("_")[0]=="metadata" else "mod_"+("_").join(col.split("_")[1:]) for col in list(comp_table.columns) if col.split("_")[0] in ["folder","metadata"]])
        # print(columns_shdbe_present)
        # assert set(temp_nrrd.keys()).issubset(columns_shdbe_present), "Not all items present in dictionary, some fault in going through the different columns in a single component"

        transforms = tio.Compose([tio.Resample(4), tio.CropOrPad((96,96,40)), select_roi_names(["larynx"]), tio.OneHot()])

        #Forming dataset and dataloader
        test_set = tio.SubjectsDataset(subjects_nrrd, transform=transforms)
        test_loader = torch.utils.data.DataLoader(test_set,batch_size=2,shuffle=True,collate_fn = collate_fn)

        #Check test_set is correct
        assert len(test_set)==2

        #Get items from test loader
        #If this function fails , there is some error in formation of test
        data = next(iter(test_loader))
        A = [item[1].shape == (2,1,96,96,40) if not "RTSTRUCT" in item[0] else item[1].shape == (2,2,96,96,40) for item in data.items()]
        assert all(A), "There is some problem in the transformation/the formation of subject object"
Esempio n. 17
0
def main(input_path, checkpoint_path, output_dir, landmarks_path, batch_size, num_workers, resample):
    import torch
    from tqdm import tqdm
    import torchio as tio
    import models
    import datasets

    fps = get_paths(input_path)
    subjects = [tio.Subject(image=tio.ScalarImage(fp)) for fp in fps]  # key must be 'image' as in get_test_transform
    transform = tio.Compose((
        tio.ToCanonical(),
        datasets.get_test_transform(landmarks_path),
    ))
    if resample:
        transform = tio.Compose((
            tio.Resample(),
            transform,
            # tio.CropOrPad((264, 268, 144)),  # ################################# for BITE?
        ))
    dataset = tio.SubjectsDataset(subjects, transform)
    checkpoint = torch.load(checkpoint_path)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = models.get_unet().to(device)
    model.load_state_dict(checkpoint['model'])
    output_dir = Path(output_dir)
    model.eval()
    torch.set_grad_enabled(False)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
    output_dir.mkdir(exist_ok=True, parents=True)
    for batch in tqdm(loader):
        inputs = batch['image'][tio.DATA].float().to(device)
        seg = model(inputs).softmax(dim=1)[:, 1:].cpu() > 0.5
        for tensor, affine, path in zip(seg, batch['image'][tio.AFFINE], batch['image'][tio.PATH]):
            image = tio.LabelMap(tensor=tensor, affine=affine.numpy())
            path = Path(path)
            out_path = output_dir / path.name.replace('.nii', '_seg_cnn.nii')
            image.save(out_path)
    return 0
Esempio n. 18
0
def save_feature_maps(input_path, output_dir):
    torch.set_grad_enabled(False)
    output_dir = Path(output_dir).expanduser().absolute()
    output_dir.mkdir(exist_ok=True, parents=True)
    device = get_device()
    repo = 'fepegar/resseg'
    model_name = 'ressegnet'
    model = torch.hub.load(repo, model_name)
    model.to(device)
    model.eval()
    hooks = []
    for args, module in get_all_modules(model):
        hook = module.register_forward_hook(get_activation(args))
    preprocessed_subject = get_dataset(input_path)[0]
    image = preprocessed_subject[IMAGE_NAME]
    inputs = image.data.unsqueeze(0).float().to(device)
    with torch.cuda.amp.autocast():
        model(inputs)

    downsampled = [image]
    for _ in range(2):
        target = torch.Tensor(downsampled[-1].spacing) * 2
        target = tuple(target.tolist())
        transform = tio.Resample(target, image_interpolation='nearest')
        downsampled_image = transform(downsampled[-1])
        downsampled.append(downsampled_image)
    for args, features in tqdm(activation.items()):
        part, level, conv_layer = args
        affine = downsampled[level].affine
        for i, feature_map in enumerate(tqdm(features[0], leave=False)):
            features_image = tio.ScalarImage(
                tensor=feature_map.unsqueeze(0).cpu().float(),
                affine=affine,
            )
            name = f'{part}_level_{level}_layer_{conv_layer}_feature_{i}.nii.gz'
            path = output_dir / name
            features_image.save(path)
Esempio n. 19
0
subject = tio.Subject(
    hr=tio.ScalarImage(image_file),
    lr=tio.ScalarImage(image_file),
)

#normalization = tio.ZNormalization(masking_method='label')#masking_method=tio.ZNormalization.mean)
normalization = tio.ZNormalization()
onehot = tio.OneHot()

spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0)

bias = tio.RandomBiasField(coefficients=0.5, p=0)
flip = tio.RandomFlip(axes=('LR', ), p=1)
noise = tio.RandomNoise(std=0.1, p=0)

sampling_1mm = tio.Resample(1)
sampling_05mm = tio.Resample(0.5)
blur = tio.RandomBlur(0.5)

sampling_jess = tio.Resample((0.8, 0.8, 2), exclude='hr')
blur_jess = tio.Blur(std=(0.001, 0.001, 1), exclude='hr')
downsampling_jess = tio.Resample((0.8, 0.8, 2), exclude='hr')
upsampling_jess = tio.Resample(target='hr', exclude='hr')

tocanonical = tio.ToCanonical()
crop1 = tio.CropOrPad((290, 290, 200))
crop2 = tio.CropOrPad((290, 290, 200), include='lr')

#transforms = tio.Compose([spatial, bias, flip, normalization, noise])
#transforms = tio.Compose([normalization, sampling_1mm, noise, blur, sampling_05mm])
#transforms = tio.Compose([blur_jess,sampling_jess])
Esempio n. 20
0
subjects = []
for (image_path, label_path) in zip(df['imagepath'], df['maskpath']):
    subject = tio.Subject(image=tio.ScalarImage(image_path),
                          mask=tio.LabelMap(label_path))
    subjects.append(subject)

# собираем особый датасет torchio с пациентами
dataset = tio.SubjectsDataset(subjects)

# приводим маску к 1 классу
if config.to_one_class:
    for subject in dataset.dry_iter():
        subject['mask'] = one(subject['mask'])

training_transform = tio.Compose([
    tio.Resample(4),
    tio.ZNormalization(
        masking_method=tio.ZNormalization.mean
    ),  # вот эту штуку все рекомендовали на форумах torchio. 
    tio.RandomFlip(p=0.25),
    tio.RandomNoise(p=0.25),
    # !!!  Приходится насильно переводить тензоры в float
    tio.Lambda(to_float)
])

validation_transform = tio.Compose([
    tio.Resample(4),
    tio.ZNormalization(masking_method=tio.ZNormalization.mean),
    tio.RandomNoise(p=0.25),
    tio.Lambda(to_float)
])
Esempio n. 21
0
    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.df.iloc[index, 0])  # 1
        image = nib.load(img_path)

        img_shape = image.shape
        target_img_shape = (1, 256, 256, 166)
        target_img_affine = "[[3.22155614e-08  2.46488298e-04 - 1.20344027e+00  9.32926025e+01], \
         [-2.46255622e-04 - 9.42077751e-01 - 3.14872030e-04  1.56945999e+02], \
         [-9.41189972e-01  2.46487912e-04  4.11920069e-08  1.12556000e+02], \
         [0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]]"

        resampled_img_data = image.get_fdata()

        resampled_data_arr = np.asarray(resampled_img_data)

        #print("New method")
        #print(resampled_data_arr.shape)
        req = np.expand_dims(resampled_data_arr, axis=0)

        #print(req.shape)

        req_shape = req.shape
        resampled_data_arr = np.reshape(resampled_data_arr, req_shape)
        #print("Niiiiiiiiiiiiiiiiiiiiiiiiiiii")
        #print(resampled_data_arr.shape)
#        if img_shape != target_img_shape:
        if resampled_data_arr != target_img_shape:
            print("inside")
            """
            resampled_nii = resample_img(image, target_affine=np.eye(4) * 2, target_shape=target_img_shape,
                                         interpolation='nearest')
            resampled_img_data = resampled_nii.get_fdata()
            """
            print(resampled_img_data.shape)
            transform = tio.Resample(4)


            resampled_img_data = transform(resampled_data_arr)  # images in fpg are now in MNI space
            #resampled_img_data = resampled_img_data.get_fdata()
            print("shape")
            print(resampled_img_data.shape)
            print("going out")
        else:
            print("img_path")
            print(img_path)
            resampled_img_data = image.get_fdata()

        resampled_data_arr = np.asarray(resampled_img_data)

        # min_max_normalization
        resampled_data_arr -= np.min(resampled_data_arr)
        resampled_data_arr /= np.max(resampled_data_arr)

        #if self.transform:
        #    resampled_data_arr = self.transform(resampled_data_arr)



        if self.transform:
            #flip = tio.RandomFlip(axes=('P',), flip_probability=1) #Please use one of: L, R, P, A, I, S, T, B
            #resampled_data_arr = flip(resampled_data_arr)
            resampled_data_arr = transform_flip(resampled_data_arr)

        y_label = 0.0 if self.df.iloc[index, 1] == 'AD' else 1.0  # bz using cross entropy #1

        # y_label = [1.0, 0.0] if (self.df.iloc[index, 1] == 'AD') else [0.0, 1.0]  # for other cross entropy #2

        y_label = torch.tensor(y_label, dtype=torch.float)

        return resampled_data_arr, y_label
Esempio n. 22
0
            lr_3=tio.ScalarImage(t2_file),
        )

    subjects.append(subject)

print('DHCP Dataset size:', len(subjects), 'subjects')

# DATA AUGMENTATION
normalization = tio.ZNormalization()
spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0.75)
flip = tio.RandomFlip(axes=('LR', ), flip_probability=0.5)

tocanonical = tio.ToCanonical()

b1 = tio.Blur(std=(0.001, 0.001, 1), include='lr_1')  #blur
d1 = tio.Resample((0.8, 0.8, 2), include='lr_1')  #downsampling
u1 = tio.Resample(target='hr', include='lr_1')  #upsampling

if in_channels == 3:
    b2 = tio.Blur(std=(0.001, 1, 0.001), include='lr_2')  #blur
    d2 = tio.Resample((0.8, 2, 0.8), include='lr_2')  #downsampling
    u2 = tio.Resample(target='hr', include='lr_2')  #upsampling

    b3 = tio.Blur(std=(1, 0.001, 0.001), include='lr_3')  #blur
    d3 = tio.Resample((2, 0.8, 0.8), include='lr_3')  #downsampling
    u3 = tio.Resample(target='hr', include='lr_3')  #upsampling

if in_channels == 1:
    transforms = [tocanonical, flip, spatial, normalization, b1, d1, u1]
    training_transform = tio.Compose(transforms)
    validation_transform = tio.Compose(
Esempio n. 23
0
 def test_image_target(self):
     tio.Resample(self.sample_subject.t1)(self.sample_subject)
Esempio n. 24
0
def segment_resection(
    input_path,
    model,
    output_path=None,
    tta_iterations=0,
    interpolation='bspline',
    num_workers=0,
    show_progress=True,
    binarize=True,
    postprocess=True,
    mni_transform_path=None,
):
    dataset = get_dataset(
        input_path,
        tta_iterations,
        interpolation,
        mni_transform_path=mni_transform_path,
    )

    device = get_device()
    model.to(device)
    model.eval()
    torch.set_grad_enabled(False)
    loader = torch.utils.data.DataLoader(
        dataset,
        num_workers=num_workers,
        collate_fn=lambda x: x,
    )
    all_results = []
    for subjects_list_batch in tqdm(loader, disable=not show_progress):
        tensors = [
            subject[IMAGE_NAME][tio.DATA] for subject in subjects_list_batch
        ]
        inputs = torch.stack(tensors).float().to(device)
        with torch.cuda.amp.autocast():
            try:
                probs = model(inputs).softmax(
                    dim=1)[:, 1:].cpu()  # discard background
            except Exception as e:
                print(e)
                raise
        iterable = list(zip(subjects_list_batch, probs))
        for subject, prob in tqdm(iterable, leave=False, unit='subject'):
            subject.image.set_data(prob)
            subject_back = subject.apply_inverse_transform(
                warn=False, image_interpolation='linear')
            all_results.append(subject_back.image.data)
    result = torch.stack(all_results)
    mean_prob = result.mean(dim=0)
    if binarize:
        mean_prob = (mean_prob >= 0.5).byte()
        class_ = tio.LabelMap
    else:
        class_ = tio.ScalarImage
    image_kwargs = {
        'tensor': mean_prob,
        'affine': subject_back[IMAGE_NAME].affine,
    }
    resample_kwargs = {
        'target': input_path,
        'image_interpolation': interpolation,
    }
    if mni_transform_path is not None:
        to_mni = tio.io.read_matrix(mni_transform_path)
        from_mni = np.linalg.inv(to_mni)
        image_kwargs[FROM_MNI] = from_mni
        resample_kwargs['pre_affine_name'] = FROM_MNI
    image = class_(**image_kwargs)
    if postprocess:
        image = tio.KeepLargestComponent()(image)
    resample = tio.Resample(**resample_kwargs)
    image_native = resample(image)
    if output_path is None:
        input_path = Path(input_path)
        split = input_path.name.split('.')
        stem = split[0]
        exts_string = '.'.join(split[1:])
        output_path = input_path.parent / f'{stem}_seg.{exts_string}'
    image_native.save(output_path)
Esempio n. 25
0
                pbar.write(
                    f"\tFilled {hole_voxels_removed} voxels from detected holes."
                )

            out = torch.from_numpy(out).unsqueeze(0)
            out = out.int()
            image = tio.LabelMap(tensor=out)

        else:
            raise NotImplementedError
            image = tio.ScalarImage(tensor=probs)

        # inverse_transforms = subject.get_composed_history().inverse(warn=False)
        # image = inverse_transforms(image)

        orig_subject = dataset.subjects_map[
            subject["name"]]  # access subject without applying transformations

        # compare shapes ignoring the channel dimension
        if image.shape[1:] != orig_subject.shape[1:]:
            resample_transform = tio.Resample(orig_subject.get_images()[0])
            image = resample_transform(image)

        assert orig_subject.shape[1:] == image.shape[
            1:], "Segmentation shape and original image shape do not match"

        pbar.write("\tSaving image...")
        image.save(out_folder / args.output_filename)
        pbar.write("\tFinished subject")
        pbar.update(1)
Esempio n. 26
0
    def __getitem__(self, index):
        file_npy = self.df.iloc[index][0]
        assert os.path.exists(file_npy), f'npy file {file_npy} does not exists'
        array_npy = np.load(file_npy)  # shape (D,H,W)
        if array_npy.ndim > 3:
            array_npy = np.squeeze(array_npy)
        array_npy = np.expand_dims(array_npy, axis=0)  #(C,D,H,W)

        #if depth_interval==2  (128,128,128)->(64,128,128)
        depth_start_random = random.randint(0, 20) % self.depth_interval
        array_npy = array_npy[:, depth_start_random::self.depth_interval, :, :]

        subject1 = tio.Subject(oct=tio.ScalarImage(tensor=array_npy), )
        subjects_list = [subject1]

        crop_h = random.randint(0, self.random_crop_h)
        # pad_h_a, pad_h_b = math.floor(crop_h / 2), math.ceil(crop_h / 2)
        pad_h_a = random.randint(0, crop_h)
        pad_h_b = crop_h - pad_h_a

        transform_1 = tio.Compose([
            # tio.OneOf({
            #     tio.RandomAffine(): 0.8,
            #     tio.RandomElasticDeformation(): 0.2,
            # }, p=0.75,),
            # tio.RandomGamma(log_gamma=(-0.3, 0.3)),
            tio.RandomFlip(axes=2, flip_probability=0.5),
            # tio.RandomAffine(
            #     scales=(0, 0, 0.9, 1.1, 0, 0), degrees=(0, 0, -5, 5, 0, 0),
            #     image_interpolation='nearest'),
            tio.Crop(cropping=(0, 0, crop_h, 0, 0, 0)),  # (d,h,w) crop height
            tio.Pad(padding=(0, 0, pad_h_a, pad_h_b, 0, 0)),
            tio.RandomNoise(std=(0, self.random_noise)),
            tio.Resample(self.resample_ratio),
            # tio.RescaleIntensity((0, 255))
        ])

        if random.randint(1, 20) == 5:
            transform = tio.Compose([tio.Resample(self.resample_ratio)])
        else:
            transform = transform_1

        subjects_dataset = tio.SubjectsDataset(subjects_list,
                                               transform=transform)

        inputs = subjects_dataset[0]['oct'][tio.DATA]
        array_3d = np.squeeze(inputs.cpu().numpy())  #shape: (D,H,W)
        array_3d = array_3d.astype(np.uint8)

        if self.imgaug_iaa is not None:
            self.imgaug_iaa.deterministic = True
        else:
            if (self.image_shape is None) or\
                    (array_3d.shape[1:3]) == (self.image_shape[0:2]):  # (H,W)
                array_4d = np.expand_dims(array_3d, axis=-1)  #(D,H,W,C)

        if 'array_4d' not in locals().keys():
            list_images = []
            for i in range(array_3d.shape[0]):
                img = array_3d[i, :, :]  #(H,W)
                if (img.shape[0:2]) != (self.image_shape[0:2]):  # (H,W)
                    img = cv2.resize(
                        img, (self.image_shape[1],
                              self.image_shape[0]))  # resize(width,height)

                # cvtColor do not support float64
                img = cv2.cvtColor(img.astype(np.float32), cv2.COLOR_GRAY2BGR)
                # other wise , MultiplyBrightness error
                img = img.astype(np.uint8)
                if self.imgaug_iaa is not None:
                    img = self.imgaug_iaa(image=img)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                list_images.append(img)

            array_4d = np.array(list_images)  # (D,H,W)
            array_4d = np.expand_dims(array_4d, axis=-1)  #(D,H,W,C)

        if self.imgaug_iaa is not None:
            self.imgaug_iaa.deterministic = False

        if self.channel_first:
            array_4d = np.transpose(array_4d,
                                    (3, 0, 1, 2))  #(D,H,W,C)->(C,D,H,W)

        array_4d = array_4d.astype(np.float32)
        array_4d = array_4d / 255.
        # if array_4d.shape != (1, 64, 64, 64):
        #     print(file_npy)

        # https://pytorch.org/docs/stable/data.html
        # It is generally not recommended to return CUDA tensors in multi-process loading because of many subtleties in using CUDA and sharing CUDA tensors in multiprocessing (see CUDA in multiprocessing).
        # tensor_x = torch.from_numpy(array_4d)

        label = int(self.df.iloc[index][1])

        return array_4d, label
Esempio n. 27
0
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        device = torch.device(
            "cuda:0" if torch.cuda.is_available() and self.hparams.cuda else "cpu")

        if self.hparams.modelID == 0:
            self.net = ResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks,
                              starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks,
                              is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm,
                              res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV,
                              upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D)  # TODO think of 2D
            # self.net = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1)
        elif self.hparams.modelID == 2:
            self.net = DualSpaceResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks,
                                        starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks,
                                        is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm,
                                        res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV,
                                        upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D,
                                        connect_mode=self.hparams.model_dspace_connect_mode, inner_norm_ksp=self.hparams.model_inner_norm_ksp)
        elif self.hparams.modelID == 3: #Primal-Dual Network, complex Primal
            self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10,
                            use_original_block = True,
                            use_original_init = True,
                            use_complex_primal = True,
                            g_normtype = "magmax",
                            transform = "Fourier",
                            return_abs = True)
        elif self.hparams.modelID == 4: #Primal-Dual Network, absolute Primal
            self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10,
                            use_original_block = True,
                            use_original_init = True,
                            use_complex_primal = False,
                            g_normtype = "magmax",
                            transform = "Fourier")
        elif self.hparams.modelID == 5: #Primal-Dual UNet Network, absolute Primal
            self.net = PrimalDualNetwork(n_primary=4, n_dual=5, n_iterations=2,
                            use_original_block = False,
                            use_original_init = False,
                            use_complex_primal = False,
                            g_normtype = "magmax",
                            transform = "Fourier")
        elif self.hparams.modelID == 6: #Primal-Dual Network v2 (no residual), complex Primal
            self.net = PrimalDualNetworkNoResidue(n_primary=5, n_dual=5, n_iterations=10,
                            use_original_block = True,
                            use_original_init = True,
                            use_complex_primal = True,
                            residuals=False,
                            g_normtype = "magmax",
                            transform = "Fourier",
                            return_abs = True)

        else:
            # TODO: other models
            sys.exit("Only ReconResNet and DualSpaceResNet have been implemented so far in ReconEngine")

        if bool(self.hparams.preweights_path):
            print("Pre-weights found, loding...")
            chk = torch.load(self.hparams.preweights_path, map_location='cpu')
            self.net.load_state_dict(chk['state_dict'])

        if self.hparams.lossID == 0:
            if self.hparams.in_channels != 1 or self.hparams.out_channels != 1:
                sys.exit(
                    "Perceptual Loss used here only works for 1 channel input and output")
            self.loss = PerceptualLoss(device=device, loss_model="unet3Dds", resize=None,
                                       loss_type=self.hparams.ploss_type, n_level=self.hparams.ploss_level)  # TODO thinkof 2D
        elif self.hparams.lossID == 1:
            self.loss = nn.L1Loss(reduction='mean')
        elif self.hparams.lossID == 2:
            self.loss = MS_SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device)
        elif self.hparams.lossID == 3:
            self.loss = SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device)
        else:
            sys.exit("Invalid Loss ID")

        self.dataspace = DataSpaceHandler(**self.hparams)

        if self.hparams.ds_mode == 0:
            trans = tioTransforms
            augs = tioAugmentations
        elif self.hparams.ds_mode == 1:
            trans = pytTransforms
            augs = pytAugmentations

        # TODO parameterised everything
        self.init_transforms = []
        self.aug_transforms = []
        self.transforms = []
        if self.hparams.ds_mode == 0 and self.hparams.cannonicalResample:  # Only applicable for TorchIO
            self.init_transforms += [tio.ToCanonical(), tio.Resample('gt')]
        if self.hparams.ds_mode == 0 and self.hparams.forceNormAffine:  # Only applicable for TorchIO
            self.init_transforms += [trans.ForceAffine()]
        if self.hparams.croppad and self.hparams.ds_mode == 1:
            self.init_transforms += [
                trans.CropOrPad(size=self.hparams.input_shape)]
        self.init_transforms += [trans.IntensityNorm(type=self.hparams.norm_type, return_meta=self.hparams.motion_return_meta)]
        # dataspace_transforms = self.dataspace.getTransforms() #TODO: dataspace transforms are not in use
        # self.init_transforms += dataspace_transforms
        if bool(self.hparams.random_crop) and self.hparams.ds_mode == 1:
            self.aug_transforms += [augs.RandomCrop(
                size=self.hparams.random_crop, p=self.hparams.p_random_crop)]
        if self.hparams.p_contrast_augment > 0:
            self.aug_transforms += [augs.getContrastAugs(
                p=self.hparams.p_contrast_augment)]
        # if the task if MoCo and pre-corrupted vols are not supplied
        if self.hparams.taskID == 1 and not bool(self.hparams.train_path_inp):
            if self.hparams.motion_mode == 0 and self.hparams.ds_mode == 0:
                motion_params = {k.split('motionmg_')[
                    1]: v for k, v in self.hparams.items() if k.startswith('motionmg')}
                self.transforms += [tioMotion.RandomMotionGhostingFast(
                    **motion_params), trans.IntensityNorm()]
            elif self.hparams.motion_mode == 1 and self.hparams.ds_mode == 1 and not self.hparams.is3D:
                self.transforms += [pytMotion.Motion2Dv0(
                    sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)]
            elif self.hparams.motion_mode == 2 and self.hparams.ds_mode == 1 and not self.hparams.is3D:
                self.transforms += [pytMotion.Motion2Dv1(sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads,
                                                         restore_original=self.hparams.motion_restore_original, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)]
            else:
                sys.exit(
                    "Error: invalid motion_mode, ds_mode, is3D combo. Please double check!")

        self.static_metamat = sio.loadmat(self.hparams.static_metamat_file) if bool(
            self.hparams.static_metamat_file) else None
        if self.hparams.taskID == 0 and self.hparams.use_datacon:
            self.datacon = DataConsistency(
                isRadial=self.hparams.is_radial, metadict=self.static_metamat)
        else:
            self.datacon = None

        input_shape = self.hparams.input_shape if self.hparams.is3D else self.hparams.input_shape[
            :-1]
        self.example_input_array = torch.empty(
            self.hparams.batch_size, self.hparams.in_channels, *input_shape).float()
        self.saver = ResSaver(
            self.hparams.res_path, save_inp=self.hparams.save_inp, do_norm=self.hparams.do_savenorm)
Esempio n. 28
0
fpg.mri.affine #.numpy()

# To maybe CropOrPad
fpg_ras.mri.spatial_shape

target_shape = 96, 96, 48
crop_pad = tio.CropOrPad(target_shape)
show_fpg(crop_pad(fpg_ras))

#Random affine :To simulate different positions and size of the patient within the scanner, we can use a RandomAffine transform.
#To improve visualization, we will use a 2D image and add a grid to it.


image = tio.ScalarImage('slice_7t.png')
spacing = image.spacing[0]
image = tio.Resample(spacing)(image)
print('Downloaded slice:', image)
to_pil(image)



tio.ScalarImage('slice_7t.png').spacing

#grid over the image
slice_grid = copy.deepcopy(image)
data = slice_grid.data
white = data.max()
N = 16
data[..., ::N, :, :] = white
data[..., :, ::N, :] = white
to_pil(slice_grid)
Esempio n. 29
0
    def get_frame(image, i):
        return image.data[..., i].permute(1, 2, 0).byte()

    plt.rcParams['animation.embed_limit'] = 25
    fig, ax = plt.subplots()
    im = ax.imshow(get_frame(image, 0))
    return animation.FuncAnimation(
        fig,
        _update_frame,
        repeat_delay=image['delay'],
        frames=image.shape[-1],
    )


# Source: https://thehigherlearning.wordpress.com/2014/06/25/watching-a-cell-divide-under-an-electron-microscope-is-mesmerizing-gif/  # noqa: E501
array, delay = read_clip('nBTu3oi.gif')
plt.imshow(array[..., 0].transpose(1, 2, 0))
plt.plot()
image = tio.ScalarImage(tensor=array, delay=delay)
original_animation = plot_gif(image)

transform = tio.Compose((
    tio.Resample((2, 2, 1)),
    tio.RandomAffine(degrees=(0, 0, 20)),
))

torch.manual_seed(0)
transformed = transform(image)
transformed_animation = plot_gif(transformed)
def patch_sampler(img_filenames,
                  labelmap_filenames,
                  patch_size,
                  sampler_type,
                  out_dir,
                  max_patches=None,
                  voxel_spacing=(),
                  patch_overlap=(0, 0, 0),
                  min_labeled_voxels=1.0,
                  label_prob=0.8,
                  save_patches=False,
                  batch_size=None,
                  prepare_batches=False,
                  inference=False):
    """Reshape a 3D volumes into a collection of 2D patches
    The resulting patches are allocated in a dedicated array.
    
    Parameters
    ----------
    img_filenames : list of strings  
        Paths to images to extract patches from 
    patch_size : tuple of ints (patch_x, patch_y, patch_z)
        The dimensions of one patch
    patch_overlap : tuple of ints (0, patch_x, patch_y)
        The maximum patch overlap between the patches 
    min_labeled_voxels is not None: : float between 0 and 1
        The minimum percentage of labeled pixels for a patch. If set to None patches are extracted based on center_voxel.
    labelmap_filenames : list of strings 
        Paths to labelmap
        
    Returns
    -------
    img_patches, label_patches : array, shape = (n_patches, patch_x, patch_y, patch_z, 1)
         The collection of patches extracted from the volumes, where `n_patches`
         is the total number of patches extracted.
    """

    if max_patches is not None:
        max_patches = int(max_patches / len(img_filenames))
    img_patches = []
    label_patches = []
    patch_counter = 0
    save_counter = 0
    img_ids = []
    label_ids = []
    save_size = 1
    if prepare_batches: save_size = batch_size
    print(f'\nExtracting patches from: {img_filenames}\n')
    for i in tqdm(range(len(img_filenames)), leave=False):
        if voxel_spacing:
            util.update_affine(img_filenames[i], labelmap_filenames[i])
        if labelmap_filenames:
            subject = tio.Subject(img=tio.Image(img_filenames[i],
                                                type=tio.INTENSITY),
                                  labelmap=tio.LabelMap(labelmap_filenames[i]))
        # Apply transformations
        #transform = tio.ZNormalization()
        #transformed = transform(subject)
        transform = tio.RescaleIntensity((0, 1))
        transformed = transform(subject)
        if voxel_spacing:
            transform = tio.Resample(voxel_spacing)
            transformed = transform(transformed)
        num_img_patches = 0
        if sampler_type == 'grid':
            sampler = tio.data.GridSampler(transformed, patch_size,
                                           patch_overlap)
            for patch in sampler:
                img_patch = np.array(patch.img.data)
                label_patch = np.array(patch.labelmap.data)
                labeled_voxels = torch.count_nonzero(
                    patch.labelmap.data) >= patch_size[0] * patch_size[
                        1] * patch_size[2] * min_labeled_voxels
                center = label_patch[0,
                                     int(patch_size[0] / 2),
                                     int(patch_size[1] / 2),
                                     int(patch_size[2] / 2)] != 0
                if labeled_voxels or center:
                    img_patches.append(img_patch)
                    label_patches.append(label_patch)
                    patch_counter += 1
                    num_img_patches += 1
                if save_patches:
                    img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter = save(
                        img_patches, label_patches, img_ids, label_ids,
                        save_counter, patch_counter, save_size, patch_size,
                        inference, out_dir)
                # Check if max_patches for img
                if max_patches is not None:
                    if num_img_patches > max_patches:
                        break
        else:
            # Define sampler
            one_label = 1.0 - label_prob
            label_probabilities = {0: one_label, 1: label_prob}
            sampler = tio.data.LabelSampler(
                patch_size, label_probabilities=label_probabilities)
            if max_patches is None:
                generator = sampler(transformed)
            else:
                generator = sampler(transformed, max_patches)
            for patch in generator:
                img_patches.append(np.array(patch.img.data))
                label_patches.append(np.array(patch.labelmap.data))
                patch_counter += 1
                if save_patches:
                    img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter = save(
                        img_patches, label_patches, img_ids, label_ids,
                        save_counter, patch_counter, save_size, patch_size,
                        inference, out_dir)
    print(f'Finished extracting patches.')
    if save_patches:
        return img_ids, label_ids
    else:
        if patch_size[0] == 1:
            return np.array(img_patches).reshape(
                len(img_patches), patch_size[1], patch_size[2],
                1), np.array(label_patches).reshape(len(label_patches),
                                                    patch_size[1],
                                                    patch_size[2], 1)
        else:
            return np.array(img_patches).reshape(
                len(img_patches), patch_size[0], patch_size[1], patch_size[2],
                1), np.array(label_patches).reshape(len(label_patches),
                                                    patch_size[1],
                                                    patch_size[2], 1)