def plot_brats_batches(brats_dataloader: DataLoader, plot_n_batches: int, **kwargs) -> None: """Plot batches of a BraTS dataloader. Keyword Args: nrow: kwarg to change number of rows uppercase_keys: if True, changes 'scan' to 'Scan' to support legacy hdf5 datasets """ LOG.info('Plotting BraTS2017 Dataset [scan & segmentation]') for sample in islice(brats_dataloader, plot_n_batches): nrow_kwarg = { 'nrow': kwargs.get('nrow') } if 'nrow' in kwargs.keys() else dict() scan_key = 'Scan' if kwargs.get('uppercase_keys', False) else 'scan' seg_key = 'Seg' if kwargs.get('uppercase_keys', False) else 'seg' mask_key = 'Mask' if kwargs.get('uppercase_keys', False) else 'mask' mask = torch.where(sample[mask_key], sample[mask_key].type(torch.FloatTensor), -3.5 * torch.ones_like(sample[scan_key])) seg = torch.where(sample[seg_key].type(torch.BoolTensor), sample[seg_key].type(torch.FloatTensor), -3.5 * torch.ones_like(sample[scan_key])) grid = make_grid(torch.cat( (sample[scan_key].type(torch.FloatTensor), seg.type(torch.FloatTensor), mask.type(torch.FloatTensor)), dim=2), padding=0, **nrow_kwarg) imshow_grid(grid, one_channel=True, plt_show=True, axis='off', **kwargs) plt.show()
def visualize_ensemble_residuals(result_tuple: Tuple[BatchInferenceResult], **kwargs) -> None: """Given a tuple of BatchInferenceResult's, plot the mean and std of the residual map.""" mask = result_tuple[0].mask stacked_residuals = torch.stack([ mask_background_to_value(result.residual, mask, value=0) for result in result_tuple ]) mean_residuals = torch.mean(stacked_residuals, dim=0) std_residuals = mask_background_to_value(torch.std(stacked_residuals, dim=0), mask, value=0) stacked_all = torch.cat((mean_residuals, std_residuals), dim=2) grid = torchvision.utils.make_grid(stacked_all, padding=0, normalize=False) kwargs['cmap'] = 'afmhot' imshow_grid(grid, **kwargs)
def plot_vae_output(decoder_output_batch: Tensor, nrow: int = 8, transpose_grid: bool = False, **kwargs) -> Tuple[plt.Figure, plt.Axes]: """Plot the output tensor batch from the VAE as a grid.""" grid = torchvision.utils.make_grid(decoder_output_batch, nrow=nrow, padding=0) if transpose_grid: grid = grid.transpose(dim0=1, dim1=2) # swap rows and columns of the grid fig, ax = imshow_grid(grid, **kwargs) return fig, ax
def visualize_ensemble_predictions(ensemble_results: Generator[ Tuple[BatchInferenceResult], None, None], **kwargs) -> None: for result_tuple in ensemble_results: mask = result_tuple[0].mask scan = mask_background_to_value(result_tuple[0].scan, mask, BACKGROUND_VAL) vert_stacked_reconstructions = torch.cat([ mask_background_to_value(result.reconstruction, mask, BACKGROUND_VAL) for result in result_tuple ], dim=2) stacked_all = torch.cat((scan, vert_stacked_reconstructions), dim=2) reconstruction_grid = torchvision.utils.make_grid(stacked_all, padding=0, normalize=False) imshow_grid(reconstruction_grid, vmin=V_MIN, vmax=V_MAX, **kwargs) visualize_mean_std_prediction(result_tuple, **kwargs) visualize_ensemble_residuals(result_tuple, **kwargs)
def plot_camcan_batches(camcan_dataloader: DataLoader, plot_n_batches: int, **kwargs) -> None: """Plot batches of a CamCAN dataloader. Keyword Args: nrow: kwarg to change number of rows uppercase_keys: if True, changes 'scan' to 'Scan' to support legacy hdf5 datasets """ LOG.info('Plotting CamCAN Dataset [scan only]') nrow_kwarg = { 'nrow': kwargs.get('nrow') } if 'nrow' in kwargs.keys() else dict() for sample in islice(camcan_dataloader, plot_n_batches): scan = 'Scan' if kwargs.get('uppercase_keys', False) else 'scan' grid = make_grid(sample[scan].type(torch.FloatTensor), padding=0, **nrow_kwarg) imshow_grid(grid, one_channel=True, plt_show=True, axis='off', **kwargs) plt.show()
def visualize_mean_std_prediction(result_tuple: Tuple[BatchInferenceResult], **kwargs) -> None: mask = result_tuple[0].mask stacked_reconstructions = torch.stack([ mask_background_to_value(result.reconstruction, mask, value=BACKGROUND_VAL) for result in result_tuple ]) mean_reconstruction = torch.mean(stacked_reconstructions, dim=0) std_reconstruction = mask_background_to_value(torch.std( stacked_reconstructions, dim=0), mask, value=0) # mean imshow_grid(torchvision.utils.make_grid(mean_reconstruction, padding=0), **kwargs, vmin=V_MIN, vmax=V_MAX) # standard deviation kwargs.update({'cmap': 'hot'}) imshow_grid(torchvision.utils.make_grid(std_reconstruction, padding=0), **kwargs)
def plot_most_least_ood(ood_dict: dict, dataset_name: str, n_most: int = 16, do_lesional: bool = True, small_score_is_more_odd: bool = True) -> None: """For healthy and lesional samples, plot the ones which are most and least OOD.""" ood_dict = ood_dict[dataset_name] def create_ood_grids(healthy_leasional: str): scores = ood_dict[healthy_leasional] slices = ood_dict[f'{healthy_leasional}_scans'] masks = ood_dict['masks'] slices = [normalize_to_0_1(s) for s in slices] slices = [mask_background_to_zero(s, m) for s, m in zip(slices, masks)] largest_score_indices = get_indices_of_n_largest_items(scores, n_most) smallest_score_indices = get_indices_of_n_smallest_items( scores, n_most) largest_scores = [scores[idx] for idx in largest_score_indices] smallest_scores = [scores[idx] for idx in smallest_score_indices] largest_slices = [slices[idx] for idx in largest_score_indices] smallest_slices = [slices[idx] for idx in smallest_score_indices] largest_values_grid = normalize_to_0_1( torchvision.utils.make_grid(largest_slices, padding=0, normalize=False)) smallest_values_grid = normalize_to_0_1( torchvision.utils.make_grid(smallest_slices, padding=0, normalize=False)) if small_score_is_more_odd: return smallest_values_grid, largest_values_grid, smallest_scores, largest_scores else: return largest_values_grid, smallest_values_grid, largest_scores, smallest_scores LOG.debug('Creating healthy grids...') most_ood_healthy_grid, least_ood_healthy_grid, most_ood_score_healthy, least_ood_scores_healthy = create_ood_grids( 'healthy') if do_lesional: LOG.debug('Creating lesional grids...') most_ood_lesional_grid, least_ood_lesional_grid, most_ood_score_lesional, least_ood_scores_lesional = create_ood_grids( 'lesional') print(most_ood_score_healthy) imshow_grid(most_ood_healthy_grid, one_channel=True, figsize=(12, 8), title=f'Most OOD Healthy {dataset_name}', axis='off') print(least_ood_scores_healthy) imshow_grid(least_ood_healthy_grid, one_channel=True, figsize=(12, 8), title=f'Least OOD Healthy {dataset_name}', axis='off') if do_lesional: print(most_ood_score_lesional) imshow_grid(most_ood_lesional_grid, one_channel=True, figsize=(12, 8), title=f'Most OOD Lesional {dataset_name}', axis='off') print(least_ood_scores_lesional) imshow_grid(least_ood_lesional_grid, one_channel=True, figsize=(12, 8), title=f'Least OOD Lesional {dataset_name}', axis='off')
def plot_ood_samples_over_range(metrics_ood_dict: dict, dataset_name: str, mode: str, stat_type: str, start_val: float, end_val: float, n_values: int, **plt_kwargs) -> None: """Given a metrics_ood_dict, whose keys have to be in the format score -> dataset -> ..., plot the statistic value plot a grid of images starting from top left to bottom right in rows left to right with increasing dose kde value. """ if mode not in ['dose_kde', 'dose_stat', 'waic']: raise ValueError( f'Chose the mode, such that it is either "dose_kde" or "raw" (statistic value).' ) if 'dose' in mode: main_mode = 'dose' else: main_mode = 'waic' try: ood_dict = metrics_ood_dict[main_mode][dataset_name] if main_mode == 'dose': healthy_values = ood_dict[f'{mode}_healthy'][stat_type] lesional_values = ood_dict[f'{mode}_lesional'][stat_type] elif main_mode == 'waic': healthy_values = ood_dict[f'healthy'] lesional_values = ood_dict[f'lesional'] else: raise KeyError(f'main_mode not supported') healthy_scans = ood_dict['healthy_scans'] lesional_scans = ood_dict['lesional_scans'] healthy_recs = ood_dict['healthy_reconstructions'] lesional_recs = ood_dict['lesional_reconstructions'] except KeyError as err: print( f'The metrics_ood_dict does not have the correct keys to generate your plot.' f'Given:\n{print_dict_tree(metrics_ood_dict)}') raise err mode_description = f'{main_mode} {stat_type if main_mode=="dose" else ""}' LOG.info( f'Plotting scans with {stat_type} from {start_val:.1f}-{end_val:.1f}.') LOG.info( f'Min/max {mode_description} from all healthy scans is: {min(healthy_values):.1f}/{max(healthy_values):.1f}' ) LOG.info( f'Min/max {mode_description} from all lesional scans is: {min(lesional_values):.1f}/{max(lesional_values):.1f}' ) # Will fill up two tensors with healthy and lesional images healthy_img_batch = torch.zeros(size=[n_values, 1, 128, 128]) lesional_img_batch = torch.zeros(size=[n_values, 1, 128, 128]) healthy_img_rec_batch = torch.zeros(size=[n_values, 1, 128, 128]) lesional_img_rec_batch = torch.zeros(size=[n_values, 1, 128, 128]) # Define reference values and initialize actual values of picked images ref_values = np.linspace(start_val, end_val, n_values) picked_healthy_values = [] picked_lesional_values = [] for img_idx, ref_val in enumerate(ref_values): closest_healthy_idx = get_idx_of_closest_value(healthy_values, value=ref_val) closest_lesional_idx = get_idx_of_closest_value(lesional_values, value=ref_val) lesional_img = lesional_scans[closest_lesional_idx] healthy_img = healthy_scans[closest_healthy_idx] lesional_rec_img = lesional_recs[closest_lesional_idx] healthy_rec_img = healthy_recs[closest_healthy_idx] lesional_img_batch[img_idx] = lesional_img healthy_img_batch[img_idx] = healthy_img lesional_img_rec_batch[img_idx] = lesional_rec_img healthy_img_rec_batch[img_idx] = healthy_rec_img picked_healthy_values.append(healthy_values[closest_healthy_idx]) picked_lesional_values.append(lesional_values[closest_lesional_idx]) nrow = plt_kwargs.get('nrow', 8) healthy_grid = torchvision.utils.make_grid(healthy_img_batch, nrow=nrow) healthy_rec_grid = torchvision.utils.make_grid(healthy_img_rec_batch, nrow=nrow) lesional_grid = torchvision.utils.make_grid(lesional_img_batch, nrow=nrow) lesional_rec_grid = torchvision.utils.make_grid(lesional_img_rec_batch, nrow=nrow) imshow_grid( healthy_grid, title=f'Healthy {mode_description} [{start_val:.1f}, {end_val:.1f}]', **plt_kwargs) imshow_grid( healthy_rec_grid, title=f'Healthy {mode_description} [{start_val:.1f}, {end_val:.1f}]', **plt_kwargs) imshow_grid( lesional_grid, title=f'Lesional {mode_description} [{start_val:.1f}, {end_val:.1f}]', **plt_kwargs) imshow_grid( lesional_rec_grid, title=f'Lesional {mode_description} [{start_val:.1f}, {end_val:.1f}]', **plt_kwargs) LOG.info(f'Healthy values:\n{picked_healthy_values}') LOG.info(f'Healthy values:\n{picked_lesional_values}')
def plot_samples_close_to_score(ood_dict: dict, dataset_name: str, min_score: float, max_score: float, n: int = 32, do_lesional: bool = True, show_ground_truth: bool = False, print_score: bool = False) -> None: """Arrange slices in a grid such that each slice displayed is closest to the interpolation OOD score from linspace which goes from min_score to max_score with n samples.""" ood_dict = ood_dict[dataset_name] ref_scores = np.linspace(min_score, max_score, n) def create_ood_grids(healthy_leasional: str): scores = ood_dict[healthy_leasional] slices = ood_dict[f'{healthy_leasional}_scans'] masks = ood_dict['masks'] segmentations = ood_dict[f'{healthy_leasional}_segmentations'] final_scores = [] final_slices = [] final_masks = [] final_segmentations = [] for ref_score in ref_scores: scores_idx = get_idx_of_closest_value(scores, ref_score) final_scores.append(scores[scores_idx]) final_slices.append(slices[scores_idx]) final_masks.append(masks[scores_idx]) if show_ground_truth: final_segmentations.append(segmentations[scores_idx]) final_slices = [normalize_to_0_1(s) for s in final_slices] final_slices = [ mask_background_to_zero(s, m) for s, m in zip(final_slices, final_masks) ] slices_grid = torchvision.utils.make_grid(final_slices, padding=0, normalize=False) segmentations_grid = None if show_ground_truth: segmentations_grid = torchvision.utils.make_grid( final_segmentations, padding=0, normalize=False) if print_score: formatted_scores = [f'{val:.2f}' for val in final_scores] LOG.info(f'Scores: {formatted_scores}') return slices_grid, segmentations_grid healthy_slices_grid, healthy_segmentations_grid = create_ood_grids( 'healthy') imshow_grid(healthy_slices_grid, one_channel=True, figsize=(12, 8), title=f'Healthy {dataset_name} {min_score}-{max_score}', axis='off') if show_ground_truth: imshow_grid( healthy_segmentations_grid, one_channel=True, figsize=(12, 8), title= f'Healthy Ground Truth {dataset_name} {min_score}-{max_score}', axis='off') if do_lesional: lesional_slices_grid, lesional_segmentations_grid = create_ood_grids( 'lesional') imshow_grid(lesional_slices_grid, one_channel=True, figsize=(12, 8), title=f'Lesional {dataset_name} {min_score}-{max_score}', axis='off') if show_ground_truth: imshow_grid( lesional_segmentations_grid, one_channel=True, figsize=(12, 8), title= f'Lesional Ground Truth {dataset_name} {min_score}-{max_score}', axis='off')
def plot_stacked_scan_reconstruction_batches(batch_generator: Generator[ BatchInferenceResult, None, None], plot_n_batches: int = 3, nrow: int = 8, show_mask: bool = False, save_dir_path: Path = None, cut_tensor_to: int = None, mask_background: bool = True, close_fig: bool = False, **kwargs) -> None: """Plot the scan and reconstruction batches. Horizontally aligned are the samples from one batch. Vertically aligned are input image, ground truth segmentation, reconstruction, residual image, residual with applied threshold, ground truth and predicted segmentation in same image. Args: batch_generator: a PyTorch DataLoader as defined in the uncertify dataloaders module plot_n_batches: limit plotting to this amount of batches save_dir_path: path to directory in which to store the resulting plots - will be created if not existent nrow: numbers of samples in one row, default is 8 show_mask: plots the brain mask cut_tensor_to: choose 8 if you need only first 8 samples but for plotting but the original tensor is way larger close_fig: if True, will not show in notebook, handy for use in large pipeline kwargs: additional keyword arguments for plotting functions """ if save_dir_path is not None: save_dir_path.mkdir(exist_ok=True) with torch.no_grad(): for batch_idx, batch in enumerate( itertools.islice(batch_generator, plot_n_batches)): mask = batch.mask max_val = torch.max(batch.scan) min_val = torch.min(batch.scan) scan = normalize_to_0_1(batch.scan, min_val, max_val) reconstruction = normalize_to_0_1(batch.reconstruction, min_val, max_val) residual = normalize_to_0_1(batch.residual) thresholded = batch.residuals_thresholded if mask_background: scan = mask_background_to_zero(scan, mask) reconstruction = mask_background_to_zero(reconstruction, mask) residual = mask_background_to_zero(residual, mask) thresholded = mask_background_to_zero(thresholded, mask) if batch.segmentation is not None: seg = mask_background_to_zero(batch.segmentation, mask) stacked = torch.cat( (scan, seg, reconstruction, residual, thresholded), dim=2) else: stacked = torch.cat( (scan, reconstruction, residual, thresholded), dim=2) if show_mask: stacked = torch.cat((stacked, mask.type(torch.FloatTensor)), dim=2) if cut_tensor_to is not None: stacked = stacked[:cut_tensor_to, ...] grid = torchvision.utils.make_grid(stacked, padding=0, normalize=False, nrow=nrow) describe = scipy.stats.describe(grid.numpy().flatten()) print_scipy_stats_description(describe, 'normalized_grid') fig, ax = imshow_grid(grid, one_channel=True, vmax=1.0, vmin=0.0, plt_show=False, **kwargs) ax.set_axis_off() plt.show() if save_dir_path is not None: img_file_name = f'batch_{batch_idx}.png' save_fig(fig, save_dir_path / img_file_name) if close_fig: plt.close(fig)