Пример #1
0
def visualize(batch, path, n_samples, figsize=(12, 12)):
    # visualize 30 images from the batch
    assert math.sqrt(n_samples) % 1 == 0
    nrows = int(math.sqrt(n_samples))

    fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize=figsize)
    axes = axes.flatten()
    images = batch["images"]
    assert len(images.shape) == 4
    for i in range(n_samples):
        # get first band of image in numpy format.
        ax = axes[i]
        image = images[i][0].cpu().numpy()
        plot_image(fig, ax, image)

    plt.tight_layout()
    fig.savefig(path, bbox_inches="tight")
Пример #2
0
    def make_plots(self, batch, n_samples=16):
        """Produced informative plots demonstrating encoder performance."""

        assert n_samples ** (0.5) % 1 == 0
        if n_samples > len(batch["n_sources"]):  # do nothing if low on samples.
            return
        nrows = int(n_samples**0.5)  # for figure

        # extract non-params entries so that 'get_full_params' to works.
        true_param_names = {"locs", "n_sources", "galaxy_bools", "star_bools"}
        true_tile_params = TileCatalog(
            self.tile_slen, {k: v for k, v in batch.items() if k in true_param_names}
        )
        true_params = true_tile_params.to_full_params()
        # prediction
        tile_est = true_tile_params.copy()
        shape = tile_est["galaxy_bools"].shape
        tile_est["galaxy_bools"] = batch["pred_galaxy_bools"].reshape(*shape)
        tile_est["star_bools"] = batch["pred_star_bools"].reshape(*shape)
        tile_est["galaxy_probs"] = batch["pred_galaxy_probs"].reshape(*shape)
        est = tile_est.to_full_params()
        # setup figure and axes
        fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize=(12, 12))
        axes = axes.flatten()

        for ii in range(n_samples):
            ax = axes[ii]
            labels = None if ii > 0 else ("t. gal", "p. gal", "t. star", "p. star")
            bp = self.border_padding
            image = batch["images"][ii, 0].cpu().numpy()
            true_plocs = true_params.plocs[ii].cpu().numpy().reshape(-1, 2)
            true_gbools = true_params["galaxy_bools"][ii].cpu().numpy().reshape(-1)
            est_plocs = est.plocs[ii].cpu().numpy().reshape(-1, 2)
            est_gprobs = est["galaxy_probs"][ii].cpu().numpy().reshape(-1)
            slen, _ = image.shape
            reporting.plot_image(fig, ax, image, colorbar=False)
            reporting.plot_locs(ax, bp, slen, true_plocs, true_gbools, m="+", s=30, cmap="cool")
            reporting.plot_locs(ax, bp, slen, est_plocs, est_gprobs, m="x", s=20, cmap="bwr")
            if ii == 0:
                reporting.add_legend(ax, labels)
        fig.tight_layout()

        title = f"Epoch:{self.current_epoch}/Validation Images"
        if self.logger is not None:
            self.logger.experiment.add_figure(title, fig)
