def threshold_analysis(model_path, ds_lst, model_params, testing_params, metric="dice", increment=0.1, fname_out="thr.png", cuda_available=True): """Run a threshold analysis to find the optimal threshold on a sub-dataset. Args: model_path (str): Model path. ds_lst (list): List of loaders. model_params (dict): Model's parameters. testing_params (dict): Testing parameters metric (str): Choice between "dice" and "recall_specificity". If "recall_specificity", then a ROC analysis is performed. increment (float): Increment between tested thresholds. fname_out (str): Plot output filename. cuda_available (bool): If True, CUDA is available. Returns: float: optimal threshold. """ if metric not in ["dice", "recall_specificity"]: raise ValueError( '\nChoice of metric for threshold analysis: dice, recall_specificity.' ) # Adjust some testing parameters testing_params["uncertainty"]["applied"] = False # Load model model = torch.load(model_path) # Eval mode model.eval() # List of thresholds thr_list = list(np.arange(0.0, 1.0, increment))[1:] # Init metric manager for each thr metric_fns = [ imed_metrics.recall_score, imed_metrics.dice_score, imed_metrics.specificity_score ] metric_dict = { thr: imed_metrics.MetricManager(metric_fns) for thr in thr_list } # Load loader = DataLoader(ConcatDataset(ds_lst), batch_size=testing_params["batch_size"], shuffle=False, pin_memory=True, sampler=None, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # Run inference preds_npy, gt_npy = run_inference(loader, model, model_params, testing_params, ofolder=None, cuda_available=cuda_available) print('\nRunning threshold analysis to find optimal threshold') # Make sure the GT is binarized gt_npy = [threshold_predictions(gt, thr=0.5) for gt in gt_npy] # Move threshold for thr in tqdm(thr_list, desc="Search"): preds_thr = [ threshold_predictions(copy.deepcopy(pred), thr=thr) for pred in preds_npy ] metric_dict[thr](preds_thr, gt_npy) # Get results tpr_list, fpr_list, dice_list = [], [], [] for thr in thr_list: result_thr = metric_dict[thr].get_results() tpr_list.append(result_thr["recall_score"]) fpr_list.append(1 - result_thr["specificity_score"]) dice_list.append(result_thr["dice_score"]) # Get optimal threshold if metric == "dice": diff_list = dice_list else: diff_list = [tpr - fpr for tpr, fpr in zip(tpr_list, fpr_list)] optimal_idx = np.max(np.where(diff_list == np.max(diff_list))) optimal_threshold = thr_list[optimal_idx] print('\tOptimal threshold: {}'.format(optimal_threshold)) # Save plot print('\tSaving plot: {}'.format(fname_out)) if metric == "dice": # Run plot imed_metrics.plot_dice_thr(thr_list, dice_list, optimal_idx, fname_out) else: # Add 0 and 1 as extrema tpr_list = [0.0] + tpr_list + [1.0] fpr_list = [0.0] + fpr_list + [1.0] optimal_idx += 1 # Run plot imed_metrics.plot_roc_curve(tpr_list, fpr_list, optimal_idx, fname_out) return optimal_threshold
def test_dice_plot(): thr_list = [0.1, 0.3, 0.5, 0.7] dice_list = [0.6, 0.7, 0.8, 0.75] __output_file__ = Path(__tmp_dir__, "test_dice.png") imed_metrics.plot_dice_thr(thr_list, dice_list, 2, str(__output_file__)) assert __output_file__.is_file()
def test_dice_plot(): thr_list = [0.1, 0.3, 0.5, 0.7] dice_list = [0.6, 0.7, 0.8, 0.75] imed_metrics.plot_dice_thr(thr_list, dice_list, 2, "test_dice.png") assert os.path.isfile("test_dice.png")
def test_dice_plot(): thr_list = [0.1, 0.3, 0.5, 0.7] dice_list = [0.6, 0.7, 0.8, 0.75] __output_file__ = os.path.join(__tmp_dir__, "test_dice.png") imed_metrics.plot_dice_thr(thr_list, dice_list, 2, __output_file__) assert os.path.isfile(__output_file__)