Exemplo n.º 1
0
    def make_scatter_contours_plot(self, meas):
        sns.set_theme(style="darkgrid")
        set_rc_params(fontsize=22,
                      legend_fontsize="small",
                      tick_label_size="small",
                      label_size="medium")
        fig, ax = plt.subplots(1, 1, figsize=(8, 8))

        # magnitudes
        x, y = meas["true_mags"], meas["recon_mags"]
        mag_ticks = (15, 16, 17, 18, 19, 20, 21, 22, 23)
        xlabel = r"$m^{\rm true}$"
        ylabel = r"$m^{\rm recon}$"
        make_scatter_contours(
            ax,
            x,
            y,
            xlabel=xlabel,
            ylabel=ylabel,
            xticks=mag_ticks,
            yticks=mag_ticks,
            xlims=(15, 24),
            ylims=(15, 24),
        )
        ax.plot([15, 24], [15, 24], color="r", lw=2)
        plt.tight_layout()
        return fig
Exemplo n.º 2
0
    def make_magnitude_prob_scatter_figure(data):
        # scatter of matched objects magnitude vs classification probability.
        set_rc_params(tick_label_size=22, label_size=30)
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        tgbool = data["tgbool"].astype(bool)
        egbool = data["egbool"].astype(bool)
        tmag, egprob = data["tmag"], data["egprob"]
        correct = np.equal(tgbool, egbool)

        ax.scatter(tmag[correct],
                   egprob[correct],
                   marker="+",
                   c="b",
                   label="correct",
                   alpha=0.5)
        ax.scatter(tmag[~correct],
                   egprob[~correct],
                   marker="x",
                   c="r",
                   label="incorrect",
                   alpha=0.5)
        ax.axhline(0.5, linestyle="--")
        ax.axhline(0.1, linestyle="--")
        ax.axhline(0.9, linestyle="--")
        ax.set_xlabel("True Magnitude")
        ax.set_ylabel("Estimated Probability of Galaxy")
        ax.legend(loc="best", prop={"size": 22})

        return fig
Exemplo n.º 3
0
    def reconstruction_figure(self, images, recons, residuals):
        pad = 6.0
        set_rc_params(fontsize=22,
                      tick_label_size="small",
                      legend_fontsize="small")
        fig, axes = plt.subplots(nrows=self.n_examples,
                                 ncols=3,
                                 figsize=(12, 20))

        assert images.shape[0] == recons.shape[0] == residuals.shape[
            0] == self.n_examples
        assert images.shape[1] == recons.shape[1] == residuals.shape[
            1] == 1, "1 band only."

        # pick standard ranges for residuals
        vmin_res = residuals.min().item()
        vmax_res = residuals.max().item()

        for i in range(self.n_examples):

            ax_true = axes[i, 0]
            ax_recon = axes[i, 1]
            ax_res = axes[i, 2]

            # only add titles to the first axes.
            if i == 0:
                ax_true.set_title("Images $x$", pad=pad)
                ax_recon.set_title(r"Reconstruction $\tilde{x}$", pad=pad)
                ax_res.set_title(
                    r"Residual $\left(x - \tilde{x}\right) / \sqrt{\tilde{x}}$",
                    pad=pad)

            # standarize ranges of true and reconstruction
            image = images[i, 0]
            recon = recons[i, 0]
            residual = residuals[i, 0]

            vmin = min(image.min().item(), recon.min().item())
            vmax = max(image.max().item(), recon.max().item())

            # plot images
            reporting.plot_image(fig, ax_true, image, vrange=(vmin, vmax))
            reporting.plot_image(fig, ax_recon, recon, vrange=(vmin, vmax))
            reporting.plot_image(fig,
                                 ax_res,
                                 residual,
                                 vrange=(vmin_res, vmax_res))

        plt.subplots_adjust(hspace=-0.4)
        plt.tight_layout()

        return fig
