def __init__(
        self,
        datasets_dir,
        real_dataset_dir,
        resection_params,
        train_batch_size,
        num_workers,
        pseudo_dir=None,
        split_ratio=0.9,
        split_seed=42,
        debug_ratio=0.02,
        log=None,
        debug=False,
        augment=True,
        verbose=False,
        cache_validation_set=True,
        histogram_standardization=True,
    ):
        super().__init__(datasets_dir, train_batch_size, num_workers)
        self.resection_params = resection_params

        # Precomputed from 90% of the public training data
        if histogram_standardization:
            self.landmarks_path = Path(
                __file__
            ).parent / 'landmarks' / 'histogram_landmarks_default.npy'
        else:
            self.landmarks_path = None

        public_subjects = self.get_public_subjects()
        train_public, val_public = self.split_subjects(public_subjects,
                                                       split_ratio, split_seed)

        train_transform = self.get_train_transform(
        ) if augment else self.get_val_transform()
        self.train_dataset = tio.SubjectsDataset(train_public,
                                                 transform=train_transform)
        self.val_dataset = tio.SubjectsDataset(val_public,
                                               transform=train_transform)
        if cache_validation_set:
            self.val_dataset = cache(self.val_dataset,
                                     resection_params,
                                     augment=augment)
        test_transform = get_test_transform(self.landmarks_path)
        self.test_dataset = get_real_resection_dataset(
            real_dataset_dir, transform=test_transform)
        if debug:
            self.train_dataset = reduce_dataset(self.train_dataset,
                                                debug_ratio)
            self.val_dataset = reduce_dataset(self.val_dataset, debug_ratio)
            self.test_dataset = reduce_dataset(self.test_dataset, debug_ratio)

        self.train_loader = self.get_train_loader(self.train_dataset)
        self.val_loader = self.get_val_loader(self.val_dataset)
        self.test_loader = self.get_val_loader(self.test_dataset)

        self.log = log

        if verbose:
            self.print_lengths()
Exemple #2
0
def get_torchio_dataset(inputs, targets, transform):
    
    """
    Function creates a torchio.SubjectsDataset from inputs and targets lists and applies transform to that dataset
    
    Arguments:
        * inputs (list): list of paths to MR images
        * targets (list):  list of paths to ground truth segmentation of MR images
        * transform (False/torchio.transforms): transformations which will be applied to MR images and ground truth segmentation of MR images (but not all of them)
    
    Output:
        * datasets (torchio.SubjectsDataset): it's kind of torchio list of torchio.data.subject.Subject entities
    """
    
    subjects = []
    for (image_path, label_path) in zip(inputs, targets ):
        subject_dict = {
            'MRI' : torchio.Image(image_path, torchio.INTENSITY),
            'LABEL': torchio.Image(label_path, torchio.LABEL), #intensity transformations won't be applied to torchio.LABEL 
        }
        subject = torchio.Subject(subject_dict)
        subjects.append(subject)
    
    if transform:
        dataset = torchio.SubjectsDataset(subjects, transform = transform)
    elif not transform:
        dataset = torchio.SubjectsDataset(subjects)
    
    return dataset
Exemple #3
0
def load_pretrain_datasets(data_shape, batch=3, workers=4, transform=None):

    data_path = '/home/mitch/Data/MSD/'
    directories = sorted(glob.glob(data_path + '*/'))

    loaders = []  #var to store dataloader for each task
    datasets = []  #store dataset objects before turning into loaders

    if transform == None:
        transform = tio.RandomFlip(p=0.)
    #preprocess all
    clippy = Lambda(lambda x: torch.clip(x, -80, 300),
                    types_to_apply=[tio.INTENSITY])
    normal = RescaleIntensity((0., 1.))
    resize = Lambda(lambda x: torch.squeeze(
        interpolate(torch.unsqueeze(x, dim=0), data_shape), dim=0))
    rounding = Lambda(lambda x: torch.round(x), types_to_apply=[tio.LABEL])
    transform = tio.Compose([clippy, normal, resize, rounding, transform])

    #deal with weird shapes
    braintransform = Lambda(lambda x: torch.unsqueeze(x[:, :, :, 2], dim=0),
                            types_to_apply=[tio.INTENSITY])
    braintransform = tio.Compose([braintransform, transform])
    prostatetransform = Lambda(lambda x: torch.unsqueeze(x[:, :, :, 1], dim=0),
                               types_to_apply=[tio.INTENSITY])
    prostatetransform = tio.Compose([prostatetransform, transform])

    for i, directory in enumerate(directories):
        images = sorted(glob.glob(directory + 'imagesTr/*'))
        segs = sorted(glob.glob(directory + 'labelsTr/*'))

        subject_list = []

        for image, seg in zip(images, segs):

            subject_list.append(
                tio.Subject(img=tio.ScalarImage(image),
                            label=tio.LabelMap(seg)))

        #handle special cases
        if i == 0:
            datasets.append(
                tio.SubjectsDataset(subject_list, transform=braintransform))
        elif i == 4:
            datasets.append(
                tio.SubjectsDataset(subject_list, transform=prostatetransform))
        else:
            datasets.append(
                tio.SubjectsDataset(subject_list, transform=transform))

        loaders.append(
            DataLoader(datasets[-1],
                       num_workers=workers,
                       batch_size=batch,
                       pin_memory=True))

    return loaders
