예제 #1
0
파일: inference.py 프로젝트: prob-ml/bliss
 def get_catalog(self, hlims, wlims) -> FullCatalog:
     h, h_end = hlims[0] - self.bp, hlims[1] - self.bp
     w, w_end = wlims[0] - self.bp, wlims[1] - self.bp
     hlims_tile = int(np.floor(h / self.tile_slen)), int(
         np.ceil(h_end / self.tile_slen))
     wlims_tile = int(np.floor(w / self.tile_slen)), int(
         np.ceil(w_end / self.tile_slen))
     tile_dict = {}
     for k, v in self.tile_catalog.to_dict().items():
         tile_dict[k] = v[:, hlims_tile[0]:hlims_tile[1],
                          wlims_tile[0]:wlims_tile[1]]
     tile_cat = TileCatalog(self.tile_slen, tile_dict)
     return tile_cat.to_full_params()
예제 #2
0
def make_plots_of_marginal_class(
    outdir: Path,
    encoder: Encoder,
    decoder: ImageDecoder,
    frame: Frame,
    tile_map_recon: TileCatalog,
    detections_at_mode,
):
    tile_map_recon["matched"] = detections_at_mode["est_tile_matches"]
    bp = encoder.border_padding
    full_map = tile_map_recon.to_full_params()
    marginal_galaxy = full_map["galaxy_probs"][0, :, 0].exp() <= 0.6
    marginal_star = full_map["galaxy_probs"][0, :, 0].exp() >= 0.4
    marginal = marginal_galaxy & marginal_star
    csv_lines = ["fname,in_coadd,ra,dec,prob"]
    for i, ploc in tqdm(enumerate(full_map.plocs[0]),
                        desc="marginal classifications"):
        if marginal[i] and full_map["matched"][0, i, 0].item():
            size = 40
            h = ploc[0].item()
            w = ploc[1].item()
            h_topleft = max(int(h - (size / 2.0)), 0) + 24
            w_topleft = max(int(w - (size / 2.0)), 0) + 24
            fig = create_figure_at_point(
                h_topleft,
                w_topleft,
                size,
                bp,
                tile_map_recon,
                frame,
                decoder,
            )
            fname = f"h{h_topleft}_w{w_topleft}.png"
            fig.savefig(outdir / fname)
            matched = full_map["matched"][0, i, 0]

            if isinstance(frame, SDSSFrame):
                ra, dec = frame.wcs.wcs_pix2world(w + 24, h + 24, 0)
            else:
                ra, dec = None, None
            prob = full_map["galaxy_probs"][0, i, 0].exp().item()
            csv_lines.append(f"{fname},{matched},{ra},{dec},{prob}")
    out_csv = outdir / "marginal_detections.csv"
    out_csv.write_text("\n".join(csv_lines))
예제 #3
0
def stats_for_threshold(
    true_plocs: Tensor,
    est_tile_cat: TileCatalog,
    threshold: Optional[float] = None,
):
    tile_slen = est_tile_cat.tile_slen
    max_sources = est_tile_cat.max_sources
    if threshold is not None:
        log_probs = rearrange(est_tile_cat["n_source_log_probs"],
                              "n nth ntw 1 1 -> n nth ntw")
        est_tile_cat.n_sources = log_probs >= math.log(threshold)
    est_cat = est_tile_cat.to_full_params()
    number_true = true_plocs.shape[1]
    number_est = int(est_cat.plocs.shape[1])
    true_matches = torch.zeros(true_plocs.shape[1], dtype=torch.bool)
    est_tile_matches = torch.zeros(*est_tile_cat.n_sources.shape,
                                   1,
                                   1,
                                   dtype=torch.bool)
    if number_true == 0 or number_est == 0:
        return {
            "tp": torch.tensor(0.0),
            "fp": torch.tensor(float(number_est)),
            "true_matches": true_matches,
            "est_tile_matches": est_tile_matches,
        }
    est_matches = torch.zeros(est_cat.plocs.shape[1], dtype=torch.bool)
    row_indx, col_indx, d, _ = reporting.match_by_locs(true_plocs[0],
                                                       est_cat.plocs[0], 1.0)
    true_matches[row_indx] = d
    est_matches[col_indx] = d
    est_cat["matched"] = est_matches.reshape(1, -1, 1)
    est_tile_matches = est_cat.to_tile_params(tile_slen,
                                              max_sources)["matched"]
    tp = d.sum()
    fp = torch.tensor(number_est) - tp
    return {
        "tp": tp,
        "fp": fp,
        "true_matches": true_matches,
        "est_tile_matches": est_tile_matches
    }
