예제 #1
0
파일: decoder.py 프로젝트: prob-ml/bliss
    def render_large_scene(
        self, tile_catalog: TileCatalog, batch_size: Optional[int] = None
    ) -> Tensor:
        if batch_size is None:
            batch_size = 75**2 + 500 * 2

        _, n_tiles_h, n_tiles_w, _, _ = tile_catalog.locs.shape
        n_rows_per_batch = batch_size // n_tiles_w
        h = tile_catalog.locs.shape[1] * tile_catalog.tile_slen + 2 * self.border_padding
        w = tile_catalog.locs.shape[2] * tile_catalog.tile_slen + 2 * self.border_padding
        scene = torch.zeros(1, 1, h, w)
        for row in range(0, n_tiles_h, n_rows_per_batch):
            end_row = row + n_rows_per_batch
            start_h = row * tile_catalog.tile_slen
            end_h = end_row * tile_catalog.tile_slen + 2 * self.border_padding
            tile_cat_row = tile_catalog.crop((row, end_row), (0, None))
            img_row = self.render_images(tile_cat_row)
            scene[:, :, start_h:end_h] += img_row.cpu()
        return scene
예제 #2
0
def create_figure_at_point(
    h: int,
    w: int,
    size: int,
    bp: int,
    tile_map_recon: TileCatalog,
    frame: Frame,
    dec: ImageDecoder,
    est_catalog: Optional[FullCatalog] = None,
    show_tiles=False,
    use_image_bounds=False,
    **kwargs,
):
    tile_slen = tile_map_recon.tile_slen
    if h + size + bp > frame.image.shape[2]:
        h = frame.image.shape[2] - size - bp
    if w + size + bp > frame.image.shape[3]:
        w = frame.image.shape[3] - size - bp
    h_tile = (h - bp) // tile_slen
    w_tile = (w - bp) // tile_slen
    n_tiles = size // tile_slen
    hlims = bp + (h_tile * tile_slen), bp + ((h_tile + n_tiles) * tile_slen)
    wlims = bp + (w_tile * tile_slen), bp + ((w_tile + n_tiles) * tile_slen)

    tile_map_cropped = tile_map_recon.crop((h_tile, h_tile + n_tiles),
                                           (w_tile, w_tile + n_tiles))
    full_map_cropped = tile_map_cropped.to_full_params()

    img_cropped = frame.image[0, 0, hlims[0]:hlims[1], wlims[0]:wlims[1]]
    bg_cropped = frame.background[0, 0, hlims[0]:hlims[1], wlims[0]:wlims[1]]
    with torch.no_grad():
        recon_cropped = dec.render_images(tile_map_cropped.to(dec.device))
        recon_cropped = recon_cropped.to("cpu")[0, 0, bp:-bp,
                                                bp:-bp] + bg_cropped
        resid_cropped = (img_cropped - recon_cropped) / recon_cropped.sqrt()

    if est_catalog is not None:
        tile_est_catalog = est_catalog.to_tile_params(
            tile_map_cropped.tile_slen, tile_map_cropped.max_sources)
        tile_est_catalog_cropped = tile_est_catalog.crop(
            (h_tile, h_tile + n_tiles), (w_tile, w_tile + n_tiles))
        est_catalog_cropped = tile_est_catalog_cropped.to_full_params()
    else:
        est_catalog_cropped = None
    if show_tiles:
        tile_map = tile_map_cropped
    else:
        tile_map = None
    if use_image_bounds:
        vmin = img_cropped.min().item()
        vmax = img_cropped.max().item()
    else:
        vmin = 800
        vmax = 1200
    return create_figure(
        img_cropped,
        recon_cropped,
        resid_cropped,
        map_recon=full_map_cropped,
        coadd_objects=est_catalog_cropped,
        tile_map=tile_map,
        vmin=vmin,
        vmax=vmax,
        **kwargs,
    )