def test_sdss(self, devices, get_sdss_galaxies_config): cfg = get_sdss_galaxies_config({}, devices) sdss_dir = cfg.paths.sdss sdss_obj = SloanDigitalSkySurvey( sdss_dir, run=3900, camcol=6, fields=[269], bands=range(5), ) an_obj = sdss_obj[0] for k in ("image", "background", "gain", "nelec_per_nmgy_list", "calibration"): assert isinstance(an_obj[k], np.ndarray) assert an_obj["field"] == 269 assert an_obj["gain"][3] == pytest.approx(4.76) assert isinstance(an_obj["wcs"], list) sdss_obj9 = SloanDigitalSkySurvey( sdss_dir, run=3900, camcol=6, fields=[269, 745], bands=range(5), ) assert (len(sdss_obj9)) == 2
def __init__(self, sdss_dir: str, pixel_scale: float, coadd_file: str): run = 94 camcol = 1 field = 12 bands = (2, ) sdss_data = SloanDigitalSkySurvey( sdss_dir=sdss_dir, run=run, camcol=camcol, fields=(field, ), bands=bands, ) self.data = sdss_data[0] self.wcs = self.data["wcs"][0] self.pixel_scale = pixel_scale image = torch.from_numpy( self.data["image"][0]).unsqueeze(0).unsqueeze(0) background = torch.from_numpy( self.data["background"][0]).unsqueeze(0).unsqueeze(0) self.image, self.background = apply_mask( image, background, regions=((1200, 1360, 1700, 1900), (280, 400, 1220, 1320)), mask_bg_val=865.0, ) self.coadd_file = coadd_file
def from_file(cls, sdss_path, run, camcol, field, band): sdss_path = Path(sdss_path) camcol_dir = sdss_path / str(run) / str(camcol) / str(field) po_path = camcol_dir / f"photoObj-{run:06d}-{camcol:d}-{field:04d}.fits" po_fits = fits.getdata(po_path) objc_type = column_to_tensor(po_fits, "objc_type") thing_id = column_to_tensor(po_fits, "thing_id") ras = column_to_tensor(po_fits, "ra") decs = column_to_tensor(po_fits, "dec") galaxy_bools = (objc_type == 3) & (thing_id != -1) star_bools = (objc_type == 6) & (thing_id != -1) star_fluxes = column_to_tensor(po_fits, "psfflux") * star_bools.reshape(-1, 1) star_mags = column_to_tensor(po_fits, "psfmag") * star_bools.reshape( -1, 1) galaxy_fluxes = column_to_tensor( po_fits, "cmodelflux") * galaxy_bools.reshape(-1, 1) galaxy_mags = column_to_tensor( po_fits, "cmodelmag") * galaxy_bools.reshape(-1, 1) fluxes = star_fluxes + galaxy_fluxes mags = star_mags + galaxy_mags keep = galaxy_bools | star_bools galaxy_bools = galaxy_bools[keep] star_bools = star_bools[keep] ras = ras[keep] decs = decs[keep] fluxes = fluxes[keep][:, band] mags = mags[keep][:, band] sdss = SloanDigitalSkySurvey(sdss_path, run, camcol, fields=(field, ), bands=(band, )) wcs: WCS = sdss[0]["wcs"][0] pts = [] prs = [] for ra, dec in zip(ras, decs): pt, pr = wcs.wcs_world2pix(ra, dec, 0) pts.append(float(pt)) prs.append(float(pr)) pts = torch.tensor(pts) + 0.5 # For consistency with BLISS prs = torch.tensor(prs) + 0.5 plocs = torch.stack((prs, pts), dim=-1) nobj = plocs.shape[0] d = { "plocs": plocs.reshape(1, nobj, 2), "n_sources": torch.tensor((nobj, )), "galaxy_bools": galaxy_bools.reshape(1, nobj, 1).float(), "star_bools": star_bools.reshape(1, nobj, 1).float(), "fluxes": fluxes.reshape(1, nobj, 1), "mags": mags.reshape(1, nobj, 1), } height = sdss[0]["image"].shape[1] width = sdss[0]["image"].shape[2] return cls(height, width, d)
def __init__(self, sdss_dir, run, camcol, field, bands): super().__init__() sdss_data = SloanDigitalSkySurvey( sdss_dir=sdss_dir, run=run, camcol=camcol, fields=(field, ), bands=bands, ) background = torch.from_numpy(sdss_data[0]["background"]) background = rearrange(background, "c h w -> 1 c h w", c=len(bands)) self.register_buffer("background", background, persistent=False) self.height, self.width = self.background.shape[-2:]
def get_sdss_data(sdss_dir, sdss_pixel_scale): run = 94 camcol = 1 field = 12 bands = (2, ) sdss_data = SloanDigitalSkySurvey( sdss_dir=sdss_dir, run=run, camcol=camcol, fields=(field, ), bands=bands, ) return { "image": sdss_data[0]["image"][0], "background": sdss_data[0]["background"][0], "wcs": sdss_data[0]["wcs"][0], "pixel_scale": sdss_pixel_scale, }
def __init__( self, detection_encoder: DetectionEncoder, binary_encoder: BinaryEncoder, location_ckpt: str, binary_ckpt: str, sdss_dir: str = "data/sdss", run: int = 94, camcol: int = 1, field: int = 12, bands: Tuple[int, ...] = (2, ), bp: int = 24, slen: int = 80, h_start: Optional[int] = None, w_start: Optional[int] = None, scene_size: Optional[int] = None, stride_factor: float = 0.5, prerender_device: str = "cpu", ) -> None: """Initializes SDSSBlendedGalaxies. Args: detection_encoder: A DetectionEncoder model. binary_encoder: A BinaryEncoder model. location_ckpt: Path of saved state_dict for location encoder. binary_ckpt: Path of saved state_dict for binary encoder. sdss_dir: Location of data storage for SDSS. Defaults to "data/sdss". run: SDSS run. camcol: SDSS camcol. field: SDSS field. bands: SDSS bands of image to use. bp: How much border padding around each chunk. slen: Side-length of each chunk. h_start: Starting height-point of image. If None, start at `bp`. w_start: Starting width-point of image. If None, start at `bp`. scene_size: Total size of the scene to use. If None, use maximum possible size. stride_factor: How much should chunks overlap? If 1.0, no overlap. prerender_device: Device to use to prerender chunks. """ super().__init__() sdss_data = SloanDigitalSkySurvey( sdss_dir=sdss_dir, run=run, camcol=camcol, fields=(field, ), bands=bands, ) image = torch.from_numpy(sdss_data[0]["image"][0]) image = rearrange(image, "h w -> 1 1 h w") background = torch.from_numpy(sdss_data[0]["background"][0]) background = rearrange(background, "h w -> 1 1 h w") self.bp = bp self.slen = slen + 2 * bp self.kernel_size = self.slen + 2 * self.bp self.stride = int(self.slen * stride_factor) assert self.stride > 0 self.prerender_device = prerender_device if h_start is None: h_start = self.bp if w_start is None: w_start = self.bp assert h_start >= self.bp assert w_start >= self.bp if scene_size is None: scene_size = min(image.shape[2] - h_start, image.shape[3] - w_start) - self.bp image = image[:, :, (h_start - self.bp):(h_start + scene_size + self.bp), (w_start - self.bp):(w_start + scene_size + self.bp), ] detection_encoder.load_state_dict( torch.load(location_ckpt, map_location=torch.device("cpu"))) binary_encoder.load_state_dict( torch.load(binary_ckpt, map_location=torch.device("cpu"))) self.encoder = Encoder(detection_encoder.eval(), binary_encoder.eval()) self.chunks, self.bgs, self.catalogs = self._prerender_chunks( image, background)
def __init__( self, dataset: SimulatedDataset, coadd: str, n_tiles_h, n_tiles_w, cache_dir=None, ): dataset.to("cpu") self.bp = dataset.image_decoder.border_padding self.tile_slen = dataset.tile_slen self.coadd_file = Path(coadd) self.sdss_dir = self.coadd_file.parent run = 94 camcol = 1 field = 12 bands = (2, ) sdss_data = SloanDigitalSkySurvey( sdss_dir=self.sdss_dir, run=run, camcol=camcol, fields=(field, ), bands=bands, ) wcs = sdss_data[0]["wcs"][0] if cache_dir is not None: sim_frame_path = Path(cache_dir) / "simulated_frame.pt" else: sim_frame_path = None if sim_frame_path and sim_frame_path.exists(): tile_catalog_dict, image, background = torch.load(sim_frame_path) tile_catalog = TileCatalog(self.tile_slen, tile_catalog_dict) else: hlim = (self.bp, self.bp + n_tiles_h * self.tile_slen) wlim = (self.bp, self.bp + n_tiles_w * self.tile_slen) full_coadd_cat = CoaddFullCatalog.from_file(coadd, wcs, hlim, wlim, band="r") if dataset.image_prior.galaxy_prior is not None: full_coadd_cat[ "galaxy_params"] = dataset.image_prior.galaxy_prior.sample( full_coadd_cat.n_sources, "cpu").unsqueeze(0) full_coadd_cat.plocs = full_coadd_cat.plocs + 0.5 max_sources = dataset.image_prior.max_sources tile_catalog = full_coadd_cat.to_tile_params( self.tile_slen, max_sources) tile_catalog.set_all_fluxes_and_mags(dataset.image_decoder) fc = tile_catalog.to_full_params() assert fc.equals(full_coadd_cat, exclude=("galaxy_fluxes", "star_bools", "fluxes", "mags")) print("INFO: started generating frame") image, background = dataset.simulate_image_from_catalog( tile_catalog) print("INFO: done generating frame") if sim_frame_path: torch.save((tile_catalog.to_dict(), image, background), sim_frame_path) self.tile_catalog = tile_catalog self.image = image self.background = background assert self.image.shape[0] == 1 assert self.background.shape[0] == 1