def __init__(self, psf, imsize, nthreads=1, backward_undersize=None): self.nthreads = nthreads self.nband, nx_psf, ny_psf = psf.shape _, nx, ny = imsize npad_xl = (nx_psf - nx) // 2 npad_xr = nx_psf - nx - npad_xl npad_yl = (ny_psf - ny) // 2 npad_yr = ny_psf - ny - npad_yl self.padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr)) self.ax = (1, 2) self.unpad_x = slice(npad_xl, -npad_xr) self.unpad_y = slice(npad_yl, -npad_yr) self.lastsize = ny + np.sum(self.padding[-1]) self.psf = psf psf_pad = iFs(psf, axes=self.ax) self.psfhat = r2c(psf_pad, axes=self.ax, forward=True, nthreads=nthreads, inorm=0) # LB - failed experiment? # self.psfhatinv = 1/(self.psfhat + 1.0) if backward_undersize is not None: # set up for backward step nx_psfb = good_size(int(backward_undersize * nx)) ny_psfb = good_size(int(backward_undersize * ny)) npad_xlb = (nx_psfb - nx) // 2 npad_xrb = nx_psfb - nx - npad_xlb npad_ylb = (ny_psfb - ny) // 2 npad_yrb = ny_psfb - ny - npad_ylb self.paddingb = ((0, 0), (npad_xlb, npad_xrb), (npad_ylb, npad_yrb)) self.unpad_xb = slice(npad_xlb, -npad_xrb) self.unpad_yb = slice(npad_ylb, -npad_yrb) self.lastsizeb = ny + np.sum(self.paddingb[-1]) xlb = (nx_psf - nx_psfb) // 2 xrb = nx_psf - nx_psfb - xlb ylb = (ny_psf - ny_psfb) // 2 yrb = ny_psf - ny_psfb - ylb psf_padb = iFs(psf[:, slice(xlb, -xrb), slice(ylb, -yrb)], axes=self.ax) self.psfhatb = r2c(psf_padb, axes=self.ax, forward=True, nthreads=nthreads, inorm=0) else: self.paddingb = self.padding self.unpad_xb = self.unpad_x self.unpad_yb = self.unpad_y self.lastsizeb = self.lastsize self.psfhatb = self.psfhat
def _hessian(x, psfhat, padding, nthreads, unpad_x, unpad_y, lastsize): xhat = iFs(np.pad(x, padding, mode='constant'), axes=(1, 2)) xhat = r2c(xhat, axes=(1, 2), nthreads=nthreads, forward=True, inorm=0) xhat = c2r(xhat * psfhat, axes=(1, 2), forward=False, lastsize=lastsize, inorm=2, nthreads=nthreads) return Fs(xhat, axes=(1, 2))[:, unpad_x, unpad_y]
def convolve2gaussres(image, xx, yy, gaussparf, nthreads, gausspari=None, pfrac=0.5, norm_kernel=False): """ Convolves the image to a specified resolution. Parameters ---------- Image - (nband, nx, ny) array to convolve xx/yy - coordinates on the grid in the same units as gaussparf. gaussparf - tuple containing Gaussian parameters of desired resolution (emaj, emin, pa). gausspari - initial resolution . By default it is assumed that the image is a clean component image with no associated resolution. If beampari is specified, it must be a tuple containing gausspars for each imaging band in the same format. nthreads - number of threads to use for the FFT's. pfrac - padding used for the FFT based convolution. Will pad by pfrac/2 on both sides of image """ nband, nx, ny = image.shape padding, unpad_x, unpad_y = get_padding_info(nx, ny, pfrac) ax = (1, 2) # axes over which to perform fft lastsize = ny + np.sum(padding[-1]) gausskern = Gaussian2D(xx, yy, gaussparf, normalise=norm_kernel) gausskern = np.pad(gausskern[None], padding, mode='constant') gausskernhat = r2c(iFs(gausskern, axes=ax), axes=ax, forward=True, nthreads=nthreads, inorm=0) image = np.pad(image, padding, mode='constant') imhat = r2c(iFs(image, axes=ax), axes=ax, forward=True, nthreads=nthreads, inorm=0) # convolve to desired resolution if gausspari is None: imhat *= gausskernhat else: for i in range(nband): thiskern = Gaussian2D(xx, yy, gausspari[i], normalise=norm_kernel) thiskern = np.pad(thiskern[None], padding, mode='constant') thiskernhat = r2c(iFs(thiskern, axes=ax), axes=ax, forward=True, nthreads=nthreads, inorm=0) convkernhat = np.where(np.abs(thiskernhat)>0.0, gausskernhat/thiskernhat, 0.0) imhat[i] *= convkernhat[0] image = Fs(c2r(imhat, axes=ax, forward=False, lastsize=lastsize, inorm=2, nthreads=nthreads), axes=ax)[:, unpad_x, unpad_y] return image, gausskern[:, unpad_x, unpad_y]
def __init__(self, sigma0, nband, nx, ny, nthreads=8): self.nthreads = nthreads self.nx = nx self.ny = ny nx_psf = 2 * self.nx npad_x = (nx_psf - nx) // 2 ny_psf = 2 * self.ny npad_y = (ny_psf - ny) // 2 self.padding = ((0, 0), (npad_x, npad_x), (npad_y, npad_y)) self.ax = (1, 2) self.unpad_x = slice(npad_x, -npad_x) self.unpad_y = slice(npad_y, -npad_y) self.lastsize = ny + np.sum(self.padding[-1]) # set length scales length_scale = 0.5 K = make_kernel(nx_psf, ny_psf, sigma0, length_scale) self.K = K K_pad = iFs(self.K, axes=self.ax) self.Khat = r2c(K_pad, axes=self.ax, forward=True, nthreads=nthreads, inorm=0) self.Khatinv = np.where(self.Khat.real > 1e-14, 1.0 / self.Khat, 1e-14) # get covariance in each dimension # pixel coordinates self.Kv = mock_array(nband) # np.eye(nband) * sigma0**2 self.Kvinv = mock_array(nband) # np.eye(nband) / sigma0**2 if nx == ny: l_coord = m_coord = np.arange(-(nx // 2), nx // 2) self.Kl = self.Km = expsq(l_coord, l_coord, 1.0, length_scale) self.Klinv = self.Kminv = np.linalg.pinv(self.Kl, hermitian=True, rcond=1e-12) self.Kl *= sigma0**2 self.Klinv /= sigma0**2 else: l_coord = np.arange(-(nx // 2), nx // 2) m_coord = np.arange(-(ny // 2), ny // 2) self.Kl = expsq(l_coord, l_coord, sigma0, length_scale) self.Km = expsq(m_coord, m_coord, 1.0, length_scale) self.Klinv = np.linalg.pinv(self.Kl, hermitian=True, rcond=1e-12) self.Kminv = np.linalg.pinv(self.Km, hermitian=True, rcond=1e-12) # Kronecker matrices for "fast" matrix vector products self.Kkron = (self.Kv, self.Kl, self.Km) self.Kinvkron = (self.Kvinv, self.Klinv, self.Kminv)
def convolveb(self, x): xhat = iFs(np.pad(x, self.paddingb, mode='constant'), axes=self.ax) xhat = r2c(xhat, axes=self.ax, nthreads=self.nthreads, forward=True, inorm=0) xhat = c2r(xhat * self.psfhatb, axes=self.ax, forward=False, lastsize=self.lastsizeb, inorm=2, nthreads=self.nthreads) return Fs(xhat, axes=self.ax)[:, self.unpad_xb, self.unpad_yb]
def __init__(self, psf, nthreads, sigma0=1.0): self.nthreads = nthreads self.nband, nx_psf, ny_psf = psf.shape nx = nx_psf // 2 ny = ny_psf // 2 npad_x = (nx_psf - nx) // 2 npad_y = (ny_psf - ny) // 2 self.padding = ((0, 0), (npad_x, npad_x), (npad_y, npad_y)) self.ax = (1, 2) self.unpad_x = slice(npad_x, -npad_x) self.unpad_y = slice(npad_y, -npad_y) self.lastsize = ny + np.sum(self.padding[-1]) self.psf = psf psf_pad = iFs(psf, axes=self.ax) self.psfhat = r2c(psf_pad, axes=self.ax, forward=True, nthreads=nthreads, inorm=0)
def _forward(**kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) import numpy as np import numexpr as ne import dask import dask.array as da from dask.distributed import performance_report from pfb.utils.fits import load_fits, set_wcs, save_fits, data_from_header from pfb.opt.hogbom import hogbom from astropy.io import fits print("Loading residual", file=log) residual = load_fits(args.residual, dtype=args.output_type).squeeze() nband, nx, ny = residual.shape hdr = fits.getheader(args.residual) print("Loading psf", file=log) psf = load_fits(args.psf, dtype=args.output_type).squeeze() _, nx_psf, ny_psf = psf.shape hdr_psf = fits.getheader(args.psf) wsums = np.amax(psf.reshape(-1, nx_psf*ny_psf), axis=1) wsum = np.sum(wsums) psf /= wsum psf_mfs = np.sum(psf, axis=0) assert (psf_mfs.max() - 1.0) < 1e-4 residual /= wsum residual_mfs = np.sum(residual, axis=0) # get info required to set WCS ra = np.deg2rad(hdr['CRVAL1']) dec = np.deg2rad(hdr['CRVAL2']) radec = [ra, dec] cell_deg = np.abs(hdr['CDELT1']) if cell_deg != np.abs(hdr['CDELT2']): raise NotImplementedError('cell sizes have to be equal') cell_rad = np.deg2rad(cell_deg) l_coord, ref_l = data_from_header(hdr, axis=1) l_coord -= ref_l m_coord, ref_m = data_from_header(hdr, axis=2) m_coord -= ref_m freq_out, ref_freq = data_from_header(hdr, axis=3) hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq) save_fits(args.output_filename + '_residual_mfs.fits', residual_mfs, hdr_mfs, dtype=args.output_type) rms = np.std(residual_mfs) rmax = np.abs(residual_mfs).max() print("Initial peak residual = %f, rms = %f" % (rmax, rms), file=log) # load beam if args.beam_model is not None: if args.beam_model.endswith('.fits'): # beam already interpolated bhdr = fits.getheader(args.beam_model) l_coord_beam, ref_lb = data_from_header(bhdr, axis=1) l_coord_beam -= ref_lb if not np.array_equal(l_coord_beam, l_coord): raise ValueError("l coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.") m_coord_beam, ref_mb = data_from_header(bhdr, axis=2) m_coord_beam -= ref_mb if not np.array_equal(m_coord_beam, m_coord): raise ValueError("m coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.") freq_beam, _ = data_from_header(bhdr, axis=freq_axis) if not np.array_equal(freq_out, freq_beam): raise ValueError("Freqs of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.") beam_image = load_fits(args.beam_model, dtype=args.output_type).squeeze() elif args.beam_model.lower() == "jimbeam": from katbeam import JimBeam if args.band.lower() == 'l': beam = JimBeam('MKAT-AA-L-JIM-2020') elif args.band.lower() == 'uhf': beam = JimBeam('MKAT-AA-UHF-JIM-2020') else: raise ValueError("Unkown band %s"%args.band[i]) xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij') beam_image = np.zeros(residual.shape, dtype=args.output_type) for v in range(freq_out.size): # freq must be in MHz beam_image[v] = beam.I(xx, yy, freq_out[v]/1e6).astype(args.output_type) else: beam_image = np.ones((nband, nx, ny), dtype=args.output_type) if args.mask is not None: mask = load_fits(args.mask).squeeze() assert mask.shape == (nx, ny) beam_image *= mask[None, :, :] beam_image = da.from_array(beam_image, chunks=(1, -1, -1)) # if weight table is provided we use the vis space Hessian approximation if args.weight_table is not None: print("Solving for update using vis space approximation", file=log) normfact = wsum from pfb.utils.misc import plan_row_chunk from daskms.experimental.zarr import xds_from_zarr xds = xds_from_zarr(args.weight_table)[0] nrow = xds.row.size freq = xds.chan.data nchan = freq.size # bin edges fmin = freq.min() fmax = freq.max() fbins = np.linspace(fmin, fmax, nband + 1) # chan <-> band mapping band_mapping = {} chan_chunks = {} freq_bin_idx = {} freq_bin_counts = {} band_map = np.zeros(freq.size, dtype=np.int32) for band in range(nband): indl = freq >= fbins[band] indu = freq < fbins[band + 1] + 1e-6 band_map = np.where(indl & indu, band, band_map) # to dask arrays bands, bin_counts = np.unique(band_map, return_counts=True) band_mapping = tuple(bands) chan_chunks = {'chan': tuple(bin_counts)} freq = da.from_array(freq, chunks=tuple(bin_counts)) bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1] freq_bin_idx = da.from_array(bin_idx, chunks=1) freq_bin_counts = da.from_array(bin_counts, chunks=1) max_chan_chunk = bin_counts.max() bin_counts = tuple(bin_counts) # the first factor of 3 accounts for the intermediate visibilities # produced in Hessian (i.e. complex data + real weights) memory_per_row = (3 * max_chan_chunk * xds.WEIGHT.data.itemsize + 3 * xds.UVW.data.itemsize) # get approx image size pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # nworker bands on single node row_chunk = plan_row_chunk(args.mem_limit/args.nworkers, band_size, nrow, memory_per_row, args.nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow, memory_per_row, args.nthreads_per_worker) print("nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) residual = da.from_array(residual, chunks=(1, -1, -1)) x0 = da.zeros((nband, nx, ny), chunks=(1, -1, -1), dtype=residual.dtype) xds = xds_from_zarr(args.weight_table, chunks={'row': -1, #row_chunk, 'chan': bin_counts})[0] from pfb.opt.pcg import pcg_wgt model = pcg_wgt(xds.UVW.data, xds.WEIGHT.data.astype(args.output_type), residual, x0, beam_image, freq, freq_bin_idx, freq_bin_counts, cell_rad, args.wstack, args.epsilon, args.double_accum, args.nvthreads, args.sigmainv, wsum, args.cg_tol, args.cg_maxit, args.cg_minit, args.cg_verbose, args.cg_report_freq, args.backtrack).compute() else: # we use the image space approximation print("Solving for update using image space approximation", file=log) normfact = 1.0 from pfb.operators.psf import hessian from ducc0.fft import r2c iFs = np.fft.ifftshift npad_xl = (nx_psf - nx)//2 npad_xr = nx_psf - nx - npad_xl npad_yl = (ny_psf - ny)//2 npad_yr = ny_psf - ny - npad_yl padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr)) unpad_x = slice(npad_xl, -npad_xr) unpad_y = slice(npad_yl, -npad_yr) lastsize = ny + np.sum(padding[-1]) psf_pad = iFs(psf, axes=(1, 2)) psfhat = r2c(psf_pad, axes=(1, 2), forward=True, nthreads=nthreads, inorm=0) psfhat = da.from_array(psfhat, chunks=(1, -1, -1)) residual = da.from_array(residual, chunks=(1, -1, -1)) x0 = da.zeros((nband, nx, ny), chunks=(1, -1, -1)) from pfb.opt.pcg import pcg_psf model = pcg_psf(psfhat, residual, x0, beam_image, args.sigmainv, args.nvthreads, padding, unpad_x, unpad_y, lastsize, args.cg_tol, args.cg_maxit, args.cg_minit, args.cg_verbose, args.cg_report_freq, args.backtrack).compute() print("Saving results", file=log) save_fits(args.output_filename + '_update.fits', model, hdr) model_mfs = np.mean(model, axis=0) save_fits(args.output_filename + '_update_mfs.fits', model_mfs, hdr_mfs) print("All done here.", file=log)
def rfftn(a, axes=None, inorm=0, nthreads=1): return fft.r2c(a, axes=axes, forward=True, inorm=inorm, nthreads=nthreads)
def _clean(**kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) import numpy as np import numexpr as ne import dask import dask.array as da from dask.distributed import performance_report from pfb.utils.fits import load_fits, set_wcs, save_fits, data_from_header from pfb.opt.hogbom import hogbom from astropy.io import fits print("Loading dirty", file=log) dirty = load_fits(args.dirty, dtype=args.output_type).squeeze() nband, nx, ny = dirty.shape hdr = fits.getheader(args.dirty) print("Loading psf", file=log) psf = load_fits(args.psf, dtype=args.output_type).squeeze() _, nx_psf, ny_psf = psf.shape hdr_psf = fits.getheader(args.psf) wsums = np.amax(psf.reshape(-1, nx_psf * ny_psf), axis=1) wsum = np.sum(wsums) psf /= wsum psf_mfs = np.sum(psf, axis=0) assert (psf_mfs.max() - 1.0) < 1e-4 dirty /= wsum dirty_mfs = np.sum(dirty, axis=0) # get info required to set WCS ra = np.deg2rad(hdr['CRVAL1']) dec = np.deg2rad(hdr['CRVAL2']) radec = [ra, dec] cell_deg = np.abs(hdr['CDELT1']) if cell_deg != np.abs(hdr['CDELT2']): raise NotImplementedError('cell sizes have to be equal') cell_rad = np.deg2rad(cell_deg) freq_out, ref_freq = data_from_header(hdr, axis=3) hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq) save_fits(args.output_filename + '_dirty_mfs.fits', dirty_mfs, hdr_mfs, dtype=args.output_type) # set up Hessian approximation if args.weight_table is not None: normfact = wsum from africanus.gridding.wgridder.dask import hessian from pfb.utils.misc import plan_row_chunk from daskms.experimental.zarr import xds_from_zarr xds = xds_from_zarr(args.weight_table)[0] nrow = xds.row.size freqs = xds.chan.data nchan = freqs.size # bin edges fmin = freqs.min() fmax = freqs.max() fbins = np.linspace(fmin, fmax, nband + 1) # chan <-> band mapping band_mapping = {} chan_chunks = {} freq_bin_idx = {} freq_bin_counts = {} band_map = np.zeros(freqs.size, dtype=np.int32) for band in range(nband): indl = freqs >= fbins[band] indu = freqs < fbins[band + 1] + 1e-6 band_map = np.where(indl & indu, band, band_map) # to dask arrays bands, bin_counts = np.unique(band_map, return_counts=True) band_mapping = tuple(bands) chan_chunks = {'chan': tuple(bin_counts)} freqs = da.from_array(freqs, chunks=tuple(bin_counts)) bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1] freq_bin_idx = da.from_array(bin_idx, chunks=1) freq_bin_counts = da.from_array(bin_counts, chunks=1) max_chan_chunk = bin_counts.max() bin_counts = tuple(bin_counts) # the first factor of 3 accounts for the intermediate visibilities # produced in Hessian (i.e. complex data + real weights) memory_per_row = (3 * max_chan_chunk * xds.WEIGHT.data.itemsize + 3 * xds.UVW.data.itemsize) # get approx image size pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # nworker bands on single node row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size, nrow, memory_per_row, args.nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow, memory_per_row, args.nthreads_per_worker) print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) def convolver(x): model = da.from_array(x, chunks=(1, nx, ny), name=False) xds = xds_from_zarr(args.weight_table, chunks={ 'row': row_chunk, 'chan': bin_counts })[0] convolvedim = hessian(xds.UVW.data, freqs, model, freq_bin_idx, freq_bin_counts, cell_rad, weights=xds.WEIGHT.data.astype( args.output_type), nthreads=args.nvthreads, epsilon=args.epsilon, do_wstacking=args.wstack, double_accum=args.double_accum) return convolvedim else: normfact = 1.0 from pfb.operators.psf import hessian from ducc0.fft import r2c iFs = np.fft.ifftshift npad_xl = (nx_psf - nx) // 2 npad_xr = nx_psf - nx - npad_xl npad_yl = (ny_psf - ny) // 2 npad_yr = ny_psf - ny - npad_yl padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr)) unpad_x = slice(npad_xl, -npad_xr) unpad_y = slice(npad_yl, -npad_yr) lastsize = ny + np.sum(padding[-1]) psf_pad = iFs(psf, axes=(1, 2)) psfhat = r2c(psf_pad, axes=(1, 2), forward=True, nthreads=nthreads, inorm=0) psfhat = da.from_array(psfhat, chunks=(1, -1, -1)) def convolver(x): model = da.from_array(x, chunks=(1, nx, ny), name=False) convolvedim = hessian(model, psfhat, padding, nvthreads, unpad_x, unpad_y, lastsize) return convolvedim # psfo = PSF(psf, dirty.shape, nthreads=args.nthreads) # def convolver(x): return psfo.convolve(x) rms = np.std(dirty_mfs) rmax = np.abs(dirty_mfs).max() print("Iter %i: peak residual = %f, rms = %f" % (0, rmax, rms), file=log) residual = dirty.copy() residual_mfs = dirty_mfs.copy() model = np.zeros_like(residual) for k in range(args.nmiter): print("Running Hogbom", file=log) x = hogbom(residual, psf, gamma=args.hb_gamma, pf=args.hb_peak_factor, maxit=args.hb_maxit, verbosity=args.hb_verbose, report_freq=args.hb_report_freq) model += x print("Getting residual", file=log) convimage = convolver(model) dask.visualize(convimage, filename=args.output_filename + '_hessian' + str(k) + '_graph.pdf', optimize_graph=False) with performance_report(filename=args.output_filename + '_hessian' + str(k) + '_per.html'): convimage = dask.compute(convimage, optimize_graph=False)[0] ne.evaluate('dirty - convimage/normfact', out=residual, casting='same_kind') ne.evaluate('sum(residual, axis=0)', out=residual_mfs, casting='same_kind') rms = np.std(residual_mfs) rmax = np.abs(residual_mfs).max() print("Iter %i: peak residual = %f, rms = %f" % (k + 1, rmax, rms), file=log) print("Saving results", file=log) save_fits(args.output_filename + '_model.fits', model, hdr) model_mfs = np.mean(model, axis=0) save_fits(args.output_filename + '_model_mfs.fits', model_mfs, hdr_mfs) save_fits(args.output_filename + '_residual.fits', residual * wsums[:, None, None], hdr) save_fits(args.output_filename + '_residual.fits', residual_mfs, hdr_mfs) print("All done here.", file=log)