예제 #1
0
    def validation_step(self, batch, batch_idx):
        """Pytorch lightning method."""
        batch_size = len(batch["images"])
        out = self._get_loss(batch)

        # log all losses
        self.log("val/loss", out["loss"], batch_size=batch_size)
        self.log("val/counter_loss",
                 out["counter_loss"].mean(),
                 batch_size=batch_size)
        self.log("val/locs_loss",
                 out["locs_loss"].mean(),
                 batch_size=batch_size)
        self.log("val/star_params_loss",
                 out["star_params_loss"].mean(),
                 batch_size=batch_size)

        catalog_dict = {
            "locs": batch["locs"][:, :, :, 0:self.max_detections],
            "log_fluxes": batch["log_fluxes"][:, :, :, 0:self.max_detections],
            "galaxy_bools": batch["galaxy_bools"][:, :, :,
                                                  0:self.max_detections],
            "n_sources": batch["n_sources"].clamp(max=self.max_detections),
        }
        true_tile_catalog = TileCatalog(self.tile_slen, catalog_dict)
        true_full_catalog = true_tile_catalog.to_full_params()
        image_ptiles = get_images_in_tiles(
            torch.cat((batch["images"], batch["background"]), dim=1),
            self.tile_slen,
            self.ptile_slen,
        )
        image_ptiles = rearrange(image_ptiles,
                                 "n nth ntw b h w -> (n nth ntw) b h w")
        dist_params = self.encode(image_ptiles)
        est_catalog_dict = self.variational_mode(dist_params)
        est_tile_catalog = TileCatalog.from_flat_dict(
            true_tile_catalog.tile_slen,
            true_tile_catalog.n_tiles_h,
            true_tile_catalog.n_tiles_w,
            est_catalog_dict,
        )
        est_full_catalog = est_tile_catalog.to_full_params()

        metrics = self.val_detection_metrics(true_full_catalog,
                                             est_full_catalog)
        self.log("val/precision", metrics["precision"], batch_size=batch_size)
        self.log("val/recall", metrics["recall"], batch_size=batch_size)
        self.log("val/f1", metrics["f1"], batch_size=batch_size)
        self.log("val/avg_distance",
                 metrics["avg_distance"],
                 batch_size=batch_size)
        return batch
예제 #2
0
파일: inference.py 프로젝트: prob-ml/bliss
 def get_catalog(self, hlims, wlims) -> FullCatalog:
     h, h_end = hlims[0] - self.bp, hlims[1] - self.bp
     w, w_end = wlims[0] - self.bp, wlims[1] - self.bp
     hlims_tile = int(np.floor(h / self.tile_slen)), int(
         np.ceil(h_end / self.tile_slen))
     wlims_tile = int(np.floor(w / self.tile_slen)), int(
         np.ceil(w_end / self.tile_slen))
     tile_dict = {}
     for k, v in self.tile_catalog.to_dict().items():
         tile_dict[k] = v[:, hlims_tile[0]:hlims_tile[1],
                          wlims_tile[0]:wlims_tile[1]]
     tile_cat = TileCatalog(self.tile_slen, tile_dict)
     return tile_cat.to_full_params()
예제 #3
0
파일: binary.py 프로젝트: prob-ml/bliss
    def make_plots(self, batch, n_samples=16):
        """Produced informative plots demonstrating encoder performance."""

        assert n_samples ** (0.5) % 1 == 0
        if n_samples > len(batch["n_sources"]):  # do nothing if low on samples.
            return
        nrows = int(n_samples**0.5)  # for figure

        # extract non-params entries so that 'get_full_params' to works.
        true_param_names = {"locs", "n_sources", "galaxy_bools", "star_bools"}
        true_tile_params = TileCatalog(
            self.tile_slen, {k: v for k, v in batch.items() if k in true_param_names}
        )
        true_params = true_tile_params.to_full_params()
        # prediction
        tile_est = true_tile_params.copy()
        shape = tile_est["galaxy_bools"].shape
        tile_est["galaxy_bools"] = batch["pred_galaxy_bools"].reshape(*shape)
        tile_est["star_bools"] = batch["pred_star_bools"].reshape(*shape)
        tile_est["galaxy_probs"] = batch["pred_galaxy_probs"].reshape(*shape)
        est = tile_est.to_full_params()
        # setup figure and axes
        fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize=(12, 12))
        axes = axes.flatten()

        for ii in range(n_samples):
            ax = axes[ii]
            labels = None if ii > 0 else ("t. gal", "p. gal", "t. star", "p. star")
            bp = self.border_padding
            image = batch["images"][ii, 0].cpu().numpy()
            true_plocs = true_params.plocs[ii].cpu().numpy().reshape(-1, 2)
            true_gbools = true_params["galaxy_bools"][ii].cpu().numpy().reshape(-1)
            est_plocs = est.plocs[ii].cpu().numpy().reshape(-1, 2)
            est_gprobs = est["galaxy_probs"][ii].cpu().numpy().reshape(-1)
            slen, _ = image.shape
            reporting.plot_image(fig, ax, image, colorbar=False)
            reporting.plot_locs(ax, bp, slen, true_plocs, true_gbools, m="+", s=30, cmap="cool")
            reporting.plot_locs(ax, bp, slen, est_plocs, est_gprobs, m="x", s=20, cmap="bwr")
            if ii == 0:
                reporting.add_legend(ax, labels)
        fig.tight_layout()

        title = f"Epoch:{self.current_epoch}/Validation Images"
        if self.logger is not None:
            self.logger.experiment.add_figure(title, fig)