Пример #3
0
    def plot_reconstruction(self,
                            images,
                            recon_mean,
                            residuals,
                            n_examples=10,
                            mode="random",
                            width=10,
                            pad=6.0):
        # only plot i band if available, otherwise the highest band given.
        assert images.size(0) >= n_examples
        assert images.shape[1] == recon_mean.shape[1] == residuals.shape[
            1] == 1, "1 band only."
        figsize = (width, width * n_examples / 3)
        fig, axes = plt.subplots(nrows=n_examples, ncols=3, figsize=figsize)

        if mode == "random":
            indices = torch.randint(0, len(images), size=(n_examples, ))
        elif mode == "worst":
            # get indices where absolute residual is the largest.
            absolute_residual = residuals.abs().sum(axis=(1, 2, 3))
            indices = absolute_residual.argsort()[-n_examples:]
        else:
            raise NotImplementedError(
                f"Specified mode '{mode}' has not been implemented.")

        # pick standard ranges for residuals
        vmin_res = residuals[indices].min().item()
        vmax_res = residuals[indices].max().item()

        for i in range(n_examples):
            idx = indices[i]

            ax_true = axes[i, 0]
            ax_recon = axes[i, 1]
            ax_res = axes[i, 2]

            # only add titles to the first axes.
            if i == 0:
                ax_true.set_title("Images $x$", pad=pad)
                ax_recon.set_title(r"Reconstruction $\tilde{x}$", pad=pad)
                ax_res.set_title(
                    r"Residual $\left(x - \tilde{x}\right) / \sqrt{\tilde{x}}$",
                    pad=pad)

            # standarize ranges of true and reconstruction
            image = images[idx, 0].detach().cpu().numpy()
            recon = recon_mean[idx, 0].detach().cpu().numpy()
            residual = residuals[idx, 0].detach().cpu().numpy()
            vmin = min(image.min().item(), recon.min().item())
            vmax = max(image.max().item(), recon.max().item())

            # plot images
            plot_image(fig, ax_true, image, vrange=(vmin, vmax))
            plot_image(fig, ax_recon, recon, vrange=(vmin, vmax))
            plot_image(fig, ax_res, residual, vrange=(vmin_res, vmax_res))

        plt.tight_layout()

        return fig
Пример #4
0
    def reconstruction_figure(self, images, recons, residuals):
        pad = 6.0
        set_rc_params(fontsize=22,
                      tick_label_size="small",
                      legend_fontsize="small")
        fig, axes = plt.subplots(nrows=self.n_examples,
                                 ncols=3,
                                 figsize=(12, 20))

        assert images.shape[0] == recons.shape[0] == residuals.shape[
            0] == self.n_examples
        assert images.shape[1] == recons.shape[1] == residuals.shape[
            1] == 1, "1 band only."

        # pick standard ranges for residuals
        vmin_res = residuals.min().item()
        vmax_res = residuals.max().item()

        for i in range(self.n_examples):

            ax_true = axes[i, 0]
            ax_recon = axes[i, 1]
            ax_res = axes[i, 2]

            # only add titles to the first axes.
            if i == 0:
                ax_true.set_title("Images $x$", pad=pad)
                ax_recon.set_title(r"Reconstruction $\tilde{x}$", pad=pad)
                ax_res.set_title(
                    r"Residual $\left(x - \tilde{x}\right) / \sqrt{\tilde{x}}$",
                    pad=pad)

            # standarize ranges of true and reconstruction
            image = images[i, 0]
            recon = recons[i, 0]
            residual = residuals[i, 0]

            vmin = min(image.min().item(), recon.min().item())
            vmax = max(image.max().item(), recon.max().item())

            # plot images
            reporting.plot_image(fig, ax_true, image, vrange=(vmin, vmax))
            reporting.plot_image(fig, ax_recon, recon, vrange=(vmin, vmax))
            reporting.plot_image(fig,
                                 ax_res,
                                 residual,
                                 vrange=(vmin_res, vmax_res))

        plt.subplots_adjust(hspace=-0.4)
        plt.tight_layout()

        return fig
