Ejemplo n.º 1
0
def run_loss_term_histograms(model: nn.Module, train_dataloader: DataLoader,
                             val_dataloader: DataLoader,
                             eval_cfg: EvaluationConfig,
                             results: EvaluationResult) -> EvaluationResult:
    """Produce loss term histograms and store in the output."""
    if eval_cfg.do_plots:
        LOG.info('Producing model loss-term statistics histograms...')
        val_generator = yield_inference_batches(
            val_dataloader,
            model,
            residual_threshold=results.pixel_anomaly_result.best_threshold,
            max_batches=eval_cfg.use_n_batches,
            progress_bar_suffix='loss-terms val set')
        train_generator = yield_inference_batches(
            train_dataloader,
            model,
            residual_threshold=results.pixel_anomaly_result.best_threshold,
            max_batches=eval_cfg.use_n_batches,
            progress_bar_suffix='loss-terms train set')

        figs_axes = plot_loss_histograms(
            output_generators=[train_generator, val_generator],
            names=[f'{results.train_set_name}', f'{results.test_set_name}'],
            figsize=(15, 6),
            ylabel='Frequency',
            plot_density=True,
            show_data_ticks=False,
            kde_bandwidth=[0.009, 0.009 * 5],
            show_histograms=False)
        for idx, (fig, _) in enumerate(figs_axes):
            save_fig(
                fig,
                results.plot_dir_path / f'loss_term_distributions_{idx}.png')
    return results
Ejemplo n.º 2
0
def infer_ensembles(ensemble_models: List[VariationalAutoEncoder],
                    dataloader: DataLoader, use_n_batches: int,
                    residual_threshold: float, **kwargs) -> List:
    """Runs inference using every single model of the ensembles and yields the results from every single one.

    Yields:
        result_generators: one result_generator per model over which one can iterate to get inference batch results
    """
    for model in ensemble_models:
        result_generator = yield_inference_batches(dataloader, model,
                                                   use_n_batches,
                                                   residual_threshold,
                                                   **kwargs)
        yield result_generator
Ejemplo n.º 3
0
def calculate_mean_false_positive_rate(
        threshold: float,
        data_loader: DataLoader,
        model: torch.nn.Module,
        use_ground_truth: bool = False,
        n_batches_per_thresh: int = None) -> float:
    """Calculate the mean false positive rate (mean taken over batches) for a given threshold.

    If use_ground_truth=True, the data_loader must provide tensors with ground truth segmentation under "seg" key.
    Else, every "anomaly" pixel in the thresholded residual is considered an outlier. In this setting, only healthy
    samples (i.e. from the training data) should be used to make the assumption hold.
    """
    result_generator = yield_inference_batches(
        data_loader,
        model,
        residual_threshold=threshold,
        max_batches=n_batches_per_thresh,
        progress_bar_suffix=f'(FPR, t={threshold:.3f})')
    per_batch_fpr = []
    if n_batches_per_thresh is not None:
        result_generator = itertools.islice(result_generator,
                                            n_batches_per_thresh)
    with torch.no_grad():
        for batch_idx, batch in enumerate(result_generator):
            prediction = batch.residuals_thresholded[batch.mask]
            pred_np = prediction.numpy().astype(int)
            if use_ground_truth:
                try:
                    ground_truth = batch.segmentation[batch.mask]
                except KeyError:
                    LOG.exception(
                        f'When use_ground_truth=True, the data_loader must '
                        f'provide batches under "seg" key. Exit.')
                    raise
                gt_np = ground_truth.numpy().astype(int)
                fpr = false_positive_rate(pred_np, gt_np)
            else:
                fpr = false_positive_rate(pred_np, np.zeros_like(pred_np))
            per_batch_fpr.append(fpr)
            if (batch_idx + 1) % 50 == 0:
                LOG.info(
                    f'Threshold: {threshold:.2f} - {batch_idx + 1} of '
                    f'{n_batches_per_thresh if n_batches_per_thresh is not None else "all"} batches done.'
                )
    mean_fpr = float(np.mean(per_batch_fpr))
    LOG.info(
        f'Mean FPR: {mean_fpr:.3f} for t={threshold:.3f} (use_ground_truth=False)'
    )
    return mean_fpr
