Exemplo n.º 1
0
def plot_patient_histograms(dataloader: DataLoader,
                            n_batches: int,
                            accumulate_batches: bool = False,
                            bins: int = 40,
                            uppercase_keys: bool = False):
    """Plot the batch-wise intensity histograms.

    Arguments
        dataloader: a hdf5 dataloader
        plot_n_batches: how many batches to take into account
        accumulate_batches: if True, stack all values from all batches and report one histogram
                            if False, do one histogram for every batch in the figure
        bins: number of bins in histograms
        uppercase_keys: if True supports legacy upper case keys
    """
    accumulated_values = []
    for idx, batch in enumerate(dataloader):
        mask = batch['mask' if not uppercase_keys else 'Mask'].cpu().detach(
        ).numpy()
        scan = batch['scan' if not uppercase_keys else 'Scan'].cpu().detach(
        ).numpy()
        masked_pixels = scan[mask != 0].flatten()
        accumulated_values.append(masked_pixels)
        if idx + 1 == n_batches:
            break

    if accumulate_batches:
        values = np.concatenate(accumulated_values)
        plot_multi_histogram(
            arrays=[values],
            plot_density=False,  # KDE
            title='Accumulated Intensities Histogram',
            xlabel='Pixel Intensity',
            hist_kwargs=dict(bins=bins),
            figsize=(12, 8),
        )
        plt.show()
    else:
        plot_multi_histogram(
            arrays=accumulated_values,
            labels=[
                f'Batch {idx + 1}' for idx in range(len(accumulated_values))
            ],
            plot_density=False,  # KDE
            title='Batch-wise intensity Histograms',
            xlabel='Pixel Intensity',
            hist_kwargs=dict(bins=bins),
            figsize=(12, 8),
        )
        plt.show()
Exemplo n.º 2
0
def plot_abnormal_pixel_distribution(
        data_loader: DataLoader, **hist_kwargs) -> Tuple[plt.Figure, plt.Axes]:
    """For a dataset with given ground truth, plot the distribution of fraction of abnormal pixels in an image.

    This is done sample-wise, i.e.

    Note: Only pixels within the brain mask are considered and only samples with abnormal pixels are considered.
    """
    normal_pixels, abnormal_pixels, total_masked_pixels = get_n_normal_abnormal_pixels(
        data_loader)

    fig, ax = plot_multi_histogram(
        arrays=[
            np.array(normal_pixels),
            np.array(abnormal_pixels),
            np.array(total_masked_pixels)
        ],
        labels=['Normal pixels', 'Abnormal pixels', 'Mask Size'],
        plot_density=False,
        title=
        'Distribution of the sample-wise number of normal / abnormal pixels',
        xlabel='Number of pixels',
        ylabel='Frequency',
        **hist_kwargs)
    return fig, ax
Exemplo n.º 3
0
def plot_ood_scores(ood_dataset_dict: dict,
                    score_label: str = 'WAIC',
                    dataset_name_filters: list = None,
                    modes_to_include: list = None,
                    do_save: bool = True) -> None:
    """Plot OOD score distribution histogram for different datasets, all, healthy and/or unhealthy.

    Arguments
        ood_dataset_dict: a dictionary with dataset names as keys and a dict like
                         {'all': [scores], 'healthy': [scores], ...}
        score_label: the name of the OOD score used
        dataset_name_filters: a list of words for which datasets are excluded if some are in their name
        modes_to_include: a list with 'all', 'healthy', 'lesional' potential entries, if None, all will be considered
    """
    if dataset_name_filters is None:
        dataset_name_filters = []
    if modes_to_include is None:
        modes_to_include = ['all', 'lesional', 'healthy']

    waic_lists = []
    list_labels = []

    for dataset_name, sub_dict in ood_dataset_dict.items():
        if any([
                filter_word in dataset_name
                for filter_word in dataset_name_filters
        ]):
            continue
        has_only_healthy = len(sub_dict['lesional']) == 0
        if has_only_healthy:
            ood_scores = sub_dict['healthy']
            label = f'{dataset_name}'
            list_labels.append(label)
            waic_lists.append(ood_scores)
        else:
            for mode in modes_to_include:
                ood_scores = sub_dict[mode]
                label = f'{dataset_name} {mode}'
                list_labels.append(label)
                waic_lists.append(ood_scores)

    fig, _ = plot_multi_histogram(waic_lists,
                                  list_labels,
                                  plot_density=False,
                                  figsize=(12, 6),
                                  xlabel=score_label,
                                  ylabel='Slice-wise frequency',
                                  hist_kwargs={'bins': 17})
    if do_save:
        save_path = DATA_DIR_PATH / 'plots' / f'{score_label}.png'
        fig.savefig(save_path)
        LOG.info(f'Saved OOD score figure at: {save_path}')
