Example #1
0
    def viz_results(self, x, y_true, y_pred, save=True):
        if self.dataset_config("dataset_type") == "tif":
            # make figure
            fig = viz.Fig(1, 3, f"Epoch {self.current_epoch}", figsize=(8, 3))
            fig.plot_img(0, 0, x[0], vmin=0, vmax=1, title="Input")
            fig.plot_img(0, 1, x[0], title="Prediction")
            fig.plot_overlay_class_mask(
                0,
                1,
                y_pred[0],
                num_classes=self.dataset_config("classes"),
                colors=self.dataset_config("class_colors"),
                alpha=0.5,
            )
            fig.plot_img(0, 2, x[0], title="Ground Truth")
            fig.plot_overlay_class_mask(
                0,
                2,
                y_true[0],
                num_classes=self.dataset_config("classes"),
                colors=self.dataset_config("class_colors"),
                alpha=0.5,
            )

            if save:
                os.makedirs(self.hparams.savedir, exist_ok=True)
                fig.save(
                    os.path.join(self.hparams.savedir,
                                 f"{self.current_epoch}.pdf"), )
            else:
                return fig
def make_overview():
    fig = viz.Fig(5, 9, None, figsize=(9, 5))
    # adjust subplot spacing
    fig.fig.subplots_adjust(hspace=0.05, wspace=0.05)
    highlight_colors = [None, 'r', None, None, None, '#31e731']

    for i, dataset in enumerate(DATASET_ORDER):
        # set plotting function
        if dataset == "platelet-em":
            plotfun = plot_platelet
            sample_idx = 5
        elif dataset == "brain-mri":
            plotfun = plot_brainmri
            sample_idx = 0
        elif dataset == "phc-u373":
            plotfun = plot_phc
            sample_idx = 9

        for j, (loss_function, highlight_color) in enumerate(
                zip(LOSS_FUNTION_ORDER, highlight_colors)):
            path = os.path.join("./weights/", dataset, "registration",
                                loss_function)
            if not os.path.isdir(path):
                continue
            # load model
            checkpoint_path = os.path.join(path, "weights.ckpt")
            model = RegistrationModel.load_from_checkpoint(
                checkpoint_path=checkpoint_path)

            # run model
            I_0, S_0, I_m, S_m, I_1, S_1, inv_flow = get_img(model, sample_idx)

            # plot aligned image
            kwargs = {
                'highlight_color': highlight_color
            } if dataset == "platelet-em" else {}
            plotfun(fig,
                    i,
                    j + 2,
                    model,
                    I_m,
                    S_m,
                    inv_flow=inv_flow,
                    **kwargs)

        # plot moved and fixed image
        plotfun(fig, i, 0, model, I_0, S_0)
        plotfun(fig, i, 1, model, I_1, S_1)

    # label loss function
    for i, lossfun in enumerate(LOSS_FUNTION_ORDER):
        fig.axs[0,
                i + 2].set_title(LOSS_FUNTION_CONFIG[lossfun]["display_name"])
    fig.axs[0, 0].set_title("Moving")
    fig.axs[0, 1].set_title("Fixed")

    os.makedirs("./out/plots", exist_ok=True)
    fig.save("./out/plots/img_sample.pdf", close=False)
    fig.save("./out/plots/img_sample.png")
Example #3
0
    def viz_results(self, I_0, I_m, I_1, S_0, S_m, S_1, flow, save=True):
        # make figure
        fig = viz.Fig(2, 3, f"Epoch {self.current_epoch}", figsize=(9, 6))

        fig.plot_img(0, 0, I_0[0], vmin=0, vmax=1, title="$I_0$")
        fig.plot_overlay_class_mask(
            0,
            0,
            S_0[0],
            num_classes=self.dataset_config("classes"),
            colors=self.dataset_config("class_colors"),
            alpha=0.2,
        )

        fig.plot_img(0, 1, I_m[0], vmin=0, vmax=1, title="$I_0 \circ \Phi$")
        fig.plot_overlay_class_mask(
            0,
            1,
            S_m[0],
            num_classes=self.dataset_config("classes"),
            colors=self.dataset_config("class_colors"),
            alpha=0.2,
        )

        fig.plot_img(1, 1, I_1[0], vmin=0, vmax=1, title="$I_1$")
        fig.plot_overlay_class_mask(
            1,
            1,
            S_1[0],
            num_classes=self.dataset_config("classes"),
            colors=self.dataset_config("class_colors"),
            alpha=0.2,
        )

        fig.plot_transform_grid(1,
                                0,
                                flow[0],
                                title="$\Phi$",
                                interval=15,
                                linewidth=0.1)
        fig.plot_img(0,
                     2, (S_0[0] != S_1[0]).long(),
                     vmin=0,
                     vmax=1,
                     title="Diff")
        fig.plot_img(1,
                     2, (S_m[0] != S_1[0]).long(),
                     vmin=0,
                     vmax=1,
                     title="Diff Registered")

        if save:
            os.makedirs(self.hparams.savedir, exist_ok=True)
            fig.save(
                os.path.join(self.hparams.savedir,
                             f"{self.current_epoch}.pdf"), )
        else:
            return fig
