def plot_latent_reconstructions_2d_grid(
        model: torch.nn.Module,
        dim1: int,
        dim2: int,
        grid_size: int = 16,
        save_path: Path = None,
        **kwargs) -> Tuple[plt.Figure, plt.Axes]:
    outputs = []
    sample_generator = latent_space_analysis.latent_space_2d_grid(
        dim1=dim1, dim2=dim2, grid_size=grid_size)
    for bp in tqdm(sample_generator,
                   desc='Inferring latent space samples',
                   total=16):
        outputs.append(infer_latent_space_samples(model, bp))
    stacked = torch.cat(outputs, dim=0)
    fig, ax = plot_vae_output(stacked,
                              figsize=(20, 20),
                              one_channel=True,
                              axis='off',
                              nrow=grid_size,
                              transpose_grid=False,
                              add_colorbar=False,
                              vmax=3,
                              **kwargs)
    if save_path is not None:
        save_fig(fig, save_path)
    return fig, ax
Ejemplo n.º 2
0
def run_loss_term_histograms(model: nn.Module, train_dataloader: DataLoader,
                             val_dataloader: DataLoader,
                             eval_cfg: EvaluationConfig,
                             results: EvaluationResult) -> EvaluationResult:
    """Produce loss term histograms and store in the output."""
    if eval_cfg.do_plots:
        LOG.info('Producing model loss-term statistics histograms...')
        val_generator = yield_inference_batches(
            val_dataloader,
            model,
            residual_threshold=results.pixel_anomaly_result.best_threshold,
            max_batches=eval_cfg.use_n_batches,
            progress_bar_suffix='loss-terms val set')
        train_generator = yield_inference_batches(
            train_dataloader,
            model,
            residual_threshold=results.pixel_anomaly_result.best_threshold,
            max_batches=eval_cfg.use_n_batches,
            progress_bar_suffix='loss-terms train set')

        figs_axes = plot_loss_histograms(
            output_generators=[train_generator, val_generator],
            names=[f'{results.train_set_name}', f'{results.test_set_name}'],
            figsize=(15, 6),
            ylabel='Frequency',
            plot_density=True,
            show_data_ticks=False,
            kde_bandwidth=[0.009, 0.009 * 5],
            show_histograms=False)
        for idx, (fig, _) in enumerate(figs_axes):
            save_fig(
                fig,
                results.plot_dir_path / f'loss_term_distributions_{idx}.png')
    return results
def plot_latent_reconstructions_multiple_dims(
        model: torch.nn.Module,
        latent_space_dims: int = 128,
        n_samples_per_dim: int = 16,
        save_path: Path = None,
        **kwargs) -> Tuple[plt.Figure, plt.Axes]:
    """Plot latent space sample reconstructions for multiple dimensions, i.e. cover all cases where all dimensions
    but one are fixed."""
    outputs = []
    sample_generator = latent_space_analysis.yield_samples_all_dims(
        latent_space_dims, n_samples_per_dim)
    for bp in tqdm(sample_generator,
                   desc='Inferring latent space samples',
                   total=latent_space_dims):
        outputs.append(infer_latent_space_samples(model, bp))
    stacked = torch.cat(outputs, dim=0)
    fig, ax = plot_vae_output(stacked,
                              figsize=(20, 20),
                              one_channel=True,
                              axis='off',
                              nrow=n_samples_per_dim,
                              transpose_grid=True,
                              add_colorbar=False,
                              vmax=3,
                              **kwargs)
    if save_path is not None:
        save_fig(fig, save_path)
    return fig, ax