Ejemplo n.º 4
0
def create_reconstruction_samples(dataloader: DataLoader,
                                  model: nn.Module,
                                  results: EvaluationResult,
                                  n_batches: int = 3) -> None:
    """Run some inference on small sample batch and store result."""
    LOG.info(f'Creating some example reconstruction plots.')
    batch_generator = yield_inference_batches(
        dataloader,
        model,
        residual_threshold=results.pixel_anomaly_result.best_threshold)
    plot_stacked_scan_reconstruction_batches(
        batch_generator,
        plot_n_batches=n_batches,
        nrow=8,
        save_dir_path=results.img_dir_path,
        cut_tensor_to=8,
        cmap='gray',
        mask_background=False,
        close_fig=True)
Ejemplo n.º 5
0
def mean_std_seg_score(data_loader: DataLoader,
                       model: torch.nn.Module,
                       residual_threshold: float,
                       score_type: str = 'dice',
                       max_n_batches: int = None) -> Tuple[float, float]:
    """Calculate the mean (over multiple / all batches) segmentation score for a given residual threshold.
    Arguments:
        data_loader: a uncertify data loader which yields dicts (with 'scan', 'mask', etc.)
        model: a trained pytorch model
        residual_threshold: the threshold from 0 to 1 for pixel-wise anomaly detection
        score_type: either 'dice' or 'iou'
        max_n_batches: if not None, take first max_n_batches only from the data_loader for calculation
    Returns:
        a tuple of (mean_seg_score, std_seg_score) for one threshold across multiple batches
    """
    if score_type not in VALID_SEGMENTATION_SCORE_TYPES:
        raise ValueError(
            f'Provided score_type ({score_type}) invalid. Choose from: {VALID_SEGMENTATION_SCORE_TYPES}'
        )
    batch_generator = yield_inference_batches(data_loader, model,
                                              max_n_batches,
                                              residual_threshold)
    per_batch_scores = []
    for batch_idx, batch in enumerate(batch_generator):
        prediction_batch = batch.residuals_thresholded[batch.mask]
        ground_truth_batch = batch.segmentation[batch.mask]
        with torch.no_grad():
            if score_type == 'dice':
                score = dice(prediction_batch.numpy(),
                             ground_truth_batch.numpy())
            elif score_type == 'iou':
                score = intersection_over_union(prediction_batch.numpy(),
                                                ground_truth_batch.numpy())
            else:
                raise RuntimeError(
                    f'Arrived at a score_type ({score_type}) which is invalid. Should not happen.'
                )
            per_batch_scores.append(score)
    # When a batch represents a single patient this actually gives back the mean and std for patient-wise scores
    seg_score_mean = float(np.mean(per_batch_scores))
    seg_score_std = float(np.std(per_batch_scores))
    return seg_score_mean, seg_score_std
Ejemplo n.º 6
0
def calculate_confusion_matrix(data_loader: DataLoader,
                               model: torch.nn.Module,
                               residual_threshold: float,
                               max_n_batches: int = None,
                               normalize: bool = False) -> np.ndarray:
    """Calculate the confusion matrix for a given threshold over multiple batches of data.

    The layout of the confusion matrix follows the convention by scikit-learn, which is used to calculate sub matrices!
    """
    batch_generator = yield_inference_batches(data_loader, model,
                                              max_n_batches,
                                              residual_threshold)
    confusion_matrix = np.zeros((2, 2))  # initialize zero confusion matrix
    for batch_idx, batch in enumerate(batch_generator):
        with torch.no_grad():
            y_pred = batch.residuals_thresholded[batch.mask].flatten().numpy()
            y_true = batch.segmentation[batch.mask].flatten().numpy()
            sub_confusion_matrix = sklearn_metrics.confusion_matrix(
                y_true, y_pred, normalize=normalize)
            confusion_matrix += sub_confusion_matrix
    return confusion_matrix