Пример #5
0
    def create_figures(self, data):  # pylint: disable=too-many-statements
        """Make figures related to reconstruction in SDSS."""
        out_figures = {}

        pad = 6.0
        set_rc_params(fontsize=22,
                      tick_label_size="small",
                      legend_fontsize="small")
        for figname, scene_coords in self.scenes.items():
            slen = scene_coords["size"]
            dvalues = data[figname].values()
            true, recon, res, coplocs, cogbools, plocs, gprobs, prob_n_sources = dvalues
            assert slen == true.shape[-1] == recon.shape[-1] == res.shape[-1]
            fig, (ax_true, ax_recon, ax_res) = plt.subplots(nrows=1,
                                                            ncols=3,
                                                            figsize=(28, 12))

            ax_true.set_title("Original Image", pad=pad)
            ax_recon.set_title("Reconstruction", pad=pad)
            ax_res.set_title("Residual", pad=pad)

            s = 55 * 300 / slen  # marker size
            sp = s * 1.5
            lw = 2 * np.sqrt(300 / slen)

            vrange1 = (800, 1100)
            vrange2 = (-5, 5)
            labels = [
                "Coadd Galaxies", "Coadd Stars", "BLISS Galaxies",
                "BLISS Stars"
            ]
            reporting.plot_image(fig, ax_true, true, vrange1)
            reporting.plot_locs(ax_true,
                                0,
                                slen,
                                coplocs,
                                cogbools,
                                "+",
                                sp,
                                lw,
                                cmap="cool")
            reporting.plot_locs(ax_true,
                                0,
                                slen,
                                plocs,
                                gprobs,
                                "x",
                                s,
                                lw,
                                cmap="bwr")

            reporting.plot_image(fig, ax_recon, recon, vrange1)
            reporting.plot_locs(ax_recon,
                                0,
                                slen,
                                coplocs,
                                cogbools,
                                "+",
                                sp,
                                lw,
                                cmap="cool")
            reporting.plot_locs(ax_recon,
                                0,
                                slen,
                                plocs,
                                gprobs,
                                "x",
                                s,
                                lw,
                                cmap="bwr")
            reporting.add_legend(ax_recon, labels, s=s)

            reporting.plot_image(fig, ax_res, res, vrange2)
            reporting.plot_locs(ax_res,
                                0,
                                slen,
                                coplocs,
                                cogbools,
                                "+",
                                sp,
                                lw,
                                cmap="cool",
                                alpha=0.5)
            reporting.plot_locs(ax_res,
                                0,
                                slen,
                                plocs,
                                gprobs,
                                "x",
                                s,
                                lw,
                                cmap="bwr",
                                alpha=0.5)
            plt.subplots_adjust(hspace=-0.4)
            plt.tight_layout()

            # plot probability of detection in each true object for blends
            if "blend" in figname:
                for ii, ploc in enumerate(coplocs.reshape(-1, 2)):
                    prob = prob_n_sources[ii].item()
                    x, y = ploc[1] + 0.5, ploc[0] + 0.5
                    text = r"$\boldsymbol{" + f"{prob:.2f}" + "}$"
                    ax_true.annotate(text, (x, y), color="lime")

            out_figures[figname] = fig

        return out_figures