Exemple #4
0
def get_dataset(
    input_path,
    tta_iterations=0,
    interpolation='bspline',
    tolerance=0.1,
    mni_transform_path=None,
):
    if mni_transform_path is None:
        image = tio.ScalarImage(input_path)
    else:
        affine = tio.io.read_matrix(mni_transform_path)
        image = tio.ScalarImage(input_path, **{TO_MNI: affine})
    subject = tio.Subject({IMAGE_NAME: image})
    landmarks = np.array([
        0., 0.31331614, 0.61505419, 0.76732501, 0.98887953, 1.71169384,
        3.21741126, 13.06931455, 32.70817796, 40.87807389, 47.83508873,
        63.4408591, 100.
    ])
    hist_std = tio.HistogramStandardization({IMAGE_NAME: landmarks})
    preprocess_transforms = [
        tio.ToCanonical(),
        hist_std,
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
    ]
    zooms = nib.load(input_path).header.get_zooms()
    pixdim = np.array(zooms)
    diff_to_1_iso = np.abs(pixdim - 1)
    if np.any(diff_to_1_iso > tolerance) or mni_transform_path is not None:
        kwargs = {'image_interpolation': interpolation}
        if mni_transform_path is not None:
            kwargs['pre_affine_name'] = TO_MNI
            kwargs['target'] = tio.datasets.Colin27().t1.path
        resample_transform = tio.Resample(**kwargs)
        preprocess_transforms.append(resample_transform)
    preprocess_transforms.append(tio.EnsureShapeMultiple(8, method='crop'))
    preprocess_transform = tio.Compose(preprocess_transforms)
    no_aug_dataset = tio.SubjectsDataset([subject],
                                         transform=preprocess_transform)

    aug_subjects = tta_iterations * [subject]
    if not aug_subjects:
        return no_aug_dataset
    augment_transform = tio.Compose((
        preprocess_transform,
        tio.RandomFlip(),
        tio.RandomAffine(image_interpolation=interpolation),
    ))
    aug_dataset = tio.SubjectsDataset(aug_subjects,
                                      transform=augment_transform)
    dataset = torch.utils.data.ConcatDataset((no_aug_dataset, aug_dataset))
    return dataset
Exemple #5
0
def load_kidney_seg(data_shape, batch=3, workers=4, transform=None):

    #take input transform and apply it after clip, normalization, resize
    if transform == None:
        transform = tio.RandomFlip(p=0.)
    #preprocess all
    clippy = Lambda(lambda x: torch.clip(x, -80, 300),
                    types_to_apply=[tio.INTENSITY])
    normal = RescaleIntensity((0., 1.))
    resize = Lambda(lambda x: torch.squeeze(
        interpolate(torch.unsqueeze(x, dim=0), data_shape), dim=0))
    rounding = Lambda(lambda x: torch.round(x), types_to_apply=[tio.LABEL])
    transform = tio.Compose([clippy, normal, resize, rounding, transform])

    subject_list = []
    for i in range(210):
        pt_image = ("data/case_{:05d}/imaging.nii.gz".format(i))
        pt_label = ("data/case_{:05d}/segmentation.nii.gz".format(i))
        subject_list.append(
            tio.Subject(img=tio.ScalarImage(pt_image),
                        label=tio.LabelMap(pt_label)))
    dataset = tio.SubjectsDataset(subject_list, transform=transform)
    return DataLoader(dataset,
                      num_workers=workers,
                      batch_size=batch,
                      pin_memory=True)
def cache(dataset, resection_params, augment=True, caches_dir='/tmp/val_set_cache', num_workers=12):
    caches_dir = Path(caches_dir)
    wm_lesion_p = resection_params['wm_lesion_p']
    clot_p = resection_params['clot_p']
    shape = resection_params['shape']
    texture = resection_params['texture']
    augment_string = '_no_augmentation' if not augment else ''
    dir_name = f'wm_{wm_lesion_p}_clot_{clot_p}_{shape}_{texture}{augment_string}'
    cache_dir = caches_dir / dir_name
    image_dir = cache_dir / 'image'
    label_dir = cache_dir / 'label'
    if not cache_dir.is_dir():
        print('Caching validation set')
        image_dir.mkdir(parents=True)
        label_dir.mkdir(parents=True)
        loader = torch.utils.data.DataLoader(
            dataset,
            num_workers=num_workers,
            collate_fn=lambda x: x[0],
        )
        for subject in tqdm(loader):
            image_path = image_dir / subject.image.path.name
            label_path = label_dir / subject.image.path.name  # label has no path because it was created not loaded
            subject.image.save(image_path)
            subject.label.save(label_path)

    subjects = []
    for im_path, label_path in zip(sglob(image_dir), sglob(label_dir)):
        subject = tio.Subject(
            image=tio.ScalarImage(im_path),
            label=tio.LabelMap(label_path),
        )
        subjects.append(subject)
    return tio.SubjectsDataset(subjects)
