def clean(**kw): ''' Single-scale clean. If the optional weight-table argument points to a valid weight table (created by the psf worker) the algorithm will approximate gradients using the diagonal Mueller weights assumption (exact for Stokes I imaging) i.e. IR = ID - R.H W R x otherwise it is a pure image space algorithm i.e. IR = ID - PSF.convolve(x) The latter is exact in the absence of wide-field effects and is usually much faster. If a host address is provided the computation can be distributed over imaging band and row. When using a distributed scheduler both mem-limit and nthreads is per node and have to be specified. When using a local cluster, mem-limit and nthreads refer to the global memory and threads available, respectively. By default the gridder will use all available resources. Disclaimer - Memory budgeting is still very crude! On a local cluster, the default is to use: nworkers = nband nthreads-per-worker = 1 They have to be specified in ~.config/dask/jobqueue.yaml in the distributed case. if LocalCluster: nvthreads = nthreads//(nworkers*nthreads_per_worker) else: nvthreads = nthreads//nthreads-per-worker ''' args = OmegaConf.create(kw) pyscilog.log_to_file(args.output_filename + '.log') if args.nworkers is None: args.nworkers = args.nband OmegaConf.set_struct(args, True) with ExitStack() as stack: # numpy imports have to happen after this step from pfb import set_client set_client(args, stack, log) # TODO - prettier config printing print('Input Options:', file=log) for key in args.keys(): print(' %25s = %s' % (key, args[key]), file=log) return _clean(**args)
def jones2col(**kw): ''' Write product of diagonal Jones matrices to 'Mueller' column ''' args = OmegaConf.create(kw) pyscilog.log_to_file(args.output_filename + '.log') from glob import glob ms = glob(args.ms) try: assert len(ms) == 1 args.ms = ms except: raise ValueError(f"There must be exactly one MS at {args.ms}") OmegaConf.set_struct(args, True) with ExitStack() as stack: from pfb import set_client args = set_client(args, stack, log) # TODO - prettier config printing print('Input Options:', file=log) for key in args.keys(): print(' %25s = %s' % (key, args[key]), file=log) return _jones2col(**args)
def dirty(**kw): ''' Create a dirty image from a list of measurement sets. The dirty image cube is not normalised by wsum as this destroyes information. The MFS image is written out in units of Jy/beam. The normalisation factors can be obtained by making a psf image using the psf worker (see pfbworkers psf --help). If a host address is provided the computation can be distributed over imaging band and row. When using a distributed scheduler both mem-limit and nthreads is per node and have to be specified. When using a local cluster, mem-limit and nthreads refer to the global memory and threads available, respectively. By default the gridder will use all available resources. Disclaimer - Memory budgeting is still very crude! On a local cluster, the default is to use: nworkers = nband nthreads-per-worker = 1 They have to be specified in ~.config/dask/jobqueue.yaml in the distributed case. if LocalCluster: ngridder-threads = nthreads//(nworkers*nthreads_per_worker) else: ngridder-threads = nthreads//nthreads-per-worker ''' args = OmegaConf.create(kw) pyscilog.log_to_file(args.output_filename + '.log') from glob import glob ms = glob(args.ms) try: assert len(ms) > 0 args.ms = ms except: raise ValueError(f"No MS at {args.ms}") if args.nworkers is None: args.nworkers = args.nband OmegaConf.set_struct(args, True) with ExitStack() as stack: from pfb import set_client args = set_client(args, stack, log) # TODO - prettier config printing print('Input Options:', file=log) for key in args.keys(): print(' %25s = %s' % (key, args[key]), file=log) return _dirty(**args)
def forward(**kw): ''' Extract flux at model locations. Will write out the result of solving x = (R.H W R + sigmainv**2 I)^{-1} ID assuming that R.H W R can be approximated as a convolution with the PSF. If a host address is provided the computation can be distributed over imaging band and row. When using a distributed scheduler both mem-limit and nthreads is per node and have to be specified. When using a local cluster, mem-limit and nthreads refer to the global memory and threads available, respectively. By default the gridder will use all available resources. Disclaimer - Memory budgeting is still very crude! On a local cluster, the default is to use: nworkers = nband nthreads-per-worker = 1 They have to be specified in ~.config/dask/jobqueue.yaml in the distributed case. if LocalCluster: ngridder-threads = nthreads//(nworkers*nthreads_per_worker) else: ngridder-threads = nthreads//nthreads-per-worker ''' args = OmegaConf.create(kw) pyscilog.log_to_file(args.output_filename + '.log') if args.nworkers is None: args.nworkers = args.nband OmegaConf.set_struct(args, True) with ExitStack() as stack: from pfb import set_client args = set_client(args, stack, log) # TODO - prettier config printing print('Input Options:', file=log) for key in args.keys(): print(' %25s = %s' % (key, args[key]), file=log) return _forward(**args)
def _predict(ms, stack, **kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) pyscilog.log_to_file(args.output_filename + '.log') pyscilog.enable_memory_logging(level=3) # number of threads per worker if args.nthreads is None: if args.host_address is not None: raise ValueError( "You have to specify nthreads when using a distributed scheduler" ) import multiprocessing nthreads = multiprocessing.cpu_count() args.nthreads = nthreads else: nthreads = args.nthreads if args.mem_limit is None: if args.host_address is not None: raise ValueError( "You have to specify mem-limit when using a distributed scheduler" ) import psutil mem_limit = int(psutil.virtual_memory()[0] / 1e9) # 100% of memory by default args.mem_limit = mem_limit else: mem_limit = args.mem_limit nband = args.nband if args.nworkers is None: nworkers = nband args.nworkers = nworkers else: nworkers = args.nworkers if args.nthreads_per_worker is None: nthreads_per_worker = 1 args.nthreads_per_worker = nthreads_per_worker else: nthreads_per_worker = args.nthreads_per_worker # the number of chunks being read in simultaneously is equal to # the number of dask threads nthreads_dask = nworkers * nthreads_per_worker if args.ngridder_threads is None: if args.host_address is not None: ngridder_threads = nthreads // nthreads_per_worker else: ngridder_threads = nthreads // nthreads_dask args.ngridder_threads = ngridder_threads else: ngridder_threads = args.ngridder_threads ms = list(ms) print('Input Options:', file=log) for key in kw.keys(): print(' %25s = %s' % (key, args[key]), file=log) # numpy imports have to happen after this step from pfb import set_client set_client(nthreads, mem_limit, nworkers, nthreads_per_worker, args.host_address, stack, log) import numpy as np from pfb.utils.misc import chan_to_band_mapping import dask from dask.distributed import performance_report from dask.graph_manipulation import clone from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table from daskms.utils import dataset_type mstype = dataset_type(ms[0]) if mstype == 'casa': from daskms import xds_to_table elif mstype == 'zarr': from daskms.experimental.zarr import xds_to_zarr as xds_to_table import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import model as im2vis from pfb.utils.fits import load_fits from pfb.utils.misc import restore_corrs, plan_row_chunk from astropy.io import fits # always returns 4D # gridder expects freq axis model = np.atleast_3d(load_fits(args.model).squeeze()) nband, nx, ny = model.shape hdr = fits.getheader(args.model) cell_d = np.abs(hdr['CDELT1']) cell_rad = np.deg2rad(cell_d) # chan <-> band mapping freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping( ms, nband=nband) # degridder memory budget max_chan_chunk = 0 for ims in ms: for spw in freqs[ims]: counts = freq_bin_counts[ims][spw].compute() max_chan_chunk = np.maximum(max_chan_chunk, counts.max()) # assumes number of correlations are the same across MS/SPW xds = xds_from_ms(ms[0]) ncorr = xds[0].dims['corr'] nrow = xds[0].dims['row'] if args.output_type is not None: output_type = np.dtype(args.output_type) else: output_type = np.result_type(np.dtype(args.real_type), np.complex64) data_bytes = output_type.itemsize bytes_per_row = max_chan_chunk * ncorr * data_bytes memory_per_row = bytes_per_row # model memory_per_row += 3 * 8 # uvw if mstype == 'zarr': if args.model_column in xds[0].keys(): model_chunks = getattr(xds[0], args.model_column).data.chunks else: model_chunks = xds[0].DATA.data.chunks print('Chunking model same as data') # get approx image size # this is not a conservative estimate when multiple SPW's map to a single # imaging band pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # full image on single node row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow, memory_per_row, nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row, nthreads_per_worker) if args.row_chunks is not None: row_chunk = int(args.row_chunks) if row_chunk == -1: row_chunk = nrow print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) chunks = {} for ims in ms: chunks[ims] = [] # xds_from_ms expects a list per ds for spw in freqs[ims]: chunks[ims].append({ 'row': row_chunk, 'chan': chan_chunks[ims][spw]['chan'] }) model = da.from_array(model.astype(args.real_type), chunks=(1, nx, ny), name=False) writes = [] radec = None # assumes we are only imaging field 0 of first MS for ims in ms: xds = xds_from_ms(ims, chunks=chunks[ims], columns=('UVW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw uvw = clone(ds.UVW.data) bands = band_mapping[ims][spw] model = model[list(bands), :, :] vis = im2vis(uvw, freqs[ims][spw], model, freq_bin_idx[ims][spw], freq_bin_counts[ims][spw], cell_rad, nthreads=ngridder_threads, epsilon=args.epsilon, do_wstacking=args.wstack) model_vis = restore_corrs(vis, ncorr) if mstype == 'zarr': model_vis = model_vis.rechunk(model_chunks) uvw = uvw.rechunk((model_chunks[0], 3)) out_ds = ds.assign( **{ args.model_column: (("row", "chan", "corr"), model_vis), 'UVW': (("row", "three"), uvw) }) # out_ds = ds.assign(**{args.model_column: (("row", "chan", "corr"), model_vis)}) out_data.append(out_ds) writes.append(xds_to_table(out_data, ims, columns=[args.model_column])) dask.visualize(*writes, filename=args.output_filename + '_predict_graph.pdf', optimize_graph=False, collapse_outputs=True) if not args.mock: with performance_report(filename=args.output_filename + '_predict_per.html'): dask.compute(writes, optimize_graph=False) print("All done here.", file=log)
def psf(**kw): ''' Create a psf image from a list of measurement setsand write out the Mueller weights. The psf image cube is not normalised by wsum as this destroyes information. The MFS image is written out in units of Jy/beam and should have a peak of one otherwise something has gone wrong. The --field-of-view and --super-resolution-factor options (equivalently --cell-size, --nx and --ny) pertain to the size of the image (eg. dirty and model). The size of the PSF output image is controlled by the --psf-oversize option. The Stokes I weights required to apply the Hessian are also written out to a zarr data set called output-filename.zarr. This data set does not adhere to the MSv2 specs and is only meant to be used to apply the Hessian. In particular, the weights written out are a combination of imaging weights and the "Mueller" weights. If a host address is provided the computation can be distributed over imaging band and row. When using a distributed scheduler both mem-limit and nthreads is per node and have to be specified. When using a local cluster, mem-limit and nthreads refer to the global memory and threads available, respectively. By default the gridder will use all available resources. Disclaimer - Memory budgeting is still very crude! On a local cluster, the default is to use: nworkers = nband nthreads-per-worker = 1 They have to be specified in ~.config/dask/jobqueue.yaml in the distributed case. if LocalCluster: ngridder-threads = nthreads//(nworkers*nthreads_per_worker) else: ngridder-threads = nthreads//nthreads-per-worker ''' args = OmegaConf.create(kw) pyscilog.log_to_file(args.output_filename + '.log') from glob import glob ms = glob(args.ms) try: assert len(ms) > 0 args.ms = ms except: raise ValueError(f"No MS at {args.ms}") if args.nworkers is None: args.nworkers = args.nband OmegaConf.set_struct(args, True) with ExitStack() as stack: from pfb import set_client args = set_client(args, stack, log) # TODO - prettier config printing print('Input Options:', file=log) for key in args.keys(): print(' %25s = %s' % (key, args[key]), file=log) return _psf(**args)
def _restore(stack, **kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) pyscilog.log_to_file(args.output_filename + '.log') pyscilog.enable_memory_logging(level=3) # number of threads per worker if args.nthreads is None: if args.host_address is not None: raise ValueError( "You have to specify nthreads when using a distributed scheduler" ) import multiprocessing nthreads = multiprocessing.cpu_count() args.nthreads = nthreads else: nthreads = args.nthreads # configure memory limit if args.mem_limit is None: if args.host_address is not None: raise ValueError( "You have to specify mem-limit when using a distributed scheduler" ) import psutil mem_limit = int(psutil.virtual_memory()[0] / 1e9) # 100% of memory by default args.mem_limit = mem_limit else: mem_limit = args.mem_limit nband = args.nband if args.nworkers is None: nworkers = nband args.nworkers = nworkers else: nworkers = args.nworkers if args.nthreads_per_worker is None: nthreads_per_worker = 1 args.nthreads_per_worker = nthreads_per_worker else: nthreads_per_worker = args.nthreads_per_worker # the number of chunks being read in simultaneously is equal to # the number of dask threads nthreads_dask = nworkers * nthreads_per_worker if args.ngridder_threads is None: if args.host_address is not None: ngridder_threads = nthreads // nthreads_per_worker else: ngridder_threads = nthreads // nthreads_dask args.ngridder_threads = ngridder_threads else: ngridder_threads = args.ngridder_threads ms = list(ms) print('Input Options:', file=log) for key in kw.keys(): print(' %25s = %s' % (key, args[key]), file=log) # numpy imports have to happen after this step from pfb import set_client set_client(nthreads, mem_limit, nworkers, nthreads_per_worker, args.host_address, stack, log) import numpy as np from astropy.io import fits mhdr = fits.getheader(args.model) from pfb.utils.fits import load_fits model = load_fits(args.model).squeeze() # drop Stokes axis # check images compatible rhdr = fits.getheader(args.residual) from pfb.utils.fits import compare_headers compare_headers(mhdr, rhdr) residual = load_fits(args.residual).squeeze() # fit restoring psf from pfb.utils.misc import fitcleanbeam psf = load_fits(args.psf, dtype=args.real_type).squeeze() nband, nx_psf, ny_psf = psf.shape wsums = np.amax(psf.reshape(args.nband, nx_psf, ny_psf), axis=1) wsum = np.sum(wsums) psf /= wsum psf_mfs = np.sum(psf, axis=0) # fit restoring psf GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0) GaussPars = fitcleanbeam(psf, level=0.5, pixsize=1.0) cpsf_mfs = np.zeros(psf_mfs.shape, dtype=args.real_type) cpsf = np.zeros(psf.shape, dtype=args.real_type) lpsf = np.arange(-R.nx_psf / 2, R.nx_psf / 2) mpsf = np.arange(-R.ny_psf / 2, R.ny_psf / 2) xx, yy = np.meshgrid(lpsf, mpsf, indexing='ij') cpsf_mfs = Gaussian2D(xx, yy, GaussPar[0], normalise=False) for v in range(args.nband): cpsf[v] = Gaussian2D(xx, yy, GaussPars[v], normalise=False) from pfb.utils.fits import add_beampars GaussPar = list(GaussPar[0]) GaussPar[0] *= args.cell_size / 3600 GaussPar[1] *= args.cell_size / 3600 GaussPar = tuple(GaussPar) hdr_psf_mfs = add_beampars(hdr_psf_mfs, GaussPar) save_fits(args.outfile + '_cpsf_mfs.fits', cpsf_mfs, hdr_psf_mfs) save_fits(args.outfile + '_psf_mfs.fits', psf_mfs, hdr_psf_mfs) if args.beam is not None: bhdr = fits.getheader(args.beam) compare_headers(mhdr, bhdr) beam = load_fits(args.beam).squeeze() model = np.where(beam > args.pb_min, model / beam, 0.0) nband, nx, ny = model.shape guassparf = () if nband > 1: for b in range(nband): guassparf += (rhdr['BMAJ' + str(b)], rhdr['BMIN' + str(b)], rhdr['BPA' + str(b)]) else: guassparf += (rhdr['BMAJ'], rhdr['BMIN'], rhdr['BPA']) # if args.convolve_residuals: cellx = np.abs(mhdr['CDELT1']) celly = np.abs(mhdr['CDELT2']) from pfb.utils.restoration import restore_image
def spifit(**kw): """ Spectral index fitter """ args = OmegaConf.create(kw) pyscilog.log_to_file(args.output_filename + '.log') from glob import glob from omegaconf import ListConfig # image is either a string or a list of strings that we want to glob on if isinstance(args.image, str): image = sorted(glob(args.image)) elif isinstance(args.image, list) or isinstance(args.image, ListConfig): image = [] for i in len(args.image): image.append(sorted(glob(args.image[i]))) # make sure it's not empty try: assert len(image) > 0 args.image = image except: raise ValueError(f"No image at {args.image}") # same goes for the residual except that it may also be None if isinstance(args.residual, str): residual = sorted(glob(args.residual)) elif isinstance(args.residual, list) or isinstance(args.residual, ListConfig): residual = [] for i in len(args.residual): residual.append(sorted(glob(args.residual[i]))) if args.residual is not None: try: assert len(residual) > 0 args.residual = residual except: raise ValueError(f"No residual at {args.residual}") # we also need the same number of residuals as images try: assert len(args.image) == len(args.residual) except: raise ValueError(f"Number of images and residuals need to " "match") else: print("No residual passed in!", file=log) # and finally the beam model if isinstance(args.beam_model, str): beam_model = sorted(glob(args.beam_model)) elif isinstance(args.beam_model, list) or isinstance(args.beam_model, ListConfig): beam_model = [] for i in len(args.beam_model): beam_model.append(sorted(glob(args.beam_model[i]))) if args.beam_model is not None: try: assert len(beam_model) > 0 args.beam_model = beam_model except: raise ValueError(f"No beam model at {args.beam_model}") try: assert len(args.image) == len(args.beam_model) except: raise ValueError(f"Number of images and beam models need to " "match") else: print("Not doing any form of primary beam correction", file=log) # LB - TODO: can we sort them along freq at this point already? OmegaConf.set_struct(args, True) with ExitStack() as stack: from pfb import set_client args = set_client(args, stack, log) # TODO - prettier config printing print('Input Options:', file=log) for key in args.keys(): print(' %25s = %s' % (key, args[key]), file=log) return _spifit(**args)
def binterp(**kw): """ Beam interpolator Interpolate beams and stack cubes one MS and one spectral window at a time. """ args = OmegaConf.create(kw) from glob import glob image = sorted(glob(args.image)) try: assert len(image) > 0 args.image = image except: raise ValueError(f"No image at {args.image}") if args.output_dir is None: args.output_dir = os.path.dirname(args.image[0]) pyscilog.log_to_file(args.output_dir + args.postfix.strip('fits') + 'log') if args.ms is not None: ms = glob(args.ms) try: assert len(ms) == 1 args.ms = ms[0] except: raise ValueError( f"There must be exactly one MS matching {args.ms} if provided") if not isinstance(args.beam_model, str): raise ValueError("Only string beam patterns allowed") else: # we are either using JimBeam or globbing for beam patterns if args.beam_model.lower() == 'jimbeam': args.beam_model = args.beam_model.lower() band = args.band.lower() if band != 'l' and band != 'uhf': raise ValueError("Only l or uhf band supported with " "JimBeam") else: print("Using %s band beam model" % args.band, file=log) elif args.beam_model.lower().endswith('.fits'): beam_model = glob(args.beam_model) try: assert len(beam_model) > 0 except: raise ValueError(f"No beam model at {args.beam_model}") else: raise ValueError("Unknown beam model provided. " "Either use JimBeam or pass in the fits beam " "patterns") OmegaConf.set_struct(args, True) with ExitStack() as stack: from pfb import set_client args = set_client(args, stack, log) # TODO - prettier config printing print('Input Options:', file=log) for key in args.keys(): print(' %25s = %s' % (key, args[key]), file=log) return _binterp(**args)
def _residual(ms, stack, **kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) pyscilog.log_to_file(args.output_filename + '.log') pyscilog.enable_memory_logging(level=3) # number of threads per worker if args.nthreads is None: if args.host_address is not None: raise ValueError( "You have to specify nthreads when using a distributed scheduler" ) import multiprocessing nthreads = multiprocessing.cpu_count() args.nthreads = nthreads else: nthreads = args.nthreads # configure memory limit if args.mem_limit is None: if args.host_address is not None: raise ValueError( "You have to specify mem-limit when using a distributed scheduler" ) import psutil mem_limit = int(psutil.virtual_memory()[0] / 1e9) # 100% of memory by default args.mem_limit = mem_limit else: mem_limit = args.mem_limit nband = args.nband if args.nworkers is None: nworkers = nband args.nworkers = nworkers else: nworkers = args.nworkers if args.nthreads_per_worker is None: nthreads_per_worker = 1 args.nthreads_per_worker = nthreads_per_worker else: nthreads_per_worker = args.nthreads_per_worker # the number of chunks being read in simultaneously is equal to # the number of dask threads nthreads_dask = nworkers * nthreads_per_worker if args.ngridder_threads is None: if args.host_address is not None: ngridder_threads = nthreads // nthreads_per_worker else: ngridder_threads = nthreads // nthreads_dask args.ngridder_threads = ngridder_threads else: ngridder_threads = args.ngridder_threads ms = list(ms) print('Input Options:', file=log) for key in kw.keys(): print(' %25s = %s' % (key, args[key]), file=log) # numpy imports have to happen after this step from pfb import set_client set_client(nthreads, mem_limit, nworkers, nthreads_per_worker, args.host_address, stack, log) import numpy as np from pfb.utils.misc import chan_to_band_mapping import dask from dask.graph_manipulation import clone from dask.distributed import performance_report from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import residual as im2residim from ducc0.fft import good_size from pfb.utils.misc import stitch_images, plan_row_chunk from pfb.utils.fits import set_wcs, save_fits # chan <-> band mapping freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping( ms, nband=nband) # gridder memory budget max_chan_chunk = 0 max_freq = 0 for ims in ms: for spw in freqs[ims]: counts = freq_bin_counts[ims][spw].compute() freq = freqs[ims][spw].compute() max_chan_chunk = np.maximum(max_chan_chunk, counts.max()) max_freq = np.maximum(max_freq, freq.max()) # assumes measurement sets have the same columns, # number of correlations etc. xds = xds_from_ms(ms[0]) ncorr = xds[0].dims['corr'] nrow = xds[0].dims['row'] data_bytes = getattr(xds[0], args.data_column).data.itemsize bytes_per_row = max_chan_chunk * ncorr * data_bytes memory_per_row = bytes_per_row # real valued weights wdims = getattr(xds[0], args.weight_column).data.ndim if wdims == 2: # WEIGHT memory_per_row += ncorr * data_bytes / 2 else: # WEIGHT_SPECTRUM memory_per_row += bytes_per_row / 2 # flags (uint8 or bool) memory_per_row += np.dtype(np.uint8).itemsize * max_chan_chunk * ncorr # UVW memory_per_row += xds[0].UVW.data.itemsize * 3 # ANTENNA1/2 memory_per_row += xds[0].ANTENNA1.data.itemsize * 2 columns = (args.data_column, args.weight_column, args.flag_column, 'UVW', 'ANTENNA1', 'ANTENNA2') # flag row if 'FLAG_ROW' in xds[0]: columns += ('FLAG_ROW', ) memory_per_row += xds[0].FLAG_ROW.data.itemsize # imaging weights if args.imaging_weight_column is not None: columns += (args.imaging_weight_column, ) memory_per_row += bytes_per_row / 2 # Mueller term (complex valued) if args.mueller_column is not None: columns += (args.mueller_column, ) memory_per_row += bytes_per_row # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 for ims in ms: xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1}) for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) uv_max = uv_max.compute() del uvw # image size cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed) if args.cell_size is not None: cell_size = args.cell_size cell_rad = cell_size * np.pi / 60 / 60 / 180 if cell_N / cell_rad < 1: raise ValueError( "Requested cell size too small. " "Super resolution factor = ", cell_N / cell_rad) print("Super resolution factor = %f" % (cell_N / cell_rad), file=log) else: cell_rad = cell_N / args.super_resolution_factor cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % cell_size, file=log) if args.nx is None: fov = args.field_of_view * 3600 npix = int(fov / cell_size) if npix % 2: npix += 1 nx = good_size(npix) ny = good_size(npix) else: nx = args.nx ny = args.ny if args.ny is not None else nx print("Image size set to (%i, %i, %i)" % (nband, nx, ny), file=log) # get approx image size # this is not a conservative estimate when multiple SPW's map to a single # imaging band pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # full image on single node row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow, memory_per_row, nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row, nthreads_per_worker) if args.row_chunks is not None: row_chunk = int(args.row_chunks) if row_chunk == -1: row_chunk = nrow print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) chunks = {} for ims in ms: chunks[ims] = [] # xds_from_ms expects a list per ds for spw in freqs[ims]: chunks[ims].append({ 'row': row_chunk, 'chan': chan_chunks[ims][spw]['chan'] }) dirties = [] radec = None # assumes we are only imaging field 0 of first MS for ims in ms: xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue # this is not correct, need to use spw spw = ds.DATA_DESC_ID uvw = clone(ds.UVW.data) data = getattr(ds, args.data_column).data dataxx = data[:, :, 0] datayy = data[:, :, -1] weights = getattr(ds, args.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) if args.imaging_weight_column is not None: imaging_weights = getattr(ds, args.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data.shape, chunks=data.chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # apply adjoint of mueller term. # Phases modify data amplitudes modify weights. if args.mueller_column is not None: mueller = getattr(ds, args.mueller_column).data dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0])) datayy *= da.exp(-1j * da.angle(mueller[:, :, -1])) weightsxx *= da.absolute(mueller[:, :, 0]) weightsyy *= da.absolute(mueller[:, :, -1]) # weighted sum corr to Stokes I weights = weightsxx + weightsyy data = (weightsxx * dataxx + weightsyy * datayy) # TODO - turn off this stupid warning data = da.where(weights, data / weights, 0.0j) # MS may contain auto-correlations if 'FLAG_ROW' in xds[0]: frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data == ds.ANTENNA2.data) else: frow = (ds.ANTENNA1.data == ds.ANTENNA2.data) # only keep data where both corrs are unflagged flag = getattr(ds, args.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] # ducc0 uses uint8 mask not flag mask = ~da.logical_or((flagxx | flagyy), frow[:, None]) dirty = vis2im(uvw, freqs[ims][spw], data, freq_bin_idx[ims][spw], freq_bin_counts[ims][spw], nx, ny, cell_rad, weights=weights, flag=mask.astype(np.uint8), nthreads=ngridder_threads, epsilon=args.epsilon, do_wstacking=args.wstack, double_accum=args.double_accum) dirties.append(dirty) # dask.visualize(dirties, filename=args.output_filename + '_graph.pdf', optimize_graph=False) if not args.mock: # result = dask.compute(dirties, wsum, optimize_graph=False) with performance_report(filename=args.output_filename + '_per.html'): result = dask.compute(dirties, optimize_graph=False) dirties = result[0] dirty = stitch_images(dirties, nband, band_mapping) hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, freq_out) save_fits(args.output_filename + '_dirty.fits', dirty, hdr, dtype=args.output_type) print("All done here.", file=log)