def get_torchio_dataset(inputs, targets, transform): """ Function creates a torchio.SubjectsDataset from inputs and targets lists and applies transform to that dataset Arguments: * inputs (list): list of paths to MR images * targets (list): list of paths to ground truth segmentation of MR images * transform (False/torchio.transforms): transformations which will be applied to MR images and ground truth segmentation of MR images (but not all of them) Output: * datasets (torchio.SubjectsDataset): it's kind of torchio list of torchio.data.subject.Subject entities """ subjects = [] for (image_path, label_path) in zip(inputs, targets ): subject_dict = { 'MRI' : torchio.Image(image_path, torchio.INTENSITY), 'LABEL': torchio.Image(label_path, torchio.LABEL), #intensity transformations won't be applied to torchio.LABEL } subject = torchio.Subject(subject_dict) subjects.append(subject) if transform: dataset = torchio.SubjectsDataset(subjects, transform = transform) elif not transform: dataset = torchio.SubjectsDataset(subjects) return dataset
def build(self): SEED = 42 data = pd.read_csv(self.data) ab = data.label ############################################ transforms = [ RescaleIntensity((0, 1)), RandomAffine(), transformss.ToTensor(), ] transform = Compose(transforms) ############################################# dataset_dir = self.dataset_dir dataset_dir = Path(dataset_dir) images_dir = dataset_dir labels_dir = dataset_dir image_paths = sorted(images_dir.glob('**/*.nii')) label_paths = sorted(labels_dir.glob('**/*.nii')) assert len(image_paths) == len(label_paths) # These two names are arbitrary MRI = 'features' BRAIN = 'targets' #split dataset into training and validation from catalyst.utils import split_dataframe_train_test train_image_paths, valid_image_paths = split_dataframe_train_test( image_paths, test_size=0.2, random_state=SEED) #training data subjects = [] i = 0 for (image_path, label_path) in zip(train_image_paths, label_paths): subject_dict = { MRI: torchio.Image(image_path, torchio.INTENSITY), BRAIN: ab[i], } i = i + 1 subject = torchio.Subject(subject_dict) subjects.append(subject) train_data = torchio.ImagesDataset(subjects) #validation data subjects = [] for (image_path, label_path) in zip(valid_image_paths, label_paths): subject_dict = { MRI: torchio.Image(image_path, torchio.INTENSITY), BRAIN: ab[i], } i = i + 1 subject = torchio.Subject(subject_dict) subjects.append(subject) test_data = torchio.ImagesDataset(subjects) return train_data, test_data
def get_sample(self, image_shape): t1 = torch.rand(*image_shape) prob = torch.zeros_like(t1) prob[3, 3, 3] = 1 subject = torchio.Subject( t1=torchio.Image(tensor=t1), prob=torchio.Image(tensor=prob), ) sample = torchio.ImagesDataset([subject])[0] return sample
def pad_3d_if_required(instance, size): r"""Pads if required in the last dimension, for 3D. """ if instance.shape[-1] < size[-1]: delta = size[-1]-instance.shape[-1] subject = instance.get_subject() transform = torchio.transforms.Pad(padding=(0, 0, 0, 0, 0, delta), padding_mode=0) subject = transform(subject) instance.x = torchio.Image(tensor=subject.x.tensor, type=torchio.INTENSITY) instance.y = torchio.Image(tensor=subject.y.tensor, type=torchio.LABEL) instance.shape = subject.shape return instance
def __init__(self, root_dir, img_range=(0,0)): self.root_dir = root_dir self.img_range = img_range subject_lists = [] #check if there is a labels if self.root_dir[-1] != '/': self.root_dir += '/' self.is_labeled = os.path.isdir(self.root_dir + LABEL_DIR) self.files = [re.findall('[0-9]{4}', filename)[0] for filename in os.listdir(self.root_dir + TRAIN_DIR)] self.files = sorted(self.files, key = lambda f : int(f)) # store all subjects in the list for img_num in range(img_range[0], img_range[1]+1): img_file = os.path.join(self.root_dir, TRAIN_DIR, IMG_PREFIX + self.files[img_num] + EXT) label_file = os.path.join(self.root_dir, LABEL_DIR, LABEL_PREFIX + self.files[img_num] + EXT) subject = torchio.Subject( torchio.Image('t1', img_file, torchio.INTENSITY), torchio.Image('label', label_file, torchio.LABEL) ) subject_lists.append(subject) print(img_file) print(label_file) # Define transforms for data normalization and augmentation mtransforms = ( ZNormalization(), #transforms.RandomNoise(std_range=(0, 0.25)), #transforms.RandomFlip(axes=(0,)), ) self.subjects = torchio.ImagesDataset(subject_lists, transform=transforms.Compose(mtransforms)) self.dataset = torchio.Queue( subjects_dataset=self.subjects, max_length=2, samples_per_volume=675, sampler_class=torchio.sampler.ImageSampler, patch_size=(240, 240, 3), num_workers=4, shuffle_subjects=False, shuffle_patches=True ) print("Dataset details\n Images: {}".format(self.img_range[1] - self.img_range[0] + 1))
def get_original_subjects(): """ get data from the path and do augmentation on it, and return a DataLoader :return: list of subjects """ if COMPUTECANADA: datasets = [ADNI_DATASET_DIR_1] else: datasets = [ADNI_DATASET_DIR_1] subjects = [ tio.Subject( img=tio.Image(path=mri.img_path, type=tio.INTENSITY), label=tio.Image(path=mri.label_path, type=tio.LABEL), # store the dataset name to help plot the image later # dataset=mri.dataset ) for mri in get_path(datasets) ] visual_img_path_list = [] visual_label_path_list = [] for mri in get_1069_path(datasets): visual_img_path_list.append(mri.img_path) visual_label_path_list.append(mri.label_path) # using in the cropping folder # img_path_list = sorted([ # Path(f) for f in sorted(glob(f"{str(CROPPED_IMG)}/**/*.nii*", recursive=True)) # ]) # label_path_list = sorted([ # Path(f) for f in sorted(glob(f"{str(CROPPED_LABEL)}/**/*.nii.gz", recursive=True)) # ]) # # subjects = [ # tio.Subject( # img=tio.Image(path=img_path, type=tio.INTENSITY), # label=tio.Image(path=label_path, type=tio.LABEL), # # store the dataset name to help plot the image later # # dataset=mri.dataset # ) for img_path, label_path in zip(img_path_list, label_path_list) # ] print(f"{ctime()}: getting number of subjects {len(subjects)}") print( f"{ctime()}: getting number of path for visualizationg {len(visual_img_path_list)}" ) return subjects, visual_img_path_list, visual_label_path_list
def infer_with_patches(self, model_inference_function, features): # This function infers using multiple patches, fusing corresponding outputs # model_inference_function is a list to suport recursive calls to similar function subject_dict = {} for i in range(0, features.shape[1]): # 0 is batch subject_dict[str(i)] = torchio.Image(tensor=features[:, i, :, :, :], type=torchio.INTENSITY) grid_sampler = torchio.inference.GridSampler( torchio.Subject(subject_dict), self.psize) patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=1) aggregator = torchio.inference.GridAggregator(grid_sampler) for patches_batch in patch_loader: # concatenate the different modalities into a tensor image = torch.cat([ patches_batch[str(i)][torchio.DATA] for i in range(0, features.shape[1]) ], dim=1) locations = patches_batch[ torchio.LOCATION] # get location of patch pred_mask = model_inference_function[0]( model_inference_function=model_inference_function[1:], features=image) aggregator.add_batch(pred_mask, locations) output = aggregator.get_output_tensor() # this is the final mask output = output.unsqueeze( 0) # increasing the number of dimension of the mask return output
def _affine(data): for key in data: data[key] = torch.Tensor(data[key]) subjs = { 'label': torchio.Image(tensor=data['label'], type=torchio.LABEL) } shape = data['image'].shape # We need to seperate out the case of 4D image if len(shape) == 4: n_channels = shape[-1] for i in range(n_channels): subjs.update({ f'ch{i}': torchio.Image(tensor=data['image'][..., i], type=torchio.INTENSITY) }) else: assert len(shape) == 3 subjs.update({ 'image': torchio.Image(tensor=data['image'], type=torchio.INTENSITY) }) transformed = transform(torchio.Subject(**subjs)) if 'image' in subjs.keys(): data['image'] = transformed.image.numpy() else: # if image contains multiple channels, # then aggregate the transformed results into one data['image'] = np.stack(tuple( getattr(transformed, ch).numpy() for ch in subjs.keys() if 'ch' in ch), axis=-1) data['label'] = transformed.label.numpy() for key in data: data[key] = data[key].squeeze() return data
def test_label_probabilities(self): labels = torch.Tensor((0, 0, 1, 1, 2, 1, 0)).reshape(1, 1, 1, -1) subject = torchio.Subject(label=torchio.Image(tensor=labels, type=torchio.LABEL), ) sample = torchio.SubjectsDataset([subject])[0] probs_dict = {0: 0, 1: 50, 2: 25, 3: 25} sampler = LabelSampler(5, 'label', label_probabilities=probs_dict) probabilities = sampler.get_probability_map(sample) fixture = torch.Tensor((0, 0, 2 / 12, 2 / 12, 3 / 12, 2 / 12, 0)) assert torch.all(probabilities.squeeze().eq(fixture))
def get_torchio_dataset(inputs, targets, transform): """ The function creates dataset from the list of files from cunstumised dataloader. """ subjects = [] for (image_path, label_path) in zip(inputs, targets ): subject_dict = { MRI : torchio.Image(image_path, torchio.INTENSITY), LABEL: torchio.Image(label_path, torchio.LABEL), } subject = torchio.Subject(subject_dict) subjects.append(subject) if transform: dataset = torchio.ImagesDataset(subjects, transform = transform) elif not transform: dataset = torchio.ImagesDataset(subjects) return dataset
def get_image_patches(input_img_name, mod_nb, gmpm=None, use_coronal=False, use_sagital=False, input_mask_name=None, augment=True, h=16, w=32, coef=.2, record_results=False, pred_labels=None): subject_dict = { 'mri': torchio.Image(input_img_name, torchio.INTENSITY), } # torchio normalization t1_landmarks = Path(f'./data/t1_landmarks_{mod_nb}.npy') landmarks_dict = {'mri': t1_landmarks} histogram_transform = HistogramStandardization(landmarks_dict) znorm_transform = ZNormalization(masking_method=ZNormalization.mean) transform = torchio.transforms.Compose( [histogram_transform, znorm_transform]) subject = torchio.Subject(subject_dict) zimage = transform(subject) target_np = zimage['mri'].data[0].numpy() if input_mask_name is not None: mask = nib.load(input_mask_name) mask_np = (mask.get_fdata() > 0).astype('float') else: mask_np = np.zeros_like(target_np) all_patches, all_labels, side_mask_np, mid_mask_np = get_patches_and_labels( target_np, gmpm, mask_np, use_coronal=use_coronal, use_sagital=use_sagital, h=h, w=w, coef=coef, augment=augment, record_results=record_results, pred_labels=pred_labels) if not record_results: return all_patches, all_labels else: return side_mask_np, mid_mask_np
def upload_raw_data(img_path, table_path, names): subjects = [] ages = [] genders = [] df = pd.read_csv(table_path) for name in names: file_ = os.path.join(img_path, str(name), 'T1w', 'T1w_acpc_dc_restore_brain.nii.gz') subject = torchio.Subject( torchio.Image('MRI', file_, torchio.INTENSITY)) subjects.append(subject) ages.append(df.Age.values[df.Subject == name][0]) genders.append(df.Gender.values[df.Subject == name][0]) data = {'images': subjects, 'genders': genders, 'ages': ages} return data
def load_subject_(self, index): sample = self.patients[index % len(self.patients)] # load mr and turs file if it hasn't already been loaded if sample not in self.subjects: # print(f'loading patient {sample}') if self.load_mask: subject = torchio.Subject(mr=torchio.ScalarImage(sample + "/mr.mhd"), trus=torchio.ScalarImage(sample + "/trus.mhd"), mr_tree=torchio.LabelMap(sample + "/mr_tree.mhd")) else: subject = torchio.Subject(mr=torchio.ScalarImage(sample + "/mr.mhd"), trus=torchio.Image(sample + "/trus.mhd")) self.subjects[sample] = subject subject = self.subjects[sample] return sample, subject
def detection_pipeline(self, input_img_name, input_mask_name=None, save_mask_name='pred_mask.nii.gz', probs=False): img = nib.load(input_img_name) subject_dict = { 'mri': torchio.Image(input_img_name, torchio.INTENSITY), } subject = torchio.Subject(subject_dict) zimage = self.transform(subject) img_np = zimage['mri'].data[0].numpy() if not probs: side_mask_np, mid_mask_np = self.get_mask(img_np) if input_mask_name is not None: true_mask = nib.load(input_mask_name) true_mask_np = true_mask.get_fdata() > 0 iou = self.get_iou(side_mask_np, mid_mask_np, true_mask_np) print('Intersection over union = {:.5f}'.format(iou)) else: iou = None self.save_nii_mask(pred_mask_np, img, save_mask_name) return side_mask_np, mid_mask_np, iou else: side_mask_np, mid_mask_np = self.get_prob_masks(img_np) if save_mask_name is not None: self.save_nii_mask( side_mask_np, img, os.path.join( f'./data/predicted_masks/{self.experiment_name}', 'side_' + os.path.basename(save_mask_name))) self.save_nii_mask( mid_mask_np, img, os.path.join( f'./data/predicted_masks/{self.experiment_name}', 'mid_' + os.path.basename(save_mask_name))) return side_mask_np, mid_mask_np, None
def __getitem__(self, idx): # Returns a tuple of the image and its group/label imgsize = 224 if torch.is_tensor(idx): idx = idx.tolist() imagepath = self.imagepaths[idx] label = get_label(imagepath, csvpath) try: subject = torchio.Subject( {'mri': torchio.Image(imagepath, torchio.INTENSITY)}) transformed_subject = transform(subject) # create imgbatch with three different perspectives imgbatch = [] imgdata = transformed_subject['mri'].data.reshape( imgsize, imgsize, imgsize).data imgdata1 = imgdata[imgsize // 2, :, :] imgdata1 = torch.stack([imgdata1, imgdata1, imgdata1], 0) imgbatch.append(imgdata1.reshape(3, imgsize, imgsize)) imgdata2 = imgdata[:, imgsize // 2, :] imgdata2 = torch.stack([imgdata2, imgdata2, imgdata2], 0) imgbatch.append(imgdata2.reshape(3, imgsize, imgsize)) imgdata3 = imgdata[:, :, imgsize // 2] imgdata3 = torch.stack([imgdata3, imgdata3, imgdata3], 0) imgbatch.append(imgdata3.reshape(3, imgsize, imgsize)) sample = (imgbatch, torch.tensor(label)) return sample except: pass
def compute_from_aggregating(self, input, target, if_path: bool, type_as_tensor=None, whether_to_return_img=False, result: pl.EvalResult = None): transform = get_val_transform() if if_path: cur_img_subject = torchio.Subject( img=torchio.Image(input, type=torchio.INTENSITY)) cur_label_subject = torchio.Subject( img=torchio.Image(target, type=torchio.LABEL)) preprocessed_img = transform(cur_img_subject) preprocessed_label = transform(cur_label_subject) patch_overlap = self.hparams.patch_overlap # is there any constrain? grid_sampler = torchio.inference.GridSampler( preprocessed_img, self.patch_size, patch_overlap, ) patch_loader = torch.utils.data.DataLoader(grid_sampler) aggregator = torchio.inference.GridAggregator(grid_sampler) for patches_batch in patch_loader: input_tensor = patches_batch['img'][torchio.DATA] # used to convert tensor to CUDA input_tensor = input_tensor.type_as(type_as_tensor['val_dice']) locations = patches_batch[torchio.LOCATION] preds = self(input_tensor) # use cuda labels = preds.argmax(dim=torchio.CHANNELS_DIMENSION, keepdim=True) # use cuda aggregator.add_batch(labels, locations) output_tensor = aggregator.get_output_tensor() # not using cuda! if if_path or whether_to_return_img: return preprocessed_img.img.data, output_tensor, preprocessed_label.img.data else: return output_tensor, preprocessed_label.img.data else: cur_subject = torchio.Subject( img=torchio.Image(tensor=input.squeeze(), type=torchio.INTENSITY), label=torchio.Image(tensor=target.squeeze(), type=torchio.LABEL)) preprocessed_subject = transform(cur_subject) patch_overlap = self.hparams.patch_overlap # is there any constrain? grid_sampler = torchio.inference.GridSampler( preprocessed_subject, self.patch_size, patch_overlap, ) patch_loader = torch.utils.data.DataLoader(grid_sampler) aggregator = torchio.inference.GridAggregator(grid_sampler) dice_loss = [] for patches_batch in patch_loader: input_tensor, target_tensor = patches_batch['img'][ torchio.DATA], patches_batch['label'][torchio.DATA] # used to convert tensor to CUDA input_tensor = input_tensor.type_as(input) locations = patches_batch[torchio.LOCATION] preds_tensor = self(input_tensor) # use cuda # Compute the loss here diceloss = DiceLoss( include_background=self.hparams.include_background, to_onehot_y=True) loss = diceloss.forward(input=preds_tensor, target=target_tensor) dice_loss.append(loss) labels = preds_tensor.argmax(dim=torchio.CHANNELS_DIMENSION, keepdim=True) # use cuda aggregator.add_batch(labels, locations) output_tensor = aggregator.get_output_tensor( ) # not using cuda!!!! if whether_to_return_img: return cur_subject['img'].data, output_tensor, cur_subject[ 'label'].data else: return output_tensor, cur_subject['label'].data, torch.stack( dice_loss)
def validate_network(model, valid_dataloader, scheduler, params, epoch=0, mode="validation"): """ Function to validate a network for a single epoch Parameters ---------- model : if parameters["model"]["type"] == torch, this is a torch.model, otherwise this is OV exec_net The model to process the input image with, it should support appropriate dimensions. valid_dataloader : torch.DataLoader The dataloader for the validation epoch params : dict The parameters passed by the user yaml mode: str The mode of validation, used to write outputs, if requested Returns ------- average_epoch_valid_loss : float Validation loss for the current epoch average_epoch_valid_metric : dict Validation metrics for the current epoch """ print("*" * 20) print("Starting " + mode + " : ") print("*" * 20) # Initialize a few things total_epoch_valid_loss = 0 total_epoch_valid_metric = {} average_epoch_valid_metric = {} for metric in params["metrics"]: if "per_label" in metric: total_epoch_valid_metric[metric] = [] else: total_epoch_valid_metric[metric] = 0 logits_list = [] subject_id_list = [] is_classification = params.get("problem_type") == "classification" is_inference = mode == "inference" # automatic mixed precision - https://pytorch.org/docs/stable/amp.html if params["verbose"]: if params["model"]["amp"]: print("Using Automatic mixed precision", flush=True) if scheduler is None: current_output_dir = params["output_dir"] # this is in inference mode else: # this is useful for inference current_output_dir = os.path.join(params["output_dir"], "output_" + mode) if not (is_inference): current_output_dir = os.path.join(current_output_dir, str(epoch)) pathlib.Path(current_output_dir).mkdir(parents=True, exist_ok=True) # Set the model to valid if params["model"]["type"] == "torch": model.eval() # # putting stuff in individual arrays for correlation analysis # all_targets = [] # all_predics = [] if params["medcam_enabled"] and params["model"]["type"] == "torch": model.enable_medcam() params["medcam_enabled"] = True if params["save_output"] or is_inference: if params["problem_type"] != "segmentation": outputToWrite = "Epoch,SubjectID,PredictedValue\n" file_to_write = os.path.join(current_output_dir, "output_predictions.csv") if os.path.exists(file_to_write): file_to_write = os.path.join( current_output_dir, "output_predictions_" + get_unique_timestamp() + ".csv", ) for batch_idx, (subject) in enumerate( tqdm(valid_dataloader, desc="Looping over " + mode + " data")): if params["verbose"]: print("== Current subject:", subject["subject_id"], flush=True) # ensure spacing is always present in params and is always subject-specific if "spacing" in subject: params["subject_spacing"] = subject["spacing"] else: params["subject_spacing"] = None # constructing a new dict because torchio.GridSampler requires torchio.Subject, which requires torchio.Image to be present in initial dict, which the loader does not provide subject_dict = {} label_ground_truth = None label_present = False # this is when we want the dataloader to pick up properties of GaNDLF's DataLoader, such as pre-processing and augmentations, if appropriate if "label" in subject: if subject["label"] != ["NA"]: subject_dict["label"] = torchio.Image( path=subject["label"]["path"], type=torchio.LABEL, tensor=subject["label"]["data"].squeeze(0), affine=subject["label"]["affine"].squeeze(0), ) label_present = True label_ground_truth = subject_dict["label"]["data"] if "value_keys" in params: # for regression/classification for key in params["value_keys"]: subject_dict["value_" + key] = subject[key] label_ground_truth = torch.cat( [subject[key] for key in params["value_keys"]], dim=0) for key in params["channel_keys"]: subject_dict[key] = torchio.Image( path=subject[key]["path"], type=subject[key]["type"], tensor=subject[key]["data"].squeeze(0), affine=subject[key]["affine"].squeeze(0), ) # regression/classification problem AND label is present if (params["problem_type"] != "segmentation") and label_present: sampler = torchio.data.LabelSampler(params["patch_size"]) tio_subject = torchio.Subject(subject_dict) generator = sampler(tio_subject, num_patches=params["q_samples_per_volume"]) pred_output = 0 for patch in generator: image = torch.cat([ patch[key][torchio.DATA] for key in params["channel_keys"] ], dim=0) valuesToPredict = torch.cat( [patch["value_" + key] for key in params["value_keys"]], dim=0) image = image.unsqueeze(0) image = image.float().to(params["device"]) ## special case for 2D if image.shape[-1] == 1: image = torch.squeeze(image, -1) if params["model"]["type"] == "torch": pred_output += model(image) elif params["model"]["type"] == "openvino": pred_output += torch.from_numpy( model(inputs={ params["model"]["IO"][0][0]: image.cpu().numpy() })[params["model"]["IO"][1][0]]) else: raise Exception( "Model type not supported. Please only use 'torch' or 'openvino'." ) pred_output = pred_output.cpu() / params["q_samples_per_volume"] pred_output /= params["scaling_factor"] if is_inference and is_classification: logits_list.append(pred_output) subject_id_list.append(subject.get("subject_id")[0]) if params["save_output"] or is_inference: outputToWrite += (str(epoch) + "," + subject["subject_id"][0] + "," + str(pred_output.cpu().max().item()) + "\n") final_loss, final_metric = get_loss_and_metrics( image, valuesToPredict, pred_output, params) # # Non network validing related total_epoch_valid_loss += final_loss.detach().cpu().item() for metric in final_metric.keys(): if isinstance(total_epoch_valid_metric[metric], list): if len(total_epoch_valid_metric[metric]) == 0: total_epoch_valid_metric[metric] = np.array( final_metric[metric]) else: total_epoch_valid_metric[metric] += np.array( final_metric[metric]) else: total_epoch_valid_metric[metric] += final_metric[metric] else: # for segmentation problems OR regression/classification when no label is present grid_sampler = torchio.inference.GridSampler( torchio.Subject(subject_dict), params["patch_size"], patch_overlap=params["inference_mechanism"]["patch_overlap"], ) patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=1) aggregator = torchio.inference.GridAggregator( grid_sampler, overlap_mode=params["inference_mechanism"] ["grid_aggregator_overlap"], ) if params["medcam_enabled"]: attention_map_aggregator = torchio.inference.GridAggregator( grid_sampler, overlap_mode=params["inference_mechanism"] ["grid_aggregator_overlap"], ) output_prediction = 0 # this is used for regression/classification current_patch = 0 for patches_batch in patch_loader: if params["verbose"]: print( "=== Current patch:", current_patch, ", time : ", get_date_time(), ", location :", patches_batch[torchio.LOCATION], flush=True, ) current_patch += 1 image = (torch.cat( [ patches_batch[key][torchio.DATA] for key in params["channel_keys"] ], dim=1, ).float().to(params["device"])) # calculate metrics if ground truth is present label = None if params["problem_type"] != "segmentation": label = label_ground_truth elif "label" in patches_batch: label = patches_batch["label"][torchio.DATA] if label is not None: label = label.to(params["device"]) if params["verbose"]: print( "=== Validation shapes : label:", label.shape, ", image:", image.shape, flush=True, ) if is_inference: result = step(model, image, None, params, train=False) else: result = step(model, image, label, params, train=True) # get the current attention map and add it to its aggregator if params["medcam_enabled"]: _, _, output, attention_map = result attention_map_aggregator.add_batch( attention_map, patches_batch[torchio.LOCATION]) else: _, _, output = result if params["problem_type"] == "segmentation": aggregator.add_batch(output.detach().cpu(), patches_batch[torchio.LOCATION]) else: if torch.is_tensor(output): # this probably needs customization for classification (majority voting or median, perhaps?) output_prediction += output.detach().cpu() else: output_prediction += output # save outputs if params["problem_type"] == "segmentation": output_prediction = aggregator.get_output_tensor() output_prediction = output_prediction.unsqueeze(0) if params["save_output"]: img_for_metadata = torchio.Image( type=subject["1"]["type"], tensor=subject["1"]["data"].squeeze(0), affine=subject["1"]["affine"].squeeze(0), ).as_sitk() ext = get_filename_extension_sanitized( subject["1"]["path"][0]) jpg_detected = False if ext in [".jpg", ".jpeg"]: jpg_detected = True pred_mask = output_prediction.numpy() # '0' because validation/testing dataloader always has batch size of '1' pred_mask = reverse_one_hot(pred_mask[0], params["model"]["class_list"]) pred_mask = np.swapaxes(pred_mask, 0, 2) # perform numpy-specific postprocessing here for postprocessor in params["data_postprocessing"]: pred_mask = global_postprocessing_dict[postprocessor]( pred_mask, params).numpy() if jpg_detected: pred_mask = pred_mask.astype(np.uint8) else: pred_mask = pred_mask.astype(np.uint16) ## special case for 2D if image.shape[-1] > 1: result_image = sitk.GetImageFromArray(pred_mask) else: result_image = sitk.GetImageFromArray( pred_mask.squeeze(0)) result_image.CopyInformation(img_for_metadata) # this handles cases that need resampling/resizing if "resample" in params["data_preprocessing"]: result_image = resample_image( result_image, img_for_metadata.GetSpacing(), interpolator=sitk.sitkNearestNeighbor, ) sitk.WriteImage( result_image, os.path.join(current_output_dir, subject["subject_id"][0] + "_seg" + ext), ) else: # final regression output output_prediction = output_prediction / len(patch_loader) if params["save_output"]: outputToWrite += (str(epoch) + "," + subject["subject_id"][0] + "," + str(output_prediction) + "\n") # get the final attention map and save it if params["medcam_enabled"] and params["model"]["type"] == "torch": attention_map = attention_map_aggregator.get_output_tensor() for i, n in enumerate(attention_map): model.save_attention_map(n.squeeze(), raw_input=image[i].squeeze(-1)) output_prediction = output_prediction.squeeze(-1) if is_inference and is_classification: logits_list.append(output_prediction) subject_id_list.append(subject.get("subject_id")[0]) # we cast to float32 because float16 was causing nan if label_ground_truth is not None: # this is for RGB label if label_ground_truth.shape[0] == 3: label_ground_truth = label_ground_truth[0, ...].unsqueeze(0) # we always want the ground truth to be in the same format as the prediction label_ground_truth = label_ground_truth.unsqueeze(0) if label_ground_truth.shape[-1] == 1: label_ground_truth = label_ground_truth.squeeze(-1) final_loss, final_metric = get_loss_and_metrics( image, label_ground_truth, output_prediction.to(torch.float32), params, ) if params["verbose"]: print( "Full image " + mode + ":: Loss: ", final_loss, "; Metric: ", final_metric, flush=True, ) # # Non network validing related # loss.cpu().data.item() total_epoch_valid_loss += final_loss.cpu().item() for metric in final_metric.keys(): if isinstance(total_epoch_valid_metric[metric], list): if len(total_epoch_valid_metric[metric]) == 0: total_epoch_valid_metric[metric] = np.array( final_metric[metric]) else: total_epoch_valid_metric[metric] += np.array( final_metric[metric]) else: total_epoch_valid_metric[metric] += final_metric[ metric] if label_ground_truth is not None: if params["verbose"]: # For printing information at halftime during an epoch if ((batch_idx + 1) % (len(valid_dataloader) / 2) == 0) and ((batch_idx + 1) < len(valid_dataloader)): print( "\nHalf-Epoch Average " + mode + " loss : ", total_epoch_valid_loss / (batch_idx + 1), ) for metric in params["metrics"]: if isinstance(total_epoch_valid_metric[metric], np.ndarray): to_print = (total_epoch_valid_metric[metric] / (batch_idx + 1)).tolist() else: to_print = total_epoch_valid_metric[metric] / ( batch_idx + 1) print( "Half-Epoch Average " + mode + " " + metric + " : ", to_print, ) if params["medcam_enabled"] and params["model"]["type"] == "torch": model.disable_medcam() params["medcam_enabled"] = False if label_ground_truth is not None: average_epoch_valid_loss = total_epoch_valid_loss / len( valid_dataloader) print(" Epoch Final " + mode + " loss : ", average_epoch_valid_loss) for metric in params["metrics"]: if isinstance(total_epoch_valid_metric[metric], np.ndarray): to_print = (total_epoch_valid_metric[metric] / len(valid_dataloader)).tolist() else: to_print = total_epoch_valid_metric[metric] / len( valid_dataloader) average_epoch_valid_metric[metric] = to_print print( " Epoch Final " + mode + " " + metric + " : ", average_epoch_valid_metric[metric], ) else: average_epoch_valid_loss, average_epoch_valid_metric = 0, {} if scheduler is not None: if params["scheduler"]["type"] in [ "reduce_on_plateau", "reduce-on-plateau", "plateau", "reduceonplateau", ]: scheduler.step(average_epoch_valid_loss) else: scheduler.step() # write the predictions, if appropriate if params["save_output"]: if is_inference and is_classification and logits_list: class_list = [str(c) for c in params["model"]["class_list"]] logit_tensor = torch.cat(logits_list) current_fold_dir = params["current_fold_dir"] logit_tensor = logit_tensor.detach().cpu().numpy() columns = ["SubjectID"] + class_list logits_df = pd.DataFrame(columns=columns) logits_df.SubjectID = subject_id_list logits_df[class_list] = logit_tensor logits_file = os.path.join(current_fold_dir, "logits.csv") if os.path.isfile(logits_file): logits_file = os.path.join( current_fold_dir, "logits_" + get_unique_timestamp() + ".csv") logits_df.to_csv(logits_file, index=False, sep=",") if "value_keys" in params: file = open(file_to_write, "w") file.write(outputToWrite) file.close() return average_epoch_valid_loss, average_epoch_valid_metric
def gridsampler_pipeline( input_array, entity_pts, patch_size=(64, 64, 64), patch_overlap=(0, 0, 0), batch_size=1, ): import torchio as tio from torchio import IMAGE, LOCATION from torchio.data.inference import GridAggregator, GridSampler logger.debug("Starting up gridsampler pipeline...") input_tensors = [] output_tensors = [] entity_pts = entity_pts.astype(np.int32) img_tens = torch.FloatTensor(input_array) one_subject = tio.Subject( img=tio.Image(tensor=img_tens, label=tio.INTENSITY), label=tio.Image(tensor=img_tens, label=tio.LABEL), ) img_dataset = tio.ImagesDataset([ one_subject, ]) img_sample = img_dataset[-1] grid_sampler = GridSampler(img_sample, patch_size, patch_overlap) patch_loader = DataLoader(grid_sampler, batch_size=batch_size) aggregator1 = GridAggregator(grid_sampler) aggregator2 = GridAggregator(grid_sampler) pipeline = Pipeline({ "p": 1, "ordered_ops": [ make_masks, make_features, make_sr, make_seg_sr, make_seg_cnn, ], }) payloads = [] with torch.no_grad(): for patches_batch in patch_loader: locations = patches_batch[LOCATION] loc_arr = np.array(locations[0]) loc = (loc_arr[0], loc_arr[1], loc_arr[2]) logger.debug(f"Location: {loc}") # Prepare region data (IMG (Float Volume) AND GEOMETRY (3d Point)) cropped_vol, offset_pts = crop_vol_and_pts_centered( input_array, entity_pts, location=loc, patch_size=patch_size, offset=True, debug_verbose=True, ) plt.figure(figsize=(12, 12)) plt.imshow(cropped_vol[cropped_vol.shape[0] // 2, :], cmap="gray") plt.scatter(offset_pts[:, 1], offset_pts[:, 2]) logger.debug(f"Number of offset_pts: {offset_pts.shape}") logger.debug( f"Allocating memory for no. voxels: {cropped_vol.shape[0] * cropped_vol.shape[1] * cropped_vol.shape[2]}" ) # payload = Patch( # {"in_array": cropped_vol}, # offset_pts, # None, # ) payload = Patch( {"total_mask": np.random.random((4, 4), )}, {"total_anno": np.random.random((4, 4), )}, {"points": np.random.random((4, 3), )}, ) pipeline.init_payload(payload) for step in pipeline: logger.debug(step) # Aggregation (Output: large volume aggregated from many smaller volumes) output_tensor = (torch.FloatTensor( payload.annotation_layers["total_mask"]).unsqueeze( 0).unsqueeze(1)) logger.debug( f"Aggregating output tensor of shape: {output_tensor.shape}") aggregator1.add_batch(output_tensor, locations) output_tensor = (torch.FloatTensor( payload.annotation_layers["prediction"]).unsqueeze( 0).unsqueeze(1)) logger.debug( f"Aggregating output tensor of shape: {output_tensor.shape}") aggregator2.add_batch(output_tensor, locations) payloads.append(payload) output_tensor1 = aggregator1.get_output_tensor() logger.debug(output_tensor1.shape) output_arr1 = np.array(output_tensor1.squeeze(0)) output_tensor2 = aggregator2.get_output_tensor() logger.debug(output_tensor2.shape) output_arr2 = np.array(output_tensor2.squeeze(0)) return [output_tensor1, output_tensor2], payloads
def get_subjects(use_cropped_resampled_data: True): if use_cropped_resampled_data: # using in the cropping folder img_path_list = sorted([ Path(f) for f in sorted( glob(f"{str(cropped_resample_img_folder)}/**/*.nii*", recursive=True)) ]) label_path_list = sorted([ Path(f) for f in sorted( glob(f"{str(cropped_resample_label_folder)}/**/*.nii.gz", recursive=True)) ]) else: img_path_list = sorted([ Path(f) for f in sorted( glob(f"{str(cropped_img_folder)}/**/*.nii*", recursive=True)) ]) label_path_list = sorted([ Path(f) for f in sorted( glob(f"{str(cropped_label_folder)}/**/*.nii.gz", recursive=True)) ]) # # the length is equal # print(f"get {len(img_path_list)} of img") # print(f"get {len(label_path_list)} of label") subjects = [ tio.Subject( img=tio.Image(path=img_path, type=tio.INTENSITY), label=tio.Image(path=label_path, type=tio.LABEL), # store the dataset name to help plot the image later # dataset=mri.dataset ) for img_path, label_path in zip(img_path_list, label_path_list) ] fine_tune_set_file = Path(__file__).resolve( ).parent.parent.parent / "ADNI_MALPEM_baseline_1069.csv" file_df = pd.read_csv(fine_tune_set_file, sep=',') images_baseline_set = set(file_df['filename']) random.seed(42) images_baseline_set = random.sample(images_baseline_set, 150) visual_img_path_list = [] visual_label_path_list = [] # used for visualization for img_path in img_path_list: img_name = img_path.name img_name = img_name + ".gz" if img_name in images_baseline_set: visual_img_path_list.append(img_path) for label_path in label_path_list: label_name = label_path.name if label_name in images_baseline_set: visual_label_path_list.append(label_path) print(f"{ctime()}: getting number of subjects {len(subjects)}") print( f"{ctime()}: getting number of path for visualizationg {len(visual_img_path_list)}" ) return subjects, visual_img_path_list, visual_label_path_list
def generate_dataset(data_path, data_root='', ref_path=None, nb_subjects=5, resampling='mni', masking_method='label'): """ Generate a torchio dataset from a csv file defining paths to subjects. :param data_path: path to a csv file :param data_root: :param ref_path: :param nb_subjects: :param resampling: :param masking_method: :return: """ ds = pd.read_csv(data_path) ds = ds.dropna(subset=['suj']) np.random.seed(0) subject_idx = np.random.choice(range(len(ds)), nb_subjects, replace=False) directories = ds.iloc[subject_idx, 1] dir_list = directories.tolist() dir_list = map(lambda partial_dir: data_root + partial_dir, dir_list) subject_list = [] for directory in dir_list: img_path = glob.glob(os.path.join(directory, 's*.nii.gz'))[0] mask_path = glob.glob(os.path.join(directory, 'niw_Mean*'))[0] coregistration_path = glob.glob(os.path.join(directory, 'aff*.txt'))[0] coregistration = np.loadtxt(coregistration_path, delimiter=' ') coregistration = np.linalg.inv(coregistration) subject = torchio.Subject( t1=torchio.Image(img_path, torchio.INTENSITY, coregistration=coregistration), label=torchio.Image(mask_path, torchio.LABEL), #ref=torchio.Image(ref_path, torchio.INTENSITY) # coregistration=coregistration, ) print('adding img {} \n mask {}\n'.format(img_path, mask_path)) subject_list.append(subject) transforms = [ # Resample(1), RescaleIntensity((0, 1), (0, 99), masking_method=masking_method), ] if resampling == 'mni': # resampling_transform = ResampleWithFoV( # target=nib.load(ref_path), image_interpolation=Interpolation.BSPLINE, coregistration_key='coregistration' # ) resampling_transform = Resample( target='ref', image_interpolation=Interpolation.BSPLINE, coregistration='coregistration') transforms.insert(0, resampling_transform) elif resampling == 'mm': # resampling_transform = ResampleWithFoV(target=nib.load(ref_path), image_interpolation=Interpolation.BSPLINE) resampling_transform = Resample( target=ref_path, image_interpolation=Interpolation.BSPLINE) transforms.insert(0, resampling_transform) transform = Compose(transforms) return torchio.ImagesDataset(subject_list, transform=transform)
from nibabel.viewers import OrthoSlicer3D as ov import glob import sys dr = '/network/lustre/dtlake01/opendata/data/HCP/raw_data/nii/727553/T1w/ROI_PVE_1mm/' dres = glob.glob('/network/lustre/dtlake01/opendata/data/HCP/raw_data/nii/*/T1w/ROI_PVE*') df, df_seuil = pd.DataFrame(), pd.DataFrame() for dr in dres: subject = Path(dr).parent.parent.name resolution = Path(dr).name print("Suj {} {}".format(subject,resolution)) dr += '/' label_list = ['GM', 'WM', 'CSF', 'L_Accu', 'L_Caud', 'L_Pall', 'L_Thal', 'L_Amyg', 'L_Hipp', 'L_Puta', 'R_Amyg', 'R_Hipp', 'R_Puta', 'R_Accu', 'R_Caud', 'R_Pall', 'R_Thal', 'BrStem', 'cereb_GM', 'cereb_WM', 'skull', 'skin', 'background'] suj = [torchio.Subject (label=torchio.Image(type = torchio.LABEL, path=[dr + ll + '.nii.gz' for ll in label_list]))] PV = suj[0].label.data #dd = torchio.SubjectsDataset(suj); ss=dd[0]; PV = ss['label']['data'] #nb.load(ff).get_fdata() #sample0['label']['data'] tbin = PV > 0.001 PV[~tbin] = 0 res = 1.4 if '14mm' in resolution else 2.8 if '28mm' in resolution else 0.7 if '07mm' in resolution else 1 voxel_volume = res * res * res dd = dict(subject=subject, resolution=resolution) # get global volume for ii, ll in enumerate(label_list): dd[ll + '_vol'] = torch.sum(PV[ii]).numpy() * voxel_volume / 1000 for label_index in range(0,10): #print('label {}'.format(label_list[label_index]) )
def test_no_type(self): with self.assertWarns(UserWarning): tio.Image(tensor=torch.rand(1, 2, 3, 4))
img_path_folder = DATA_ROOT / "all_different_size_img" / "cropped" / "img" label_path_folder = DATA_ROOT / "all_different_size_img" / "cropped" / "label" img_path_list = sorted([ Path(f) for f in sorted( glob(f"{str(img_path_folder)}/**/*.nii.gz", recursive=True)) ]) label_path_list = sorted([ Path(f) for f in sorted( glob(f"{str(label_path_folder)}/**/*.nii.gz", recursive=True)) ]) subjects = [] for img_path, label_path in zip(img_path_list, label_path_list): subject = tio.Subject( img=tio.Image(path=img_path, type=tio.INTENSITY), label=tio.Image(path=label_path, type=tio.LABEL), ) subjects.append(subject) print(f"get {len(subjects)} of subject!") training_transform = get_train_transforms() training_set = tio.ImagesDataset(subjects, transform=training_transform) loader = DataLoader( training_set, batch_size=2, # num_workers=multiprocessing.cpu_count()) num_workers=8)
def get_metrics_save_mask(model, device, loader, psize, channel_keys, value_keys, class_list, loss_fn, is_segmentation, scaling_factor=1, weights=None, save_mask=False, outputDir=None, with_roi=False): ''' This function gets various statistics from the specified model and data loader ''' # # if no weights are specified, use 1 # if weights is None: # weights = [1] # for i in range(len(class_list) - 1): # weights.append(1) Path(outputDir).mkdir(parents=True, exist_ok=True) outputToWrite = 'SubjectID,PredictedValue\n' model.eval() with torch.no_grad(): total_loss = total_dice = 0 for batch_idx, (subject) in enumerate(loader): # constructing a new dict because torchio.GridSampler requires torchio.Subject, which requires torchio.Image to be present in initial dict, which the loader does not provide subject_dict = {} if ('label' in subject): if (subject['label'] != ['NA']): subject_dict['label'] = torchio.Image( subject['label']['path'], type=torchio.LABEL) for key in value_keys: # for regression/classification subject_dict['value_' + key] = subject[key] for key in channel_keys: subject_dict[key] = torchio.Image(subject[key]['path'], type=torchio.INTENSITY) grid_sampler = torchio.inference.GridSampler( torchio.Subject(subject_dict), psize) patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=1) aggregator = torchio.inference.GridAggregator(grid_sampler) pred_output = 0 # this is used for regression for patches_batch in patch_loader: image = torch.cat( [patches_batch[key][torchio.DATA] for key in channel_keys], dim=1) if len(value_keys) > 0: valuesToPredict = torch.cat( [patches_batch['value_' + key] for key in value_keys], dim=0) locations = patches_batch[torchio.LOCATION] image = image.float().to(device) ## special case for 2D if image.shape[-1] == 1: model_2d = True image = torch.squeeze(image, -1) locations = torch.squeeze(locations, -1) else: model_2d = False if is_segmentation: # for segmentation, get the predicted mask pred_mask = model(image) if model_2d: pred_mask = pred_mask.unsqueeze(-1) else: # for regression/classification, get the predicted output and add it together to average later on pred_output += model(image) if is_segmentation: # aggregate the predicted mask aggregator.add_batch(pred_mask, locations) if is_segmentation: pred_mask = aggregator.get_output_tensor() pred_mask = pred_mask.cpu() # the validation is done on CPU pred_mask = pred_mask.unsqueeze( 0) # increasing the number of dimension of the mask else: pred_output = pred_output / len( locations) # average the predicted output across patches pred_output = pred_output.cpu() # loss = loss_fn(pred_output.double(), valuesToPredict.double(), len(class_list), weights).cpu().data.item() # this would need to be customized for regression/classification loss = torch.nn.MSELoss()( pred_output.double(), valuesToPredict.double()).cpu().data.item( ) # this needs to be revisited for multi-class output total_loss += loss first = next(iter(subject['label'])) if is_segmentation: if first == 'NA': print( "Ground Truth Mask not found. Generating the Segmentation based one the METADATA of one of the modalities, The Segmentation will be named accordingly" ) mask = subject_dict['label'][ torchio.DATA] # get the label image if mask.dim() == 4: mask = mask.unsqueeze( 0) # increasing the number of dimension of the mask mask = one_hot(mask, class_list) loss = loss_fn(pred_mask.double(), mask.double( ), len(class_list), weights).cpu().data.item( ) # this would need to be customized for regression/classification total_loss += loss #Computing the dice score curr_dice = MCD(pred_mask.double(), mask.double(), len(class_list)).cpu().data.item() #Computing the total dice total_dice += curr_dice if save_mask: patient_name = subject['subject_id'][0] if is_segmentation: path_to_metadata = subject['path_to_metadata'][0] inputImage = sitk.ReadImage(path_to_metadata) _, ext = os.path.splitext(path_to_metadata) pred_mask = pred_mask.numpy() pred_mask = reverse_one_hot(pred_mask[0], class_list) if not (model_2d): result_image = sitk.GetImageFromArray( np.swapaxes(pred_mask, 0, 2)) else: result_image = pred_mask result_image.CopyInformation(inputImage) # if parameters['resize'] is not None: # originalSize = inputImage.GetSize() # result_image = resize_image(resize_image, originalSize, sitk.sitkNearestNeighbor) # change this for resample sitk.WriteImage( result_image, os.path.join(outputDir, patient_name + '_seg' + ext)) elif len(value_keys) > 0: outputToWrite += patient_name + ',' + str( pred_output / scaling_factor) + '\n' if len(value_keys) > 0: file = open(os.path.join(outputDir, "output_predictions.csv"), 'w') file.write(outputToWrite) file.close() # calculate average loss and dice avg_loss = total_loss / len(loader.dataset) if is_segmentation: avg_dice = total_dice / len(loader.dataset) else: avg_dice = 1 # we don't care about this for regression/classification return avg_dice, avg_loss
if EVAL_METRIC == "MeanIoU": print("Using MeanIoU") eval_criterion = MeanIoU() elif EVAL_METRIC == "GenericAveragePrecision": print("Using GenericAveragePrecision") eval_criterion = GenericAveragePrecision() else: print("No evaluation metric specified, exiting") sys.exit(1) # Create model and optimizer os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_DEVICE unet = create_unet_on_device(DEVICE_NUM, MODEL_DICT) optimizer = torch.optim.AdamW(unet.parameters(), lr=STARTING_LR) train_subject = torchio.Subject( data=torchio.Image(tensor=torch.from_numpy(train_data), label=torchio.INTENSITY), label=torchio.Image(tensor=torch.from_numpy(train_seg), label=torchio.LABEL), ) valid_subject = torchio.Subject( data=torchio.Image(tensor=torch.from_numpy(valid_data), label=torchio.INTENSITY), label=torchio.Image(tensor=torch.from_numpy(valid_seg), label=torchio.LABEL), ) # Define the transforms for the set of training patches training_transform = Compose([ RandomNoise(p=0.2), RandomFlip(axes=(0, 1, 2)), RandomBlur(p=0.2), OneOf({
def preprocess_and_save(data_csv, config_file, output_dir, label_pad_mode="constant", applyaugs=False): """ This function performs preprocessing based on parameters provided and saves the output. Args: data_csv (str): The CSV file of the training data. config_file (str): The YAML file of the training configuration. output_dir (str): The output directory. label_pad_mode (str): The padding strategy for the label. Defaults to "constant". applyaugs (bool): If data augmentation is to be applied before saving the image. Defaults to False. Raises: ValueError: Parameter check from previous """ Path(output_dir).mkdir(parents=True, exist_ok=True) # read the csv # don't care if the dataframe gets shuffled or not dataframe, headers = parseTrainingCSV(data_csv, train=False) parameters = parseConfig(config_file) # save the parameters so that the same compute doesn't happen once again parameter_file = os.path.join(output_dir, "parameters.pkl") if os.path.exists(parameter_file): parameters_prev = pickle.load(open(parameter_file, "rb")) if parameters != parameters_prev: raise ValueError( "The parameters are not the same as the ones stored in the previous run, please re-check." ) else: with open(parameter_file, "wb") as handle: pickle.dump(parameters, handle, protocol=pickle.HIGHEST_PROTOCOL) parameters = populate_header_in_parameters(parameters, headers) data_for_processing = ImagesFromDataFrame(dataframe, parameters, train=applyaugs, apply_zero_crop=True, loader_type="full") dataloader_for_processing = DataLoader( data_for_processing, batch_size=1, pin_memory=False, ) # initialize a new dict for the preprocessed data base_df = get_dataframe(data_csv) # ensure csv only contains lower case columns base_df.columns = base_df.columns.str.lower() # only store the column names output_columns_to_write = base_df.to_dict() for key in output_columns_to_write.keys(): output_columns_to_write[key] = [] # keep a record of the keys which contains only images keys_with_images = parameters["headers"]["channelHeaders"] keys_with_images = [str(x) for x in keys_with_images] ## to-do # use dataloader_for_processing to loop through all images # if padding is enabled, ensure that it gets applied to the images # save the images to disk, but keep a record that these images are preprocessed. # create new csv that contains new files. # give warning if label sampler is present but number of patches to extract is > 1 if ((parameters["patch_sampler"] == "label") or (isinstance(parameters["patch_sampler"], dict))) and parameters["q_samples_per_volume"] > 1: print( "[WARNING] Label sampling has been enabled but q_samples_per_volume > 1; this has been known to cause issues, so q_samples_per_volume will be hard-coded to 1 during preprocessing. Please contact GaNDLF developers for more information", file=sys.stderr, flush=True, ) for _, (subject) in enumerate( tqdm(dataloader_for_processing, desc="Looping over data")): # initialize the current_output_dir current_output_dir = os.path.join(output_dir, str(subject["subject_id"][0])) Path(current_output_dir).mkdir(parents=True, exist_ok=True) output_columns_to_write["subjectid"].append(subject["subject_id"][0]) subject_dict_to_write, subject_process = {}, {} # start constructing the torchio.Subject object for channel in parameters["headers"]["channelHeaders"]: # the "squeeze" is needed because the dataloader automatically # constructs 5D tensor considering the batch_size as first # dimension, but the constructor needs 4D tensor. subject_process[str(channel)] = torchio.Image( tensor=subject[str(channel)]["data"].squeeze(0), type=torchio.INTENSITY, path=subject[str(channel)]["path"], ) if parameters["headers"]["labelHeader"] is not None: subject_process["label"] = torchio.Image( tensor=subject["label"]["data"].squeeze(0), type=torchio.LABEL, path=subject["label"]["path"], ) subject_dict_to_write = torchio.Subject(subject_process) # apply a different padding mode to image and label (so that label information is not duplicated) if (parameters["patch_sampler"] == "label") or (isinstance( parameters["patch_sampler"], dict)): # get the padding size from the patch_size psize_pad = list( np.asarray(np.ceil(np.divide(parameters["patch_size"], 2)), dtype=int)) # initialize the padder for images padder = torchio.transforms.Pad(psize_pad, padding_mode="symmetric", include=keys_with_images) subject_dict_to_write = padder(subject_dict_to_write) if parameters["headers"]["labelHeader"] is not None: # initialize the padder for label padder_label = torchio.transforms.Pad( psize_pad, padding_mode=label_pad_mode, include="label") subject_dict_to_write = padder_label(subject_dict_to_write) sampler = torchio.data.LabelSampler(parameters["patch_size"]) generator = sampler(subject_dict_to_write, num_patches=1) for patch in generator: for channel in parameters["headers"]["channelHeaders"]: subject_dict_to_write[str(channel)] = patch[str( channel)] subject_dict_to_write["label"] = patch["label"] # write new images common_ext = get_filename_extension_sanitized(subject["1"]["path"][0]) # in cases where the original image has a file format that does not support # RGB floats, use the "vtk" format if common_ext in [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"]: common_ext = ".vtk" if subject["1"]["path"][0] != "": image_for_info_copy = sitk.ReadImage(subject["1"]["path"][0]) else: image_for_info_copy = subject_dict_to_write["1"].as_sitk() correct_spacing_for_info_copy = subject["spacing"][0].tolist() for channel in parameters["headers"]["channelHeaders"]: image_file = Path( os.path.join( current_output_dir, subject["subject_id"][0] + "_" + str(channel) + common_ext, )).as_posix() output_columns_to_write["channel_" + str(channel - 1)].append(image_file) image_to_write = subject_dict_to_write[str(channel)].as_sitk() image_to_write.SetOrigin(image_for_info_copy.GetOrigin()) image_to_write.SetDirection(image_for_info_copy.GetDirection()) image_to_write.SetSpacing(correct_spacing_for_info_copy) if not os.path.isfile(image_file): try: sitk.WriteImage(image_to_write, image_file) except IOError: raise IOError( "Could not write image file: {}. Make sure that the file is not open and try again." .format(image_file)) # now try to write the label if "label" in subject_dict_to_write: image_file = Path( os.path.join(current_output_dir, subject["subject_id"][0] + "_label" + common_ext)).as_posix() output_columns_to_write["label"].append(image_file) image_to_write = subject_dict_to_write["label"].as_sitk() image_to_write.SetOrigin(image_for_info_copy.GetOrigin()) image_to_write.SetDirection(image_for_info_copy.GetDirection()) image_to_write.SetSpacing(correct_spacing_for_info_copy) if not os.path.isfile(image_file): try: sitk.WriteImage(image_to_write, image_file) except IOError: raise IOError( "Could not write image file: {}. Make sure that the file is not open and try again." .format(image_file)) # ensure prediction headers are getting saved, as well if len(parameters["headers"]["predictionHeaders"]) > 1: for key in parameters["headers"]["predictionHeaders"]: output_columns_to_write["valuetopredict_" + str(key)].append( str(subject["value_" + str(key)].numpy()[0])) elif len(parameters["headers"]["predictionHeaders"]) == 1: output_columns_to_write["valuetopredict"].append( str(subject["value_0"].numpy()[0])) path_for_csv = Path(os.path.join(output_dir, "data_processed.csv")).as_posix() print("Writing final csv for subsequent training: ", path_for_csv) pd.DataFrame.from_dict(data=output_columns_to_write).to_csv(path_for_csv, header=True, index=False)
def patch_sampler(img_filenames, labelmap_filenames, patch_size, sampler_type, out_dir, max_patches=None, voxel_spacing=(), patch_overlap=(0, 0, 0), min_labeled_voxels=1.0, label_prob=0.8, save_patches=False, batch_size=None, prepare_batches=False, inference=False): """Reshape a 3D volumes into a collection of 2D patches The resulting patches are allocated in a dedicated array. Parameters ---------- img_filenames : list of strings Paths to images to extract patches from patch_size : tuple of ints (patch_x, patch_y, patch_z) The dimensions of one patch patch_overlap : tuple of ints (0, patch_x, patch_y) The maximum patch overlap between the patches min_labeled_voxels is not None: : float between 0 and 1 The minimum percentage of labeled pixels for a patch. If set to None patches are extracted based on center_voxel. labelmap_filenames : list of strings Paths to labelmap Returns ------- img_patches, label_patches : array, shape = (n_patches, patch_x, patch_y, patch_z, 1) The collection of patches extracted from the volumes, where `n_patches` is the total number of patches extracted. """ if max_patches is not None: max_patches = int(max_patches / len(img_filenames)) img_patches = [] label_patches = [] patch_counter = 0 save_counter = 0 img_ids = [] label_ids = [] save_size = 1 if prepare_batches: save_size = batch_size print(f'\nExtracting patches from: {img_filenames}\n') for i in tqdm(range(len(img_filenames)), leave=False): if voxel_spacing: util.update_affine(img_filenames[i], labelmap_filenames[i]) if labelmap_filenames: subject = tio.Subject(img=tio.Image(img_filenames[i], type=tio.INTENSITY), labelmap=tio.LabelMap(labelmap_filenames[i])) # Apply transformations #transform = tio.ZNormalization() #transformed = transform(subject) transform = tio.RescaleIntensity((0, 1)) transformed = transform(subject) if voxel_spacing: transform = tio.Resample(voxel_spacing) transformed = transform(transformed) num_img_patches = 0 if sampler_type == 'grid': sampler = tio.data.GridSampler(transformed, patch_size, patch_overlap) for patch in sampler: img_patch = np.array(patch.img.data) label_patch = np.array(patch.labelmap.data) labeled_voxels = torch.count_nonzero( patch.labelmap.data) >= patch_size[0] * patch_size[ 1] * patch_size[2] * min_labeled_voxels center = label_patch[0, int(patch_size[0] / 2), int(patch_size[1] / 2), int(patch_size[2] / 2)] != 0 if labeled_voxels or center: img_patches.append(img_patch) label_patches.append(label_patch) patch_counter += 1 num_img_patches += 1 if save_patches: img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter = save( img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter, save_size, patch_size, inference, out_dir) # Check if max_patches for img if max_patches is not None: if num_img_patches > max_patches: break else: # Define sampler one_label = 1.0 - label_prob label_probabilities = {0: one_label, 1: label_prob} sampler = tio.data.LabelSampler( patch_size, label_probabilities=label_probabilities) if max_patches is None: generator = sampler(transformed) else: generator = sampler(transformed, max_patches) for patch in generator: img_patches.append(np.array(patch.img.data)) label_patches.append(np.array(patch.labelmap.data)) patch_counter += 1 if save_patches: img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter = save( img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter, save_size, patch_size, inference, out_dir) print(f'Finished extracting patches.') if save_patches: return img_ids, label_ids else: if patch_size[0] == 1: return np.array(img_patches).reshape( len(img_patches), patch_size[1], patch_size[2], 1), np.array(label_patches).reshape(len(label_patches), patch_size[1], patch_size[2], 1) else: return np.array(img_patches).reshape( len(img_patches), patch_size[0], patch_size[1], patch_size[2], 1), np.array(label_patches).reshape(len(label_patches), patch_size[1], patch_size[2], 1)
batch_size = 2 # Set to 2 for 32Gb Card print(f"Patch size is {PATCH_SIZE}") print(f"Free GPU memory is {free_gpu_mem:0.2f} GB. Batch size will be " f"{batch_size}.") # Load model print(f"Loading model from {MODEL_FILE}") model_dict = torch.load(MODEL_FILE, map_location='cpu') unet = create_unet_on_device(DEVICE_NUM, model_dict['model_struc_dict']) unet.load_state_dict(model_dict['model_state_dict']) if model_dict['model_struc_dict']['out_channels'] > 1: multilabel = True # Load the data and create a sampler print(f"Loading data from {DATA_FILE}") data_tens = tensor_from_hdf5(DATA_FILE, HDF5_PATH) data_subject = torchio.Subject( data=torchio.Image(tensor=data_tens, label=torchio.INTENSITY)) print(f"Setting up grid sampler with overlap {PATCH_OVERLAP} and padding " f"mode: {PADDING_MODE}") grid_sampler = GridSampler(data_subject, PATCH_SIZE, PATCH_OVERLAP, padding_mode=PADDING_MODE) pred_vol = predict_volume(unet, grid_sampler, batch_size, DATA_OUT_FN, multilabel) fig_out_dir = DATA_OUT_DIR / f'{date.today()}_3d_prediction_figs' print(f"Creating directory for figures: {fig_out_dir}") os.makedirs(fig_out_dir, exist_ok=True) plot_predict_figure(pred_vol, data_tens, fig_out_dir)
def predict_agg_3d( input_array, model3d, patch_size=(128, 224, 224), patch_overlap=(12, 12, 12), nb=True, device=0, debug_verbose=False, fpn=False, overlap_mode="crop", ): import torchio as tio from torchio import IMAGE, LOCATION from torchio.data.inference import GridAggregator, GridSampler print(input_array.shape) img_tens = torch.FloatTensor(input_array[:]).unsqueeze(0) print(f"Predict and aggregate on volume of {img_tens.shape}") one_subject = tio.Subject( img=tio.Image(tensor=img_tens, label=tio.INTENSITY), label=tio.Image(tensor=img_tens, label=tio.LABEL), ) img_dataset = tio.SubjectsDataset( [ one_subject, ] ) img_sample = img_dataset[-1] batch_size = 1 grid_sampler = GridSampler(img_sample, patch_size, patch_overlap) patch_loader = DataLoader(grid_sampler, batch_size=batch_size) aggregator1 = GridAggregator(grid_sampler, overlap_mode=overlap_mode) input_tensors = [] output_tensors = [] if nb: from tqdm.notebook import tqdm else: from tqdm import tqdm with torch.no_grad(): for patches_batch in tqdm(patch_loader): input_tensor = patches_batch["img"]["data"] locations = patches_batch[LOCATION] inputs_t = input_tensor inputs_t = inputs_t.to(device) if fpn: outputs = model3d(inputs_t)[0] else: outputs = model3d(inputs_t) if debug_verbose: print(f"inputs_t: {inputs_t.shape}") print(f"outputs: {outputs.shape}") output = outputs[:, 0:1, :] # output = torch.sigmoid(output) aggregator1.add_batch(output, locations) return aggregator1
def main( input_path, parcellation_path, output_image_path, output_label_path, min_volume, max_volume, volumes_path, ): """Console script for resector.""" import torchio import resector hemispheres = 'left', 'right' input_path = Path(input_path) output_dir = input_path.parent stem = input_path.name.split('.nii')[0] # assume it's a .nii file gm_paths = [] resectable_paths = [] for hemisphere in hemispheres: dst = output_dir / f'{stem}_gray_matter_{hemisphere}_seg.nii.gz' gm_paths.append(dst) if not dst.is_file(): gm = resector.parcellation.get_gray_matter_mask( parcellation_path, hemisphere) resector.io.write(gm, dst) dst = output_dir / f'{stem}_resectable_{hemisphere}_seg.nii.gz' resectable_paths.append(dst) if not dst.is_file(): resectable = resector.parcellation.get_resectable_hemisphere_mask( parcellation_path, hemisphere, ) resector.io.write(resectable, dst) noise_path = output_dir / f'{stem}_noise.nii.gz' if not noise_path.is_file(): resector.parcellation.make_noise_image( input_path, parcellation_path, noise_path, ) if volumes_path is not None: import pandas as pd df = pd.read_csv(volumes_path) volumes = df.Volume.values kwargs = dict(volumes=volumes) else: kwargs = dict(volumes_range=(min_volume, max_volume)) transform = torchio.Compose(( torchio.ToCanonical(), resector.RandomResection(**kwargs), )) subject = torchio.Subject( image=torchio.Image(input_path, torchio.INTENSITY), resection_resectable_left=torchio.Image(resectable_paths[0], torchio.LABEL), resection_resectable_right=torchio.Image(resectable_paths[1], torchio.LABEL), resection_gray_matter_left=torchio.Image(gm_paths[0], torchio.LABEL), resection_gray_matter_right=torchio.Image(gm_paths[1], torchio.LABEL), resection_noise=torchio.Image(noise_path, None), ) dataset = torchio.ImagesDataset([subject], transform=transform) resected = dataset[0] dataset.save_sample( resected, dict(image=output_image_path, label=output_label_path), ) return 0