예제 #4
0
def test_galaxy_blend(get_sdss_galaxies_config, devices):
    overrides = {
        "datasets.galsim_blended_galaxies.num_workers": 0,
        "datasets.galsim_blended_galaxies.batch_size": 4,
        "datasets.galsim_blended_galaxies.n_batches": 1,
        "datasets.galsim_blended_galaxies.prior.max_n_sources": 3,
    }
    cfg = get_sdss_galaxies_config(overrides, devices)
    blend_ds: GalsimBlends = instantiate(cfg.datasets.galsim_blended_galaxies)

    for b in blend_ds.train_dataloader():
        images, _ = b.pop("images"), b.pop("background")
        tile_cat = TileCatalog(4, b)
        full_cat = tile_cat.to_full_params()
        max_n_sources = full_cat.max_sources
        n_sources = full_cat.n_sources
        plocs = full_cat.plocs
        params = full_cat["galaxy_params"]
        snr = full_cat["snr"]
        blendedness = full_cat["blendedness"]
        ellips = full_cat["ellips"]
        mags = full_cat["mags"]
        fluxes = full_cat["fluxes"]
        assert images.shape == (4, 1, 88, 88)  # 40 + 24 * 2
        assert params.shape == (4, max_n_sources, 7)
        assert plocs.shape == (4, max_n_sources, 2)
        assert snr.shape == (4, max_n_sources, 1)
        assert blendedness.shape == (4, max_n_sources, 1)
        assert n_sources.shape == (4, )
        assert ellips.shape == (4, max_n_sources, 2)
        assert mags.shape == (4, max_n_sources, 1)
        assert fluxes.shape == (4, max_n_sources, 1)

        # check empty if no sources
        for ii, n in enumerate(n_sources):
            assert torch.all(snr[ii, n:] == 0)
            assert torch.all(blendedness[ii, n:] == 0)
            assert torch.all(plocs[ii, n:] == 0)
예제 #5
0
    def compute_data(self, blend_file: Path, encoder: Encoder,
                     decoder: ImageDecoder):
        blend_data = torch.load(blend_file)
        images = blend_data.pop("images")
        background = blend_data.pop("background")
        n_batches, _, slen, _ = images.shape
        assert background.shape == (1, slen, slen)

        # prepare background
        background = background.unsqueeze(0)
        background = background.expand(n_batches, 1, slen, slen)

        # first create FullCatalog from simulated data
        tile_cat = TileCatalog(decoder.tile_slen, blend_data).cpu()
        full_truth = tile_cat.to_full_params()

        print("INFO: BLISS posterior inference on images.")
        tile_est = encoder.variational_mode(images, background)
        tile_est.set_all_fluxes_and_mags(decoder)
        tile_est.set_galaxy_ellips(decoder, scale=0.393)
        tile_est = tile_est.cpu()
        full_est = tile_est.to_full_params()

        snr = []
        blendedness = []
        true_mags = []
        true_ellips1 = []
        true_ellips2 = []
        est_mags = []
        est_ellips1 = []
        est_ellips2 = []
        for ii in tqdm(range(n_batches), desc="Matching batches"):
            true_plocs_ii, est_plocs_ii = full_truth.plocs[ii], full_est.plocs[
                ii]

            tindx, eindx, dkeep, _ = match_by_locs(true_plocs_ii, est_plocs_ii)
            n_matches = len(tindx[dkeep])

            snr_ii = full_truth["snr"][ii][tindx][dkeep]
            blendedness_ii = full_truth["blendedness"][ii][tindx][dkeep]
            true_mag_ii = full_truth["mags"][ii][tindx][dkeep]
            est_mag_ii = full_est["mags"][ii][eindx][dkeep]
            true_ellips_ii = full_truth["ellips"][ii][tindx][dkeep]
            est_ellips_ii = full_est["ellips"][ii][eindx][dkeep]
            n_matches = len(snr_ii)
            for jj in range(n_matches):
                snr.append(snr_ii[jj].item())
                blendedness.append(blendedness_ii[jj].item())
                true_mags.append(true_mag_ii[jj].item())
                est_mags.append(est_mag_ii[jj].item())
                true_ellips1.append(true_ellips_ii[jj][0].item())
                true_ellips2.append(true_ellips_ii[jj][1].item())
                est_ellips1.append(est_ellips_ii[jj][0].item())
                est_ellips2.append(est_ellips_ii[jj][1].item())

        true_ellips = torch.vstack(
            [torch.tensor(true_ellips1),
             torch.tensor(true_ellips2)])
        true_ellips = true_ellips.T.reshape(-1, 2)

        est_ellips = torch.vstack(
            [torch.tensor(est_ellips1),
             torch.tensor(est_ellips2)])
        est_ellips = est_ellips.T.reshape(-1, 2)

        return {
            "snr": torch.tensor(snr),
            "blendedness": torch.tensor(blendedness),
            "true_mags": torch.tensor(true_mags),
            "est_mags": torch.tensor(est_mags),
            "true_ellips": true_ellips,
            "est_ellips": est_ellips,
        }
