def get_catalog(self, hlims, wlims) -> FullCatalog: h, h_end = hlims[0] - self.bp, hlims[1] - self.bp w, w_end = wlims[0] - self.bp, wlims[1] - self.bp hlims_tile = int(np.floor(h / self.tile_slen)), int( np.ceil(h_end / self.tile_slen)) wlims_tile = int(np.floor(w / self.tile_slen)), int( np.ceil(w_end / self.tile_slen)) tile_dict = {} for k, v in self.tile_catalog.to_dict().items(): tile_dict[k] = v[:, hlims_tile[0]:hlims_tile[1], wlims_tile[0]:wlims_tile[1]] tile_cat = TileCatalog(self.tile_slen, tile_dict) return tile_cat.to_full_params()
def make_plots_of_marginal_class( outdir: Path, encoder: Encoder, decoder: ImageDecoder, frame: Frame, tile_map_recon: TileCatalog, detections_at_mode, ): tile_map_recon["matched"] = detections_at_mode["est_tile_matches"] bp = encoder.border_padding full_map = tile_map_recon.to_full_params() marginal_galaxy = full_map["galaxy_probs"][0, :, 0].exp() <= 0.6 marginal_star = full_map["galaxy_probs"][0, :, 0].exp() >= 0.4 marginal = marginal_galaxy & marginal_star csv_lines = ["fname,in_coadd,ra,dec,prob"] for i, ploc in tqdm(enumerate(full_map.plocs[0]), desc="marginal classifications"): if marginal[i] and full_map["matched"][0, i, 0].item(): size = 40 h = ploc[0].item() w = ploc[1].item() h_topleft = max(int(h - (size / 2.0)), 0) + 24 w_topleft = max(int(w - (size / 2.0)), 0) + 24 fig = create_figure_at_point( h_topleft, w_topleft, size, bp, tile_map_recon, frame, decoder, ) fname = f"h{h_topleft}_w{w_topleft}.png" fig.savefig(outdir / fname) matched = full_map["matched"][0, i, 0] if isinstance(frame, SDSSFrame): ra, dec = frame.wcs.wcs_pix2world(w + 24, h + 24, 0) else: ra, dec = None, None prob = full_map["galaxy_probs"][0, i, 0].exp().item() csv_lines.append(f"{fname},{matched},{ra},{dec},{prob}") out_csv = outdir / "marginal_detections.csv" out_csv.write_text("\n".join(csv_lines))
def stats_for_threshold( true_plocs: Tensor, est_tile_cat: TileCatalog, threshold: Optional[float] = None, ): tile_slen = est_tile_cat.tile_slen max_sources = est_tile_cat.max_sources if threshold is not None: log_probs = rearrange(est_tile_cat["n_source_log_probs"], "n nth ntw 1 1 -> n nth ntw") est_tile_cat.n_sources = log_probs >= math.log(threshold) est_cat = est_tile_cat.to_full_params() number_true = true_plocs.shape[1] number_est = int(est_cat.plocs.shape[1]) true_matches = torch.zeros(true_plocs.shape[1], dtype=torch.bool) est_tile_matches = torch.zeros(*est_tile_cat.n_sources.shape, 1, 1, dtype=torch.bool) if number_true == 0 or number_est == 0: return { "tp": torch.tensor(0.0), "fp": torch.tensor(float(number_est)), "true_matches": true_matches, "est_tile_matches": est_tile_matches, } est_matches = torch.zeros(est_cat.plocs.shape[1], dtype=torch.bool) row_indx, col_indx, d, _ = reporting.match_by_locs(true_plocs[0], est_cat.plocs[0], 1.0) true_matches[row_indx] = d est_matches[col_indx] = d est_cat["matched"] = est_matches.reshape(1, -1, 1) est_tile_matches = est_cat.to_tile_params(tile_slen, max_sources)["matched"] tp = d.sum() fp = torch.tensor(number_est) - tp return { "tp": tp, "fp": fp, "true_matches": true_matches, "est_tile_matches": est_tile_matches }
def test_galaxy_blend(get_sdss_galaxies_config, devices): overrides = { "datasets.galsim_blended_galaxies.num_workers": 0, "datasets.galsim_blended_galaxies.batch_size": 4, "datasets.galsim_blended_galaxies.n_batches": 1, "datasets.galsim_blended_galaxies.prior.max_n_sources": 3, } cfg = get_sdss_galaxies_config(overrides, devices) blend_ds: GalsimBlends = instantiate(cfg.datasets.galsim_blended_galaxies) for b in blend_ds.train_dataloader(): images, _ = b.pop("images"), b.pop("background") tile_cat = TileCatalog(4, b) full_cat = tile_cat.to_full_params() max_n_sources = full_cat.max_sources n_sources = full_cat.n_sources plocs = full_cat.plocs params = full_cat["galaxy_params"] snr = full_cat["snr"] blendedness = full_cat["blendedness"] ellips = full_cat["ellips"] mags = full_cat["mags"] fluxes = full_cat["fluxes"] assert images.shape == (4, 1, 88, 88) # 40 + 24 * 2 assert params.shape == (4, max_n_sources, 7) assert plocs.shape == (4, max_n_sources, 2) assert snr.shape == (4, max_n_sources, 1) assert blendedness.shape == (4, max_n_sources, 1) assert n_sources.shape == (4, ) assert ellips.shape == (4, max_n_sources, 2) assert mags.shape == (4, max_n_sources, 1) assert fluxes.shape == (4, max_n_sources, 1) # check empty if no sources for ii, n in enumerate(n_sources): assert torch.all(snr[ii, n:] == 0) assert torch.all(blendedness[ii, n:] == 0) assert torch.all(plocs[ii, n:] == 0)
def compute_data(self, blend_file: Path, encoder: Encoder, decoder: ImageDecoder): blend_data = torch.load(blend_file) images = blend_data.pop("images") background = blend_data.pop("background") n_batches, _, slen, _ = images.shape assert background.shape == (1, slen, slen) # prepare background background = background.unsqueeze(0) background = background.expand(n_batches, 1, slen, slen) # first create FullCatalog from simulated data tile_cat = TileCatalog(decoder.tile_slen, blend_data).cpu() full_truth = tile_cat.to_full_params() print("INFO: BLISS posterior inference on images.") tile_est = encoder.variational_mode(images, background) tile_est.set_all_fluxes_and_mags(decoder) tile_est.set_galaxy_ellips(decoder, scale=0.393) tile_est = tile_est.cpu() full_est = tile_est.to_full_params() snr = [] blendedness = [] true_mags = [] true_ellips1 = [] true_ellips2 = [] est_mags = [] est_ellips1 = [] est_ellips2 = [] for ii in tqdm(range(n_batches), desc="Matching batches"): true_plocs_ii, est_plocs_ii = full_truth.plocs[ii], full_est.plocs[ ii] tindx, eindx, dkeep, _ = match_by_locs(true_plocs_ii, est_plocs_ii) n_matches = len(tindx[dkeep]) snr_ii = full_truth["snr"][ii][tindx][dkeep] blendedness_ii = full_truth["blendedness"][ii][tindx][dkeep] true_mag_ii = full_truth["mags"][ii][tindx][dkeep] est_mag_ii = full_est["mags"][ii][eindx][dkeep] true_ellips_ii = full_truth["ellips"][ii][tindx][dkeep] est_ellips_ii = full_est["ellips"][ii][eindx][dkeep] n_matches = len(snr_ii) for jj in range(n_matches): snr.append(snr_ii[jj].item()) blendedness.append(blendedness_ii[jj].item()) true_mags.append(true_mag_ii[jj].item()) est_mags.append(est_mag_ii[jj].item()) true_ellips1.append(true_ellips_ii[jj][0].item()) true_ellips2.append(true_ellips_ii[jj][1].item()) est_ellips1.append(est_ellips_ii[jj][0].item()) est_ellips2.append(est_ellips_ii[jj][1].item()) true_ellips = torch.vstack( [torch.tensor(true_ellips1), torch.tensor(true_ellips2)]) true_ellips = true_ellips.T.reshape(-1, 2) est_ellips = torch.vstack( [torch.tensor(est_ellips1), torch.tensor(est_ellips2)]) est_ellips = est_ellips.T.reshape(-1, 2) return { "snr": torch.tensor(snr), "blendedness": torch.tensor(blendedness), "true_mags": torch.tensor(true_mags), "est_mags": torch.tensor(est_mags), "true_ellips": true_ellips, "est_ellips": est_ellips, }
def __init__( self, dataset: SimulatedDataset, coadd: str, n_tiles_h, n_tiles_w, cache_dir=None, ): dataset.to("cpu") self.bp = dataset.image_decoder.border_padding self.tile_slen = dataset.tile_slen self.coadd_file = Path(coadd) self.sdss_dir = self.coadd_file.parent run = 94 camcol = 1 field = 12 bands = (2, ) sdss_data = SloanDigitalSkySurvey( sdss_dir=self.sdss_dir, run=run, camcol=camcol, fields=(field, ), bands=bands, ) wcs = sdss_data[0]["wcs"][0] if cache_dir is not None: sim_frame_path = Path(cache_dir) / "simulated_frame.pt" else: sim_frame_path = None if sim_frame_path and sim_frame_path.exists(): tile_catalog_dict, image, background = torch.load(sim_frame_path) tile_catalog = TileCatalog(self.tile_slen, tile_catalog_dict) else: hlim = (self.bp, self.bp + n_tiles_h * self.tile_slen) wlim = (self.bp, self.bp + n_tiles_w * self.tile_slen) full_coadd_cat = CoaddFullCatalog.from_file(coadd, wcs, hlim, wlim, band="r") if dataset.image_prior.galaxy_prior is not None: full_coadd_cat[ "galaxy_params"] = dataset.image_prior.galaxy_prior.sample( full_coadd_cat.n_sources, "cpu").unsqueeze(0) full_coadd_cat.plocs = full_coadd_cat.plocs + 0.5 max_sources = dataset.image_prior.max_sources tile_catalog = full_coadd_cat.to_tile_params( self.tile_slen, max_sources) tile_catalog.set_all_fluxes_and_mags(dataset.image_decoder) fc = tile_catalog.to_full_params() assert fc.equals(full_coadd_cat, exclude=("galaxy_fluxes", "star_bools", "fluxes", "mags")) print("INFO: started generating frame") image, background = dataset.simulate_image_from_catalog( tile_catalog) print("INFO: done generating frame") if sim_frame_path: torch.save((tile_catalog.to_dict(), image, background), sim_frame_path) self.tile_catalog = tile_catalog self.image = image self.background = background assert self.image.shape[0] == 1 assert self.background.shape[0] == 1
def _make_plots(self, batch, n_samples=5): # validate worst reconstruction images. n_samples = min(len(batch["n_sources"]), n_samples) samples = np.random.choice(len(batch["n_sources"]), n_samples, replace=False) keys = [ "images", "background", "locs", "galaxy_bools", "star_bools", "fluxes", "log_fluxes", "n_sources", ] for k in keys: batch[k] = batch[k][samples] # extract non-params entries so that 'get_full_params' to works. images = batch["images"] background = batch["background"] tile_locs = batch["locs"] # obtain map estimates image_ptiles = get_images_in_tiles( torch.cat((images, background), dim=1), self.tile_slen, self.ptile_slen, ) _, n_tiles_h, n_tiles_w, _, _, _ = image_ptiles.shape image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") locs = rearrange(tile_locs, "n nth ntw ns hw -> 1 (n nth ntw) ns hw") z, _ = self._encode(image_ptiles, locs) galaxy_params = rearrange( z, "ns (n nth ntw) ms d -> (ns n) nth ntw ms d", ns=1, nth=n_tiles_h, ntw=n_tiles_w, ) tile_est = TileCatalog( self.tile_slen, { "n_sources": batch["n_sources"], "locs": batch["locs"], "galaxy_bools": batch["galaxy_bools"], "star_bools": batch["star_bools"], "fluxes": batch["fluxes"], "log_fluxes": batch["log_fluxes"], "galaxy_params": galaxy_params, }, ) est = tile_est.to_full_params() # draw all reconstruction images. # render_images automatically accounts for tiles with no galaxies. recon_images = self.image_decoder.render_images(tile_est) recon_images += background residuals = (images - recon_images) / torch.sqrt(recon_images) # draw worst `n_samples` examples as measured by absolute avg. residual error. worst_indices = residuals.abs().mean(dim=(1, 2, 3)).argsort( descending=True)[:n_samples] if self.crop_loss_at_border: bp = self.border_padding * 2 residuals[:, :, :bp, :] = 0.0 residuals[:, :, -bp:, :] = 0.0 residuals[:, :, :, :bp] = 0.0 residuals[:, :, :, -bp:] = 0.0 figsize = (12, 4 * n_samples) fig, axes = plt.subplots(nrows=n_samples, ncols=3, figsize=figsize, squeeze=False) for i, idx in enumerate(worst_indices): true_ax = axes[i, 0] recon_ax = axes[i, 1] res_ax = axes[i, 2] # add titles to axes in the first row if i == 0: true_ax.set_title("Truth", size=18) recon_ax.set_title("Reconstruction", size=18) res_ax.set_title("Residual", size=18) # vmin, vmax should be shared between reconstruction and true images. if self.max_flux_valid_plots is None: vmax = np.ceil( max(images[idx].max().item(), recon_images[idx].max().item())) else: vmax = self.max_flux_valid_plots vmin = np.floor( min(images[idx].min().item(), recon_images[idx].min().item())) vrange = (vmin, vmax) # plot! labels = ("t. gal", None, "t. star", None) # Plotting only works on square images assert images.shape[-2] == images.shape[-1] slen = images.shape[-1] - 2 * self.border_padding bp = self.border_padding image = images[i, 0].cpu().numpy() plocs = est.plocs[i].cpu().numpy().reshape(-1, 2) probs = est["galaxy_bools"][i].cpu().numpy().reshape(-1) plot_image(fig, true_ax, image, vrange=vrange, colorbar=True) plot_locs(true_ax, bp, slen, plocs, probs, cmap="cool") plot_image(fig, recon_ax, image, vrange=vrange, colorbar=True) plot_locs(recon_ax, bp, slen, plocs, probs, cmap="cool") residuals_idx = residuals[idx, 0].cpu().numpy() res_vmax = np.ceil(residuals_idx.max()) res_vmin = np.floor(residuals_idx.min()) if self.crop_loss_at_border: bp = (recon_images.shape[-1] - slen) // 2 eff_slen = slen - bp for b in (bp, bp * 2): recon_ax.axvline(b, color="w") recon_ax.axvline(b + eff_slen, color="w") recon_ax.axhline(b, color="w") recon_ax.axhline(b + eff_slen, color="w") plot_image(fig, res_ax, residuals_idx, vrange=(res_vmin, res_vmax)) if self.crop_loss_at_border: for b in (bp, bp * 2): res_ax.axvline(b, color="w") res_ax.axvline(b + eff_slen, color="w") res_ax.axhline(b, color="w") res_ax.axhline(b + eff_slen, color="w") if i == 0: add_legend(true_ax, labels) fig.tight_layout() if self.logger: self.logger.experiment.add_figure( f"Epoch:{self.current_epoch}/Worst Validation Images", fig) plt.close(fig)
def forward_tile(self, tile_cat: TileCatalog): full_cat = tile_cat.to_full_params() return self(full_cat)