Пример #1
0
    def test_variational_mode(self, devices):
        """Tests forward function of source encoder.

        Arguments:
            devices: GPU device information.
        """
        device = devices.device

        batch_size = 2
        n_tiles_h = 3
        n_tiles_w = 5
        max_detections = 4
        ptile_slen = 10
        n_bands = 2
        tile_slen = 2
        background = (10.0, 20.0)

        # get encoder
        star_encoder: DetectionEncoder = DetectionEncoder(
            LogBackgroundTransform(),
            channel=8,
            dropout=0,
            spatial_dropout=0,
            hidden=64,
            ptile_slen=ptile_slen,
            tile_slen=tile_slen,
            n_bands=n_bands,
            mean_detections=0.48,
            max_detections=max_detections,
        ).to(device)

        with torch.no_grad():
            star_encoder.eval()

            # simulate image padded tiles
            images = torch.randn(
                batch_size,
                n_bands,
                ptile_slen + (n_tiles_h - 1) * tile_slen,
                ptile_slen + (n_tiles_w - 1) * tile_slen,
                device=device,
            )

            background_tensor = torch.tensor(background, device=device)
            background_tensor = background_tensor.reshape(1, -1, 1, 1).expand(*images.shape)

            images *= background_tensor.sqrt()
            images += background_tensor
            image_ptiles = get_images_in_tiles(
                torch.cat((images, background_tensor), dim=1),
                star_encoder.tile_slen,
                star_encoder.ptile_slen,
            )
            image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w")
            var_params = star_encoder.encode(image_ptiles)
            catalog = star_encoder.variational_mode(var_params)

            assert catalog["n_sources"].size() == torch.Size([batch_size * n_tiles_h * n_tiles_w])
            correct_locs_shape = torch.Size([batch_size * n_tiles_h * n_tiles_w, max_detections, 2])
            assert catalog["locs"].shape == correct_locs_shape
Пример #2
0
    def validation_step(self, batch, batch_idx):
        """Pytorch lightning method."""
        batch_size = len(batch["images"])
        out = self._get_loss(batch)

        # log all losses
        self.log("val/loss", out["loss"], batch_size=batch_size)
        self.log("val/counter_loss",
                 out["counter_loss"].mean(),
                 batch_size=batch_size)
        self.log("val/locs_loss",
                 out["locs_loss"].mean(),
                 batch_size=batch_size)
        self.log("val/star_params_loss",
                 out["star_params_loss"].mean(),
                 batch_size=batch_size)

        catalog_dict = {
            "locs": batch["locs"][:, :, :, 0:self.max_detections],
            "log_fluxes": batch["log_fluxes"][:, :, :, 0:self.max_detections],
            "galaxy_bools": batch["galaxy_bools"][:, :, :,
                                                  0:self.max_detections],
            "n_sources": batch["n_sources"].clamp(max=self.max_detections),
        }
        true_tile_catalog = TileCatalog(self.tile_slen, catalog_dict)
        true_full_catalog = true_tile_catalog.to_full_params()
        image_ptiles = get_images_in_tiles(
            torch.cat((batch["images"], batch["background"]), dim=1),
            self.tile_slen,
            self.ptile_slen,
        )
        image_ptiles = rearrange(image_ptiles,
                                 "n nth ntw b h w -> (n nth ntw) b h w")
        dist_params = self.encode(image_ptiles)
        est_catalog_dict = self.variational_mode(dist_params)
        est_tile_catalog = TileCatalog.from_flat_dict(
            true_tile_catalog.tile_slen,
            true_tile_catalog.n_tiles_h,
            true_tile_catalog.n_tiles_w,
            est_catalog_dict,
        )
        est_full_catalog = est_tile_catalog.to_full_params()

        metrics = self.val_detection_metrics(true_full_catalog,
                                             est_full_catalog)
        self.log("val/precision", metrics["precision"], batch_size=batch_size)
        self.log("val/recall", metrics["recall"], batch_size=batch_size)
        self.log("val/f1", metrics["f1"], batch_size=batch_size)
        self.log("val/avg_distance",
                 metrics["avg_distance"],
                 batch_size=batch_size)
        return batch