def make_detail():
    # detail view
    fig = viz.Fig(2, 1, None, figsize=(1.5, 3))
    # adjust subplot spacing
    fig.fig.subplots_adjust(hspace=0.3, wspace=0.05)

    # set plotting function
    plotfun = plot_platelet_detail
    dataset = "platelet-em"
    sample_idx = 5
    LOSS_FUNTION_ORDER = ["ncc2", "deepsim"]
    highlight_colors = ['r', '#31e731']

    for j, (loss_function, highlight_color) in enumerate(
            zip(LOSS_FUNTION_ORDER, highlight_colors)):
        path = os.path.join("./weights/", dataset, "registration",
                            loss_function)
        if not os.path.isdir(path):
            continue
        # load model
        checkpoint_path = os.path.join(path, "weights.ckpt")
        model = RegistrationModel.load_from_checkpoint(
            checkpoint_path=checkpoint_path)

        # run model
        I_0, S_0, I_m, S_m, I_1, S_1, inv_flow = get_img(model, sample_idx)

        # plot aligned image
        plotfun(fig,
                j,
                0,
                model,
                I_m,
                S_m,
                inv_flow=inv_flow,
                title=LOSS_FUNTION_CONFIG[loss_function]["display_name"],
                highlight_color=highlight_color)

    os.makedirs("./out/plots", exist_ok=True)
    fig.save("./out/plots/img_sample_detail.pdf", close=False)
    fig.save("./out/plots/img_sample_detail.png")
Example #5
0
def main(hparams):
    # load model
    model = RegistrationModel.load_from_checkpoint(
        checkpoint_path=hparams.weights)
    model.eval()

    print(
        f"Evaluating model for dataset {model.hparams.dataset}, loss {model.hparams.loss}, lambda {model.hparams.lam}"
    )

    # init trainer
    trainer = pl.Trainer()

    # test (pass in the model)
    trainer.test(model)

    # create grid animation
    test_set = model.test_dataloader().dataset
    images = []
    for i in tqdm(range(len(test_set)), desc="creating tif image"):
        (I_0, S_0), (I_1, S_1) = test_set[i]
        (I_0, S_0), (I_1, S_1) = (
            (I_0.unsqueeze(0), S_0.unsqueeze(0)),
            (I_1.unsqueeze(0), S_1.unsqueeze(0)),
        )
        flow = model.forward(I_0, I_1)

        fig = viz.Fig(1, 1, None, figsize=(3, 3))
        fig.plot_img(0, 0, I_0[0], vmin=0, vmax=1)
        fig.plot_transform_vec(0,
                               0,
                               -flow[0],
                               interval=10,
                               arrow_length=1.0,
                               linewidth=2.0,
                               overlay=True)
        # extract the axis we are interested in
        img = fig.save_ax_to_PIL(0, 0)
        images.append(img)
    os.makedirs(os.path.dirname(hparams.out), exist_ok=True)
    images[0].save(hparams.out, save_all=True, append_images=images[1:])
def make_detail_all():
    # detail view
    fig = viz.Fig(1, 6, None, figsize=(9, 2))
    # adjust subplot spacing
    fig.fig.subplots_adjust(hspace=0.3, wspace=0.05)

    # set plotting function
    plotfun = plot_platelet_detail
    dataset = "platelet-em"
    sample_idx = 5

    for j, loss_function in enumerate(LOSS_FUNTION_ORDER):
        path = os.path.join("./weights/", dataset, "registration",
                            loss_function)
        if not os.path.isdir(path):
            continue
        # load model
        checkpoint_path = os.path.join(path, "weights.ckpt")
        model = RegistrationModel.load_from_checkpoint(
            checkpoint_path=checkpoint_path)

        # run model
        I_0, S_0, I_m, S_m, I_1, S_1, inv_flow = get_img(model, sample_idx)

        # plot aligned image
        plotfun(
            fig,
            0,
            j,
            model,
            I_m,
            S_m,
            inv_flow=inv_flow,
            title=LOSS_FUNTION_CONFIG[loss_function]["display_name"],
        )

    os.makedirs("./out/plots", exist_ok=True)
    fig.save("./out/plots/img_sample_detail_all.pdf", close=False)
    fig.save("./out/plots/img_sample_detail_all.png")