def __init__( self, datasets_dir, real_dataset_dir, resection_params, train_batch_size, num_workers, pseudo_dir=None, split_ratio=0.9, split_seed=42, debug_ratio=0.02, log=None, debug=False, augment=True, verbose=False, cache_validation_set=True, histogram_standardization=True, ): super().__init__(datasets_dir, train_batch_size, num_workers) self.resection_params = resection_params # Precomputed from 90% of the public training data if histogram_standardization: self.landmarks_path = Path( __file__ ).parent / 'landmarks' / 'histogram_landmarks_default.npy' else: self.landmarks_path = None public_subjects = self.get_public_subjects() train_public, val_public = self.split_subjects(public_subjects, split_ratio, split_seed) train_transform = self.get_train_transform( ) if augment else self.get_val_transform() self.train_dataset = tio.SubjectsDataset(train_public, transform=train_transform) self.val_dataset = tio.SubjectsDataset(val_public, transform=train_transform) if cache_validation_set: self.val_dataset = cache(self.val_dataset, resection_params, augment=augment) test_transform = get_test_transform(self.landmarks_path) self.test_dataset = get_real_resection_dataset( real_dataset_dir, transform=test_transform) if debug: self.train_dataset = reduce_dataset(self.train_dataset, debug_ratio) self.val_dataset = reduce_dataset(self.val_dataset, debug_ratio) self.test_dataset = reduce_dataset(self.test_dataset, debug_ratio) self.train_loader = self.get_train_loader(self.train_dataset) self.val_loader = self.get_val_loader(self.val_dataset) self.test_loader = self.get_val_loader(self.test_dataset) self.log = log if verbose: self.print_lengths()
def get_torchio_dataset(inputs, targets, transform): """ Function creates a torchio.SubjectsDataset from inputs and targets lists and applies transform to that dataset Arguments: * inputs (list): list of paths to MR images * targets (list): list of paths to ground truth segmentation of MR images * transform (False/torchio.transforms): transformations which will be applied to MR images and ground truth segmentation of MR images (but not all of them) Output: * datasets (torchio.SubjectsDataset): it's kind of torchio list of torchio.data.subject.Subject entities """ subjects = [] for (image_path, label_path) in zip(inputs, targets ): subject_dict = { 'MRI' : torchio.Image(image_path, torchio.INTENSITY), 'LABEL': torchio.Image(label_path, torchio.LABEL), #intensity transformations won't be applied to torchio.LABEL } subject = torchio.Subject(subject_dict) subjects.append(subject) if transform: dataset = torchio.SubjectsDataset(subjects, transform = transform) elif not transform: dataset = torchio.SubjectsDataset(subjects) return dataset
def load_pretrain_datasets(data_shape, batch=3, workers=4, transform=None): data_path = '/home/mitch/Data/MSD/' directories = sorted(glob.glob(data_path + '*/')) loaders = [] #var to store dataloader for each task datasets = [] #store dataset objects before turning into loaders if transform == None: transform = tio.RandomFlip(p=0.) #preprocess all clippy = Lambda(lambda x: torch.clip(x, -80, 300), types_to_apply=[tio.INTENSITY]) normal = RescaleIntensity((0., 1.)) resize = Lambda(lambda x: torch.squeeze( interpolate(torch.unsqueeze(x, dim=0), data_shape), dim=0)) rounding = Lambda(lambda x: torch.round(x), types_to_apply=[tio.LABEL]) transform = tio.Compose([clippy, normal, resize, rounding, transform]) #deal with weird shapes braintransform = Lambda(lambda x: torch.unsqueeze(x[:, :, :, 2], dim=0), types_to_apply=[tio.INTENSITY]) braintransform = tio.Compose([braintransform, transform]) prostatetransform = Lambda(lambda x: torch.unsqueeze(x[:, :, :, 1], dim=0), types_to_apply=[tio.INTENSITY]) prostatetransform = tio.Compose([prostatetransform, transform]) for i, directory in enumerate(directories): images = sorted(glob.glob(directory + 'imagesTr/*')) segs = sorted(glob.glob(directory + 'labelsTr/*')) subject_list = [] for image, seg in zip(images, segs): subject_list.append( tio.Subject(img=tio.ScalarImage(image), label=tio.LabelMap(seg))) #handle special cases if i == 0: datasets.append( tio.SubjectsDataset(subject_list, transform=braintransform)) elif i == 4: datasets.append( tio.SubjectsDataset(subject_list, transform=prostatetransform)) else: datasets.append( tio.SubjectsDataset(subject_list, transform=transform)) loaders.append( DataLoader(datasets[-1], num_workers=workers, batch_size=batch, pin_memory=True)) return loaders
def get_dataset( input_path, tta_iterations=0, interpolation='bspline', tolerance=0.1, mni_transform_path=None, ): if mni_transform_path is None: image = tio.ScalarImage(input_path) else: affine = tio.io.read_matrix(mni_transform_path) image = tio.ScalarImage(input_path, **{TO_MNI: affine}) subject = tio.Subject({IMAGE_NAME: image}) landmarks = np.array([ 0., 0.31331614, 0.61505419, 0.76732501, 0.98887953, 1.71169384, 3.21741126, 13.06931455, 32.70817796, 40.87807389, 47.83508873, 63.4408591, 100. ]) hist_std = tio.HistogramStandardization({IMAGE_NAME: landmarks}) preprocess_transforms = [ tio.ToCanonical(), hist_std, tio.ZNormalization(masking_method=tio.ZNormalization.mean), ] zooms = nib.load(input_path).header.get_zooms() pixdim = np.array(zooms) diff_to_1_iso = np.abs(pixdim - 1) if np.any(diff_to_1_iso > tolerance) or mni_transform_path is not None: kwargs = {'image_interpolation': interpolation} if mni_transform_path is not None: kwargs['pre_affine_name'] = TO_MNI kwargs['target'] = tio.datasets.Colin27().t1.path resample_transform = tio.Resample(**kwargs) preprocess_transforms.append(resample_transform) preprocess_transforms.append(tio.EnsureShapeMultiple(8, method='crop')) preprocess_transform = tio.Compose(preprocess_transforms) no_aug_dataset = tio.SubjectsDataset([subject], transform=preprocess_transform) aug_subjects = tta_iterations * [subject] if not aug_subjects: return no_aug_dataset augment_transform = tio.Compose(( preprocess_transform, tio.RandomFlip(), tio.RandomAffine(image_interpolation=interpolation), )) aug_dataset = tio.SubjectsDataset(aug_subjects, transform=augment_transform) dataset = torch.utils.data.ConcatDataset((no_aug_dataset, aug_dataset)) return dataset
def load_kidney_seg(data_shape, batch=3, workers=4, transform=None): #take input transform and apply it after clip, normalization, resize if transform == None: transform = tio.RandomFlip(p=0.) #preprocess all clippy = Lambda(lambda x: torch.clip(x, -80, 300), types_to_apply=[tio.INTENSITY]) normal = RescaleIntensity((0., 1.)) resize = Lambda(lambda x: torch.squeeze( interpolate(torch.unsqueeze(x, dim=0), data_shape), dim=0)) rounding = Lambda(lambda x: torch.round(x), types_to_apply=[tio.LABEL]) transform = tio.Compose([clippy, normal, resize, rounding, transform]) subject_list = [] for i in range(210): pt_image = ("data/case_{:05d}/imaging.nii.gz".format(i)) pt_label = ("data/case_{:05d}/segmentation.nii.gz".format(i)) subject_list.append( tio.Subject(img=tio.ScalarImage(pt_image), label=tio.LabelMap(pt_label))) dataset = tio.SubjectsDataset(subject_list, transform=transform) return DataLoader(dataset, num_workers=workers, batch_size=batch, pin_memory=True)
def cache(dataset, resection_params, augment=True, caches_dir='/tmp/val_set_cache', num_workers=12): caches_dir = Path(caches_dir) wm_lesion_p = resection_params['wm_lesion_p'] clot_p = resection_params['clot_p'] shape = resection_params['shape'] texture = resection_params['texture'] augment_string = '_no_augmentation' if not augment else '' dir_name = f'wm_{wm_lesion_p}_clot_{clot_p}_{shape}_{texture}{augment_string}' cache_dir = caches_dir / dir_name image_dir = cache_dir / 'image' label_dir = cache_dir / 'label' if not cache_dir.is_dir(): print('Caching validation set') image_dir.mkdir(parents=True) label_dir.mkdir(parents=True) loader = torch.utils.data.DataLoader( dataset, num_workers=num_workers, collate_fn=lambda x: x[0], ) for subject in tqdm(loader): image_path = image_dir / subject.image.path.name label_path = label_dir / subject.image.path.name # label has no path because it was created not loaded subject.image.save(image_path) subject.label.save(label_path) subjects = [] for im_path, label_path in zip(sglob(image_dir), sglob(label_dir)): subject = tio.Subject( image=tio.ScalarImage(im_path), label=tio.LabelMap(label_path), ) subjects.append(subject) return tio.SubjectsDataset(subjects)
def get_subjects(path, structures, transform): """ Browse the path folder to build a dataset. Folder must contains the subjects with the CT and masks. :param path: root folder. :type path: str :param structures: list of structures. :type structures: list[str] :param transform: transforms to be applied. :type transform: :class:`tio.transforms.Transform` :return: Base TorchIO dataset. :rtype: :class:`tio.SubjectsDataset` """ subject_ids = os.listdir(path) subjects = [] for subject_id in subject_ids: ct_path = os.path.join(path, subject_id, 'ct.nii') structures_path_dict = {k: os.path.join(path, subject_id, k + '.nii') for k in structures} subject = tio.Subject( ct=tio.ScalarImage(ct_path), ) label_map = torch.zeros(subject["ct"].shape, dtype=torch.long) for i, (k, v) in enumerate(structures_path_dict.items()): label_map += tio.LabelMap(v).data * (i + 1) label_map[label_map > len(structures)] = 0 subject.add_image(tio.LabelMap(tensor=label_map, affine=subject["ct"].affine), 'label_map') subjects.append(subject) return tio.SubjectsDataset(subjects, transform=transform)
def __init__( self, fold, num_folds, datasets_dir, dataset_name, train_batch_size, num_workers, use_public_landmarks=False, pseudo_dirname=None, split_seed=42, log=None, verbose=True, ): super().__init__(datasets_dir, train_batch_size, num_workers) self.resection_params = None real_dataset_dir = self.datasets_dir / 'real' / dataset_name real_subjects = get_real_resection_subjects(real_dataset_dir) train_subjects, val_subjects = self.split_subjects( real_subjects, fold, num_folds, split_seed) self.train_dataset = tio.SubjectsDataset(train_subjects) if use_public_landmarks: self.landmarks_path = get_landmarks_path() else: self.landmarks_path = get_landmarks_path( dataset=self.train_dataset) train_transform = self.get_train_transform(resect=False) self.train_dataset.set_transform(train_transform) test_transform = get_test_transform(self.landmarks_path) self.val_dataset = tio.SubjectsDataset(val_subjects, transform=test_transform) if pseudo_dirname is not None: pseudo_dir = self.datasets_dir / 'real' / pseudo_dirname pseudo_dataset = get_real_resection_dataset( pseudo_dir, transform=train_transform) self.train_dataset = torch.utils.data.ConcatDataset( (self.train_dataset, pseudo_dataset)) self.train_loader = self.get_train_loader(self.train_dataset) self.val_loader = self.test_loader = self.get_val_loader( self.val_dataset) self.log = log if verbose: self.print_lengths(test=False)
def test_from_batch(self): dataset = tio.SubjectsDataset([self.sample_subject]) loader = DataLoader(dataset) batch = tio.utils.get_first_item(loader) new_dataset = tio.SubjectsDataset.from_batch(batch) self.assertTensorEqual( dataset[0].t1.data, new_dataset[0].t1.data, )
def test_label_probabilities(self): labels = torch.Tensor((0, 0, 1, 1, 2, 1, 0)).reshape(1, 1, 1, -1) subject = torchio.Subject(label=torchio.Image(tensor=labels, type=torchio.LABEL), ) sample = torchio.SubjectsDataset([subject])[0] probs_dict = {0: 0, 1: 50, 2: 25, 3: 25} sampler = LabelSampler(5, 'label', label_probabilities=probs_dict) probabilities = sampler.get_probability_map(sample) fixture = torch.Tensor((0, 0, 2 / 12, 2 / 12, 3 / 12, 2 / 12, 0)) assert torch.all(probabilities.squeeze().eq(fixture))
def get_sample(self, image_shape): t1 = torch.rand(*image_shape) prob = torch.zeros_like(t1) prob[0, 3, 3, 3] = 1 subject = torchio.Subject( t1=torchio.ScalarImage(tensor=t1), prob=torchio.ScalarImage(tensor=prob), ) subject = torchio.SubjectsDataset([subject])[0] return subject
def createTIODynDS(path_gt, path_corrupt, is_infer=False, p=1, transforms=[], **kwargs): files_gt = glob(path_gt + "/**/*.nii", recursive=True) + glob( path_gt + "/**/*.nii.gz", recursive=True) if path_corrupt: files_inp = glob(path_corrupt + "/**/*.nii", recursive=True) + glob( path_corrupt + "/**/*.nii.gz", recursive=True) corruptFly = False else: files_inp = files_gt.copy() corruptFly = True subjects = [] inp_dicts, files_inp = __process_TPs(files_inp) gt_dicts, _ = __process_TPs(files_gt) for filename in files_inp: inp_files = [d for d in inp_dicts if filename in d['filename']] gt_files = [d for d in gt_dicts if filename in d['filename']] tps = list(set(dic["tp"] for dic in inp_files)) tp_prev = tps.pop(0) for tp in tps: inp_tp_prev = [d for d in inp_files if tp_prev == d['tp']] gt_tp_prev = [d for d in gt_files if tp_prev == d['tp']] inp_tp = [d for d in inp_files if tp == d['tp']] gt_tp = [d for d in gt_files if tp == d['tp']] tp_prev = tp if len(gt_tp_prev) > 0 and len(gt_tp) > 0: subjects.append( tio.Subject( gt_tp_prev=tio.ScalarImage(gt_tp_prev[0]['path']), inp_tp_prev=tio.ScalarImage(inp_tp_prev[0]['path']), gt=tio.ScalarImage(gt_tp[0]['path']), inp=tio.ScalarImage(inp_tp[0]['path']), filename=filename, tp=tp, tag="CorruptNGT", )) else: print( "Warning: Not Implemented if GT is missing. Skipping Sub-TP." ) continue if corruptFly: moco = MotionCorrupter(**kwargs) transforms.append(tio.Lambda(moco.perform, p=p)) transforms.append(ProcessTIOSubsTPs()) transform = tio.Compose(transforms) subjects_dataset = tio.SubjectsDataset(subjects, transform=transform) return subjects_dataset
def calculate_ssim(global_list): global_ssim = [] for i, scale_factors in enumerate(global_list): set_ssim = [] test_dataset = tio.SubjectsDataset(scale_factors, transform=test_transform) for sample in tqdm(test_dataset): _, _, _, ssim_val = test_network(sample) set_ssim.append(ssim_val) global_ssim.append(set_ssim) return global_ssim
def __init__(self, images_dir, labels_dir): self.subjects = [] if (hp.in_class == 1) and (hp.out_class == 1): images_dir = Path(images_dir) self.image_paths = sorted(images_dir.glob(hp.fold_arch)) labels_dir = Path(labels_dir) self.label_paths = sorted(labels_dir.glob(hp.fold_arch)) for (image_path, label_path) in zip(self.image_paths, self.label_paths): subject = tio.Subject( source=tio.ScalarImage(image_path), label=tio.LabelMap(label_path), ) self.subjects.append(subject) else: images_dir = Path(images_dir) self.image_paths = sorted(images_dir.glob(hp.fold_arch)) artery_labels_dir = Path(labels_dir + '/artery') self.artery_label_paths = sorted( artery_labels_dir.glob(hp.fold_arch)) lung_labels_dir = Path(labels_dir + '/lung') self.lung_label_paths = sorted(lung_labels_dir.glob(hp.fold_arch)) trachea_labels_dir = Path(labels_dir + '/trachea') self.trachea_label_paths = sorted( trachea_labels_dir.glob(hp.fold_arch)) vein_labels_dir = Path(labels_dir + '/vein') self.vein_label_paths = sorted(vein_labels_dir.glob(hp.fold_arch)) for (image_path, artery_label_path, lung_label_path, trachea_label_path, vein_label_path) in zip( self.image_paths, self.artery_label_paths, self.lung_label_paths, self.trachea_label_paths, self.vein_label_paths): subject = tio.Subject( source=tio.ScalarImage(image_path), atery=tio.LabelMap(artery_label_path), lung=tio.LabelMap(lung_label_path), trachea=tio.LabelMap(trachea_label_path), vein=tio.LabelMap(vein_label_path), ) self.subjects.append(subject) self.transforms = self.transform() self.training_set = tio.SubjectsDataset(self.subjects, transform=self.transforms)
def training_network(landmarks, dataset, subjects): training_transform = Compose([ ToCanonical(), Resample(4), CropOrPad((48, 60, 48), padding_mode='reflect'), RandomMotion(), HistogramStandardization({'mri': landmarks}), RandomBiasField(), ZNormalization(masking_method=ZNormalization.mean), RandomNoise(), RandomFlip(axes=(0, )), OneOf({ RandomAffine(): 0.8, RandomElasticDeformation(): 0.2, }), ]) validation_transform = Compose([ ToCanonical(), Resample(4), CropOrPad((48, 60, 48), padding_mode='reflect'), HistogramStandardization({'mri': landmarks}), ZNormalization(masking_method=ZNormalization.mean), ]) training_split_ratio = 0.9 num_subjects = len(dataset) num_training_subjects = int(training_split_ratio * num_subjects) training_subjects = subjects[:num_training_subjects] validation_subjects = subjects[num_training_subjects:] training_set = tio.SubjectsDataset(training_subjects, transform=training_transform) validation_set = tio.SubjectsDataset(validation_subjects, transform=validation_transform) print('Training set:', len(training_set), 'subjects') print('Validation set:', len(validation_set), 'subjects') return training_set, validation_set
def __init__(self, imgs_path): self.img_list = get_listdir(imgs_path) self.img_list.sort() self.subjects = [] for image_path in self.img_list: subject = torchio.Subject( source=torchio.ScalarImage(image_path) ) self.subjects.append(subject) self.transforms = self.transform() self.test_set = torchio.SubjectsDataset(self.subjects, transform=self.transforms)
def create_trainDS(path, p=1, **kwargs): files = glob(path + "/**/*.nii", recursive=True) + glob( path + "/**/*.nii.gz", recursive=True) subjects = [] for file in files: subjects.append( tio.Subject( im=tio.ScalarImage(file), filename=os.path.basename(file), )) moco = MotionCorrupter(**kwargs) transforms = [tio.Lambda(moco.perform, p=p)] transform = tio.Compose(transforms) subjects_dataset = tio.SubjectsDataset(subjects, transform=transform) return subjects_dataset
def create_trainDS_precorrupt(path_gt, path_corrupt, p=1, norm_mode=0): files = glob(path_gt + "/**/*.nii", recursive=True) + glob( path_gt + "/**/*.nii.gz", recursive=True) subjects = [] for file in files: subjects.append( tio.Subject( im=tio.ScalarImage(file), filename=os.path.basename(file), )) transforms = [ ReadCorrupted(path_corrupt=path_corrupt, p=p, norm_mode=norm_mode) ] transform = tio.Compose(transforms) subjects_dataset = tio.SubjectsDataset(subjects, transform=transform) return subjects_dataset
def setUp(self): """Set up test fixtures, if any.""" self.dir = Path(tempfile.gettempdir()) / '.torchio_tests' self.dir.mkdir(exist_ok=True) random.seed(42) np.random.seed(42) registration_matrix = np.array([ [1, 0, 0, 10], [0, 1, 0, 0], [0, 0, 1.2, 0], [0, 0, 0, 1] ]) subject_a = tio.Subject( t1=tio.ScalarImage(self.get_image_path('t1_a')), ) subject_b = tio.Subject( t1=tio.ScalarImage(self.get_image_path('t1_b')), label=tio.LabelMap(self.get_image_path('label_b', binary=True)), ) subject_c = tio.Subject( label=tio.LabelMap(self.get_image_path('label_c', binary=True)), ) subject_d = tio.Subject( t1=tio.ScalarImage( self.get_image_path('t1_d'), pre_affine=registration_matrix, ), t2=tio.ScalarImage(self.get_image_path('t2_d')), label=tio.LabelMap(self.get_image_path('label_d', binary=True)), ) subject_a4 = tio.Subject( t1=tio.ScalarImage(self.get_image_path('t1_a'), components=2), ) self.subjects_list = [ subject_a, subject_a4, subject_b, subject_c, subject_d, ] self.dataset = tio.SubjectsDataset(self.subjects_list) self.sample_subject = self.dataset[-1] # subject_d
def SubjectsDataset(): images_dir = dataset_dir / 'image' labels_dir = dataset_dir / 'label' image_paths = sorted(images_dir.glob('*.nii.gz')) label_paths = sorted(labels_dir.glob('*.nii.gz')) assert len(image_paths) == len(label_paths) subjects = [] for (image_path, label_path) in zip(image_paths, label_paths): subject = tio.Subject( mri=tio.ScalarImage(image_path), brain=tio.LabelMap(label_path), ) subjects.append(subject) # subjects = np.array(subjects) dataset = tio.SubjectsDataset(subjects) print('Dataset size:', len(dataset), 'subjects') ## => Dataset size : 566 subjects return dataset, subjects
def test_batch_history(self): # https://github.com/fepegar/torchio/discussions/743 subject = self.sample_subject transform = tio.Compose([ tio.RandomAffine(), tio.CropOrPad(5), tio.OneHot(), ]) dataset = tio.SubjectsDataset([subject], transform=transform) loader = torch.utils.data.DataLoader( dataset, collate_fn=tio.utils.history_collate ) batch = tio.utils.get_first_item(loader) transformed: tio.Subject = tio.utils.get_subjects_from_batch(batch)[0] inverse = transformed.apply_inverse_transform() images1 = subject.get_images(intensity_only=False) images2 = inverse.get_images(intensity_only=False) for image1, image2 in zip(images1, images2): assert image1.shape == image2.shape
def __init__(self, path, images=None, labels=None, transforms=None): self.transforms = transforms self.subjects = [] self.images = images self.labels = labels self.subject_folder_names = os.listdir(path) self.subject_folders = [f"{path}/{folder}/" for folder in self.subject_folder_names] for subject_folder in self.subject_folders: subject_files = os.listdir(subject_folder) subject_data = {} attributes_file = "attributes.json" if attributes_file in subject_files: with open(f"{subject_folder}/{attributes_file}") as f: subject_data = json.load(f) subject_files.remove(attributes_file) file_map = {file[:file.find(".")]: file for file in subject_files} missing_name = False all_names = [] if images is not None: all_names += images if labels is not None: all_names += labels for name in all_names: if name not in file_map: missing_name = True if missing_name: continue if images is not None: for name in images: subject_data[name] = tio.ScalarImage(subject_folder + file_map[name]) if labels is not None: for name in labels: subject_data[name] = tio.LabelMap(subject_folder + file_map[name]) self.subjects.append(tio.Subject(**subject_data)) self.subject_dataset = tio.SubjectsDataset(self.subjects, transform=transforms)
def main(image_dir, label_dir, checkpoint_path, output_dir, landmarks_path, df_path, batch_size, num_workers, multi_gpu): import torch import torchio as tio import models import datasets import engine import utils fps = get_paths(image_dir) lfps = get_paths(label_dir) assert len(fps) == len(lfps) # key must be 'image' as in get_test_transform subjects = [ tio.Subject(image=tio.ScalarImage(fp), label=tio.LabelMap(lfp)) for (fp, lfp) in zip(fps, lfps) ] transform = datasets.get_test_transform(landmarks_path) dataset = tio.SubjectsDataset(subjects, transform) checkpoint = torch.load(checkpoint_path) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = models.get_unet().to(device) if multi_gpu: model = torch.nn.DataParallel(model) model.module.load_state_dict(checkpoint['model']) else: model.load_state_dict(checkpoint['model']) output_dir = Path(output_dir) model.eval() torch.set_grad_enabled(False) loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) output_dir.mkdir(parents=True) evaluator = engine.Evaluator() df = evaluator.infer(model, loader, output_dir) df.to_csv(df_path) med, iqr = 100 * utils.get_median_iqr(df.Dice) print(f'{med:.1f} ({iqr:.1f})') return 0
def get_pseudo_loader( threshold, percentile, metric, summary_path, dataset_name, num_workers, batch_size=2, remove_zero_volume=False, ): subjects = [] subject_ids = get_certain_subjects( threshold, percentile, metric, summary_path, remove_zero_volume=remove_zero_volume, ) dataset_dir = Path('/home/fernando/datasets/real/') / dataset_name assert dataset_dir.is_dir() image_dir = dataset_dir / 'image' label_dir = dataset_dir / 'label' for subject_id in subject_ids: image_path = list(image_dir.glob(f'{subject_id}_*'))[0] label_path = list(label_dir.glob(f'{subject_id}_*'))[0] subject = tio.Subject( image=tio.ScalarImage(image_path), label=tio.LabelMap(label_path), ) subjects.append(subject) transform = get_train_transform(get_landmarks_path()) dataset = tio.SubjectsDataset(subjects, transform=transform) loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, pin_memory=True, shuffle=True, num_workers=num_workers, ) return loader
def test_dataset(self, modalities): """ Testing the Dataset class Note that test is not for """ output_path_mod = pathlib.Path(self.output_path, str("temp_folder_" + ("_").join(modalities.split(",")))).as_posix() comp_path = pathlib.Path(output_path_mod).resolve().joinpath('dataset.csv').as_posix() comp_table = pd.read_csv(comp_path, index_col=0) print(comp_path, comp_table) #Loading from nrrd files subjects_nrrd = Dataset.load_image(output_path_mod, ignore_multi=True) #Loading files directly # subjects_direct = Dataset.load_directly(self.input_path,modalities=modalities,ignore_multi=True) #The number of subjects is equal to the number of components which is 2 for this dataset # assert len(subjects_nrrd) == len(subjects_direct) == 2, "There was some error in generation of subject object" # assert subjects_nrrd[0].keys() == subjects_direct[0].keys() # del subjects_direct # To check if all metadata items present in the keys # temp_nrrd = subjects_nrrd[0] # columns_shdbe_present = set([col if col.split("_")[0]=="metadata" else "mod_"+("_").join(col.split("_")[1:]) for col in list(comp_table.columns) if col.split("_")[0] in ["folder","metadata"]]) # print(columns_shdbe_present) # assert set(temp_nrrd.keys()).issubset(columns_shdbe_present), "Not all items present in dictionary, some fault in going through the different columns in a single component" transforms = tio.Compose([tio.Resample(4), tio.CropOrPad((96,96,40)), select_roi_names(["larynx"]), tio.OneHot()]) #Forming dataset and dataloader test_set = tio.SubjectsDataset(subjects_nrrd, transform=transforms) test_loader = torch.utils.data.DataLoader(test_set,batch_size=2,shuffle=True,collate_fn = collate_fn) #Check test_set is correct assert len(test_set)==2 #Get items from test loader #If this function fails , there is some error in formation of test data = next(iter(test_loader)) A = [item[1].shape == (2,1,96,96,40) if not "RTSTRUCT" in item[0] else item[1].shape == (2,2,96,96,40) for item in data.items()] assert all(A), "There is some problem in the transformation/the formation of subject object"
def main(input_path, checkpoint_path, output_dir, landmarks_path, batch_size, num_workers, resample): import torch from tqdm import tqdm import torchio as tio import models import datasets fps = get_paths(input_path) subjects = [tio.Subject(image=tio.ScalarImage(fp)) for fp in fps] # key must be 'image' as in get_test_transform transform = tio.Compose(( tio.ToCanonical(), datasets.get_test_transform(landmarks_path), )) if resample: transform = tio.Compose(( tio.Resample(), transform, # tio.CropOrPad((264, 268, 144)), # ################################# for BITE? )) dataset = tio.SubjectsDataset(subjects, transform) checkpoint = torch.load(checkpoint_path) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = models.get_unet().to(device) model.load_state_dict(checkpoint['model']) output_dir = Path(output_dir) model.eval() torch.set_grad_enabled(False) loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) output_dir.mkdir(exist_ok=True, parents=True) for batch in tqdm(loader): inputs = batch['image'][tio.DATA].float().to(device) seg = model(inputs).softmax(dim=1)[:, 1:].cpu() > 0.5 for tensor, affine, path in zip(seg, batch['image'][tio.AFFINE], batch['image'][tio.PATH]): image = tio.LabelMap(tensor=tensor, affine=affine.numpy()) path = Path(path) out_path = output_dir / path.name.replace('.nii', '_seg_cnn.nii') image.save(out_path) return 0
def createTIODS(path_gt, path_corrupt, is_infer=False, p=1, transforms=[], **kwargs): files_gt = glob(path_gt + "/**/*.nii", recursive=True) + glob( path_gt + "/**/*.nii.gz", recursive=True) if path_corrupt: files_inp = glob(path_corrupt + "/**/*.nii", recursive=True) + glob( path_corrupt + "/**/*.nii.gz", recursive=True) corruptFly = False else: files_inp = files_gt.copy() corruptFly = True subjects = [] for file in files_inp: filename = os.path.basename(file) gt_files = [f for f in files_gt if filename in f] if len(gt_files) > 0: gt_path = gt_files[0] files_gt.remove(gt_path) subjects.append( tio.Subject( gt=tio.ScalarImage(gt_path), inp=tio.ScalarImage(file), filename=filename, tag="CorruptNGT", )) if corruptFly: moco = MotionCorrupter(**kwargs) transforms.append(tio.Lambda(moco.perform, p=p)) transform = tio.Compose(transforms) subjects_dataset = tio.SubjectsDataset(subjects, transform=transform) return subjects_dataset
def apply_transforms(self, image, labels): #inputs = np.asarray(image, dtype=np.float32) inputs = image inputs = torch.tensor(inputs, dtype=torch.float, requires_grad=False) labels = torch.tensor(labels, dtype=torch.long, requires_grad=False) """ Expected input is: (C x W x H x D) """ inputs = inputs.unsqueeze(0) inputs = torch.moveaxis(inputs, 1, -1) labels = labels.unsqueeze(0) labels = torch.moveaxis(labels, 1, -1) subject_a = tio.Subject( one_image=tio.ScalarImage(tensor=inputs), # *** must be tensors!!! a_segmentation=tio.LabelMap(tensor=labels)) subjects_list = [subject_a] subjects_dataset = tio.SubjectsDataset(subjects_list, transform=self.transforms) subject_sample = subjects_dataset[0] X = subject_sample['one_image']['data'].numpy() Y = subject_sample['a_segmentation']['data'].numpy() """ Re-arrange channels for Pytorch into (D, H, W) """ X = X[0] X = np.moveaxis(X, -1, 0) Y = Y[0] Y = np.moveaxis(Y, -1, 0) """ DEBUG """ #plot_max(X) #plot_max(Y) return X, Y
def predict_agg_3d( input_array, model3d, patch_size=(128, 224, 224), patch_overlap=(12, 12, 12), nb=True, device=0, debug_verbose=False, fpn=False, overlap_mode="crop", ): import torchio as tio from torchio import IMAGE, LOCATION from torchio.data.inference import GridAggregator, GridSampler print(input_array.shape) img_tens = torch.FloatTensor(input_array[:]).unsqueeze(0) print(f"Predict and aggregate on volume of {img_tens.shape}") one_subject = tio.Subject( img=tio.Image(tensor=img_tens, label=tio.INTENSITY), label=tio.Image(tensor=img_tens, label=tio.LABEL), ) img_dataset = tio.SubjectsDataset( [ one_subject, ] ) img_sample = img_dataset[-1] batch_size = 1 grid_sampler = GridSampler(img_sample, patch_size, patch_overlap) patch_loader = DataLoader(grid_sampler, batch_size=batch_size) aggregator1 = GridAggregator(grid_sampler, overlap_mode=overlap_mode) input_tensors = [] output_tensors = [] if nb: from tqdm.notebook import tqdm else: from tqdm import tqdm with torch.no_grad(): for patches_batch in tqdm(patch_loader): input_tensor = patches_batch["img"]["data"] locations = patches_batch[LOCATION] inputs_t = input_tensor inputs_t = inputs_t.to(device) if fpn: outputs = model3d(inputs_t)[0] else: outputs = model3d(inputs_t) if debug_verbose: print(f"inputs_t: {inputs_t.shape}") print(f"outputs: {outputs.shape}") output = outputs[:, 0:1, :] # output = torch.sigmoid(output) aggregator1.add_batch(output, locations) return aggregator1
torch.manual_seed(0) batch_size = 4 subject = tio.datasets.FPG() subject.remove_image('seg') subjects = 4 * [subject] transform = tio.Compose(( tio.ToCanonical(), tio.RandomGamma(p=0.75), tio.RandomBlur(p=0.5), tio.RandomFlip(), tio.RescaleIntensity((-1, 1)), )) dataset = tio.SubjectsDataset(subjects, transform=transform) transformed = dataset[0] print('Applied transforms:') # noqa: T001 pprint.pprint(transformed.history) # noqa: T003 print('\nComposed transform to reproduce history:') # noqa: T001 print(transformed.get_composed_history()) # noqa: T001 print('\nComposed transform to invert applied transforms when possible:' ) # noqa: T001, E501 print(transformed.get_inverse_transform(ignore_intensity=False)) # noqa: T001 loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, collate_fn=tio.utils.history_collate, )