Пример #3
0
def get_map_estimate(
    image_encoder: DetectionEncoder, images, background, slen: int, wlen: int = None
):
    # return full estimate of parameters in full image.
    # NOTE: slen*wlen is size of the image without border padding

    if wlen is None:
        wlen = slen
    assert isinstance(slen, int) and isinstance(wlen, int)
    # check image compatibility
    border1 = (images.shape[-2] - slen) / 2
    border2 = (images.shape[-1] - wlen) / 2
    assert border1 == border2, "border paddings on each dimension differ."
    assert slen % image_encoder.tile_slen == 0, "incompatible slen"
    assert wlen % image_encoder.tile_slen == 0, "incompatible wlen"
    assert border1 == image_encoder.border_padding, "incompatible border"

    # obtained estimates per tile, then on full image.
    image_ptiles = get_images_in_tiles(
        torch.cat((images, background), dim=1),
        image_encoder.tile_slen,
        image_encoder.ptile_slen,
    )
    _, n_tiles_h, n_tiles_w, _, _, _ = image_ptiles.shape
    image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w")
    var_params = image_encoder.encode(image_ptiles)
    tile_cutoff = 25**2
    var_params2 = image_encoder.encode(image_ptiles[:tile_cutoff])

    assert torch.allclose(
        var_params["n_source_log_probs"][:tile_cutoff],
        var_params2["n_source_log_probs"],
        atol=1e-5,
    )
    assert torch.allclose(
        var_params["per_source_params"][:tile_cutoff],
        var_params2["per_source_params"],
        atol=1e-5,
    )

    tile_map_dict = image_encoder.variational_mode(var_params)
    tile_map = TileCatalog.from_flat_dict(
        image_encoder.tile_slen, n_tiles_h, n_tiles_w, tile_map_dict
    )
    full_map = tile_map.to_full_params()
    tile_map_tilde = full_map.to_tile_params(image_encoder.tile_slen, tile_map.max_sources)
    assert tile_map.equals(tile_map_tilde, exclude=("n_source_log_probs",), atol=1e-5)

    return full_map
Пример #4
0
 def _make_ptile_loader(self, image: Tensor, background: Tensor,
                        n_tiles_h: int):
     img_bg = torch.cat((image, background), dim=1).to(self.device)
     n_images = image.shape[0]
     for start_b in range(0, n_images, self.n_images_per_batch):
         for row in range(0, n_tiles_h, self.n_rows_per_batch):
             end_b = start_b + self.n_images_per_batch
             end_row = row + self.n_rows_per_batch
             start_h = row * self.detection_encoder.tile_slen
             end_h = end_row * self.detection_encoder.tile_slen + 2 * self.border_padding
             img_bg_cropped = img_bg[start_b:end_b, :, start_h:end_h, :]
             image_ptiles = get_images_in_tiles(
                 img_bg_cropped,
                 self.detection_encoder.tile_slen,
                 self.detection_encoder.ptile_slen,
             )
             yield image_ptiles.reshape(-1, *image_ptiles.shape[-3:])
Пример #5
0
    def _get_loss(self, batch):
        images: Tensor = batch["images"]
        background: Tensor = batch["background"]
        tile_catalog = TileCatalog(self.tile_slen, {
            k: v
            for k, v in batch.items() if k not in {"images", "background"}
        })

        image_ptiles = get_images_in_tiles(
            torch.cat((images, background), dim=1),
            self.tile_slen,
            self.ptile_slen,
        )
        image_ptiles = rearrange(image_ptiles,
                                 "n nth ntw b h w -> (n nth ntw) b h w")
        locs = rearrange(tile_catalog.locs,
                         "n nth ntw ns hw -> 1 (n nth ntw) ns hw")
        galaxy_params, pq_divergence = self._encode(image_ptiles, locs)
        # draw fully reconstructed image.
        # NOTE: Assume recon_mean = recon_var per poisson approximation.
        tile_catalog["galaxy_params"] = rearrange(
            galaxy_params,
            "ns (n nth ntw) ms d -> (ns n) nth ntw ms d",
            nth=tile_catalog.n_tiles_h,
            ntw=tile_catalog.n_tiles_w,
        )
        recon_mean = self.image_decoder.render_images(tile_catalog)
        recon_mean = rearrange(recon_mean, "(ns n) c h w -> ns n c h w", ns=1)
        recon_mean += background.unsqueeze(0)

        assert not torch.any(torch.isnan(recon_mean))
        assert not torch.any(torch.isinf(recon_mean))
        recon_losses = -Normal(recon_mean, recon_mean.sqrt()).log_prob(
            images.unsqueeze(0))
        if self.crop_loss_at_border:
            bp = self.border_padding * 2
            recon_losses = recon_losses[:, :, :, bp:(-bp), bp:(-bp)]
        assert not torch.any(torch.isnan(recon_losses))
        assert not torch.any(torch.isinf(recon_losses))

        # For divergence loss, we only evaluate tiles with a galaxy in them
        galaxy_bools = rearrange(tile_catalog["galaxy_bools"],
                                 "n nth ntw ms 1 -> 1 (n nth ntw) ms")
        divergence_loss = (pq_divergence * galaxy_bools).sum()
        return recon_losses.sum() - divergence_loss
