Example #1
0
def plot_brats_batches(brats_dataloader: DataLoader, plot_n_batches: int,
                       **kwargs) -> None:
    """Plot batches of a BraTS dataloader.

    Keyword Args:
        nrow: kwarg to change number of rows
        uppercase_keys: if True, changes 'scan' to 'Scan' to support legacy hdf5 datasets
    """
    LOG.info('Plotting BraTS2017 Dataset [scan & segmentation]')
    for sample in islice(brats_dataloader, plot_n_batches):
        nrow_kwarg = {
            'nrow': kwargs.get('nrow')
        } if 'nrow' in kwargs.keys() else dict()
        scan_key = 'Scan' if kwargs.get('uppercase_keys', False) else 'scan'
        seg_key = 'Seg' if kwargs.get('uppercase_keys', False) else 'seg'
        mask_key = 'Mask' if kwargs.get('uppercase_keys', False) else 'mask'

        mask = torch.where(sample[mask_key],
                           sample[mask_key].type(torch.FloatTensor),
                           -3.5 * torch.ones_like(sample[scan_key]))
        seg = torch.where(sample[seg_key].type(torch.BoolTensor),
                          sample[seg_key].type(torch.FloatTensor),
                          -3.5 * torch.ones_like(sample[scan_key]))
        grid = make_grid(torch.cat(
            (sample[scan_key].type(torch.FloatTensor),
             seg.type(torch.FloatTensor), mask.type(torch.FloatTensor)),
            dim=2),
                         padding=0,
                         **nrow_kwarg)
        imshow_grid(grid,
                    one_channel=True,
                    plt_show=True,
                    axis='off',
                    **kwargs)
        plt.show()
Example #2
0
def visualize_ensemble_residuals(result_tuple: Tuple[BatchInferenceResult],
                                 **kwargs) -> None:
    """Given a tuple of BatchInferenceResult's, plot the mean and std of the residual map."""
    mask = result_tuple[0].mask
    stacked_residuals = torch.stack([
        mask_background_to_value(result.residual, mask, value=0)
        for result in result_tuple
    ])
    mean_residuals = torch.mean(stacked_residuals, dim=0)
    std_residuals = mask_background_to_value(torch.std(stacked_residuals,
                                                       dim=0),
                                             mask,
                                             value=0)

    stacked_all = torch.cat((mean_residuals, std_residuals), dim=2)

    grid = torchvision.utils.make_grid(stacked_all, padding=0, normalize=False)
    kwargs['cmap'] = 'afmhot'
    imshow_grid(grid, **kwargs)
def plot_vae_output(decoder_output_batch: Tensor,
                    nrow: int = 8,
                    transpose_grid: bool = False,
                    **kwargs) -> Tuple[plt.Figure, plt.Axes]:
    """Plot the output tensor batch from the VAE as a grid."""
    grid = torchvision.utils.make_grid(decoder_output_batch,
                                       nrow=nrow,
                                       padding=0)
    if transpose_grid:
        grid = grid.transpose(dim0=1,
                              dim1=2)  # swap rows and columns of the grid
    fig, ax = imshow_grid(grid, **kwargs)
    return fig, ax
Example #4
0
def visualize_ensemble_predictions(ensemble_results: Generator[
    Tuple[BatchInferenceResult], None, None], **kwargs) -> None:
    for result_tuple in ensemble_results:
        mask = result_tuple[0].mask
        scan = mask_background_to_value(result_tuple[0].scan, mask,
                                        BACKGROUND_VAL)

        vert_stacked_reconstructions = torch.cat([
            mask_background_to_value(result.reconstruction, mask,
                                     BACKGROUND_VAL) for result in result_tuple
        ],
                                                 dim=2)

        stacked_all = torch.cat((scan, vert_stacked_reconstructions), dim=2)

        reconstruction_grid = torchvision.utils.make_grid(stacked_all,
                                                          padding=0,
                                                          normalize=False)

        imshow_grid(reconstruction_grid, vmin=V_MIN, vmax=V_MAX, **kwargs)

        visualize_mean_std_prediction(result_tuple, **kwargs)
        visualize_ensemble_residuals(result_tuple, **kwargs)