Exemplo n.º 4
0
    def make_classification_figure(
            mags,
            data,
            cuts_or_bins="cuts",
            xlims=(18, 24),
            ylims=(0.5, 1.05),
            ratio=2,
            where_step="mid",
            n_gap=50,
    ):
        # classification accuracy
        class_acc = data["class_acc"]
        galaxy_acc = data["galaxy_acc"]
        star_acc = data["star_acc"]
        set_rc_params(tick_label_size=22, label_size=30)
        fig, (ax1,
              ax2) = plt.subplots(2,
                                  1,
                                  figsize=(10, 10),
                                  gridspec_kw={"height_ratios": [1, ratio]},
                                  sharex=True)
        xlabel = r"\rm magnitude " + cuts_or_bins[:-1]
        format_plot(ax2, xlabel=xlabel, ylabel="classification accuracy")
        ax2.plot(mags, galaxy_acc, "-o", label=r"\rm galaxy")
        ax2.plot(mags, star_acc, "-o", label=r"\rm star")
        ax2.plot(mags, class_acc, "-o", label=r"\rm overall")
        ax2.set_xlim(xlims)
        ax2.set_ylim(ylims)
        ax2.legend(loc="lower left", prop={"size": 18})

        # setup histogram up top.
        gcounts = data["n_matches_coadd_gal"]
        scounts = data["n_matches_coadd_star"]
        ax1.step(mags,
                 gcounts,
                 label=r"\rm matched coadd galaxies",
                 where=where_step)
        ax1.step(mags,
                 scounts,
                 label=r"\rm matched coadd stars",
                 where=where_step)
        ymax = max(max(gcounts), max(scounts))
        ymax = np.ceil(ymax / n_gap) * n_gap
        yticks = np.arange(0, ymax, n_gap)
        format_plot(ax1, yticks=yticks, ylabel=r"\rm Counts")
        ax1.legend(loc="best", prop={"size": 16})
        ax1.set_ylim((0, ymax))
        plt.subplots_adjust(hspace=0)

        return fig
Exemplo n.º 5
0
def make_detection_figure(
        mags,
        data,
        xlims=(18, 24),
        ylims=(0.5, 1.05),
        ratio=2,
        where_step="mid",
        n_gap=50,
):
    # precision / recall / f1 score
    precision = data["precision"]
    recall = data["recall"]
    f1_score = data["f1"]
    tgcount = data["tgcount"]
    tscount = data["tscount"]
    egcount = data["egcount"]
    escount = data["escount"]
    # (1) precision / recall
    set_rc_params(tick_label_size=22, label_size=30)
    fig, (ax1, ax2) = plt.subplots(2,
                                   1,
                                   figsize=(10, 10),
                                   gridspec_kw={"height_ratios": [1, ratio]},
                                   sharex=True)
    ymin = min(min(precision), min(recall))
    yticks = np.arange(np.round(ymin, 1), 1.1, 0.1)
    format_plot(ax2,
                xlabel=r"\rm magnitude cut",
                ylabel="metric",
                yticks=yticks)
    ax2.plot(mags, recall, "-o", label=r"\rm recall")
    ax2.plot(mags, precision, "-o", label=r"\rm precision")
    ax2.plot(mags, f1_score, "-o", label=r"\rm f1 score")
    ax2.legend(loc="lower left", prop={"size": 22})
    ax2.set_xlim(xlims)
    ax2.set_ylim(ylims)

    # setup histogram plot up top.
    c1 = CB_color_cycle[3]
    c2 = CB_color_cycle[4]
    ax1.step(mags, tgcount, label="coadd galaxies", where=where_step, color=c1)
    ax1.step(mags, tscount, label="coadd stars", where=where_step, color=c2)
    ax1.step(mags,
             egcount,
             label="pred. galaxies",
             ls="--",
             where=where_step,
             color=c1)
    ax1.step(mags,
             escount,
             label="pred. stars",
             ls="--",
             where=where_step,
             color=c2)
    ymax = max(max(tgcount), max(tscount), max(egcount), max(escount))
    ymax = np.ceil(ymax / n_gap) * n_gap
    yticks = np.arange(0, ymax, n_gap)
    ax1.set_ylim((0, ymax))
    format_plot(ax1, yticks=yticks, ylabel=r"\rm Counts")
    ax1.legend(loc="best", prop={"size": 16})
    plt.subplots_adjust(hspace=0)

    return fig