Exemple #7
0
def get_subjects(path, structures, transform):
    """
    Browse the path folder to build a dataset. Folder must contains the subjects with the CT and masks.

    :param path: root folder.
    :type path: str
    :param structures: list of structures.
    :type structures: list[str]
    :param transform: transforms to be applied.
    :type transform: :class:`tio.transforms.Transform`
    :return: Base TorchIO dataset.
    :rtype: :class:`tio.SubjectsDataset`
    """
    subject_ids = os.listdir(path)
    subjects = []
    for subject_id in subject_ids:
        ct_path = os.path.join(path, subject_id, 'ct.nii')
        structures_path_dict = {k: os.path.join(path, subject_id, k + '.nii') for k in structures}

        subject = tio.Subject(
            ct=tio.ScalarImage(ct_path),
        )
        label_map = torch.zeros(subject["ct"].shape, dtype=torch.long)
        for i, (k, v) in enumerate(structures_path_dict.items()):
            label_map += tio.LabelMap(v).data * (i + 1)

        label_map[label_map > len(structures)] = 0
        subject.add_image(tio.LabelMap(tensor=label_map, affine=subject["ct"].affine), 'label_map')
        subjects.append(subject)

    return tio.SubjectsDataset(subjects, transform=transform)
    def __init__(
        self,
        fold,
        num_folds,
        datasets_dir,
        dataset_name,
        train_batch_size,
        num_workers,
        use_public_landmarks=False,
        pseudo_dirname=None,
        split_seed=42,
        log=None,
        verbose=True,
    ):
        super().__init__(datasets_dir, train_batch_size, num_workers)
        self.resection_params = None
        real_dataset_dir = self.datasets_dir / 'real' / dataset_name
        real_subjects = get_real_resection_subjects(real_dataset_dir)
        train_subjects, val_subjects = self.split_subjects(
            real_subjects, fold, num_folds, split_seed)
        self.train_dataset = tio.SubjectsDataset(train_subjects)
        if use_public_landmarks:
            self.landmarks_path = get_landmarks_path()
        else:
            self.landmarks_path = get_landmarks_path(
                dataset=self.train_dataset)
        train_transform = self.get_train_transform(resect=False)
        self.train_dataset.set_transform(train_transform)
        test_transform = get_test_transform(self.landmarks_path)
        self.val_dataset = tio.SubjectsDataset(val_subjects,
                                               transform=test_transform)

        if pseudo_dirname is not None:
            pseudo_dir = self.datasets_dir / 'real' / pseudo_dirname
            pseudo_dataset = get_real_resection_dataset(
                pseudo_dir, transform=train_transform)
            self.train_dataset = torch.utils.data.ConcatDataset(
                (self.train_dataset, pseudo_dataset))

        self.train_loader = self.get_train_loader(self.train_dataset)
        self.val_loader = self.test_loader = self.get_val_loader(
            self.val_dataset)

        self.log = log
        if verbose:
            self.print_lengths(test=False)
Exemple #9
0
 def test_from_batch(self):
     dataset = tio.SubjectsDataset([self.sample_subject])
     loader = DataLoader(dataset)
     batch = tio.utils.get_first_item(loader)
     new_dataset = tio.SubjectsDataset.from_batch(batch)
     self.assertTensorEqual(
         dataset[0].t1.data,
         new_dataset[0].t1.data,
     )
Exemple #10
0
 def test_label_probabilities(self):
     labels = torch.Tensor((0, 0, 1, 1, 2, 1, 0)).reshape(1, 1, 1, -1)
     subject = torchio.Subject(label=torchio.Image(tensor=labels,
                                                   type=torchio.LABEL), )
     sample = torchio.SubjectsDataset([subject])[0]
     probs_dict = {0: 0, 1: 50, 2: 25, 3: 25}
     sampler = LabelSampler(5, 'label', label_probabilities=probs_dict)
     probabilities = sampler.get_probability_map(sample)
     fixture = torch.Tensor((0, 0, 2 / 12, 2 / 12, 3 / 12, 2 / 12, 0))
     assert torch.all(probabilities.squeeze().eq(fixture))
Exemple #11
0
 def get_sample(self, image_shape):
     t1 = torch.rand(*image_shape)
     prob = torch.zeros_like(t1)
     prob[0, 3, 3, 3] = 1
     subject = torchio.Subject(
         t1=torchio.ScalarImage(tensor=t1),
         prob=torchio.ScalarImage(tensor=prob),
     )
     subject = torchio.SubjectsDataset([subject])[0]
     return subject