Example #5
0
def plot_camcan_batches(camcan_dataloader: DataLoader, plot_n_batches: int,
                        **kwargs) -> None:
    """Plot batches of a CamCAN dataloader.

    Keyword Args:
        nrow: kwarg to change number of rows
        uppercase_keys: if True, changes 'scan' to 'Scan' to support legacy hdf5 datasets
    """
    LOG.info('Plotting CamCAN Dataset [scan only]')
    nrow_kwarg = {
        'nrow': kwargs.get('nrow')
    } if 'nrow' in kwargs.keys() else dict()
    for sample in islice(camcan_dataloader, plot_n_batches):
        scan = 'Scan' if kwargs.get('uppercase_keys', False) else 'scan'
        grid = make_grid(sample[scan].type(torch.FloatTensor),
                         padding=0,
                         **nrow_kwarg)
        imshow_grid(grid,
                    one_channel=True,
                    plt_show=True,
                    axis='off',
                    **kwargs)
        plt.show()
Example #6
0
def visualize_mean_std_prediction(result_tuple: Tuple[BatchInferenceResult],
                                  **kwargs) -> None:
    mask = result_tuple[0].mask
    stacked_reconstructions = torch.stack([
        mask_background_to_value(result.reconstruction,
                                 mask,
                                 value=BACKGROUND_VAL)
        for result in result_tuple
    ])
    mean_reconstruction = torch.mean(stacked_reconstructions, dim=0)
    std_reconstruction = mask_background_to_value(torch.std(
        stacked_reconstructions, dim=0),
                                                  mask,
                                                  value=0)

    # mean
    imshow_grid(torchvision.utils.make_grid(mean_reconstruction, padding=0),
                **kwargs,
                vmin=V_MIN,
                vmax=V_MAX)
    # standard deviation
    kwargs.update({'cmap': 'hot'})
    imshow_grid(torchvision.utils.make_grid(std_reconstruction, padding=0),
                **kwargs)
Example #7
0
def plot_most_least_ood(ood_dict: dict,
                        dataset_name: str,
                        n_most: int = 16,
                        do_lesional: bool = True,
                        small_score_is_more_odd: bool = True) -> None:
    """For healthy and lesional samples, plot the ones which are most and least OOD."""
    ood_dict = ood_dict[dataset_name]

    def create_ood_grids(healthy_leasional: str):
        scores = ood_dict[healthy_leasional]
        slices = ood_dict[f'{healthy_leasional}_scans']
        masks = ood_dict['masks']

        slices = [normalize_to_0_1(s) for s in slices]
        slices = [mask_background_to_zero(s, m) for s, m in zip(slices, masks)]

        largest_score_indices = get_indices_of_n_largest_items(scores, n_most)
        smallest_score_indices = get_indices_of_n_smallest_items(
            scores, n_most)
        largest_scores = [scores[idx] for idx in largest_score_indices]
        smallest_scores = [scores[idx] for idx in smallest_score_indices]

        largest_slices = [slices[idx] for idx in largest_score_indices]
        smallest_slices = [slices[idx] for idx in smallest_score_indices]

        largest_values_grid = normalize_to_0_1(
            torchvision.utils.make_grid(largest_slices,
                                        padding=0,
                                        normalize=False))
        smallest_values_grid = normalize_to_0_1(
            torchvision.utils.make_grid(smallest_slices,
                                        padding=0,
                                        normalize=False))
        if small_score_is_more_odd:
            return smallest_values_grid, largest_values_grid, smallest_scores, largest_scores
        else:
            return largest_values_grid, smallest_values_grid, largest_scores, smallest_scores

    LOG.debug('Creating healthy grids...')
    most_ood_healthy_grid, least_ood_healthy_grid, most_ood_score_healthy, least_ood_scores_healthy = create_ood_grids(
        'healthy')
    if do_lesional:
        LOG.debug('Creating lesional grids...')
        most_ood_lesional_grid, least_ood_lesional_grid, most_ood_score_lesional, least_ood_scores_lesional = create_ood_grids(
            'lesional')

    print(most_ood_score_healthy)
    imshow_grid(most_ood_healthy_grid,
                one_channel=True,
                figsize=(12, 8),
                title=f'Most OOD Healthy {dataset_name}',
                axis='off')
    print(least_ood_scores_healthy)
    imshow_grid(least_ood_healthy_grid,
                one_channel=True,
                figsize=(12, 8),
                title=f'Least OOD Healthy {dataset_name}',
                axis='off')
    if do_lesional:
        print(most_ood_score_lesional)
        imshow_grid(most_ood_lesional_grid,
                    one_channel=True,
                    figsize=(12, 8),
                    title=f'Most OOD Lesional {dataset_name}',
                    axis='off')
        print(least_ood_scores_lesional)
        imshow_grid(least_ood_lesional_grid,
                    one_channel=True,
                    figsize=(12, 8),
                    title=f'Least OOD Lesional {dataset_name}',
                    axis='off')