Ejemplo n.º 7
0
def aggregate_slice_wise_statistics(model: nn.Module,
                                    data_loader: DataLoader,
                                    statistics: Iterable[str],
                                    max_n_batches: int = None,
                                    residual_threshold: float = None,
                                    health_state: str = 'all') -> dict:
    """Evaluate slice wise statistics and return aggregated results in a statistics-dict.

    Returns
        statistics_dict: will have a key for each statistic with a dictionary with 'all', 'healthy' and 'lesional'
                         keys with lists of slice-wise values for this statistic in this sub-dict
    """
    assert all([item in STATISTICS_FUNCTIONS for item in statistics]), f'Need to provide valid ' \
                                                                       f'statistics ({STATISTICS_FUNCTIONS})!'
    statistics_dict = defaultdict(
        list)  # statistics are keys and list of scores are values
    slice_wise_scans = []  # track scans for later visualization
    slices_keep_mask = [
    ]  # the indices mask which decides which slices we keep by index
    slice_wise_seg_maps = []
    slice_wise_masks = []
    slice_wise_reconstructions = []
    for batch_idx, batch in enumerate(
            yield_inference_batches(
                data_loader,
                model,
                max_n_batches,
                residual_threshold,
                progress_bar_suffix=f'(aggr. slice statistics '
                f'{data_loader.dataset.name})')):
        batch_size, _, _, _ = batch.scan.shape

        # Track which slices we keep based on slice health state and emptiness of brain mask
        health_state_mask = define_health_state_mask(health_state, batch)
        is_not_empty_mask = np.invert(batch.slice_wise_is_empty)
        batch_keep_mask = np.logical_and(health_state_mask, is_not_empty_mask)
        slices_keep_mask.extend(list(batch_keep_mask))

        # Track lesional slices
        is_lesional = list(batch.slice_wise_is_lesional) if batch.slice_wise_is_lesional is not None \
            else list(np.zeros(len(batch.scan), dtype=bool))
        statistics_dict['is_lesional'].extend(is_lesional)

        # Add the actual statistic
        for statistic in statistics:
            statistics_dict[statistic].extend(
                list(STATISTICS_FUNCTIONS[statistic](batch)))

        # Track scans and potentially ground truth segmentation for visualizations later on
        for mask in batch.mask:
            slice_wise_masks.append(mask)
        for reconstruction in batch.reconstruction:
            slice_wise_reconstructions.append(reconstruction)
        for scan in batch.scan:
            slice_wise_scans.append(scan)
        if batch.segmentation is not None:
            for seg in batch.segmentation:
                slice_wise_seg_maps.append(seg)
        else:
            for _ in range(batch_size):
                slice_wise_seg_maps.append(torch.zeros_like(batch.scan[0]))

    # Apply indices mask to filter out empty slices or slices from other health state
    keep_slice_indices = [
        idx for idx, keep_slice_flag in enumerate(slices_keep_mask)
        if keep_slice_flag
    ]
    statistics_dict = {
        key: [values[idx] for idx in keep_slice_indices]
        for key, values in statistics_dict.items()
    }

    # Add metadata keys
    slice_wise_masks = [slice_wise_masks[idx] for idx in keep_slice_indices]
    slice_wise_scans = [slice_wise_scans[idx] for idx in keep_slice_indices]
    slice_wise_seg_maps = [
        slice_wise_seg_maps[idx] for idx in keep_slice_indices
    ]
    slice_wise_reconstructions = [
        slice_wise_reconstructions[idx] for idx in keep_slice_indices
    ]
    statistics_dict.update({'scans': slice_wise_scans})
    statistics_dict.update({'segmentations': slice_wise_seg_maps})
    statistics_dict.update({'masks': slice_wise_masks})
    statistics_dict.update({'reconstructions': slice_wise_reconstructions})

    return statistics_dict