Ejemplo n.º 4
0
def run_segmentation_performance(eval_cfg: EvaluationConfig,
                                 results: EvaluationResult,
                                 val_dataloader: DataLoader,
                                 model: nn.Module) -> EvaluationResult:
    if not has_segmentation(val_dataloader.dataset):  # TODO: Fix typing issue.
        LOG.warning(
            f'Skipping anomaly detection for {val_dataloader.dataset.name}')
        return results
    perf_cfg = eval_cfg.seg_performance_config
    LOG.info(
        f'Running segmentation performance '
        f'(residual threshold = {results.pixel_anomaly_result.best_threshold:.2f})'
    )

    # Calculate scores for multiple pixel thresholds
    if perf_cfg.do_multiple_thresholds:
        pixel_thresholds = np.linspace(perf_cfg.min_val, perf_cfg.max_val,
                                       perf_cfg.num_values)
        LOG.info(
            f'Calculating segmentation (DICE / IOU) performance for residual thresholds: {list(pixel_thresholds)}'
        )
        mean_dice_scores, std_dice_score = mean_std_dice_scores(
            val_dataloader, model, pixel_thresholds, eval_cfg.use_n_batches)
        if perf_cfg.do_iou:
            mean_iou_scores, std_iou_scores = mean_std_iou_scores(
                val_dataloader, model, pixel_thresholds,
                eval_cfg.use_n_batches)
        else:
            mean_iou_scores, std_iou_scores = None, None
        if eval_cfg.do_plots:
            threshold = results.pixel_anomaly_result.best_threshold
            segmentation_performance_fig = plot_segmentation_performance_vs_threshold(
                thresholds=pixel_thresholds,
                dice_scores=mean_dice_scores,
                dice_stds=std_dice_score,
                iou_scores=mean_iou_scores,
                iou_stds=std_iou_scores,
                train_set_threshold=threshold)
            save_fig(segmentation_performance_fig,
                     results.plot_dir_path / 'seg_performance_vs_thresh.png')

    # If best threshold is calculated already, calculate segmentation score as well
    if results.pixel_anomaly_result.best_threshold is not None:
        best_mean_dice_score, best_std_dice_score = mean_std_dice_scores(
            val_dataloader, model,
            [results.pixel_anomaly_result.best_threshold],
            eval_cfg.use_n_batches)
        results.pixel_anomaly_result.per_patient_dice_score_mean = best_mean_dice_score[
            0]
        results.pixel_anomaly_result.per_patient_dice_score_std = best_std_dice_score[
            0]
        LOG.info(
            f'Calculated best dice score (t={results.pixel_anomaly_result.best_threshold:.2f}): '
            f'{best_mean_dice_score[0]:.2f} +- {best_std_dice_score[0]:.2f}')
    return results
Ejemplo n.º 5
0
def run_residual_threshold_evaluation(
        model: nn.Module, train_dataloader: DataLoader,
        eval_cfg: EvaluationConfig,
        results: EvaluationResult) -> EvaluationResult:
    """Search for best threshold given an accepted FPR and update the results dict."""
    thresh_cfg = eval_cfg.thresh_search_config

    # Find best threshold via GSS search
    objective = partial(calculate_fpr_minus_accepted,
                        accepted_fpr=thresh_cfg.accepted_fpr,
                        data_loader=train_dataloader,
                        model=model,
                        use_ground_truth=False,
                        n_batches_per_thresh=eval_cfg.use_n_batches)
    best_threshold = golden_section_search(objective,
                                           low=thresh_cfg.gss_lower_val,
                                           up=thresh_cfg.gss_upper_val,
                                           tolerance=thresh_cfg.gss_tolerance,
                                           return_mean=True)
    results.best_threshold = best_threshold

    # Create threshold plots
    if eval_cfg.do_plots:
        pixel_thresholds = np.linspace(thresh_cfg.min_val, thresh_cfg.max_val,
                                       thresh_cfg.num_values)
        LOG.info(f'Producing FPR vs residual threshold plots with '
                 f'accepted FPR ({thresh_cfg.accepted_fpr:.2f}) '
                 f'checking on pixel thresholds {list(pixel_thresholds)}...')
        thresholds, train_false_positive_rates = threshold_vs_fpr(
            train_dataloader,
            model,
            thresholds=pixel_thresholds,
            use_ground_truth=False,
            n_batches_per_thresh=eval_cfg.use_n_batches)
        fpr_vs_threshold_fig = plot_fpr_vs_residual_threshold(
            accepted_fpr=thresh_cfg.accepted_fpr,
            calculated_threshold=best_threshold,
            thresholds=pixel_thresholds,
            fpr_train=train_false_positive_rates)
        save_fig(fpr_vs_threshold_fig,
                 results.plot_dir_path / 'fpr_vs_threshold.png')
    LOG.info(f'Calculated residual threshold: {best_threshold}')
    return results