Example #8
0
def plot_ood_samples_over_range(metrics_ood_dict: dict, dataset_name: str,
                                mode: str, stat_type: str, start_val: float,
                                end_val: float, n_values: int,
                                **plt_kwargs) -> None:
    """Given a metrics_ood_dict, whose keys have to be in the format score -> dataset -> ..., plot the statistic value
    plot a grid of images starting from top left to bottom right in rows left to right with increasing dose kde
    value.
    """
    if mode not in ['dose_kde', 'dose_stat', 'waic']:
        raise ValueError(
            f'Chose the mode, such that it is either "dose_kde" or "raw" (statistic value).'
        )
    if 'dose' in mode:
        main_mode = 'dose'
    else:
        main_mode = 'waic'
    try:
        ood_dict = metrics_ood_dict[main_mode][dataset_name]
        if main_mode == 'dose':
            healthy_values = ood_dict[f'{mode}_healthy'][stat_type]
            lesional_values = ood_dict[f'{mode}_lesional'][stat_type]
        elif main_mode == 'waic':
            healthy_values = ood_dict[f'healthy']
            lesional_values = ood_dict[f'lesional']
        else:
            raise KeyError(f'main_mode not supported')
        healthy_scans = ood_dict['healthy_scans']
        lesional_scans = ood_dict['lesional_scans']
        healthy_recs = ood_dict['healthy_reconstructions']
        lesional_recs = ood_dict['lesional_reconstructions']
    except KeyError as err:
        print(
            f'The metrics_ood_dict does not have the correct keys to generate your plot.'
            f'Given:\n{print_dict_tree(metrics_ood_dict)}')
        raise err

    mode_description = f'{main_mode} {stat_type if main_mode=="dose" else ""}'
    LOG.info(
        f'Plotting scans with {stat_type} from {start_val:.1f}-{end_val:.1f}.')
    LOG.info(
        f'Min/max {mode_description} from all healthy scans is: {min(healthy_values):.1f}/{max(healthy_values):.1f}'
    )
    LOG.info(
        f'Min/max {mode_description} from all lesional scans is: {min(lesional_values):.1f}/{max(lesional_values):.1f}'
    )
    # Will fill up two tensors with healthy and lesional images
    healthy_img_batch = torch.zeros(size=[n_values, 1, 128, 128])
    lesional_img_batch = torch.zeros(size=[n_values, 1, 128, 128])
    healthy_img_rec_batch = torch.zeros(size=[n_values, 1, 128, 128])
    lesional_img_rec_batch = torch.zeros(size=[n_values, 1, 128, 128])

    # Define reference values and initialize actual values of picked images
    ref_values = np.linspace(start_val, end_val, n_values)
    picked_healthy_values = []
    picked_lesional_values = []

    for img_idx, ref_val in enumerate(ref_values):
        closest_healthy_idx = get_idx_of_closest_value(healthy_values,
                                                       value=ref_val)
        closest_lesional_idx = get_idx_of_closest_value(lesional_values,
                                                        value=ref_val)

        lesional_img = lesional_scans[closest_lesional_idx]
        healthy_img = healthy_scans[closest_healthy_idx]
        lesional_rec_img = lesional_recs[closest_lesional_idx]
        healthy_rec_img = healthy_recs[closest_healthy_idx]

        lesional_img_batch[img_idx] = lesional_img
        healthy_img_batch[img_idx] = healthy_img
        lesional_img_rec_batch[img_idx] = lesional_rec_img
        healthy_img_rec_batch[img_idx] = healthy_rec_img
        picked_healthy_values.append(healthy_values[closest_healthy_idx])
        picked_lesional_values.append(lesional_values[closest_lesional_idx])

    nrow = plt_kwargs.get('nrow', 8)
    healthy_grid = torchvision.utils.make_grid(healthy_img_batch, nrow=nrow)
    healthy_rec_grid = torchvision.utils.make_grid(healthy_img_rec_batch,
                                                   nrow=nrow)

    lesional_grid = torchvision.utils.make_grid(lesional_img_batch, nrow=nrow)
    lesional_rec_grid = torchvision.utils.make_grid(lesional_img_rec_batch,
                                                    nrow=nrow)

    imshow_grid(
        healthy_grid,
        title=f'Healthy {mode_description} [{start_val:.1f}, {end_val:.1f}]',
        **plt_kwargs)
    imshow_grid(
        healthy_rec_grid,
        title=f'Healthy {mode_description} [{start_val:.1f}, {end_val:.1f}]',
        **plt_kwargs)
    imshow_grid(
        lesional_grid,
        title=f'Lesional {mode_description} [{start_val:.1f}, {end_val:.1f}]',
        **plt_kwargs)
    imshow_grid(
        lesional_rec_grid,
        title=f'Lesional {mode_description} [{start_val:.1f}, {end_val:.1f}]',
        **plt_kwargs)
    LOG.info(f'Healthy values:\n{picked_healthy_values}')
    LOG.info(f'Healthy values:\n{picked_lesional_values}')
