def get_positive_negative_stats( true_cat: FullCatalog, est_tile_cat: TileCatalog, mag_max: float = np.inf, ): true_cat = true_cat.apply_mag_bin(-np.inf, mag_max) thresholds = np.linspace(0.01, 0.99, 99) est_tile_cat = est_tile_cat.copy() res = Parallel(n_jobs=10)( delayed(stats_for_threshold)(true_cat.plocs, est_tile_cat, t) for t in tqdm(thresholds)) out: Dict[str, Union[int, Tensor]] = {} for k in res[0]: out[k] = torch.stack([r[k] for r in res]) out["n_obj"] = true_cat.plocs.shape[1] return out
def make_plots(self, batch, n_samples=16): """Produced informative plots demonstrating encoder performance.""" assert n_samples ** (0.5) % 1 == 0 if n_samples > len(batch["n_sources"]): # do nothing if low on samples. return nrows = int(n_samples**0.5) # for figure # extract non-params entries so that 'get_full_params' to works. true_param_names = {"locs", "n_sources", "galaxy_bools", "star_bools"} true_tile_params = TileCatalog( self.tile_slen, {k: v for k, v in batch.items() if k in true_param_names} ) true_params = true_tile_params.to_full_params() # prediction tile_est = true_tile_params.copy() shape = tile_est["galaxy_bools"].shape tile_est["galaxy_bools"] = batch["pred_galaxy_bools"].reshape(*shape) tile_est["star_bools"] = batch["pred_star_bools"].reshape(*shape) tile_est["galaxy_probs"] = batch["pred_galaxy_probs"].reshape(*shape) est = tile_est.to_full_params() # setup figure and axes fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize=(12, 12)) axes = axes.flatten() for ii in range(n_samples): ax = axes[ii] labels = None if ii > 0 else ("t. gal", "p. gal", "t. star", "p. star") bp = self.border_padding image = batch["images"][ii, 0].cpu().numpy() true_plocs = true_params.plocs[ii].cpu().numpy().reshape(-1, 2) true_gbools = true_params["galaxy_bools"][ii].cpu().numpy().reshape(-1) est_plocs = est.plocs[ii].cpu().numpy().reshape(-1, 2) est_gprobs = est["galaxy_probs"][ii].cpu().numpy().reshape(-1) slen, _ = image.shape reporting.plot_image(fig, ax, image, colorbar=False) reporting.plot_locs(ax, bp, slen, true_plocs, true_gbools, m="+", s=30, cmap="cool") reporting.plot_locs(ax, bp, slen, est_plocs, est_gprobs, m="x", s=20, cmap="bwr") if ii == 0: reporting.add_legend(ax, labels) fig.tight_layout() title = f"Epoch:{self.current_epoch}/Validation Images" if self.logger is not None: self.logger.experiment.add_figure(title, fig)