def plot_latent_reconstructions_one_dim_changing(
        trained_model: torch.nn.Module,
        change_dim_idx: int,
        n_samples: int,
        save_path: Path = None,
        **kwargs) -> Tuple[plt.Figure, plt.Axes]:
    """Sampling (0, 0, ..., 0) in latent space and then varying one single dimension and check the impact."""
    latent_samples = latent_space_analysis.get_latent_samples_one_dim(
        dim_idx=change_dim_idx, n_samples=n_samples)
    output = infer_latent_space_samples(trained_model, latent_samples)
    fig, ax = plot_vae_output(output,
                              figsize=(20, 20),
                              vmax=3,
                              one_channel=True,
                              axis='off',
                              nrow=n_samples,
                              add_colorbar=False,
                              **kwargs)
    if save_path is not None:
        save_fig(fig, save_path, **kwargs)
    return fig, ax
def plot_random_latent_space_samples(model: torch.nn.Module,
                                     n_samples: int = 16,
                                     latent_space_dims: int = 128,
                                     save_path: Path = None,
                                     **kwargs) -> Tuple[plt.Figure, plt.Axes]:
    """Sample zero mean unit variance samples from latent space and visualize."""
    latent_samples = torch.normal(mean=0,
                                  std=torch.ones((
                                      n_samples,
                                      latent_space_dims,
                                  )))
    output = infer_latent_space_samples(model, latent_samples)
    fig, ax = plot_vae_output(output,
                              figsize=(20, 20),
                              one_channel=True,
                              axis='off',
                              add_colorbar=False,
                              vmax=3,
                              **kwargs)
    if save_path is not None:
        save_fig(fig, save_path, **kwargs)
    return fig, ax
Ejemplo n.º 8
0
def plot_roc_prc_slice_wise_lesion_detection_dose_kde(train_loader_name,
                                                      test_loader_name,
                                                      post_kde_stat,
                                                      fpr,
                                                      tpr,
                                                      au_roc,
                                                      recall,
                                                      precision,
                                                      au_prc,
                                                      healthy_id_kde,
                                                      healthy_post_kde,
                                                      lesional_post_kde,
                                                      id_healthy_stat=None,
                                                      healthy_stat=None,
                                                      lesional_stat=None,
                                                      save_figs: bool = True,
                                                      show_legend: bool = True,
                                                      show_title: bool = True):
    kde_hist_fig_title = f'{train_loader_name}_{test_loader_name}_kde_hist_{post_kde_stat}'
    array_labels = None if not show_legend else [
        f'{train_loader_name} healthy', f'{test_loader_name} healthy OOD',
        f'{test_loader_name} lesional OOD'
    ]
    if post_kde_stat == 'waic':
        xlabel = 'waic'
    else:
        xlabel = '$DoSE_{KDE}($' + LABEL_MAP[post_kde_stat] + ')'
    kde_hist_fig, _ = plot_multi_histogram([
        np.array(healthy_id_kde),
        np.array(healthy_post_kde),
        np.array(lesional_post_kde)
    ],
                                           array_labels,
                                           plot_density=False,
                                           show_data_ticks=False,
                                           legend_pos='upper left',
                                           figsize=(3, 3),
                                           xlabel=xlabel)
    if save_figs:
        save_fig(kde_hist_fig,
                 DEFAULT_PLOT_DIR_PATH / f'{kde_hist_fig_title}.png')

    if healthy_stat is not None and lesional_stat is not None and id_healthy_stat is not None:
        stat_hist_fig_title = f'{train_loader_name}_{test_loader_name}_stat_hist_{post_kde_stat}'
        stat_hist_fig, _ = plot_multi_histogram(
            [
                np.array(id_healthy_stat),
                np.array(healthy_stat),
                np.array(lesional_stat)
            ],
            array_labels,
            plot_density=False,
            show_data_ticks=False,
            legend_pos='upper left',
            figsize=(3, 3),
            xlabel=LABEL_MAP[post_kde_stat])
        if save_figs:
            save_fig(stat_hist_fig,
                     DEFAULT_PLOT_DIR_PATH / f'{stat_hist_fig_title}.png')

    roc_fig = plot_roc_curve(
        fpr,
        tpr,
        au_roc,
        title=f'ROC SW Anomaly Detection {test_loader_name} {post_kde_stat}'
        if show_title else None,
        figsize=(3, 3))
    if save_figs:
        save_fig(
            roc_fig, DEFAULT_PLOT_DIR_PATH /
            f'roc_sw_{test_loader_name}_{post_kde_stat}')

    prc_fig = plot_precision_recall_curve(
        recall,
        precision,
        au_prc,
        title=f'PRC SW Anomaly Detection {test_loader_name} {post_kde_stat}'
        if show_title else None,
        figsize=(3, 3))
    if save_figs:
        save_fig(
            prc_fig, DEFAULT_PLOT_DIR_PATH /
            f'prc_sw_{test_loader_name}_{post_kde_stat}')

    print(f'\tau_roc: {au_roc:.2f}')
    print(f'\tau_prc: {au_prc:.2f}')