예제 #4
0
def stats_for_threshold(
    true_plocs: Tensor,
    est_tile_cat: TileCatalog,
    threshold: Optional[float] = None,
):
    tile_slen = est_tile_cat.tile_slen
    max_sources = est_tile_cat.max_sources
    if threshold is not None:
        log_probs = rearrange(est_tile_cat["n_source_log_probs"],
                              "n nth ntw 1 1 -> n nth ntw")
        est_tile_cat.n_sources = log_probs >= math.log(threshold)
    est_cat = est_tile_cat.to_full_params()
    number_true = true_plocs.shape[1]
    number_est = int(est_cat.plocs.shape[1])
    true_matches = torch.zeros(true_plocs.shape[1], dtype=torch.bool)
    est_tile_matches = torch.zeros(*est_tile_cat.n_sources.shape,
                                   1,
                                   1,
                                   dtype=torch.bool)
    if number_true == 0 or number_est == 0:
        return {
            "tp": torch.tensor(0.0),
            "fp": torch.tensor(float(number_est)),
            "true_matches": true_matches,
            "est_tile_matches": est_tile_matches,
        }
    est_matches = torch.zeros(est_cat.plocs.shape[1], dtype=torch.bool)
    row_indx, col_indx, d, _ = reporting.match_by_locs(true_plocs[0],
                                                       est_cat.plocs[0], 1.0)
    true_matches[row_indx] = d
    est_matches[col_indx] = d
    est_cat["matched"] = est_matches.reshape(1, -1, 1)
    est_tile_matches = est_cat.to_tile_params(tile_slen,
                                              max_sources)["matched"]
    tp = d.sum()
    fp = torch.tensor(number_est) - tp
    return {
        "tp": tp,
        "fp": fp,
        "true_matches": true_matches,
        "est_tile_matches": est_tile_matches
    }
예제 #5
0
def test_galaxy_blend(get_sdss_galaxies_config, devices):
    overrides = {
        "datasets.galsim_blended_galaxies.num_workers": 0,
        "datasets.galsim_blended_galaxies.batch_size": 4,
        "datasets.galsim_blended_galaxies.n_batches": 1,
        "datasets.galsim_blended_galaxies.prior.max_n_sources": 3,
    }
    cfg = get_sdss_galaxies_config(overrides, devices)
    blend_ds: GalsimBlends = instantiate(cfg.datasets.galsim_blended_galaxies)

    for b in blend_ds.train_dataloader():
        images, _ = b.pop("images"), b.pop("background")
        tile_cat = TileCatalog(4, b)
        full_cat = tile_cat.to_full_params()
        max_n_sources = full_cat.max_sources
        n_sources = full_cat.n_sources
        plocs = full_cat.plocs
        params = full_cat["galaxy_params"]
        snr = full_cat["snr"]
        blendedness = full_cat["blendedness"]
        ellips = full_cat["ellips"]
        mags = full_cat["mags"]
        fluxes = full_cat["fluxes"]
        assert images.shape == (4, 1, 88, 88)  # 40 + 24 * 2
        assert params.shape == (4, max_n_sources, 7)
        assert plocs.shape == (4, max_n_sources, 2)
        assert snr.shape == (4, max_n_sources, 1)
        assert blendedness.shape == (4, max_n_sources, 1)
        assert n_sources.shape == (4, )
        assert ellips.shape == (4, max_n_sources, 2)
        assert mags.shape == (4, max_n_sources, 1)
        assert fluxes.shape == (4, max_n_sources, 1)

        # check empty if no sources
        for ii, n in enumerate(n_sources):
            assert torch.all(snr[ii, n:] == 0)
            assert torch.all(blendedness[ii, n:] == 0)
            assert torch.all(plocs[ii, n:] == 0)