Example #9
0
def plot_samples_close_to_score(ood_dict: dict,
                                dataset_name: str,
                                min_score: float,
                                max_score: float,
                                n: int = 32,
                                do_lesional: bool = True,
                                show_ground_truth: bool = False,
                                print_score: bool = False) -> None:
    """Arrange slices in a grid such that each slice displayed is closest to the interpolation OOD score from
    linspace which goes from min_score to max_score with n samples."""
    ood_dict = ood_dict[dataset_name]
    ref_scores = np.linspace(min_score, max_score, n)

    def create_ood_grids(healthy_leasional: str):
        scores = ood_dict[healthy_leasional]
        slices = ood_dict[f'{healthy_leasional}_scans']
        masks = ood_dict['masks']
        segmentations = ood_dict[f'{healthy_leasional}_segmentations']
        final_scores = []
        final_slices = []
        final_masks = []
        final_segmentations = []

        for ref_score in ref_scores:
            scores_idx = get_idx_of_closest_value(scores, ref_score)
            final_scores.append(scores[scores_idx])
            final_slices.append(slices[scores_idx])
            final_masks.append(masks[scores_idx])
            if show_ground_truth:
                final_segmentations.append(segmentations[scores_idx])

        final_slices = [normalize_to_0_1(s) for s in final_slices]
        final_slices = [
            mask_background_to_zero(s, m)
            for s, m in zip(final_slices, final_masks)
        ]

        slices_grid = torchvision.utils.make_grid(final_slices,
                                                  padding=0,
                                                  normalize=False)
        segmentations_grid = None
        if show_ground_truth:
            segmentations_grid = torchvision.utils.make_grid(
                final_segmentations, padding=0, normalize=False)
        if print_score:
            formatted_scores = [f'{val:.2f}' for val in final_scores]
            LOG.info(f'Scores: {formatted_scores}')
        return slices_grid, segmentations_grid

    healthy_slices_grid, healthy_segmentations_grid = create_ood_grids(
        'healthy')
    imshow_grid(healthy_slices_grid,
                one_channel=True,
                figsize=(12, 8),
                title=f'Healthy {dataset_name} {min_score}-{max_score}',
                axis='off')
    if show_ground_truth:
        imshow_grid(
            healthy_segmentations_grid,
            one_channel=True,
            figsize=(12, 8),
            title=
            f'Healthy Ground Truth {dataset_name} {min_score}-{max_score}',
            axis='off')
    if do_lesional:
        lesional_slices_grid, lesional_segmentations_grid = create_ood_grids(
            'lesional')
        imshow_grid(lesional_slices_grid,
                    one_channel=True,
                    figsize=(12, 8),
                    title=f'Lesional {dataset_name} {min_score}-{max_score}',
                    axis='off')
        if show_ground_truth:
            imshow_grid(
                lesional_segmentations_grid,
                one_channel=True,
                figsize=(12, 8),
                title=
                f'Lesional Ground Truth {dataset_name} {min_score}-{max_score}',
                axis='off')
