def test_threshold(nii_seg): # input array arr_seg_proc = imed_postpro.threshold_predictions(np.copy(np.asanyarray(nii_seg.dataobj))) assert isinstance(arr_seg_proc, np.ndarray) # Before thresholding: [0.33333333, 0.66666667, 1. ] --> after thresholding: [0, 1, 1] assert np.array_equal(arr_seg_proc[4:7, 8, 4], np.array([0, 1, 1])) # input nibabel nii_seg_proc = imed_postpro.threshold_predictions(nii_seg) assert isinstance(nii_seg_proc, nib.nifti1.Nifti1Image) assert np.array_equal(nii_seg_proc.get_fdata()[4:7, 8, 4], np.array([0, 1, 1]))
def combine_predictions(fname_lst, fname_hard, fname_prob, thr=0.5): """Combine predictions from Monte Carlo simulations. Combine predictions from Monte Carlo simulations and save the resulting as: (1) `fname_prob`, a soft segmentation obtained by averaging the Monte Carlo samples. (2) `fname_hard`, a hard segmentation obtained thresholding with `thr`. Args: fname_lst (list of str): List of the Monte Carlo samples. fname_hard (str): Filename for the output hard segmentation. fname_prob (str): Filename for the output soft segmentation. thr (float): Between 0 and 1. Used to threshold the soft segmentation and generate the hard segmentation. """ # collect all MC simulations mc_data = np.array([nib.load(fname).get_fdata() for fname in fname_lst]) affine = nib.load(fname_lst[0]).affine # average over all the MC simulations data_prob = np.mean(mc_data, axis=0) # save prob segmentation nib_prob = nib.Nifti1Image(data_prob, affine) nib.save(nib_prob, fname_prob) # argmax operator data_hard = imed_postpro.threshold_predictions(data_prob, thr=thr).astype(np.uint8) # save hard segmentation nib_hard = nib.Nifti1Image(data_hard, affine) nib.save(nib_hard, fname_hard)
def save_color_labels(gt_data, binarize, gt_filename, output_filename, slice_axis): """Saves labels encoded in RGB in specified output file. Args: gt_data (ndarray): Input image with dimensions (Number of classes, height, width, depth). binarize (bool): If True binarizes gt_data to 0 and 1 values, else soft values are kept. gt_filename (str): GT path and filename. output_filename (str): Name of the output file where the colored labels are saved. slice_axis (int): Indicates the axis used to extract slices: "axial": 2, "sagittal": 0, "coronal": 1. Returns: ndarray: RGB labels. """ n_class, h, w, d = gt_data.shape labels = range(n_class) # Generate color labels multi_labeled_pred = np.zeros((h, w, d, 3)) if binarize: gt_data = imed_postpro.threshold_predictions(gt_data) # Keep always the same color labels np.random.seed(6) for label in labels: r, g, b = np.random.randint(0, 256, size=3) multi_labeled_pred[..., 0] += r * gt_data[label, ] multi_labeled_pred[..., 1] += g * gt_data[label, ] multi_labeled_pred[..., 2] += b * gt_data[label, ] rgb_dtype = np.dtype([('R', 'u1'), ('G', 'u1'), ('B', 'u1')]) multi_labeled_pred = multi_labeled_pred.copy().astype('u1').view( dtype=rgb_dtype).reshape((h, w, d)) imed_inference.pred_to_nib([multi_labeled_pred], [], gt_filename, output_filename, slice_axis=slice_axis, kernel_dim='3d', bin_thr=-1, discard_noise=False) return multi_labeled_pred
def __getitem__(self, index): """Return the specific processed data corresponding to index (input, ground truth, roi and metadata). Args: index (int): Slice index. """ # copy.deepcopy is used to have different coordinates for reconstruction for a given handler with patch, # to allow a different rater at each iteration of training, and to clean transforms params from previous # transforms i.e. remove params from previous iterations so that the coming transforms are different if self.is_2d_patch: coord = self.indexes[index] if self.disk_cache: with self.handlers[coord['handler_index']].open( mode="rb") as f: seg_pair_slice, roi_pair_slice = pickle.load(f) else: seg_pair_slice, roi_pair_slice = copy.deepcopy( self.handlers[coord['handler_index']]) else: if self.disk_cache: with self.indexes[index].open(mode="rb") as f: seg_pair_slice, roi_pair_slice = pickle.load(f) else: seg_pair_slice, roi_pair_slice = copy.deepcopy( self.indexes[index]) # In case multiple raters if seg_pair_slice['gt'] and isinstance(seg_pair_slice['gt'][0], list): # Randomly pick a rater idx_rater = random.randint(0, len(seg_pair_slice['gt'][0]) - 1) # Use it as ground truth for this iteration # Note: in case of multi-class: the same rater is used across classes for idx_class in range(len(seg_pair_slice['gt'])): seg_pair_slice['gt'][idx_class] = seg_pair_slice['gt'][ idx_class][idx_rater] seg_pair_slice['gt_metadata'][idx_class] = seg_pair_slice[ 'gt_metadata'][idx_class][idx_rater] metadata_input = seg_pair_slice['input_metadata'] if seg_pair_slice[ 'input_metadata'] is not None else [] metadata_roi = roi_pair_slice['gt_metadata'] if roi_pair_slice[ 'gt_metadata'] is not None else [] metadata_gt = seg_pair_slice['gt_metadata'] if seg_pair_slice[ 'gt_metadata'] is not None else [] if self.is_2d_patch: stack_roi, metadata_roi = None, None else: # Set coordinates to the slices full size coord = {} coord['x_min'], coord['x_max'] = 0, seg_pair_slice["input"][ 0].shape[0] coord['y_min'], coord['y_max'] = 0, seg_pair_slice["input"][ 0].shape[1] # Run transforms on ROI # ROI goes first because params of ROICrop are needed for the followings stack_roi, metadata_roi = self.transform( sample=roi_pair_slice["gt"], metadata=metadata_roi, data_type="roi") # Update metadata_input with metadata_roi metadata_input = imed_loader_utils.update_metadata( metadata_roi, metadata_input) # Add coordinates of slices or patches to input metadata for metadata in metadata_input: metadata['coord'] = [ coord["x_min"], coord["x_max"], coord["y_min"], coord["y_max"] ] # Extract image and gt slices or patches from coordinates stack_input = np.asarray( seg_pair_slice["input"])[:, coord['x_min']:coord['x_max'], coord['y_min']:coord['y_max']] if seg_pair_slice["gt"]: stack_gt = np.asarray( seg_pair_slice["gt"])[:, coord['x_min']:coord['x_max'], coord['y_min']:coord['y_max']] else: stack_gt = [] # Run transforms on image slices or patches stack_input, metadata_input = self.transform(sample=list(stack_input), metadata=metadata_input, data_type="im") # Update metadata_gt with metadata_input metadata_gt = imed_loader_utils.update_metadata( metadata_input, metadata_gt) if self.task == "segmentation": # Run transforms on gt slices or patches stack_gt, metadata_gt = self.transform(sample=list(stack_gt), metadata=metadata_gt, data_type="gt") # Make sure stack_gt is binarized if stack_gt is not None and not self.soft_gt: stack_gt = imed_postpro.threshold_predictions(stack_gt, thr=0.5).astype( np.uint8) else: # Force no transformation on labels for classification task # stack_gt is a tensor of size 1x1, values: 0 or 1 # "expand(1)" is necessary to be compatible with segmentation convention: n_labelxhxwxd stack_gt = torch.from_numpy(seg_pair_slice["gt"][0]).expand(1) data_dict = { 'input': stack_input, 'gt': stack_gt, 'roi': stack_roi, 'input_metadata': metadata_input, 'gt_metadata': metadata_gt, 'roi_metadata': metadata_roi } # Input-level dropout to train with missing modalities if self.is_input_dropout: data_dict = dropout_input(data_dict) return data_dict
def pred_to_nib(data_lst, z_lst, fname_ref, fname_out, slice_axis, debug=False, kernel_dim='2d', bin_thr=0.5, discard_noise=True, postprocessing=None): """Save the network predictions as nibabel object. Based on the header of `fname_ref` image, it creates a nibabel object from the Network predictions (`data_lst`). Args: data_lst (list of np arrays): Predictions, either 2D slices either 3D patches. z_lst (list of ints): Slice indexes to reconstruct a 3D volume for 2D slices. fname_ref (str): Filename of the input image: its header is copied to the output nibabel object. fname_out (str): If not None, then the generated nibabel object is saved with this filename. slice_axis (int): Indicates the axis used for the 2D slice extraction: Sagittal: 0, Coronal: 1, Axial: 2. debug (bool): If True, extended verbosity and intermediate outputs. kernel_dim (str): Indicates whether the predictions were done on 2D or 3D patches. Choices: '2d', '3d'. bin_thr (float): If positive, then the segmentation is binarized with this given threshold. Otherwise, a soft segmentation is output. discard_noise (bool): If True, predictions that are lower than 0.01 are set to zero. postprocessing (dict): Contains postprocessing steps to be applied. Returns: NibabelObject: Object containing the Network prediction. """ # Load reference nibabel object nib_ref = nib.load(fname_ref) nib_ref_can = nib.as_closest_canonical(nib_ref) if kernel_dim == '2d': # complete missing z with zeros tmp_lst = [] for z in range(nib_ref_can.header.get_data_shape()[slice_axis]): if not z in z_lst: tmp_lst.append(np.zeros(data_lst[0].shape)) else: tmp_lst.append(data_lst[z_lst.index(z)]) if debug: print("Len {}".format(len(tmp_lst))) for arr in tmp_lst: print("Shape element lst {}".format(arr.shape)) # create data and stack on depth dimension arr_pred_ref_space = np.stack(tmp_lst, axis=-1) else: arr_pred_ref_space = data_lst[0] n_channel = arr_pred_ref_space.shape[0] oriented_volumes = [] if len(arr_pred_ref_space.shape) == 4: for i in range(n_channel): oriented_volumes.append( imed_loader_utils.reorient_image(arr_pred_ref_space[i, ], slice_axis, nib_ref, nib_ref_can)) # transpose to locate the channel dimension at the end to properly see image on viewer arr_pred_ref_space = np.asarray(oriented_volumes).transpose( (1, 2, 3, 0)) else: arr_pred_ref_space = imed_loader_utils.reorient_image( arr_pred_ref_space, slice_axis, nib_ref, nib_ref_can) if bin_thr >= 0: arr_pred_ref_space = imed_postpro.threshold_predictions( arr_pred_ref_space, thr=bin_thr) elif discard_noise: # discard noise arr_pred_ref_space[arr_pred_ref_space <= 1e-2] = 0 # create nibabel object if postprocessing: fname_prefix = fname_out.split( "_pred.nii.gz")[0] if fname_out is not None else None postpro = imed_postpro.Postprocessing(postprocessing, arr_pred_ref_space, nib_ref.header['pixdim'][1:4], fname_prefix) arr_pred_ref_space = postpro.apply() nib_pred = nib.Nifti1Image(arr_pred_ref_space, nib_ref.affine) # save as nifti file if fname_out is not None: nib.save(nib_pred, fname_out) return nib_pred
def __getitem__(self, index): """Return the specific index pair subvolume (input, ground truth). Args: index (int): Subvolume index. """ coord = self.indexes[index] seg_pair, _ = self.handlers[coord['handler_index']] # Clean transforms params from previous transforms # i.e. remove params from previous iterations so that the coming transforms are different # Use copy to have different coordinates for reconstruction for a given handler metadata_input = imed_loader_utils.clean_metadata( copy.deepcopy(seg_pair['input_metadata'])) metadata_gt = imed_loader_utils.clean_metadata( copy.deepcopy(seg_pair['gt_metadata'])) # Run transforms on images stack_input, metadata_input = self.transform(sample=seg_pair['input'], metadata=metadata_input, data_type="im") # Update metadata_gt with metadata_input metadata_gt = imed_loader_utils.update_metadata( metadata_input, metadata_gt) # Run transforms on images stack_gt, metadata_gt = self.transform(sample=seg_pair['gt'], metadata=metadata_gt, data_type="gt") # Make sure stack_gt is binarized if stack_gt is not None and not self.soft_gt: stack_gt = imed_postpro.threshold_predictions(stack_gt, thr=0.5) shape_x = coord["x_max"] - coord["x_min"] shape_y = coord["y_max"] - coord["y_min"] shape_z = coord["z_max"] - coord["z_min"] # add coordinates to metadata to reconstruct volume for metadata in metadata_input: metadata['coord'] = [ coord["x_min"], coord["x_max"], coord["y_min"], coord["y_max"], coord["z_min"], coord["z_max"] ] subvolumes = { 'input': torch.zeros(stack_input.shape[0], shape_x, shape_y, shape_z), 'gt': torch.zeros(stack_gt.shape[0], shape_x, shape_y, shape_z) if stack_gt is not None else None, 'input_metadata': metadata_input, 'gt_metadata': metadata_gt } for _ in range(len(stack_input)): subvolumes['input'] = stack_input[:, coord['x_min']:coord['x_max'], coord['y_min']:coord['y_max'], coord['z_min']:coord['z_max']] if stack_gt is not None: for _ in range(len(stack_gt)): subvolumes['gt'] = stack_gt[:, coord['x_min']:coord['x_max'], coord['y_min']:coord['y_max'], coord['z_min']:coord['z_max']] return subvolumes
def __getitem__(self, index): """Return the specific processed data corresponding to index (input, ground truth, roi and metadata). Args: index (int): Slice index. """ seg_pair_slice, roi_pair_slice = self.indexes[index] # Clean transforms params from previous transforms # i.e. remove params from previous iterations so that the coming transforms are different metadata_input = imed_loader_utils.clean_metadata( seg_pair_slice['input_metadata']) metadata_roi = imed_loader_utils.clean_metadata( roi_pair_slice['gt_metadata']) metadata_gt = imed_loader_utils.clean_metadata( seg_pair_slice['gt_metadata']) # Run transforms on ROI # ROI goes first because params of ROICrop are needed for the followings stack_roi, metadata_roi = self.transform(sample=roi_pair_slice["gt"], metadata=metadata_roi, data_type="roi") # Update metadata_input with metadata_roi metadata_input = imed_loader_utils.update_metadata( metadata_roi, metadata_input) # Run transforms on images stack_input, metadata_input = self.transform( sample=seg_pair_slice["input"], metadata=metadata_input, data_type="im") # Update metadata_input with metadata_roi metadata_gt = imed_loader_utils.update_metadata( metadata_input, metadata_gt) if self.task == "segmentation": # Run transforms on images stack_gt, metadata_gt = self.transform(sample=seg_pair_slice["gt"], metadata=metadata_gt, data_type="gt") # Make sure stack_gt is binarized if stack_gt is not None and not self.soft_gt: stack_gt = imed_postpro.threshold_predictions(stack_gt, thr=0.5) else: # Force no transformation on labels for classification task # stack_gt is a tensor of size 1x1, values: 0 or 1 # "expand(1)" is necessary to be compatible with segmentation convention: n_labelxhxwxd stack_gt = torch.from_numpy(seg_pair_slice["gt"][0]).expand(1) data_dict = { 'input': stack_input, 'gt': stack_gt, 'roi': stack_roi, 'input_metadata': metadata_input, 'gt_metadata': metadata_gt, 'roi_metadata': metadata_roi } return data_dict
def threshold_analysis(model_path, ds_lst, model_params, testing_params, metric="dice", increment=0.1, fname_out="thr.png", cuda_available=True): """Run a threshold analysis to find the optimal threshold on a sub-dataset. Args: model_path (str): Model path. ds_lst (list): List of loaders. model_params (dict): Model's parameters. testing_params (dict): Testing parameters metric (str): Choice between "dice" and "recall_specificity". If "recall_specificity", then a ROC analysis is performed. increment (float): Increment between tested thresholds. fname_out (str): Plot output filename. cuda_available (bool): If True, CUDA is available. Returns: float: optimal threshold. """ if metric not in ["dice", "recall_specificity"]: raise ValueError( '\nChoice of metric for threshold analysis: dice, recall_specificity.' ) # Adjust some testing parameters testing_params["uncertainty"]["applied"] = False # Load model model = torch.load(model_path) # Eval mode model.eval() # List of thresholds thr_list = list(np.arange(0.0, 1.0, increment))[1:] # Init metric manager for each thr metric_fns = [ imed_metrics.recall_score, imed_metrics.dice_score, imed_metrics.specificity_score ] metric_dict = { thr: imed_metrics.MetricManager(metric_fns) for thr in thr_list } # Load loader = DataLoader(ConcatDataset(ds_lst), batch_size=testing_params["batch_size"], shuffle=False, pin_memory=True, sampler=None, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # Run inference preds_npy, gt_npy = run_inference(loader, model, model_params, testing_params, ofolder=None, cuda_available=cuda_available) print('\nRunning threshold analysis to find optimal threshold') # Make sure the GT is binarized gt_npy = [threshold_predictions(gt, thr=0.5) for gt in gt_npy] # Move threshold for thr in tqdm(thr_list, desc="Search"): preds_thr = [ threshold_predictions(copy.deepcopy(pred), thr=thr) for pred in preds_npy ] metric_dict[thr](preds_thr, gt_npy) # Get results tpr_list, fpr_list, dice_list = [], [], [] for thr in thr_list: result_thr = metric_dict[thr].get_results() tpr_list.append(result_thr["recall_score"]) fpr_list.append(1 - result_thr["specificity_score"]) dice_list.append(result_thr["dice_score"]) # Get optimal threshold if metric == "dice": diff_list = dice_list else: diff_list = [tpr - fpr for tpr, fpr in zip(tpr_list, fpr_list)] optimal_idx = np.max(np.where(diff_list == np.max(diff_list))) optimal_threshold = thr_list[optimal_idx] print('\tOptimal threshold: {}'.format(optimal_threshold)) # Save plot print('\tSaving plot: {}'.format(fname_out)) if metric == "dice": # Run plot imed_metrics.plot_dice_thr(thr_list, dice_list, optimal_idx, fname_out) else: # Add 0 and 1 as extrema tpr_list = [0.0] + tpr_list + [1.0] fpr_list = [0.0] + fpr_list + [1.0] optimal_idx += 1 # Run plot imed_metrics.plot_roc_curve(tpr_list, fpr_list, optimal_idx, fname_out) return optimal_threshold
def run_inference(pred_folder, im_lst, thr_pred, gt_folder, target_suf, param_eval, unc_name=None, thr_unc=None): # init df df_results = pd.DataFrame() # loop across images for fname_pref in im_lst: if not any(elem is None for elem in [unc_name, thr_unc]): logger.debug(thr_unc) # uncertainty map fname_unc = os.path.join(pred_folder, fname_pref + unc_name + '.nii.gz') im = nib.load(fname_unc) data_unc = im.get_data() del im # list MC samples data_pred_lst = np.array([ nib.load(os.path.join(pred_folder, f)).get_data() for f in os.listdir(pred_folder) if fname_pref + '_pred_' in f ]) else: data_pred_lst = np.array([ nib.load(os.path.join(pred_folder, f)).get_data() for f in os.listdir(pred_folder) if fname_pref + '_pred.' in f ]) # ground-truth fname fname_gt = os.path.join(gt_folder, fname_pref.split('_')[0], 'anat', fname_pref + target_suf + '.nii.gz') nib_gt = nib.load(fname_gt) data_gt = nib_gt.get_data() # soft prediction data_soft = np.mean(data_pred_lst, axis=0) if not any(elem is None for elem in [unc_name, thr_unc]): logger.debug("thr") # discard uncertain lesions from data_soft data_soft[data_unc > thr_unc] = 0 data_hard = imed_postpro.threshold_predictions( data_soft, thr=thr_pred).astype(np.uint8) eval = imed_utils.Evaluation3DMetrics( data_pred=data_hard, data_gt=data_gt, dim_lst=nib_gt.header['pixdim'][1:4], params=param_eval) results_pred, _ = eval.run_eval() # save results of this fname_pred results_pred['image_id'] = fname_pref.split('_')[0] df_results = df_results.append(results_pred, ignore_index=True) return df_results
def run_experiment(level, unc_name, thr_unc_lst, thr_pred_lst, gt_folder, pred_folder, im_lst, target_suf, param_eval): # init results tmp_lst = [[] for _ in range(len(thr_pred_lst))] res_init_lst = [deepcopy(tmp_lst) for _ in range(len(thr_unc_lst))] res_dct = { 'tpr': deepcopy(res_init_lst), 'fdr': deepcopy(res_init_lst), 'retained_elt': [[] for _ in range(len(thr_unc_lst))] } # loop across images for fname_pref in im_lst: # uncertainty map fname_unc = os.path.join(pred_folder, fname_pref + unc_name + '.nii.gz') im = nib.load(fname_unc) data_unc = im.get_data() del im # list MC samples data_pred_lst = np.array([ nib.load(os.path.join(pred_folder, f)).get_data() for f in os.listdir(pred_folder) if fname_pref + '_pred_' in f ]) # ground-truth fname fname_gt = os.path.join(gt_folder, fname_pref.split('_')[0], 'anat', fname_pref + target_suf + '.nii.gz') if os.path.isfile(fname_gt): nib_gt = nib.load(fname_gt) data_gt = nib_gt.get_data() logger.debug(np.sum(data_gt)) # soft prediction data_soft = np.mean(data_pred_lst, axis=0) if np.any(data_soft): for i_unc, thr_unc in enumerate(thr_unc_lst): # discard uncertain lesions from data_soft data_soft_thrUnc = deepcopy(data_soft) data_soft_thrUnc[data_unc > thr_unc] = 0 cmpt = count_retained( (data_soft > 0).astype(np.int), (data_soft_thrUnc > 0).astype(np.int), level) res_dct['retained_elt'][i_unc].append(cmpt) logger.debug(f"{thr_unc} {cmpt}") for i_pred, thr_pred in enumerate(thr_pred_lst): data_hard = imed_postpro.threshold_predictions(deepcopy(data_soft_thrUnc), thr=thr_pred)\ .astype(np.uint8) eval = imed_utils.Evaluation3DMetrics( data_pred=data_hard, data_gt=data_gt, dim_lst=nib_gt.header['pixdim'][1:4], params=param_eval) if level == 'vox': tpr = imed_metrics.recall_score(eval.data_pred, eval.data_gt, err_value=np.nan) fdr = 100. - imed_metrics.precision_score( eval.data_pred, eval.data_gt, err_value=np.nan) else: tpr, _ = eval.get_ltpr() fdr = eval.get_lfdr() logger.debug( f"{thr_pred} {np.count_nonzero(deepcopy(data_soft_thrUnc))} " f"{np.count_nonzero(data_hard)} {tpr} {fdr}") res_dct['tpr'][i_unc][i_pred].append(tpr / 100.) res_dct['fdr'][i_unc][i_pred].append(fdr / 100.) return res_dct
def test_image_orientation(download_data_testing_test_files, loader_parameters): device = torch.device("cuda:" + str(GPU_ID) if torch.cuda.is_available() else "cpu") cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.set_device(device) logger.info(f"Using GPU ID {device}") bids_df = BidsDataframe(loader_parameters, __tmp_dir__, derivatives=True) contrast_params = loader_parameters["contrast_params"] target_suffix = loader_parameters["target_suffix"] roi_params = loader_parameters["roi_params"] train_lst = ['sub-unf01_T1w.nii.gz'] training_transform_dict = { "Resample": { "wspace": 1.5, "hspace": 1, "dspace": 3 }, "CenterCrop": { "size": [176, 128, 160] }, "NormalizeInstance": { "applied_to": ['im'] } } tranform_lst, training_undo_transform = imed_transforms.prepare_transforms( training_transform_dict) model_params = { "name": "Modified3DUNet", "dropout_rate": 0.3, "bn_momentum": 0.9, "depth": 2, "in_channel": 1, "out_channel": 1, "length_3D": [176, 128, 160], "stride_3D": [176, 128, 160], "attention": False, "n_filters": 8 } for dim in ['2d', '3d']: for slice_axis in [0, 1, 2]: if dim == '2d': ds = BidsDataset(bids_df=bids_df, subject_file_lst=train_lst, target_suffix=target_suffix, contrast_params=contrast_params, model_params=model_params, metadata_choice=False, slice_axis=slice_axis, transform=tranform_lst, multichannel=False) ds.load_filenames() else: ds = Bids3DDataset(bids_df=bids_df, subject_file_lst=train_lst, target_suffix=target_suffix, model_params=model_params, contrast_params=contrast_params, metadata_choice=False, slice_axis=slice_axis, transform=tranform_lst, multichannel=False) loader = DataLoader(ds, batch_size=1, shuffle=True, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=1) input_filename, gt_filename, roi_filename, metadata = ds.filename_pairs[ 0] segpair = SegmentationPair(input_filename, gt_filename, metadata=metadata, slice_axis=slice_axis) nib_original = nib.load(gt_filename[0]) # Get image with original, ras and hwd orientations input_init = nib_original.get_fdata() input_ras = nib.as_closest_canonical(nib_original).get_fdata() img, gt = segpair.get_pair_data() input_hwd = gt[0] pred_tmp_lst, z_tmp_lst = [], [] for i, batch in enumerate(loader): # batch["input_metadata"] = batch["input_metadata"][0] # Take only metadata from one input # batch["gt_metadata"] = batch["gt_metadata"][0] # Take only metadata from one label for smp_idx in range(len(batch['gt'])): # undo transformations if dim == '2d': preds_idx_undo, metadata_idx = training_undo_transform( batch["gt"][smp_idx], batch["gt_metadata"][smp_idx], data_type='gt') # add new sample to pred_tmp_lst pred_tmp_lst.append(preds_idx_undo[0]) z_tmp_lst.append( int(batch['input_metadata'][smp_idx][0] ['slice_index'])) else: preds_idx_undo, metadata_idx = training_undo_transform( batch["gt"][smp_idx], batch["gt_metadata"][smp_idx], data_type='gt') fname_ref = metadata_idx[0]['gt_filenames'][0] if (pred_tmp_lst and i == len(loader) - 1) or dim == '3d': # save the completely processed file as a nii nib_ref = nib.load(fname_ref) nib_ref_can = nib.as_closest_canonical(nib_ref) if dim == '2d': tmp_lst = [] for z in range(nib_ref_can.header.get_data_shape() [slice_axis]): tmp_lst.append( pred_tmp_lst[z_tmp_lst.index(z)]) arr = np.stack(tmp_lst, axis=-1) else: arr = np.array(preds_idx_undo[0]) # verify image after transform, undo transform and 3D reconstruction input_hwd_2 = imed_postpro.threshold_predictions(arr) # Some difference are generated due to transform and undo transform # (e.i. Resample interpolation) assert imed_metrics.dice_score(input_hwd_2, input_hwd) >= 0.8 input_ras_2 = imed_loader_utils.orient_img_ras( input_hwd_2, slice_axis) assert imed_metrics.dice_score(input_ras_2, input_ras) >= 0.8 input_init_2 = imed_loader_utils.reorient_image( input_hwd_2, slice_axis, nib_ref, nib_ref_can) assert imed_metrics.dice_score(input_init_2, input_init) >= 0.8 # re-init pred_stack_lst pred_tmp_lst, z_tmp_lst = [], []