Пример #6
0
    def test_sample(self, devices):
        device = devices.device

        max_detections = 4
        ptile_slen = 10
        n_bands = 2
        tile_slen = 2
        n_samples = 5
        background = (10.0, 20.0)
        background_tensor = torch.tensor(background).view(1, -1, 1, 1).to(device)

        images = (
            torch.randn(1, n_bands, 4 * ptile_slen, 4 * ptile_slen).to(device)
            * background_tensor.sqrt()
            + background_tensor
        )
        background_tensor = background_tensor.expand(*images.shape)

        star_encoder: DetectionEncoder = DetectionEncoder(
            LogBackgroundTransform(),
            channel=8,
            dropout=0,
            spatial_dropout=0,
            hidden=64,
            ptile_slen=ptile_slen,
            tile_slen=tile_slen,
            n_bands=n_bands,
            mean_detections=0.48,
            max_detections=max_detections,
        ).to(device)

        image_ptiles = get_images_in_tiles(
            torch.cat((images, background_tensor), dim=1),
            star_encoder.tile_slen,
            star_encoder.ptile_slen,
        )
        image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w")
        var_params = star_encoder.encode(image_ptiles)
        star_encoder.sample(var_params, n_samples)
Пример #7
0
    def get_prediction(self, batch: Dict[str, Tensor]):
        """Return loss, accuracy, binary probabilities, and MAP classifications for given batch."""

        galaxy_bools = batch["galaxy_bools"].reshape(-1)
        locs = rearrange(batch["locs"], "n nth ntw ms hw -> 1 (n nth ntw) ms hw")
        image_ptiles = get_images_in_tiles(
            torch.cat((batch["images"], batch["background"]), dim=1),
            self.tile_slen,
            self.ptile_slen,
        )
        image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w")
        galaxy_probs = self.forward(image_ptiles, locs)
        galaxy_probs = galaxy_probs.reshape(-1)

        tile_is_on_array = get_is_on_from_n_sources(batch["n_sources"], self.max_sources)
        tile_is_on_array = tile_is_on_array.reshape(-1)

        # we need to calculate cross entropy loss, only for "on" sources
        loss = BCELoss(reduction="none")(galaxy_probs, galaxy_bools) * tile_is_on_array
        loss = loss.sum()

        # get predictions for calculating metrics
        pred_galaxy_bools = (galaxy_probs > 0.5).float() * tile_is_on_array
        correct = ((pred_galaxy_bools.eq(galaxy_bools)) * tile_is_on_array).sum()
        total_n_sources = batch["n_sources"].sum()
        acc = correct / total_n_sources

        # finally organize quantities and return as a dictionary
        pred_star_bools = (1 - pred_galaxy_bools) * tile_is_on_array

        return {
            "loss": loss,
            "acc": acc,
            "galaxy_bools": pred_galaxy_bools,
            "star_bools": pred_star_bools,
            "galaxy_probs": galaxy_probs,
        }
