예제 #1
0
def get_positive_negative_stats(
    true_cat: FullCatalog,
    est_tile_cat: TileCatalog,
    mag_max: float = np.inf,
):
    true_cat = true_cat.apply_mag_bin(-np.inf, mag_max)
    thresholds = np.linspace(0.01, 0.99, 99)
    est_tile_cat = est_tile_cat.copy()

    res = Parallel(n_jobs=10)(
        delayed(stats_for_threshold)(true_cat.plocs, est_tile_cat, t)
        for t in tqdm(thresholds))
    out: Dict[str, Union[int, Tensor]] = {}
    for k in res[0]:
        out[k] = torch.stack([r[k] for r in res])
    out["n_obj"] = true_cat.plocs.shape[1]
    return out
예제 #2
0
파일: binary.py 프로젝트: prob-ml/bliss
    def make_plots(self, batch, n_samples=16):
        """Produced informative plots demonstrating encoder performance."""

        assert n_samples ** (0.5) % 1 == 0
        if n_samples > len(batch["n_sources"]):  # do nothing if low on samples.
            return
        nrows = int(n_samples**0.5)  # for figure

        # extract non-params entries so that 'get_full_params' to works.
        true_param_names = {"locs", "n_sources", "galaxy_bools", "star_bools"}
        true_tile_params = TileCatalog(
            self.tile_slen, {k: v for k, v in batch.items() if k in true_param_names}
        )
        true_params = true_tile_params.to_full_params()
        # prediction
        tile_est = true_tile_params.copy()
        shape = tile_est["galaxy_bools"].shape
        tile_est["galaxy_bools"] = batch["pred_galaxy_bools"].reshape(*shape)
        tile_est["star_bools"] = batch["pred_star_bools"].reshape(*shape)
        tile_est["galaxy_probs"] = batch["pred_galaxy_probs"].reshape(*shape)
        est = tile_est.to_full_params()
        # setup figure and axes
        fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize=(12, 12))
        axes = axes.flatten()

        for ii in range(n_samples):
            ax = axes[ii]
            labels = None if ii > 0 else ("t. gal", "p. gal", "t. star", "p. star")
            bp = self.border_padding
            image = batch["images"][ii, 0].cpu().numpy()
            true_plocs = true_params.plocs[ii].cpu().numpy().reshape(-1, 2)
            true_gbools = true_params["galaxy_bools"][ii].cpu().numpy().reshape(-1)
            est_plocs = est.plocs[ii].cpu().numpy().reshape(-1, 2)
            est_gprobs = est["galaxy_probs"][ii].cpu().numpy().reshape(-1)
            slen, _ = image.shape
            reporting.plot_image(fig, ax, image, colorbar=False)
            reporting.plot_locs(ax, bp, slen, true_plocs, true_gbools, m="+", s=30, cmap="cool")
            reporting.plot_locs(ax, bp, slen, est_plocs, est_gprobs, m="x", s=20, cmap="bwr")
            if ii == 0:
                reporting.add_legend(ax, labels)
        fig.tight_layout()

        title = f"Epoch:{self.current_epoch}/Validation Images"
        if self.logger is not None:
            self.logger.experiment.add_figure(title, fig)