예제 #6
0
def get_map_estimate(
    image_encoder: DetectionEncoder, images, background, slen: int, wlen: int = None
):
    # return full estimate of parameters in full image.
    # NOTE: slen*wlen is size of the image without border padding

    if wlen is None:
        wlen = slen
    assert isinstance(slen, int) and isinstance(wlen, int)
    # check image compatibility
    border1 = (images.shape[-2] - slen) / 2
    border2 = (images.shape[-1] - wlen) / 2
    assert border1 == border2, "border paddings on each dimension differ."
    assert slen % image_encoder.tile_slen == 0, "incompatible slen"
    assert wlen % image_encoder.tile_slen == 0, "incompatible wlen"
    assert border1 == image_encoder.border_padding, "incompatible border"

    # obtained estimates per tile, then on full image.
    image_ptiles = get_images_in_tiles(
        torch.cat((images, background), dim=1),
        image_encoder.tile_slen,
        image_encoder.ptile_slen,
    )
    _, n_tiles_h, n_tiles_w, _, _, _ = image_ptiles.shape
    image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w")
    var_params = image_encoder.encode(image_ptiles)
    tile_cutoff = 25**2
    var_params2 = image_encoder.encode(image_ptiles[:tile_cutoff])

    assert torch.allclose(
        var_params["n_source_log_probs"][:tile_cutoff],
        var_params2["n_source_log_probs"],
        atol=1e-5,
    )
    assert torch.allclose(
        var_params["per_source_params"][:tile_cutoff],
        var_params2["per_source_params"],
        atol=1e-5,
    )

    tile_map_dict = image_encoder.variational_mode(var_params)
    tile_map = TileCatalog.from_flat_dict(
        image_encoder.tile_slen, n_tiles_h, n_tiles_w, tile_map_dict
    )
    full_map = tile_map.to_full_params()
    tile_map_tilde = full_map.to_tile_params(image_encoder.tile_slen, tile_map.max_sources)
    assert tile_map.equals(tile_map_tilde, exclude=("n_source_log_probs",), atol=1e-5)

    return full_map
예제 #7
0
def get_positive_negative_stats(
    true_cat: FullCatalog,
    est_tile_cat: TileCatalog,
    mag_max: float = np.inf,
):
    true_cat = true_cat.apply_mag_bin(-np.inf, mag_max)
    thresholds = np.linspace(0.01, 0.99, 99)
    est_tile_cat = est_tile_cat.copy()

    res = Parallel(n_jobs=10)(
        delayed(stats_for_threshold)(true_cat.plocs, est_tile_cat, t)
        for t in tqdm(thresholds))
    out: Dict[str, Union[int, Tensor]] = {}
    for k in res[0]:
        out[k] = torch.stack([r[k] for r in res])
    out["n_obj"] = true_cat.plocs.shape[1]
    return out
예제 #8
0
    def _get_loss(self, batch):
        images: Tensor = batch["images"]
        background: Tensor = batch["background"]
        tile_catalog = TileCatalog(self.tile_slen, {
            k: v
            for k, v in batch.items() if k not in {"images", "background"}
        })

        image_ptiles = get_images_in_tiles(
            torch.cat((images, background), dim=1),
            self.tile_slen,
            self.ptile_slen,
        )
        image_ptiles = rearrange(image_ptiles,
                                 "n nth ntw b h w -> (n nth ntw) b h w")
        locs = rearrange(tile_catalog.locs,
                         "n nth ntw ns hw -> 1 (n nth ntw) ns hw")
        galaxy_params, pq_divergence = self._encode(image_ptiles, locs)
        # draw fully reconstructed image.
        # NOTE: Assume recon_mean = recon_var per poisson approximation.
        tile_catalog["galaxy_params"] = rearrange(
            galaxy_params,
            "ns (n nth ntw) ms d -> (ns n) nth ntw ms d",
            nth=tile_catalog.n_tiles_h,
            ntw=tile_catalog.n_tiles_w,
        )
        recon_mean = self.image_decoder.render_images(tile_catalog)
        recon_mean = rearrange(recon_mean, "(ns n) c h w -> ns n c h w", ns=1)
        recon_mean += background.unsqueeze(0)

        assert not torch.any(torch.isnan(recon_mean))
        assert not torch.any(torch.isinf(recon_mean))
        recon_losses = -Normal(recon_mean, recon_mean.sqrt()).log_prob(
            images.unsqueeze(0))
        if self.crop_loss_at_border:
            bp = self.border_padding * 2
            recon_losses = recon_losses[:, :, :, bp:(-bp), bp:(-bp)]
        assert not torch.any(torch.isnan(recon_losses))
        assert not torch.any(torch.isinf(recon_losses))

        # For divergence loss, we only evaluate tiles with a galaxy in them
        galaxy_bools = rearrange(tile_catalog["galaxy_bools"],
                                 "n nth ntw ms 1 -> 1 (n nth ntw) ms")
        divergence_loss = (pq_divergence * galaxy_bools).sum()
        return recon_losses.sum() - divergence_loss
예제 #9
0
def make_plots_of_marginal_class(
    outdir: Path,
    encoder: Encoder,
    decoder: ImageDecoder,
    frame: Frame,
    tile_map_recon: TileCatalog,
    detections_at_mode,
):
    tile_map_recon["matched"] = detections_at_mode["est_tile_matches"]
    bp = encoder.border_padding
    full_map = tile_map_recon.to_full_params()
    marginal_galaxy = full_map["galaxy_probs"][0, :, 0].exp() <= 0.6
    marginal_star = full_map["galaxy_probs"][0, :, 0].exp() >= 0.4
    marginal = marginal_galaxy & marginal_star
    csv_lines = ["fname,in_coadd,ra,dec,prob"]
    for i, ploc in tqdm(enumerate(full_map.plocs[0]),
                        desc="marginal classifications"):
        if marginal[i] and full_map["matched"][0, i, 0].item():
            size = 40
            h = ploc[0].item()
            w = ploc[1].item()
            h_topleft = max(int(h - (size / 2.0)), 0) + 24
            w_topleft = max(int(w - (size / 2.0)), 0) + 24
            fig = create_figure_at_point(
                h_topleft,
                w_topleft,
                size,
                bp,
                tile_map_recon,
                frame,
                decoder,
            )
            fname = f"h{h_topleft}_w{w_topleft}.png"
            fig.savefig(outdir / fname)
            matched = full_map["matched"][0, i, 0]

            if isinstance(frame, SDSSFrame):
                ra, dec = frame.wcs.wcs_pix2world(w + 24, h + 24, 0)
            else:
                ra, dec = None, None
            prob = full_map["galaxy_probs"][0, i, 0].exp().item()
            csv_lines.append(f"{fname},{matched},{ra},{dec},{prob}")
    out_csv = outdir / "marginal_detections.csv"
    out_csv.write_text("\n".join(csv_lines))