Ejemplo n.º 9
0
def calculate_and_update_ood_performance(metric_name: str,
                                         ood_dict: Dict[str, Dict[str, Union[
                                             List[float], List[bool]]]],
                                         eval_cfg: EvaluationConfig,
                                         results: EvaluationResult,
                                         invert_scores: bool = False) -> None:
    """Calculate OOD Detection performance, update results and create plots for given metric and ood_dict.
    Arguments
        metric_name: used to calculate the correct ood scores based on the metric
        ood_dict: the actual ood numbers, this should be in the format
                  ood_dict = {
                                'ood': {
                                    'ood_scores': [...],
                                    'is_lesional': [...]
                                },
                                'in_distribution': {
                                    'ood_scores': [...],
                                    'is_lesional': [...]
                                }
                            }
        eval_cfg: used for configuration
        results: this function will update the results object
        invert_scores: some metrics have high score for in-distr. while others have low, so multiply by -1 if true

    """
    # Only for calculations, will invert back for plotting
    if invert_scores:
        for ood_key in ['ood', 'in_distribution']:
            ood_dict[ood_key]['ood_scores'] = [
                -val for val in ood_dict[ood_key]['ood_scores']
            ]

    y_pred_proba_all = [
    ]  # These are the e.g. OOD scores for both in- and OOD data!
    y_true_all = []
    y_pred_proba_lesional = []
    y_true_lesional = []
    y_pred_proba_healthy = []
    y_true_healthy = []

    # First use val_dataloader
    LOG.info(f'Inference on out-of-distribution data...')
    ood_scores = ood_dict['ood']['ood_scores']
    lesional_list = ood_dict['ood']['is_lesional']
    y_pred_proba_all.extend(ood_scores)
    y_true_all.extend(len(ood_scores) * [1.0])

    lesional_indices = [
        idx for idx, is_lesional in enumerate(lesional_list) if is_lesional
    ]
    normal_indices = [
        idx for idx, is_lesional in enumerate(lesional_list) if not is_lesional
    ]

    lesional_ood_scores = [ood_scores[idx] for idx in lesional_indices]
    healthy_odd_scores = [ood_scores[idx] for idx in normal_indices]

    n_lesional_samples = len(lesional_ood_scores)
    y_pred_proba_lesional.extend(lesional_ood_scores)
    y_true_lesional.extend(n_lesional_samples * [1.0])

    n_healthy_samples = len(healthy_odd_scores)
    y_pred_proba_healthy.extend(healthy_odd_scores)
    y_true_healthy.extend(n_healthy_samples * [1.0])

    # Now use train_dataloader
    LOG.info(f'Inference on in-distribution data...')
    ood_scores = ood_dict['in_distribution']['ood_scores']

    y_pred_proba_all.extend(ood_scores)
    y_true_all.extend(len(ood_scores) * [0.0])

    # Train dataloader has no lesional samples, so fill up lesional and healthy ones from above with same amount
    y_pred_proba_lesional.extend(random.sample(ood_scores, n_lesional_samples))
    y_true_lesional.extend(n_lesional_samples * [0.0])

    # We can sample a maximum of len(ood_scores) samples from train data
    y_pred_proba_healthy.extend(
        random.sample(ood_scores, min(n_healthy_samples, len(ood_scores))))
    y_true_healthy.extend(n_healthy_samples * [0.0])

    def calculate_ood_performance(mode: str, y_true: list,
                                  y_pred_proba: list) -> dict:
        if mode not in ['all', 'lesional', 'healthy']:
            raise ValueError(
                f'Given mode ({mode}) is not valid! Choose from all, lesional, healthy.'
            )
        fpr, tpr, _, au_roc = calculate_roc(y_true, y_pred_proba)
        precision, recall, _, au_prc = calculate_prc(y_true, y_pred_proba)
        out_dict = {
            'tpr': tpr,
            'fpr': fpr,
            'au_roc': au_roc,
            'precision': precision,
            'recall': recall,
            'au_prc': au_prc
        }
        return out_dict

    def update_results(ood_scores_: dict, mode_: str, metric_: str) -> None:
        ood_result = OODDetectionResult()
        ood_result.au_roc = ood_scores_['au_roc']
        ood_result.au_prc = ood_scores_['au_prc']
        ood_result.mode = mode_
        ood_result.metric = metric_
        results.ood_detection_results.results.append(ood_result)

    # Will be filled with ood scores dict for each mode ('all', 'healthy', 'lesional') with
    metrics_dict = {}
    for mode, y_true, y_pred_proba in zip(
        ['all', 'healthy', 'lesional'],
        [y_true_all, y_true_healthy, y_true_lesional],
        [y_pred_proba_all, y_pred_proba_healthy, y_pred_proba_lesional]):
        if mode in ['healthy', 'lesional'] and not sum(lesional_list) > 0:
            continue
        scores = calculate_ood_performance(mode, y_true, y_pred_proba)
        update_results(scores, mode, metric_name)
        metrics_dict[mode] = scores

    # ROC & PCR curves
    if eval_cfg.do_plots:
        labels = list(metrics_dict.keys())
        fprs = [score_dict['fpr'] for score_dict in metrics_dict.values()]
        tprs = [score_dict['tpr'] for score_dict in metrics_dict.values()]
        au_rocs = [
            score_dict['au_roc'] for score_dict in metrics_dict.values()
        ]
        recalls = [
            score_dict['recall'] for score_dict in metrics_dict.values()
        ]
        precisions = [
            score_dict['precision'] for score_dict in metrics_dict.values()
        ]
        au_prcs = [
            score_dict['au_prc'] for score_dict in metrics_dict.values()
        ]
        roc_fig = plot_multi_roc_curves(
            fprs,
            tprs,
            au_rocs,
            labels,
            title=f'{metric_name} ROC Curve OOD Detection',
            figsize=(6, 6))
        prc_fig = plot_multi_prc_curves(
            recalls,
            precisions,
            au_prcs,
            labels,
            title=f'{metric_name} PR Curve OOD Detection',
            figsize=(6, 6))
        save_fig(roc_fig, results.plot_dir_path / f'ood_roc_{metric_name}.png')
        save_fig(prc_fig, results.plot_dir_path / f'ood_prc_{metric_name}.png')
        plt.close(roc_fig)
        plt.close(prc_fig)

    # Slice-wise OOD score distributions
    if eval_cfg.do_plots:
        # ood healthy, ood lesional, in healthy
        ood_scores_healthy = [
            score for idx, score in enumerate(y_pred_proba_healthy)
            if y_true_healthy[idx] == 1.0
        ]
        ood_scores_lesional = [
            score for idx, score in enumerate(y_pred_proba_lesional)
            if y_true_lesional[idx] == 1.0
        ]
        id_scores_healthy = [
            score for idx, score in enumerate(y_pred_proba_healthy)
            if y_true_healthy[idx] == 0.0
        ]

        arrays = []
        labels = []
        for scores_list, label in zip(
            [ood_scores_healthy, ood_scores_lesional, id_scores_healthy],
            [f'healthy OOD', f'lesional OOD', f'healthy in-distr.']):
            if len(scores_list) > 0:
                array = np.array(scores_list)
                if invert_scores:
                    array = -1 * array
                arrays.append(array)
                labels.append(label)

        fig, _ = plot_multi_histogram(
            arrays,
            labels,
            show_data_ticks=False,
            plot_density=False,
            title=f'OOD scores {results.test_set_name}',
            xlabel=f'{metric_name}')
        save_fig(fig, results.plot_dir_path / f'ood_scores_{metric_name}.png')
        plt.close(fig)