Exemplo n.º 6
0
    def create_figures(self, data):  # pylint: disable=too-many-statements
        """Make figures related to reconstruction in SDSS."""
        out_figures = {}

        pad = 6.0
        set_rc_params(fontsize=22,
                      tick_label_size="small",
                      legend_fontsize="small")
        for figname, scene_coords in self.scenes.items():
            slen = scene_coords["size"]
            dvalues = data[figname].values()
            true, recon, res, coplocs, cogbools, plocs, gprobs, prob_n_sources = dvalues
            assert slen == true.shape[-1] == recon.shape[-1] == res.shape[-1]
            fig, (ax_true, ax_recon, ax_res) = plt.subplots(nrows=1,
                                                            ncols=3,
                                                            figsize=(28, 12))

            ax_true.set_title("Original Image", pad=pad)
            ax_recon.set_title("Reconstruction", pad=pad)
            ax_res.set_title("Residual", pad=pad)

            s = 55 * 300 / slen  # marker size
            sp = s * 1.5
            lw = 2 * np.sqrt(300 / slen)

            vrange1 = (800, 1100)
            vrange2 = (-5, 5)
            labels = [
                "Coadd Galaxies", "Coadd Stars", "BLISS Galaxies",
                "BLISS Stars"
            ]
            reporting.plot_image(fig, ax_true, true, vrange1)
            reporting.plot_locs(ax_true,
                                0,
                                slen,
                                coplocs,
                                cogbools,
                                "+",
                                sp,
                                lw,
                                cmap="cool")
            reporting.plot_locs(ax_true,
                                0,
                                slen,
                                plocs,
                                gprobs,
                                "x",
                                s,
                                lw,
                                cmap="bwr")

            reporting.plot_image(fig, ax_recon, recon, vrange1)
            reporting.plot_locs(ax_recon,
                                0,
                                slen,
                                coplocs,
                                cogbools,
                                "+",
                                sp,
                                lw,
                                cmap="cool")
            reporting.plot_locs(ax_recon,
                                0,
                                slen,
                                plocs,
                                gprobs,
                                "x",
                                s,
                                lw,
                                cmap="bwr")
            reporting.add_legend(ax_recon, labels, s=s)

            reporting.plot_image(fig, ax_res, res, vrange2)
            reporting.plot_locs(ax_res,
                                0,
                                slen,
                                coplocs,
                                cogbools,
                                "+",
                                sp,
                                lw,
                                cmap="cool",
                                alpha=0.5)
            reporting.plot_locs(ax_res,
                                0,
                                slen,
                                plocs,
                                gprobs,
                                "x",
                                s,
                                lw,
                                cmap="bwr",
                                alpha=0.5)
            plt.subplots_adjust(hspace=-0.4)
            plt.tight_layout()

            # plot probability of detection in each true object for blends
            if "blend" in figname:
                for ii, ploc in enumerate(coplocs.reshape(-1, 2)):
                    prob = prob_n_sources[ii].item()
                    x, y = ploc[1] + 0.5, ploc[0] + 0.5
                    text = r"$\boldsymbol{" + f"{prob:.2f}" + "}$"
                    ax_true.annotate(text, (x, y), color="lime")

            out_figures[figname] = fig

        return out_figures
