コード例 #1
0
ファイル: utils.py プロジェクト: AlexandreAdam/Censai
def reconstruction_plot(y_true, y_pred):
    batch_size = y_true.shape[0]
    len_y = batch_size // 3
    fig, axs = plt.subplots(len_y, 9, figsize=(27, 3 * len_y))

    for i in range(len_y):
        for j in range(3):
            k = (i * 3 + j) % batch_size
            ax = axs[i, j]
            ax.imshow(y_true[k, ..., 0], cmap="hot", origin="lower")
            ax.axis("off")

            ax = axs[i, j + 3]
            ax.imshow(y_pred[k, ..., 0], cmap="hot", origin="lower")
            ax.axis("off")

            ax = axs[i, j + 6]
            ax.imshow(y_true[k, ..., 0] - y_pred[k, ..., 0],
                      cmap="seismic",
                      norm=CenteredNorm(),
                      origin="lower")
            ax.axis("off")

    axs[0, 1].set_title("Ground Truth", size=20)
    axs[0, 4].set_title("Prediction", size=20)
    axs[0, 7].set_title("Residual", size=20)
    fig.subplots_adjust(wspace=0, hspace=0)
    return fig
コード例 #2
0
ファイル: utils.py プロジェクト: AlexandreAdam/Censai
def vae_residual_plot(y_true, y_pred):
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))

    ax = axs[0]
    im = ax.imshow(y_true[..., 0], cmap="hot", origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    ax = axs[1]
    im = ax.imshow(y_pred[..., 0], cmap="hot", origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    ax = axs[2]
    im = ax.imshow(y_true[..., 0] - y_pred[..., 0],
                   cmap="seismic",
                   norm=CenteredNorm(),
                   origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    axs[0].set_title("Ground Truth", size=15)
    axs[1].set_title("Prediction", size=15)
    axs[2].set_title("Residual", size=15)
    return fig
コード例 #3
0
ファイル: utils.py プロジェクト: AlexandreAdam/Censai
def lens_residual_plot(lens_true, lens_pred, title=""):
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    ax = axs[0]
    im = ax.imshow(lens_true[..., 0], cmap="hot", origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    ax = axs[1]
    im = ax.imshow(lens_pred[..., 0], cmap="hot", origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    ax = axs[2]
    im = ax.imshow((lens_true - lens_pred)[..., 0],
                   cmap="seismic",
                   norm=CenteredNorm(),
                   origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    fig.suptitle(f"{title}", size=20)
    axs[0].set_title("Ground Truth", size=15)
    axs[1].set_title("Predictions", size=15)
    axs[2].set_title("Residuals", size=15)
    plt.subplots_adjust(wspace=.2, hspace=.2)
    return fig
コード例 #4
0
ファイル: utils.py プロジェクト: AlexandreAdam/Censai
def deflection_angles_residual_plot(y_true, y_pred):
    fig, axs = plt.subplots(2, 3, figsize=(12, 8))
    for i in range(2):
        im = axs[i, 0].imshow(y_true[..., i], cmap="jet", origin="lower")
        divider = make_axes_locatable(axs[i, 0])
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im, cax=cax)
        axs[i, 0].axis("off")

        im = axs[i, 1].imshow(y_pred[..., i], cmap="jet", origin="lower")
        divider = make_axes_locatable(axs[i, 1])
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im, cax=cax)
        axs[i, 1].axis("off")

        residual = np.abs(y_true[..., i] - y_pred[..., i])
        im = axs[i, 2].imshow(residual,
                              cmap="seismic",
                              norm=CenteredNorm(),
                              origin="lower")
        divider = make_axes_locatable(axs[i, 2])
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im, cax=cax)
        axs[i, 2].axis("off")

    axs[0, 0].set_title("Ground Truth")
    axs[0, 1].set_title("Prediction")
    axs[0, 2].set_title("Residual")
    plt.subplots_adjust(wspace=.2, hspace=.2)
    plt.figtext(0.1,
                0.7,
                r"$\alpha_x$",
                va="center",
                ha="center",
                size=15,
                rotation=90)
    plt.figtext(0.1,
                0.3,
                r"$\alpha_y$",
                va="center",
                ha="center",
                size=15,
                rotation=90)
    return fig
コード例 #5
0
def residual_plot(dataset, rim, N):
    fig, axs = plt.subplots(N, 5, figsize=(6 * 4, N * 4))
    index = [0, 1, -1]
    label = [1, 2, rim.steps]
    for j in range(N):
        X, Y, noise_std = dataset[j]
        y = rim.inverse_link_function(Y[0, ..., 0])
        out, chi_squared = rim.call(X, noise_std)
        out = out[:, 0, ..., 0]
        for plot_i, i in enumerate(index):
            axs[j, plot_i].imshow(out[i],
                                  cmap="hot",
                                  origin="lower",
                                  vmin=np.log10(rim.log_floor),
                                  vmax=0)
            axs[j, plot_i].axis("off")
            if j == 0:
                axs[j, plot_i].set_title(
                    f"Step {label[i]} \n" +
                    fr"$\chi^2_\nu$ = {chi_squared[i, 0]:.1e}")
            else:
                axs[j, plot_i].set_title(
                    fr"$\chi^2_\nu$ = {chi_squared[i, 0]:.1e}")

        axs[j, 3].imshow(y,
                         cmap="hot",
                         origin="lower",
                         vmin=np.log10(rim.log_floor),
                         vmax=0)
        axs[j, 3].axis("off")

        im = axs[j, 4].imshow(out[-1] - y,
                              cmap="seismic",
                              norm=CenteredNorm(),
                              origin="lower")
        divider = make_axes_locatable(axs[j, 4])
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im, cax=cax)
        axs[j, 4].axis("off")
        if j == 0:
            axs[j, 3].set_title(f"Ground Truth")
            axs[j, 4].set_title(f"Residuals")
    # plt.subplots_adjust(wspace=.1, hspace=0)
    return fig
コード例 #6
0
ファイル: utils.py プロジェクト: AlexandreAdam/Censai
def rim_residual_plot(lens_true, source_true, kappa_true, lens_pred,
                      source_pred, kappa_pred, chi_squared):
    fig, axs = plt.subplots(3, 3, figsize=(12, 12))

    ax = axs[0, 0]
    im = ax.imshow(lens_true[..., 0],
                   cmap="hot",
                   origin="lower",
                   vmin=0,
                   vmax=1)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    ax = axs[1, 0]
    im = ax.imshow(source_true[..., 0],
                   cmap="bone",
                   origin="lower",
                   vmin=0,
                   vmax=1)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    ax = axs[2, 0]
    im = ax.imshow(kappa_true[..., 0],
                   cmap="hot",
                   norm=LogNorm(vmin=1e-1, vmax=100),
                   origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    ax = axs[0, 1]
    im = ax.imshow(lens_pred[..., 0],
                   cmap="hot",
                   origin="lower",
                   vmin=0,
                   vmax=1)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    ax = axs[1, 1]
    im = ax.imshow(source_pred[..., 0],
                   cmap="bone",
                   origin="lower",
                   vmin=0,
                   vmax=1)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    ax = axs[2, 1]
    im = ax.imshow(kappa_pred[..., 0],
                   cmap="hot",
                   norm=LogNorm(vmin=1e-1, vmax=100),
                   origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    ax = axs[0, 2]
    im = ax.imshow((lens_true - lens_pred)[..., 0],
                   cmap="seismic",
                   norm=CenteredNorm(),
                   origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    ax = axs[1, 2]
    im = ax.imshow((source_true - source_pred)[..., 0],
                   cmap="seismic",
                   norm=CenteredNorm(),
                   origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    ax = axs[2, 2]
    im = ax.imshow((kappa_true - kappa_pred)[..., 0],
                   cmap="seismic",
                   norm=SymLogNorm(linthresh=1e-1,
                                   base=10,
                                   vmax=100,
                                   vmin=-100),
                   origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.axis("off")

    axs[0, 0].set_title("Ground Truth", size=15)
    axs[0, 1].set_title("Predictions", size=15)
    axs[0, 2].set_title("Residuals", size=15)
    fig.suptitle(fr"$\chi^2$ = {chi_squared: .3e}", size=20)
    plt.subplots_adjust(wspace=.4, hspace=.2)
    plt.figtext(0.1,
                0.75,
                r"Lens",
                va="center",
                ha="center",
                size=15,
                rotation=90)
    plt.figtext(0.1,
                0.5,
                r"Source",
                va="center",
                ha="center",
                size=15,
                rotation=90)
    plt.figtext(0.1,
                0.22,
                r"$\kappa$",
                va="center",
                ha="center",
                size=15,
                rotation=90)

    return fig