Ejemplo n.º 8
0
def sample_wise_waic_scores(models: Iterable[nn.Module],
                            data_loader: DataLoader,
                            max_n_batches: int = None,
                            return_slices: bool = False) -> ReturnTuple:
    """Computes all per-slice WAIC scores for all batches of the generator as well as the ELBO for one model.

    Arguments:
        models: an iterable of trained ensemble models
        data_loader: the pytorch dataloader to receive the data from
        max_n_batches: limit number of batches used in analysis, handy for debugging
        return_slices: whether to aggregate and return the individual slices (should be turned of for large evaluation)
    Returns:
        slice_wise_waic_scores: a list of waic scores, one for each slice, so the size is ~(num_batches * batch_size)
        slice_wise_is_lesional: a list of True (for lesional) or False (for normal) values, one for each slice
        slice_wise_scans [Optional]: a list of scan tensors for further analysis connected to the slice_wise_waic_scores
    """
    LOG.info(f'Getting slice-wise WAIC scores for {data_loader.dataset.name}')
    # Keys are slice indices, values are a list of log likelihoods coming from different models
    slice_wise_elbos_ensemble = defaultdict(list)
    # A list holding information for every slice if it's lesional (True) or normal (False)
    slice_wise_is_lesional = []
    # A list of pytorch tensors holding a scan of one slice
    slice_wise_scans = []
    # The slice-wise ELBO evaluated on the first of the ensembles
    slice_wise_elbo_scores = []
    # The slice-wise KL Divergence evaluated on the first of the ensembles
    slice_wise_kl_div = []
    # The slice-wise reconstruction error evaluated on the first of the ensembles
    slice_wise_rec_err = []
    # All masks
    slice_wise_masks = []
    # Ground truth
    slice_wise_segmentations = []
    # Reconstruction samples from one model
    slice_wise_reconstructions = []

    global_slice_idx = 0
    for model_idx, model in enumerate(
            models):  # will yield same input data for every ensemble model
        for batch_idx, batch in enumerate(
                yield_inference_batches(
                    data_loader,
                    model,
                    max_n_batches,
                    progress_bar_suffix=f'WAIC (ensemble {model_idx})')):
            slice_wise_elbos = batch.rec_err - batch.kl_div
            for slice_idx, slice_elbo in enumerate(slice_wise_elbos):
                # if not batch.slice_wise_is_empty[slice_idx]:
                slice_wise_elbos_ensemble[global_slice_idx].append(slice_elbo)
                if model_idx == 0:
                    slice_wise_elbo_scores.append(slice_wise_elbos[slice_idx])
                    slice_wise_kl_div.append(batch.kl_div[slice_idx])
                    slice_wise_rec_err.append(batch.rec_err[slice_idx])
                    slice_wise_masks.append(batch.mask[slice_idx])
                    slice_wise_reconstructions.append(
                        batch.reconstruction[slice_idx])
                    if batch.segmentation is not None:
                        slice_wise_segmentations.append(
                            batch.segmentation[slice_idx])
                        n_abnormal_pixels = float(
                            torch.sum(batch.segmentation[slice_idx] > 0))
                        slice_wise_is_lesional.append(
                            n_abnormal_pixels >
                            N_ABNORMAL_PIXELS_THRESHOLD_LESIONAL)
                    else:
                        slice_wise_segmentations.append(
                            torch.zeros_like(batch.scan[slice_idx]))
                        slice_wise_is_lesional.append(False)
                    if return_slices:
                        slice_wise_scans.append(batch.scan[slice_idx])
                # Increase global slice counter when we added a slice to the evaluation list
                global_slice_idx += 1
        # Reset the global slice counter when iterating over batches and slices using the next ensemble model
        global_slice_idx = 0

    # Now loop over all lists of elbo values (one list per slice) and compute the WAIC score
    slice_wise_waic_scores = []
    for slice_elbo_lists in slice_wise_elbos_ensemble.values():
        mean = float(np.mean(slice_elbo_lists))
        var = float(np.var(slice_elbo_lists))
        waic = (mean - var)
        slice_wise_waic_scores.append(waic)

    slice_wise_scans = slice_wise_scans if len(slice_wise_scans) > 0 else None

    return ReturnTuple(slice_wise_waic_scores, slice_wise_is_lesional,
                       slice_wise_scans, slice_wise_elbo_scores,
                       slice_wise_kl_div, slice_wise_rec_err, slice_wise_masks,
                       slice_wise_segmentations, slice_wise_reconstructions)