Exemple #12
0
def createTIODynDS(path_gt,
                   path_corrupt,
                   is_infer=False,
                   p=1,
                   transforms=[],
                   **kwargs):
    files_gt = glob(path_gt + "/**/*.nii", recursive=True) + glob(
        path_gt + "/**/*.nii.gz", recursive=True)
    if path_corrupt:
        files_inp = glob(path_corrupt + "/**/*.nii", recursive=True) + glob(
            path_corrupt + "/**/*.nii.gz", recursive=True)
        corruptFly = False
    else:
        files_inp = files_gt.copy()
        corruptFly = True
    subjects = []

    inp_dicts, files_inp = __process_TPs(files_inp)
    gt_dicts, _ = __process_TPs(files_gt)
    for filename in files_inp:
        inp_files = [d for d in inp_dicts if filename in d['filename']]
        gt_files = [d for d in gt_dicts if filename in d['filename']]
        tps = list(set(dic["tp"] for dic in inp_files))
        tp_prev = tps.pop(0)
        for tp in tps:
            inp_tp_prev = [d for d in inp_files if tp_prev == d['tp']]
            gt_tp_prev = [d for d in gt_files if tp_prev == d['tp']]
            inp_tp = [d for d in inp_files if tp == d['tp']]
            gt_tp = [d for d in gt_files if tp == d['tp']]
            tp_prev = tp
            if len(gt_tp_prev) > 0 and len(gt_tp) > 0:
                subjects.append(
                    tio.Subject(
                        gt_tp_prev=tio.ScalarImage(gt_tp_prev[0]['path']),
                        inp_tp_prev=tio.ScalarImage(inp_tp_prev[0]['path']),
                        gt=tio.ScalarImage(gt_tp[0]['path']),
                        inp=tio.ScalarImage(inp_tp[0]['path']),
                        filename=filename,
                        tp=tp,
                        tag="CorruptNGT",
                    ))

            else:
                print(
                    "Warning: Not Implemented if GT is missing. Skipping Sub-TP."
                )
                continue

    if corruptFly:
        moco = MotionCorrupter(**kwargs)
        transforms.append(tio.Lambda(moco.perform, p=p))
    transforms.append(ProcessTIOSubsTPs())
    transform = tio.Compose(transforms)
    subjects_dataset = tio.SubjectsDataset(subjects, transform=transform)
    return subjects_dataset
Exemple #13
0
def calculate_ssim(global_list):
    global_ssim = []
    for i, scale_factors in enumerate(global_list):
        set_ssim = []
        test_dataset = tio.SubjectsDataset(scale_factors,
                                           transform=test_transform)
        for sample in tqdm(test_dataset):
            _, _, _, ssim_val = test_network(sample)
            set_ssim.append(ssim_val)
        global_ssim.append(set_ssim)
    return global_ssim
Exemple #14
0
    def __init__(self, images_dir, labels_dir):

        self.subjects = []

        if (hp.in_class == 1) and (hp.out_class == 1):

            images_dir = Path(images_dir)
            self.image_paths = sorted(images_dir.glob(hp.fold_arch))
            labels_dir = Path(labels_dir)
            self.label_paths = sorted(labels_dir.glob(hp.fold_arch))

            for (image_path, label_path) in zip(self.image_paths,
                                                self.label_paths):
                subject = tio.Subject(
                    source=tio.ScalarImage(image_path),
                    label=tio.LabelMap(label_path),
                )
                self.subjects.append(subject)
        else:
            images_dir = Path(images_dir)
            self.image_paths = sorted(images_dir.glob(hp.fold_arch))

            artery_labels_dir = Path(labels_dir + '/artery')
            self.artery_label_paths = sorted(
                artery_labels_dir.glob(hp.fold_arch))

            lung_labels_dir = Path(labels_dir + '/lung')
            self.lung_label_paths = sorted(lung_labels_dir.glob(hp.fold_arch))

            trachea_labels_dir = Path(labels_dir + '/trachea')
            self.trachea_label_paths = sorted(
                trachea_labels_dir.glob(hp.fold_arch))

            vein_labels_dir = Path(labels_dir + '/vein')
            self.vein_label_paths = sorted(vein_labels_dir.glob(hp.fold_arch))

            for (image_path, artery_label_path, lung_label_path,
                 trachea_label_path, vein_label_path) in zip(
                     self.image_paths, self.artery_label_paths,
                     self.lung_label_paths, self.trachea_label_paths,
                     self.vein_label_paths):
                subject = tio.Subject(
                    source=tio.ScalarImage(image_path),
                    atery=tio.LabelMap(artery_label_path),
                    lung=tio.LabelMap(lung_label_path),
                    trachea=tio.LabelMap(trachea_label_path),
                    vein=tio.LabelMap(vein_label_path),
                )
                self.subjects.append(subject)

        self.transforms = self.transform()

        self.training_set = tio.SubjectsDataset(self.subjects,
                                                transform=self.transforms)