Ejemplo n.º 10
0
def run_anomaly_detection_performance(
        eval_config: EvaluationConfig, model: nn.Module,
        val_dataloader: DataLoader,
        results: EvaluationResult) -> EvaluationResult:
    """Perform pixel-level anomaly detection performance evaluation."""
    LOG.info(
        'Calculating pixel-wise anomaly detection performance (ROC, PRC)...')
    if not has_segmentation(val_dataloader.dataset):  # TODO: Fix typing issue.
        LOG.warning(
            f'Skipping anomaly detection for {val_dataloader.dataset.name}')
        return results
    anomaly_scores = yield_anomaly_predictions(
        val_dataloader,
        model,
        eval_config.use_n_batches,
        results.pixel_anomaly_result.best_threshold,
        use_mask=eval_config.use_masked_loss)

    # Step 1 - Pixel-wise anomaly detection
    y_true = anomaly_scores.pixel_wise.y_true
    y_pred_proba = anomaly_scores.pixel_wise.y_pred_proba
    fpr, tpr, threshs, au_roc = calculate_roc(y_true, y_pred_proba)
    precision, recall, threshs, au_prc = calculate_prc(y_true, y_pred_proba)
    results.pixel_anomaly_result.au_prc = au_prc
    results.pixel_anomaly_result.au_roc = au_roc
    LOG.info(
        f'Pixel-wise anomaly detection performance: AUROC {au_roc}, AUPRC {au_prc}'
    )

    if eval_config.do_plots:
        roc_fig = plot_roc_curve(
            fpr,
            tpr,
            au_roc,
            # calculated_threshold=results.pixel_anomaly_result.best_threshold,
            # thresholds=threshs,
            title=f'ROC Curve Pixel-wise Anomaly Detection',
            figsize=(6, 6))
        save_fig(roc_fig, results.plot_dir_path / f'pixel_wise_roc.png')
        plt.close(roc_fig)

        prc_fig = plot_precision_recall_curve(
            recall,
            precision,
            au_prc,
            # calculated_threshold=results.pixel_anomaly_result.best_threshold,
            # thresholds=threshs,
            title=f'PR Curve Pixel-wise Anomaly Detection',
            figsize=(6, 6))
        save_fig(prc_fig, results.plot_dir_path / 'pixel_wise_prc.png')
        plt.close(prc_fig)
        """ Buggy :-(
        if anomaly_scores.pixel_wise.y_pred is not None:
            conf_matrix = confusion_matrix(anomaly_scores.pixel_wise.y_true, anomaly_scores.pixel_wise.y_pred)
            confusion_matrix_fig, _ = plot_confusion_matrix(conf_matrix, categories=['normal', 'anomaly'],
                                                            cbar=False, cmap='YlOrRd_r', figsize=(6, 6))
            confusion_matrix_fig.savefig(results.plot_dir_path / 'pixel_wise_confusion_matrix.png')
            plt.close(confusion_matrix_fig)
        """
    # TODO: Calculate Recall, Precision, Accuracy, F1 score from confusion matrix.

    # Step 2 - Slice-wise anomaly detection
    LOG.info(
        'Calculating slice-wise anomaly detection performance (ROC, PRC)...')

    reverse_score_map = {
        'REC_TERM': False,
        'KL_TERM': False,
        'ELBO': False
    }  # whether to negate the scores

    for criteria in SliceWiseCriteria:
        y_true = anomaly_scores.slice_wise[criteria.name].anomaly_score.y_true
        y_pred_proba = anomaly_scores.slice_wise[
            criteria.name].anomaly_score.y_pred_proba

        if reverse_score_map[criteria.name]:
            y_pred_proba = [-val for val in y_pred_proba]
        fpr, tpr, _, au_roc = calculate_roc(y_true, y_pred_proba)
        precision, recall, _, au_prc = calculate_prc(y_true, y_pred_proba)
        anomaly_result_for_criteria = SliceAnomalyDetectionResult()
        anomaly_result_for_criteria.criteria = criteria.name
        anomaly_result_for_criteria.au_prc = au_prc
        anomaly_result_for_criteria.au_roc = au_roc
        results.slice_anomaly_results.results.append(
            anomaly_result_for_criteria)
        LOG.info(
            f'Slice-wise anomaly detection metric {criteria.name}: AUROC {au_roc:.2f}, AUPRC {au_prc:.2f}'
        )

        if eval_config.do_plots:
            roc_fig = plot_roc_curve(
                fpr,
                tpr,
                au_roc,
                title=
                f'ROC Curve Sample-wise Anomaly Detection [{str(criteria.name)}]',
                figsize=(6, 6))
            save_fig(
                roc_fig, results.plot_dir_path /
                f'sample_wise_roc_{str(criteria.name)}.png')
            plt.close(roc_fig)

            prc_fig = plot_precision_recall_curve(
                recall,
                precision,
                au_prc,
                title=f'PR Curve Sample-wise Anomaly Detection '
                f'[{str(criteria.name)}]',
                figsize=(6, 6))
            save_fig(
                prc_fig, results.plot_dir_path /
                f'sample_wise_prc_{str(criteria.name)}.png')
            plt.close(prc_fig)

    return results