Example #10
0
def plot_stacked_scan_reconstruction_batches(batch_generator: Generator[
    BatchInferenceResult, None, None],
                                             plot_n_batches: int = 3,
                                             nrow: int = 8,
                                             show_mask: bool = False,
                                             save_dir_path: Path = None,
                                             cut_tensor_to: int = None,
                                             mask_background: bool = True,
                                             close_fig: bool = False,
                                             **kwargs) -> None:
    """Plot the scan and reconstruction batches. Horizontally aligned are the samples from one batch.
    Vertically aligned are input image, ground truth segmentation, reconstruction, residual image, residual with
    applied threshold, ground truth and predicted segmentation in same image.
    Args:
        batch_generator: a PyTorch DataLoader as defined in the uncertify dataloaders module
        plot_n_batches: limit plotting to this amount of batches
        save_dir_path: path to directory in which to store the resulting plots - will be created if not existent
        nrow: numbers of samples in one row, default is 8
        show_mask: plots the brain mask
        cut_tensor_to: choose 8 if you need only first 8 samples but for plotting but the original tensor is way larger
        close_fig: if True, will not show in notebook, handy for use in large pipeline
        kwargs: additional keyword arguments for plotting functions
    """
    if save_dir_path is not None:
        save_dir_path.mkdir(exist_ok=True)
    with torch.no_grad():
        for batch_idx, batch in enumerate(
                itertools.islice(batch_generator, plot_n_batches)):
            mask = batch.mask
            max_val = torch.max(batch.scan)
            min_val = torch.min(batch.scan)

            scan = normalize_to_0_1(batch.scan, min_val, max_val)
            reconstruction = normalize_to_0_1(batch.reconstruction, min_val,
                                              max_val)
            residual = normalize_to_0_1(batch.residual)
            thresholded = batch.residuals_thresholded

            if mask_background:
                scan = mask_background_to_zero(scan, mask)
                reconstruction = mask_background_to_zero(reconstruction, mask)
                residual = mask_background_to_zero(residual, mask)
                thresholded = mask_background_to_zero(thresholded, mask)

            if batch.segmentation is not None:
                seg = mask_background_to_zero(batch.segmentation, mask)
                stacked = torch.cat(
                    (scan, seg, reconstruction, residual, thresholded), dim=2)
            else:
                stacked = torch.cat(
                    (scan, reconstruction, residual, thresholded), dim=2)
            if show_mask:
                stacked = torch.cat((stacked, mask.type(torch.FloatTensor)),
                                    dim=2)
            if cut_tensor_to is not None:
                stacked = stacked[:cut_tensor_to, ...]
            grid = torchvision.utils.make_grid(stacked,
                                               padding=0,
                                               normalize=False,
                                               nrow=nrow)
            describe = scipy.stats.describe(grid.numpy().flatten())
            print_scipy_stats_description(describe, 'normalized_grid')
            fig, ax = imshow_grid(grid,
                                  one_channel=True,
                                  vmax=1.0,
                                  vmin=0.0,
                                  plt_show=False,
                                  **kwargs)
            ax.set_axis_off()
            plt.show()
            if save_dir_path is not None:
                img_file_name = f'batch_{batch_idx}.png'
                save_fig(fig, save_dir_path / img_file_name)
                if close_fig:
                    plt.close(fig)