def plot_gaussian_annulus_distribution(latent_space_dims: int = 128,
                                       n_samples: int = 1000) -> plt.Figure:
    latent_samples = torch.normal(mean=0,
                                  std=torch.ones((
                                      n_samples,
                                      latent_space_dims,
                                  )))
    norms = torch.norm(latent_samples, dim=1).numpy()

    fig, ax = plot_multi_histogram([norms], ['Distance to origin'],
                                   show_data_ticks=True,
                                   show_histograms=True)
    ax.set_xticks([10, 11.3, 13])
    return fig
Exemplo n.º 5
0
def plot_fraction_of_abnormal_pixels(
        data_loader: DataLoader, **hist_kwargs) -> Tuple[plt.Figure, plt.Axes]:
    """For a dataset with given ground truth, plot the distribution of fraction of abnormal pixels in an image.

    Note: Only pixels within the brain mask are considered and only samples with abnormal pixels are considered.
    """
    normal_pixels, abnormal_pixels, total_masked_pixels = get_n_normal_abnormal_pixels(
        data_loader)
    fractions = []
    for normal, total in zip(abnormal_pixels, total_masked_pixels):
        fraction = normal / total
        fractions.append(fraction)
    percentile_5 = np.percentile(fractions, q=5)
    fig, ax = plot_multi_histogram(
        arrays=[np.array(fractions)],
        labels=None,
        plot_density=True,
        kde_bandwidth=0.02,
        xlabel='Fraction of abnormal pixels from all pixels within brain masks',
        ylabel='Frequency',
        **hist_kwargs)
    ax.plot([percentile_5, percentile_5], [0, 3], 'g--',
            linewidth=2)  # TODO: Hardcoded.
    return fig, ax
Exemplo n.º 6
0
def plot_roc_prc_slice_wise_lesion_detection_dose_kde(train_loader_name,
                                                      test_loader_name,
                                                      post_kde_stat,
                                                      fpr,
                                                      tpr,
                                                      au_roc,
                                                      recall,
                                                      precision,
                                                      au_prc,
                                                      healthy_id_kde,
                                                      healthy_post_kde,
                                                      lesional_post_kde,
                                                      id_healthy_stat=None,
                                                      healthy_stat=None,
                                                      lesional_stat=None,
                                                      save_figs: bool = True,
                                                      show_legend: bool = True,
                                                      show_title: bool = True):
    kde_hist_fig_title = f'{train_loader_name}_{test_loader_name}_kde_hist_{post_kde_stat}'
    array_labels = None if not show_legend else [
        f'{train_loader_name} healthy', f'{test_loader_name} healthy OOD',
        f'{test_loader_name} lesional OOD'
    ]
    if post_kde_stat == 'waic':
        xlabel = 'waic'
    else:
        xlabel = '$DoSE_{KDE}($' + LABEL_MAP[post_kde_stat] + ')'
    kde_hist_fig, _ = plot_multi_histogram([
        np.array(healthy_id_kde),
        np.array(healthy_post_kde),
        np.array(lesional_post_kde)
    ],
                                           array_labels,
                                           plot_density=False,
                                           show_data_ticks=False,
                                           legend_pos='upper left',
                                           figsize=(3, 3),
                                           xlabel=xlabel)
    if save_figs:
        save_fig(kde_hist_fig,
                 DEFAULT_PLOT_DIR_PATH / f'{kde_hist_fig_title}.png')

    if healthy_stat is not None and lesional_stat is not None and id_healthy_stat is not None:
        stat_hist_fig_title = f'{train_loader_name}_{test_loader_name}_stat_hist_{post_kde_stat}'
        stat_hist_fig, _ = plot_multi_histogram(
            [
                np.array(id_healthy_stat),
                np.array(healthy_stat),
                np.array(lesional_stat)
            ],
            array_labels,
            plot_density=False,
            show_data_ticks=False,
            legend_pos='upper left',
            figsize=(3, 3),
            xlabel=LABEL_MAP[post_kde_stat])
        if save_figs:
            save_fig(stat_hist_fig,
                     DEFAULT_PLOT_DIR_PATH / f'{stat_hist_fig_title}.png')

    roc_fig = plot_roc_curve(
        fpr,
        tpr,
        au_roc,
        title=f'ROC SW Anomaly Detection {test_loader_name} {post_kde_stat}'
        if show_title else None,
        figsize=(3, 3))
    if save_figs:
        save_fig(
            roc_fig, DEFAULT_PLOT_DIR_PATH /
            f'roc_sw_{test_loader_name}_{post_kde_stat}')

    prc_fig = plot_precision_recall_curve(
        recall,
        precision,
        au_prc,
        title=f'PRC SW Anomaly Detection {test_loader_name} {post_kde_stat}'
        if show_title else None,
        figsize=(3, 3))
    if save_figs:
        save_fig(
            prc_fig, DEFAULT_PLOT_DIR_PATH /
            f'prc_sw_{test_loader_name}_{post_kde_stat}')

    print(f'\tau_roc: {au_roc:.2f}')
    print(f'\tau_prc: {au_prc:.2f}')