예제 #6
0
파일: inference.py 프로젝트: prob-ml/bliss
    def __init__(
        self,
        dataset: SimulatedDataset,
        coadd: str,
        n_tiles_h,
        n_tiles_w,
        cache_dir=None,
    ):
        dataset.to("cpu")
        self.bp = dataset.image_decoder.border_padding
        self.tile_slen = dataset.tile_slen
        self.coadd_file = Path(coadd)
        self.sdss_dir = self.coadd_file.parent
        run = 94
        camcol = 1
        field = 12
        bands = (2, )
        sdss_data = SloanDigitalSkySurvey(
            sdss_dir=self.sdss_dir,
            run=run,
            camcol=camcol,
            fields=(field, ),
            bands=bands,
        )
        wcs = sdss_data[0]["wcs"][0]
        if cache_dir is not None:
            sim_frame_path = Path(cache_dir) / "simulated_frame.pt"
        else:
            sim_frame_path = None
        if sim_frame_path and sim_frame_path.exists():
            tile_catalog_dict, image, background = torch.load(sim_frame_path)
            tile_catalog = TileCatalog(self.tile_slen, tile_catalog_dict)
        else:
            hlim = (self.bp, self.bp + n_tiles_h * self.tile_slen)
            wlim = (self.bp, self.bp + n_tiles_w * self.tile_slen)
            full_coadd_cat = CoaddFullCatalog.from_file(coadd,
                                                        wcs,
                                                        hlim,
                                                        wlim,
                                                        band="r")
            if dataset.image_prior.galaxy_prior is not None:
                full_coadd_cat[
                    "galaxy_params"] = dataset.image_prior.galaxy_prior.sample(
                        full_coadd_cat.n_sources, "cpu").unsqueeze(0)
            full_coadd_cat.plocs = full_coadd_cat.plocs + 0.5
            max_sources = dataset.image_prior.max_sources
            tile_catalog = full_coadd_cat.to_tile_params(
                self.tile_slen, max_sources)
            tile_catalog.set_all_fluxes_and_mags(dataset.image_decoder)
            fc = tile_catalog.to_full_params()
            assert fc.equals(full_coadd_cat,
                             exclude=("galaxy_fluxes", "star_bools", "fluxes",
                                      "mags"))
            print("INFO: started generating frame")
            image, background = dataset.simulate_image_from_catalog(
                tile_catalog)
            print("INFO: done generating frame")
            if sim_frame_path:
                torch.save((tile_catalog.to_dict(), image, background),
                           sim_frame_path)

        self.tile_catalog = tile_catalog
        self.image = image
        self.background = background
        assert self.image.shape[0] == 1
        assert self.background.shape[0] == 1
