예제 #1
0
파일: test_sdss.py 프로젝트: prob-ml/bliss
    def test_sdss(self, devices, get_sdss_galaxies_config):
        cfg = get_sdss_galaxies_config({}, devices)
        sdss_dir = cfg.paths.sdss
        sdss_obj = SloanDigitalSkySurvey(
            sdss_dir,
            run=3900,
            camcol=6,
            fields=[269],
            bands=range(5),
        )
        an_obj = sdss_obj[0]
        for k in ("image", "background", "gain", "nelec_per_nmgy_list",
                  "calibration"):
            assert isinstance(an_obj[k], np.ndarray)

        assert an_obj["field"] == 269
        assert an_obj["gain"][3] == pytest.approx(4.76)
        assert isinstance(an_obj["wcs"], list)

        sdss_obj9 = SloanDigitalSkySurvey(
            sdss_dir,
            run=3900,
            camcol=6,
            fields=[269, 745],
            bands=range(5),
        )
        assert (len(sdss_obj9)) == 2
예제 #2
0
파일: inference.py 프로젝트: prob-ml/bliss
    def __init__(self, sdss_dir: str, pixel_scale: float, coadd_file: str):
        run = 94
        camcol = 1
        field = 12
        bands = (2, )
        sdss_data = SloanDigitalSkySurvey(
            sdss_dir=sdss_dir,
            run=run,
            camcol=camcol,
            fields=(field, ),
            bands=bands,
        )
        self.data = sdss_data[0]
        self.wcs = self.data["wcs"][0]
        self.pixel_scale = pixel_scale

        image = torch.from_numpy(
            self.data["image"][0]).unsqueeze(0).unsqueeze(0)
        background = torch.from_numpy(
            self.data["background"][0]).unsqueeze(0).unsqueeze(0)
        self.image, self.background = apply_mask(
            image,
            background,
            regions=((1200, 1360, 1700, 1900), (280, 400, 1220, 1320)),
            mask_bg_val=865.0,
        )
        self.coadd_file = coadd_file
예제 #3
0
파일: catalog.py 프로젝트: prob-ml/bliss
    def from_file(cls, sdss_path, run, camcol, field, band):
        sdss_path = Path(sdss_path)
        camcol_dir = sdss_path / str(run) / str(camcol) / str(field)
        po_path = camcol_dir / f"photoObj-{run:06d}-{camcol:d}-{field:04d}.fits"
        po_fits = fits.getdata(po_path)
        objc_type = column_to_tensor(po_fits, "objc_type")
        thing_id = column_to_tensor(po_fits, "thing_id")
        ras = column_to_tensor(po_fits, "ra")
        decs = column_to_tensor(po_fits, "dec")
        galaxy_bools = (objc_type == 3) & (thing_id != -1)
        star_bools = (objc_type == 6) & (thing_id != -1)
        star_fluxes = column_to_tensor(po_fits,
                                       "psfflux") * star_bools.reshape(-1, 1)
        star_mags = column_to_tensor(po_fits, "psfmag") * star_bools.reshape(
            -1, 1)
        galaxy_fluxes = column_to_tensor(
            po_fits, "cmodelflux") * galaxy_bools.reshape(-1, 1)
        galaxy_mags = column_to_tensor(
            po_fits, "cmodelmag") * galaxy_bools.reshape(-1, 1)
        fluxes = star_fluxes + galaxy_fluxes
        mags = star_mags + galaxy_mags
        keep = galaxy_bools | star_bools
        galaxy_bools = galaxy_bools[keep]
        star_bools = star_bools[keep]
        ras = ras[keep]
        decs = decs[keep]
        fluxes = fluxes[keep][:, band]
        mags = mags[keep][:, band]

        sdss = SloanDigitalSkySurvey(sdss_path,
                                     run,
                                     camcol,
                                     fields=(field, ),
                                     bands=(band, ))
        wcs: WCS = sdss[0]["wcs"][0]

        pts = []
        prs = []
        for ra, dec in zip(ras, decs):
            pt, pr = wcs.wcs_world2pix(ra, dec, 0)
            pts.append(float(pt))
            prs.append(float(pr))
        pts = torch.tensor(pts) + 0.5  # For consistency with BLISS
        prs = torch.tensor(prs) + 0.5
        plocs = torch.stack((prs, pts), dim=-1)
        nobj = plocs.shape[0]

        d = {
            "plocs": plocs.reshape(1, nobj, 2),
            "n_sources": torch.tensor((nobj, )),
            "galaxy_bools": galaxy_bools.reshape(1, nobj, 1).float(),
            "star_bools": star_bools.reshape(1, nobj, 1).float(),
            "fluxes": fluxes.reshape(1, nobj, 1),
            "mags": mags.reshape(1, nobj, 1),
        }

        height = sdss[0]["image"].shape[1]
        width = sdss[0]["image"].shape[2]

        return cls(height, width, d)
