def make_fig(args): # set up device = 'cuda' if torch.cuda.is_available() else 'cpu' isic_model_checkpoints = ['./trained_models/isic18/softm/', './trained_models/isic18/ensemble/', './trained_models/isic18/mcdropout/', './trained_models/isic18/punet/'] isic_models = [util.load_model_from_checkpoint(ckpt).to( device) for ckpt in isic_model_checkpoints] lidc_model_checkpoints = ['./trained_models/lidc/softm/', './trained_models/lidc/ensemble/', './trained_models/lidc/mcdropout/', './trained_models/lidc/punet/'] lidc_models = [util.load_model_from_checkpoint(ckpt).to( device) for ckpt in lidc_model_checkpoints] fig = Fig( rows=5, cols=args.images_each*2, title=None, figsize=(5, 3), background=True, ) plt.tight_layout() colors = np.array(sns.color_palette("Paired")) * 255 color_gt = colors[1] color_model = colors[7] np.random.seed(7) img_indices = np.random.choice(range(500), args.images_each*2) # plot isic dataset for i in range(args.images_each): plot_col(fig, i, img_indices[i], device, isic_models, 4, color_gt, color_model) # plot lidc dataset for i in range(args.images_each, args.images_each*2): plot_col(fig, i, img_indices[i], device, lidc_models, 2, color_gt, color_model) # adjust spacing plt.subplots_adjust(left=0, bottom=0, right=1, top=1., wspace=0.06, hspace=0.06) for f in args.output_file: fig.save(f, close=False)
def make_fig(model_checkpoints, output_file): device = 'cuda' if torch.cuda.is_available() else 'cpu' for row, model_checkpoint in enumerate(tqdm(model_checkpoints, desc='Plotting Model outputs...')): model = util.load_model_from_checkpoint(model_checkpoint).to(device) datamodule = util.load_datamodule_for_model( model, batch_size=1) sample = load_sample(datamodule, idx=0, device=device) if row == 0: fig = set_up_figure(len(model_checkpoints), sample) predictions = predict(model, sample[0]) plot_predictions(fig, row + 1, predictions, model.model_shortname()) os.makedirs("./plots/", exist_ok=True) fig.save(output_file)
def make_fig(model_checkpoint, output_file, indices): device = 'cuda' if torch.cuda.is_available() else 'cpu' model = util.load_model_from_checkpoint(model_checkpoint[0]).to(device) datamodule = util.load_datamodule_for_model(model, batch_size=1) for row, idx in enumerate(indices): sample = load_train_sample(datamodule, idx=idx, device=device) if row == 0: fig = set_up_figure(len(indices), sample) predictions = predict(model, sample[0]) plot_predictions(fig, row + 1, predictions, str(idx)) os.makedirs("./plots/", exist_ok=True) fig.save(output_file)
def make_fig(args): device = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu' models = [util.load_model_from_checkpoint(ckpt).to( device) for ckpt in args.model_checkpoints] datamodule = util.load_datamodule_for_model( models[0], batch_size=1) x, ys = load_sample(datamodule, idx=args.sample_idx, device=device) fig = set_up_figure(4, x, ys) for row, model in enumerate(tqdm(models, desc='Plotting Model outputs...')): predictions = predict(model, x) plot_predictions(fig, row + 1, predictions) # adjust spacing plt.subplots_adjust(left=0, bottom=0, right=1, top=0.9, wspace=0.0, hspace=0.06) os.makedirs("./plots/", exist_ok=True) for f in args.output_file: fig.save(f, close=False)
def make_fig(args): # set up device = 'cuda' if torch.cuda.is_available() else 'cpu' model = util.load_model_from_checkpoint(args.model_path).to(device) datamodule = util.load_datamodule_for_model(model, batch_size=1) for idx in tqdm(range(100), desc='Generating Images'): x, y = load_sample(datamodule, idx=idx, device=device) fig = Fig( rows=1, cols=2, title=None, figsize=None, background=True, ) colors = np.array(sns.color_palette("Paired")) * 255 color_gt = colors[1] color_model = colors[7] # draw samples pl.seed_everything(42) with torch.no_grad(): p = model.pixel_wise_probabaility(x, sample_cnt=args.samples) _, y_pred = p.max(dim=1, keepdim=True) uncertainty = util.entropy(p) # plot image fig.plot_img(0, 0, x[0], vmin=0, vmax=1) # plot uncertainty heatmap fig.plot_overlay( 0, 1, uncertainty[0], alpha=1, vmin=None, vmax=None, cmap='Greys', colorbar=True, colorbar_label="Model Uncertainty") # plot model prediction outline fig.plot_contour(0, 0, y_pred[0], contour_class=1, width=2, rgba=color_model ) fig.plot_contour(0, 1, y_pred[0], contour_class=1, width=2, rgba=color_model ) # plot gt seg outline fig.plot_contour(0, 0, y[0], contour_class=1, width=2, rgba=color_gt ) fig.plot_contour(0, 1, y[0], contour_class=1, width=2, rgba=color_gt ) # add legend from matplotlib import pyplot as plt from matplotlib.patches import Rectangle legend_data = [ [0, color_gt, "GT Annotation"], [1, color_model, "Model Prediction"], ] handles = [ Rectangle((0, 0), 1, 1, color=[v/255 for v in c]) for k, c, n in legend_data ] labels = [n for k, c, n in legend_data] plt.legend(handles, labels, ncol=len(legend_data)) os.makedirs("./plots/", exist_ok=True) # fig.save(args.output_file) os.makedirs(args.output_folder, exist_ok=True) fig.save(os.path.join(args.output_folder, f'test_{idx}.png'), close=False) fig.save(os.path.join(args.output_folder, f'test_{idx}.pdf'))
def cli_main(): pl.seed_everything(1234) supported_models = util.get_supported_models() # ------------ # args # ------------ parser = ArgumentParser() parser.add_argument( '--model_path', type=str, help=f'Path to the trained model.') parser.add_argument( '--file', type=str, default='./plots/experiment_results.pickl', help=f'File to save the results in.') parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() # ------------ # model # ------------ model = util.load_model_from_checkpoint(args.model_path) checkpoint_path = util.get_checkpoint_path(args.model_path) datamodule = util.load_datamodule_for_model(model) dataset = model.hparams.dataset # ------------ # file # ------------ if os.path.isfile(args.file): print('Loading existing result file') with open(args.file, 'rb') as f: test_results = pickle.load(f) else: print('Creating new result file') test_results = {} if dataset not in test_results: test_results[dataset] = {} test_results[dataset][model.model_shortname()] = {} test_results[dataset][model.model_shortname( )]['model_name'] = model.model_name() test_results[dataset][model.model_shortname( )]['model_shortname'] = model.model_shortname() # ------------ # Run model test script # ------------ trainer = pl.Trainer.from_argparse_args(args) test_metrics = trainer.test( model=model, ckpt_path=checkpoint_path, datamodule=datamodule) for k, v in test_metrics[0].items(): test_results[dataset][model.model_shortname()][k] = v # ------------ # Record sample-specific metrics # ------------ with torch.no_grad(): device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) model.eval() metrics = defaultdict(list) pixel_metrics = defaultdict(list) for i, (x, ys) in enumerate(tqdm(datamodule.test_dataloader(), desc='Collecting sample-individual metrics...')): x, ys = util.to_device((x, ys), device) assert ys[0].max() <= 1 y_mean = torch.stack(ys).float().mean(dim=0) # # Image-wise metrics # for sample_count in [1, 4, 8, 16]: ged, sample_diversity = generalized_energy_distance( model, x, ys, sample_count=sample_count) metrics[f"test/ged/{sample_count}"].append(ged) metrics[f"test/sample_diversity/{sample_count}"].append( sample_diversity) dice = heatmap_dice_loss( model, x, ys, sample_count=sample_count) metrics[f"test/diceloss/{sample_count}"].append(dice) sample_count = 16 uncertainty = model.pixel_wise_uncertainty( x, sample_cnt=sample_count) correl = torch.stack([pearsonr(uncertainty, torch.nn.functional.binary_cross_entropy( model.sample_prediction(x).float(), ys[torch.randint(len(ys), ())].float(), reduction='none')) for _ in range(16)]).mean(dim=0) metrics["test/uncertainty_seg_error_correl"].append(correl) # # Pixel-whise metrics # model_uncertainty = uncertainty.cpu().numpy() # model uncertainty values # we sample pixel-wise metrics to preserve memory PIXELS_PER_IMAGE = 1000 B = model_uncertainty.shape[0] indices = np.moveaxis(np.indices(model_uncertainty.shape[1:]), 0, -1) indices = indices.reshape((-1, 3)) indices = indices[np.random.randint(len(indices), size=PIXELS_PER_IMAGE)] # re-add batch dim #indices = np.concatenate([np.stack([np.insert(i, 0, b) for i in indices]) for b in range(B)]) y_hat = model.sample_prediction(x) # model thresholded prediction annotator_sum = torch.stack(ys).sum(dim=0) # sum of annotator votes annotator_cnt = len(ys) model_uncertainty = uncertainty # model uncertainty values y_mean = torch.stack(ys).float().mean(dim=0) annot_uncertainty = util.binary_entropy(y_mean) # annotator uncertainty # record conditional uncetainty only if the is consensus for b in range(B): # iterate over images of the batch # sample pixels by indices idx = torch.tensor([[b]+list(i) for i in indices]) y_hat_subsampled = index_select(y_hat, idx) annotator_sum_subsampled = index_select(annotator_sum, idx) model_uncertainty_subsampled = index_select(model_uncertainty, idx) annot_uncertainty_subsampled = index_select(annot_uncertainty, idx) pixel_metrics["test/tp_uncertainty"].append( model_uncertainty_subsampled[(y_hat_subsampled == 1) & (annotator_sum_subsampled == annotator_cnt)].cpu().tolist()) pixel_metrics["test/fp_uncertainty"].append( model_uncertainty_subsampled[(y_hat_subsampled == 1) & (annotator_sum_subsampled == 0)].cpu().tolist()) pixel_metrics["test/fn_uncertainty"].append( model_uncertainty_subsampled[(y_hat_subsampled == 0) & (annotator_sum_subsampled == annotator_cnt)].cpu().tolist()) pixel_metrics["test/tn_uncertainty"].append( model_uncertainty_subsampled[(y_hat_subsampled == 0) & (annotator_sum_subsampled == 0)].cpu().tolist()) pixel_metrics["test/model_uncertainty"].append(model_uncertainty_subsampled.cpu().tolist()) pixel_metrics["test/annotator_uncertainty"].append(annot_uncertainty_subsampled.cpu().tolist()) pixel_metrics["test/is_prediction_correct"].append((y_hat_subsampled == (annotator_sum_subsampled+2)//4).tolist()) # map metrics into lists of floats test_results[dataset][model.model_shortname()]['per_sample'] = {} for k in metrics: test_results[dataset][model.model_shortname()]['per_sample'][k] = torch.cat( metrics[k]).cpu().numpy() for k in pixel_metrics: test_results[dataset][model.model_shortname()]['per_sample'][k] = pixel_metrics[k] # ------------ # save results # ------------ with open(args.file, 'wb') as f: pickle.dump(test_results, f)