예제 #7
0
    def _make_plots(self, batch, n_samples=5):
        # validate worst reconstruction images.
        n_samples = min(len(batch["n_sources"]), n_samples)
        samples = np.random.choice(len(batch["n_sources"]),
                                   n_samples,
                                   replace=False)
        keys = [
            "images",
            "background",
            "locs",
            "galaxy_bools",
            "star_bools",
            "fluxes",
            "log_fluxes",
            "n_sources",
        ]
        for k in keys:
            batch[k] = batch[k][samples]

        # extract non-params entries so that 'get_full_params' to works.
        images = batch["images"]
        background = batch["background"]
        tile_locs = batch["locs"]

        # obtain map estimates
        image_ptiles = get_images_in_tiles(
            torch.cat((images, background), dim=1),
            self.tile_slen,
            self.ptile_slen,
        )
        _, n_tiles_h, n_tiles_w, _, _, _ = image_ptiles.shape
        image_ptiles = rearrange(image_ptiles,
                                 "n nth ntw b h w -> (n nth ntw) b h w")
        locs = rearrange(tile_locs, "n nth ntw ns hw -> 1 (n nth ntw) ns hw")
        z, _ = self._encode(image_ptiles, locs)
        galaxy_params = rearrange(
            z,
            "ns (n nth ntw) ms d -> (ns n) nth ntw ms d",
            ns=1,
            nth=n_tiles_h,
            ntw=n_tiles_w,
        )

        tile_est = TileCatalog(
            self.tile_slen,
            {
                "n_sources": batch["n_sources"],
                "locs": batch["locs"],
                "galaxy_bools": batch["galaxy_bools"],
                "star_bools": batch["star_bools"],
                "fluxes": batch["fluxes"],
                "log_fluxes": batch["log_fluxes"],
                "galaxy_params": galaxy_params,
            },
        )
        est = tile_est.to_full_params()

        # draw all reconstruction images.
        # render_images automatically accounts for tiles with no galaxies.
        recon_images = self.image_decoder.render_images(tile_est)
        recon_images += background
        residuals = (images - recon_images) / torch.sqrt(recon_images)

        # draw worst `n_samples` examples as measured by absolute avg. residual error.
        worst_indices = residuals.abs().mean(dim=(1, 2, 3)).argsort(
            descending=True)[:n_samples]

        if self.crop_loss_at_border:
            bp = self.border_padding * 2
            residuals[:, :, :bp, :] = 0.0
            residuals[:, :, -bp:, :] = 0.0
            residuals[:, :, :, :bp] = 0.0
            residuals[:, :, :, -bp:] = 0.0

        figsize = (12, 4 * n_samples)
        fig, axes = plt.subplots(nrows=n_samples,
                                 ncols=3,
                                 figsize=figsize,
                                 squeeze=False)

        for i, idx in enumerate(worst_indices):

            true_ax = axes[i, 0]
            recon_ax = axes[i, 1]
            res_ax = axes[i, 2]

            # add titles to axes in the first row
            if i == 0:
                true_ax.set_title("Truth", size=18)
                recon_ax.set_title("Reconstruction", size=18)
                res_ax.set_title("Residual", size=18)

            # vmin, vmax should be shared between reconstruction and true images.
            if self.max_flux_valid_plots is None:
                vmax = np.ceil(
                    max(images[idx].max().item(),
                        recon_images[idx].max().item()))
            else:
                vmax = self.max_flux_valid_plots
            vmin = np.floor(
                min(images[idx].min().item(), recon_images[idx].min().item()))
            vrange = (vmin, vmax)

            # plot!
            labels = ("t. gal", None, "t. star", None)
            # Plotting only works on square images
            assert images.shape[-2] == images.shape[-1]
            slen = images.shape[-1] - 2 * self.border_padding
            bp = self.border_padding
            image = images[i, 0].cpu().numpy()
            plocs = est.plocs[i].cpu().numpy().reshape(-1, 2)
            probs = est["galaxy_bools"][i].cpu().numpy().reshape(-1)
            plot_image(fig, true_ax, image, vrange=vrange, colorbar=True)
            plot_locs(true_ax, bp, slen, plocs, probs, cmap="cool")
            plot_image(fig, recon_ax, image, vrange=vrange, colorbar=True)
            plot_locs(recon_ax, bp, slen, plocs, probs, cmap="cool")
            residuals_idx = residuals[idx, 0].cpu().numpy()
            res_vmax = np.ceil(residuals_idx.max())
            res_vmin = np.floor(residuals_idx.min())
            if self.crop_loss_at_border:
                bp = (recon_images.shape[-1] - slen) // 2
                eff_slen = slen - bp
                for b in (bp, bp * 2):
                    recon_ax.axvline(b, color="w")
                    recon_ax.axvline(b + eff_slen, color="w")
                    recon_ax.axhline(b, color="w")
                    recon_ax.axhline(b + eff_slen, color="w")
            plot_image(fig, res_ax, residuals_idx, vrange=(res_vmin, res_vmax))
            if self.crop_loss_at_border:
                for b in (bp, bp * 2):
                    res_ax.axvline(b, color="w")
                    res_ax.axvline(b + eff_slen, color="w")
                    res_ax.axhline(b, color="w")
                    res_ax.axhline(b + eff_slen, color="w")
            if i == 0:
                add_legend(true_ax, labels)

        fig.tight_layout()
        if self.logger:
            self.logger.experiment.add_figure(
                f"Epoch:{self.current_epoch}/Worst Validation Images", fig)
        plt.close(fig)
예제 #8
0
 def forward_tile(self, tile_cat: TileCatalog):
     full_cat = tile_cat.to_full_params()
     return self(full_cat)