예제 #10
0
파일: decoder.py 프로젝트: prob-ml/bliss
    def render_large_scene(
        self, tile_catalog: TileCatalog, batch_size: Optional[int] = None
    ) -> Tensor:
        if batch_size is None:
            batch_size = 75**2 + 500 * 2

        _, n_tiles_h, n_tiles_w, _, _ = tile_catalog.locs.shape
        n_rows_per_batch = batch_size // n_tiles_w
        h = tile_catalog.locs.shape[1] * tile_catalog.tile_slen + 2 * self.border_padding
        w = tile_catalog.locs.shape[2] * tile_catalog.tile_slen + 2 * self.border_padding
        scene = torch.zeros(1, 1, h, w)
        for row in range(0, n_tiles_h, n_rows_per_batch):
            end_row = row + n_rows_per_batch
            start_h = row * tile_catalog.tile_slen
            end_h = end_row * tile_catalog.tile_slen + 2 * self.border_padding
            tile_cat_row = tile_catalog.crop((row, end_row), (0, None))
            img_row = self.render_images(tile_cat_row)
            scene[:, :, start_h:end_h] += img_row.cpu()
        return scene
예제 #11
0
def make_images_of_example_blend(blend_dir: Path, encoder: Encoder,
                                 decoder: ImageDecoder, frame: Frame):
    slen = 40
    h = 1400
    h_end = h + slen
    w = 1710
    w_end = w + slen
    hlims = (h, h_end)
    wlims = (w, w_end)
    recon, tile_map_recon = reconstruct_scene_at_coordinates(
        encoder,
        decoder,
        frame.image,
        frame.background,
        hlims,
        wlims,
    )
    bp = encoder.border_padding
    img = frame.image[:, :, h:h_end, w:w_end]
    bg = frame.background[:, :, h:h_end, w:w_end]
    resid = (img - recon) / recon.sqrt()
    plt.imsave(blend_dir / "img.png", img[0, 0])
    plt.imsave(blend_dir / "recon.png", recon[0, 0])
    plt.imsave(blend_dir / "resid.png", resid[0, 0])

    masks = ((4, 5), (3, 7))

    for (i, (h_mask, w_mask)) in enumerate(masks):
        tile_onegal_dict = tile_map_recon.to_dict()
        tile_onegal_dict["n_sources"] = tile_onegal_dict["n_sources"].clone()
        tile_onegal_dict["galaxy_bools"] = tile_onegal_dict[
            "galaxy_bools"].clone()
        tile_onegal_dict["galaxy_bools"][:, h_mask, w_mask] = 0.0
        tile_onegal_dict["n_sources"][0, h_mask, w_mask] = 0
        tile_map_one_galaxy = TileCatalog(tile_map_recon.tile_slen,
                                          tile_onegal_dict)
        recon_one_galaxy = decoder.render_images(
            tile_map_one_galaxy).detach().cpu()
        recon_one_galaxy = recon_one_galaxy[:, :, bp:-bp, bp:-bp] + bg
        plt.imsave(blend_dir / f"galaxy_{i}.png",
                   recon_one_galaxy[0, 0, :, :],
                   vmax=recon.max())
예제 #12
0
파일: prior.py 프로젝트: prob-ml/bliss
    def sample_prior(self, tile_slen: int, batch_size: int, n_tiles_h: int,
                     n_tiles_w: int) -> TileCatalog:
        """Samples latent variables from the prior of an astronomical image.

        Args:
            tile_slen: Side length of catalog tiles.
            batch_size: The number of samples to draw.
            n_tiles_h: Number of tiles height-wise.
            n_tiles_w: Number of tiles width-wise.

        Returns:
            A dictionary of tensors. Each tensor is a particular per-tile quantity; i.e.
            the first three dimensions of each tensor are
            `(batch_size, self.n_tiles_h, self.n_tiles_w)`.
            The remaining dimensions are variable-specific.
        """
        assert n_tiles_h > 0
        assert n_tiles_w > 0
        n_sources = self._sample_n_sources(batch_size, n_tiles_h, n_tiles_w)
        is_on_array = get_is_on_from_n_sources(n_sources, self.max_sources)
        locs = self._sample_locs(is_on_array)

        galaxy_bools, star_bools = self._sample_n_galaxies_and_stars(
            is_on_array)
        galaxy_params = self._sample_galaxy_params(galaxy_bools)
        fluxes = self._sample_fluxes(star_bools)
        log_fluxes = self._get_log_fluxes(fluxes)

        # per tile quantities.
        return TileCatalog(
            tile_slen,
            {
                "n_sources": n_sources,
                "locs": locs,
                "galaxy_bools": galaxy_bools,
                "star_bools": star_bools,
                "galaxy_params": galaxy_params,
                "fluxes": fluxes,
                "log_fluxes": log_fluxes,
            },
        )