Exemple #15
0
def training_network(landmarks, dataset, subjects):
    training_transform = Compose([
        ToCanonical(),
        Resample(4),
        CropOrPad((48, 60, 48), padding_mode='reflect'),
        RandomMotion(),
        HistogramStandardization({'mri': landmarks}),
        RandomBiasField(),
        ZNormalization(masking_method=ZNormalization.mean),
        RandomNoise(),
        RandomFlip(axes=(0, )),
        OneOf({
            RandomAffine(): 0.8,
            RandomElasticDeformation(): 0.2,
        }),
    ])

    validation_transform = Compose([
        ToCanonical(),
        Resample(4),
        CropOrPad((48, 60, 48), padding_mode='reflect'),
        HistogramStandardization({'mri': landmarks}),
        ZNormalization(masking_method=ZNormalization.mean),
    ])

    training_split_ratio = 0.9
    num_subjects = len(dataset)
    num_training_subjects = int(training_split_ratio * num_subjects)

    training_subjects = subjects[:num_training_subjects]
    validation_subjects = subjects[num_training_subjects:]

    training_set = tio.SubjectsDataset(training_subjects,
                                       transform=training_transform)

    validation_set = tio.SubjectsDataset(validation_subjects,
                                         transform=validation_transform)

    print('Training set:', len(training_set), 'subjects')
    print('Validation set:', len(validation_set), 'subjects')
    return training_set, validation_set
    def __init__(self, imgs_path):
        self.img_list = get_listdir(imgs_path)
        self.img_list.sort()
        self.subjects = []
        for image_path in self.img_list:
            subject = torchio.Subject(
                source=torchio.ScalarImage(image_path)
            )
            self.subjects.append(subject)
        self.transforms = self.transform()

        self.test_set = torchio.SubjectsDataset(self.subjects, transform=self.transforms)
Exemple #17
0
def create_trainDS(path, p=1, **kwargs):
    files = glob(path + "/**/*.nii", recursive=True) + glob(
        path + "/**/*.nii.gz", recursive=True)
    subjects = []
    for file in files:
        subjects.append(
            tio.Subject(
                im=tio.ScalarImage(file),
                filename=os.path.basename(file),
            ))
    moco = MotionCorrupter(**kwargs)
    transforms = [tio.Lambda(moco.perform, p=p)]
    transform = tio.Compose(transforms)
    subjects_dataset = tio.SubjectsDataset(subjects, transform=transform)
    return subjects_dataset
Exemple #18
0
def create_trainDS_precorrupt(path_gt, path_corrupt, p=1, norm_mode=0):
    files = glob(path_gt + "/**/*.nii", recursive=True) + glob(
        path_gt + "/**/*.nii.gz", recursive=True)
    subjects = []
    for file in files:
        subjects.append(
            tio.Subject(
                im=tio.ScalarImage(file),
                filename=os.path.basename(file),
            ))
    transforms = [
        ReadCorrupted(path_corrupt=path_corrupt, p=p, norm_mode=norm_mode)
    ]
    transform = tio.Compose(transforms)
    subjects_dataset = tio.SubjectsDataset(subjects, transform=transform)
    return subjects_dataset
Exemple #19
0
    def setUp(self):
        """Set up test fixtures, if any."""
        self.dir = Path(tempfile.gettempdir()) / '.torchio_tests'
        self.dir.mkdir(exist_ok=True)
        random.seed(42)
        np.random.seed(42)

        registration_matrix = np.array([
            [1, 0, 0, 10],
            [0, 1, 0, 0],
            [0, 0, 1.2, 0],
            [0, 0, 0, 1]
        ])

        subject_a = tio.Subject(
            t1=tio.ScalarImage(self.get_image_path('t1_a')),
        )
        subject_b = tio.Subject(
            t1=tio.ScalarImage(self.get_image_path('t1_b')),
            label=tio.LabelMap(self.get_image_path('label_b', binary=True)),
        )
        subject_c = tio.Subject(
            label=tio.LabelMap(self.get_image_path('label_c', binary=True)),
        )
        subject_d = tio.Subject(
            t1=tio.ScalarImage(
                self.get_image_path('t1_d'),
                pre_affine=registration_matrix,
            ),
            t2=tio.ScalarImage(self.get_image_path('t2_d')),
            label=tio.LabelMap(self.get_image_path('label_d', binary=True)),
        )
        subject_a4 = tio.Subject(
            t1=tio.ScalarImage(self.get_image_path('t1_a'), components=2),
        )
        self.subjects_list = [
            subject_a,
            subject_a4,
            subject_b,
            subject_c,
            subject_d,
        ]
        self.dataset = tio.SubjectsDataset(self.subjects_list)
        self.sample_subject = self.dataset[-1]  # subject_d
