Esempio n. 1
0
def threshold_analysis(model_path,
                       ds_lst,
                       model_params,
                       testing_params,
                       metric="dice",
                       increment=0.1,
                       fname_out="thr.png",
                       cuda_available=True):
    """Run a threshold analysis to find the optimal threshold on a sub-dataset.

    Args:
        model_path (str): Model path.
        ds_lst (list): List of loaders.
        model_params (dict): Model's parameters.
        testing_params (dict): Testing parameters
        metric (str): Choice between "dice" and "recall_specificity". If "recall_specificity", then a ROC analysis
            is performed.
        increment (float): Increment between tested thresholds.
        fname_out (str): Plot output filename.
        cuda_available (bool): If True, CUDA is available.

    Returns:
        float: optimal threshold.
    """
    if metric not in ["dice", "recall_specificity"]:
        raise ValueError(
            '\nChoice of metric for threshold analysis: dice, recall_specificity.'
        )

    # Adjust some testing parameters
    testing_params["uncertainty"]["applied"] = False

    # Load model
    model = torch.load(model_path)
    # Eval mode
    model.eval()

    # List of thresholds
    thr_list = list(np.arange(0.0, 1.0, increment))[1:]

    # Init metric manager for each thr
    metric_fns = [
        imed_metrics.recall_score, imed_metrics.dice_score,
        imed_metrics.specificity_score
    ]
    metric_dict = {
        thr: imed_metrics.MetricManager(metric_fns)
        for thr in thr_list
    }

    # Load
    loader = DataLoader(ConcatDataset(ds_lst),
                        batch_size=testing_params["batch_size"],
                        shuffle=False,
                        pin_memory=True,
                        sampler=None,
                        collate_fn=imed_loader_utils.imed_collate,
                        num_workers=0)

    # Run inference
    preds_npy, gt_npy = run_inference(loader,
                                      model,
                                      model_params,
                                      testing_params,
                                      ofolder=None,
                                      cuda_available=cuda_available)

    print('\nRunning threshold analysis to find optimal threshold')
    # Make sure the GT is binarized
    gt_npy = [threshold_predictions(gt, thr=0.5) for gt in gt_npy]
    # Move threshold
    for thr in tqdm(thr_list, desc="Search"):
        preds_thr = [
            threshold_predictions(copy.deepcopy(pred), thr=thr)
            for pred in preds_npy
        ]
        metric_dict[thr](preds_thr, gt_npy)

    # Get results
    tpr_list, fpr_list, dice_list = [], [], []
    for thr in thr_list:
        result_thr = metric_dict[thr].get_results()
        tpr_list.append(result_thr["recall_score"])
        fpr_list.append(1 - result_thr["specificity_score"])
        dice_list.append(result_thr["dice_score"])

    # Get optimal threshold
    if metric == "dice":
        diff_list = dice_list
    else:
        diff_list = [tpr - fpr for tpr, fpr in zip(tpr_list, fpr_list)]

    optimal_idx = np.max(np.where(diff_list == np.max(diff_list)))
    optimal_threshold = thr_list[optimal_idx]
    print('\tOptimal threshold: {}'.format(optimal_threshold))

    # Save plot
    print('\tSaving plot: {}'.format(fname_out))
    if metric == "dice":
        # Run plot
        imed_metrics.plot_dice_thr(thr_list, dice_list, optimal_idx, fname_out)
    else:
        # Add 0 and 1 as extrema
        tpr_list = [0.0] + tpr_list + [1.0]
        fpr_list = [0.0] + fpr_list + [1.0]
        optimal_idx += 1
        # Run plot
        imed_metrics.plot_roc_curve(tpr_list, fpr_list, optimal_idx, fname_out)

    return optimal_threshold
Esempio n. 2
0
def test_dice_plot():
    thr_list = [0.1, 0.3, 0.5, 0.7]
    dice_list = [0.6, 0.7, 0.8, 0.75]
    __output_file__ = Path(__tmp_dir__, "test_dice.png")
    imed_metrics.plot_dice_thr(thr_list, dice_list, 2, str(__output_file__))
    assert __output_file__.is_file()
Esempio n. 3
0
def test_dice_plot():
    thr_list = [0.1, 0.3, 0.5, 0.7]
    dice_list = [0.6, 0.7, 0.8, 0.75]
    imed_metrics.plot_dice_thr(thr_list, dice_list, 2, "test_dice.png")
    assert os.path.isfile("test_dice.png")
Esempio n. 4
0
def test_dice_plot():
    thr_list = [0.1, 0.3, 0.5, 0.7]
    dice_list = [0.6, 0.7, 0.8, 0.75]
    __output_file__ = os.path.join(__tmp_dir__, "test_dice.png")
    imed_metrics.plot_dice_thr(thr_list, dice_list, 2, __output_file__)
    assert os.path.isfile(__output_file__)