def main(args): # get coord info hdr = fits.getheader(args.image) l_coord, ref_l = data_from_header(hdr, axis=1) l_coord -= ref_l m_coord, ref_m = data_from_header(hdr, axis=2) m_coord -= ref_m if hdr["CTYPE4"].lower() == 'freq': freq_axis = 4 elif hdr["CTYPE3"].lower() == 'freq': freq_axis = 3 else: raise ValueError("Freq axis must be 3rd or 4th") freqs, ref_freq = data_from_header(hdr, axis=freq_axis) xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij') # interpolate primary beam to fits header and optionally average over time beam_image = interpolate_beam(xx, yy, freqs, args) # save power beam save_fits(args.output_filename, beam_image, hdr) print("Wrote interpolated beam cube to %s \n" % args.output_filename) return
def sara(psf, model, residual, mask=None, beam_image=None, hessian=None, wsum=1, adapt_sig21=True, hdr=None, hdr_mfs=None, outfile=None, cpsf=None, nthreads=1, sig_21=1e-6, sigma_frac=100, maxit=10, tol=1e-3, gamma=0.99, psi_levels=2, psi_basis=None, alpha=None, pdtol=1e-6, pdmaxit=250, pdverbose=1, positivity=True, cgtol=1e-6, cgminit=25, cgmaxit=150, cgverbose=1, pmtol=1e-5, pmmaxit=50, pmverbose=1): if len(residual.shape) > 3: raise ValueError("Residual must have shape (nband, nx, ny)") nband, nx, ny = residual.shape if beam_image is None: def beam(x): return x def beaminv(x): return x else: try: assert beam.shape == (nband, nx, ny) def beam(x): return beam_image * x def beaminv(x): return np.where(beam_image > 0.01, x / beam_image, x) except BaseException: raise ValueError("Beam has incorrect shape") if mask is None: def mask(x): return x else: try: if mask.ndim == 2: assert mask.shape == (nx, ny) def mask(x): return mask[None] * x elif mask.ndim == 3: assert mask.shape == (1, nx, ny) def mask(x): return mask * x else: raise ValueError except BaseException: raise ValueError("Mask has incorrect shape") # PSF operator psfo = PSF(psf, residual.shape, nthreads=nthreads) #, backward_undersize=1.2) if cpsf is None: raise ValueError else: cpsfo = PSF(cpsf, residual.shape, nthreads=nthreads) residual_mfs = np.sum(residual, axis=0) residual = mask(beam(residual)) rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) # wavelet dictionary if psi_basis is None: psi = DaskPSI(imsize=residual.shape, nlevels=psi_levels, nthreads=nthreads) else: if not isinstance(psi_basis, list): psi_basis = [psi_basis] psi = DaskPSI(imsize=residual.shape, nlevels=psi_levels, nthreads=nthreads, bases=psi_basis) # set alpha's and sig21's # this assumes that the model has been initialised using NNLS alpha = np.zeros(psi.nbasis) sigmas = np.zeros(psi.nbasis) resid_comps = psi.hdot( residual / np.amax(residual.reshape(-1, nx * ny), axis=1)[:, None, None]) l2_norm = np.linalg.norm(psi.hdot(cpsfo.convolve(model)), axis=1) for m in range(psi.nbasis): alpha[m] = np.std(resid_comps[m]) _, sigmas[m] = expon.fit(l2_norm[m], floc=0.0) print("Basis %i, alpha %f, sigma %f" % (m, alpha[m], sigmas[m]), file=log) # l21 weights and dual weights21 = np.ones((psi.nbasis, psi.nmax), dtype=residual.dtype) for m in range(psi.nbasis): weights21[m] *= sigmas[m] / sig_21 dual = np.zeros((psi.nbasis, nband, psi.nmax), dtype=residual.dtype) # use PSF to approximate Hessian if not passed in if hessian is None: hessian = psfo.convolve wsum = 1.0 # preconditioning operator if model.any(): varmap = np.maximum(rms, sigma_frac * cpsfo.convolve(model)) else: varmap = np.ones(model.shape) * sigma_frac * rms def hessf(x): # return mask(beam(hessian(mask(beam(x)))))/wsum + x / varmap return mask(beam(psfo.convolve(mask(beam(x))))) + x / varmap def hessb(x): return mask(beam(psfo.convolve(mask(beam(x))))) + x / varmap beta, betavec = power_method(hessb, residual.shape, tol=pmtol, maxit=pmmaxit, verbosity=pmverbose) if model.any(): dirty = residual + hessian(mask(beam(model))) / wsum else: dirty = residual # deconvolve for i in range(0, maxit): x = pcg(hessf, mask(beam(residual)), np.zeros_like(residual), M=lambda x: x * varmap, tol=cgtol, maxit=cgmaxit, minit=cgminit, verbosity=cgverbose) # update model modelp = model model = modelp + gamma * x model, dual = primal_dual(hessb, model, modelp, dual, sig_21, psi, weights21, beta, prox_21, tol=pdtol, maxit=pdmaxit, report_freq=50, mask=mask, verbosity=pdverbose, positivity=positivity) # get residual residual, residual_mfs = resid_func(model, dirty, hessian, mask, beam, wsum) model_mfs = np.mean(model, axis=0) x_mfs = np.mean(x, axis=0) # check stopping criteria rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) eps = np.linalg.norm(model - modelp) / np.linalg.norm(model) # update variance map (positivity constraint optional) varmap = np.maximum(rms, sigma_frac * cpsfo.convolve(model)) # update spectral norm beta, betavec = power_method(hessb, residual.shape, b0=betavec, tol=pmtol, maxit=pmmaxit, verbosity=pmverbose) print("Iter %i: peak residual = %f, rms = %f, eps = %f" % (i + 1, rmax, rms, eps), file=log) # reweight l2_norm = np.linalg.norm(psi.hdot(model), axis=1) for m in range(psi.nbasis): if adapt_sig21: _, sigmas[m] = expon.fit(l2_norm[m], floc=0.0) print('basis %i, sigma %f' % sigmas[m], file=log) weights21[m] = alpha[m] / (alpha[m] + l2_norm[m]) * sigmas[m] / sig_21 # save current iteration if outfile is not None: assert hdr is not None assert hdr_mfs is not None save_fits(outfile + str(i + 1) + '_model_mfs.fits', model_mfs, hdr_mfs) save_fits(outfile + str(i + 1) + '_model.fits', model, hdr) save_fits(outfile + str(i + 1) + '_update.fits', x, hdr) save_fits(outfile + str(i + 1) + '_update_mfs.fits', x_mfs, hdr) save_fits(outfile + str(i + 1) + '_residual_mfs.fits', residual_mfs, hdr_mfs) save_fits(outfile + str(i + 1) + '_residual.fits', residual * wsum, hdr) if eps < tol: print("Success, convergence after %i iterations" % (i + 1), file=log) break return model
def _psf(**kw): args = OmegaConf.create(kw) from omegaconf import ListConfig if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig): args.ms = [args.ms] OmegaConf.set_struct(args, True) import numpy as np from pfb.utils.misc import chan_to_band_mapping import dask # from dask.distributed import performance_report from dask.graph_manipulation import clone from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table from daskms import Dataset from daskms.experimental.zarr import xds_to_zarr import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import dirty as vis2im from ducc0.fft import good_size from pfb.utils.misc import stitch_images, plan_row_chunk from pfb.utils.fits import set_wcs, save_fits # chan <-> band mapping ms = args.ms nband = args.nband freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping( ms, nband=nband) # gridder memory budget max_chan_chunk = 0 max_freq = 0 for ims in args.ms: for spw in freqs[ims]: counts = freq_bin_counts[ims][spw].compute() freq = freqs[ims][spw].compute() max_chan_chunk = np.maximum(max_chan_chunk, counts.max()) max_freq = np.maximum(max_freq, freq.max()) # assumes measurement sets have the same columns, # number of correlations etc. xds = xds_from_ms(args.ms[0]) ncorr = xds[0].dims['corr'] nrow = xds[0].dims['row'] # we still have to cater for complex valued data because we cast # the weights to complex but we not longer need to factor the # weight column into our memory budget data_bytes = getattr(xds[0], args.data_column).data.itemsize bytes_per_row = max_chan_chunk * ncorr * data_bytes memory_per_row = bytes_per_row # flags (uint8 or bool) memory_per_row += bytes_per_row / 8 # UVW memory_per_row += xds[0].UVW.data.itemsize * 3 # ANTENNA1/2 memory_per_row += xds[0].ANTENNA1.data.itemsize * 2 # TIME memory_per_row += xds[0].TIME.data.itemsize # data column is not actually read into memory just used to infer # dtype and chunking columns = (args.data_column, args.weight_column, args.flag_column, 'UVW', 'ANTENNA1', 'ANTENNA2', 'TIME') # flag row if 'FLAG_ROW' in xds[0]: columns += ('FLAG_ROW', ) memory_per_row += xds[0].FLAG_ROW.data.itemsize # imaging weights if args.imaging_weight_column is not None: columns += (args.imaging_weight_column, ) memory_per_row += bytes_per_row / 2 # Mueller term (complex valued) if args.mueller_column is not None: columns += (args.mueller_column, ) memory_per_row += bytes_per_row # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 for ims in args.ms: xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1}) for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) uv_max = uv_max.compute() del uvw # image size cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed) if args.cell_size is not None: cell_size = args.cell_size cell_rad = cell_size * np.pi / 60 / 60 / 180 if cell_N / cell_rad < 1: raise ValueError( "Requested cell size too small. " "Super resolution factor = ", cell_N / cell_rad) print("Super resolution factor = %f" % (cell_N / cell_rad), file=log) else: cell_rad = cell_N / args.super_resolution_factor cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % cell_size, file=log) if args.nx is None: fov = args.field_of_view * 3600 npix = int(args.psf_oversize * fov / cell_size) if npix % 2: npix += 1 nx = npix ny = npix else: nx = args.nx ny = args.ny if args.ny is not None else nx print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log) # get approx image size # this is not a conservative estimate when multiple SPW's map to a single # imaging band pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # full image on single node row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size, nrow, memory_per_row, args.nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow, memory_per_row, args.nthreads_per_worker) if args.row_chunks is not None: row_chunk = int(args.row_chunks) if row_chunk == -1: row_chunk = nrow print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) chunks = {} for ims in args.ms: chunks[ims] = [] # xds_from_ms expects a list per ds for spw in freqs[ims]: chunks[ims].append({ 'row': row_chunk, 'chan': chan_chunks[ims][spw]['chan'] }) psfs = [] radec = None # assumes we are only imaging field 0 of first MS out_datasets = [] for ims in args.ms: xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue # this is not correct, need to use spw spw = ds.DATA_DESC_ID uvw = clone(ds.UVW.data) data_type = getattr(ds, args.data_column).data.dtype data_shape = getattr(ds, args.data_column).data.shape data_chunks = getattr(ds, args.data_column).data.chunks weights = getattr(ds, args.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data_shape, chunks=data_chunks) if args.imaging_weight_column is not None: imaging_weights = getattr(ds, args.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data_shape, chunks=data_chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # apply mueller term if args.mueller_column is not None: mueller = getattr(ds, args.mueller_column).data weightsxx *= da.absolute(mueller[:, :, 0])**2 weightsyy *= da.absolute(mueller[:, :, -1])**2 # weighted sum corr to Stokes I weights = weightsxx + weightsyy # MS may contain auto-correlations if 'FLAG_ROW' in xds[0]: frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data == ds.ANTENNA2.data) else: frow = (ds.ANTENNA1.data == ds.ANTENNA2.data) # only keep data where both corrs are unflagged flag = getattr(ds, args.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] # ducc0 uses uint8 mask not flag mask = ~da.logical_or((flagxx | flagyy), frow[:, None]) psf = vis2im(uvw, freqs[ims][spw], weights.astype(data_type), freq_bin_idx[ims][spw], freq_bin_counts[ims][spw], nx, ny, cell_rad, flag=mask.astype(np.uint8), nthreads=args.nvthreads, epsilon=args.epsilon, do_wstacking=args.wstack, double_accum=args.double_accum) psfs.append(psf) data_vars = { 'FIELD_ID': (('row', ), da.full_like(ds.TIME.data, ds.FIELD_ID, chunks=args.row_out_chunk)), 'DATA_DESC_ID': (('row', ), da.full_like(ds.TIME.data, ds.DATA_DESC_ID, chunks=args.row_out_chunk)), 'WEIGHT': (('row', 'chan'), weights.rechunk({0: args.row_out_chunk })), # why no 'f4'? 'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk})) } coords = {'chan': (('chan', ), freqs[ims][spw])} out_ds = Dataset(data_vars, coords) out_datasets.append(out_ds) writes = xds_to_zarr(out_datasets, args.output_filename + '.zarr', columns='ALL') # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False) # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False) if not args.mock: # psfs = dask.compute(psfs, writes, optimize_graph=False)[0] # with performance_report(filename=args.output_filename + '_psf_per.html'): psfs = dask.compute(psfs, writes, optimize_graph=False)[0] psf = stitch_images(psfs, nband, band_mapping) hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, freq_out) save_fits(args.output_filename + '_psf.fits', psf, hdr, dtype=args.output_type) psf_mfs = np.sum(psf, axis=0) wsum = psf_mfs.max() psf_mfs /= wsum hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, np.mean(freq_out)) save_fits(args.output_filename + '_psf_mfs.fits', psf_mfs, hdr_mfs, dtype=args.output_type) print("All done here.", file=log)
def _main(dest=sys.stdout): from pfb.parser import create_parser args = create_parser().parse_args() if not args.nthreads: import multiprocessing args.nthreads = multiprocessing.cpu_count() if not args.mem_limit: import psutil args.mem_limit = int(psutil.virtual_memory()[0] / 1e9) # 100% of memory by default import numpy as np import numba import numexpr import dask import dask.array as da from daskms import xds_from_ms, xds_from_table from astropy.io import fits from pfb.utils.fits import (set_wcs, load_fits, save_fits, compare_headers, data_from_header) from pfb.utils.restoration import fitcleanbeam from pfb.utils.misc import Gaussian2D from pfb.operators.gridder import Gridder from pfb.operators.psf import PSF from pfb.deconv.sara import sara from pfb.deconv.clean import clean from pfb.deconv.spotless import spotless from pfb.deconv.nnls import nnls from pfb.opt.pcg import pcg if not isinstance(args.ms, list): args.ms = [args.ms] pyscilog.log_to_file(args.outfile + '.log') pyscilog.enable_memory_logging(level=3) GD = vars(args) print('Input Options:', file=log) for key in GD.keys(): print(' %25s = %s' % (key, GD[key]), file=log) # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 all_freqs = [] for ims in args.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), columns=('UVW'), chunks={'row': args.row_chunks}) spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") spws = dask.compute(spws)[0] for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) spw = spws[ds.DATA_DESC_ID] tmp_freq = spw.CHAN_FREQ.data.squeeze() all_freqs.append(list([tmp_freq])) uv_max = u_max.compute() del uvw # get Nyquist cell size from africanus.constants import c as lightspeed all_freqs = dask.compute(all_freqs) freq = np.unique(all_freqs) cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed) if args.cell_size is not None: cell_rad = args.cell_size * np.pi / 60 / 60 / 180 if cell_N / cell_rad < 1: raise ValueError( "Requested cell size too small. " "Super resolution factor = ", cell_N / cell_rad) print("Super resolution factor = %f" % (cell_N / cell_rad), file=dest) else: cell_rad = cell_N / args.super_resolution_factor args.cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % args.cell_size, file=dest) if args.nx is None or args.ny is None: from ducc0.fft import good_size fov = args.fov * 3600 npix = int(fov / args.cell_size) if npix % 2: npix += 1 args.nx = good_size(npix) args.ny = good_size(npix) if args.nband is None: args.nband = freq.size print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny), file=dest) # mask if args.mask is not None: mask_array = load_fits(args.mask, dtype=args.real_type).squeeze() if mask_array.shape != (args.nx, args.ny): raise ValueError("Mask has incorrect shape.") # add freq axis mask_array = mask_array[None] def mask(x): return mask_array * x else: mask_array = None def mask(x): return x # init gridder R = Gridder( args.ms, args.nx, args.ny, args.cell_size, nband=args.nband, nthreads=args.nthreads, do_wstacking=args.do_wstacking, row_chunks=args.row_chunks, psf_oversize=args.psf_oversize, data_column=args.data_column, epsilon=args.epsilon, weight_column=args.weight_column, imaging_weight_column=args.imaging_weight_column, model_column=args.model_column, flag_column=args.flag_column, weighting=args.weighting, robust=args.robust, mem_limit=int( 0.8 * args.mem_limit)) # assumes gridding accounts for 80% memory freq_out = R.freq_out radec = R.radec print("PSF size set to (%i, %i, %i)" % (args.nband, R.nx_psf, R.ny_psf), file=dest) # get headers hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx, args.ny, radec, freq_out) hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx, args.ny, radec, np.mean(freq_out)) hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600, R.nx_psf, R.ny_psf, radec, freq_out) hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, R.nx_psf, R.ny_psf, radec, np.mean(freq_out)) # psf if args.psf is not None: try: compare_headers(hdr_psf, fits.getheader(args.psf)) psf = load_fits(args.psf, dtype=args.real_type).squeeze() except BaseException: raise psf = R.make_psf() save_fits(args.outfile + '_psf.fits', psf, hdr_psf) else: psf = R.make_psf() save_fits(args.outfile + '_psf.fits', psf, hdr_psf) # Normalising by wsum (so that the PSF always sums to 1) results in the # most intuitive sig_21 values and by far the least bookkeeping. # However, we won't save the cubes that way as it destroys information # about the noise in image space. Note only the MFS images will have the # usual units of Jy/beam. wsums = np.amax(psf.reshape(args.nband, R.nx_psf * R.ny_psf), axis=1) wsum = np.sum(wsums) psf /= wsum psf_mfs = np.sum(psf, axis=0) # fit restoring psf GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0) GaussPars = fitcleanbeam(psf, level=0.5, pixsize=1.0) cpsf_mfs = np.zeros(psf_mfs.shape, dtype=args.real_type) cpsf = np.zeros(psf.shape, dtype=args.real_type) lpsf = np.arange(-R.nx_psf / 2, R.nx_psf / 2) mpsf = np.arange(-R.ny_psf / 2, R.ny_psf / 2) xx, yy = np.meshgrid(lpsf, mpsf, indexing='ij') cpsf_mfs = Gaussian2D(xx, yy, GaussPar[0], normalise=False) for v in range(args.nband): cpsf[v] = Gaussian2D(xx, yy, GaussPars[v], normalise=False) from pfb.utils.fits import add_beampars GaussPar = list(GaussPar[0]) GaussPar[0] *= args.cell_size / 3600 GaussPar[1] *= args.cell_size / 3600 GaussPar = tuple(GaussPar) hdr_psf_mfs = add_beampars(hdr_psf_mfs, GaussPar) save_fits(args.outfile + '_cpsf_mfs.fits', cpsf_mfs, hdr_psf_mfs) save_fits(args.outfile + '_psf_mfs.fits', psf_mfs, hdr_psf_mfs) GaussPars = list(GaussPars) for b in range(args.nband): GaussPars[b] = list(GaussPars[b]) GaussPars[b][0] *= args.cell_size / 3600 GaussPars[b][1] *= args.cell_size / 3600 GaussPars[b] = tuple(GaussPars[b]) GaussPars = tuple(GaussPars) hdr_psf = add_beampars(hdr_psf, GaussPar, GaussPars) save_fits(args.outfile + '_cpsf.fits', cpsf, hdr_psf) # dirty if args.dirty is not None: try: compare_headers(hdr, fits.getheader(args.dirty)) dirty = load_fits(args.dirty).squeeze() except BaseException: raise dirty = R.make_dirty() save_fits(args.outfile + '_dirty.fits', dirty, hdr) else: dirty = R.make_dirty() save_fits(args.outfile + '_dirty.fits', dirty, hdr) dirty /= wsum dirty_mfs = np.sum(dirty, axis=0) save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs) quit() # initial model and residual if args.x0 is not None: try: compare_headers(hdr, fits.getheader(args.x0)) model = load_fits(args.x0, dtype=args.real_type).squeeze() if args.first_residual is not None: try: compare_headers(hdr, fits.getheader(args.first_residual)) residual = load_fits(args.first_residual, dtype=args.real_type).squeeze() except BaseException: residual = R.make_residual(model) save_fits(args.outfile + '_first_residual.fits', residual, hdr) else: residual = R.make_residual(model) save_fits(args.outfile + '_first_residual.fits', residual, hdr) residual /= wsum except BaseException: model = np.zeros((args.nband, args.nx, args.ny)) residual = dirty.copy() else: model = np.zeros((args.nband, args.nx, args.ny)) residual = dirty.copy() residual_mfs = np.sum(residual, axis=0) save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs) # smooth beam if args.beam_model is not None: if args.beam_model[-5:] == '.fits': beam_image = load_fits(args.beam_model, dtype=args.real_type).squeeze() if beam_image.shape != (args.nband, args.nx, args.ny): raise ValueError("Beam has incorrect shape") elif args.beam_model == "JimBeam": from katbeam import JimBeam if args.band.lower() == 'l': beam = JimBeam('MKAT-AA-L-JIM-2020') else: beam = JimBeam('MKAT-AA-UHF-JIM-2020') beam_image = np.zeros((args.nband, args.nx, args.ny), dtype=args.real_type) l_coord, ref_l = data_from_header(hdr, axis=1) l_coord -= ref_l m_coord, ref_m = data_from_header(hdr, axis=2) m_coord -= ref_m xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij') for v in range(args.nband): beam_image[v] = beam.I(xx, yy, freq_out[v]) def beam(x): return beam_image * x else: beam_image = None def beam(x): return x if args.init_nnls: print("Initialising with NNLS", file=log) model = nnls(psf, model, residual, mask=mask_array, beam_image=beam_image, hdr=hdr, hdr_mfs=hdr_mfs, outfile=args.outfile, maxit=1, nthreads=args.nthreads) residual = R.make_residual(beam(mask(model))) / wsum residual_mfs = np.sum(residual, axis=0) # deconvolve rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) redo_dirty = False print("Peak of initial residual is %f and rms is %f" % (rmax, rms), file=dest) for i in range(0, args.maxit): # run minor cycle of choice modelp = model.copy() if args.deconv_mode == 'sara': model = sara(psf, model, residual, mask=mask_array, beam_image=beam_image, hessian=R.convolve, wsum=wsum, adapt_sig21=args.adapt_sig21, hdr=hdr, hdr_mfs=hdr_mfs, outfile=args.outfile, cpsf=cpsf, nthreads=args.nthreads, sig_21=args.sig_21, sigma_frac=args.sigma_frac, maxit=args.minormaxit, tol=args.minortol, gamma=args.gamma, psi_levels=args.psi_levels, psi_basis=args.psi_basis, pdtol=args.pdtol, pdmaxit=args.pdmaxit, pdverbose=args.pdverbose, positivity=args.positivity, cgtol=args.cgtol, cgminit=args.cgminit, cgmaxit=args.cgmaxit, cgverbose=args.cgverbose, pmtol=args.pmtol, pmmaxit=args.pmmaxit, pmverbose=args.pmverbose) elif args.deconv_mode == 'clean': model = clean(psf, model, residual, mask=mask_array, beam=beam_image, nthreads=args.nthreads, maxit=args.minormaxit, gamma=args.gamma, peak_factor=args.peak_factor, threshold=args.threshold, hbgamma=args.hbgamma, hbpf=args.hbpf, hbmaxit=args.hbmaxit, hbverbose=args.hbverbose) elif args.deconv_mode == 'spotless': model = spotless(psf, model, residual, mask=mask_array, beam_image=beam_image, hessian=R.convolve, wsum=wsum, adapt_sig21=args.adapt_sig21, cpsf=cpsf_mfs, hdr=hdr, hdr_mfs=hdr_mfs, outfile=args.outfile, sig_21=args.sig_21, sigma_frac=args.sigma_frac, nthreads=args.nthreads, gamma=args.gamma, peak_factor=args.peak_factor, maxit=args.minormaxit, tol=args.minortol, threshold=args.threshold, positivity=args.positivity, hbgamma=args.hbgamma, hbpf=args.hbpf, hbmaxit=args.hbmaxit, hbverbose=args.hbverbose, pdtol=args.pdtol, pdmaxit=args.pdmaxit, pdverbose=args.pdverbose, cgtol=args.cgtol, cgminit=args.cgminit, cgmaxit=args.cgmaxit, cgverbose=args.cgverbose, pmtol=args.pmtol, pmmaxit=args.pmmaxit, pmverbose=args.pmverbose) else: raise ValueError("Unknown deconvolution mode ", args.deconv_mode) # get residual if redo_dirty: # Need to do this if weights or Jones has changed # (eg. if we change robustness factor, reweight or calibrate) psf = R.make_psf() wsums = np.amax(psf.reshape(args.nband, R.nx_psf * R.ny_psf), axis=1) wsum = np.sum(wsums) psf /= wsum dirty = R.make_dirty() / wsum # compute in image space # residual = dirty - R.convolve(beam(mask(model))) / wsum residual = R.make_residual(beam(mask(model))) / wsum residual_mfs = np.sum(residual, axis=0) # save current iteration model_mfs = np.mean(model, axis=0) save_fits(args.outfile + '_major' + str(i + 1) + '_model_mfs.fits', model_mfs, hdr_mfs) save_fits(args.outfile + '_major' + str(i + 1) + '_model.fits', model, hdr) save_fits(args.outfile + '_major' + str(i + 1) + '_residual_mfs.fits', residual_mfs, hdr_mfs) save_fits(args.outfile + '_major' + str(i + 1) + '_residual.fits', residual * wsum, hdr) # check stopping criteria rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) eps = np.linalg.norm(model - modelp) / np.linalg.norm(model) print("At iteration %i peak of residual is %f, rms is %f, current " "eps is %f" % (i + 1, rmax, rms, eps), file=dest) if eps < args.tol: break if args.mop_flux: print("Mopping flux", file=dest) # vague Gaussian prior on x def hess(x): return mask(beam(R.convolve(mask(beam(x))))) / wsum + 1e-6 * x def M(x): return x / 1e-6 # preconditioner x = pcg(hess, mask(beam(residual)), np.zeros(residual.shape, dtype=residual.dtype), M=M, tol=0.1 * args.cgtol, maxit=args.cgmaxit, minit=args.cgminit, verbosity=args.cgverbose) model += x # residual = dirty - R.convolve(beam(mask(model))) / wsum residual = R.make_residual(beam(mask(model))) / wsum save_fits(args.outfile + '_mopped_model.fits', model, hdr) save_fits(args.outfile + '_mopped_residual.fits', residual, hdr) model_mfs = np.mean(model, axis=0) save_fits(args.outfile + '_mopped_model_mfs.fits', model_mfs, hdr_mfs) residual_mfs = np.sum(residual, axis=0) save_fits(args.outfile + '_mopped_residual_mfs.fits', residual_mfs, hdr_mfs) rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) print("After mopping flux peak of residual is %f, rms is %f" % (rmax, rms), file=dest) # if args.interp_model: # nband = args.nband # order = args.spectral_poly_order # phi.trim_fat(model) # I = np.argwhere(phi.mask).squeeze() # Ix = I[:, 0] # Iy = I[:, 1] # npix = I.shape[0] # # get components # beta = model[:, Ix, Iy] # # fit integrated polynomial to model components # # we are given frequencies at bin centers, convert to bin edges # ref_freq = np.mean(freq_out) # delta_freq = freq_out[1] - freq_out[0] # wlow = (freq_out - delta_freq/2.0)/ref_freq # whigh = (freq_out + delta_freq/2.0)/ref_freq # wdiff = whigh - wlow # # set design matrix for each component # Xdesign = np.zeros([freq_out.size, args.spectral_poly_order]) # for i in range(1, args.spectral_poly_order+1): # Xdesign[:, i-1] = (whigh**i - wlow**i)/(i*wdiff) # weights = psf_max[:, None] # dirty_comps = Xdesign.T.dot(weights*beta) # hess_comps = Xdesign.T.dot(weights*Xdesign) # comps = np.linalg.solve(hess_comps, dirty_comps) # np.savez(args.outfile + "spectral_comps", comps=comps, ref_freq=ref_freq, mask=np.any(model, axis=0)) if args.write_model: print("Writing model", file=dest) R.write_model(model) if args.make_restored: print("Making restored", file=dest) cpsfo = PSF(cpsf, residual.shape, nthreads=args.nthreads) restored = cpsfo.convolve(model) # residual needs to be in Jy/beam before adding to convolved model wsums = np.amax(psf.reshape(-1, R.nx_psf * R.ny_psf), axis=1) restored += residual / wsums[:, None, None] save_fits(args.outfile + '_restored.fits', restored, hdr) restored_mfs = np.mean(restored, axis=0) save_fits(args.outfile + '_restored_mfs.fits', restored_mfs, hdr_mfs) residual_mfs = np.sum(residual, axis=0)
def spotless( psf, model, residual, mask=None, beam_image=None, hessian=None, wsum=1, adapt_sig21=False, cpsf=None, hdr=None, hdr_mfs=None, outfile=None, nthreads=1, sig_21=1e-3, sigma_frac=100, maxit=10, tol=1e-4, peak_factor=0.01, threshold=0.0, positivity=True, gamma=0.9999, hbgamma=0.1, hbpf=0.1, hbmaxit=5000, hbverbose=1, pdtol=1e-4, pdmaxit=250, pdverbose=1, # primal dual options cgtol=1e-4, cgminit=15, cgmaxit=150, cgverbose=1, # pcg options pmtol=1e-4, pmmaxit=50, pmverbose=1): # power method options """ Modified clean algorithm: psf - PSF image i.e. R.H W where W contains the weights. Shape must be >= residual.shape model - current intrinsic model residual - apparent residual image i.e. R.H W (V - R A x) Note that peak finding happens in apparent residual because that is where it is easiest to accommodate convolution by the PSF. However, the beam and the mask have to be applied to the residual before we solve for the pre-conditioned updates. """ if len(residual.shape) > 3: raise ValueError("Residual must have shape (nband, nx, ny)") nband, nx, ny = residual.shape if beam_image is None: def beam(x): return x def beaminv(x): return x else: try: assert beam.shape == (nband, nx, ny) def beam(x): return beam_image * x def beaminv(x): return np.where(beam_image > 0.01, x / beam_image, x) except BaseException: raise ValueError("Beam has incorrect shape") if mask is None: def mask(x): return x else: try: if mask.ndim == 2: assert mask.shape == (nx, ny) def mask(x): return mask[None] * x elif mask.ndim == 3: assert mask.shape == (1, nx, ny) def mask(x): return mask * x else: raise ValueError except BaseException: raise ValueError("Mask has incorrect shape") # PSF operator psfo = PSF(psf, residual.shape, nthreads=nthreads, backward_undersize=1.2) # set up point sources phi = Dirac(nband, nx, ny, mask=np.any(model, axis=0)) dual = np.zeros((nband, nx, ny), dtype=np.float64) # clean beam if cpsf is not None: try: assert cpsf.shape == (1, ) + psf.shape[1::] except Exception as e: cpsf = cpsf[None, :, :] cpsfo = PSF(cpsf, residual.shape, nthreads=nthreads) residual_mfs = np.sum(residual, axis=0) rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) # preconditioning operator varmap = np.ones(model.shape) * (sigma_frac * rmax) def hessb(x): return phi.hdot(mask(beam(psfo.convolveb(mask(beam(phi.dot(x))))))) +\ x / varmap def hessf(x): return phi.hdot(mask(beam(psfo.convolve(mask(beam(phi.dot(x))))))) +\ x / varmap beta, betavec = power_method(hessb, residual.shape, tol=pmtol, maxit=pmmaxit, verbosity=pmverbose) if hessian is None: hessian = psf.convolve wsum = 1.0 if model.any(): dirty = residual + hessian(mask(beam(model))) / wsum else: dirty = residual # deconvolve threshold = np.maximum(peak_factor * rmax, threshold) alpha = sig_21 for i in range(0, maxit): # find point source candidates modelu = hogbom(mask(residual), psf, gamma=hbgamma, pf=hbpf, maxit=hbmaxit, verbosity=hbverbose) phi.update_locs(modelu) # solve for beta updates x = pcg(hessf, phi.hdot(mask(beam(residual))), phi.hdot(beaminv(modelu)), M=lambda x: x * (sigma_frac * rmax), tol=cgtol, maxit=cgmaxit, minit=cgminit, verbosity=cgverbose) modelp = model.copy() model += gamma * x weights_21 = np.where(phi.mask, alpha / (alpha + np.abs(np.mean(modelp, axis=0))), 1e10) # 1e10 for effective infinity beta, betavec = power_method(hessb, model.shape, b0=betavec, tol=pmtol, maxit=pmmaxit, verbosity=pmverbose) model, dual = primal_dual(hessb, model, modelp, dual, sig_21, phi, weights_21, beta, prox_21m, tol=pdtol, maxit=pdmaxit, axis=0, positivity=positivity, report_freq=50, verbosity=pdverbose) # update Dirac dictionary (remove zero components) phi.trim_fat(model) residual, residual_mfs = resid_func(model, dirty, hessian, mask, beam, wsum) model_mfs = np.mean(model, axis=0) # check stopping criteria rmax = np.abs(mask(residual_mfs)).max() rms = np.std(mask(residual_mfs)) eps = np.linalg.norm(model - modelp) / np.linalg.norm(model) # update variance map (positivity constraint optional) varmap = np.maximum(rmax * sigma_frac, sigma_frac * (rmax + model)) print("Iter %i: peak residual = %f, rms = %f, eps = %f" % (i + 1, rmax, rms, eps), file=log) # save current iteration if outfile is not None: assert hdr is not None assert hdr_mfs is not None save_fits(outfile + str(i + 1) + '_model_mfs.fits', model_mfs, hdr_mfs) save_fits(outfile + str(i + 1) + '_model.fits', model, hdr) save_fits(outfile + str(i + 1) + '_update.fits', x, hdr) save_fits(outfile + str(i + 1) + '_residual_mfs.fits', residual_mfs, hdr_mfs) save_fits(outfile + str(i + 1) + '_residual.fits', residual * wsum, hdr) if rmax < threshold or eps < tol: print("Success, convergence after %i iterations", file=log) break if adapt_sig21: # sig_21 should be set to the std of the image noise from scipy.stats import skew, kurtosis alpha = rms tmp = residual_mfs z = tmp / alpha k = 0 while (np.abs(skew(z.ravel(), nan_policy='omit')) > 0.05 or np.abs(kurtosis(z.ravel(), fisher=True, nan_policy='omit')) > 0.5) and k < 10: # eliminate outliers tmp = np.where(np.abs(z) < 3, residual_mfs, np.nan) alpha = np.nanstd(tmp) z = tmp / alpha print(alpha, skew(z.ravel(), nan_policy='omit'), kurtosis(z.ravel(), fisher=True, nan_policy='omit')) k += 1 sig_21 = alpha print("alpha set to %f" % (alpha), file=log) return model
def _forward(**kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) import numpy as np import numexpr as ne import dask import dask.array as da from dask.distributed import performance_report from pfb.utils.fits import load_fits, set_wcs, save_fits, data_from_header from pfb.opt.hogbom import hogbom from astropy.io import fits print("Loading residual", file=log) residual = load_fits(args.residual, dtype=args.output_type).squeeze() nband, nx, ny = residual.shape hdr = fits.getheader(args.residual) print("Loading psf", file=log) psf = load_fits(args.psf, dtype=args.output_type).squeeze() _, nx_psf, ny_psf = psf.shape hdr_psf = fits.getheader(args.psf) wsums = np.amax(psf.reshape(-1, nx_psf*ny_psf), axis=1) wsum = np.sum(wsums) psf /= wsum psf_mfs = np.sum(psf, axis=0) assert (psf_mfs.max() - 1.0) < 1e-4 residual /= wsum residual_mfs = np.sum(residual, axis=0) # get info required to set WCS ra = np.deg2rad(hdr['CRVAL1']) dec = np.deg2rad(hdr['CRVAL2']) radec = [ra, dec] cell_deg = np.abs(hdr['CDELT1']) if cell_deg != np.abs(hdr['CDELT2']): raise NotImplementedError('cell sizes have to be equal') cell_rad = np.deg2rad(cell_deg) l_coord, ref_l = data_from_header(hdr, axis=1) l_coord -= ref_l m_coord, ref_m = data_from_header(hdr, axis=2) m_coord -= ref_m freq_out, ref_freq = data_from_header(hdr, axis=3) hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq) save_fits(args.output_filename + '_residual_mfs.fits', residual_mfs, hdr_mfs, dtype=args.output_type) rms = np.std(residual_mfs) rmax = np.abs(residual_mfs).max() print("Initial peak residual = %f, rms = %f" % (rmax, rms), file=log) # load beam if args.beam_model is not None: if args.beam_model.endswith('.fits'): # beam already interpolated bhdr = fits.getheader(args.beam_model) l_coord_beam, ref_lb = data_from_header(bhdr, axis=1) l_coord_beam -= ref_lb if not np.array_equal(l_coord_beam, l_coord): raise ValueError("l coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.") m_coord_beam, ref_mb = data_from_header(bhdr, axis=2) m_coord_beam -= ref_mb if not np.array_equal(m_coord_beam, m_coord): raise ValueError("m coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.") freq_beam, _ = data_from_header(bhdr, axis=freq_axis) if not np.array_equal(freq_out, freq_beam): raise ValueError("Freqs of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.") beam_image = load_fits(args.beam_model, dtype=args.output_type).squeeze() elif args.beam_model.lower() == "jimbeam": from katbeam import JimBeam if args.band.lower() == 'l': beam = JimBeam('MKAT-AA-L-JIM-2020') elif args.band.lower() == 'uhf': beam = JimBeam('MKAT-AA-UHF-JIM-2020') else: raise ValueError("Unkown band %s"%args.band[i]) xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij') beam_image = np.zeros(residual.shape, dtype=args.output_type) for v in range(freq_out.size): # freq must be in MHz beam_image[v] = beam.I(xx, yy, freq_out[v]/1e6).astype(args.output_type) else: beam_image = np.ones((nband, nx, ny), dtype=args.output_type) if args.mask is not None: mask = load_fits(args.mask).squeeze() assert mask.shape == (nx, ny) beam_image *= mask[None, :, :] beam_image = da.from_array(beam_image, chunks=(1, -1, -1)) # if weight table is provided we use the vis space Hessian approximation if args.weight_table is not None: print("Solving for update using vis space approximation", file=log) normfact = wsum from pfb.utils.misc import plan_row_chunk from daskms.experimental.zarr import xds_from_zarr xds = xds_from_zarr(args.weight_table)[0] nrow = xds.row.size freq = xds.chan.data nchan = freq.size # bin edges fmin = freq.min() fmax = freq.max() fbins = np.linspace(fmin, fmax, nband + 1) # chan <-> band mapping band_mapping = {} chan_chunks = {} freq_bin_idx = {} freq_bin_counts = {} band_map = np.zeros(freq.size, dtype=np.int32) for band in range(nband): indl = freq >= fbins[band] indu = freq < fbins[band + 1] + 1e-6 band_map = np.where(indl & indu, band, band_map) # to dask arrays bands, bin_counts = np.unique(band_map, return_counts=True) band_mapping = tuple(bands) chan_chunks = {'chan': tuple(bin_counts)} freq = da.from_array(freq, chunks=tuple(bin_counts)) bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1] freq_bin_idx = da.from_array(bin_idx, chunks=1) freq_bin_counts = da.from_array(bin_counts, chunks=1) max_chan_chunk = bin_counts.max() bin_counts = tuple(bin_counts) # the first factor of 3 accounts for the intermediate visibilities # produced in Hessian (i.e. complex data + real weights) memory_per_row = (3 * max_chan_chunk * xds.WEIGHT.data.itemsize + 3 * xds.UVW.data.itemsize) # get approx image size pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # nworker bands on single node row_chunk = plan_row_chunk(args.mem_limit/args.nworkers, band_size, nrow, memory_per_row, args.nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow, memory_per_row, args.nthreads_per_worker) print("nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) residual = da.from_array(residual, chunks=(1, -1, -1)) x0 = da.zeros((nband, nx, ny), chunks=(1, -1, -1), dtype=residual.dtype) xds = xds_from_zarr(args.weight_table, chunks={'row': -1, #row_chunk, 'chan': bin_counts})[0] from pfb.opt.pcg import pcg_wgt model = pcg_wgt(xds.UVW.data, xds.WEIGHT.data.astype(args.output_type), residual, x0, beam_image, freq, freq_bin_idx, freq_bin_counts, cell_rad, args.wstack, args.epsilon, args.double_accum, args.nvthreads, args.sigmainv, wsum, args.cg_tol, args.cg_maxit, args.cg_minit, args.cg_verbose, args.cg_report_freq, args.backtrack).compute() else: # we use the image space approximation print("Solving for update using image space approximation", file=log) normfact = 1.0 from pfb.operators.psf import hessian from ducc0.fft import r2c iFs = np.fft.ifftshift npad_xl = (nx_psf - nx)//2 npad_xr = nx_psf - nx - npad_xl npad_yl = (ny_psf - ny)//2 npad_yr = ny_psf - ny - npad_yl padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr)) unpad_x = slice(npad_xl, -npad_xr) unpad_y = slice(npad_yl, -npad_yr) lastsize = ny + np.sum(padding[-1]) psf_pad = iFs(psf, axes=(1, 2)) psfhat = r2c(psf_pad, axes=(1, 2), forward=True, nthreads=nthreads, inorm=0) psfhat = da.from_array(psfhat, chunks=(1, -1, -1)) residual = da.from_array(residual, chunks=(1, -1, -1)) x0 = da.zeros((nband, nx, ny), chunks=(1, -1, -1)) from pfb.opt.pcg import pcg_psf model = pcg_psf(psfhat, residual, x0, beam_image, args.sigmainv, args.nvthreads, padding, unpad_x, unpad_y, lastsize, args.cg_tol, args.cg_maxit, args.cg_minit, args.cg_verbose, args.cg_report_freq, args.backtrack).compute() print("Saving results", file=log) save_fits(args.output_filename + '_update.fits', model, hdr) model_mfs = np.mean(model, axis=0) save_fits(args.output_filename + '_update_mfs.fits', model_mfs, hdr_mfs) print("All done here.", file=log)
def _spifit(**kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) import dask.array as da import numpy as np from astropy.io import fits from africanus.model.spi.dask import fit_spi_components from pfb.utils.fits import load_fits, save_fits, data_from_header, set_wcs from pfb.utils.misc import convolve2gaussres # get max gausspars gaussparf = None if args.psf_pars is None: if args.residual is None: ppsource = args.image else: ppsource = args.residual for image in ppsource: try: pphdr = fits.getheader(image) except Exception as e: raise e if 'BMAJ0' in pphdr.keys(): emaj = pphdr['BMAJ0'] emin = pphdr['BMIN0'] pa = pphdr['BPA0'] gausspars = [emaj, emin, pa] freq_idx0 = 0 elif 'BMAJ1' in pphdr.keys(): emaj = pphdr['BMAJ1'] emin = pphdr['BMIN1'] pa = pphdr['BPA1'] gausspars = [emaj, emin, pa] freq_idx0 = 1 elif 'BMAJ' in pphdr.keys(): emaj = pphdr['BMAJ'] emin = pphdr['BMIN'] pa = pphdr['BPA'] gausspars = [emaj, emin, pa] freq_idx0 = 0 else: raise ValueError("No beam parameters found in residual." "You will have to provide them manually.") if gaussparf is None: gaussparf = gausspars else: # we need to take the max in both directions gaussparf[0] = np.maximum(gaussparf[0], gausspars[0]) gaussparf[1] = np.maximum(gaussparf[1], gausspars[1]) else: freq_idx0 = 0 # assumption gaussparf = list(args.psf_pars) if args.circ_psf: e = np.maximum(gaussparf[0], gaussparf[1]) gaussparf[0] = e gaussparf[1] = e gaussparf[2] = 0.0 gaussparf = tuple(gaussparf) print("Using emaj = %3.2e, emin = %3.2e, PA = %3.2e \n" % gaussparf, file=log) # get required data products image_dict = {} for i in range(len(args.image)): image_dict[i] = {} # load model image model = load_fits(args.image[i], dtype=args.out_dtype).squeeze() mhdr = fits.getheader(args.image[i]) if model.ndim < 3: model = model[None, :, :] l_coord, ref_l = data_from_header(mhdr, axis=1) l_coord -= ref_l m_coord, ref_m = data_from_header(mhdr, axis=2) m_coord -= ref_m if mhdr["CTYPE4"].lower() == 'freq': freq_axis = 4 stokes_axis = 3 elif mhdr["CTYPE3"].lower() == 'freq': freq_axis = 3 stokes_axis = 4 else: raise ValueError("Freq axis must be 3rd or 4th") freqs, ref_freq = data_from_header(mhdr, axis=freq_axis) image_dict[i]['freqs'] = freqs nband = freqs.size npix_l = l_coord.size npix_m = m_coord.size xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij') # load beam if args.beam_model is not None: bhdr = fits.getheader(args.beam_model[i]) l_coord_beam, ref_lb = data_from_header(bhdr, axis=1) l_coord_beam -= ref_lb if not np.array_equal(l_coord_beam, l_coord): raise ValueError("l coordinates of beam model do not match " "those of image. Use binterp to make " "compatible beam images") m_coord_beam, ref_mb = data_from_header(bhdr, axis=2) m_coord_beam -= ref_mb if not np.array_equal(m_coord_beam, m_coord): raise ValueError("m coordinates of beam model do not match " "those of image. Use binterp to make " "compatible beam images") freqs_beam, _ = data_from_header(bhdr, axis=freq_axis) if not np.array_equal(freqs, freqs_beam): raise ValueError("Freq coordinates of beam model do not match " "those of image. Use binterp to make " "compatible beam images") beam_image = load_fits(args.beam_model[i], dtype=args.out_dtype).squeeze() if beam_image.ndim < 3: beam_image = beam_image[None, :, :] else: beam_image = np.ones(model.shape, dtype=args.out_dtype) image_dict[i]['beam'] = beam_image if not args.dont_convolve: print("Convolving model %i"%i, file=log) # convolve model to desired resolution model, gausskern = convolve2gaussres(model, xx, yy, gaussparf, args.nthreads, None, args.padding_frac) image_dict[i]['model'] = model # add in residuals and set threshold if args.residual is not None: msg = "of residual do not match those of model" rhdr = fits.getheader(args.residual[i]) l_res, ref_lb = data_from_header(rhdr, axis=1) l_res -= ref_lb if not np.array_equal(l_res, l_coord): raise ValueError("l coordinates " + msg) m_res, ref_mb = data_from_header(rhdr, axis=2) m_res -= ref_mb if not np.array_equal(m_res, m_coord): raise ValueError("m coordinates " + msg) freqs_res, _ = data_from_header(rhdr, axis=freq_axis) if not np.array_equal(freqs, freqs_res): raise ValueError("Freqs " + msg) resid = load_fits(args.residual[i], dtype=args.out_dtype).squeeze() if resid.ndim < 3: resid = resid[None, :, :] # convolve residual to same resolution as model gausspari = () for b in range(nband): key = 'BMAJ' + str(b + freq_idx0) if key in rhdr.keys(): emaj = rhdr[key] emin = rhdr[key] pa = rhdr[key] gausspari += ((emaj, emin, pa),) elif 'BMAJ' in rhdr.keys(): emaj = rhdr['BMAJ'] emin = rhdr['BMIN'] pa = rhdr['BPA'] gausspari += ((emaj, emin, pa),) else: print("Can't find Gausspars in residual header, " "unable to add residuals back in", file=log) gausspari = None break if gausspari is not None and args.add_convolved_residuals: print("Convolving residuals %i"%i, file=log) resid, _ = convolve2gaussres(resid, xx, yy, gaussparf, args.nthreads, gausspari, args.padding_frac, norm_kernel=False) model += resid print("Convolved residuals added to convolved model %i"%i, file=log) image_dict[i]['resid'] = resid else: image_dict[i]['resid'] = None # concatenate images along frequency here freqs = [] model = [] beam_image = [] resid = [] for i in image_dict.keys(): freqs.append(image_dict[i]['freqs']) model.append(image_dict[i]['model']) beam_image.append(image_dict[i]['beam']) resid.append(image_dict[i]['resid']) freqs = np.concatenate(freqs, axis=0) Isort = np.argsort(freqs) freqs = freqs[Isort] model = np.concatenate(model, axis=0) model = model[Isort] # create header cell_deg = mhdr['CDELT1'] ra = np.deg2rad(mhdr['CRVAL1']) dec = np.deg2rad(mhdr['CRVAL2']) radec = [ra, dec] nband, nx, ny = model.shape hdr = set_wcs(cell_deg, cell_deg, nx, ny, radec, freqs) for i in range(1, nband+1): hdr['BMAJ' + str(i)] = gaussparf[0] hdr['BMIN' + str(i)] = gaussparf[1] hdr['BPA' + str(i)] = gaussparf[2] if args.ref_freq is None: ref_freq = np.mean(freqs) else: ref_freq = args.ref_freq hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq) hdr_mfs['BMAJ'] = gaussparf[0] hdr_mfs['BMIN'] = gaussparf[1] hdr_mfs['BPA'] = gaussparf[2] # save convolved model if 'm' in args.products: name = args.output_filename + '.convolved_model.fits' save_fits(name, model, hdr, dtype=args.out_dtype) print("Wrote convolved model to %s" % name, file=log) beam_image = np.concatenate(beam_image, axis=0) beam_image = beam_image[Isort] if 'b' in args.products: name = args.output_filename + '.power_beam.fits' save_fits(name, beam_image, hdr, dtype=args.out_dtype) print("Wrote average power beam to %s" % name, file=log) if resid[0] is not None: resid = np.concatenate(resid, axis=0) resid = resid[Isort] if 'r' in args.products: name = args.output_filename + '.convolved_residual.fits' save_fits(name, resid, hdr, dtype=args.out_dtype) print("Wrote convolved residuals to %s" % name, file=log) # get threshold counts = np.sum(resid != 0) rms = np.sqrt(np.sum(resid**2)/counts) rms_cube = np.std(resid.reshape(nband, npix_l*npix_m), axis=1).ravel() threshold = args.threshold * rms else: print("No residual provided. Setting threshold i.t.o dynamic range. " "Max dynamic range is %i " % args.maxdr, file=log) threshold = model.max()/args.maxdr rms_cube = None print("Threshold set to %f Jy. \n" % threshold, file=log) # beam cut off beam_min = np.amin(beam_image, axis=0) model = np.where(beam_min[None] > args.pb_min, model, 0.0) # get pixels above threshold minimage = np.amin(model, axis=0) maskindices = np.argwhere(minimage > threshold) nanindices = np.argwhere(minimage <= threshold) if not maskindices.size: raise ValueError("No components found above threshold. " "Try lowering your threshold." "Max of convolved model is %3.2e" % model.max()) fitcube = model[:, maskindices[:, 0], maskindices[:, 1]].T beam_comps = beam_image[:, maskindices[:, 0], maskindices[:, 1]].T # set weights for fit if rms_cube is not None: print("Using RMS in each imaging band to determine weights.", file=log) weights = np.where(rms_cube > 0, 1.0/rms_cube**2, 0.0) # normalise weights /= weights.max() else: if args.band_weights is not None: weights = np.array(args.band_weights) try: assert weights.size == nband except Exception as e: raise ValueError("Inconsistent weighst provided.") print("Using provided channel weights.", file=log) else: print("No residual or channel weights provided. Using equal weights.", file=log) weights = np.ones(nband, dtype=np.float64) ncomps, _ = fitcube.shape fitcube = da.from_array(fitcube.astype(np.float64), chunks=(ncomps//args.nthreads, nband)) beam_comps = da.from_array(beam_comps.astype(np.float64), chunks=(ncomps//args.nthreads, nband)) weights = da.from_array(weights.astype(np.float64), chunks=(nband)) freqsdask = da.from_array(freqs.astype(np.float64), chunks=(nband)) print("Fitting %i components" % ncomps, file=log) alpha, alpha_err, Iref, i0_err = fit_spi_components(fitcube, weights, freqsdask, np.float64(ref_freq), beam=beam_comps).compute() print("Done. Writing output.", file=log) alphamap = np.zeros(model[0].shape, dtype=model.dtype) alphamap[...] = np.nan alpha_err_map = np.zeros(model[0].shape, dtype=model.dtype) alpha_err_map[...] = np.nan i0map = np.zeros(model[0].shape, dtype=model.dtype) i0map[...] = np.nan i0_err_map = np.zeros(model[0].shape, dtype=model.dtype) i0_err_map[...] = np.nan alphamap[maskindices[:, 0], maskindices[:, 1]] = alpha alpha_err_map[maskindices[:, 0], maskindices[:, 1]] = alpha_err i0map[maskindices[:, 0], maskindices[:, 1]] = Iref i0_err_map[maskindices[:, 0], maskindices[:, 1]] = i0_err if 'I' in args.products: # get the reconstructed cube Irec_cube = i0map[None, :, :] * \ (freqs[:, None, None]/ref_freq)**alphamap[None, :, :] name = args.output_filename + '.Irec_cube.fits' save_fits(name, Irec_cube, hdr, dtype=args.out_dtype) print("Wrote reconstructed cube to %s" % name, file=log) # save alpha map if 'a' in args.products: name = args.output_filename + '.alpha.fits' save_fits(name, alphamap, hdr_mfs, dtype=args.out_dtype) print("Wrote alpha map to %s" % name, file=log) # save alpha error map if 'e' in args.products: name = args.output_filename + '.alpha_err.fits' save_fits(name, alpha_err_map, mhdr, dtype=args.out_dtype) print("Wrote alpha error map to %s" % name, file=log) # save I0 map if 'i' in args.products: name = args.output_filename + '.I0.fits' save_fits(name, i0map, mhdr, dtype=args.out_dtype) print("Wrote I0 map to %s" % name, file=log) # save I0 error map if 'k' in args.products: name = args.output_filename + '.I0_err.fits' save_fits(name, i0_err_map, mhdr, dtype=args.out_dtype) print("Wrote I0 error map to %s" % name, file=log) print("All done here", file=log)
def nnls(**kw): ''' Minor cycle implementing non-negative least squares ''' args = OmegaConf.create(kw) pyscilog.log_to_file(args.output_filename + '.log') pyscilog.enable_memory_logging(level=3) print('Input Options:', file=log) for key in kw.keys(): print(' %25s = %s' % (key, kw[key]), file=log) from pfb.utils.fits import load_fits from astropy.io import fits import numpy as np def resid_func(x, dirty, psfo): """ Returns the unattenuated residual """ residual = dirty - psfo.convolve(x) residual_mfs = np.sum(residual, axis=0) return residual, residual_mfs def value_and_grad(x, dirty, psfo): model_conv = psfo.convolve(x) return np.vdot(x, model_conv - 2 * dirty), 2 * (model_conv - dirty) def prox(x): x[x < args.min_value] = 0.0 return x dirty = load_fits(args.dirty).squeeze() nband, nx, ny = dirty.shape hdr = fits.getheader(args.dirty) psf = load_fits(args.psf).squeeze() _, nx_psf, ny_psf = psf.shape hdr_psf = fits.getheader(args.psf) wsums = np.amax(psf.reshape(-1, nx_psf * ny_psf), axis=1) wsum = np.sum(wsums) psf /= wsum psf_mfs = np.sum(psf, axis=0) assert (psf_mfs.max() - 1.0) < 1e-4 dirty /= wsum dirty_mfs = np.sum(dirty, axis=0) from pfb.operators.psf import PSF psfo = PSF(psf, dirty.shape, nthreads=args.nthreads) from pfb.opt.power_method import power_method beta, betavec = power_method(psfo.convolve, dirty.shape, tol=args.pm_tol, maxit=args.pm_maxit, verbosity=args.pm_verbose, report_freq=args.pm_report_freq) fprime = partial(value_and_grad, dirty=dirty, psfo=psfo) from pfb.opt.fista import fista if args.x0 is None: x0 = np.zeros_like(dirty) else: x0 = load_fits(args.x0, dtype=dirty.dtype).squeeze() model = fista(x0, beta, fprime, prox, tol=args.fista_tol, maxit=args.fista_maxit, verbosity=args.fista_verbose, report_freq=args.fista_report_freq) residual, residual_mfs = resid_func(model, dirty, psfo) from pfb.utils.fits import save_fits save_fits(args.output_filename + '_model.fits', model, hdr) save_fits(args.output_filename + '_residual.fits', residual, hdr)
def nnls(psf, model, residual, mask=None, beam_image=None, hessian=None, wsum=None, gamma=0.95, hdr=None, hdr_mfs=None, outfile=None, nthreads=1, maxit=1, tol=1e-3, pmtol=1e-5, pmmaxit=50, pmverbose=1, ftol=1e-5, fmaxit=250, fverbose=3): if len(residual.shape) > 3: raise ValueError("Residual must have shape (nband, nx, ny)") nband, nx, ny = residual.shape if beam_image is None: def beam(x): return x else: try: assert beam.shape == (nband, nx, ny) def beam(x): return beam_image * x except BaseException: raise ValueError("Beam has incorrect shape") if mask is None: def mask(x): return x else: try: if mask.ndim == 2: assert mask.shape == (nx, ny) def mask(x): return mask[None] * x elif mask.ndim == 3: assert mask.shape == (1, nx, ny) def mask(x): return mask * x else: raise ValueError except BaseException: raise ValueError("Mask has incorrect shape") # PSF operator psfo = PSF(psf, residual.shape, nthreads=nthreads) residual_mfs = np.sum(residual, axis=0) residual = mask(beam(residual)) rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) if hessian is None: hessian = psfo.convolve wsum = 1 def hess(x): return mask(beam(psfo.convolve(mask(beam(x))))) beta, betavec = power_method(hess, residual.shape, tol=pmtol, maxit=pmmaxit, verbosity=pmverbose) if model.any(): dirty = residual + hessian(mask(beam(model))) / wsum else: dirty = residual for i in range(maxit): fprime = partial(value_and_grad, dirty=residual, psfo=psfo, mask=mask, beam=beam) x = fista(np.zeros_like(model), beta, fprime, prox, tol=ftol, maxit=fmaxit, verbosity=fverbose) modelp = model.copy() model += gamma * x residual, residual_mfs = resid_func(model, dirty, hessian, mask, beam, wsum) model_mfs = np.mean(model, axis=0) # check stopping criteria rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) eps = np.linalg.norm(model - modelp) / np.linalg.norm(model) print("Iter %i: peak residual = %f, rms = %f, eps = %f" % (i + 1, rmax, rms, eps), file=log) # save current iteration if outfile is not None: assert hdr is not None assert hdr_mfs is not None save_fits(outfile + str(i + 1) + '_NNLS_model_mfs.fits', model_mfs, hdr_mfs) save_fits(outfile + str(i + 1) + '_NNLS_model.fits', model, hdr) save_fits(outfile + str(i + 1) + '_NNLS_residual_mfs.fits', residual_mfs, hdr_mfs) if eps < tol: print("Success, convergence after %i iterations" % (i + 1), file=log) break return model
def _binterp(**kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) from pfb.utils.fits import save_fits import dask import dask.array as da import numpy as np from numba import jit from astropy.io import fits import warnings from africanus.rime import parallactic_angles from pfb.utils.fits import load_fits, save_fits, data_from_header from daskms import xds_from_ms, xds_from_table if args.ms is None: if args.beam_model.lower() == 'jimbeam': for image in args.image: mhdr = fits.getheader(image) l_coord, ref_l = data_from_header(mhdr, axis=1) l_coord -= ref_l m_coord, ref_m = data_from_header(mhdr, axis=2) m_coord -= ref_m if mhdr["CTYPE4"].lower() == 'freq': freq_axis = 4 stokes_axis = 3 elif mhdr["CTYPE3"].lower() == 'freq': freq_axis = 3 stokes_axis = 4 else: raise ValueError("Freq axis must be 3rd or 4th") freq, ref_freq = data_from_header(mhdr, axis=freq_axis) from katbeam import JimBeam if args.band.lower() == 'l': beam = JimBeam('MKAT-AA-L-JIM-2020') elif args.band.lower() == 'uhf': beam = JimBeam('MKAT-AA-UHF-JIM-2020') else: raise ValueError("Unkown band %s" % args.band[i]) xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij') beam_image = np.zeros((freq.size, l_coord.size, m_coord.size), dtype=args.out_dtype) for v in range(freq.size): # freq must be in MHz beam_image[v] = beam.I(xx, yy, freq[v] / 1e6).astype( args.out_dtype) if args.output_dir in image: idx = len(args.output_dir) iname = image[idx::] outname = iname + '.' + args.postfix else: outname = image + '.' + args.postfix beam_image = np.expand_dims(beam_image, axis=3 - stokes_axis + 1) save_fits(args.output_dir + outname, beam_image, mhdr, dtype=args.out_dtype) else: raise NotImplementedError("Not there yet, sorry") print("All done here.", file=log) # @jit(nopython=True, nogil=True, cache=True) # def _unflagged_counts(flags, time_idx, out): # for i in range(time_idx.size): # ilow = time_idx[i] # ihigh = time_idx[i+1] # out[i] = np.sum(~flags[ilow:ihigh]) # return out # def extract_dde_info(args, freqs): # """ # Computes paralactic angles, antenna scaling and pointing information # required for beam interpolation. # """ # # get ms info required to compute paralactic angles and weighted sum # nband = freqs.size # if args.ms is not None: # utimes = [] # unflag_counts = [] # ant_pos = None # phase_dir = None # for ms_name in args.ms: # # get antenna positions # ant = xds_from_table(ms_name + '::ANTENNA')[0].compute() # if ant_pos is None: # ant_pos = ant['POSITION'].data # else: # check all are the same # tmp = ant['POSITION'] # if not np.array_equal(ant_pos, tmp): # raise ValueError( # "Antenna positions not the same across measurement sets") # # get phase center for field # field = xds_from_table(ms_name + '::FIELD')[0].compute() # if phase_dir is None: # phase_dir = field['PHASE_DIR'][args.field].data.squeeze() # else: # tmp = field['PHASE_DIR'][args.field].data.squeeze() # if not np.array_equal(phase_dir, tmp): # raise ValueError( # 'Phase direction not the same across measurement sets') # # get unique times and count flags # xds = xds_from_ms(ms_name, columns=["TIME", "FLAG_ROW"], group_cols=[ # "FIELD_ID"])[args.field] # utime, time_idx = np.unique( # xds.TIME.data.compute(), return_index=True) # ntime = utime.size # # extract subset of times # if args.sparsify_time > 1: # I = np.arange(0, ntime, args.sparsify_time) # utime = utime[I] # time_idx = time_idx[I] # ntime = utime.size # utimes.append(utime) # flags = xds.FLAG_ROW.data.compute() # unflag_count = _unflagged_counts(flags.astype( # np.int32), time_idx, np.zeros(ntime, dtype=np.int32)) # unflag_counts.append(unflag_count) # utimes = np.concatenate(utimes) # unflag_counts = np.concatenate(unflag_counts) # ntimes = utimes.size # # compute paralactic angles # parangles = parallactic_angles(utimes, ant_pos, phase_dir) # # mean over antanna nant -> 1 # parangles = np.mean(parangles, axis=1, keepdims=True) # nant = 1 # # beam_cube_dde requirements # ant_scale = np.ones((nant, nband, 2), dtype=np.float64) # point_errs = np.zeros((ntimes, nant, nband, 2), dtype=np.float64) # return (parangles, # da.from_array(ant_scale, chunks=ant_scale.shape), # point_errs, # unflag_counts, # True) # else: # ntimes = 1 # nant = 1 # parangles = np.zeros((ntimes, nant,), dtype=np.float64) # ant_scale = np.ones((nant, nband, 2), dtype=np.float64) # point_errs = np.zeros((ntimes, nant, nband, 2), dtype=np.float64) # unflag_counts = np.array([1]) # return (parangles, ant_scale, point_errs, unflag_counts, False) # def make_power_beam(args, lm_source, freqs, use_dask): # print("Loading fits beam patterns from %s" % args.beam_model) # from glob import glob # paths = glob(args.beam_model + '**_**.fits') # beam_hdr = None # if args.corr_type == 'linear': # corr1 = 'XX' # corr2 = 'YY' # elif args.corr_type == 'circular': # corr1 = 'LL' # corr2 = 'RR' # else: # raise KeyError( # "Unknown corr_type supplied. Only 'linear' or 'circular' supported") # for path in paths: # if corr1.lower() in path[-10::]: # if 're' in path[-7::]: # corr1_re = load_fits(path) # if beam_hdr is None: # beam_hdr = fits.getheader(path) # elif 'im' in path[-7::]: # corr1_im = load_fits(path) # else: # raise NotImplementedError("Only re/im patterns supported") # elif corr2.lower() in path[-10::]: # if 're' in path[-7::]: # corr2_re = load_fits(path) # elif 'im' in path[-7::]: # corr2_im = load_fits(path) # else: # raise NotImplementedError("Only re/im patterns supported") # # get power beam # beam_amp = (corr1_re**2 + corr1_im**2 + corr2_re**2 + corr2_im**2)/2.0 # # get cube in correct shape for interpolation code # beam_amp = np.ascontiguousarray(np.transpose(beam_amp, (1, 2, 0)) # [:, :, :, None, None]) # # get cube info # if beam_hdr['CUNIT1'].lower() != "deg": # raise ValueError("Beam image units must be in degrees") # npix_l = beam_hdr['NAXIS1'] # refpix_l = beam_hdr['CRPIX1'] # delta_l = beam_hdr['CDELT1'] # l_min = (1 - refpix_l)*delta_l # l_max = (1 + npix_l - refpix_l)*delta_l # if beam_hdr['CUNIT2'].lower() != "deg": # raise ValueError("Beam image units must be in degrees") # npix_m = beam_hdr['NAXIS2'] # refpix_m = beam_hdr['CRPIX2'] # delta_m = beam_hdr['CDELT2'] # m_min = (1 - refpix_m)*delta_m # m_max = (1 + npix_m - refpix_m)*delta_m # if (l_min > lm_source[:, 0].min() or m_min > lm_source[:, 1].min() or # l_max < lm_source[:, 0].max() or m_max < lm_source[:, 1].max()): # raise ValueError("The supplied beam is not large enough") # beam_extents = np.array([[l_min, l_max], [m_min, m_max]]) # # get frequencies # if beam_hdr["CTYPE3"].lower() != 'freq': # raise ValueError( # "Cubes are assumed to be in format [nchan, nx, ny]") # nchan = beam_hdr['NAXIS3'] # refpix = beam_hdr['CRPIX3'] # delta = beam_hdr['CDELT3'] # assumes units are Hz # freq0 = beam_hdr['CRVAL3'] # bfreqs = freq0 + np.arange(1 - refpix, 1 + nchan - refpix) * delta # if bfreqs[0] > freqs[0] or bfreqs[-1] < freqs[-1]: # warnings.warn("The supplied beam does not have sufficient " # "bandwidth. Beam frequencies:") # with np.printoptions(precision=2): # print(bfreqs) # if use_dask: # return (da.from_array(beam_amp, chunks=beam_amp.shape), # da.from_array(beam_extents, chunks=beam_extents.shape), # da.from_array(bfreqs, bfreqs.shape)) # else: # return beam_amp, beam_extents, bfreqs # def interpolate_beam(ll, mm, freqs, args): # """ # Interpolate beam to image coordinates and optionally compute average # over time if MS is provoded # """ # nband = freqs.size # print("Interpolating beam") # parangles, ant_scale, point_errs, unflag_counts, use_dask = extract_dde_info( # args, freqs) # lm_source = np.vstack((ll.ravel(), mm.ravel())).T # beam_amp, beam_extents, bfreqs = make_power_beam( # args, lm_source, freqs, use_dask) # # interpolate beam # if use_dask: # from africanus.rime.dask import beam_cube_dde # lm_source = da.from_array(lm_source, chunks=lm_source.shape) # freqs = da.from_array(freqs, chunks=freqs.shape) # # compute ncpu images at a time to avoid memory errors # ntimes = parangles.shape[0] # I = np.arange(0, ntimes, args.ncpu) # nchunks = I.size # I = np.append(I, ntimes) # beam_image = np.zeros((ll.size, 1, nband), dtype=beam_amp.dtype) # for i in range(nchunks): # ilow = I[i] # ihigh = I[i+1] # part_parangles = da.from_array( # parangles[ilow:ihigh], chunks=(1, 1)) # part_point_errs = da.from_array( # point_errs[ilow:ihigh], chunks=(1, 1, freqs.size, 2)) # # interpolate and remove redundant axes # part_beam_image = beam_cube_dde(beam_amp, beam_extents, bfreqs, # lm_source, part_parangles, part_point_errs, # ant_scale, freqs).compute()[:, :, 0, :, 0, 0] # # weighted sum over time # beam_image += np.sum(part_beam_image * # unflag_counts[None, ilow:ihigh, None], axis=1, keepdims=True) # # normalise by sum of weights # beam_image /= np.sum(unflag_counts) # # remove time axis # beam_image = beam_image[:, 0, :] # else: # from africanus.rime.fast_beam_cubes import beam_cube_dde # beam_image = beam_cube_dde(beam_amp, beam_extents, bfreqs, # lm_source, parangles, point_errs, # ant_scale, freqs).squeeze() # # swap source and freq axes and reshape to image shape # beam_source = np.transpose(beam_image, axes=(1, 0)) # return beam_source.squeeze().reshape((freqs.size, *ll.shape)) # def main(args): # # get coord info # hdr = fits.getheader(args.image) # l_coord, ref_l = data_from_header(hdr, axis=1) # l_coord -= ref_l # m_coord, ref_m = data_from_header(hdr, axis=2) # m_coord -= ref_m # if hdr["CTYPE4"].lower() == 'freq': # freq_axis = 4 # elif hdr["CTYPE3"].lower() == 'freq': # freq_axis = 3 # else: # raise ValueError("Freq axis must be 3rd or 4th") # freqs, ref_freq = data_from_header(hdr, axis=freq_axis) # xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij') # # interpolate primary beam to fits header and optionally average over time # beam_image = interpolate_beam(xx, yy, freqs, args) # # save power beam # save_fits(args.output_filename, beam_image, hdr) # print("Wrote interpolated beam cube to %s \n" % args.output_filename) # return
def test_forwardmodel(do_beam, do_gains, tmp_path_factory): test_dir = tmp_path_factory.mktemp("test_pfb") packratt.get('/test/ms/2021-06-24/elwood/test_ascii_1h60.0s.MS.tar', str(test_dir)) import numpy as np np.random.seed(420) from numpy.testing import assert_allclose from pyrap.tables import table ms = table(str(test_dir / 'test_ascii_1h60.0s.MS'), readonly=False) spw = table(str(test_dir / 'test_ascii_1h60.0s.MS::SPECTRAL_WINDOW')) utime = np.unique(ms.getcol('TIME')) freq = spw.getcol('CHAN_FREQ').squeeze() freq0 = np.mean(freq) ntime = utime.size nchan = freq.size nant = np.maximum( ms.getcol('ANTENNA1').max(), ms.getcol('ANTENNA2').max()) + 1 ncorr = ms.getcol('FLAG').shape[-1] uvw = ms.getcol('UVW') nrow = uvw.shape[0] u_max = abs(uvw[:, 0]).max() v_max = abs(uvw[:, 1]).max() uv_max = np.maximum(u_max, v_max) # image size from africanus.constants import c as lightspeed cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed) srf = 2.0 cell_rad = cell_N / srf cell_size = cell_rad * 180 / np.pi print("Cell size set to %5.5e arcseconds" % cell_size) fov = 2 npix = int(fov / cell_size) if npix % 2: npix += 1 nx = npix ny = npix print("Image size set to (%i, %i, %i)" % (nchan, nx, ny)) # model model = np.zeros((nchan, nx, ny), dtype=np.float64) nsource = 10 Ix = np.random.randint(0, npix, nsource) Iy = np.random.randint(0, npix, nsource) alpha = -0.7 + 0.1 * np.random.randn(nsource) I0 = 1.0 + np.abs(np.random.randn(nsource)) for i in range(nsource): model[:, Ix[i], Iy[i]] = I0[i] * (freq / freq0)**alpha[i] if do_beam: # primary beam from katbeam import JimBeam beam = JimBeam('MKAT-AA-L-JIM-2020') l_coord = -np.arange(-(nx // 2), nx // 2) * cell_size m_coord = np.arange(-(ny // 2), ny // 2) * cell_size xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij') pbeam = np.zeros((nchan, nx, ny), dtype=np.float64) for i in range(nchan): pbeam[i] = beam.I(xx, yy, freq[i] / 1e6) # freq in MHz model_att = pbeam * model bm = 'JimBeam' else: model_att = model bm = None # model vis from ducc0.wgridder import dirty2ms model_vis = np.zeros((nrow, nchan, ncorr), dtype=np.complex128) for c in range(nchan): model_vis[:, c:c + 1, 0] = dirty2ms(uvw, freq[c:c + 1], model_att[c], pixsize_x=cell_rad, pixsize_y=cell_rad, epsilon=1e-8, do_wstacking=True, nthreads=8) model_vis[:, c, -1] = model_vis[:, c, 0] ms.putcol('MODEL_DATA', model_vis.astype(np.complex64)) if do_gains: t = (utime - utime.min()) / (utime.max() - utime.min()) nu = 2.5 * (freq / freq0 - 1.0) from africanus.gps.utils import abs_diff tt = abs_diff(t, t) lt = 0.25 Kt = 0.1 * np.exp(-tt**2 / (2 * lt**2)) Lt = np.linalg.cholesky(Kt + 1e-10 * np.eye(ntime)) vv = abs_diff(nu, nu) lv = 0.1 Kv = 0.1 * np.exp(-vv**2 / (2 * lv**2)) Lv = np.linalg.cholesky(Kv + 1e-10 * np.eye(nchan)) L = (Lt, Lv) from pfb.utils.misc import kron_matvec jones = np.zeros((ntime, nant, nchan, 1, ncorr), dtype=np.complex128) for p in range(nant): for c in [0, -1]: # for now only diagonal xi_amp = np.random.randn(ntime, nchan) amp = np.exp(-nu[None, :]**2 + kron_matvec(L, xi_amp).reshape(ntime, nchan)) xi_phase = np.random.randn(ntime, nchan) phase = kron_matvec(L, xi_phase).reshape(ntime, nchan) jones[:, p, :, 0, c] = amp * np.exp(1.0j * phase) # corrupted vis model_vis = model_vis.reshape(nrow, nchan, 1, 2, 2) from africanus.calibration.utils import chunkify_rows time = ms.getcol('TIME') row_chunks, tbin_idx, tbin_counts = chunkify_rows(time, ntime) ant1 = ms.getcol('ANTENNA1') ant2 = ms.getcol('ANTENNA2') from africanus.calibration.utils import corrupt_vis vis = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones, model_vis).reshape(nrow, nchan, ncorr) model_vis[:, :, 0, 0, 0] = 1.0 + 0j model_vis[:, :, 0, -1, -1] = 1.0 + 0j muellercol = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones, model_vis).reshape(nrow, nchan, ncorr) ms.putcol('DATA', vis.astype(np.complex64)) ms.putcol('CORRECTED_DATA', muellercol.astype(np.complex64)) ms.close() mcol = 'CORRECTED_DATA' else: ms.putcol('DATA', model_vis.astype(np.complex64)) mcol = None from pfb.workers.grid.dirty import _dirty _dirty(ms=str(test_dir / 'test_ascii_1h60.0s.MS'), data_column="DATA", weight_column='WEIGHT', imaging_weight_column=None, flag_column='FLAG', mueller_column=mcol, row_chunks=None, epsilon=1e-5, wstack=True, mock=False, double_accum=True, output_filename=str(test_dir / 'test'), nband=nchan, field_of_view=fov, super_resolution_factor=srf, cell_size=None, nx=None, ny=None, output_type='f4', nworkers=1, nthreads_per_worker=1, nvthreads=8, mem_limit=8, nthreads=8, host_address=None) from pfb.workers.grid.psf import _psf _psf(ms=str(test_dir / 'test_ascii_1h60.0s.MS'), data_column="DATA", weight_column='WEIGHT', imaging_weight_column=None, flag_column='FLAG', mueller_column=mcol, row_out_chunk=-1, row_chunks=None, epsilon=1e-5, wstack=True, mock=False, psf_oversize=2, double_accum=True, output_filename=str(test_dir / 'test'), nband=nchan, field_of_view=fov, super_resolution_factor=srf, cell_size=None, nx=None, ny=None, output_type='f4', nworkers=1, nthreads_per_worker=1, nvthreads=8, mem_limit=8, nthreads=8, host_address=None) # solve for model using pcg and mask mask = np.any(model, axis=0) from astropy.io import fits from pfb.utils.fits import save_fits hdr = fits.getheader(str(test_dir / 'test_dirty.fits')) save_fits(str(test_dir / 'test_model.fits'), model, hdr) save_fits(str(test_dir / 'test_mask.fits'), mask, hdr) from pfb.workers.deconv.forward import _forward _forward(residual=str(test_dir / 'test_dirty.fits'), psf=str(test_dir / 'test_psf.fits'), mask=str(test_dir / 'test_mask.fits'), beam_model=bm, band='L', weight_table=str(test_dir / 'test.zarr'), output_filename=str(test_dir / 'test'), nband=nchan, output_type='f4', epsilon=1e-5, sigmainv=0.0, wstack=True, double_accum=True, cg_tol=1e-6, cg_minit=10, cg_maxit=100, cg_verbose=0, cg_report_freq=10, backtrack=False, nworkers=1, nthreads_per_worker=1, nvthreads=1, mem_limit=8, nthreads=1, host_address=None) # get inferred model from pfb.utils.fits import load_fits model_inferred = load_fits(str(test_dir / 'test_update.fits')).squeeze() for i in range(nsource): if do_beam: beam = pbeam[:, Ix[i], Iy[i]] assert_allclose( 0.0, beam * (model_inferred[:, Ix[i], Iy[i]] - model[:, Ix[i], Iy[i]]), atol=1e-4) else: assert_allclose(0.0, model_inferred[:, Ix[i], Iy[i]] - model[:, Ix[i], Iy[i]], atol=1e-4)
def _residual(ms, stack, **kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) pyscilog.log_to_file(args.output_filename + '.log') pyscilog.enable_memory_logging(level=3) # number of threads per worker if args.nthreads is None: if args.host_address is not None: raise ValueError( "You have to specify nthreads when using a distributed scheduler" ) import multiprocessing nthreads = multiprocessing.cpu_count() args.nthreads = nthreads else: nthreads = args.nthreads # configure memory limit if args.mem_limit is None: if args.host_address is not None: raise ValueError( "You have to specify mem-limit when using a distributed scheduler" ) import psutil mem_limit = int(psutil.virtual_memory()[0] / 1e9) # 100% of memory by default args.mem_limit = mem_limit else: mem_limit = args.mem_limit nband = args.nband if args.nworkers is None: nworkers = nband args.nworkers = nworkers else: nworkers = args.nworkers if args.nthreads_per_worker is None: nthreads_per_worker = 1 args.nthreads_per_worker = nthreads_per_worker else: nthreads_per_worker = args.nthreads_per_worker # the number of chunks being read in simultaneously is equal to # the number of dask threads nthreads_dask = nworkers * nthreads_per_worker if args.ngridder_threads is None: if args.host_address is not None: ngridder_threads = nthreads // nthreads_per_worker else: ngridder_threads = nthreads // nthreads_dask args.ngridder_threads = ngridder_threads else: ngridder_threads = args.ngridder_threads ms = list(ms) print('Input Options:', file=log) for key in kw.keys(): print(' %25s = %s' % (key, args[key]), file=log) # numpy imports have to happen after this step from pfb import set_client set_client(nthreads, mem_limit, nworkers, nthreads_per_worker, args.host_address, stack, log) import numpy as np from pfb.utils.misc import chan_to_band_mapping import dask from dask.graph_manipulation import clone from dask.distributed import performance_report from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import residual as im2residim from ducc0.fft import good_size from pfb.utils.misc import stitch_images, plan_row_chunk from pfb.utils.fits import set_wcs, save_fits # chan <-> band mapping freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping( ms, nband=nband) # gridder memory budget max_chan_chunk = 0 max_freq = 0 for ims in ms: for spw in freqs[ims]: counts = freq_bin_counts[ims][spw].compute() freq = freqs[ims][spw].compute() max_chan_chunk = np.maximum(max_chan_chunk, counts.max()) max_freq = np.maximum(max_freq, freq.max()) # assumes measurement sets have the same columns, # number of correlations etc. xds = xds_from_ms(ms[0]) ncorr = xds[0].dims['corr'] nrow = xds[0].dims['row'] data_bytes = getattr(xds[0], args.data_column).data.itemsize bytes_per_row = max_chan_chunk * ncorr * data_bytes memory_per_row = bytes_per_row # real valued weights wdims = getattr(xds[0], args.weight_column).data.ndim if wdims == 2: # WEIGHT memory_per_row += ncorr * data_bytes / 2 else: # WEIGHT_SPECTRUM memory_per_row += bytes_per_row / 2 # flags (uint8 or bool) memory_per_row += np.dtype(np.uint8).itemsize * max_chan_chunk * ncorr # UVW memory_per_row += xds[0].UVW.data.itemsize * 3 # ANTENNA1/2 memory_per_row += xds[0].ANTENNA1.data.itemsize * 2 columns = (args.data_column, args.weight_column, args.flag_column, 'UVW', 'ANTENNA1', 'ANTENNA2') # flag row if 'FLAG_ROW' in xds[0]: columns += ('FLAG_ROW', ) memory_per_row += xds[0].FLAG_ROW.data.itemsize # imaging weights if args.imaging_weight_column is not None: columns += (args.imaging_weight_column, ) memory_per_row += bytes_per_row / 2 # Mueller term (complex valued) if args.mueller_column is not None: columns += (args.mueller_column, ) memory_per_row += bytes_per_row # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 for ims in ms: xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1}) for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) uv_max = uv_max.compute() del uvw # image size cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed) if args.cell_size is not None: cell_size = args.cell_size cell_rad = cell_size * np.pi / 60 / 60 / 180 if cell_N / cell_rad < 1: raise ValueError( "Requested cell size too small. " "Super resolution factor = ", cell_N / cell_rad) print("Super resolution factor = %f" % (cell_N / cell_rad), file=log) else: cell_rad = cell_N / args.super_resolution_factor cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % cell_size, file=log) if args.nx is None: fov = args.field_of_view * 3600 npix = int(fov / cell_size) if npix % 2: npix += 1 nx = good_size(npix) ny = good_size(npix) else: nx = args.nx ny = args.ny if args.ny is not None else nx print("Image size set to (%i, %i, %i)" % (nband, nx, ny), file=log) # get approx image size # this is not a conservative estimate when multiple SPW's map to a single # imaging band pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # full image on single node row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow, memory_per_row, nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row, nthreads_per_worker) if args.row_chunks is not None: row_chunk = int(args.row_chunks) if row_chunk == -1: row_chunk = nrow print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) chunks = {} for ims in ms: chunks[ims] = [] # xds_from_ms expects a list per ds for spw in freqs[ims]: chunks[ims].append({ 'row': row_chunk, 'chan': chan_chunks[ims][spw]['chan'] }) dirties = [] radec = None # assumes we are only imaging field 0 of first MS for ims in ms: xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue # this is not correct, need to use spw spw = ds.DATA_DESC_ID uvw = clone(ds.UVW.data) data = getattr(ds, args.data_column).data dataxx = data[:, :, 0] datayy = data[:, :, -1] weights = getattr(ds, args.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) if args.imaging_weight_column is not None: imaging_weights = getattr(ds, args.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data.shape, chunks=data.chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # apply adjoint of mueller term. # Phases modify data amplitudes modify weights. if args.mueller_column is not None: mueller = getattr(ds, args.mueller_column).data dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0])) datayy *= da.exp(-1j * da.angle(mueller[:, :, -1])) weightsxx *= da.absolute(mueller[:, :, 0]) weightsyy *= da.absolute(mueller[:, :, -1]) # weighted sum corr to Stokes I weights = weightsxx + weightsyy data = (weightsxx * dataxx + weightsyy * datayy) # TODO - turn off this stupid warning data = da.where(weights, data / weights, 0.0j) # MS may contain auto-correlations if 'FLAG_ROW' in xds[0]: frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data == ds.ANTENNA2.data) else: frow = (ds.ANTENNA1.data == ds.ANTENNA2.data) # only keep data where both corrs are unflagged flag = getattr(ds, args.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] # ducc0 uses uint8 mask not flag mask = ~da.logical_or((flagxx | flagyy), frow[:, None]) dirty = vis2im(uvw, freqs[ims][spw], data, freq_bin_idx[ims][spw], freq_bin_counts[ims][spw], nx, ny, cell_rad, weights=weights, flag=mask.astype(np.uint8), nthreads=ngridder_threads, epsilon=args.epsilon, do_wstacking=args.wstack, double_accum=args.double_accum) dirties.append(dirty) # dask.visualize(dirties, filename=args.output_filename + '_graph.pdf', optimize_graph=False) if not args.mock: # result = dask.compute(dirties, wsum, optimize_graph=False) with performance_report(filename=args.output_filename + '_per.html'): result = dask.compute(dirties, optimize_graph=False) dirties = result[0] dirty = stitch_images(dirties, nband, band_mapping) hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, freq_out) save_fits(args.output_filename + '_dirty.fits', dirty, hdr, dtype=args.output_type) print("All done here.", file=log)
def _clean(**kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) import numpy as np import numexpr as ne import dask import dask.array as da from dask.distributed import performance_report from pfb.utils.fits import load_fits, set_wcs, save_fits, data_from_header from pfb.opt.hogbom import hogbom from astropy.io import fits print("Loading dirty", file=log) dirty = load_fits(args.dirty, dtype=args.output_type).squeeze() nband, nx, ny = dirty.shape hdr = fits.getheader(args.dirty) print("Loading psf", file=log) psf = load_fits(args.psf, dtype=args.output_type).squeeze() _, nx_psf, ny_psf = psf.shape hdr_psf = fits.getheader(args.psf) wsums = np.amax(psf.reshape(-1, nx_psf * ny_psf), axis=1) wsum = np.sum(wsums) psf /= wsum psf_mfs = np.sum(psf, axis=0) assert (psf_mfs.max() - 1.0) < 1e-4 dirty /= wsum dirty_mfs = np.sum(dirty, axis=0) # get info required to set WCS ra = np.deg2rad(hdr['CRVAL1']) dec = np.deg2rad(hdr['CRVAL2']) radec = [ra, dec] cell_deg = np.abs(hdr['CDELT1']) if cell_deg != np.abs(hdr['CDELT2']): raise NotImplementedError('cell sizes have to be equal') cell_rad = np.deg2rad(cell_deg) freq_out, ref_freq = data_from_header(hdr, axis=3) hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq) save_fits(args.output_filename + '_dirty_mfs.fits', dirty_mfs, hdr_mfs, dtype=args.output_type) # set up Hessian approximation if args.weight_table is not None: normfact = wsum from africanus.gridding.wgridder.dask import hessian from pfb.utils.misc import plan_row_chunk from daskms.experimental.zarr import xds_from_zarr xds = xds_from_zarr(args.weight_table)[0] nrow = xds.row.size freqs = xds.chan.data nchan = freqs.size # bin edges fmin = freqs.min() fmax = freqs.max() fbins = np.linspace(fmin, fmax, nband + 1) # chan <-> band mapping band_mapping = {} chan_chunks = {} freq_bin_idx = {} freq_bin_counts = {} band_map = np.zeros(freqs.size, dtype=np.int32) for band in range(nband): indl = freqs >= fbins[band] indu = freqs < fbins[band + 1] + 1e-6 band_map = np.where(indl & indu, band, band_map) # to dask arrays bands, bin_counts = np.unique(band_map, return_counts=True) band_mapping = tuple(bands) chan_chunks = {'chan': tuple(bin_counts)} freqs = da.from_array(freqs, chunks=tuple(bin_counts)) bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1] freq_bin_idx = da.from_array(bin_idx, chunks=1) freq_bin_counts = da.from_array(bin_counts, chunks=1) max_chan_chunk = bin_counts.max() bin_counts = tuple(bin_counts) # the first factor of 3 accounts for the intermediate visibilities # produced in Hessian (i.e. complex data + real weights) memory_per_row = (3 * max_chan_chunk * xds.WEIGHT.data.itemsize + 3 * xds.UVW.data.itemsize) # get approx image size pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # nworker bands on single node row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size, nrow, memory_per_row, args.nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow, memory_per_row, args.nthreads_per_worker) print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) def convolver(x): model = da.from_array(x, chunks=(1, nx, ny), name=False) xds = xds_from_zarr(args.weight_table, chunks={ 'row': row_chunk, 'chan': bin_counts })[0] convolvedim = hessian(xds.UVW.data, freqs, model, freq_bin_idx, freq_bin_counts, cell_rad, weights=xds.WEIGHT.data.astype( args.output_type), nthreads=args.nvthreads, epsilon=args.epsilon, do_wstacking=args.wstack, double_accum=args.double_accum) return convolvedim else: normfact = 1.0 from pfb.operators.psf import hessian from ducc0.fft import r2c iFs = np.fft.ifftshift npad_xl = (nx_psf - nx) // 2 npad_xr = nx_psf - nx - npad_xl npad_yl = (ny_psf - ny) // 2 npad_yr = ny_psf - ny - npad_yl padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr)) unpad_x = slice(npad_xl, -npad_xr) unpad_y = slice(npad_yl, -npad_yr) lastsize = ny + np.sum(padding[-1]) psf_pad = iFs(psf, axes=(1, 2)) psfhat = r2c(psf_pad, axes=(1, 2), forward=True, nthreads=nthreads, inorm=0) psfhat = da.from_array(psfhat, chunks=(1, -1, -1)) def convolver(x): model = da.from_array(x, chunks=(1, nx, ny), name=False) convolvedim = hessian(model, psfhat, padding, nvthreads, unpad_x, unpad_y, lastsize) return convolvedim # psfo = PSF(psf, dirty.shape, nthreads=args.nthreads) # def convolver(x): return psfo.convolve(x) rms = np.std(dirty_mfs) rmax = np.abs(dirty_mfs).max() print("Iter %i: peak residual = %f, rms = %f" % (0, rmax, rms), file=log) residual = dirty.copy() residual_mfs = dirty_mfs.copy() model = np.zeros_like(residual) for k in range(args.nmiter): print("Running Hogbom", file=log) x = hogbom(residual, psf, gamma=args.hb_gamma, pf=args.hb_peak_factor, maxit=args.hb_maxit, verbosity=args.hb_verbose, report_freq=args.hb_report_freq) model += x print("Getting residual", file=log) convimage = convolver(model) dask.visualize(convimage, filename=args.output_filename + '_hessian' + str(k) + '_graph.pdf', optimize_graph=False) with performance_report(filename=args.output_filename + '_hessian' + str(k) + '_per.html'): convimage = dask.compute(convimage, optimize_graph=False)[0] ne.evaluate('dirty - convimage/normfact', out=residual, casting='same_kind') ne.evaluate('sum(residual, axis=0)', out=residual_mfs, casting='same_kind') rms = np.std(residual_mfs) rmax = np.abs(residual_mfs).max() print("Iter %i: peak residual = %f, rms = %f" % (k + 1, rmax, rms), file=log) print("Saving results", file=log) save_fits(args.output_filename + '_model.fits', model, hdr) model_mfs = np.mean(model, axis=0) save_fits(args.output_filename + '_model_mfs.fits', model_mfs, hdr_mfs) save_fits(args.output_filename + '_residual.fits', residual * wsums[:, None, None], hdr) save_fits(args.output_filename + '_residual.fits', residual_mfs, hdr_mfs) print("All done here.", file=log)