Пример #6
0
    def _make_plots(self, batch, n_samples=5):
        # validate worst reconstruction images.
        n_samples = min(len(batch["n_sources"]), n_samples)
        samples = np.random.choice(len(batch["n_sources"]),
                                   n_samples,
                                   replace=False)
        keys = [
            "images",
            "background",
            "locs",
            "galaxy_bools",
            "star_bools",
            "fluxes",
            "log_fluxes",
            "n_sources",
        ]
        for k in keys:
            batch[k] = batch[k][samples]

        # extract non-params entries so that 'get_full_params' to works.
        images = batch["images"]
        background = batch["background"]
        tile_locs = batch["locs"]

        # obtain map estimates
        image_ptiles = get_images_in_tiles(
            torch.cat((images, background), dim=1),
            self.tile_slen,
            self.ptile_slen,
        )
        _, n_tiles_h, n_tiles_w, _, _, _ = image_ptiles.shape
        image_ptiles = rearrange(image_ptiles,
                                 "n nth ntw b h w -> (n nth ntw) b h w")
        locs = rearrange(tile_locs, "n nth ntw ns hw -> 1 (n nth ntw) ns hw")
        z, _ = self._encode(image_ptiles, locs)
        galaxy_params = rearrange(
            z,
            "ns (n nth ntw) ms d -> (ns n) nth ntw ms d",
            ns=1,
            nth=n_tiles_h,
            ntw=n_tiles_w,
        )

        tile_est = TileCatalog(
            self.tile_slen,
            {
                "n_sources": batch["n_sources"],
                "locs": batch["locs"],
                "galaxy_bools": batch["galaxy_bools"],
                "star_bools": batch["star_bools"],
                "fluxes": batch["fluxes"],
                "log_fluxes": batch["log_fluxes"],
                "galaxy_params": galaxy_params,
            },
        )
        est = tile_est.to_full_params()

        # draw all reconstruction images.
        # render_images automatically accounts for tiles with no galaxies.
        recon_images = self.image_decoder.render_images(tile_est)
        recon_images += background
        residuals = (images - recon_images) / torch.sqrt(recon_images)

        # draw worst `n_samples` examples as measured by absolute avg. residual error.
        worst_indices = residuals.abs().mean(dim=(1, 2, 3)).argsort(
            descending=True)[:n_samples]

        if self.crop_loss_at_border:
            bp = self.border_padding * 2
            residuals[:, :, :bp, :] = 0.0
            residuals[:, :, -bp:, :] = 0.0
            residuals[:, :, :, :bp] = 0.0
            residuals[:, :, :, -bp:] = 0.0

        figsize = (12, 4 * n_samples)
        fig, axes = plt.subplots(nrows=n_samples,
                                 ncols=3,
                                 figsize=figsize,
                                 squeeze=False)

        for i, idx in enumerate(worst_indices):

            true_ax = axes[i, 0]
            recon_ax = axes[i, 1]
            res_ax = axes[i, 2]

            # add titles to axes in the first row
            if i == 0:
                true_ax.set_title("Truth", size=18)
                recon_ax.set_title("Reconstruction", size=18)
                res_ax.set_title("Residual", size=18)

            # vmin, vmax should be shared between reconstruction and true images.
            if self.max_flux_valid_plots is None:
                vmax = np.ceil(
                    max(images[idx].max().item(),
                        recon_images[idx].max().item()))
            else:
                vmax = self.max_flux_valid_plots
            vmin = np.floor(
                min(images[idx].min().item(), recon_images[idx].min().item()))
            vrange = (vmin, vmax)

            # plot!
            labels = ("t. gal", None, "t. star", None)
            # Plotting only works on square images
            assert images.shape[-2] == images.shape[-1]
            slen = images.shape[-1] - 2 * self.border_padding
            bp = self.border_padding
            image = images[i, 0].cpu().numpy()
            plocs = est.plocs[i].cpu().numpy().reshape(-1, 2)
            probs = est["galaxy_bools"][i].cpu().numpy().reshape(-1)
            plot_image(fig, true_ax, image, vrange=vrange, colorbar=True)
            plot_locs(true_ax, bp, slen, plocs, probs, cmap="cool")
            plot_image(fig, recon_ax, image, vrange=vrange, colorbar=True)
            plot_locs(recon_ax, bp, slen, plocs, probs, cmap="cool")
            residuals_idx = residuals[idx, 0].cpu().numpy()
            res_vmax = np.ceil(residuals_idx.max())
            res_vmin = np.floor(residuals_idx.min())
            if self.crop_loss_at_border:
                bp = (recon_images.shape[-1] - slen) // 2
                eff_slen = slen - bp
                for b in (bp, bp * 2):
                    recon_ax.axvline(b, color="w")
                    recon_ax.axvline(b + eff_slen, color="w")
                    recon_ax.axhline(b, color="w")
                    recon_ax.axhline(b + eff_slen, color="w")
            plot_image(fig, res_ax, residuals_idx, vrange=(res_vmin, res_vmax))
            if self.crop_loss_at_border:
                for b in (bp, bp * 2):
                    res_ax.axvline(b, color="w")
                    res_ax.axvline(b + eff_slen, color="w")
                    res_ax.axhline(b, color="w")
                    res_ax.axhline(b + eff_slen, color="w")
            if i == 0:
                add_legend(true_ax, labels)

        fig.tight_layout()
        if self.logger:
            self.logger.experiment.add_figure(
                f"Epoch:{self.current_epoch}/Worst Validation Images", fig)
        plt.close(fig)