예제 #13
0
    def variational_mode(self, image: Tensor,
                         background: Tensor) -> TileCatalog:
        """Get maximum a posteriori of catalog from image padded tiles.

        Note that, strictly speaking, this is not the true MAP of the variational
        distribution of the catalog.
        Rather, we use sequential estimation; the MAP of the locations is first estimated,
        then plugged-in to the binary and galaxy encoders. Thus, the binary and galaxy
        encoders are conditioned on the location MAP. The true MAP would require optimizing
        over the entire catalog jointly, but this is not tractable.

        Args:
            image: An astronomical image,
                with shape `n * n_bands * h * w`.
            background: Background associated with image,
                with shape `n * n_bands * h * w`.

        Returns:
            A dictionary of the maximum a posteriori
            of the catalog in tiles. Specifically, this dictionary comprises:
                - The output of DetectionEncoder.variational_mode()
                - 'galaxy_bools', 'star_bools', and 'galaxy_probs' from BinaryEncoder.
                - 'galaxy_params' from GalaxyEncoder.
        """
        tile_map_dict = self.sample(image, background, None)
        n_tiles_h = (image.shape[2] - 2 *
                     self.border_padding) // self.detection_encoder.tile_slen
        n_tiles_w = (image.shape[3] - 2 *
                     self.border_padding) // self.detection_encoder.tile_slen
        return TileCatalog.from_flat_dict(
            self.detection_encoder.tile_slen,
            n_tiles_h,
            n_tiles_w,
            {k: v.squeeze(0)
             for k, v in tile_map_dict.items()},
        )
예제 #14
0
 def forward_tile(self, tile_cat: TileCatalog):
     full_cat = tile_cat.to_full_params()
     return self(full_cat)
예제 #15
0
    def validation_epoch_end(self,
                             outputs,
                             kind="validation",
                             max_n_samples=16):
        # pylint: disable=too-many-statements
        """Pytorch lightning method."""
        batch: Dict[str, Tensor] = outputs[-1]
        if self.n_bands > 1:
            return
        n_samples = min(
            int(math.sqrt(len(batch["n_sources"])))**2, max_n_samples)
        nrows = int(n_samples**0.5)  # for figure

        catalog_dict = {
            "locs": batch["locs"][:, :, :, 0:self.max_detections],
            "log_fluxes": batch["log_fluxes"][:, :, :, 0:self.max_detections],
            "galaxy_bools": batch["galaxy_bools"][:, :, :,
                                                  0:self.max_detections],
            "star_bools": batch["star_bools"][:, :, :, 0:self.max_detections],
            "n_sources": batch["n_sources"].clamp(max=self.max_detections),
        }
        true_tile_catalog = TileCatalog(self.tile_slen, catalog_dict)
        true_cat = true_tile_catalog.to_full_params()

        image_ptiles = get_images_in_tiles(
            torch.cat((batch["images"], batch["background"]), dim=1),
            self.tile_slen,
            self.ptile_slen,
        )
        image_ptiles = rearrange(image_ptiles,
                                 "n nth ntw b h w -> (n nth ntw) b h w")
        dist_params = self.encode(image_ptiles)

        est_catalog_dict = self.variational_mode(dist_params)
        est_tile_catalog = TileCatalog.from_flat_dict(
            true_tile_catalog.tile_slen,
            true_tile_catalog.n_tiles_h,
            true_tile_catalog.n_tiles_w,
            est_catalog_dict,
        )
        est_cat = est_tile_catalog.to_full_params()

        # setup figure and axes.
        fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize=(12, 12))
        axes = axes.flatten() if nrows > 1 else [axes]

        images = batch["images"]
        assert images.shape[-2] == images.shape[-1]
        bp = self.border_padding
        for idx, ax in enumerate(axes):
            true_n_sources = true_cat.n_sources[idx].item()
            n_sources = est_cat.n_sources[idx].item()
            ax.set_xlabel(f"True num: {true_n_sources}; Est num: {n_sources}")

            # add white border showing where centers of stars and galaxies can be
            ax.axvline(bp, color="w")
            ax.axvline(images.shape[-1] - bp, color="w")
            ax.axhline(bp, color="w")
            ax.axhline(images.shape[-2] - bp, color="w")

            # plot image first
            image = images[idx, 0].cpu().numpy()
            vmin = image.min().item()
            vmax = image.max().item()
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            im = ax.matshow(image, vmin=vmin, vmax=vmax, cmap="viridis")
            fig.colorbar(im, cax=cax, orientation="vertical")

            true_cat.plot_plocs(ax,
                                idx,
                                "galaxy",
                                bp=bp,
                                color="r",
                                marker="x",
                                s=20)
            true_cat.plot_plocs(ax,
                                idx,
                                "star",
                                bp=bp,
                                color="c",
                                marker="x",
                                s=20)
            est_cat.plot_plocs(ax,
                               idx,
                               "all",
                               bp=bp,
                               color="b",
                               marker="+",
                               s=30)

            if idx == 0:
                ax.scatter(None,
                           None,
                           color="r",
                           marker="x",
                           s=20,
                           label="t.gal")
                ax.scatter(None,
                           None,
                           color="c",
                           marker="x",
                           s=20,
                           label="t.star")
                ax.scatter(None,
                           None,
                           color="b",
                           marker="+",
                           s=30,
                           label="p.source")
                ax.legend(
                    bbox_to_anchor=(0.0, 1.2, 1.0, 0.102),
                    loc="lower left",
                    ncol=2,
                    mode="expand",
                    borderaxespad=0.0,
                )

        fig.tight_layout()
        if self.logger:
            if kind == "validation":
                title = f"Epoch:{self.current_epoch}/Validation Images"
                self.logger.experiment.add_figure(title, fig)
            elif kind == "testing":
                self.logger.experiment.add_figure("Test Images", fig)
            else:
                raise NotImplementedError()
        plt.close(fig)