예제 #4
0
파일: background.py 프로젝트: prob-ml/bliss
 def __init__(self, sdss_dir, run, camcol, field, bands):
     super().__init__()
     sdss_data = SloanDigitalSkySurvey(
         sdss_dir=sdss_dir,
         run=run,
         camcol=camcol,
         fields=(field, ),
         bands=bands,
     )
     background = torch.from_numpy(sdss_data[0]["background"])
     background = rearrange(background, "c h w -> 1 c h w", c=len(bands))
     self.register_buffer("background", background, persistent=False)
     self.height, self.width = self.background.shape[-2:]
예제 #5
0
def get_sdss_data(sdss_dir, sdss_pixel_scale):
    run = 94
    camcol = 1
    field = 12
    bands = (2, )
    sdss_data = SloanDigitalSkySurvey(
        sdss_dir=sdss_dir,
        run=run,
        camcol=camcol,
        fields=(field, ),
        bands=bands,
    )

    return {
        "image": sdss_data[0]["image"][0],
        "background": sdss_data[0]["background"][0],
        "wcs": sdss_data[0]["wcs"][0],
        "pixel_scale": sdss_pixel_scale,
    }
예제 #6
0
    def __init__(
        self,
        detection_encoder: DetectionEncoder,
        binary_encoder: BinaryEncoder,
        location_ckpt: str,
        binary_ckpt: str,
        sdss_dir: str = "data/sdss",
        run: int = 94,
        camcol: int = 1,
        field: int = 12,
        bands: Tuple[int, ...] = (2, ),
        bp: int = 24,
        slen: int = 80,
        h_start: Optional[int] = None,
        w_start: Optional[int] = None,
        scene_size: Optional[int] = None,
        stride_factor: float = 0.5,
        prerender_device: str = "cpu",
    ) -> None:
        """Initializes SDSSBlendedGalaxies.

        Args:
            detection_encoder: A DetectionEncoder model.
            binary_encoder: A BinaryEncoder model.
            location_ckpt: Path of saved state_dict for location encoder.
            binary_ckpt: Path of saved state_dict for binary encoder.
            sdss_dir: Location of data storage for SDSS. Defaults to "data/sdss".
            run: SDSS run.
            camcol: SDSS camcol.
            field: SDSS field.
            bands: SDSS bands of image to use.
            bp: How much border padding around each chunk.
            slen: Side-length of each chunk.
            h_start: Starting height-point of image. If None, start at `bp`.
            w_start: Starting width-point of image. If None, start at `bp`.
            scene_size: Total size of the scene to use. If None, use maximum possible size.
            stride_factor: How much should chunks overlap? If 1.0, no overlap.
            prerender_device: Device to use to prerender chunks.
        """
        super().__init__()
        sdss_data = SloanDigitalSkySurvey(
            sdss_dir=sdss_dir,
            run=run,
            camcol=camcol,
            fields=(field, ),
            bands=bands,
        )
        image = torch.from_numpy(sdss_data[0]["image"][0])
        image = rearrange(image, "h w -> 1 1 h w")
        background = torch.from_numpy(sdss_data[0]["background"][0])
        background = rearrange(background, "h w -> 1 1 h w")

        self.bp = bp
        self.slen = slen + 2 * bp
        self.kernel_size = self.slen + 2 * self.bp
        self.stride = int(self.slen * stride_factor)
        assert self.stride > 0
        self.prerender_device = prerender_device

        if h_start is None:
            h_start = self.bp
        if w_start is None:
            w_start = self.bp
        assert h_start >= self.bp
        assert w_start >= self.bp

        if scene_size is None:
            scene_size = min(image.shape[2] - h_start,
                             image.shape[3] - w_start) - self.bp
        image = image[:, :,
                      (h_start - self.bp):(h_start + scene_size + self.bp),
                      (w_start - self.bp):(w_start + scene_size + self.bp), ]

        detection_encoder.load_state_dict(
            torch.load(location_ckpt, map_location=torch.device("cpu")))
        binary_encoder.load_state_dict(
            torch.load(binary_ckpt, map_location=torch.device("cpu")))
        self.encoder = Encoder(detection_encoder.eval(), binary_encoder.eval())
        self.chunks, self.bgs, self.catalogs = self._prerender_chunks(
            image, background)
