def test_variational_mode(self, devices): """Tests forward function of source encoder. Arguments: devices: GPU device information. """ device = devices.device batch_size = 2 n_tiles_h = 3 n_tiles_w = 5 max_detections = 4 ptile_slen = 10 n_bands = 2 tile_slen = 2 background = (10.0, 20.0) # get encoder star_encoder: DetectionEncoder = DetectionEncoder( LogBackgroundTransform(), channel=8, dropout=0, spatial_dropout=0, hidden=64, ptile_slen=ptile_slen, tile_slen=tile_slen, n_bands=n_bands, mean_detections=0.48, max_detections=max_detections, ).to(device) with torch.no_grad(): star_encoder.eval() # simulate image padded tiles images = torch.randn( batch_size, n_bands, ptile_slen + (n_tiles_h - 1) * tile_slen, ptile_slen + (n_tiles_w - 1) * tile_slen, device=device, ) background_tensor = torch.tensor(background, device=device) background_tensor = background_tensor.reshape(1, -1, 1, 1).expand(*images.shape) images *= background_tensor.sqrt() images += background_tensor image_ptiles = get_images_in_tiles( torch.cat((images, background_tensor), dim=1), star_encoder.tile_slen, star_encoder.ptile_slen, ) image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") var_params = star_encoder.encode(image_ptiles) catalog = star_encoder.variational_mode(var_params) assert catalog["n_sources"].size() == torch.Size([batch_size * n_tiles_h * n_tiles_w]) correct_locs_shape = torch.Size([batch_size * n_tiles_h * n_tiles_w, max_detections, 2]) assert catalog["locs"].shape == correct_locs_shape
def validation_step(self, batch, batch_idx): """Pytorch lightning method.""" batch_size = len(batch["images"]) out = self._get_loss(batch) # log all losses self.log("val/loss", out["loss"], batch_size=batch_size) self.log("val/counter_loss", out["counter_loss"].mean(), batch_size=batch_size) self.log("val/locs_loss", out["locs_loss"].mean(), batch_size=batch_size) self.log("val/star_params_loss", out["star_params_loss"].mean(), batch_size=batch_size) catalog_dict = { "locs": batch["locs"][:, :, :, 0:self.max_detections], "log_fluxes": batch["log_fluxes"][:, :, :, 0:self.max_detections], "galaxy_bools": batch["galaxy_bools"][:, :, :, 0:self.max_detections], "n_sources": batch["n_sources"].clamp(max=self.max_detections), } true_tile_catalog = TileCatalog(self.tile_slen, catalog_dict) true_full_catalog = true_tile_catalog.to_full_params() image_ptiles = get_images_in_tiles( torch.cat((batch["images"], batch["background"]), dim=1), self.tile_slen, self.ptile_slen, ) image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") dist_params = self.encode(image_ptiles) est_catalog_dict = self.variational_mode(dist_params) est_tile_catalog = TileCatalog.from_flat_dict( true_tile_catalog.tile_slen, true_tile_catalog.n_tiles_h, true_tile_catalog.n_tiles_w, est_catalog_dict, ) est_full_catalog = est_tile_catalog.to_full_params() metrics = self.val_detection_metrics(true_full_catalog, est_full_catalog) self.log("val/precision", metrics["precision"], batch_size=batch_size) self.log("val/recall", metrics["recall"], batch_size=batch_size) self.log("val/f1", metrics["f1"], batch_size=batch_size) self.log("val/avg_distance", metrics["avg_distance"], batch_size=batch_size) return batch
def get_map_estimate( image_encoder: DetectionEncoder, images, background, slen: int, wlen: int = None ): # return full estimate of parameters in full image. # NOTE: slen*wlen is size of the image without border padding if wlen is None: wlen = slen assert isinstance(slen, int) and isinstance(wlen, int) # check image compatibility border1 = (images.shape[-2] - slen) / 2 border2 = (images.shape[-1] - wlen) / 2 assert border1 == border2, "border paddings on each dimension differ." assert slen % image_encoder.tile_slen == 0, "incompatible slen" assert wlen % image_encoder.tile_slen == 0, "incompatible wlen" assert border1 == image_encoder.border_padding, "incompatible border" # obtained estimates per tile, then on full image. image_ptiles = get_images_in_tiles( torch.cat((images, background), dim=1), image_encoder.tile_slen, image_encoder.ptile_slen, ) _, n_tiles_h, n_tiles_w, _, _, _ = image_ptiles.shape image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") var_params = image_encoder.encode(image_ptiles) tile_cutoff = 25**2 var_params2 = image_encoder.encode(image_ptiles[:tile_cutoff]) assert torch.allclose( var_params["n_source_log_probs"][:tile_cutoff], var_params2["n_source_log_probs"], atol=1e-5, ) assert torch.allclose( var_params["per_source_params"][:tile_cutoff], var_params2["per_source_params"], atol=1e-5, ) tile_map_dict = image_encoder.variational_mode(var_params) tile_map = TileCatalog.from_flat_dict( image_encoder.tile_slen, n_tiles_h, n_tiles_w, tile_map_dict ) full_map = tile_map.to_full_params() tile_map_tilde = full_map.to_tile_params(image_encoder.tile_slen, tile_map.max_sources) assert tile_map.equals(tile_map_tilde, exclude=("n_source_log_probs",), atol=1e-5) return full_map
def _make_ptile_loader(self, image: Tensor, background: Tensor, n_tiles_h: int): img_bg = torch.cat((image, background), dim=1).to(self.device) n_images = image.shape[0] for start_b in range(0, n_images, self.n_images_per_batch): for row in range(0, n_tiles_h, self.n_rows_per_batch): end_b = start_b + self.n_images_per_batch end_row = row + self.n_rows_per_batch start_h = row * self.detection_encoder.tile_slen end_h = end_row * self.detection_encoder.tile_slen + 2 * self.border_padding img_bg_cropped = img_bg[start_b:end_b, :, start_h:end_h, :] image_ptiles = get_images_in_tiles( img_bg_cropped, self.detection_encoder.tile_slen, self.detection_encoder.ptile_slen, ) yield image_ptiles.reshape(-1, *image_ptiles.shape[-3:])
def _get_loss(self, batch): images: Tensor = batch["images"] background: Tensor = batch["background"] tile_catalog = TileCatalog(self.tile_slen, { k: v for k, v in batch.items() if k not in {"images", "background"} }) image_ptiles = get_images_in_tiles( torch.cat((images, background), dim=1), self.tile_slen, self.ptile_slen, ) image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") locs = rearrange(tile_catalog.locs, "n nth ntw ns hw -> 1 (n nth ntw) ns hw") galaxy_params, pq_divergence = self._encode(image_ptiles, locs) # draw fully reconstructed image. # NOTE: Assume recon_mean = recon_var per poisson approximation. tile_catalog["galaxy_params"] = rearrange( galaxy_params, "ns (n nth ntw) ms d -> (ns n) nth ntw ms d", nth=tile_catalog.n_tiles_h, ntw=tile_catalog.n_tiles_w, ) recon_mean = self.image_decoder.render_images(tile_catalog) recon_mean = rearrange(recon_mean, "(ns n) c h w -> ns n c h w", ns=1) recon_mean += background.unsqueeze(0) assert not torch.any(torch.isnan(recon_mean)) assert not torch.any(torch.isinf(recon_mean)) recon_losses = -Normal(recon_mean, recon_mean.sqrt()).log_prob( images.unsqueeze(0)) if self.crop_loss_at_border: bp = self.border_padding * 2 recon_losses = recon_losses[:, :, :, bp:(-bp), bp:(-bp)] assert not torch.any(torch.isnan(recon_losses)) assert not torch.any(torch.isinf(recon_losses)) # For divergence loss, we only evaluate tiles with a galaxy in them galaxy_bools = rearrange(tile_catalog["galaxy_bools"], "n nth ntw ms 1 -> 1 (n nth ntw) ms") divergence_loss = (pq_divergence * galaxy_bools).sum() return recon_losses.sum() - divergence_loss
def test_sample(self, devices): device = devices.device max_detections = 4 ptile_slen = 10 n_bands = 2 tile_slen = 2 n_samples = 5 background = (10.0, 20.0) background_tensor = torch.tensor(background).view(1, -1, 1, 1).to(device) images = ( torch.randn(1, n_bands, 4 * ptile_slen, 4 * ptile_slen).to(device) * background_tensor.sqrt() + background_tensor ) background_tensor = background_tensor.expand(*images.shape) star_encoder: DetectionEncoder = DetectionEncoder( LogBackgroundTransform(), channel=8, dropout=0, spatial_dropout=0, hidden=64, ptile_slen=ptile_slen, tile_slen=tile_slen, n_bands=n_bands, mean_detections=0.48, max_detections=max_detections, ).to(device) image_ptiles = get_images_in_tiles( torch.cat((images, background_tensor), dim=1), star_encoder.tile_slen, star_encoder.ptile_slen, ) image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") var_params = star_encoder.encode(image_ptiles) star_encoder.sample(var_params, n_samples)
def get_prediction(self, batch: Dict[str, Tensor]): """Return loss, accuracy, binary probabilities, and MAP classifications for given batch.""" galaxy_bools = batch["galaxy_bools"].reshape(-1) locs = rearrange(batch["locs"], "n nth ntw ms hw -> 1 (n nth ntw) ms hw") image_ptiles = get_images_in_tiles( torch.cat((batch["images"], batch["background"]), dim=1), self.tile_slen, self.ptile_slen, ) image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") galaxy_probs = self.forward(image_ptiles, locs) galaxy_probs = galaxy_probs.reshape(-1) tile_is_on_array = get_is_on_from_n_sources(batch["n_sources"], self.max_sources) tile_is_on_array = tile_is_on_array.reshape(-1) # we need to calculate cross entropy loss, only for "on" sources loss = BCELoss(reduction="none")(galaxy_probs, galaxy_bools) * tile_is_on_array loss = loss.sum() # get predictions for calculating metrics pred_galaxy_bools = (galaxy_probs > 0.5).float() * tile_is_on_array correct = ((pred_galaxy_bools.eq(galaxy_bools)) * tile_is_on_array).sum() total_n_sources = batch["n_sources"].sum() acc = correct / total_n_sources # finally organize quantities and return as a dictionary pred_star_bools = (1 - pred_galaxy_bools) * tile_is_on_array return { "loss": loss, "acc": acc, "galaxy_bools": pred_galaxy_bools, "star_bools": pred_star_bools, "galaxy_probs": galaxy_probs, }
def validation_epoch_end(self, outputs, kind="validation", max_n_samples=16): # pylint: disable=too-many-statements """Pytorch lightning method.""" batch: Dict[str, Tensor] = outputs[-1] if self.n_bands > 1: return n_samples = min( int(math.sqrt(len(batch["n_sources"])))**2, max_n_samples) nrows = int(n_samples**0.5) # for figure catalog_dict = { "locs": batch["locs"][:, :, :, 0:self.max_detections], "log_fluxes": batch["log_fluxes"][:, :, :, 0:self.max_detections], "galaxy_bools": batch["galaxy_bools"][:, :, :, 0:self.max_detections], "star_bools": batch["star_bools"][:, :, :, 0:self.max_detections], "n_sources": batch["n_sources"].clamp(max=self.max_detections), } true_tile_catalog = TileCatalog(self.tile_slen, catalog_dict) true_cat = true_tile_catalog.to_full_params() image_ptiles = get_images_in_tiles( torch.cat((batch["images"], batch["background"]), dim=1), self.tile_slen, self.ptile_slen, ) image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") dist_params = self.encode(image_ptiles) est_catalog_dict = self.variational_mode(dist_params) est_tile_catalog = TileCatalog.from_flat_dict( true_tile_catalog.tile_slen, true_tile_catalog.n_tiles_h, true_tile_catalog.n_tiles_w, est_catalog_dict, ) est_cat = est_tile_catalog.to_full_params() # setup figure and axes. fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize=(12, 12)) axes = axes.flatten() if nrows > 1 else [axes] images = batch["images"] assert images.shape[-2] == images.shape[-1] bp = self.border_padding for idx, ax in enumerate(axes): true_n_sources = true_cat.n_sources[idx].item() n_sources = est_cat.n_sources[idx].item() ax.set_xlabel(f"True num: {true_n_sources}; Est num: {n_sources}") # add white border showing where centers of stars and galaxies can be ax.axvline(bp, color="w") ax.axvline(images.shape[-1] - bp, color="w") ax.axhline(bp, color="w") ax.axhline(images.shape[-2] - bp, color="w") # plot image first image = images[idx, 0].cpu().numpy() vmin = image.min().item() vmax = image.max().item() divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) im = ax.matshow(image, vmin=vmin, vmax=vmax, cmap="viridis") fig.colorbar(im, cax=cax, orientation="vertical") true_cat.plot_plocs(ax, idx, "galaxy", bp=bp, color="r", marker="x", s=20) true_cat.plot_plocs(ax, idx, "star", bp=bp, color="c", marker="x", s=20) est_cat.plot_plocs(ax, idx, "all", bp=bp, color="b", marker="+", s=30) if idx == 0: ax.scatter(None, None, color="r", marker="x", s=20, label="t.gal") ax.scatter(None, None, color="c", marker="x", s=20, label="t.star") ax.scatter(None, None, color="b", marker="+", s=30, label="p.source") ax.legend( bbox_to_anchor=(0.0, 1.2, 1.0, 0.102), loc="lower left", ncol=2, mode="expand", borderaxespad=0.0, ) fig.tight_layout() if self.logger: if kind == "validation": title = f"Epoch:{self.current_epoch}/Validation Images" self.logger.experiment.add_figure(title, fig) elif kind == "testing": self.logger.experiment.add_figure("Test Images", fig) else: raise NotImplementedError() plt.close(fig)
def _get_loss(self, batch: Dict[str, Tensor]): true_catalog = { "locs": rearrange( batch["locs"][:, :, :, 0:self.max_detections], "n nth ntw ns hw -> (n nth ntw) ns hw", ), "log_fluxes": rearrange( batch["log_fluxes"][:, :, :, 0:self.max_detections], "n nth ntw ns b -> (n nth ntw) ns b", ), "galaxy_bools": rearrange( batch["galaxy_bools"][:, :, :, 0:self.max_detections], "n nth ntw ns 1 -> (n nth ntw) ns 1", ), "n_sources": rearrange(batch["n_sources"].clamp(max=self.max_detections), "n nth ntw -> (n nth ntw)"), } true_catalog["is_on_array"] = get_is_on_from_n_sources( true_catalog["n_sources"], self.max_detections) image_ptiles = get_images_in_tiles( torch.cat((batch["images"], batch["background"]), dim=1), self.tile_slen, self.ptile_slen, ) image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") dist_params = self.encode(image_ptiles) nslp_flat = rearrange(dist_params["n_source_log_probs"], "n_ptiles ns -> n_ptiles ns") counter_loss = F.nll_loss(nslp_flat, true_catalog["n_sources"].reshape(-1), reduction="none") pred = self._encode_for_n_sources( dist_params["per_source_params"], rearrange(true_catalog["n_sources"], "n_ptiles -> 1 n_ptiles"), ) locs_log_probs_all = _get_params_logprob_all_combs( true_catalog["locs"], pred["loc_mean"].squeeze(0), pred["loc_sd"].squeeze(0), ) star_params_log_probs_all = _get_params_logprob_all_combs( true_catalog["log_fluxes"], pred["log_flux_mean"].squeeze(0), pred["log_flux_sd"].squeeze(0), ) (locs_loss, star_params_loss) = _get_min_perm_loss( locs_log_probs_all, star_params_log_probs_all, rearrange(true_catalog["galaxy_bools"], "n_ptiles ns 1 -> n_ptiles ns"), true_catalog["is_on_array"], ) loss_vec = locs_loss * (locs_loss.detach() < 1e6).float() + counter_loss + star_params_loss loss = loss_vec.mean() return { "loss": loss, "counter_loss": counter_loss, "locs_loss": locs_loss, "star_params_loss": star_params_loss, }
def get_images_in_ptiles(self, images): """Run get_images_in_ptiles with correct tile_slen and ptile_slen.""" return get_images_in_tiles(images, self.detection_encoder.tile_slen, self.detection_encoder.ptile_slen)
def _make_plots(self, batch, n_samples=5): # validate worst reconstruction images. n_samples = min(len(batch["n_sources"]), n_samples) samples = np.random.choice(len(batch["n_sources"]), n_samples, replace=False) keys = [ "images", "background", "locs", "galaxy_bools", "star_bools", "fluxes", "log_fluxes", "n_sources", ] for k in keys: batch[k] = batch[k][samples] # extract non-params entries so that 'get_full_params' to works. images = batch["images"] background = batch["background"] tile_locs = batch["locs"] # obtain map estimates image_ptiles = get_images_in_tiles( torch.cat((images, background), dim=1), self.tile_slen, self.ptile_slen, ) _, n_tiles_h, n_tiles_w, _, _, _ = image_ptiles.shape image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") locs = rearrange(tile_locs, "n nth ntw ns hw -> 1 (n nth ntw) ns hw") z, _ = self._encode(image_ptiles, locs) galaxy_params = rearrange( z, "ns (n nth ntw) ms d -> (ns n) nth ntw ms d", ns=1, nth=n_tiles_h, ntw=n_tiles_w, ) tile_est = TileCatalog( self.tile_slen, { "n_sources": batch["n_sources"], "locs": batch["locs"], "galaxy_bools": batch["galaxy_bools"], "star_bools": batch["star_bools"], "fluxes": batch["fluxes"], "log_fluxes": batch["log_fluxes"], "galaxy_params": galaxy_params, }, ) est = tile_est.to_full_params() # draw all reconstruction images. # render_images automatically accounts for tiles with no galaxies. recon_images = self.image_decoder.render_images(tile_est) recon_images += background residuals = (images - recon_images) / torch.sqrt(recon_images) # draw worst `n_samples` examples as measured by absolute avg. residual error. worst_indices = residuals.abs().mean(dim=(1, 2, 3)).argsort( descending=True)[:n_samples] if self.crop_loss_at_border: bp = self.border_padding * 2 residuals[:, :, :bp, :] = 0.0 residuals[:, :, -bp:, :] = 0.0 residuals[:, :, :, :bp] = 0.0 residuals[:, :, :, -bp:] = 0.0 figsize = (12, 4 * n_samples) fig, axes = plt.subplots(nrows=n_samples, ncols=3, figsize=figsize, squeeze=False) for i, idx in enumerate(worst_indices): true_ax = axes[i, 0] recon_ax = axes[i, 1] res_ax = axes[i, 2] # add titles to axes in the first row if i == 0: true_ax.set_title("Truth", size=18) recon_ax.set_title("Reconstruction", size=18) res_ax.set_title("Residual", size=18) # vmin, vmax should be shared between reconstruction and true images. if self.max_flux_valid_plots is None: vmax = np.ceil( max(images[idx].max().item(), recon_images[idx].max().item())) else: vmax = self.max_flux_valid_plots vmin = np.floor( min(images[idx].min().item(), recon_images[idx].min().item())) vrange = (vmin, vmax) # plot! labels = ("t. gal", None, "t. star", None) # Plotting only works on square images assert images.shape[-2] == images.shape[-1] slen = images.shape[-1] - 2 * self.border_padding bp = self.border_padding image = images[i, 0].cpu().numpy() plocs = est.plocs[i].cpu().numpy().reshape(-1, 2) probs = est["galaxy_bools"][i].cpu().numpy().reshape(-1) plot_image(fig, true_ax, image, vrange=vrange, colorbar=True) plot_locs(true_ax, bp, slen, plocs, probs, cmap="cool") plot_image(fig, recon_ax, image, vrange=vrange, colorbar=True) plot_locs(recon_ax, bp, slen, plocs, probs, cmap="cool") residuals_idx = residuals[idx, 0].cpu().numpy() res_vmax = np.ceil(residuals_idx.max()) res_vmin = np.floor(residuals_idx.min()) if self.crop_loss_at_border: bp = (recon_images.shape[-1] - slen) // 2 eff_slen = slen - bp for b in (bp, bp * 2): recon_ax.axvline(b, color="w") recon_ax.axvline(b + eff_slen, color="w") recon_ax.axhline(b, color="w") recon_ax.axhline(b + eff_slen, color="w") plot_image(fig, res_ax, residuals_idx, vrange=(res_vmin, res_vmax)) if self.crop_loss_at_border: for b in (bp, bp * 2): res_ax.axvline(b, color="w") res_ax.axvline(b + eff_slen, color="w") res_ax.axhline(b, color="w") res_ax.axhline(b + eff_slen, color="w") if i == 0: add_legend(true_ax, labels) fig.tight_layout() if self.logger: self.logger.experiment.add_figure( f"Epoch:{self.current_epoch}/Worst Validation Images", fig) plt.close(fig)