Пример #7
0
def create_figure(
    true,
    recon,
    res,
    coadd_objects: Optional[FullCatalog] = None,
    map_recon: Optional[FullCatalog] = None,
    include_residuals: bool = True,
    colorbar=True,
    scatter_size: int = 100,
    scatter_on_true: bool = True,
    tile_map=None,
    vmin=800,
    vmax=1200,
):
    """Make figures related to detection and classification in SDSS."""
    plt.style.use("seaborn-colorblind")

    true_gal_col = "m"
    true_star_col = "b"
    pred_gal_col = "c"
    pred_star_col = "r"
    pad = 6.0
    set_rc_params(fontsize=22,
                  tick_label_size="small",
                  legend_fontsize="small")
    ncols = 2 + include_residuals
    figsize = (20 + 10 * include_residuals, 12)
    fig, axes = plt.subplots(nrows=1, ncols=ncols, figsize=figsize)
    assert len(true.shape) == len(recon.shape) == len(res.shape) == 2

    # pick standard ranges for residuals
    scene_size = max(true.shape[-2], true.shape[-1])

    ax_true = axes[0]
    ax_recon = axes[1]

    ax_true.set_title("Original Image", pad=pad)
    ax_recon.set_title("Reconstruction", pad=pad)

    # plot images
    reporting.plot_image(fig,
                         ax_true,
                         true,
                         vrange=(vmin, vmax),
                         colorbar=colorbar,
                         cmap="gist_gray")
    if not tile_map:
        reporting.plot_image(fig,
                             ax_recon,
                             recon,
                             vrange=(vmin, vmax),
                             colorbar=colorbar,
                             cmap="gist_gray")
    else:
        is_on_array = rearrange(tile_map.is_on_array,
                                "1 nth ntw 1 -> nth ntw 1 1")
        is_on_array = repeat(is_on_array,
                             "nth ntw 1 1 -> nth ntw h w",
                             h=4,
                             w=4)
        is_on_array = rearrange(is_on_array, "nth ntw h w -> (nth h) (ntw w)")
        ax_recon.matshow(is_on_array, vmin=0, vmax=1, cmap="gist_gray")
        for grid in range(3, is_on_array.shape[-1], 4):
            ax_recon.axvline(grid + 0.5, color="purple")
            ax_recon.axhline(grid + 0.5, color="purple")

    if include_residuals:
        ax_res = axes[2]
        ax_res.set_title("Residual", pad=pad)
        vmin_res, vmax_res = -6.0, 6.0
        reporting.plot_image(fig, ax_res, res, vrange=(vmin_res, vmax_res))

    if coadd_objects is not None:
        locs_true = coadd_objects.plocs - 0.5
        true_galaxy_bools = coadd_objects["galaxy_bools"]
        locs_galaxies_true = locs_true[true_galaxy_bools.squeeze(-1) > 0.5]
        locs_stars_true = locs_true[true_galaxy_bools.squeeze(-1) < 0.5]

        mismatched = coadd_objects.get("mismatched")
        if mismatched is not None:
            locs_mismatched_true = locs_true[mismatched.squeeze(-1) > 0.5]
        else:
            locs_mismatched_true = None
        if locs_galaxies_true.shape[0] > 0:
            if scatter_on_true:
                ax_true.scatter(
                    locs_galaxies_true[:, 1],
                    locs_galaxies_true[:, 0],
                    color=true_gal_col,
                    marker="+",
                    s=scatter_size,
                    label="COADD Galaxies",
                )
            ax_recon.scatter(
                locs_galaxies_true[:, 1],
                locs_galaxies_true[:, 0],
                color=true_gal_col,
                marker="+",
                s=scatter_size,
                label="SDSS Galaxies",
            )
        if locs_stars_true.shape[0] > 0:
            if scatter_on_true:
                ax_true.scatter(
                    locs_stars_true[:, 1],
                    locs_stars_true[:, 0],
                    color=true_star_col,
                    marker="+",
                    s=scatter_size,
                    label="COADD Stars",
                )
            ax_recon.scatter(
                locs_stars_true[:, 1],
                locs_stars_true[:, 0],
                color=true_star_col,
                marker="+",
                s=scatter_size,
                label="SDSS Stars",
            )
        if locs_mismatched_true is not None:
            if scatter_on_true:
                ax_true.scatter(
                    locs_mismatched_true[:, 1],
                    locs_mismatched_true[:, 0],
                    color="orange",
                    marker="+",
                    s=scatter_size,
                    label="Unmatched",
                )
            ax_recon.scatter(
                locs_mismatched_true[:, 1],
                locs_mismatched_true[:, 0],
                color="orange",
                marker="+",
                s=scatter_size,
                label="Unmatched",
            )

    if map_recon is not None:
        locs_pred = map_recon.plocs[0] - 0.5
        star_bools = map_recon["star_bools"][0, :, 0] > 0.5
        galaxy_bools = map_recon["galaxy_bools"][0, :, 0] > 0.5
        locs_galaxies = locs_pred[galaxy_bools, :]
        locs_stars = locs_pred[star_bools, :]
        locs_extra = locs_pred[(~galaxy_bools) & (~star_bools), :]
        if locs_galaxies.shape[0] > 0:
            in_bounds = torch.all(
                (locs_galaxies > 0) & (locs_galaxies < scene_size), dim=-1)
            locs_galaxies = locs_galaxies[in_bounds]
            if scatter_on_true:
                ax_true.scatter(
                    locs_galaxies[:, 1],
                    locs_galaxies[:, 0],
                    color=pred_gal_col,
                    marker="x",
                    s=scatter_size,
                    label="Predicted Galaxy",
                )
            ax_recon.scatter(
                locs_galaxies[:, 1],
                locs_galaxies[:, 0],
                color=pred_gal_col,
                marker="x",
                s=scatter_size,
                label="Predicted Galaxy",
                alpha=0.6,
            )
            if include_residuals:
                ax_res.scatter(locs_galaxies[:, 1],
                               locs_galaxies[:, 0],
                               color="c",
                               marker="x",
                               s=scatter_size)
        if locs_stars.shape[0] > 0:
            in_bounds = torch.all((locs_stars > 0) & (locs_stars < scene_size),
                                  dim=-1)
            locs_stars = locs_stars[in_bounds]
            if scatter_on_true:
                ax_true.scatter(
                    locs_stars[:, 1],
                    locs_stars[:, 0],
                    color=pred_star_col,
                    marker="x",
                    s=scatter_size,
                    alpha=0.6,
                    label="Predicted Star",
                )
            ax_recon.scatter(
                locs_stars[:, 1],
                locs_stars[:, 0],
                color=pred_star_col,
                marker="x",
                s=scatter_size,
                label="Predicted Star",
                alpha=0.6,
            )
            if include_residuals:
                ax_res.scatter(locs_stars[:, 1],
                               locs_stars[:, 0],
                               color="r",
                               marker="x",
                               s=scatter_size)

        if locs_extra.shape[0] > 0.5:
            in_bounds = torch.all((locs_extra > 0) & (locs_extra < scene_size),
                                  dim=-1)
            locs_extra = locs_extra[in_bounds]
            if scatter_on_true:
                ax_true.scatter(
                    locs_extra[:, 1],
                    locs_extra[:, 0],
                    color="w",
                    marker="x",
                    s=scatter_size,
                    alpha=0.6,
                    label="Predicted Object (below 0.5)",
                )
            ax_recon.scatter(
                locs_extra[:, 1],
                locs_extra[:, 0],
                color="w",
                marker="x",
                s=scatter_size,
                label="Predicted Object (below 0.5)",
                alpha=0.6,
            )
            if include_residuals:
                ax_res.scatter(locs_extra[:, 1],
                               locs_extra[:, 0],
                               color="w",
                               marker="x",
                               s=scatter_size)

    ax_recon.legend(
        bbox_to_anchor=(0.0, -0.1, 1.0, 0.5),
        loc="lower left",
        ncol=4,
        mode="expand",
        borderaxespad=0.0,
    )
    plt.subplots_adjust(hspace=-0.4)
    plt.tight_layout()

    return fig