Exemple #20
0
def SubjectsDataset():
    images_dir = dataset_dir / 'image'
    labels_dir = dataset_dir / 'label'
    image_paths = sorted(images_dir.glob('*.nii.gz'))
    label_paths = sorted(labels_dir.glob('*.nii.gz'))
    assert len(image_paths) == len(label_paths)

    subjects = []
    for (image_path, label_path) in zip(image_paths, label_paths):
        subject = tio.Subject(
            mri=tio.ScalarImage(image_path),
            brain=tio.LabelMap(label_path),
        )
        subjects.append(subject)
    # subjects = np.array(subjects)
    dataset = tio.SubjectsDataset(subjects)
    print('Dataset size:', len(dataset),
          'subjects')  ## => Dataset size : 566 subjects
    return dataset, subjects
Exemple #21
0
 def test_batch_history(self):
     # https://github.com/fepegar/torchio/discussions/743
     subject = self.sample_subject
     transform = tio.Compose([
         tio.RandomAffine(),
         tio.CropOrPad(5),
         tio.OneHot(),
     ])
     dataset = tio.SubjectsDataset([subject], transform=transform)
     loader = torch.utils.data.DataLoader(
         dataset,
         collate_fn=tio.utils.history_collate
     )
     batch = tio.utils.get_first_item(loader)
     transformed: tio.Subject = tio.utils.get_subjects_from_batch(batch)[0]
     inverse = transformed.apply_inverse_transform()
     images1 = subject.get_images(intensity_only=False)
     images2 = inverse.get_images(intensity_only=False)
     for image1, image2 in zip(images1, images2):
         assert image1.shape == image2.shape
    def __init__(self, path, images=None, labels=None, transforms=None):
        self.transforms = transforms
        self.subjects = []

        self.images = images
        self.labels = labels

        self.subject_folder_names = os.listdir(path)
        self.subject_folders = [f"{path}/{folder}/" for folder in self.subject_folder_names]
        for subject_folder in self.subject_folders:
            subject_files = os.listdir(subject_folder)
            subject_data = {}

            attributes_file = "attributes.json"
            if attributes_file in subject_files:
                with open(f"{subject_folder}/{attributes_file}") as f:
                    subject_data = json.load(f)
                subject_files.remove(attributes_file)

            file_map = {file[:file.find(".")]: file for file in subject_files}

            missing_name = False
            all_names = []
            if images is not None:
                all_names += images
            if labels is not None:
                all_names += labels
            for name in all_names:
                if name not in file_map:
                    missing_name = True
            if missing_name:
                continue
            if images is not None:
                for name in images:
                    subject_data[name] = tio.ScalarImage(subject_folder + file_map[name])
            if labels is not None:
                for name in labels:
                    subject_data[name] = tio.LabelMap(subject_folder + file_map[name])

            self.subjects.append(tio.Subject(**subject_data))
        self.subject_dataset = tio.SubjectsDataset(self.subjects, transform=transforms)
def main(image_dir, label_dir, checkpoint_path, output_dir, landmarks_path,
         df_path, batch_size, num_workers, multi_gpu):
    import torch
    import torchio as tio
    import models
    import datasets
    import engine
    import utils

    fps = get_paths(image_dir)
    lfps = get_paths(label_dir)
    assert len(fps) == len(lfps)
    # key must be 'image' as in get_test_transform
    subjects = [
        tio.Subject(image=tio.ScalarImage(fp), label=tio.LabelMap(lfp))
        for (fp, lfp) in zip(fps, lfps)
    ]
    transform = datasets.get_test_transform(landmarks_path)
    dataset = tio.SubjectsDataset(subjects, transform)
    checkpoint = torch.load(checkpoint_path)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = models.get_unet().to(device)
    if multi_gpu:
        model = torch.nn.DataParallel(model)
        model.module.load_state_dict(checkpoint['model'])
    else:
        model.load_state_dict(checkpoint['model'])
    output_dir = Path(output_dir)
    model.eval()
    torch.set_grad_enabled(False)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         num_workers=num_workers)
    output_dir.mkdir(parents=True)
    evaluator = engine.Evaluator()
    df = evaluator.infer(model, loader, output_dir)
    df.to_csv(df_path)
    med, iqr = 100 * utils.get_median_iqr(df.Dice)
    print(f'{med:.1f} ({iqr:.1f})')
    return 0
def get_pseudo_loader(
    threshold,
    percentile,
    metric,
    summary_path,
    dataset_name,
    num_workers,
    batch_size=2,
    remove_zero_volume=False,
):
    subjects = []
    subject_ids = get_certain_subjects(
        threshold,
        percentile,
        metric,
        summary_path,
        remove_zero_volume=remove_zero_volume,
    )
    dataset_dir = Path('/home/fernando/datasets/real/') / dataset_name
    assert dataset_dir.is_dir()
    image_dir = dataset_dir / 'image'
    label_dir = dataset_dir / 'label'
    for subject_id in subject_ids:
        image_path = list(image_dir.glob(f'{subject_id}_*'))[0]
        label_path = list(label_dir.glob(f'{subject_id}_*'))[0]
        subject = tio.Subject(
            image=tio.ScalarImage(image_path),
            label=tio.LabelMap(label_path),
        )
        subjects.append(subject)
    transform = get_train_transform(get_landmarks_path())
    dataset = tio.SubjectsDataset(subjects, transform=transform)
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=num_workers,
    )
    return loader