Ejemplo n.º 11
0
def plot_stacked_scan_reconstruction_batches(batch_generator: Generator[
    BatchInferenceResult, None, None],
                                             plot_n_batches: int = 3,
                                             nrow: int = 8,
                                             show_mask: bool = False,
                                             save_dir_path: Path = None,
                                             cut_tensor_to: int = None,
                                             mask_background: bool = True,
                                             close_fig: bool = False,
                                             **kwargs) -> None:
    """Plot the scan and reconstruction batches. Horizontally aligned are the samples from one batch.
    Vertically aligned are input image, ground truth segmentation, reconstruction, residual image, residual with
    applied threshold, ground truth and predicted segmentation in same image.
    Args:
        batch_generator: a PyTorch DataLoader as defined in the uncertify dataloaders module
        plot_n_batches: limit plotting to this amount of batches
        save_dir_path: path to directory in which to store the resulting plots - will be created if not existent
        nrow: numbers of samples in one row, default is 8
        show_mask: plots the brain mask
        cut_tensor_to: choose 8 if you need only first 8 samples but for plotting but the original tensor is way larger
        close_fig: if True, will not show in notebook, handy for use in large pipeline
        kwargs: additional keyword arguments for plotting functions
    """
    if save_dir_path is not None:
        save_dir_path.mkdir(exist_ok=True)
    with torch.no_grad():
        for batch_idx, batch in enumerate(
                itertools.islice(batch_generator, plot_n_batches)):
            mask = batch.mask
            max_val = torch.max(batch.scan)
            min_val = torch.min(batch.scan)

            scan = normalize_to_0_1(batch.scan, min_val, max_val)
            reconstruction = normalize_to_0_1(batch.reconstruction, min_val,
                                              max_val)
            residual = normalize_to_0_1(batch.residual)
            thresholded = batch.residuals_thresholded

            if mask_background:
                scan = mask_background_to_zero(scan, mask)
                reconstruction = mask_background_to_zero(reconstruction, mask)
                residual = mask_background_to_zero(residual, mask)
                thresholded = mask_background_to_zero(thresholded, mask)

            if batch.segmentation is not None:
                seg = mask_background_to_zero(batch.segmentation, mask)
                stacked = torch.cat(
                    (scan, seg, reconstruction, residual, thresholded), dim=2)
            else:
                stacked = torch.cat(
                    (scan, reconstruction, residual, thresholded), dim=2)
            if show_mask:
                stacked = torch.cat((stacked, mask.type(torch.FloatTensor)),
                                    dim=2)
            if cut_tensor_to is not None:
                stacked = stacked[:cut_tensor_to, ...]
            grid = torchvision.utils.make_grid(stacked,
                                               padding=0,
                                               normalize=False,
                                               nrow=nrow)
            describe = scipy.stats.describe(grid.numpy().flatten())
            print_scipy_stats_description(describe, 'normalized_grid')
            fig, ax = imshow_grid(grid,
                                  one_channel=True,
                                  vmax=1.0,
                                  vmin=0.0,
                                  plt_show=False,
                                  **kwargs)
            ax.set_axis_off()
            plt.show()
            if save_dir_path is not None:
                img_file_name = f'batch_{batch_idx}.png'
                save_fig(fig, save_dir_path / img_file_name)
                if close_fig:
                    plt.close(fig)