Пример #8
0
    def validation_epoch_end(self,
                             outputs,
                             kind="validation",
                             max_n_samples=16):
        # pylint: disable=too-many-statements
        """Pytorch lightning method."""
        batch: Dict[str, Tensor] = outputs[-1]
        if self.n_bands > 1:
            return
        n_samples = min(
            int(math.sqrt(len(batch["n_sources"])))**2, max_n_samples)
        nrows = int(n_samples**0.5)  # for figure

        catalog_dict = {
            "locs": batch["locs"][:, :, :, 0:self.max_detections],
            "log_fluxes": batch["log_fluxes"][:, :, :, 0:self.max_detections],
            "galaxy_bools": batch["galaxy_bools"][:, :, :,
                                                  0:self.max_detections],
            "star_bools": batch["star_bools"][:, :, :, 0:self.max_detections],
            "n_sources": batch["n_sources"].clamp(max=self.max_detections),
        }
        true_tile_catalog = TileCatalog(self.tile_slen, catalog_dict)
        true_cat = true_tile_catalog.to_full_params()

        image_ptiles = get_images_in_tiles(
            torch.cat((batch["images"], batch["background"]), dim=1),
            self.tile_slen,
            self.ptile_slen,
        )
        image_ptiles = rearrange(image_ptiles,
                                 "n nth ntw b h w -> (n nth ntw) b h w")
        dist_params = self.encode(image_ptiles)

        est_catalog_dict = self.variational_mode(dist_params)
        est_tile_catalog = TileCatalog.from_flat_dict(
            true_tile_catalog.tile_slen,
            true_tile_catalog.n_tiles_h,
            true_tile_catalog.n_tiles_w,
            est_catalog_dict,
        )
        est_cat = est_tile_catalog.to_full_params()

        # setup figure and axes.
        fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize=(12, 12))
        axes = axes.flatten() if nrows > 1 else [axes]

        images = batch["images"]
        assert images.shape[-2] == images.shape[-1]
        bp = self.border_padding
        for idx, ax in enumerate(axes):
            true_n_sources = true_cat.n_sources[idx].item()
            n_sources = est_cat.n_sources[idx].item()
            ax.set_xlabel(f"True num: {true_n_sources}; Est num: {n_sources}")

            # add white border showing where centers of stars and galaxies can be
            ax.axvline(bp, color="w")
            ax.axvline(images.shape[-1] - bp, color="w")
            ax.axhline(bp, color="w")
            ax.axhline(images.shape[-2] - bp, color="w")

            # plot image first
            image = images[idx, 0].cpu().numpy()
            vmin = image.min().item()
            vmax = image.max().item()
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            im = ax.matshow(image, vmin=vmin, vmax=vmax, cmap="viridis")
            fig.colorbar(im, cax=cax, orientation="vertical")

            true_cat.plot_plocs(ax,
                                idx,
                                "galaxy",
                                bp=bp,
                                color="r",
                                marker="x",
                                s=20)
            true_cat.plot_plocs(ax,
                                idx,
                                "star",
                                bp=bp,
                                color="c",
                                marker="x",
                                s=20)
            est_cat.plot_plocs(ax,
                               idx,
                               "all",
                               bp=bp,
                               color="b",
                               marker="+",
                               s=30)

            if idx == 0:
                ax.scatter(None,
                           None,
                           color="r",
                           marker="x",
                           s=20,
                           label="t.gal")
                ax.scatter(None,
                           None,
                           color="c",
                           marker="x",
                           s=20,
                           label="t.star")
                ax.scatter(None,
                           None,
                           color="b",
                           marker="+",
                           s=30,
                           label="p.source")
                ax.legend(
                    bbox_to_anchor=(0.0, 1.2, 1.0, 0.102),
                    loc="lower left",
                    ncol=2,
                    mode="expand",
                    borderaxespad=0.0,
                )

        fig.tight_layout()
        if self.logger:
            if kind == "validation":
                title = f"Epoch:{self.current_epoch}/Validation Images"
                self.logger.experiment.add_figure(title, fig)
            elif kind == "testing":
                self.logger.experiment.add_figure("Test Images", fig)
            else:
                raise NotImplementedError()
        plt.close(fig)