Exemple #25
0
    def test_dataset(self, modalities):
        """
        Testing the Dataset class
        Note that test is not for 
        """
        output_path_mod = pathlib.Path(self.output_path, str("temp_folder_" + ("_").join(modalities.split(",")))).as_posix()
        comp_path = pathlib.Path(output_path_mod).resolve().joinpath('dataset.csv').as_posix()
        comp_table = pd.read_csv(comp_path, index_col=0)
        print(comp_path, comp_table)
        
        #Loading from nrrd files
        subjects_nrrd = Dataset.load_image(output_path_mod, ignore_multi=True)
        #Loading files directly
        # subjects_direct = Dataset.load_directly(self.input_path,modalities=modalities,ignore_multi=True)
        
        #The number of subjects is equal to the number of components which is 2 for this dataset
        # assert len(subjects_nrrd) == len(subjects_direct) == 2, "There was some error in generation of subject object"
        # assert subjects_nrrd[0].keys() == subjects_direct[0].keys()

        # del subjects_direct
        # To check if all metadata items present in the keys
        # temp_nrrd = subjects_nrrd[0]
        # columns_shdbe_present = set([col if col.split("_")[0]=="metadata" else "mod_"+("_").join(col.split("_")[1:]) for col in list(comp_table.columns) if col.split("_")[0] in ["folder","metadata"]])
        # print(columns_shdbe_present)
        # assert set(temp_nrrd.keys()).issubset(columns_shdbe_present), "Not all items present in dictionary, some fault in going through the different columns in a single component"

        transforms = tio.Compose([tio.Resample(4), tio.CropOrPad((96,96,40)), select_roi_names(["larynx"]), tio.OneHot()])

        #Forming dataset and dataloader
        test_set = tio.SubjectsDataset(subjects_nrrd, transform=transforms)
        test_loader = torch.utils.data.DataLoader(test_set,batch_size=2,shuffle=True,collate_fn = collate_fn)

        #Check test_set is correct
        assert len(test_set)==2

        #Get items from test loader
        #If this function fails , there is some error in formation of test
        data = next(iter(test_loader))
        A = [item[1].shape == (2,1,96,96,40) if not "RTSTRUCT" in item[0] else item[1].shape == (2,2,96,96,40) for item in data.items()]
        assert all(A), "There is some problem in the transformation/the formation of subject object"
Exemple #26
0
def main(input_path, checkpoint_path, output_dir, landmarks_path, batch_size, num_workers, resample):
    import torch
    from tqdm import tqdm
    import torchio as tio
    import models
    import datasets

    fps = get_paths(input_path)
    subjects = [tio.Subject(image=tio.ScalarImage(fp)) for fp in fps]  # key must be 'image' as in get_test_transform
    transform = tio.Compose((
        tio.ToCanonical(),
        datasets.get_test_transform(landmarks_path),
    ))
    if resample:
        transform = tio.Compose((
            tio.Resample(),
            transform,
            # tio.CropOrPad((264, 268, 144)),  # ################################# for BITE?
        ))
    dataset = tio.SubjectsDataset(subjects, transform)
    checkpoint = torch.load(checkpoint_path)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = models.get_unet().to(device)
    model.load_state_dict(checkpoint['model'])
    output_dir = Path(output_dir)
    model.eval()
    torch.set_grad_enabled(False)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
    output_dir.mkdir(exist_ok=True, parents=True)
    for batch in tqdm(loader):
        inputs = batch['image'][tio.DATA].float().to(device)
        seg = model(inputs).softmax(dim=1)[:, 1:].cpu() > 0.5
        for tensor, affine, path in zip(seg, batch['image'][tio.AFFINE], batch['image'][tio.PATH]):
            image = tio.LabelMap(tensor=tensor, affine=affine.numpy())
            path = Path(path)
            out_path = output_dir / path.name.replace('.nii', '_seg_cnn.nii')
            image.save(out_path)
    return 0