예제 #16
0
    def compute_data(self, blend_file: Path, encoder: Encoder,
                     decoder: ImageDecoder):
        blend_data = torch.load(blend_file)
        images = blend_data.pop("images")
        background = blend_data.pop("background")
        n_batches, _, slen, _ = images.shape
        assert background.shape == (1, slen, slen)

        # prepare background
        background = background.unsqueeze(0)
        background = background.expand(n_batches, 1, slen, slen)

        # first create FullCatalog from simulated data
        tile_cat = TileCatalog(decoder.tile_slen, blend_data).cpu()
        full_truth = tile_cat.to_full_params()

        print("INFO: BLISS posterior inference on images.")
        tile_est = encoder.variational_mode(images, background)
        tile_est.set_all_fluxes_and_mags(decoder)
        tile_est.set_galaxy_ellips(decoder, scale=0.393)
        tile_est = tile_est.cpu()
        full_est = tile_est.to_full_params()

        snr = []
        blendedness = []
        true_mags = []
        true_ellips1 = []
        true_ellips2 = []
        est_mags = []
        est_ellips1 = []
        est_ellips2 = []
        for ii in tqdm(range(n_batches), desc="Matching batches"):
            true_plocs_ii, est_plocs_ii = full_truth.plocs[ii], full_est.plocs[
                ii]

            tindx, eindx, dkeep, _ = match_by_locs(true_plocs_ii, est_plocs_ii)
            n_matches = len(tindx[dkeep])

            snr_ii = full_truth["snr"][ii][tindx][dkeep]
            blendedness_ii = full_truth["blendedness"][ii][tindx][dkeep]
            true_mag_ii = full_truth["mags"][ii][tindx][dkeep]
            est_mag_ii = full_est["mags"][ii][eindx][dkeep]
            true_ellips_ii = full_truth["ellips"][ii][tindx][dkeep]
            est_ellips_ii = full_est["ellips"][ii][eindx][dkeep]
            n_matches = len(snr_ii)
            for jj in range(n_matches):
                snr.append(snr_ii[jj].item())
                blendedness.append(blendedness_ii[jj].item())
                true_mags.append(true_mag_ii[jj].item())
                est_mags.append(est_mag_ii[jj].item())
                true_ellips1.append(true_ellips_ii[jj][0].item())
                true_ellips2.append(true_ellips_ii[jj][1].item())
                est_ellips1.append(est_ellips_ii[jj][0].item())
                est_ellips2.append(est_ellips_ii[jj][1].item())

        true_ellips = torch.vstack(
            [torch.tensor(true_ellips1),
             torch.tensor(true_ellips2)])
        true_ellips = true_ellips.T.reshape(-1, 2)

        est_ellips = torch.vstack(
            [torch.tensor(est_ellips1),
             torch.tensor(est_ellips2)])
        est_ellips = est_ellips.T.reshape(-1, 2)

        return {
            "snr": torch.tensor(snr),
            "blendedness": torch.tensor(blendedness),
            "true_mags": torch.tensor(true_mags),
            "est_mags": torch.tensor(est_mags),
            "true_ellips": true_ellips,
            "est_ellips": est_ellips,
        }