Пример #9
0
    def _get_loss(self, batch: Dict[str, Tensor]):
        true_catalog = {
            "locs":
            rearrange(
                batch["locs"][:, :, :, 0:self.max_detections],
                "n nth ntw ns hw -> (n nth ntw) ns hw",
            ),
            "log_fluxes":
            rearrange(
                batch["log_fluxes"][:, :, :, 0:self.max_detections],
                "n nth ntw ns b -> (n nth ntw) ns b",
            ),
            "galaxy_bools":
            rearrange(
                batch["galaxy_bools"][:, :, :, 0:self.max_detections],
                "n nth ntw ns 1 -> (n nth ntw) ns 1",
            ),
            "n_sources":
            rearrange(batch["n_sources"].clamp(max=self.max_detections),
                      "n nth ntw -> (n nth ntw)"),
        }
        true_catalog["is_on_array"] = get_is_on_from_n_sources(
            true_catalog["n_sources"], self.max_detections)
        image_ptiles = get_images_in_tiles(
            torch.cat((batch["images"], batch["background"]), dim=1),
            self.tile_slen,
            self.ptile_slen,
        )
        image_ptiles = rearrange(image_ptiles,
                                 "n nth ntw b h w -> (n nth ntw) b h w")
        dist_params = self.encode(image_ptiles)
        nslp_flat = rearrange(dist_params["n_source_log_probs"],
                              "n_ptiles ns -> n_ptiles ns")
        counter_loss = F.nll_loss(nslp_flat,
                                  true_catalog["n_sources"].reshape(-1),
                                  reduction="none")

        pred = self._encode_for_n_sources(
            dist_params["per_source_params"],
            rearrange(true_catalog["n_sources"], "n_ptiles -> 1 n_ptiles"),
        )
        locs_log_probs_all = _get_params_logprob_all_combs(
            true_catalog["locs"],
            pred["loc_mean"].squeeze(0),
            pred["loc_sd"].squeeze(0),
        )
        star_params_log_probs_all = _get_params_logprob_all_combs(
            true_catalog["log_fluxes"],
            pred["log_flux_mean"].squeeze(0),
            pred["log_flux_sd"].squeeze(0),
        )

        (locs_loss, star_params_loss) = _get_min_perm_loss(
            locs_log_probs_all,
            star_params_log_probs_all,
            rearrange(true_catalog["galaxy_bools"],
                      "n_ptiles ns 1 -> n_ptiles ns"),
            true_catalog["is_on_array"],
        )

        loss_vec = locs_loss * (locs_loss.detach() <
                                1e6).float() + counter_loss + star_params_loss
        loss = loss_vec.mean()

        return {
            "loss": loss,
            "counter_loss": counter_loss,
            "locs_loss": locs_loss,
            "star_params_loss": star_params_loss,
        }
Пример #10
0
 def get_images_in_ptiles(self, images):
     """Run get_images_in_ptiles with correct tile_slen and ptile_slen."""
     return get_images_in_tiles(images, self.detection_encoder.tile_slen,
                                self.detection_encoder.ptile_slen)