Exemplo n.º 7
0
def create_figure(
    true,
    recon,
    res,
    coadd_objects: Optional[FullCatalog] = None,
    map_recon: Optional[FullCatalog] = None,
    include_residuals: bool = True,
    colorbar=True,
    scatter_size: int = 100,
    scatter_on_true: bool = True,
    tile_map=None,
    vmin=800,
    vmax=1200,
):
    """Make figures related to detection and classification in SDSS."""
    plt.style.use("seaborn-colorblind")

    true_gal_col = "m"
    true_star_col = "b"
    pred_gal_col = "c"
    pred_star_col = "r"
    pad = 6.0
    set_rc_params(fontsize=22,
                  tick_label_size="small",
                  legend_fontsize="small")
    ncols = 2 + include_residuals
    figsize = (20 + 10 * include_residuals, 12)
    fig, axes = plt.subplots(nrows=1, ncols=ncols, figsize=figsize)
    assert len(true.shape) == len(recon.shape) == len(res.shape) == 2

    # pick standard ranges for residuals
    scene_size = max(true.shape[-2], true.shape[-1])

    ax_true = axes[0]
    ax_recon = axes[1]

    ax_true.set_title("Original Image", pad=pad)
    ax_recon.set_title("Reconstruction", pad=pad)

    # plot images
    reporting.plot_image(fig,
                         ax_true,
                         true,
                         vrange=(vmin, vmax),
                         colorbar=colorbar,
                         cmap="gist_gray")
    if not tile_map:
        reporting.plot_image(fig,
                             ax_recon,
                             recon,
                             vrange=(vmin, vmax),
                             colorbar=colorbar,
                             cmap="gist_gray")
    else:
        is_on_array = rearrange(tile_map.is_on_array,
                                "1 nth ntw 1 -> nth ntw 1 1")
        is_on_array = repeat(is_on_array,
                             "nth ntw 1 1 -> nth ntw h w",
                             h=4,
                             w=4)
        is_on_array = rearrange(is_on_array, "nth ntw h w -> (nth h) (ntw w)")
        ax_recon.matshow(is_on_array, vmin=0, vmax=1, cmap="gist_gray")
        for grid in range(3, is_on_array.shape[-1], 4):
            ax_recon.axvline(grid + 0.5, color="purple")
            ax_recon.axhline(grid + 0.5, color="purple")

    if include_residuals:
        ax_res = axes[2]
        ax_res.set_title("Residual", pad=pad)
        vmin_res, vmax_res = -6.0, 6.0
        reporting.plot_image(fig, ax_res, res, vrange=(vmin_res, vmax_res))

    if coadd_objects is not None:
        locs_true = coadd_objects.plocs - 0.5
        true_galaxy_bools = coadd_objects["galaxy_bools"]
        locs_galaxies_true = locs_true[true_galaxy_bools.squeeze(-1) > 0.5]
        locs_stars_true = locs_true[true_galaxy_bools.squeeze(-1) < 0.5]

        mismatched = coadd_objects.get("mismatched")
        if mismatched is not None:
            locs_mismatched_true = locs_true[mismatched.squeeze(-1) > 0.5]
        else:
            locs_mismatched_true = None
        if locs_galaxies_true.shape[0] > 0:
            if scatter_on_true:
                ax_true.scatter(
                    locs_galaxies_true[:, 1],
                    locs_galaxies_true[:, 0],
                    color=true_gal_col,
                    marker="+",
                    s=scatter_size,
                    label="COADD Galaxies",
                )
            ax_recon.scatter(
                locs_galaxies_true[:, 1],
                locs_galaxies_true[:, 0],
                color=true_gal_col,
                marker="+",
                s=scatter_size,
                label="SDSS Galaxies",
            )
        if locs_stars_true.shape[0] > 0:
            if scatter_on_true:
                ax_true.scatter(
                    locs_stars_true[:, 1],
                    locs_stars_true[:, 0],
                    color=true_star_col,
                    marker="+",
                    s=scatter_size,
                    label="COADD Stars",
                )
            ax_recon.scatter(
                locs_stars_true[:, 1],
                locs_stars_true[:, 0],
                color=true_star_col,
                marker="+",
                s=scatter_size,
                label="SDSS Stars",
            )
        if locs_mismatched_true is not None:
            if scatter_on_true:
                ax_true.scatter(
                    locs_mismatched_true[:, 1],
                    locs_mismatched_true[:, 0],
                    color="orange",
                    marker="+",
                    s=scatter_size,
                    label="Unmatched",
                )
            ax_recon.scatter(
                locs_mismatched_true[:, 1],
                locs_mismatched_true[:, 0],
                color="orange",
                marker="+",
                s=scatter_size,
                label="Unmatched",
            )

    if map_recon is not None:
        locs_pred = map_recon.plocs[0] - 0.5
        star_bools = map_recon["star_bools"][0, :, 0] > 0.5
        galaxy_bools = map_recon["galaxy_bools"][0, :, 0] > 0.5
        locs_galaxies = locs_pred[galaxy_bools, :]
        locs_stars = locs_pred[star_bools, :]
        locs_extra = locs_pred[(~galaxy_bools) & (~star_bools), :]
        if locs_galaxies.shape[0] > 0:
            in_bounds = torch.all(
                (locs_galaxies > 0) & (locs_galaxies < scene_size), dim=-1)
            locs_galaxies = locs_galaxies[in_bounds]
            if scatter_on_true:
                ax_true.scatter(
                    locs_galaxies[:, 1],
                    locs_galaxies[:, 0],
                    color=pred_gal_col,
                    marker="x",
                    s=scatter_size,
                    label="Predicted Galaxy",
                )
            ax_recon.scatter(
                locs_galaxies[:, 1],
                locs_galaxies[:, 0],
                color=pred_gal_col,
                marker="x",
                s=scatter_size,
                label="Predicted Galaxy",
                alpha=0.6,
            )
            if include_residuals:
                ax_res.scatter(locs_galaxies[:, 1],
                               locs_galaxies[:, 0],
                               color="c",
                               marker="x",
                               s=scatter_size)
        if locs_stars.shape[0] > 0:
            in_bounds = torch.all((locs_stars > 0) & (locs_stars < scene_size),
                                  dim=-1)
            locs_stars = locs_stars[in_bounds]
            if scatter_on_true:
                ax_true.scatter(
                    locs_stars[:, 1],
                    locs_stars[:, 0],
                    color=pred_star_col,
                    marker="x",
                    s=scatter_size,
                    alpha=0.6,
                    label="Predicted Star",
                )
            ax_recon.scatter(
                locs_stars[:, 1],
                locs_stars[:, 0],
                color=pred_star_col,
                marker="x",
                s=scatter_size,
                label="Predicted Star",
                alpha=0.6,
            )
            if include_residuals:
                ax_res.scatter(locs_stars[:, 1],
                               locs_stars[:, 0],
                               color="r",
                               marker="x",
                               s=scatter_size)

        if locs_extra.shape[0] > 0.5:
            in_bounds = torch.all((locs_extra > 0) & (locs_extra < scene_size),
                                  dim=-1)
            locs_extra = locs_extra[in_bounds]
            if scatter_on_true:
                ax_true.scatter(
                    locs_extra[:, 1],
                    locs_extra[:, 0],
                    color="w",
                    marker="x",
                    s=scatter_size,
                    alpha=0.6,
                    label="Predicted Object (below 0.5)",
                )
            ax_recon.scatter(
                locs_extra[:, 1],
                locs_extra[:, 0],
                color="w",
                marker="x",
                s=scatter_size,
                label="Predicted Object (below 0.5)",
                alpha=0.6,
            )
            if include_residuals:
                ax_res.scatter(locs_extra[:, 1],
                               locs_extra[:, 0],
                               color="w",
                               marker="x",
                               s=scatter_size)

    ax_recon.legend(
        bbox_to_anchor=(0.0, -0.1, 1.0, 0.5),
        loc="lower left",
        ncol=4,
        mode="expand",
        borderaxespad=0.0,
    )
    plt.subplots_adjust(hspace=-0.4)
    plt.tight_layout()

    return fig