예제 #17
0
파일: inference.py 프로젝트: prob-ml/bliss
    def __init__(
        self,
        dataset: SimulatedDataset,
        coadd: str,
        n_tiles_h,
        n_tiles_w,
        cache_dir=None,
    ):
        dataset.to("cpu")
        self.bp = dataset.image_decoder.border_padding
        self.tile_slen = dataset.tile_slen
        self.coadd_file = Path(coadd)
        self.sdss_dir = self.coadd_file.parent
        run = 94
        camcol = 1
        field = 12
        bands = (2, )
        sdss_data = SloanDigitalSkySurvey(
            sdss_dir=self.sdss_dir,
            run=run,
            camcol=camcol,
            fields=(field, ),
            bands=bands,
        )
        wcs = sdss_data[0]["wcs"][0]
        if cache_dir is not None:
            sim_frame_path = Path(cache_dir) / "simulated_frame.pt"
        else:
            sim_frame_path = None
        if sim_frame_path and sim_frame_path.exists():
            tile_catalog_dict, image, background = torch.load(sim_frame_path)
            tile_catalog = TileCatalog(self.tile_slen, tile_catalog_dict)
        else:
            hlim = (self.bp, self.bp + n_tiles_h * self.tile_slen)
            wlim = (self.bp, self.bp + n_tiles_w * self.tile_slen)
            full_coadd_cat = CoaddFullCatalog.from_file(coadd,
                                                        wcs,
                                                        hlim,
                                                        wlim,
                                                        band="r")
            if dataset.image_prior.galaxy_prior is not None:
                full_coadd_cat[
                    "galaxy_params"] = dataset.image_prior.galaxy_prior.sample(
                        full_coadd_cat.n_sources, "cpu").unsqueeze(0)
            full_coadd_cat.plocs = full_coadd_cat.plocs + 0.5
            max_sources = dataset.image_prior.max_sources
            tile_catalog = full_coadd_cat.to_tile_params(
                self.tile_slen, max_sources)
            tile_catalog.set_all_fluxes_and_mags(dataset.image_decoder)
            fc = tile_catalog.to_full_params()
            assert fc.equals(full_coadd_cat,
                             exclude=("galaxy_fluxes", "star_bools", "fluxes",
                                      "mags"))
            print("INFO: started generating frame")
            image, background = dataset.simulate_image_from_catalog(
                tile_catalog)
            print("INFO: done generating frame")
            if sim_frame_path:
                torch.save((tile_catalog.to_dict(), image, background),
                           sim_frame_path)

        self.tile_catalog = tile_catalog
        self.image = image
        self.background = background
        assert self.image.shape[0] == 1
        assert self.background.shape[0] == 1
예제 #18
0
    def _make_plots(self, batch, n_samples=5):
        # validate worst reconstruction images.
        n_samples = min(len(batch["n_sources"]), n_samples)
        samples = np.random.choice(len(batch["n_sources"]),
                                   n_samples,
                                   replace=False)
        keys = [
            "images",
            "background",
            "locs",
            "galaxy_bools",
            "star_bools",
            "fluxes",
            "log_fluxes",
            "n_sources",
        ]
        for k in keys:
            batch[k] = batch[k][samples]

        # extract non-params entries so that 'get_full_params' to works.
        images = batch["images"]
        background = batch["background"]
        tile_locs = batch["locs"]

        # obtain map estimates
        image_ptiles = get_images_in_tiles(
            torch.cat((images, background), dim=1),
            self.tile_slen,
            self.ptile_slen,
        )
        _, n_tiles_h, n_tiles_w, _, _, _ = image_ptiles.shape
        image_ptiles = rearrange(image_ptiles,
                                 "n nth ntw b h w -> (n nth ntw) b h w")
        locs = rearrange(tile_locs, "n nth ntw ns hw -> 1 (n nth ntw) ns hw")
        z, _ = self._encode(image_ptiles, locs)
        galaxy_params = rearrange(
            z,
            "ns (n nth ntw) ms d -> (ns n) nth ntw ms d",
            ns=1,
            nth=n_tiles_h,
            ntw=n_tiles_w,
        )

        tile_est = TileCatalog(
            self.tile_slen,
            {
                "n_sources": batch["n_sources"],
                "locs": batch["locs"],
                "galaxy_bools": batch["galaxy_bools"],
                "star_bools": batch["star_bools"],
                "fluxes": batch["fluxes"],
                "log_fluxes": batch["log_fluxes"],
                "galaxy_params": galaxy_params,
            },
        )
        est = tile_est.to_full_params()

        # draw all reconstruction images.
        # render_images automatically accounts for tiles with no galaxies.
        recon_images = self.image_decoder.render_images(tile_est)
        recon_images += background
        residuals = (images - recon_images) / torch.sqrt(recon_images)

        # draw worst `n_samples` examples as measured by absolute avg. residual error.
        worst_indices = residuals.abs().mean(dim=(1, 2, 3)).argsort(
            descending=True)[:n_samples]

        if self.crop_loss_at_border:
            bp = self.border_padding * 2
            residuals[:, :, :bp, :] = 0.0
            residuals[:, :, -bp:, :] = 0.0
            residuals[:, :, :, :bp] = 0.0
            residuals[:, :, :, -bp:] = 0.0

        figsize = (12, 4 * n_samples)
        fig, axes = plt.subplots(nrows=n_samples,
                                 ncols=3,
                                 figsize=figsize,
                                 squeeze=False)

        for i, idx in enumerate(worst_indices):

            true_ax = axes[i, 0]
            recon_ax = axes[i, 1]
            res_ax = axes[i, 2]

            # add titles to axes in the first row
            if i == 0:
                true_ax.set_title("Truth", size=18)
                recon_ax.set_title("Reconstruction", size=18)
                res_ax.set_title("Residual", size=18)

            # vmin, vmax should be shared between reconstruction and true images.
            if self.max_flux_valid_plots is None:
                vmax = np.ceil(
                    max(images[idx].max().item(),
                        recon_images[idx].max().item()))
            else:
                vmax = self.max_flux_valid_plots
            vmin = np.floor(
                min(images[idx].min().item(), recon_images[idx].min().item()))
            vrange = (vmin, vmax)

            # plot!
            labels = ("t. gal", None, "t. star", None)
            # Plotting only works on square images
            assert images.shape[-2] == images.shape[-1]
            slen = images.shape[-1] - 2 * self.border_padding
            bp = self.border_padding
            image = images[i, 0].cpu().numpy()
            plocs = est.plocs[i].cpu().numpy().reshape(-1, 2)
            probs = est["galaxy_bools"][i].cpu().numpy().reshape(-1)
            plot_image(fig, true_ax, image, vrange=vrange, colorbar=True)
            plot_locs(true_ax, bp, slen, plocs, probs, cmap="cool")
            plot_image(fig, recon_ax, image, vrange=vrange, colorbar=True)
            plot_locs(recon_ax, bp, slen, plocs, probs, cmap="cool")
            residuals_idx = residuals[idx, 0].cpu().numpy()
            res_vmax = np.ceil(residuals_idx.max())
            res_vmin = np.floor(residuals_idx.min())
            if self.crop_loss_at_border:
                bp = (recon_images.shape[-1] - slen) // 2
                eff_slen = slen - bp
                for b in (bp, bp * 2):
                    recon_ax.axvline(b, color="w")
                    recon_ax.axvline(b + eff_slen, color="w")
                    recon_ax.axhline(b, color="w")
                    recon_ax.axhline(b + eff_slen, color="w")
            plot_image(fig, res_ax, residuals_idx, vrange=(res_vmin, res_vmax))
            if self.crop_loss_at_border:
                for b in (bp, bp * 2):
                    res_ax.axvline(b, color="w")
                    res_ax.axvline(b + eff_slen, color="w")
                    res_ax.axhline(b, color="w")
                    res_ax.axhline(b + eff_slen, color="w")
            if i == 0:
                add_legend(true_ax, labels)

        fig.tight_layout()
        if self.logger:
            self.logger.experiment.add_figure(
                f"Epoch:{self.current_epoch}/Worst Validation Images", fig)
        plt.close(fig)