Пример #11
0
    def _make_plots(self, batch, n_samples=5):
        # validate worst reconstruction images.
        n_samples = min(len(batch["n_sources"]), n_samples)
        samples = np.random.choice(len(batch["n_sources"]),
                                   n_samples,
                                   replace=False)
        keys = [
            "images",
            "background",
            "locs",
            "galaxy_bools",
            "star_bools",
            "fluxes",
            "log_fluxes",
            "n_sources",
        ]
        for k in keys:
            batch[k] = batch[k][samples]

        # extract non-params entries so that 'get_full_params' to works.
        images = batch["images"]
        background = batch["background"]
        tile_locs = batch["locs"]

        # obtain map estimates
        image_ptiles = get_images_in_tiles(
            torch.cat((images, background), dim=1),
            self.tile_slen,
            self.ptile_slen,
        )
        _, n_tiles_h, n_tiles_w, _, _, _ = image_ptiles.shape
        image_ptiles = rearrange(image_ptiles,
                                 "n nth ntw b h w -> (n nth ntw) b h w")
        locs = rearrange(tile_locs, "n nth ntw ns hw -> 1 (n nth ntw) ns hw")
        z, _ = self._encode(image_ptiles, locs)
        galaxy_params = rearrange(
            z,
            "ns (n nth ntw) ms d -> (ns n) nth ntw ms d",
            ns=1,
            nth=n_tiles_h,
            ntw=n_tiles_w,
        )

        tile_est = TileCatalog(
            self.tile_slen,
            {
                "n_sources": batch["n_sources"],
                "locs": batch["locs"],
                "galaxy_bools": batch["galaxy_bools"],
                "star_bools": batch["star_bools"],
                "fluxes": batch["fluxes"],
                "log_fluxes": batch["log_fluxes"],
                "galaxy_params": galaxy_params,
            },
        )
        est = tile_est.to_full_params()

        # draw all reconstruction images.
        # render_images automatically accounts for tiles with no galaxies.
        recon_images = self.image_decoder.render_images(tile_est)
        recon_images += background
        residuals = (images - recon_images) / torch.sqrt(recon_images)

        # draw worst `n_samples` examples as measured by absolute avg. residual error.
        worst_indices = residuals.abs().mean(dim=(1, 2, 3)).argsort(
            descending=True)[:n_samples]

        if self.crop_loss_at_border:
            bp = self.border_padding * 2
            residuals[:, :, :bp, :] = 0.0
            residuals[:, :, -bp:, :] = 0.0
            residuals[:, :, :, :bp] = 0.0
            residuals[:, :, :, -bp:] = 0.0

        figsize = (12, 4 * n_samples)
        fig, axes = plt.subplots(nrows=n_samples,
                                 ncols=3,
                                 figsize=figsize,
                                 squeeze=False)

        for i, idx in enumerate(worst_indices):

            true_ax = axes[i, 0]
            recon_ax = axes[i, 1]
            res_ax = axes[i, 2]

            # add titles to axes in the first row
            if i == 0:
                true_ax.set_title("Truth", size=18)
                recon_ax.set_title("Reconstruction", size=18)
                res_ax.set_title("Residual", size=18)

            # vmin, vmax should be shared between reconstruction and true images.
            if self.max_flux_valid_plots is None:
                vmax = np.ceil(
                    max(images[idx].max().item(),
                        recon_images[idx].max().item()))
            else:
                vmax = self.max_flux_valid_plots
            vmin = np.floor(
                min(images[idx].min().item(), recon_images[idx].min().item()))
            vrange = (vmin, vmax)

            # plot!
            labels = ("t. gal", None, "t. star", None)
            # Plotting only works on square images
            assert images.shape[-2] == images.shape[-1]
            slen = images.shape[-1] - 2 * self.border_padding
            bp = self.border_padding
            image = images[i, 0].cpu().numpy()
            plocs = est.plocs[i].cpu().numpy().reshape(-1, 2)
            probs = est["galaxy_bools"][i].cpu().numpy().reshape(-1)
            plot_image(fig, true_ax, image, vrange=vrange, colorbar=True)
            plot_locs(true_ax, bp, slen, plocs, probs, cmap="cool")
            plot_image(fig, recon_ax, image, vrange=vrange, colorbar=True)
            plot_locs(recon_ax, bp, slen, plocs, probs, cmap="cool")
            residuals_idx = residuals[idx, 0].cpu().numpy()
            res_vmax = np.ceil(residuals_idx.max())
            res_vmin = np.floor(residuals_idx.min())
            if self.crop_loss_at_border:
                bp = (recon_images.shape[-1] - slen) // 2
                eff_slen = slen - bp
                for b in (bp, bp * 2):
                    recon_ax.axvline(b, color="w")
                    recon_ax.axvline(b + eff_slen, color="w")
                    recon_ax.axhline(b, color="w")
                    recon_ax.axhline(b + eff_slen, color="w")
            plot_image(fig, res_ax, residuals_idx, vrange=(res_vmin, res_vmax))
            if self.crop_loss_at_border:
                for b in (bp, bp * 2):
                    res_ax.axvline(b, color="w")
                    res_ax.axvline(b + eff_slen, color="w")
                    res_ax.axhline(b, color="w")
                    res_ax.axhline(b + eff_slen, color="w")
            if i == 0:
                add_legend(true_ax, labels)

        fig.tight_layout()
        if self.logger:
            self.logger.experiment.add_figure(
                f"Epoch:{self.current_epoch}/Worst Validation Images", fig)
        plt.close(fig)