Exemplo n.º 7
0
def calculate_and_update_ood_performance(metric_name: str,
                                         ood_dict: Dict[str, Dict[str, Union[
                                             List[float], List[bool]]]],
                                         eval_cfg: EvaluationConfig,
                                         results: EvaluationResult,
                                         invert_scores: bool = False) -> None:
    """Calculate OOD Detection performance, update results and create plots for given metric and ood_dict.
    Arguments
        metric_name: used to calculate the correct ood scores based on the metric
        ood_dict: the actual ood numbers, this should be in the format
                  ood_dict = {
                                'ood': {
                                    'ood_scores': [...],
                                    'is_lesional': [...]
                                },
                                'in_distribution': {
                                    'ood_scores': [...],
                                    'is_lesional': [...]
                                }
                            }
        eval_cfg: used for configuration
        results: this function will update the results object
        invert_scores: some metrics have high score for in-distr. while others have low, so multiply by -1 if true

    """
    # Only for calculations, will invert back for plotting
    if invert_scores:
        for ood_key in ['ood', 'in_distribution']:
            ood_dict[ood_key]['ood_scores'] = [
                -val for val in ood_dict[ood_key]['ood_scores']
            ]

    y_pred_proba_all = [
    ]  # These are the e.g. OOD scores for both in- and OOD data!
    y_true_all = []
    y_pred_proba_lesional = []
    y_true_lesional = []
    y_pred_proba_healthy = []
    y_true_healthy = []

    # First use val_dataloader
    LOG.info(f'Inference on out-of-distribution data...')
    ood_scores = ood_dict['ood']['ood_scores']
    lesional_list = ood_dict['ood']['is_lesional']
    y_pred_proba_all.extend(ood_scores)
    y_true_all.extend(len(ood_scores) * [1.0])

    lesional_indices = [
        idx for idx, is_lesional in enumerate(lesional_list) if is_lesional
    ]
    normal_indices = [
        idx for idx, is_lesional in enumerate(lesional_list) if not is_lesional
    ]

    lesional_ood_scores = [ood_scores[idx] for idx in lesional_indices]
    healthy_odd_scores = [ood_scores[idx] for idx in normal_indices]

    n_lesional_samples = len(lesional_ood_scores)
    y_pred_proba_lesional.extend(lesional_ood_scores)
    y_true_lesional.extend(n_lesional_samples * [1.0])

    n_healthy_samples = len(healthy_odd_scores)
    y_pred_proba_healthy.extend(healthy_odd_scores)
    y_true_healthy.extend(n_healthy_samples * [1.0])

    # Now use train_dataloader
    LOG.info(f'Inference on in-distribution data...')
    ood_scores = ood_dict['in_distribution']['ood_scores']

    y_pred_proba_all.extend(ood_scores)
    y_true_all.extend(len(ood_scores) * [0.0])

    # Train dataloader has no lesional samples, so fill up lesional and healthy ones from above with same amount
    y_pred_proba_lesional.extend(random.sample(ood_scores, n_lesional_samples))
    y_true_lesional.extend(n_lesional_samples * [0.0])

    # We can sample a maximum of len(ood_scores) samples from train data
    y_pred_proba_healthy.extend(
        random.sample(ood_scores, min(n_healthy_samples, len(ood_scores))))
    y_true_healthy.extend(n_healthy_samples * [0.0])

    def calculate_ood_performance(mode: str, y_true: list,
                                  y_pred_proba: list) -> dict:
        if mode not in ['all', 'lesional', 'healthy']:
            raise ValueError(
                f'Given mode ({mode}) is not valid! Choose from all, lesional, healthy.'
            )
        fpr, tpr, _, au_roc = calculate_roc(y_true, y_pred_proba)
        precision, recall, _, au_prc = calculate_prc(y_true, y_pred_proba)
        out_dict = {
            'tpr': tpr,
            'fpr': fpr,
            'au_roc': au_roc,
            'precision': precision,
            'recall': recall,
            'au_prc': au_prc
        }
        return out_dict

    def update_results(ood_scores_: dict, mode_: str, metric_: str) -> None:
        ood_result = OODDetectionResult()
        ood_result.au_roc = ood_scores_['au_roc']
        ood_result.au_prc = ood_scores_['au_prc']
        ood_result.mode = mode_
        ood_result.metric = metric_
        results.ood_detection_results.results.append(ood_result)

    # Will be filled with ood scores dict for each mode ('all', 'healthy', 'lesional') with
    metrics_dict = {}
    for mode, y_true, y_pred_proba in zip(
        ['all', 'healthy', 'lesional'],
        [y_true_all, y_true_healthy, y_true_lesional],
        [y_pred_proba_all, y_pred_proba_healthy, y_pred_proba_lesional]):
        if mode in ['healthy', 'lesional'] and not sum(lesional_list) > 0:
            continue
        scores = calculate_ood_performance(mode, y_true, y_pred_proba)
        update_results(scores, mode, metric_name)
        metrics_dict[mode] = scores

    # ROC & PCR curves
    if eval_cfg.do_plots:
        labels = list(metrics_dict.keys())
        fprs = [score_dict['fpr'] for score_dict in metrics_dict.values()]
        tprs = [score_dict['tpr'] for score_dict in metrics_dict.values()]
        au_rocs = [
            score_dict['au_roc'] for score_dict in metrics_dict.values()
        ]
        recalls = [
            score_dict['recall'] for score_dict in metrics_dict.values()
        ]
        precisions = [
            score_dict['precision'] for score_dict in metrics_dict.values()
        ]
        au_prcs = [
            score_dict['au_prc'] for score_dict in metrics_dict.values()
        ]
        roc_fig = plot_multi_roc_curves(
            fprs,
            tprs,
            au_rocs,
            labels,
            title=f'{metric_name} ROC Curve OOD Detection',
            figsize=(6, 6))
        prc_fig = plot_multi_prc_curves(
            recalls,
            precisions,
            au_prcs,
            labels,
            title=f'{metric_name} PR Curve OOD Detection',
            figsize=(6, 6))
        save_fig(roc_fig, results.plot_dir_path / f'ood_roc_{metric_name}.png')
        save_fig(prc_fig, results.plot_dir_path / f'ood_prc_{metric_name}.png')
        plt.close(roc_fig)
        plt.close(prc_fig)

    # Slice-wise OOD score distributions
    if eval_cfg.do_plots:
        # ood healthy, ood lesional, in healthy
        ood_scores_healthy = [
            score for idx, score in enumerate(y_pred_proba_healthy)
            if y_true_healthy[idx] == 1.0
        ]
        ood_scores_lesional = [
            score for idx, score in enumerate(y_pred_proba_lesional)
            if y_true_lesional[idx] == 1.0
        ]
        id_scores_healthy = [
            score for idx, score in enumerate(y_pred_proba_healthy)
            if y_true_healthy[idx] == 0.0
        ]

        arrays = []
        labels = []
        for scores_list, label in zip(
            [ood_scores_healthy, ood_scores_lesional, id_scores_healthy],
            [f'healthy OOD', f'lesional OOD', f'healthy in-distr.']):
            if len(scores_list) > 0:
                array = np.array(scores_list)
                if invert_scores:
                    array = -1 * array
                arrays.append(array)
                labels.append(label)

        fig, _ = plot_multi_histogram(
            arrays,
            labels,
            show_data_ticks=False,
            plot_density=False,
            title=f'OOD scores {results.test_set_name}',
            xlabel=f'{metric_name}')
        save_fig(fig, results.plot_dir_path / f'ood_scores_{metric_name}.png')
        plt.close(fig)