예제 #19
0
def create_figure_at_point(
    h: int,
    w: int,
    size: int,
    bp: int,
    tile_map_recon: TileCatalog,
    frame: Frame,
    dec: ImageDecoder,
    est_catalog: Optional[FullCatalog] = None,
    show_tiles=False,
    use_image_bounds=False,
    **kwargs,
):
    tile_slen = tile_map_recon.tile_slen
    if h + size + bp > frame.image.shape[2]:
        h = frame.image.shape[2] - size - bp
    if w + size + bp > frame.image.shape[3]:
        w = frame.image.shape[3] - size - bp
    h_tile = (h - bp) // tile_slen
    w_tile = (w - bp) // tile_slen
    n_tiles = size // tile_slen
    hlims = bp + (h_tile * tile_slen), bp + ((h_tile + n_tiles) * tile_slen)
    wlims = bp + (w_tile * tile_slen), bp + ((w_tile + n_tiles) * tile_slen)

    tile_map_cropped = tile_map_recon.crop((h_tile, h_tile + n_tiles),
                                           (w_tile, w_tile + n_tiles))
    full_map_cropped = tile_map_cropped.to_full_params()

    img_cropped = frame.image[0, 0, hlims[0]:hlims[1], wlims[0]:wlims[1]]
    bg_cropped = frame.background[0, 0, hlims[0]:hlims[1], wlims[0]:wlims[1]]
    with torch.no_grad():
        recon_cropped = dec.render_images(tile_map_cropped.to(dec.device))
        recon_cropped = recon_cropped.to("cpu")[0, 0, bp:-bp,
                                                bp:-bp] + bg_cropped
        resid_cropped = (img_cropped - recon_cropped) / recon_cropped.sqrt()

    if est_catalog is not None:
        tile_est_catalog = est_catalog.to_tile_params(
            tile_map_cropped.tile_slen, tile_map_cropped.max_sources)
        tile_est_catalog_cropped = tile_est_catalog.crop(
            (h_tile, h_tile + n_tiles), (w_tile, w_tile + n_tiles))
        est_catalog_cropped = tile_est_catalog_cropped.to_full_params()
    else:
        est_catalog_cropped = None
    if show_tiles:
        tile_map = tile_map_cropped
    else:
        tile_map = None
    if use_image_bounds:
        vmin = img_cropped.min().item()
        vmax = img_cropped.max().item()
    else:
        vmin = 800
        vmax = 1200
    return create_figure(
        img_cropped,
        recon_cropped,
        resid_cropped,
        map_recon=full_map_cropped,
        coadd_objects=est_catalog_cropped,
        tile_map=tile_map,
        vmin=vmin,
        vmax=vmax,
        **kwargs,
    )