def validation_step(self, batch, batch_idx): """Pytorch lightning method.""" batch_size = len(batch["images"]) out = self._get_loss(batch) # log all losses self.log("val/loss", out["loss"], batch_size=batch_size) self.log("val/counter_loss", out["counter_loss"].mean(), batch_size=batch_size) self.log("val/locs_loss", out["locs_loss"].mean(), batch_size=batch_size) self.log("val/star_params_loss", out["star_params_loss"].mean(), batch_size=batch_size) catalog_dict = { "locs": batch["locs"][:, :, :, 0:self.max_detections], "log_fluxes": batch["log_fluxes"][:, :, :, 0:self.max_detections], "galaxy_bools": batch["galaxy_bools"][:, :, :, 0:self.max_detections], "n_sources": batch["n_sources"].clamp(max=self.max_detections), } true_tile_catalog = TileCatalog(self.tile_slen, catalog_dict) true_full_catalog = true_tile_catalog.to_full_params() image_ptiles = get_images_in_tiles( torch.cat((batch["images"], batch["background"]), dim=1), self.tile_slen, self.ptile_slen, ) image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") dist_params = self.encode(image_ptiles) est_catalog_dict = self.variational_mode(dist_params) est_tile_catalog = TileCatalog.from_flat_dict( true_tile_catalog.tile_slen, true_tile_catalog.n_tiles_h, true_tile_catalog.n_tiles_w, est_catalog_dict, ) est_full_catalog = est_tile_catalog.to_full_params() metrics = self.val_detection_metrics(true_full_catalog, est_full_catalog) self.log("val/precision", metrics["precision"], batch_size=batch_size) self.log("val/recall", metrics["recall"], batch_size=batch_size) self.log("val/f1", metrics["f1"], batch_size=batch_size) self.log("val/avg_distance", metrics["avg_distance"], batch_size=batch_size) return batch
def get_catalog(self, hlims, wlims) -> FullCatalog: h, h_end = hlims[0] - self.bp, hlims[1] - self.bp w, w_end = wlims[0] - self.bp, wlims[1] - self.bp hlims_tile = int(np.floor(h / self.tile_slen)), int( np.ceil(h_end / self.tile_slen)) wlims_tile = int(np.floor(w / self.tile_slen)), int( np.ceil(w_end / self.tile_slen)) tile_dict = {} for k, v in self.tile_catalog.to_dict().items(): tile_dict[k] = v[:, hlims_tile[0]:hlims_tile[1], wlims_tile[0]:wlims_tile[1]] tile_cat = TileCatalog(self.tile_slen, tile_dict) return tile_cat.to_full_params()
def make_plots(self, batch, n_samples=16): """Produced informative plots demonstrating encoder performance.""" assert n_samples ** (0.5) % 1 == 0 if n_samples > len(batch["n_sources"]): # do nothing if low on samples. return nrows = int(n_samples**0.5) # for figure # extract non-params entries so that 'get_full_params' to works. true_param_names = {"locs", "n_sources", "galaxy_bools", "star_bools"} true_tile_params = TileCatalog( self.tile_slen, {k: v for k, v in batch.items() if k in true_param_names} ) true_params = true_tile_params.to_full_params() # prediction tile_est = true_tile_params.copy() shape = tile_est["galaxy_bools"].shape tile_est["galaxy_bools"] = batch["pred_galaxy_bools"].reshape(*shape) tile_est["star_bools"] = batch["pred_star_bools"].reshape(*shape) tile_est["galaxy_probs"] = batch["pred_galaxy_probs"].reshape(*shape) est = tile_est.to_full_params() # setup figure and axes fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize=(12, 12)) axes = axes.flatten() for ii in range(n_samples): ax = axes[ii] labels = None if ii > 0 else ("t. gal", "p. gal", "t. star", "p. star") bp = self.border_padding image = batch["images"][ii, 0].cpu().numpy() true_plocs = true_params.plocs[ii].cpu().numpy().reshape(-1, 2) true_gbools = true_params["galaxy_bools"][ii].cpu().numpy().reshape(-1) est_plocs = est.plocs[ii].cpu().numpy().reshape(-1, 2) est_gprobs = est["galaxy_probs"][ii].cpu().numpy().reshape(-1) slen, _ = image.shape reporting.plot_image(fig, ax, image, colorbar=False) reporting.plot_locs(ax, bp, slen, true_plocs, true_gbools, m="+", s=30, cmap="cool") reporting.plot_locs(ax, bp, slen, est_plocs, est_gprobs, m="x", s=20, cmap="bwr") if ii == 0: reporting.add_legend(ax, labels) fig.tight_layout() title = f"Epoch:{self.current_epoch}/Validation Images" if self.logger is not None: self.logger.experiment.add_figure(title, fig)
def stats_for_threshold( true_plocs: Tensor, est_tile_cat: TileCatalog, threshold: Optional[float] = None, ): tile_slen = est_tile_cat.tile_slen max_sources = est_tile_cat.max_sources if threshold is not None: log_probs = rearrange(est_tile_cat["n_source_log_probs"], "n nth ntw 1 1 -> n nth ntw") est_tile_cat.n_sources = log_probs >= math.log(threshold) est_cat = est_tile_cat.to_full_params() number_true = true_plocs.shape[1] number_est = int(est_cat.plocs.shape[1]) true_matches = torch.zeros(true_plocs.shape[1], dtype=torch.bool) est_tile_matches = torch.zeros(*est_tile_cat.n_sources.shape, 1, 1, dtype=torch.bool) if number_true == 0 or number_est == 0: return { "tp": torch.tensor(0.0), "fp": torch.tensor(float(number_est)), "true_matches": true_matches, "est_tile_matches": est_tile_matches, } est_matches = torch.zeros(est_cat.plocs.shape[1], dtype=torch.bool) row_indx, col_indx, d, _ = reporting.match_by_locs(true_plocs[0], est_cat.plocs[0], 1.0) true_matches[row_indx] = d est_matches[col_indx] = d est_cat["matched"] = est_matches.reshape(1, -1, 1) est_tile_matches = est_cat.to_tile_params(tile_slen, max_sources)["matched"] tp = d.sum() fp = torch.tensor(number_est) - tp return { "tp": tp, "fp": fp, "true_matches": true_matches, "est_tile_matches": est_tile_matches }
def test_galaxy_blend(get_sdss_galaxies_config, devices): overrides = { "datasets.galsim_blended_galaxies.num_workers": 0, "datasets.galsim_blended_galaxies.batch_size": 4, "datasets.galsim_blended_galaxies.n_batches": 1, "datasets.galsim_blended_galaxies.prior.max_n_sources": 3, } cfg = get_sdss_galaxies_config(overrides, devices) blend_ds: GalsimBlends = instantiate(cfg.datasets.galsim_blended_galaxies) for b in blend_ds.train_dataloader(): images, _ = b.pop("images"), b.pop("background") tile_cat = TileCatalog(4, b) full_cat = tile_cat.to_full_params() max_n_sources = full_cat.max_sources n_sources = full_cat.n_sources plocs = full_cat.plocs params = full_cat["galaxy_params"] snr = full_cat["snr"] blendedness = full_cat["blendedness"] ellips = full_cat["ellips"] mags = full_cat["mags"] fluxes = full_cat["fluxes"] assert images.shape == (4, 1, 88, 88) # 40 + 24 * 2 assert params.shape == (4, max_n_sources, 7) assert plocs.shape == (4, max_n_sources, 2) assert snr.shape == (4, max_n_sources, 1) assert blendedness.shape == (4, max_n_sources, 1) assert n_sources.shape == (4, ) assert ellips.shape == (4, max_n_sources, 2) assert mags.shape == (4, max_n_sources, 1) assert fluxes.shape == (4, max_n_sources, 1) # check empty if no sources for ii, n in enumerate(n_sources): assert torch.all(snr[ii, n:] == 0) assert torch.all(blendedness[ii, n:] == 0) assert torch.all(plocs[ii, n:] == 0)
def get_map_estimate( image_encoder: DetectionEncoder, images, background, slen: int, wlen: int = None ): # return full estimate of parameters in full image. # NOTE: slen*wlen is size of the image without border padding if wlen is None: wlen = slen assert isinstance(slen, int) and isinstance(wlen, int) # check image compatibility border1 = (images.shape[-2] - slen) / 2 border2 = (images.shape[-1] - wlen) / 2 assert border1 == border2, "border paddings on each dimension differ." assert slen % image_encoder.tile_slen == 0, "incompatible slen" assert wlen % image_encoder.tile_slen == 0, "incompatible wlen" assert border1 == image_encoder.border_padding, "incompatible border" # obtained estimates per tile, then on full image. image_ptiles = get_images_in_tiles( torch.cat((images, background), dim=1), image_encoder.tile_slen, image_encoder.ptile_slen, ) _, n_tiles_h, n_tiles_w, _, _, _ = image_ptiles.shape image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") var_params = image_encoder.encode(image_ptiles) tile_cutoff = 25**2 var_params2 = image_encoder.encode(image_ptiles[:tile_cutoff]) assert torch.allclose( var_params["n_source_log_probs"][:tile_cutoff], var_params2["n_source_log_probs"], atol=1e-5, ) assert torch.allclose( var_params["per_source_params"][:tile_cutoff], var_params2["per_source_params"], atol=1e-5, ) tile_map_dict = image_encoder.variational_mode(var_params) tile_map = TileCatalog.from_flat_dict( image_encoder.tile_slen, n_tiles_h, n_tiles_w, tile_map_dict ) full_map = tile_map.to_full_params() tile_map_tilde = full_map.to_tile_params(image_encoder.tile_slen, tile_map.max_sources) assert tile_map.equals(tile_map_tilde, exclude=("n_source_log_probs",), atol=1e-5) return full_map
def get_positive_negative_stats( true_cat: FullCatalog, est_tile_cat: TileCatalog, mag_max: float = np.inf, ): true_cat = true_cat.apply_mag_bin(-np.inf, mag_max) thresholds = np.linspace(0.01, 0.99, 99) est_tile_cat = est_tile_cat.copy() res = Parallel(n_jobs=10)( delayed(stats_for_threshold)(true_cat.plocs, est_tile_cat, t) for t in tqdm(thresholds)) out: Dict[str, Union[int, Tensor]] = {} for k in res[0]: out[k] = torch.stack([r[k] for r in res]) out["n_obj"] = true_cat.plocs.shape[1] return out
def _get_loss(self, batch): images: Tensor = batch["images"] background: Tensor = batch["background"] tile_catalog = TileCatalog(self.tile_slen, { k: v for k, v in batch.items() if k not in {"images", "background"} }) image_ptiles = get_images_in_tiles( torch.cat((images, background), dim=1), self.tile_slen, self.ptile_slen, ) image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") locs = rearrange(tile_catalog.locs, "n nth ntw ns hw -> 1 (n nth ntw) ns hw") galaxy_params, pq_divergence = self._encode(image_ptiles, locs) # draw fully reconstructed image. # NOTE: Assume recon_mean = recon_var per poisson approximation. tile_catalog["galaxy_params"] = rearrange( galaxy_params, "ns (n nth ntw) ms d -> (ns n) nth ntw ms d", nth=tile_catalog.n_tiles_h, ntw=tile_catalog.n_tiles_w, ) recon_mean = self.image_decoder.render_images(tile_catalog) recon_mean = rearrange(recon_mean, "(ns n) c h w -> ns n c h w", ns=1) recon_mean += background.unsqueeze(0) assert not torch.any(torch.isnan(recon_mean)) assert not torch.any(torch.isinf(recon_mean)) recon_losses = -Normal(recon_mean, recon_mean.sqrt()).log_prob( images.unsqueeze(0)) if self.crop_loss_at_border: bp = self.border_padding * 2 recon_losses = recon_losses[:, :, :, bp:(-bp), bp:(-bp)] assert not torch.any(torch.isnan(recon_losses)) assert not torch.any(torch.isinf(recon_losses)) # For divergence loss, we only evaluate tiles with a galaxy in them galaxy_bools = rearrange(tile_catalog["galaxy_bools"], "n nth ntw ms 1 -> 1 (n nth ntw) ms") divergence_loss = (pq_divergence * galaxy_bools).sum() return recon_losses.sum() - divergence_loss
def make_plots_of_marginal_class( outdir: Path, encoder: Encoder, decoder: ImageDecoder, frame: Frame, tile_map_recon: TileCatalog, detections_at_mode, ): tile_map_recon["matched"] = detections_at_mode["est_tile_matches"] bp = encoder.border_padding full_map = tile_map_recon.to_full_params() marginal_galaxy = full_map["galaxy_probs"][0, :, 0].exp() <= 0.6 marginal_star = full_map["galaxy_probs"][0, :, 0].exp() >= 0.4 marginal = marginal_galaxy & marginal_star csv_lines = ["fname,in_coadd,ra,dec,prob"] for i, ploc in tqdm(enumerate(full_map.plocs[0]), desc="marginal classifications"): if marginal[i] and full_map["matched"][0, i, 0].item(): size = 40 h = ploc[0].item() w = ploc[1].item() h_topleft = max(int(h - (size / 2.0)), 0) + 24 w_topleft = max(int(w - (size / 2.0)), 0) + 24 fig = create_figure_at_point( h_topleft, w_topleft, size, bp, tile_map_recon, frame, decoder, ) fname = f"h{h_topleft}_w{w_topleft}.png" fig.savefig(outdir / fname) matched = full_map["matched"][0, i, 0] if isinstance(frame, SDSSFrame): ra, dec = frame.wcs.wcs_pix2world(w + 24, h + 24, 0) else: ra, dec = None, None prob = full_map["galaxy_probs"][0, i, 0].exp().item() csv_lines.append(f"{fname},{matched},{ra},{dec},{prob}") out_csv = outdir / "marginal_detections.csv" out_csv.write_text("\n".join(csv_lines))
def render_large_scene( self, tile_catalog: TileCatalog, batch_size: Optional[int] = None ) -> Tensor: if batch_size is None: batch_size = 75**2 + 500 * 2 _, n_tiles_h, n_tiles_w, _, _ = tile_catalog.locs.shape n_rows_per_batch = batch_size // n_tiles_w h = tile_catalog.locs.shape[1] * tile_catalog.tile_slen + 2 * self.border_padding w = tile_catalog.locs.shape[2] * tile_catalog.tile_slen + 2 * self.border_padding scene = torch.zeros(1, 1, h, w) for row in range(0, n_tiles_h, n_rows_per_batch): end_row = row + n_rows_per_batch start_h = row * tile_catalog.tile_slen end_h = end_row * tile_catalog.tile_slen + 2 * self.border_padding tile_cat_row = tile_catalog.crop((row, end_row), (0, None)) img_row = self.render_images(tile_cat_row) scene[:, :, start_h:end_h] += img_row.cpu() return scene
def make_images_of_example_blend(blend_dir: Path, encoder: Encoder, decoder: ImageDecoder, frame: Frame): slen = 40 h = 1400 h_end = h + slen w = 1710 w_end = w + slen hlims = (h, h_end) wlims = (w, w_end) recon, tile_map_recon = reconstruct_scene_at_coordinates( encoder, decoder, frame.image, frame.background, hlims, wlims, ) bp = encoder.border_padding img = frame.image[:, :, h:h_end, w:w_end] bg = frame.background[:, :, h:h_end, w:w_end] resid = (img - recon) / recon.sqrt() plt.imsave(blend_dir / "img.png", img[0, 0]) plt.imsave(blend_dir / "recon.png", recon[0, 0]) plt.imsave(blend_dir / "resid.png", resid[0, 0]) masks = ((4, 5), (3, 7)) for (i, (h_mask, w_mask)) in enumerate(masks): tile_onegal_dict = tile_map_recon.to_dict() tile_onegal_dict["n_sources"] = tile_onegal_dict["n_sources"].clone() tile_onegal_dict["galaxy_bools"] = tile_onegal_dict[ "galaxy_bools"].clone() tile_onegal_dict["galaxy_bools"][:, h_mask, w_mask] = 0.0 tile_onegal_dict["n_sources"][0, h_mask, w_mask] = 0 tile_map_one_galaxy = TileCatalog(tile_map_recon.tile_slen, tile_onegal_dict) recon_one_galaxy = decoder.render_images( tile_map_one_galaxy).detach().cpu() recon_one_galaxy = recon_one_galaxy[:, :, bp:-bp, bp:-bp] + bg plt.imsave(blend_dir / f"galaxy_{i}.png", recon_one_galaxy[0, 0, :, :], vmax=recon.max())
def sample_prior(self, tile_slen: int, batch_size: int, n_tiles_h: int, n_tiles_w: int) -> TileCatalog: """Samples latent variables from the prior of an astronomical image. Args: tile_slen: Side length of catalog tiles. batch_size: The number of samples to draw. n_tiles_h: Number of tiles height-wise. n_tiles_w: Number of tiles width-wise. Returns: A dictionary of tensors. Each tensor is a particular per-tile quantity; i.e. the first three dimensions of each tensor are `(batch_size, self.n_tiles_h, self.n_tiles_w)`. The remaining dimensions are variable-specific. """ assert n_tiles_h > 0 assert n_tiles_w > 0 n_sources = self._sample_n_sources(batch_size, n_tiles_h, n_tiles_w) is_on_array = get_is_on_from_n_sources(n_sources, self.max_sources) locs = self._sample_locs(is_on_array) galaxy_bools, star_bools = self._sample_n_galaxies_and_stars( is_on_array) galaxy_params = self._sample_galaxy_params(galaxy_bools) fluxes = self._sample_fluxes(star_bools) log_fluxes = self._get_log_fluxes(fluxes) # per tile quantities. return TileCatalog( tile_slen, { "n_sources": n_sources, "locs": locs, "galaxy_bools": galaxy_bools, "star_bools": star_bools, "galaxy_params": galaxy_params, "fluxes": fluxes, "log_fluxes": log_fluxes, }, )
def variational_mode(self, image: Tensor, background: Tensor) -> TileCatalog: """Get maximum a posteriori of catalog from image padded tiles. Note that, strictly speaking, this is not the true MAP of the variational distribution of the catalog. Rather, we use sequential estimation; the MAP of the locations is first estimated, then plugged-in to the binary and galaxy encoders. Thus, the binary and galaxy encoders are conditioned on the location MAP. The true MAP would require optimizing over the entire catalog jointly, but this is not tractable. Args: image: An astronomical image, with shape `n * n_bands * h * w`. background: Background associated with image, with shape `n * n_bands * h * w`. Returns: A dictionary of the maximum a posteriori of the catalog in tiles. Specifically, this dictionary comprises: - The output of DetectionEncoder.variational_mode() - 'galaxy_bools', 'star_bools', and 'galaxy_probs' from BinaryEncoder. - 'galaxy_params' from GalaxyEncoder. """ tile_map_dict = self.sample(image, background, None) n_tiles_h = (image.shape[2] - 2 * self.border_padding) // self.detection_encoder.tile_slen n_tiles_w = (image.shape[3] - 2 * self.border_padding) // self.detection_encoder.tile_slen return TileCatalog.from_flat_dict( self.detection_encoder.tile_slen, n_tiles_h, n_tiles_w, {k: v.squeeze(0) for k, v in tile_map_dict.items()}, )
def forward_tile(self, tile_cat: TileCatalog): full_cat = tile_cat.to_full_params() return self(full_cat)
def validation_epoch_end(self, outputs, kind="validation", max_n_samples=16): # pylint: disable=too-many-statements """Pytorch lightning method.""" batch: Dict[str, Tensor] = outputs[-1] if self.n_bands > 1: return n_samples = min( int(math.sqrt(len(batch["n_sources"])))**2, max_n_samples) nrows = int(n_samples**0.5) # for figure catalog_dict = { "locs": batch["locs"][:, :, :, 0:self.max_detections], "log_fluxes": batch["log_fluxes"][:, :, :, 0:self.max_detections], "galaxy_bools": batch["galaxy_bools"][:, :, :, 0:self.max_detections], "star_bools": batch["star_bools"][:, :, :, 0:self.max_detections], "n_sources": batch["n_sources"].clamp(max=self.max_detections), } true_tile_catalog = TileCatalog(self.tile_slen, catalog_dict) true_cat = true_tile_catalog.to_full_params() image_ptiles = get_images_in_tiles( torch.cat((batch["images"], batch["background"]), dim=1), self.tile_slen, self.ptile_slen, ) image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") dist_params = self.encode(image_ptiles) est_catalog_dict = self.variational_mode(dist_params) est_tile_catalog = TileCatalog.from_flat_dict( true_tile_catalog.tile_slen, true_tile_catalog.n_tiles_h, true_tile_catalog.n_tiles_w, est_catalog_dict, ) est_cat = est_tile_catalog.to_full_params() # setup figure and axes. fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize=(12, 12)) axes = axes.flatten() if nrows > 1 else [axes] images = batch["images"] assert images.shape[-2] == images.shape[-1] bp = self.border_padding for idx, ax in enumerate(axes): true_n_sources = true_cat.n_sources[idx].item() n_sources = est_cat.n_sources[idx].item() ax.set_xlabel(f"True num: {true_n_sources}; Est num: {n_sources}") # add white border showing where centers of stars and galaxies can be ax.axvline(bp, color="w") ax.axvline(images.shape[-1] - bp, color="w") ax.axhline(bp, color="w") ax.axhline(images.shape[-2] - bp, color="w") # plot image first image = images[idx, 0].cpu().numpy() vmin = image.min().item() vmax = image.max().item() divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) im = ax.matshow(image, vmin=vmin, vmax=vmax, cmap="viridis") fig.colorbar(im, cax=cax, orientation="vertical") true_cat.plot_plocs(ax, idx, "galaxy", bp=bp, color="r", marker="x", s=20) true_cat.plot_plocs(ax, idx, "star", bp=bp, color="c", marker="x", s=20) est_cat.plot_plocs(ax, idx, "all", bp=bp, color="b", marker="+", s=30) if idx == 0: ax.scatter(None, None, color="r", marker="x", s=20, label="t.gal") ax.scatter(None, None, color="c", marker="x", s=20, label="t.star") ax.scatter(None, None, color="b", marker="+", s=30, label="p.source") ax.legend( bbox_to_anchor=(0.0, 1.2, 1.0, 0.102), loc="lower left", ncol=2, mode="expand", borderaxespad=0.0, ) fig.tight_layout() if self.logger: if kind == "validation": title = f"Epoch:{self.current_epoch}/Validation Images" self.logger.experiment.add_figure(title, fig) elif kind == "testing": self.logger.experiment.add_figure("Test Images", fig) else: raise NotImplementedError() plt.close(fig)
def compute_data(self, blend_file: Path, encoder: Encoder, decoder: ImageDecoder): blend_data = torch.load(blend_file) images = blend_data.pop("images") background = blend_data.pop("background") n_batches, _, slen, _ = images.shape assert background.shape == (1, slen, slen) # prepare background background = background.unsqueeze(0) background = background.expand(n_batches, 1, slen, slen) # first create FullCatalog from simulated data tile_cat = TileCatalog(decoder.tile_slen, blend_data).cpu() full_truth = tile_cat.to_full_params() print("INFO: BLISS posterior inference on images.") tile_est = encoder.variational_mode(images, background) tile_est.set_all_fluxes_and_mags(decoder) tile_est.set_galaxy_ellips(decoder, scale=0.393) tile_est = tile_est.cpu() full_est = tile_est.to_full_params() snr = [] blendedness = [] true_mags = [] true_ellips1 = [] true_ellips2 = [] est_mags = [] est_ellips1 = [] est_ellips2 = [] for ii in tqdm(range(n_batches), desc="Matching batches"): true_plocs_ii, est_plocs_ii = full_truth.plocs[ii], full_est.plocs[ ii] tindx, eindx, dkeep, _ = match_by_locs(true_plocs_ii, est_plocs_ii) n_matches = len(tindx[dkeep]) snr_ii = full_truth["snr"][ii][tindx][dkeep] blendedness_ii = full_truth["blendedness"][ii][tindx][dkeep] true_mag_ii = full_truth["mags"][ii][tindx][dkeep] est_mag_ii = full_est["mags"][ii][eindx][dkeep] true_ellips_ii = full_truth["ellips"][ii][tindx][dkeep] est_ellips_ii = full_est["ellips"][ii][eindx][dkeep] n_matches = len(snr_ii) for jj in range(n_matches): snr.append(snr_ii[jj].item()) blendedness.append(blendedness_ii[jj].item()) true_mags.append(true_mag_ii[jj].item()) est_mags.append(est_mag_ii[jj].item()) true_ellips1.append(true_ellips_ii[jj][0].item()) true_ellips2.append(true_ellips_ii[jj][1].item()) est_ellips1.append(est_ellips_ii[jj][0].item()) est_ellips2.append(est_ellips_ii[jj][1].item()) true_ellips = torch.vstack( [torch.tensor(true_ellips1), torch.tensor(true_ellips2)]) true_ellips = true_ellips.T.reshape(-1, 2) est_ellips = torch.vstack( [torch.tensor(est_ellips1), torch.tensor(est_ellips2)]) est_ellips = est_ellips.T.reshape(-1, 2) return { "snr": torch.tensor(snr), "blendedness": torch.tensor(blendedness), "true_mags": torch.tensor(true_mags), "est_mags": torch.tensor(est_mags), "true_ellips": true_ellips, "est_ellips": est_ellips, }
def __init__( self, dataset: SimulatedDataset, coadd: str, n_tiles_h, n_tiles_w, cache_dir=None, ): dataset.to("cpu") self.bp = dataset.image_decoder.border_padding self.tile_slen = dataset.tile_slen self.coadd_file = Path(coadd) self.sdss_dir = self.coadd_file.parent run = 94 camcol = 1 field = 12 bands = (2, ) sdss_data = SloanDigitalSkySurvey( sdss_dir=self.sdss_dir, run=run, camcol=camcol, fields=(field, ), bands=bands, ) wcs = sdss_data[0]["wcs"][0] if cache_dir is not None: sim_frame_path = Path(cache_dir) / "simulated_frame.pt" else: sim_frame_path = None if sim_frame_path and sim_frame_path.exists(): tile_catalog_dict, image, background = torch.load(sim_frame_path) tile_catalog = TileCatalog(self.tile_slen, tile_catalog_dict) else: hlim = (self.bp, self.bp + n_tiles_h * self.tile_slen) wlim = (self.bp, self.bp + n_tiles_w * self.tile_slen) full_coadd_cat = CoaddFullCatalog.from_file(coadd, wcs, hlim, wlim, band="r") if dataset.image_prior.galaxy_prior is not None: full_coadd_cat[ "galaxy_params"] = dataset.image_prior.galaxy_prior.sample( full_coadd_cat.n_sources, "cpu").unsqueeze(0) full_coadd_cat.plocs = full_coadd_cat.plocs + 0.5 max_sources = dataset.image_prior.max_sources tile_catalog = full_coadd_cat.to_tile_params( self.tile_slen, max_sources) tile_catalog.set_all_fluxes_and_mags(dataset.image_decoder) fc = tile_catalog.to_full_params() assert fc.equals(full_coadd_cat, exclude=("galaxy_fluxes", "star_bools", "fluxes", "mags")) print("INFO: started generating frame") image, background = dataset.simulate_image_from_catalog( tile_catalog) print("INFO: done generating frame") if sim_frame_path: torch.save((tile_catalog.to_dict(), image, background), sim_frame_path) self.tile_catalog = tile_catalog self.image = image self.background = background assert self.image.shape[0] == 1 assert self.background.shape[0] == 1
def _make_plots(self, batch, n_samples=5): # validate worst reconstruction images. n_samples = min(len(batch["n_sources"]), n_samples) samples = np.random.choice(len(batch["n_sources"]), n_samples, replace=False) keys = [ "images", "background", "locs", "galaxy_bools", "star_bools", "fluxes", "log_fluxes", "n_sources", ] for k in keys: batch[k] = batch[k][samples] # extract non-params entries so that 'get_full_params' to works. images = batch["images"] background = batch["background"] tile_locs = batch["locs"] # obtain map estimates image_ptiles = get_images_in_tiles( torch.cat((images, background), dim=1), self.tile_slen, self.ptile_slen, ) _, n_tiles_h, n_tiles_w, _, _, _ = image_ptiles.shape image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") locs = rearrange(tile_locs, "n nth ntw ns hw -> 1 (n nth ntw) ns hw") z, _ = self._encode(image_ptiles, locs) galaxy_params = rearrange( z, "ns (n nth ntw) ms d -> (ns n) nth ntw ms d", ns=1, nth=n_tiles_h, ntw=n_tiles_w, ) tile_est = TileCatalog( self.tile_slen, { "n_sources": batch["n_sources"], "locs": batch["locs"], "galaxy_bools": batch["galaxy_bools"], "star_bools": batch["star_bools"], "fluxes": batch["fluxes"], "log_fluxes": batch["log_fluxes"], "galaxy_params": galaxy_params, }, ) est = tile_est.to_full_params() # draw all reconstruction images. # render_images automatically accounts for tiles with no galaxies. recon_images = self.image_decoder.render_images(tile_est) recon_images += background residuals = (images - recon_images) / torch.sqrt(recon_images) # draw worst `n_samples` examples as measured by absolute avg. residual error. worst_indices = residuals.abs().mean(dim=(1, 2, 3)).argsort( descending=True)[:n_samples] if self.crop_loss_at_border: bp = self.border_padding * 2 residuals[:, :, :bp, :] = 0.0 residuals[:, :, -bp:, :] = 0.0 residuals[:, :, :, :bp] = 0.0 residuals[:, :, :, -bp:] = 0.0 figsize = (12, 4 * n_samples) fig, axes = plt.subplots(nrows=n_samples, ncols=3, figsize=figsize, squeeze=False) for i, idx in enumerate(worst_indices): true_ax = axes[i, 0] recon_ax = axes[i, 1] res_ax = axes[i, 2] # add titles to axes in the first row if i == 0: true_ax.set_title("Truth", size=18) recon_ax.set_title("Reconstruction", size=18) res_ax.set_title("Residual", size=18) # vmin, vmax should be shared between reconstruction and true images. if self.max_flux_valid_plots is None: vmax = np.ceil( max(images[idx].max().item(), recon_images[idx].max().item())) else: vmax = self.max_flux_valid_plots vmin = np.floor( min(images[idx].min().item(), recon_images[idx].min().item())) vrange = (vmin, vmax) # plot! labels = ("t. gal", None, "t. star", None) # Plotting only works on square images assert images.shape[-2] == images.shape[-1] slen = images.shape[-1] - 2 * self.border_padding bp = self.border_padding image = images[i, 0].cpu().numpy() plocs = est.plocs[i].cpu().numpy().reshape(-1, 2) probs = est["galaxy_bools"][i].cpu().numpy().reshape(-1) plot_image(fig, true_ax, image, vrange=vrange, colorbar=True) plot_locs(true_ax, bp, slen, plocs, probs, cmap="cool") plot_image(fig, recon_ax, image, vrange=vrange, colorbar=True) plot_locs(recon_ax, bp, slen, plocs, probs, cmap="cool") residuals_idx = residuals[idx, 0].cpu().numpy() res_vmax = np.ceil(residuals_idx.max()) res_vmin = np.floor(residuals_idx.min()) if self.crop_loss_at_border: bp = (recon_images.shape[-1] - slen) // 2 eff_slen = slen - bp for b in (bp, bp * 2): recon_ax.axvline(b, color="w") recon_ax.axvline(b + eff_slen, color="w") recon_ax.axhline(b, color="w") recon_ax.axhline(b + eff_slen, color="w") plot_image(fig, res_ax, residuals_idx, vrange=(res_vmin, res_vmax)) if self.crop_loss_at_border: for b in (bp, bp * 2): res_ax.axvline(b, color="w") res_ax.axvline(b + eff_slen, color="w") res_ax.axhline(b, color="w") res_ax.axhline(b + eff_slen, color="w") if i == 0: add_legend(true_ax, labels) fig.tight_layout() if self.logger: self.logger.experiment.add_figure( f"Epoch:{self.current_epoch}/Worst Validation Images", fig) plt.close(fig)
def create_figure_at_point( h: int, w: int, size: int, bp: int, tile_map_recon: TileCatalog, frame: Frame, dec: ImageDecoder, est_catalog: Optional[FullCatalog] = None, show_tiles=False, use_image_bounds=False, **kwargs, ): tile_slen = tile_map_recon.tile_slen if h + size + bp > frame.image.shape[2]: h = frame.image.shape[2] - size - bp if w + size + bp > frame.image.shape[3]: w = frame.image.shape[3] - size - bp h_tile = (h - bp) // tile_slen w_tile = (w - bp) // tile_slen n_tiles = size // tile_slen hlims = bp + (h_tile * tile_slen), bp + ((h_tile + n_tiles) * tile_slen) wlims = bp + (w_tile * tile_slen), bp + ((w_tile + n_tiles) * tile_slen) tile_map_cropped = tile_map_recon.crop((h_tile, h_tile + n_tiles), (w_tile, w_tile + n_tiles)) full_map_cropped = tile_map_cropped.to_full_params() img_cropped = frame.image[0, 0, hlims[0]:hlims[1], wlims[0]:wlims[1]] bg_cropped = frame.background[0, 0, hlims[0]:hlims[1], wlims[0]:wlims[1]] with torch.no_grad(): recon_cropped = dec.render_images(tile_map_cropped.to(dec.device)) recon_cropped = recon_cropped.to("cpu")[0, 0, bp:-bp, bp:-bp] + bg_cropped resid_cropped = (img_cropped - recon_cropped) / recon_cropped.sqrt() if est_catalog is not None: tile_est_catalog = est_catalog.to_tile_params( tile_map_cropped.tile_slen, tile_map_cropped.max_sources) tile_est_catalog_cropped = tile_est_catalog.crop( (h_tile, h_tile + n_tiles), (w_tile, w_tile + n_tiles)) est_catalog_cropped = tile_est_catalog_cropped.to_full_params() else: est_catalog_cropped = None if show_tiles: tile_map = tile_map_cropped else: tile_map = None if use_image_bounds: vmin = img_cropped.min().item() vmax = img_cropped.max().item() else: vmin = 800 vmax = 1200 return create_figure( img_cropped, recon_cropped, resid_cropped, map_recon=full_map_cropped, coadd_objects=est_catalog_cropped, tile_map=tile_map, vmin=vmin, vmax=vmax, **kwargs, )