def viz_results(self, x, y_true, y_pred, save=True): if self.dataset_config("dataset_type") == "tif": # make figure fig = viz.Fig(1, 3, f"Epoch {self.current_epoch}", figsize=(8, 3)) fig.plot_img(0, 0, x[0], vmin=0, vmax=1, title="Input") fig.plot_img(0, 1, x[0], title="Prediction") fig.plot_overlay_class_mask( 0, 1, y_pred[0], num_classes=self.dataset_config("classes"), colors=self.dataset_config("class_colors"), alpha=0.5, ) fig.plot_img(0, 2, x[0], title="Ground Truth") fig.plot_overlay_class_mask( 0, 2, y_true[0], num_classes=self.dataset_config("classes"), colors=self.dataset_config("class_colors"), alpha=0.5, ) if save: os.makedirs(self.hparams.savedir, exist_ok=True) fig.save( os.path.join(self.hparams.savedir, f"{self.current_epoch}.pdf"), ) else: return fig
def make_overview(): fig = viz.Fig(5, 9, None, figsize=(9, 5)) # adjust subplot spacing fig.fig.subplots_adjust(hspace=0.05, wspace=0.05) highlight_colors = [None, 'r', None, None, None, '#31e731'] for i, dataset in enumerate(DATASET_ORDER): # set plotting function if dataset == "platelet-em": plotfun = plot_platelet sample_idx = 5 elif dataset == "brain-mri": plotfun = plot_brainmri sample_idx = 0 elif dataset == "phc-u373": plotfun = plot_phc sample_idx = 9 for j, (loss_function, highlight_color) in enumerate( zip(LOSS_FUNTION_ORDER, highlight_colors)): path = os.path.join("./weights/", dataset, "registration", loss_function) if not os.path.isdir(path): continue # load model checkpoint_path = os.path.join(path, "weights.ckpt") model = RegistrationModel.load_from_checkpoint( checkpoint_path=checkpoint_path) # run model I_0, S_0, I_m, S_m, I_1, S_1, inv_flow = get_img(model, sample_idx) # plot aligned image kwargs = { 'highlight_color': highlight_color } if dataset == "platelet-em" else {} plotfun(fig, i, j + 2, model, I_m, S_m, inv_flow=inv_flow, **kwargs) # plot moved and fixed image plotfun(fig, i, 0, model, I_0, S_0) plotfun(fig, i, 1, model, I_1, S_1) # label loss function for i, lossfun in enumerate(LOSS_FUNTION_ORDER): fig.axs[0, i + 2].set_title(LOSS_FUNTION_CONFIG[lossfun]["display_name"]) fig.axs[0, 0].set_title("Moving") fig.axs[0, 1].set_title("Fixed") os.makedirs("./out/plots", exist_ok=True) fig.save("./out/plots/img_sample.pdf", close=False) fig.save("./out/plots/img_sample.png")
def viz_results(self, I_0, I_m, I_1, S_0, S_m, S_1, flow, save=True): # make figure fig = viz.Fig(2, 3, f"Epoch {self.current_epoch}", figsize=(9, 6)) fig.plot_img(0, 0, I_0[0], vmin=0, vmax=1, title="$I_0$") fig.plot_overlay_class_mask( 0, 0, S_0[0], num_classes=self.dataset_config("classes"), colors=self.dataset_config("class_colors"), alpha=0.2, ) fig.plot_img(0, 1, I_m[0], vmin=0, vmax=1, title="$I_0 \circ \Phi$") fig.plot_overlay_class_mask( 0, 1, S_m[0], num_classes=self.dataset_config("classes"), colors=self.dataset_config("class_colors"), alpha=0.2, ) fig.plot_img(1, 1, I_1[0], vmin=0, vmax=1, title="$I_1$") fig.plot_overlay_class_mask( 1, 1, S_1[0], num_classes=self.dataset_config("classes"), colors=self.dataset_config("class_colors"), alpha=0.2, ) fig.plot_transform_grid(1, 0, flow[0], title="$\Phi$", interval=15, linewidth=0.1) fig.plot_img(0, 2, (S_0[0] != S_1[0]).long(), vmin=0, vmax=1, title="Diff") fig.plot_img(1, 2, (S_m[0] != S_1[0]).long(), vmin=0, vmax=1, title="Diff Registered") if save: os.makedirs(self.hparams.savedir, exist_ok=True) fig.save( os.path.join(self.hparams.savedir, f"{self.current_epoch}.pdf"), ) else: return fig
def make_detail(): # detail view fig = viz.Fig(2, 1, None, figsize=(1.5, 3)) # adjust subplot spacing fig.fig.subplots_adjust(hspace=0.3, wspace=0.05) # set plotting function plotfun = plot_platelet_detail dataset = "platelet-em" sample_idx = 5 LOSS_FUNTION_ORDER = ["ncc2", "deepsim"] highlight_colors = ['r', '#31e731'] for j, (loss_function, highlight_color) in enumerate( zip(LOSS_FUNTION_ORDER, highlight_colors)): path = os.path.join("./weights/", dataset, "registration", loss_function) if not os.path.isdir(path): continue # load model checkpoint_path = os.path.join(path, "weights.ckpt") model = RegistrationModel.load_from_checkpoint( checkpoint_path=checkpoint_path) # run model I_0, S_0, I_m, S_m, I_1, S_1, inv_flow = get_img(model, sample_idx) # plot aligned image plotfun(fig, j, 0, model, I_m, S_m, inv_flow=inv_flow, title=LOSS_FUNTION_CONFIG[loss_function]["display_name"], highlight_color=highlight_color) os.makedirs("./out/plots", exist_ok=True) fig.save("./out/plots/img_sample_detail.pdf", close=False) fig.save("./out/plots/img_sample_detail.png")
def main(hparams): # load model model = RegistrationModel.load_from_checkpoint( checkpoint_path=hparams.weights) model.eval() print( f"Evaluating model for dataset {model.hparams.dataset}, loss {model.hparams.loss}, lambda {model.hparams.lam}" ) # init trainer trainer = pl.Trainer() # test (pass in the model) trainer.test(model) # create grid animation test_set = model.test_dataloader().dataset images = [] for i in tqdm(range(len(test_set)), desc="creating tif image"): (I_0, S_0), (I_1, S_1) = test_set[i] (I_0, S_0), (I_1, S_1) = ( (I_0.unsqueeze(0), S_0.unsqueeze(0)), (I_1.unsqueeze(0), S_1.unsqueeze(0)), ) flow = model.forward(I_0, I_1) fig = viz.Fig(1, 1, None, figsize=(3, 3)) fig.plot_img(0, 0, I_0[0], vmin=0, vmax=1) fig.plot_transform_vec(0, 0, -flow[0], interval=10, arrow_length=1.0, linewidth=2.0, overlay=True) # extract the axis we are interested in img = fig.save_ax_to_PIL(0, 0) images.append(img) os.makedirs(os.path.dirname(hparams.out), exist_ok=True) images[0].save(hparams.out, save_all=True, append_images=images[1:])
def make_detail_all(): # detail view fig = viz.Fig(1, 6, None, figsize=(9, 2)) # adjust subplot spacing fig.fig.subplots_adjust(hspace=0.3, wspace=0.05) # set plotting function plotfun = plot_platelet_detail dataset = "platelet-em" sample_idx = 5 for j, loss_function in enumerate(LOSS_FUNTION_ORDER): path = os.path.join("./weights/", dataset, "registration", loss_function) if not os.path.isdir(path): continue # load model checkpoint_path = os.path.join(path, "weights.ckpt") model = RegistrationModel.load_from_checkpoint( checkpoint_path=checkpoint_path) # run model I_0, S_0, I_m, S_m, I_1, S_1, inv_flow = get_img(model, sample_idx) # plot aligned image plotfun( fig, 0, j, model, I_m, S_m, inv_flow=inv_flow, title=LOSS_FUNTION_CONFIG[loss_function]["display_name"], ) os.makedirs("./out/plots", exist_ok=True) fig.save("./out/plots/img_sample_detail_all.pdf", close=False) fig.save("./out/plots/img_sample_detail_all.png")