Exemple #27
0
def createTIODS(path_gt,
                path_corrupt,
                is_infer=False,
                p=1,
                transforms=[],
                **kwargs):
    files_gt = glob(path_gt + "/**/*.nii", recursive=True) + glob(
        path_gt + "/**/*.nii.gz", recursive=True)
    if path_corrupt:
        files_inp = glob(path_corrupt + "/**/*.nii", recursive=True) + glob(
            path_corrupt + "/**/*.nii.gz", recursive=True)
        corruptFly = False
    else:
        files_inp = files_gt.copy()
        corruptFly = True
    subjects = []

    for file in files_inp:
        filename = os.path.basename(file)
        gt_files = [f for f in files_gt if filename in f]
        if len(gt_files) > 0:
            gt_path = gt_files[0]
            files_gt.remove(gt_path)
            subjects.append(
                tio.Subject(
                    gt=tio.ScalarImage(gt_path),
                    inp=tio.ScalarImage(file),
                    filename=filename,
                    tag="CorruptNGT",
                ))

    if corruptFly:
        moco = MotionCorrupter(**kwargs)
        transforms.append(tio.Lambda(moco.perform, p=p))
    transform = tio.Compose(transforms)
    subjects_dataset = tio.SubjectsDataset(subjects, transform=transform)
    return subjects_dataset
    def apply_transforms(self, image, labels):
        #inputs = np.asarray(image, dtype=np.float32)
        inputs = image

        inputs = torch.tensor(inputs, dtype=torch.float, requires_grad=False)
        labels = torch.tensor(labels, dtype=torch.long, requires_grad=False)
        """ Expected input is:   (C x W x H x D) """
        inputs = inputs.unsqueeze(0)
        inputs = torch.moveaxis(inputs, 1, -1)

        labels = labels.unsqueeze(0)
        labels = torch.moveaxis(labels, 1, -1)

        subject_a = tio.Subject(
            one_image=tio.ScalarImage(tensor=inputs),  # *** must be tensors!!!
            a_segmentation=tio.LabelMap(tensor=labels))

        subjects_list = [subject_a]

        subjects_dataset = tio.SubjectsDataset(subjects_list,
                                               transform=self.transforms)
        subject_sample = subjects_dataset[0]

        X = subject_sample['one_image']['data'].numpy()
        Y = subject_sample['a_segmentation']['data'].numpy()
        """ Re-arrange channels for Pytorch into (D, H, W) """
        X = X[0]
        X = np.moveaxis(X, -1, 0)

        Y = Y[0]
        Y = np.moveaxis(Y, -1, 0)
        """ DEBUG """
        #plot_max(X)
        #plot_max(Y)

        return X, Y
def predict_agg_3d(
    input_array,
    model3d,
    patch_size=(128, 224, 224),
    patch_overlap=(12, 12, 12),
    nb=True,
    device=0,
    debug_verbose=False,
    fpn=False,
    overlap_mode="crop",
):
    import torchio as tio
    from torchio import IMAGE, LOCATION
    from torchio.data.inference import GridAggregator, GridSampler

    print(input_array.shape)
    img_tens = torch.FloatTensor(input_array[:]).unsqueeze(0)
    print(f"Predict and aggregate on volume of {img_tens.shape}")

    one_subject = tio.Subject(
        img=tio.Image(tensor=img_tens, label=tio.INTENSITY),
        label=tio.Image(tensor=img_tens, label=tio.LABEL),
    )

    img_dataset = tio.SubjectsDataset(
        [
            one_subject,
        ]
    )
    img_sample = img_dataset[-1]

    batch_size = 1

    grid_sampler = GridSampler(img_sample, patch_size, patch_overlap)
    patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
    aggregator1 = GridAggregator(grid_sampler, overlap_mode=overlap_mode)

    input_tensors = []
    output_tensors = []

    if nb:
        from tqdm.notebook import tqdm
    else:
        from tqdm import tqdm

    with torch.no_grad():

        for patches_batch in tqdm(patch_loader):
            input_tensor = patches_batch["img"]["data"]
            locations = patches_batch[LOCATION]
            inputs_t = input_tensor
            inputs_t = inputs_t.to(device)

            if fpn:
                outputs = model3d(inputs_t)[0]
            else:
                outputs = model3d(inputs_t)
            if debug_verbose:
                print(f"inputs_t: {inputs_t.shape}")
                print(f"outputs: {outputs.shape}")

            output = outputs[:, 0:1, :]
            # output = torch.sigmoid(output)

            aggregator1.add_batch(output, locations)

    return aggregator1
Exemple #30
0
torch.manual_seed(0)

batch_size = 4
subject = tio.datasets.FPG()
subject.remove_image('seg')
subjects = 4 * [subject]

transform = tio.Compose((
    tio.ToCanonical(),
    tio.RandomGamma(p=0.75),
    tio.RandomBlur(p=0.5),
    tio.RandomFlip(),
    tio.RescaleIntensity((-1, 1)),
))

dataset = tio.SubjectsDataset(subjects, transform=transform)

transformed = dataset[0]
print('Applied transforms:')  # noqa: T001
pprint.pprint(transformed.history)  # noqa: T003
print('\nComposed transform to reproduce history:')  # noqa: T001
print(transformed.get_composed_history())  # noqa: T001
print('\nComposed transform to invert applied transforms when possible:'
      )  # noqa: T001, E501
print(transformed.get_inverse_transform(ignore_intensity=False))  # noqa: T001

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    collate_fn=tio.utils.history_collate,
)