Exemplo n.º 8
0
    def make_scatter_bin_plots(self, meas):
        fig, axes = plt.subplots(2, 2, figsize=(16, 16))
        ax1, ax2, ax3, ax4 = axes.flatten()
        set_rc_params(fontsize=24)
        snr = meas["snr"]
        xticks = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
        xlims = (0, 3)
        xlabel = r"$\log_{10} \text{SNR}$"

        # magnitudes
        true_mags, recon_mags = meas["true_mags"], meas["recon_mags"]
        x, y = np.log10(snr), recon_mags - true_mags
        scatter_bin_plot(
            ax1,
            x,
            y,
            delta=0.2,
            xlims=xlims,
            xlabel=xlabel,
            ylabel=r"\rm $m^{\rm recon} - m^{\rm true}$",
            xticks=xticks,
        )

        # fluxes
        true_fluxes, recon_fluxes = meas["true_fluxes"], meas["recon_fluxes"]
        x, y = np.log10(snr), (recon_fluxes - true_fluxes) / recon_fluxes
        scatter_bin_plot(
            ax2,
            x,
            y,
            delta=0.2,
            xlims=xlims,
            xlabel=xlabel,
            ylabel=r"\rm $(f^{\rm recon} - f^{\rm true}) / f^{\rm recon}$",
            xticks=xticks,
        )

        # ellipticities
        true_ellip1, recon_ellip1 = meas["true_ellips"][:, 0], meas[
            "recon_ellips"][:, 0]
        x, y = np.log10(snr), recon_ellip1 - true_ellip1
        scatter_bin_plot(
            ax3,
            x,
            y,
            delta=0.2,
            xlims=xlims,
            xticks=xticks,
            xlabel=xlabel,
            ylabel=r"$g_{1}^{\rm recon} - g_{1}^{\rm true}$",
        )

        true_ellip2, recon_ellip2 = meas["true_ellips"][:, 1], meas[
            "recon_ellips"][:, 1]
        x, y = np.log10(snr), recon_ellip2 - true_ellip2
        scatter_bin_plot(
            ax4,
            x,
            y,
            delta=0.2,
            xlims=xlims,
            xticks=xticks,
            xlabel=xlabel,
            ylabel=r"$g_{2}^{\rm recon} - g_{2}^{\rm true}$",
        )

        plt.tight_layout()

        return fig