예제 #7
0
파일: inference.py 프로젝트: prob-ml/bliss
    def __init__(
        self,
        dataset: SimulatedDataset,
        coadd: str,
        n_tiles_h,
        n_tiles_w,
        cache_dir=None,
    ):
        dataset.to("cpu")
        self.bp = dataset.image_decoder.border_padding
        self.tile_slen = dataset.tile_slen
        self.coadd_file = Path(coadd)
        self.sdss_dir = self.coadd_file.parent
        run = 94
        camcol = 1
        field = 12
        bands = (2, )
        sdss_data = SloanDigitalSkySurvey(
            sdss_dir=self.sdss_dir,
            run=run,
            camcol=camcol,
            fields=(field, ),
            bands=bands,
        )
        wcs = sdss_data[0]["wcs"][0]
        if cache_dir is not None:
            sim_frame_path = Path(cache_dir) / "simulated_frame.pt"
        else:
            sim_frame_path = None
        if sim_frame_path and sim_frame_path.exists():
            tile_catalog_dict, image, background = torch.load(sim_frame_path)
            tile_catalog = TileCatalog(self.tile_slen, tile_catalog_dict)
        else:
            hlim = (self.bp, self.bp + n_tiles_h * self.tile_slen)
            wlim = (self.bp, self.bp + n_tiles_w * self.tile_slen)
            full_coadd_cat = CoaddFullCatalog.from_file(coadd,
                                                        wcs,
                                                        hlim,
                                                        wlim,
                                                        band="r")
            if dataset.image_prior.galaxy_prior is not None:
                full_coadd_cat[
                    "galaxy_params"] = dataset.image_prior.galaxy_prior.sample(
                        full_coadd_cat.n_sources, "cpu").unsqueeze(0)
            full_coadd_cat.plocs = full_coadd_cat.plocs + 0.5
            max_sources = dataset.image_prior.max_sources
            tile_catalog = full_coadd_cat.to_tile_params(
                self.tile_slen, max_sources)
            tile_catalog.set_all_fluxes_and_mags(dataset.image_decoder)
            fc = tile_catalog.to_full_params()
            assert fc.equals(full_coadd_cat,
                             exclude=("galaxy_fluxes", "star_bools", "fluxes",
                                      "mags"))
            print("INFO: started generating frame")
            image, background = dataset.simulate_image_from_catalog(
                tile_catalog)
            print("INFO: done generating frame")
            if sim_frame_path:
                torch.save((tile_catalog.to_dict(), image, background),
                           sim_frame_path)

        self.tile_catalog = tile_catalog
        self.image = image
        self.background = background
        assert self.image.shape[0] == 1
        assert self.background.shape[0] == 1