def _residual(ms, stack, **kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) pyscilog.log_to_file(args.output_filename + '.log') pyscilog.enable_memory_logging(level=3) # number of threads per worker if args.nthreads is None: if args.host_address is not None: raise ValueError( "You have to specify nthreads when using a distributed scheduler" ) import multiprocessing nthreads = multiprocessing.cpu_count() args.nthreads = nthreads else: nthreads = args.nthreads # configure memory limit if args.mem_limit is None: if args.host_address is not None: raise ValueError( "You have to specify mem-limit when using a distributed scheduler" ) import psutil mem_limit = int(psutil.virtual_memory()[0] / 1e9) # 100% of memory by default args.mem_limit = mem_limit else: mem_limit = args.mem_limit nband = args.nband if args.nworkers is None: nworkers = nband args.nworkers = nworkers else: nworkers = args.nworkers if args.nthreads_per_worker is None: nthreads_per_worker = 1 args.nthreads_per_worker = nthreads_per_worker else: nthreads_per_worker = args.nthreads_per_worker # the number of chunks being read in simultaneously is equal to # the number of dask threads nthreads_dask = nworkers * nthreads_per_worker if args.ngridder_threads is None: if args.host_address is not None: ngridder_threads = nthreads // nthreads_per_worker else: ngridder_threads = nthreads // nthreads_dask args.ngridder_threads = ngridder_threads else: ngridder_threads = args.ngridder_threads ms = list(ms) print('Input Options:', file=log) for key in kw.keys(): print(' %25s = %s' % (key, args[key]), file=log) # numpy imports have to happen after this step from pfb import set_client set_client(nthreads, mem_limit, nworkers, nthreads_per_worker, args.host_address, stack, log) import numpy as np from pfb.utils.misc import chan_to_band_mapping import dask from dask.graph_manipulation import clone from dask.distributed import performance_report from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import residual as im2residim from ducc0.fft import good_size from pfb.utils.misc import stitch_images, plan_row_chunk from pfb.utils.fits import set_wcs, save_fits # chan <-> band mapping freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping( ms, nband=nband) # gridder memory budget max_chan_chunk = 0 max_freq = 0 for ims in ms: for spw in freqs[ims]: counts = freq_bin_counts[ims][spw].compute() freq = freqs[ims][spw].compute() max_chan_chunk = np.maximum(max_chan_chunk, counts.max()) max_freq = np.maximum(max_freq, freq.max()) # assumes measurement sets have the same columns, # number of correlations etc. xds = xds_from_ms(ms[0]) ncorr = xds[0].dims['corr'] nrow = xds[0].dims['row'] data_bytes = getattr(xds[0], args.data_column).data.itemsize bytes_per_row = max_chan_chunk * ncorr * data_bytes memory_per_row = bytes_per_row # real valued weights wdims = getattr(xds[0], args.weight_column).data.ndim if wdims == 2: # WEIGHT memory_per_row += ncorr * data_bytes / 2 else: # WEIGHT_SPECTRUM memory_per_row += bytes_per_row / 2 # flags (uint8 or bool) memory_per_row += np.dtype(np.uint8).itemsize * max_chan_chunk * ncorr # UVW memory_per_row += xds[0].UVW.data.itemsize * 3 # ANTENNA1/2 memory_per_row += xds[0].ANTENNA1.data.itemsize * 2 columns = (args.data_column, args.weight_column, args.flag_column, 'UVW', 'ANTENNA1', 'ANTENNA2') # flag row if 'FLAG_ROW' in xds[0]: columns += ('FLAG_ROW', ) memory_per_row += xds[0].FLAG_ROW.data.itemsize # imaging weights if args.imaging_weight_column is not None: columns += (args.imaging_weight_column, ) memory_per_row += bytes_per_row / 2 # Mueller term (complex valued) if args.mueller_column is not None: columns += (args.mueller_column, ) memory_per_row += bytes_per_row # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 for ims in ms: xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1}) for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) uv_max = uv_max.compute() del uvw # image size cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed) if args.cell_size is not None: cell_size = args.cell_size cell_rad = cell_size * np.pi / 60 / 60 / 180 if cell_N / cell_rad < 1: raise ValueError( "Requested cell size too small. " "Super resolution factor = ", cell_N / cell_rad) print("Super resolution factor = %f" % (cell_N / cell_rad), file=log) else: cell_rad = cell_N / args.super_resolution_factor cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % cell_size, file=log) if args.nx is None: fov = args.field_of_view * 3600 npix = int(fov / cell_size) if npix % 2: npix += 1 nx = good_size(npix) ny = good_size(npix) else: nx = args.nx ny = args.ny if args.ny is not None else nx print("Image size set to (%i, %i, %i)" % (nband, nx, ny), file=log) # get approx image size # this is not a conservative estimate when multiple SPW's map to a single # imaging band pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # full image on single node row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow, memory_per_row, nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row, nthreads_per_worker) if args.row_chunks is not None: row_chunk = int(args.row_chunks) if row_chunk == -1: row_chunk = nrow print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) chunks = {} for ims in ms: chunks[ims] = [] # xds_from_ms expects a list per ds for spw in freqs[ims]: chunks[ims].append({ 'row': row_chunk, 'chan': chan_chunks[ims][spw]['chan'] }) dirties = [] radec = None # assumes we are only imaging field 0 of first MS for ims in ms: xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue # this is not correct, need to use spw spw = ds.DATA_DESC_ID uvw = clone(ds.UVW.data) data = getattr(ds, args.data_column).data dataxx = data[:, :, 0] datayy = data[:, :, -1] weights = getattr(ds, args.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) if args.imaging_weight_column is not None: imaging_weights = getattr(ds, args.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data.shape, chunks=data.chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # apply adjoint of mueller term. # Phases modify data amplitudes modify weights. if args.mueller_column is not None: mueller = getattr(ds, args.mueller_column).data dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0])) datayy *= da.exp(-1j * da.angle(mueller[:, :, -1])) weightsxx *= da.absolute(mueller[:, :, 0]) weightsyy *= da.absolute(mueller[:, :, -1]) # weighted sum corr to Stokes I weights = weightsxx + weightsyy data = (weightsxx * dataxx + weightsyy * datayy) # TODO - turn off this stupid warning data = da.where(weights, data / weights, 0.0j) # MS may contain auto-correlations if 'FLAG_ROW' in xds[0]: frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data == ds.ANTENNA2.data) else: frow = (ds.ANTENNA1.data == ds.ANTENNA2.data) # only keep data where both corrs are unflagged flag = getattr(ds, args.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] # ducc0 uses uint8 mask not flag mask = ~da.logical_or((flagxx | flagyy), frow[:, None]) dirty = vis2im(uvw, freqs[ims][spw], data, freq_bin_idx[ims][spw], freq_bin_counts[ims][spw], nx, ny, cell_rad, weights=weights, flag=mask.astype(np.uint8), nthreads=ngridder_threads, epsilon=args.epsilon, do_wstacking=args.wstack, double_accum=args.double_accum) dirties.append(dirty) # dask.visualize(dirties, filename=args.output_filename + '_graph.pdf', optimize_graph=False) if not args.mock: # result = dask.compute(dirties, wsum, optimize_graph=False) with performance_report(filename=args.output_filename + '_per.html'): result = dask.compute(dirties, optimize_graph=False) dirties = result[0] dirty = stitch_images(dirties, nband, band_mapping) hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, freq_out) save_fits(args.output_filename + '_dirty.fits', dirty, hdr, dtype=args.output_type) print("All done here.", file=log)
def compute_weights(self, robust): from pfb.utils.weighting import compute_counts, counts_to_weights # compute counts counts = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=('UVW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # not optimal, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data count = compute_counts(uvw, freq, freq_bin_idx, freq_bin_counts, self.nx, self.ny, self.cell, self.cell, np.float32) counts.append(count) counts = dask.compute(counts)[0] counts = accumulate_dirty(counts, self.nband, self.band_mapping) counts = da.from_array(counts, chunks=(1, -1, -1)) # convert counts to weights writes = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data weights = counts_to_weights(counts, uvw, freq, freq_bin_idx, freq_bin_counts, self.nx, self.ny, self.cell, self.cell, np.float32, robust) # hack to get shape and chunking info data = getattr(ds, self.data_column).data weights = da.broadcast_to(weights[:, :, None], data.shape, chunks=data.chunks) out_ds = ds.assign(**{ self.imaging_weight_column: (("row", "chan", "corr"), weights) }) out_data.append(out_ds) writes.append( xds_to_table(out_data, ims, columns=[self.imaging_weight_column])) dask.compute(writes)
def make_dirty(self): print("Making dirty", file=log) dirty = da.zeros((self.nband, self.nx, self.ny), dtype=np.float32, chunks=(1, self.nx, self.ny), name=False) dirties = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] freq_chunk = freq_bin_counts[0].compute() uvw = ds.UVW.data data = getattr(ds, self.data_column).data dataxx = data[:, :, 0] datayy = data[:, :, -1] weights = getattr(ds, self.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) if self.imaging_weight_column is not None: imaging_weights = getattr(ds, self.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to( imaging_weights[:, None, :], data.shape, chunks=data.chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # apply adjoint of mueller term. # Phases modify data amplitudes modify weights. if self.mueller_column is not None: mueller = getattr(ds, self.mueller_column).data dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0])) datayy *= da.exp(-1j * da.angle(mueller[:, :, -1])) weightsxx *= da.absolute(mueller[:, :, 0]) weightsyy *= da.absolute(mueller[:, :, -1]) # weighted sum corr to Stokes I weights = weightsxx + weightsyy data = (weightsxx * dataxx + weightsyy * datayy) # TODO - turn off this stupid warning data = da.where(weights, data / weights, 0.0j) # only keep data where both corrs are unflagged flag = getattr(ds, self.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] # ducc0 convention uses uint8 mask not flag flag = ~(flagxx | flagyy) dirty = vis2im(uvw, freq, data, freq_bin_idx, freq_bin_counts, self.nx, self.ny, self.cell, weights=weights, flag=flag.astype(np.uint8), nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking, double_accum=True) dirties.append(dirty) dirties = dask.compute(dirties, scheduler='single-threaded')[0] return accumulate_dirty(dirties, self.nband, self.band_mapping).astype(self.real_type)
from dask.diagnostics import Profiler, ProgressBar def create_parser(): parser = argparse.ArgumentParser() parser.add_argument("ms") parser.add_argument("-c", "--chunks", default=10000, type=int) parser.add_argument("-s", "--scheduler", default="threaded") return parser args = create_parser().parse_args() with scheduler_context(args): # Create a dataset representing the entire antenna table ant_table = '::'.join((args.ms, 'ANTENNA')) for ant_ds in xds_from_table(ant_table): print(dask.compute(ant_ds.NAME.data, ant_ds.POSITION.data, ant_ds.DISH_DIAMETER.data)) # Create datasets representing each row of the spw table spw_table = '::'.join((args.ms, 'SPECTRAL_WINDOW')) for spw_ds in xds_from_table(spw_table, group_cols="__row__"): print(spw_ds) print(spw_ds.NUM_CHAN.values) print(spw_ds.CHAN_FREQ.values) # Create datasets from a partioning of the MS datasets = list(xds_from_ms(args.ms, chunks={'row': args.chunks}))
def _psf(**kw): args = OmegaConf.create(kw) from omegaconf import ListConfig if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig): args.ms = [args.ms] OmegaConf.set_struct(args, True) import numpy as np from pfb.utils.misc import chan_to_band_mapping import dask # from dask.distributed import performance_report from dask.graph_manipulation import clone from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table from daskms import Dataset from daskms.experimental.zarr import xds_to_zarr import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import dirty as vis2im from ducc0.fft import good_size from pfb.utils.misc import stitch_images, plan_row_chunk from pfb.utils.fits import set_wcs, save_fits # chan <-> band mapping ms = args.ms nband = args.nband freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping( ms, nband=nband) # gridder memory budget max_chan_chunk = 0 max_freq = 0 for ims in args.ms: for spw in freqs[ims]: counts = freq_bin_counts[ims][spw].compute() freq = freqs[ims][spw].compute() max_chan_chunk = np.maximum(max_chan_chunk, counts.max()) max_freq = np.maximum(max_freq, freq.max()) # assumes measurement sets have the same columns, # number of correlations etc. xds = xds_from_ms(args.ms[0]) ncorr = xds[0].dims['corr'] nrow = xds[0].dims['row'] # we still have to cater for complex valued data because we cast # the weights to complex but we not longer need to factor the # weight column into our memory budget data_bytes = getattr(xds[0], args.data_column).data.itemsize bytes_per_row = max_chan_chunk * ncorr * data_bytes memory_per_row = bytes_per_row # flags (uint8 or bool) memory_per_row += bytes_per_row / 8 # UVW memory_per_row += xds[0].UVW.data.itemsize * 3 # ANTENNA1/2 memory_per_row += xds[0].ANTENNA1.data.itemsize * 2 # TIME memory_per_row += xds[0].TIME.data.itemsize # data column is not actually read into memory just used to infer # dtype and chunking columns = (args.data_column, args.weight_column, args.flag_column, 'UVW', 'ANTENNA1', 'ANTENNA2', 'TIME') # flag row if 'FLAG_ROW' in xds[0]: columns += ('FLAG_ROW', ) memory_per_row += xds[0].FLAG_ROW.data.itemsize # imaging weights if args.imaging_weight_column is not None: columns += (args.imaging_weight_column, ) memory_per_row += bytes_per_row / 2 # Mueller term (complex valued) if args.mueller_column is not None: columns += (args.mueller_column, ) memory_per_row += bytes_per_row # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 for ims in args.ms: xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1}) for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) uv_max = uv_max.compute() del uvw # image size cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed) if args.cell_size is not None: cell_size = args.cell_size cell_rad = cell_size * np.pi / 60 / 60 / 180 if cell_N / cell_rad < 1: raise ValueError( "Requested cell size too small. " "Super resolution factor = ", cell_N / cell_rad) print("Super resolution factor = %f" % (cell_N / cell_rad), file=log) else: cell_rad = cell_N / args.super_resolution_factor cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % cell_size, file=log) if args.nx is None: fov = args.field_of_view * 3600 npix = int(args.psf_oversize * fov / cell_size) if npix % 2: npix += 1 nx = npix ny = npix else: nx = args.nx ny = args.ny if args.ny is not None else nx print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log) # get approx image size # this is not a conservative estimate when multiple SPW's map to a single # imaging band pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # full image on single node row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size, nrow, memory_per_row, args.nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow, memory_per_row, args.nthreads_per_worker) if args.row_chunks is not None: row_chunk = int(args.row_chunks) if row_chunk == -1: row_chunk = nrow print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) chunks = {} for ims in args.ms: chunks[ims] = [] # xds_from_ms expects a list per ds for spw in freqs[ims]: chunks[ims].append({ 'row': row_chunk, 'chan': chan_chunks[ims][spw]['chan'] }) psfs = [] radec = None # assumes we are only imaging field 0 of first MS out_datasets = [] for ims in args.ms: xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue # this is not correct, need to use spw spw = ds.DATA_DESC_ID uvw = clone(ds.UVW.data) data_type = getattr(ds, args.data_column).data.dtype data_shape = getattr(ds, args.data_column).data.shape data_chunks = getattr(ds, args.data_column).data.chunks weights = getattr(ds, args.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data_shape, chunks=data_chunks) if args.imaging_weight_column is not None: imaging_weights = getattr(ds, args.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data_shape, chunks=data_chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # apply mueller term if args.mueller_column is not None: mueller = getattr(ds, args.mueller_column).data weightsxx *= da.absolute(mueller[:, :, 0])**2 weightsyy *= da.absolute(mueller[:, :, -1])**2 # weighted sum corr to Stokes I weights = weightsxx + weightsyy # MS may contain auto-correlations if 'FLAG_ROW' in xds[0]: frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data == ds.ANTENNA2.data) else: frow = (ds.ANTENNA1.data == ds.ANTENNA2.data) # only keep data where both corrs are unflagged flag = getattr(ds, args.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] # ducc0 uses uint8 mask not flag mask = ~da.logical_or((flagxx | flagyy), frow[:, None]) psf = vis2im(uvw, freqs[ims][spw], weights.astype(data_type), freq_bin_idx[ims][spw], freq_bin_counts[ims][spw], nx, ny, cell_rad, flag=mask.astype(np.uint8), nthreads=args.nvthreads, epsilon=args.epsilon, do_wstacking=args.wstack, double_accum=args.double_accum) psfs.append(psf) data_vars = { 'FIELD_ID': (('row', ), da.full_like(ds.TIME.data, ds.FIELD_ID, chunks=args.row_out_chunk)), 'DATA_DESC_ID': (('row', ), da.full_like(ds.TIME.data, ds.DATA_DESC_ID, chunks=args.row_out_chunk)), 'WEIGHT': (('row', 'chan'), weights.rechunk({0: args.row_out_chunk })), # why no 'f4'? 'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk})) } coords = {'chan': (('chan', ), freqs[ims][spw])} out_ds = Dataset(data_vars, coords) out_datasets.append(out_ds) writes = xds_to_zarr(out_datasets, args.output_filename + '.zarr', columns='ALL') # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False) # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False) if not args.mock: # psfs = dask.compute(psfs, writes, optimize_graph=False)[0] # with performance_report(filename=args.output_filename + '_psf_per.html'): psfs = dask.compute(psfs, writes, optimize_graph=False)[0] psf = stitch_images(psfs, nband, band_mapping) hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, freq_out) save_fits(args.output_filename + '_psf.fits', psf, hdr, dtype=args.output_type) psf_mfs = np.sum(psf, axis=0) wsum = psf_mfs.max() psf_mfs /= wsum hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, np.mean(freq_out)) save_fits(args.output_filename + '_psf_mfs.fits', psf_mfs, hdr_mfs, dtype=args.output_type) print("All done here.", file=log)
def main(args): # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 all_freqs = [] for ims in args.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), columns=('UVW'), chunks={'row': args.row_chunks}) spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") spws = dask.compute(spws)[0] for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) spw = spws[ds.DATA_DESC_ID] tmp_freq = spw.CHAN_FREQ.data.squeeze() all_freqs.append(list(tmp_freq)) uv_max = u_max.compute() del uvw # get Nyquist cell size from africanus.constants import c as lightspeed all_freqs = dask.compute(all_freqs) freq = np.unique(all_freqs) cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed) if args.cell_size is not None: cell_rad = args.cell_size * np.pi / 60 / 60 / 180 print("Super resolution factor = ", cell_N / cell_rad) else: cell_rad = cell_N / args.super_resolution_factor args.cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % args.cell_size) if args.nx is None or args.ny is None: fov = args.fov * 3600 npix = int(fov / args.cell_size) if npix % 2: npix += 1 args.nx = npix args.ny = npix if args.nband is None: args.nband = freq.size print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny)) # init gridder R = Gridder(args.ms, args.nx, args.ny, args.cell_size, nband=args.nband, nthreads=args.nthreads, do_wstacking=args.do_wstacking, row_chunks=args.row_chunks, data_column=args.data_column, weight_column=args.weight_column, epsilon=args.epsilon, imaging_weight_column=args.imaging_weight_column, model_column=args.model_column, flag_column=args.flag_column) freq_out = R.freq_out radec = R.radec # get headers hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx, args.ny, radec, freq_out) hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx, args.ny, radec, np.mean(freq_out)) hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600, 2 * args.nx, 2 * args.ny, radec, freq_out) hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, 2 * args.nx, 2 * args.ny, radec, np.mean(freq_out)) # psf if args.psf is not None: try: compare_headers(hdr_psf, fits.getheader(args.psf)) psf_array = load_fits(args.psf) except: psf_array = R.make_psf() save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf) else: psf_array = R.make_psf() save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf) psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny), axis=1) wsum = np.sum(psf_max) counts = np.sum(psf_max > 0) psf_max_mean = wsum / counts # normalissation for more intuitive sig_21 values psf_array /= psf_max_mean psf = PSF(psf_array, args.nthreads) psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny), axis=1) wsum = np.sum(psf_max) psf_max[psf_max < 1e-15] = 1e-15 # LB - is this the right thing to do? psf_mfs = np.sum(psf_array, axis=0) / wsum save_fits( args.outfile + '_psf_mfs.fits', psf_mfs[args.nx // 2:3 * args.nx // 2, args.ny // 2:3 * args.ny // 2], hdr_mfs) # dirty if args.dirty is not None: try: compare_headers(hdr, fits.getheader(args.dirty)) dirty = load_fits(args.dirty) except: dirty = R.make_dirty() save_fits(args.outfile + '_dirty.fits', dirty, hdr) else: dirty = R.make_dirty() save_fits(args.outfile + '_dirty.fits', dirty, hdr) dirty_mfs = np.sum(dirty / psf_max_mean, axis=0) / wsum save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs) residual = dirty.copy() model = np.zeros((2, args.nband, args.nx, args.ny)) recompute_residual = False if args.beta0 is not None: compare_headers(hdr, fits.getheader(args.beta0)) model[0] = load_fits(args.beta0).squeeze() recompute_residual = True if args.alpha0 is not None: compare_headers(hdr, fits.getheader(args.alpha0)) model[1] = load_fits(args.alpha0).squeeze() recompute_residual = True # normalise for more intuitive hypers residual /= psf_max_mean residual_mfs = np.sum(residual, axis=0) / wsum save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs) # mask if args.mask is not None: mask = load_fits(args.mask, dtype=np.int64)[None, :, :] if mask.shape != (1, args.nx, args.ny): raise ValueError("Mask has incorrect shape") else: mask = np.ones((1, args.nx, args.ny), dtype=np.int64) # point mask pmask = load_fits(args.point_mask, dtype=np.bool)[None, :, :] if pmask.shape != (1, args.nx, args.ny): raise ValueError("Mask has incorrect shape") # set up splitting operator phi = lambda x: x[0] * pmask + x[1] * mask phih = lambda x: np.concatenate( ((pmask * x)[None], (mask * x)[None]), axis=0) if recompute_residual: image = phi(model) residual = R.make_residual(image) / psf_max_mean residual_mfs = np.sum(residual, axis=0) / wsum # Gaussian "prior" used for preconditioning extended emission A = Gauss(args.sig_l2a, args.nband, args.nx, args.ny, args.nthreads) # preconditioning matrix def hess(x): return phih(psf.convolve(phi(x))) + np.concatenate( (x[0:1] / args.sig_l2b**2, A.idot(x[1])[None]), axis=0) # return phih(psf.convolve(phi(x))) + np.concatenate((x[0:1]/args.sig_l2b**2, x[1::]/args.sig_l2a**2), axis=0) # M_func = lambda x: np.concatenate((x[0:1] * args.sig_l2b**2, x[1::] * args.sig_l2a**2), axis=0) M_func = lambda x: np.concatenate( (x[0:1] * args.sig_l2b**2, A.convolve(x[1])[None]), axis=0) par_shape = phih(dirty).shape if args.beta is None: print("Getting spectral norm of update operator") beta = power_method(hess, par_shape, tol=args.pmtol, maxit=args.pmmaxit) else: beta = args.beta print(" beta = %f " % beta) # set up wavelet basis theta = DaskTheta(args.nband, args.nx, args.ny, nthreads=args.nthreads) nbasis = theta.nbasis weights_21 = np.ones((theta.nbasis + 1, theta.nmax), dtype=np.float64) tmp = np.pad(pmask.ravel(), (0, theta.nmax - args.nx * args.ny), mode='constant') weights_21[0] = np.where(tmp, args.sig_21b / args.sig_21a, 1e15) dual = np.zeros((theta.nbasis + 1, args.nband, theta.nmax), dtype=np.float64) # Reporting report_iters = list(np.arange(0, args.maxit, args.report_freq)) if report_iters[-1] != args.maxit - 1: report_iters.append(args.maxit - 1) # deconvolve eps = 1.0 i = 0 rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) print("Peak of initial residual is %f and rms is %f" % (rmax, rms)) for i in range(1, args.maxit): x = pcg(hess, phih(residual), np.zeros(par_shape, dtype=np.float64), M=M_func, tol=args.cgtol, maxit=args.cgmaxit, verbosity=args.cgverbose) if i in report_iters: save_fits(args.outfile + str(i) + '_point_update.fits', x[0], hdr) save_fits(args.outfile + str(i) + '_fluff_update.fits', x[1], hdr) # update model modelp = model model = modelp + args.gamma * x model, dual = primal_dual(hess, model, modelp, dual, args.sig_21a, theta, weights_21, beta, tol=args.pdtol, maxit=args.pdmaxit, report_freq=100, mask=mask, positivity=args.positivity, gamma=args.gamma) # get residual image = phi(model) residual = R.make_residual(image) / psf_max_mean # check stopping criteria residual_mfs = np.sum(residual, axis=0) / wsum rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) eps = np.linalg.norm(model - modelp) / np.linalg.norm(model) if i in report_iters: # save current iteration save_fits(args.outfile + str(i) + '_model.fits', image, hdr) save_fits(args.outfile + str(i) + '_point.fits', model[0], hdr) save_fits(args.outfile + str(i) + '_fluff.fits', model[1], hdr) model_mfs = np.mean(image, axis=0) save_fits(args.outfile + str(i) + '_model_mfs.fits', model_mfs, hdr_mfs) save_fits(args.outfile + str(i) + '_residual.fits', residual, hdr) save_fits(args.outfile + str(i) + '_residual_mfs.fits', residual_mfs, hdr_mfs) print( "At iteration %i peak of residual is %f, rms is %f, current eps is %f" % (i, rmax, rms, eps)) if eps < args.tol: break if args.interp_model: nband = args.nband order = args.spectral_poly_order mask = np.where(model_mfs > 1e-10, 1, 0) I = np.argwhere(mask).squeeze() Ix = I[:, 0] Iy = I[:, 1] npix = I.shape[0] # get components beta = image[:, Ix, Iy] # fit integrated polynomial to model components # we are given frequencies at bin centers, convert to bin edges ref_freq = np.mean(freq_out) delta_freq = freq_out[1] - freq_out[0] wlow = (freq_out - delta_freq / 2.0) / ref_freq whigh = (freq_out + delta_freq / 2.0) / ref_freq wdiff = whigh - wlow # set design matrix for each component Xdesign = np.zeros([freq_out.size, args.spectral_poly_order]) for i in range(1, args.spectral_poly_order + 1): Xdesign[:, i - 1] = (whigh**i - wlow**i) / (i * wdiff) weights = psf_max[:, None] dirty_comps = Xdesign.T.dot(weights * beta) hess_comps = Xdesign.T.dot(weights * Xdesign) comps = np.linalg.solve(hess_comps, dirty_comps) np.savez(args.outfile + "spectral_comps", comps=comps, ref_freq=ref_freq, mask=np.any(model, axis=0)) if args.write_model: if args.interp_model: R.write_component_model(comps, ref_freq, mask, args.row_chunks, args.chan_chunks) else: R.write_model(model)
def read_ms(ms, num_vis, res_arcmin, chunks=50000, channel=0, field_id=0): """ Use dask-ms to load the necessary data to create a telescope operator (will use uvw positions, and antenna positions) -- res_arcmin: Used to calculate the maximum baselines to consider. We want two pixels per smallest fringe pix_res > fringe / 2 u sin(theta) = n (for nth fringe) at small angles: theta = 1/u, or bl_max = 1 / theta d sin(theta) = lambda / 2 d / lambda = 1 / (2 sin(theta)) bl_max = lambda / 2sin(theta) """ # local_cluster = distributed.LocalCluster(processes=False) # address = local_cluster.scheduler_address # logging.info("Using distributed scheduler " # "with address '{}'".format(address)) # client = distributed.Client() try: # Create a dataset representing the entire antenna table ant_table = "::".join((ms, "ANTENNA")) for ant_ds in xds_from_table(ant_table): # print(ant_ds) # print(dask.compute(ant_ds.NAME.data, # ant_ds.POSITION.data, # ant_ds.DISH_DIAMETER.data)) ant_p = np.array(ant_ds.POSITION.data) logger.info("Antenna Positions {}".format(ant_p.shape)) # Create a dataset representing the field field_table = "::".join((ms, "FIELD")) for field_ds in xds_from_table(field_table): phase_dir = np.array(field_ds.PHASE_DIR.data)[0].flatten() name = field_ds.NAME.data.compute() logger.info("Field {}: Phase Dir {}".format( name, np.degrees(phase_dir))) # Create datasets representing each row of the spw table spw_table = "::".join((ms, "SPECTRAL_WINDOW")) for spw_ds in xds_from_table(spw_table, group_cols="__row__"): logger.info("CHAN_FREQ.values: {}".format( spw_ds.CHAN_FREQ.values.shape)) frequencies = dask.compute(spw_ds.CHAN_FREQ.values)[0].flatten() frequency = frequencies[channel] logger.info("Frequencies = {}".format(frequencies)) logger.info("Frequency = {}".format(frequency)) logger.info("NUM_CHAN = %f" % np.array(spw_ds.NUM_CHAN.values)[0]) # Create datasets from a partioning of the MS datasets = list(xds_from_ms(ms, chunks={"row": chunks})) logger.info("DataSets: N={}".format(len(datasets))) pol = 0 def read_np_array(da, title, dtype=np.float32): tic = time.perf_counter() logger.info("Reading {}...".format(title)) ret = np.array(da, dtype=dtype) toc = time.perf_counter() logger.info("Elapsed {:04f} seconds".format(toc - tic)) return ret for i, ds in enumerate(datasets): logger.info("DATASET field_id={} shape: {}".format( ds.FIELD_ID, ds.DATA.data.shape)) logger.info("UVW shape: {}".format(ds.UVW.data.shape)) logger.info("SIGMA shape: {}".format(ds.SIGMA.data.shape)) if int(field_id) == int(ds.FIELD_ID): uvw = read_np_array(ds.UVW.data, "UVW") flags = read_np_array(ds.FLAG.data[:, channel, pol], "FLAGS", dtype=np.int32) # # # Now calculate which indices we should use to get the required number of # visibilities. # bl_max = get_resolution_max_baseline(res_arcmin, frequency) logger.info("Resolution Max UVW: {:g} meters".format(bl_max)) logger.info("Flags: {}".format(flags.shape)) # Now report the recommended resolution from the data. # 1.0 / 2*np.sin(theta) = limit_u limit_uvw = np.max(np.abs(uvw), 0) res_limit = get_baseline_resolution(limit_uvw[0], frequency) logger.info("Nyquist resolution: {:g} arcmin".format( np.degrees(res_limit) * 60.0)) if True: bl = np.sqrt(uvw[:, 0]**2 + uvw[:, 1]**2 + uvw[:, 2]**2) # good_data = np.array(np.where((flags == 0) & (np.max(np.abs(uvw), 1) < bl_max))).T.reshape((-1,)) good_data = np.array(np.where((flags == 0) & (bl < bl_max))).T.reshape( (-1, )) else: good_data = np.array(np.where(flags == 0)).T.reshape( (-1, )) logger.info("Good Data {}".format(good_data.shape)) logger.info("Maximum UVW: {}".format(limit_uvw)) logger.info("Minimum UVW: {}".format(np.min(np.abs(uvw), 0))) for i in range(3): p05, p50, p95 = np.percentile(np.abs(uvw[:, i]), [5, 50, 95]) logger.info(" U[{}]: {:5.2f} {:5.2f} {:5.2f}".format( i, p05, p50, p95)) n_ant = len(ant_p) n_max = len(good_data) if n_max <= num_vis: indices = np.arange(n_max) else: indices = np.random.choice(good_data, min(num_vis, n_max), replace=False) # sort the indices to keep them in order (speeds up IO) indices = np.sort(indices) # # # Now read the remaining data # sigma = read_np_array(ds.SIGMA.data[indices, pol], "SIGMA") # ant1 = read_np_array(ds.ANTENNA1.data[indices], "ANTENNA1") # ant12 = read_np_array(ds.ANTENNA1.data[indices], "ANTENNA2") cv_vis = read_np_array(ds.DATA.data[indices, channel, pol], "DATA", dtype=np.complex64) epoch_seconds = np.array(ds.TIME.data)[0] if "uvw" not in locals(): raise RuntimeError("FIELD_ID ({}) is invalid".format(field_id)) hdr = { "CTYPE1": ("RA---SIN", "Right ascension angle cosine"), "CRVAL1": np.degrees(phase_dir)[0], "CUNIT1": "deg ", "CTYPE2": ("DEC--SIN", "Declination angle cosine "), "CRVAL2": np.degrees(phase_dir)[1], "CUNIT2": "deg ", "CTYPE3": "FREQ ", # / Central frequency ", "CRPIX3": 1.0, "CRVAL3": "{}".format(frequency), "CDELT3": 10026896.158854, "CUNIT3": "Hz ", "EQUINOX": "2000.", "DATE-OBS": "{}".format(epoch_seconds), "BTYPE": "Intensity", } # from astropy.wcs.utils import celestial_frame_to_wcs # from astropy.coordinates import FK5 # frame = FK5(equinox='J2010') # wcs = celestial_frame_to_wcs(frame) # wcs.to_header() u_arr = uvw[indices, 0].T v_arr = uvw[indices, 1].T w_arr = uvw[indices, 2].T rms_arr = sigma.T logger.info("Max vis {}".format(np.max(np.abs(cv_vis)))) # Convert from reduced Julian Date to timestamp. timestamp = datetime.datetime( 1858, 11, 17, 0, 0, 0, tzinfo=datetime.timezone.utc) + datetime.timedelta( seconds=epoch_seconds) except Exception as e: logger.info("Exception {}".format(e)) # finally: # client.close() # local_cluster.close() return u_arr, v_arr, w_arr, frequency, cv_vis, hdr, timestamp, rms_arr
def get_field_names(myms): field_tab = xms.xds_from_table( myms+'::FIELD', columns=['NAME', 'SOURCE_ID']) field_ids = field_tab[0].SOURCE_ID.values field_names = field_tab[0].NAME.values return field_ids, field_names
def __init__(self, ms_name, nx, ny, cell_size, nband=None, nthreads=8, do_wstacking=1, Stokes='I', row_chunks=100000, optimise_chunks=True, epsilon=1e-5, data_column='CORRECTED_DATA', weight_column='WEIGHT_SPECTRUM', model_column="MODEL_DATA", flag_column='FLAG', imaging_weight_column=None): if Stokes != 'I': raise NotImplementedError("Only Stokes I currently supported") self.nx = nx self.ny = ny self.cell = cell_size * np.pi/60/60/180 self.nthreads = nthreads self.do_wstacking = do_wstacking self.epsilon = epsilon self.data_column = data_column self.weight_column = weight_column self.model_column = model_column self.flag_column = flag_column if isinstance(ms_name, list): self.ms = ms_name else: self.ms = [ms_name] # first pass through data to determine freq_mapping self.radec = None self.freq = {} all_freqs = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks={"row":-1}, columns=('TIME')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] self.freq[ims] = {} for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() # check fields match if self.radec is None: self.radec = radec if not np.array_equal(radec, self.radec): continue spw = spws[ds.DATA_DESC_ID] tmp_freq = spw.CHAN_FREQ.data.squeeze() self.freq[ims][ds.DATA_DESC_ID] = tmp_freq all_freqs.append(list([tmp_freq])) # freq mapping all_freqs = dask.compute(all_freqs) ufreqs = np.unique(all_freqs) # returns ascending sorted self.nchan = ufreqs.size if nband is None: self.nband = self.nchan else: self.nband = nband # bin edges fmin = ufreqs[0] fmax = ufreqs[-1] fbins = np.linspace(fmin, fmax, self.nband+1) self.freq_out = np.zeros(self.nband) for band in range(self.nband): indl = ufreqs >= fbins[band] indu = ufreqs < fbins[band + 1] + 1e-6 self.freq_out[band] = np.mean(ufreqs[indl & indu]) # chan <-> band mapping self.band_mapping = {} self.chunks = {} self.freq_bin_idx = {} self.freq_bin_counts = {} for ims in self.freq: self.freq_bin_idx[ims] = {} self.freq_bin_counts[ims] = {} self.band_mapping[ims] = {} self.chunks[ims] = [] for spw in self.freq[ims]: freq = np.atleast_1d(dask.compute(self.freq[ims][spw])[0]) band_map = np.zeros(freq.size, dtype=np.int32) for band in range(self.nband): indl = freq >= fbins[band] indu = freq < fbins[band + 1] + 1e-6 band_map = np.where(indl & indu, band, band_map) # to dask arrays bands, bin_counts = np.unique(band_map, return_counts=True) self.band_mapping[ims][spw] = tuple(bands) self.chunks[ims].append({'row':(-1,), 'chan':tuple(bin_counts)}) self.freq[ims][spw] = da.from_array(freq, chunks=tuple(bin_counts)) bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1] self.freq_bin_idx[ims][spw] = da.from_array(bin_idx, chunks=1) self.freq_bin_counts[ims][spw] = da.from_array(bin_counts, chunks=1) self.imaging_weight_column = imaging_weight_column if imaging_weight_column is not None: self.columns = (self.data_column, self.weight_column, self.imaging_weight_column, self.flag_column, 'UVW') else: self.columns = (self.data_column, self.weight_column, self.flag_column, 'UVW')
def both(args): """Generate model data, corrupted visibilities and gains (phase-only or normal)""" # Set thread count to cpu count if args.ncpu: from multiprocessing.pool import ThreadPool import dask dask.config.set(pool=ThreadPool(args.ncpu)) else: import multiprocessing args.ncpu = multiprocessing.cpu_count() # Get full time column and compute row chunks ms = xds_from_table(args.ms)[0] row_chunks, tbin_idx, tbin_counts = chunkify_rows( ms.TIME, args.utimes_per_chunk) # Convert time rows to dask arrays tbin_idx = da.from_array(tbin_idx, chunks=(args.utimes_per_chunk)) tbin_counts = da.from_array(tbin_counts, chunks=(args.utimes_per_chunk)) # Time axis n_time = tbin_idx.size # Get antenna columns ant1 = ms.ANTENNA1.data ant2 = ms.ANTENNA2.data # No. of antennas axis n_ant = (np.maximum(ant1.max(), ant2.max()) + 1).compute() # Get flag column flag = ms.FLAG.data # Get convention if args.phase_convention == 'CASA': uvw = -(ms.UVW.data.astype(np.float64)) elif args.phase_convention == 'CODEX': uvw = ms.UVW.data.astype(np.float64) else: raise ValueError("Unknown sign convention for phase") # Get rest of dimensions n_row, n_freq, n_corr = flag.shape # Raise error if correlation axis too small if n_corr != 4: raise NotImplementedError("Only 4 correlations "\ + "currently supported") # Get phase direction radec0_table = xds_from_table(args.ms+'::FIELD')[0] radec0 = radec0_table.PHASE_DIR.data.squeeze().compute() # Get frequency column freq_table = xds_from_table(args.ms+'::SPECTRAL_WINDOW')[0] freq = freq_table.CHAN_FREQ.data.astype(np.float64)[0] # Check dimension assert freq.size == n_freq # Check for sky-model if args.sky_model == 'MODEL-1.txt': args.sky_model = MODEL_1 elif args.sky_model == 'MODEL-4.txt': args.sky_model = MODEL_4 elif args.sky_model == 'MODEL-50.txt': args.sky_model = MODEL_50 else: raise NotImplemented(f"Sky-model {args.sky_model} not in "\ + "kalcal/datasets/sky_model/") # Build source model from lsm lsm = Tigger.load(args.sky_model) # Direction axis n_dir = len(lsm.sources) # Create initial model array model = np.zeros((n_dir, n_freq, n_corr), dtype=np.float64) # Create initial coordinate array and source names lm = np.zeros((n_dir, 2), dtype=np.float64) source_names = [] # Cycle coordinates creating a source with flux for d, source in enumerate(lsm.sources): # Extract name source_names.append(source.name) # Extract position radec_s = np.array([[source.pos.ra, source.pos.dec]]) lm[d] = radec_to_lm(radec_s, radec0) # Get flux - Stokes I if source.flux.I: I0 = source.flux.I # Get spectrum (only spi currently supported) tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 0] = I0 * (freq/ref_freq)**spi # Get flux - Stokes Q if source.flux.Q: Q0 = source.flux.Q # Get spectrum tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 1] = Q0 * (freq/ref_freq)**spi # Get flux - Stokes U if source.flux.U: U0 = source.flux.U # Get spectrum tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 2] = U0 * (freq/ref_freq)**spi # Get flux - Stokes V if source.flux.V: V0 = source.flux.V # Get spectrum tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 3] = V0 * (freq/ref_freq)**spi # Generate gains jones = None jones_shape = None # Dask to NP t = tbin_idx.compute() nu = freq.compute() print('==> Both-mode') if args.mode == "phase": jones = phase_gains(lm, nu, n_time, n_ant, args.alpha_std) elif args.mode == "normal": jones = normal_gains(t, nu, lm, n_ant, n_corr, args.sigma_f, args.lt, args.lnu, args.ls) else: raise ValueError("Only normal and phase modes available.") print() # Reduce jones to diagonals only jones = jones[:, :, :, :, (0, -1)] # Jones to complex jones = jones.astype(np.complex128) # Jones shape jones_shape = jones.shape # Generate filename if args.out == "": args.out = f"{args.mode}.npy" # Save gains and settings to file with open(args.out, 'wb') as file: np.save(file, jones) # Build dask graph lm = da.from_array(lm, chunks=lm.shape) model = da.from_array(model, chunks=model.shape) jones_da = da.from_array(jones, chunks=(args.utimes_per_chunk,) + jones_shape[1::]) # Append antenna columns cols = [] cols.append('ANTENNA1') cols.append('ANTENNA2') cols.append('UVW') # Load data in in chunks and apply gains to each chunk xds = xds_from_ms(args.ms, columns=cols, chunks={"row": row_chunks})[0] ant1 = xds.ANTENNA1.data ant2 = xds.ANTENNA2.data # Adjust UVW based on phase-convention if args.phase_convention == 'CASA': uvw = -xds.UVW.data.astype(np.float64) elif args.phase_convention == 'CODEX': uvw = xds.UVW.data.astype(np.float64) else: raise ValueError("Unknown sign convention for phase") # Get model visibilities model_vis = np.zeros((n_row, n_freq, n_dir, n_corr), dtype=np.complex128) for s in range(n_dir): model_vis[:, :, s] = im_to_vis( model[s].reshape((1, n_freq, n_corr)), uvw, lm[s].reshape((1, 2)), freq, dtype=np.complex64, convention='fourier') # NP to Dask model_vis = da.from_array(model_vis, chunks=(row_chunks, n_freq, n_dir, n_corr)) # Convert Stokes to corr in_schema = ['I', 'Q', 'U', 'V'] out_schema = [['RR', 'RL'], ['LR', 'LL']] model_vis = convert(model_vis, in_schema, out_schema) # Apply gains data = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones_da, model_vis).reshape( (n_row, n_freq, n_corr)) # Assign model visibilities out_names = [] for d in range(n_dir): xds = xds.assign(**{source_names[d]: (("row", "chan", "corr"), model_vis[:, :, d].reshape( n_row, n_freq, n_corr).astype(np.complex64))}) out_names += [source_names[d]] # Assign noise free visibilities to 'CLEAN_DATA' xds = xds.assign(**{'CLEAN_DATA': (("row", "chan", "corr"), data.astype(np.complex64))}) out_names += ['CLEAN_DATA'] # Get noise realisation if args.sigma_n > 0.0: # Noise matrix noise = (da.random.normal(loc=0.0, scale=args.sigma_n, size=(n_row, n_freq, n_corr), chunks=(row_chunks, n_freq, n_corr)) \ + 1.0j*da.random.normal(loc=0.0, scale=args.sigma_n, size=(n_row, n_freq, n_corr), chunks=(row_chunks, n_freq, n_corr)))/np.sqrt(2.0) # Zero matrix for off-diagonals zero = da.zeros_like(noise[:, :, 0]) # Dask to NP noise = noise.compute() zero = zero.compute() # Remove noise on off-diagonals noise[:, :, 1] = zero[:, :] noise[:, :, 2] = zero[:, :] # NP to Dask noise = da.from_array(noise, chunks=(row_chunks, n_freq, n_corr)) # Assign noise to 'NOISE' xds = xds.assign(**{'NOISE': (("row", "chan", "corr"), noise.astype(np.complex64))}) out_names += ['NOISE'] # Add noise to data and assign to 'DATA' noisy_data = data + noise xds = xds.assign(**{'DATA': (("row", "chan", "corr"), noisy_data.astype(np.complex64))}) out_names += ['DATA'] # Create a write to the table write = xds_to_table(xds, args.ms, out_names) # Submit all graph computations in parallel with ProgressBar(): write.compute() print(f"==> Applied Jones to MS: {args.ms} <--> {args.out}")
def get_chan_freqs(myms): spw_tab = xms.xds_from_table( myms+'::SPECTRAL_WINDOW', columns=['CHAN_FREQ']) chan_freqs = spw_tab[0].CHAN_FREQ return chan_freqs
def jones(args): """Generate jones matrix only, but based off of a measurement set.""" # Set thread count to cpu count if args.ncpu: from multiprocessing.pool import ThreadPool import dask dask.config.set(pool=ThreadPool(args.ncpu)) else: import multiprocessing args.ncpu = multiprocessing.cpu_count() # Get full time column and compute row chunks ms = xds_from_table(args.ms)[0] _, tbin_idx, tbin_counts = chunkify_rows( ms.TIME, args.utimes_per_chunk) # Convert time rows to dask arrays tbin_idx = da.from_array(tbin_idx, chunks=(args.utimes_per_chunk)) tbin_counts = da.from_array(tbin_counts, chunks=(args.utimes_per_chunk)) # Time axis n_time = tbin_idx.size # Get antenna columns ant1 = ms.ANTENNA1.data ant2 = ms.ANTENNA2.data # No. of antennas axis n_ant = (np.maximum(ant1.max(), ant2.max()) + 1).compute() # Get flag column flag = ms.FLAG.data # Get convention if args.phase_convention == 'CASA': uvw = -(ms.UVW.data.astype(np.float64)) elif args.phase_convention == 'CODEX': uvw = ms.UVW.data.astype(np.float64) else: raise ValueError("Unknown sign convention for phase") # Get rest of dimensions n_row, n_freq, n_corr = flag.shape # Raise error if correlation axis too small if n_corr != 4: raise NotImplementedError("Only 4 correlations "\ + "currently supported") # Get phase direction radec0_table = xds_from_table(args.ms+'::FIELD')[0] radec0 = radec0_table.PHASE_DIR.data.squeeze().compute() # Get frequency column freq_table = xds_from_table(args.ms+'::SPECTRAL_WINDOW')[0] freq = freq_table.CHAN_FREQ.data.astype(np.float64)[0] # Check dimension assert freq.size == n_freq # Check for sky-model if args.sky_model == 'MODEL-1.txt': args.sky_model = MODEL_1 elif args.sky_model == 'MODEL-4.txt': args.sky_model = MODEL_4 elif args.sky_model == 'MODEL-50.txt': args.sky_model = MODEL_50 else: raise ValueError(f"Sky-model {args.sky_model} not in "\ + "kalcal/datasets/sky_model/") # Build source model from lsm lsm = Tigger.load(args.sky_model) # Direction axis n_dir = len(lsm.sources) # Create initial coordinate array and source names lm = np.zeros((n_dir, 2), dtype=np.float64) # Cycle coordinates creating a source with flux for d, source in enumerate(lsm.sources): # Extract position radec_s = np.array([[source.pos.ra, source.pos.dec]]) lm[d] = radec_to_lm(radec_s, radec0) # Generate gains jones = None print('==> Jones-only mode') if args.mode == "phase": jones = phase_gains(lm, freq, n_time, n_ant, args.alpha_std) elif args.mode == "normal": jones = normal_gains(tbin_idx, freq, lm, n_ant, n_corr, args.sigma_f, args.lt, args.lnu, args.ls) else: raise ValueError("Only normal and phase modes available.") # Reduce jones to diagonals only jones = jones[:, :, :, :, (0, -1)] # Jones to complex jones = jones.astype(np.complex128) # Generate filename if args.out == "": args.out = f"{args.mode}.npy" # Save gains and settings to file with open(args.out, 'wb') as file: np.save(file, jones) print(f"==> Created Jones data: {args.out}")
def new(ms, sky_model, **kwargs): """Generate a jones matrix based on a given sky-model either as phase-only or normal gains, as an .npy file.""" # Options to attributed dictionary if kwargs["yaml"] is not None: options = ocf.load(kwargs["yaml"]) else: options = ocf.create(kwargs) # Set to struct ocf.set_struct(options, True) # Change path to sky model if chosen try: sky_model = sky_models[sky_model.lower()] except: # Own sky model reference pass # Load ms MS = xds_from_ms(ms)[0] # Get dimensions (correlations need to be adapted) dims = ocf.create(dict(MS.sizes)) n_chan = dims.chan n_corr = dims.corr # Get time-bin indices and counts _, tbin_indices, _ = np.unique(MS.TIME, return_index=True, return_counts=True) # Set time dimension n_time = len(tbin_indices) # Get antenna arrays (dask ignored for now) ant1 = MS.ANTENNA1.data.compute() ant2 = MS.ANTENNA2.data.compute() # Set antenna dimension n_ant = np.max((np.max(ant1), np.max(ant2))) + 1 # Build source model from lsm lsm = Tigger.load(sky_model) # Set direction axis as per source n_dir = len(lsm.sources) # Get phase direction radec0_table = xds_from_table(ms + '::FIELD')[0] radec0 = radec0_table.PHASE_DIR.data.squeeze().compute() # Get frequency column freq_table = xds_from_table(ms + '::SPECTRAL_WINDOW')[0] freq = freq_table.CHAN_FREQ.data.astype(np.float64)[0] # Check dimension assert freq.size == n_chan # Create initial coordinate array and source names lm = np.zeros((n_dir, 2), dtype=np.float64) # Cycle coordinates creating a source with flux for d, source in enumerate(lsm.sources): # Extract position radec_s = np.array([[source.pos.ra, source.pos.dec]]) lm[d] = radec_to_lm(radec_s, radec0) # Direction independent gains if options.die: lm = np.array(lm[0]).reshape((1, -1)) n_dir = 1 # Choose between phase-only or normal if options.type == "phase": # Run phase-only print("==> Simulating `phase-only` gains, with dimensions ("\ + f"n_time={n_time}, n_ant={n_ant}, n_chan={n_chan}, "\ + f"n_dir={n_dir}, n_corr={n_corr})") jones = phase_gains(lm, freq, n_time, n_ant, n_chan, n_dir, n_corr, options.std) elif options.type == "normal": # With normal selected, get differentials lt, lnu, ls = options.diffs # Run normal print("==> Simulating `normal` gains, with dimensions ("\ + f"n_time={n_time}, n_ant={n_ant}, n_chan={n_chan}, "\ + f"n_dir={n_dir}, n_corr={n_corr})") jones = normal_gains(tbin_indices, freq, lm, n_time, n_ant, n_chan, n_dir, n_corr, options.std, lt, lnu, ls) # Output to jones to .npy file gains_file = (options.type + ".npy") if options.out_file is None\ else options.out_file with open(gains_file, 'wb') as file: np.save(file, jones) print(f"==> Completed and gains saved to: {gains_file}")
def read_ms(ms, num_vis, res_arcmin, chunks=50000, channel=0): ''' Use dask-ms to load the necessary data to create a telescope operator (will use uvw positions, and antenna positions) -- res_arcmin: Used to calculate the maximum baselines to consider. We want two pixels per smallest fringe pix_res > fringe / 2 u sin(theta) = n (for nth fringe) at small angles: theta = 1/u, or u_max = 1 / theta d sin(theta) = lambda / 2 d / lambda = 1 / (2 sin(theta)) u_max = lambda / 2sin(theta) ''' with scheduler_context(): # Create a dataset representing the entire antenna table ant_table = '::'.join((ms, 'ANTENNA')) for ant_ds in xds_from_table(ant_table): #print(ant_ds) #print(dask.compute(ant_ds.NAME.data, #ant_ds.POSITION.data, #ant_ds.DISH_DIAMETER.data)) ant_p = np.array(ant_ds.POSITION.data) logger.info("Antenna Positions {}".format(ant_p.shape)) # Create a dataset representing the field field_table = '::'.join((ms, 'FIELD')) for field_ds in xds_from_table(field_table): #print(ant_ds) #print(dask.compute(ant_ds.NAME.data, #ant_ds.POSITION.data, #ant_ds.DISH_DIAMETER.data)) phase_dir = np.array(field_ds.PHASE_DIR.data)[0].flatten() logger.info("Phase Dir {}".format(np.degrees(phase_dir))) # Create datasets representing each row of the spw table spw_table = '::'.join((ms, 'SPECTRAL_WINDOW')) for spw_ds in xds_from_table(spw_table, group_cols="__row__"): #print(spw_ds) #print(spw_ds.NUM_CHAN.values) logger.info("CHAN_FREQ.values: {}".format( spw_ds.CHAN_FREQ.values.shape)) frequencies = dask.compute(spw_ds.CHAN_FREQ.values)[0].flatten() frequency = frequencies[channel] logger.info("Frequencies = {}".format(frequencies)) logger.info("Frequency = {}".format(frequency)) logger.info("NUM_CHAN = %f" % np.array(spw_ds.NUM_CHAN.values)[0]) # Create datasets from a partioning of the MS datasets = list(xds_from_ms(ms, chunks={'row': chunks})) pol = 0 for ds in datasets: logger.info("DATA shape: {}".format(ds.DATA.data.shape)) logger.info("UVW shape: {}".format(ds.UVW.data.shape)) uvw = np.array(ds.UVW.data) # UVW is stored in meters! ant1 = np.array(ds.ANTENNA1.data) ant2 = np.array(ds.ANTENNA2.data) flags = np.array(ds.FLAG.data) cv_vis = np.array(ds.DATA.data)[:, channel, pol] epoch_seconds = np.array(ds.TIME.data)[0] # Try write the STATE_ID column back write = xds_to_table(ds, ms, 'STATE_ID') with ProgressBar(), Profiler() as prof: write.compute() # Profile #prof.visualize(file_path="chunked.html") ### NOW REMOVE DATA THAT DOESN'T FIT THE IMAGE RESOLUTION u_max = get_resolution_max_baseline(res_arcmin, frequency) logger.info("Resolution Max UVW: {:g}".format(u_max)) logger.info("Flags: {}".format(flags.shape)) # Now report the recommended resolution from the data. # 1.0 / 2*np.sin(theta) = limit_u limit_uvw = np.max(np.abs(uvw), 0) res_limit = get_baseline_resolution(limit_uvw[0], frequency) logger.info("Nyquist resolution: {:g} arcmin".format( np.degrees(res_limit) * 60.0)) #maxuvw = np.max(np.abs(uvw), 1) #logger.info(np.random.choice(maxuvw, 100)) if False: good_data = np.array(np.where(flags[:, channel, pol] == 0)).T.reshape((-1, )) else: good_data = np.array( np.where((flags[:, channel, pol] == 0) & (np.max(np.abs(uvw), 1) < u_max))).T.reshape((-1, )) logger.info("Good Data {}".format(good_data.shape)) logger.info("Maximum UVW: {}".format(limit_uvw)) logger.info("Minimum UVW: {}".format(np.min(np.abs(uvw), 0))) n_ant = len(ant_p) good_vis = cv_vis[good_data] n_max = len(good_vis) indices = np.random.choice(good_data, min(num_vis, n_max)) hdr = { 'CTYPE1': ('RA---SIN', "Right ascension angle cosine"), 'CRVAL1': np.degrees(phase_dir)[0], 'CUNIT1': 'deg ', 'CTYPE2': ('DEC--SIN', "Declination angle cosine "), 'CRVAL2': np.degrees(phase_dir)[1], 'CUNIT2': 'deg ', 'CTYPE3': 'FREQ ', # / Central frequency ", 'CRPIX3': 1., 'CRVAL3': "{}".format(frequency), 'CDELT3': 10026896.158854, 'CUNIT3': 'Hz ', 'EQUINOX': '2000.', 'DATE-OBS': "{}".format(epoch_seconds), 'BTYPE': 'Intensity' } #from astropy.wcs.utils import celestial_frame_to_wcs #from astropy.coordinates import FK5 #frame = FK5(equinox='J2010') #wcs = celestial_frame_to_wcs(frame) #wcs.to_header() u_arr = uvw[indices, 0] v_arr = uvw[indices, 1] w_arr = uvw[indices, 2] cv_vis = cv_vis[indices] # Convert from reduced Julian Date to timestamp. timestamp = datetime.datetime( 1858, 11, 17, 0, 0, 0, tzinfo=datetime.timezone.utc) + datetime.timedelta( seconds=epoch_seconds) return u_arr, v_arr, w_arr, frequency, cv_vis, hdr, timestamp
def main(args): if args.precision > 1e-6: real_type = np.float32 complex_type = np.complex64 else: real_type = np.float64 complex_type = np.complex128 # get max uv coords over all fields uvw = [] xds = xds_from_table(args.table_name, group_cols=('FIELD_ID'), columns=('UVW'), chunks={'row': -1}) for ds in xds: uvw.append(ds.UVW.data.compute()) uvw = np.concatenate(uvw) from africanus.constants import c as lightspeed u_max = np.abs(uvw[:, 0]).max() v_max = np.abs(uvw[:, 1]).max() # del uvw # get Nyquist cell size freq = xds_from_table(args.table_name + "::FREQ")[0].FREQ.data.compute().squeeze() uv_max = np.maximum(u_max, v_max) cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed) if args.cell_size is not None: cell_rad = args.cell_size * np.pi / 60 / 60 / 180 print("Super resolution factor = ", cell_N / cell_rad) else: cell_rad = cell_N / args.super_resolution_factor args.cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % args.cell_size) if args.nx is None: fov = args.fov * 3600 nx = int(fov / args.cell_size) from scipy.fftpack import next_fast_len args.nx = next_fast_len(nx) if args.ny is None: fov = args.fov * 3600 ny = int(fov / args.cell_size) from scipy.fftpack import next_fast_len args.ny = next_fast_len(ny) if args.channels_out is None: args.channels_out = freq.size print("Image size set to (%i, %i, %i)" % (args.channels_out, args.nx, args.ny)) # init gridder R = OutMemGridder(args.table_name, args.nx, args.ny, args.cell_size, freq, nband=args.channels_out, field=args.field, precision=args.precision, ncpu=args.ncpu, do_wstacking=args.do_wstacking, data_column=args.data_column, weight_column=args.weight_column) freq_out = R.freq_out # get headers radec = xds_from_table(args.table_name + "::RADEC")[0].RADEC.data.compute().squeeze() hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx, args.ny, radec, freq_out) hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx, args.ny, radec, np.mean(freq_out)) hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600, 2 * args.nx, 2 * args.ny, radec, freq_out) hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, 2 * args.nx, 2 * args.ny, radec, np.mean(freq_out)) # make psf LB - TODO: undersized psfs psf = R.make_psf() nband = R.nband psf_max = np.amax(psf.reshape(nband, 4 * args.nx * args.ny), axis=1) # make dirty dirty = R.make_dirty() # save dirty and psf images save_fits(args.outfile + '_dirty.fits', dirty, hdr, dtype=real_type) save_fits(args.outfile + '_psf.fits', psf, hdr_psf, dtype=real_type) # MFS images wsum = np.sum(psf_max) dirty_mfs = np.sum(dirty, axis=0) / wsum save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs) psf_mfs = np.sum(psf, axis=0) / wsum save_fits(args.outfile + '_psf_mfs.fits', psf_mfs, hdr_psf_mfs) rmax = np.abs(dirty_mfs).max() rms = np.std(dirty_mfs) print("Peak of dirty is %f and rms is %f" % (rmax, rms))
def make_psf(self): print("Making PSF") psfs = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data flag = getattr(ds, self.flag_column).data weights = getattr(ds, self.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], flag.shape, chunks=flag.chunks) if self.imaging_weight_column is not None: imaging_weights = getattr(ds, self.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data.shape, chunks=data.chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # weighted sum corr to Stokes I weights = weightsxx + weightsyy data = weights.astype(np.complex64) # only keep data where both corrs are unflagged flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] flag = ~ (flagxx | flagyy) # ducc0 convention psf = vis2im(uvw, freq, data, freq_bin_idx, freq_bin_counts, 2*self.nx, 2*self.ny, self.cell, flag=flag.astype(np.uint8), nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking) psfs.append(psf) psfs = dask.compute(psfs)[0] return accumulate_dirty(psfs, self.nband, self.band_mapping).astype(np.float64)
def main(args): """ Flags outliers in data given a model and rescale weights so that whitened residuals have a mean amplitude of sqrt(2). Flags and weights are computed per chunk of data """ radec_ref = None writes = [] for ims in args.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks={ "row": args.row_chunks, "chan": args.chan_chunks }, columns=('UVW', args.data_column, args.weight_column, args.model_column, args.flag_column, 'FLAG_ROW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() # check fields match if radec_ref is None: radec_ref = radec if not np.array_equal(radec, radec_ref): continue # load in data and compute whitened residuals data = getattr(ds, args.data_column).data model = getattr(ds, args.model_column).data flag = getattr(ds, args.flag_column).data flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None]) weights = getattr(ds, args.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) if args.trim_channels: flag = trim_chans(flag, args.trim_channels) # Stokes I vis weights = (~flag) * weights resid_vis = (data - model) * weights wsums = (weights[:, :, 0] + weights[:, :, -1]) resid_vis_I = da.where( wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums, 0.0j) # whiten and take abs white_resid = resid_vis_I * da.sqrt(wsums) abs_resid_vis_I = (white_resid).__abs__() # mean amp sum_amp = da.sum(abs_resid_vis_I) count = da.sum(wsums > 0) mean_amp = sum_amp / count flag_legacy = flag[:, :, 0] | flag[:, :, -1] flag_I = da.logical_or(abs_resid_vis_I > args.sigma_cut * mean_amp, flag_legacy) # new flags updated_flag = da.broadcast_to(flag_I[:, :, None], flag.shape, chunks=flag.chunks) # scale weights (whitened residuals should have mean amplitude of 1/sqrt(2)) if args.scale_weights: # recompute mean amp with new flags weights = (~updated_flag) * weights resid_vis = (data - model) * weights wsums = (weights[:, :, 0] + weights[:, :, -1]) resid_vis_I = da.where( wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums, 0.0j) white_resid = resid_vis_I * da.sqrt(wsums) abs_resid_vis_I = (white_resid).__abs__() sum_amp = da.sum(abs_resid_vis_I) count = da.sum(wsums > 0) mean_amp = sum_amp / count updated_weight = 2**0.5 * weights / mean_amp**2 else: updated_weight = weights ds = ds.assign(**{ args.weight_out_column: (("row", "chan", "corr"), updated_weight) }) ds = ds.assign(**{ args.flag_out_column: (("row", "chan", "corr"), updated_flag) }) out_data.append(ds) writes.append( xds_to_table( out_data, ims, columns=[args.flag_out_column, args.weight_out_column])) with ProgressBar(): dask.compute(writes) # report new mean amp if args.report_means: radec_ref = None mean_amps = [] for ims in args.ms: xds = xds_from_ms( ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks={ "row": args.row_chunks, "chan": args.chan_chunks }, columns=('UVW', args.data_column, args.weight_out_column, args.model_column, args.flag_out_column, 'FLAG_ROW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() # check fields match if radec_ref is None: radec_ref = radec if not np.array_equal(radec, radec_ref): continue # load in data and compute whitened residuals data = getattr(ds, args.data_column).data model = getattr(ds, args.model_column).data flag = getattr(ds, args.flag_out_column).data flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None]) weights = getattr(ds, args.weight_out_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) # Stokes I vis weights = (~flag) * weights resid_vis = (data - model) * weights wsums = (weights[:, :, 0] + weights[:, :, -1]) resid_vis_I = da.where( wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums, 0.0j) # whiten and take abs white_resid = resid_vis_I * da.sqrt(wsums) abs_resid_vis_I = (white_resid).__abs__() # mean amp sum_amp = da.sum(abs_resid_vis_I) count = da.sum(wsums > 0) mean_amps.append(sum_amp / count) mean_amps = dask.compute(mean_amps)[0] print(mean_amps)
MAX([SELECT ABS(UVW[1]) FROM {ms}]) as ABS_VMAX, MIN([SELECT UVW[2] FROM {ms}]) AS WMIN, MAX([SELECT UVW[2] FROM {ms}]) AS WMAX """.format(ms=args.ms) with pt.taql(query) as Q: umin = Q.getcol("ABS_UMIN").item() umax = Q.getcol("ABS_UMAX").item() vmin = Q.getcol("ABS_VMIN").item() vmax = Q.getcol("ABS_VMAX").item() wmin = Q.getcol("WMIN").item() wmax = Q.getcol("WMAX").item() xds = list(xds_from_ms(args.ms, chunks={"row": args.chunks}))[0] spw_ds = list( xds_from_table("::".join((args.ms, "SPECTRAL_WINDOW")), group_cols="__row__"))[0] wavelength = (lightspeed / spw_ds.CHAN_FREQ.data[0]).compute() if args.cell_size: cell_size = args.cell_size else: cell_size = estimate_cell_size(umax, vmax, wavelength, factor=3, ny=args.npix, nx=args.npix).max() # Convolution Filter conv_filter = convolution_filter(3, 63, "kaiser-bessel")
def new(ms, sky_model, gains, **kwargs): """Generate model visibilties per source (as direction axis) for stokes I and Q and generate relevant visibilities.""" # Options to attributed dictionary if kwargs["yaml"] is not None: options = ocf.load(kwargs["yaml"]) else: options = ocf.create(kwargs) # Set to struct ocf.set_struct(options, True) # Change path to sky model if chosen try: sky_model = sky_models[sky_model.lower()] except: # Own sky model reference pass # Set thread count to cpu count if options.ncpu: from multiprocessing.pool import ThreadPool import dask dask.config.set(pool=ThreadPool(options.ncpu)) else: import multiprocessing options.ncpu = multiprocessing.cpu_count() # Load gains to corrupt with with open(gains, "rb") as file: jones = np.load(file) # Load dimensions n_time, n_ant, n_chan, n_dir, n_corr = jones.shape n_row = n_time * (n_ant * (n_ant - 1) // 2) # Load ms MS = xds_from_ms(ms)[0] # Get time-bin indices and counts row_chunks, tbin_indices, tbin_counts = chunkify_rows( MS.TIME, options.utime) # Close and reopen with chunked rows MS.close() MS = xds_from_ms(ms, chunks={"row": row_chunks})[0] # Get antenna arrays (dask ignored for now) ant1 = MS.ANTENNA1.data ant2 = MS.ANTENNA2.data # Adjust UVW based on phase-convention if options.phase_convention.upper() == 'CASA': uvw = -MS.UVW.data.astype(np.float64) elif options.phase_convention.upper() == 'CODEX': uvw = MS.UVW.data.astype(np.float64) else: raise ValueError("Unknown sign convention for phase.") # MS dimensions dims = ocf.create(dict(MS.sizes)) # Close MS MS.close() # Build source model from lsm lsm = Tigger.load(sky_model) # Check if dimensions match jones assert n_time * (n_ant * (n_ant - 1) // 2) == dims.row assert n_time == len(tbin_indices) assert n_ant == np.max((np.max(ant1), np.max(ant2))) + 1 assert n_chan == dims.chan assert n_corr == dims.corr # If gains are DIE if options.die: assert n_dir == 1 n_dir = len(lsm.sources) else: assert n_dir == len(lsm.sources) # Get phase direction radec0_table = xds_from_table(ms + '::FIELD')[0] radec0 = radec0_table.PHASE_DIR.data.squeeze().compute() radec0_table.close() # Get frequency column freq_table = xds_from_table(ms + '::SPECTRAL_WINDOW')[0] freq = freq_table.CHAN_FREQ.data.astype(np.float64)[0] freq_table.close() # Get feed orientation feed_table = xds_from_table(ms + '::FEED')[0] feeds = feed_table.POLARIZATION_TYPE.data[0].compute() # Create initial model array model = np.zeros((n_dir, n_chan, n_corr), dtype=np.float64) # Create initial coordinate array and source names lm = np.zeros((n_dir, 2), dtype=np.float64) source_names = [] # Cycle coordinates creating a source with flux print("==> Building model visibilities") for d, source in enumerate(lsm.sources): # Extract name source_names.append(source.name) # Extract position radec_s = np.array([[source.pos.ra, source.pos.dec]]) lm[d] = radec_to_lm(radec_s, radec0) # Get flux - Stokes I if source.flux.I: I0 = source.flux.I # Get spectrum (only spi currently supported) tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 0] = I0 * (freq / ref_freq)**spi # Get flux - Stokes Q if source.flux.Q: Q0 = source.flux.Q # Get spectrum tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 1] = Q0 * (freq / ref_freq)**spi # Get flux - Stokes U if source.flux.U: U0 = source.flux.U # Get spectrum tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 2] = U0 * (freq / ref_freq)**spi # Get flux - Stokes V if source.flux.V: V0 = source.flux.V # Get spectrum tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 3] = V0 * (freq / ref_freq)**spi # Close sky-model del lsm # Build dask graph tbin_indices = da.from_array(tbin_indices, chunks=(options.utime)) tbin_counts = da.from_array(tbin_counts, chunks=(options.utime)) lm = da.from_array(lm, chunks=lm.shape) model = da.from_array(model, chunks=model.shape) jones = da.from_array(jones, chunks=(options.utime, ) + jones.shape[1::]) # Apply image to visibility for each source sources = [] for s in range(n_dir): source_vis = im_to_vis(model[s].reshape((1, n_chan, n_corr)), uvw, lm[s].reshape((1, 2)), freq, dtype=np.complex64, convention='fourier') sources.append(source_vis) model_vis = da.stack(sources, axis=2) # Sum over direction? if options.die: model_vis = da.sum(model_vis, axis=2, keepdims=True) n_dir = 1 source_names = [options.mname] # Select schema based on feed orientation if (feeds == ["X", "Y"]).all(): out_schema = [["XX", "XY"], ["YX", "YY"]] elif (feeds == ["R", "L"]).all(): out_schema = [['RR', 'RL'], ['LR', 'LL']] else: raise ValueError("Unknown feed orientation implementation.") # Convert Stokes to Correlations in_schema = ['I', 'Q', 'U', 'V'] model_vis = convert(model_vis, in_schema, out_schema).reshape( (n_row, n_chan, n_dir, n_corr)) # Apply gains to model_vis print("==> Corrupting visibilities") data = corrupt_vis(tbin_indices, tbin_counts, ant1, ant2, jones, model_vis) # Reopen MS MS = xds_from_ms(ms, chunks={"row": row_chunks})[0] # Assign model visibilities out_names = [] for d in range(n_dir): MS = MS.assign( **{ source_names[d]: (("row", "chan", "corr"), model_vis[:, :, d].astype(np.complex64)) }) out_names += [source_names[d]] # Assign noise free visibilities to 'CLEAN_DATA' MS = MS.assign( **{ 'CLEAN_' + options.dname: (("row", "chan", "corr"), data.astype(np.complex64)) }) out_names += ['CLEAN_' + options.dname] # Get noise realisation if options.std > 0.0: # Noise matrix print(f"==> Applying noise (std={options.std}) to visibilities") noise = [] for i in range(2): real = da.random.normal(loc=0.0, scale=options.std, size=(n_row, n_chan), chunks=(row_chunks, n_chan)) imag = 1.0j * (da.random.normal(loc=0.0, scale=options.std, size=(n_row, n_chan), chunks=(row_chunks, n_chan))) noise.append(real + imag) # Zero matrix for off-diagonals zero = da.zeros((n_row, n_chan), chunks=(row_chunks, n_chan)) noise.insert(1, zero) noise.insert(2, zero) # NP to Dask noise = da.stack(noise, axis=2).rechunk((row_chunks, n_chan, n_corr)) # Assign noise to 'NOISE' MS = MS.assign( **{'NOISE': (("row", "chan", "corr"), noise.astype(np.complex64))}) out_names += ['NOISE'] # Add noise to data and assign to 'DATA' noisy_data = data + noise MS = MS.assign( **{ options.dname: (("row", "chan", "corr"), noisy_data.astype(np.complex64)) }) out_names += [options.dname] # Create a write to the table write = xds_to_table(MS, ms, out_names) # Submit all graph computations in parallel print(f"==> Executing `dask-ms` write to `{ms}` for the following columns: "\ + f"{', '.join(out_names)}") with ProgressBar(): write.compute() print(f"==> Completed.")
def main(args): # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 all_freqs = [] for ims in args.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), columns=('UVW'), chunks={'row': args.row_chunks}) spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") spws = dask.compute(spws)[0] for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) spw = spws[ds.DATA_DESC_ID] tmp_freq = spw.CHAN_FREQ.data.squeeze() all_freqs.append(list([tmp_freq])) uv_max = u_max.compute() del uvw # get Nyquist cell size from africanus.constants import c as lightspeed all_freqs = dask.compute(all_freqs) freq = np.unique(all_freqs) cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed) if args.cell_size is not None: cell_rad = args.cell_size * np.pi / 60 / 60 / 180 print("Super resolution factor = ", cell_N / cell_rad) else: cell_rad = cell_N / args.super_resolution_factor args.cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % args.cell_size) if args.nx is None or args.ny is None: fov = args.fov * 3600 npix = int(fov / args.cell_size) if npix % 2: npix += 1 args.nx = npix args.ny = npix if args.nband is None: args.nband = freq.size print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny)) # init gridder R = Gridder(args.ms, args.nx, args.ny, args.cell_size, nband=args.nband, nthreads=args.nthreads, do_wstacking=args.do_wstacking, row_chunks=args.row_chunks, data_column=args.data_column, weight_column=args.weight_column, epsilon=args.epsilon, imaging_weight_column=args.imaging_weight_column, model_column=args.model_column, flag_column=args.flag_column) freq_out = R.freq_out radec = R.radec # get headers hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx, args.ny, radec, freq_out) hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx, args.ny, radec, np.mean(freq_out)) hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600, 2 * args.nx, 2 * args.ny, radec, freq_out) hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, 2 * args.nx, 2 * args.ny, radec, np.mean(freq_out)) # psf if args.psf is not None: try: compare_headers(hdr_psf, fits.getheader(args.psf)) psf_array = load_fits(args.psf) except: psf_array = R.make_psf() save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf) else: psf_array = R.make_psf() save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf) psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny), axis=1) wsum = np.sum(psf_max) counts = np.sum(psf_max > 0) psf_max_mean = wsum / counts # normalissation for more intuitive sig_21 values psf_array /= psf_max_mean psf = PSF(psf_array, args.nthreads) psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny), axis=1) wsum = np.sum(psf_max) psf_max[psf_max < 1e-15] = 1e-15 # LB - is this the right thing to do? psf_mfs = np.sum(psf_array, axis=0) / wsum save_fits( args.outfile + '_psf_mfs.fits', psf_mfs[args.nx // 2:3 * args.nx // 2, args.ny // 2:3 * args.ny // 2], hdr_mfs) # dirty if args.dirty is not None: try: compare_headers(hdr, fits.getheader(args.dirty)) dirty = load_fits(args.dirty) except: dirty = R.make_dirty() save_fits(args.outfile + '_dirty.fits', dirty, hdr) else: dirty = R.make_dirty() save_fits(args.outfile + '_dirty.fits', dirty, hdr) dirty_mfs = np.sum(dirty / psf_max_mean, axis=0) / wsum save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs) if args.x0 is not None: try: compare_headers(hdr, fits.getheader(args.x0)) model = load_fits(args.x0, dtype=np.float64) if args.first_residual is not None: try: compare_headers(hdr, fits.getheader(args.first_residual)) residual = load_fits(args.first_residual, dtype=np.float64) except: residual = R.make_residual(model) save_fits(args.outfile + '_first_residual.fits', residual, hdr) else: residual = R.make_residual(model) save_fits(args.outfile + '_first_residual.fits', residual, hdr) except: model = np.zeros((args.nband, args.nx, args.ny)) residual = dirty.copy() else: model = np.zeros((args.nband, args.nx, args.ny)) residual = dirty.copy() # normalise for more intuitive hypers residual /= psf_max_mean residual_mfs = np.sum(residual, axis=0) / wsum save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs) # mask if args.mask is not None: mask = load_fits(args.mask, dtype=np.int64)[None, :, :] if mask.shape != (1, args.nx, args.ny): raise ValueError("Mask has incorrect shape") else: mask = np.ones((1, args.nx, args.ny), dtype=np.int64) # preconditioning matrix def hess(x): return mask * psf.convolve(mask * x) + x / args.sig_l2**2 if args.beta is None: print("Getting spectral norm of update operator") beta = power_method(hess, dirty.shape, tol=args.pmtol, maxit=args.pmmaxit) else: beta = args.beta print(" beta = %f " % beta) # set up wavelet basis if args.psi_basis is None: print("Using Dirac + db1-4 dictionary") psi = DaskPSI(args.nband, args.nx, args.ny, nlevels=args.psi_levels, nthreads=args.nthreads) # psi = PSI(args.nband, args.nx, args.ny, nlevels=args.psi_levels) else: if not isinstance(args.psi_basis, list): args.psi_basis = list(args.psi_basis) print("Using ", args.psi_basis, " dictionary") psi = DaskPSI(args.nband, args.nx, args.ny, nlevels=args.psi_levels, nthreads=args.nthreads, bases=args.psi_basis) # psi = PSI(args.nband, args.nx, args.ny, nlevels=args.psi_levels, bases=args.psi_basis) nbasis = psi.nbasis weights_21 = np.ones((psi.nbasis, psi.nmax), dtype=np.float64) dual = np.zeros((psi.nbasis, args.nband, psi.nmax), dtype=np.float64) # Reweighting if args.reweight_iters is not None: if not isinstance(args.reweight_iters, list): reweight_iters = [args.reweight_iters] else: reweight_iters = list(args.reweight_iters) else: reweight_iters = list( np.arange(args.reweight_start, args.reweight_end, args.reweight_freq)) reweight_iters.append(args.reweight_end) # Reporting report_iters = list(np.arange(0, args.maxit, args.report_freq)) if report_iters[-1] != args.maxit - 1: report_iters.append(args.maxit - 1) # deconvolve eps = 1.0 i = 0 rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) M = lambda x: x * args.sig_l2**2 # preconditioner print("Peak of initial residual is %f and rms is %f" % (rmax, rms)) for i in range(1, args.maxit): x = pcg(hess, mask * residual, np.zeros(dirty.shape, dtype=np.float64), M=M, tol=args.cgtol, maxit=args.cgmaxit, minit=args.cgminit, verbosity=args.cgverbose) if i in report_iters: save_fits(args.outfile + str(i) + '_update.fits', x, hdr) # update model modelp = model model = modelp + args.gamma * x model, dual = primal_dual(hess, model, modelp, dual, args.sig_21, psi, weights_21, beta, tol=args.pdtol, maxit=args.pdmaxit, report_freq=100, mask=mask, positivity=args.positivity) # reweighting if i in reweight_iters: v = psi.hdot(model) l2_norm = norm(v, axis=1) l2_norm = np.where(l2_norm < args.sig_21 * weights_21, 0.0, l2_norm) for m in range(psi.nbasis): indnz = l2_norm[m].nonzero() alpha = np.percentile(l2_norm[m, indnz].flatten(), args.reweight_alpha_percent) alpha = np.maximum(alpha, args.reweight_alpha_min) print("Reweighting - ", m, alpha) weights_21[m] = alpha / (l2_norm[m] + alpha) args.reweight_alpha_percent *= args.reweight_alpha_ff # print(" reweight alpha percent = ", args.reweight_alpha_percent) # get residual residual = R.make_residual(model) / psf_max_mean # check stopping criteria residual_mfs = np.sum(residual, axis=0) / wsum rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) eps = np.linalg.norm(model - modelp) / np.linalg.norm(model) if i in report_iters: # save current iteration save_fits(args.outfile + str(i) + '_model.fits', model, hdr) model_mfs = np.mean(model, axis=0) save_fits(args.outfile + str(i) + '_model_mfs.fits', model_mfs, hdr_mfs) save_fits(args.outfile + str(i) + '_residual.fits', residual, hdr) save_fits(args.outfile + str(i) + '_residual_mfs.fits', residual_mfs, hdr_mfs) print( "At iteration %i peak of residual is %f, rms is %f, current eps is %f" % (i, rmax, rms, eps)) if args.write_model: R.write_model(model) if args.make_restored: x = pcg(hess, residual, np.zeros(dirty.shape, dtype=np.float64), M=M, tol=args.cgtol, maxit=args.cgmaxit) restored = model + x # get residual residual = R.make_residual(restored) / psf_max_mean residual_mfs = np.sum(residual, axis=0) / wsum rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) print("After restoring peak of residual is %f and rms is %f" % (rmax, rms)) # save current iteration save_fits(args.outfile + '_restored.fits', restored, hdr) restored_mfs = np.mean(restored, axis=0) save_fits(args.outfile + '_restored_mfs.fits', restored_mfs, hdr_mfs) save_fits(args.outfile + '_restored_residual.fits', residual, hdr) save_fits(args.outfile + '_restored_residual_mfs.fits', residual_mfs, hdr_mfs)
def extract_dde_info(args, freqs): """ Computes paralactic angles, antenna scaling and pointing information required for beam interpolation. """ # get ms info required to compute paralactic angles and weighted sum nband = freqs.size if args.ms is not None: utimes = [] unflag_counts = [] ant_pos = None phase_dir = None for ms_name in args.ms: # get antenna positions ant = xds_from_table(ms_name + '::ANTENNA')[0].compute() if ant_pos is None: ant_pos = ant['POSITION'].data else: # check all are the same tmp = ant['POSITION'] if not np.array_equal(ant_pos, tmp): raise ValueError("Antenna positions not the same across measurement sets") # get phase center for field field = xds_from_table(ms_name + '::FIELD')[0].compute() if phase_dir is None: phase_dir = field['PHASE_DIR'][args.field].data.squeeze() else: tmp = field['PHASE_DIR'][args.field].data.squeeze() if not np.array_equal(phase_dir, tmp): raise ValueError('Phase direction not the same across measurement sets') # get unique times and count flags xds = xds_from_ms(ms_name, columns=["TIME", "FLAG_ROW"], group_cols=["FIELD_ID"])[args.field] utime, time_idx = np.unique(xds.TIME.data.compute(), return_index=True) ntime = utime.size # extract subset of times if args.sparsify_time > 1: I = np.arange(0, ntime, args.sparsify_time) utime = utime[I] time_idx = time_idx[I] ntime = utime.size utimes.append(utime) flags = xds.FLAG_ROW.data.compute() unflag_count = _unflagged_counts(flags.astype(np.int32), time_idx, np.zeros(ntime, dtype=np.int32)) unflag_counts.append(unflag_count) utimes = np.concatenate(utimes) unflag_counts = np.concatenate(unflag_counts) ntimes = utimes.size # compute paralactic angles parangles = parallactic_angles(utimes, ant_pos, phase_dir) # mean over antanna nant -> 1 parangles = np.mean(parangles, axis=1, keepdims=True) nant = 1 # beam_cube_dde requirements ant_scale = np.ones((nant, nband, 2), dtype=np.float64) point_errs = np.zeros((ntimes, nant, nband, 2), dtype=np.float64) return (parangles, da.from_array(ant_scale, chunks=ant_scale.shape), point_errs, unflag_counts, True) else: ntimes = 1 nant = 1 parangles = np.zeros((ntimes, nant,), dtype=np.float64) ant_scale = np.ones((nant, nband, 2), dtype=np.float64) point_errs = np.zeros((ntimes, nant, nband, 2), dtype=np.float64) unflag_counts = np.array([1]) return (parangles, ant_scale, point_errs, unflag_counts, False)
def antenna_flags_field(msname, fields=None, antennas=None): ds_ant = xds_from_table(msname+"::ANTENNA")[0] ds_field = xds_from_table(msname+"::FIELD")[0] ds_obs = xds_from_table(msname+"::OBSERVATION")[0] ant_names = ds_ant.NAME.data.compute() field_names = ds_field.NAME.data.compute() ant_positions = ds_ant.POSITION.data.compute() try: # Get observatory name and centre of array obs_name = ds_obs.TELESCOPE_NAME.data.compute()[0] me = casacore.measures.measures() obs_cofa = me.observatory(obs_name) lon, lat, alt = (obs_cofa['m0']['value'], obs_cofa['m1']['value'], obs_cofa['m2']['value']) cofa = wgs84_to_ecef(lon, lat, alt) except: # Otherwise use the first id antenna cofa = ant_positions[0] if fields: if isinstance(fields[0], str): field_ids = list(map(fields.index, fields)) else: field_ids = fields else: field_ids = list(range(len(field_names))) if antennas: if isinstance(antennas[0], str): ant_ids = list(map(antennas.index, antennas)) else: ant_ids = antennas else: ant_ids = list(range(len(ant_names))) nant = len(ant_ids) nfield = len(field_ids) fields_str = ", ".join(map(str, field_ids)) ds_mss = xds_from_ms(msname, group_cols=["FIELD_ID", "DATA_DESC_ID"], chunks={'row': 100000}, taql_where="FIELD_ID IN [%s]" % fields_str) flag_sum_computes = [] for ds in ds_mss: flag_sums = da.blockwise(_get_flags, ("row",), ant_ids, ("ant",), ds.ANTENNA1.data, ("row",), ds.ANTENNA2.data, ("row",), ds.FLAG.data, ("row","chan", "corr"), adjust_chunks={"row": nant }, dtype=numpy.ndarray) flags_redux = da.reduction(flag_sums, chunk=_chunk, combine=_combine, aggregate=_aggregate, concatenate=False, dtype=numpy.float64) flag_sum_computes.append(flags_redux) #flag_sum_computes[0].visualize("graph.pdf") sum_per_field_spw = dask.compute(flag_sum_computes)[0] sum_all = sum(sum_per_field_spw) fractions = sum_all[:,0]/sum_all[:,1] stats = {} for i,aid in enumerate(ant_ids): ant_stats = {} ant_pos = list(ant_positions[i]) ant_stats["name"] = ant_names[aid] ant_stats["position"] = ant_pos ant_stats["array_centre_dist"] = _distance(cofa, ant_pos) ant_stats["frac"] = fractions[i] ant_stats["sum"] = sum_all[i][0] ant_stats["counts"] = sum_all[i][1] stats[aid] = ant_stats return stats
def _main(dest=sys.stdout): from pfb.parser import create_parser args = create_parser().parse_args() if not args.nthreads: import multiprocessing args.nthreads = multiprocessing.cpu_count() if not args.mem_limit: import psutil args.mem_limit = int(psutil.virtual_memory()[0] / 1e9) # 100% of memory by default import numpy as np import numba import numexpr import dask import dask.array as da from daskms import xds_from_ms, xds_from_table from astropy.io import fits from pfb.utils.fits import (set_wcs, load_fits, save_fits, compare_headers, data_from_header) from pfb.utils.restoration import fitcleanbeam from pfb.utils.misc import Gaussian2D from pfb.operators.gridder import Gridder from pfb.operators.psf import PSF from pfb.deconv.sara import sara from pfb.deconv.clean import clean from pfb.deconv.spotless import spotless from pfb.deconv.nnls import nnls from pfb.opt.pcg import pcg if not isinstance(args.ms, list): args.ms = [args.ms] pyscilog.log_to_file(args.outfile + '.log') pyscilog.enable_memory_logging(level=3) GD = vars(args) print('Input Options:', file=log) for key in GD.keys(): print(' %25s = %s' % (key, GD[key]), file=log) # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 all_freqs = [] for ims in args.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), columns=('UVW'), chunks={'row': args.row_chunks}) spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") spws = dask.compute(spws)[0] for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) spw = spws[ds.DATA_DESC_ID] tmp_freq = spw.CHAN_FREQ.data.squeeze() all_freqs.append(list([tmp_freq])) uv_max = u_max.compute() del uvw # get Nyquist cell size from africanus.constants import c as lightspeed all_freqs = dask.compute(all_freqs) freq = np.unique(all_freqs) cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed) if args.cell_size is not None: cell_rad = args.cell_size * np.pi / 60 / 60 / 180 if cell_N / cell_rad < 1: raise ValueError( "Requested cell size too small. " "Super resolution factor = ", cell_N / cell_rad) print("Super resolution factor = %f" % (cell_N / cell_rad), file=dest) else: cell_rad = cell_N / args.super_resolution_factor args.cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % args.cell_size, file=dest) if args.nx is None or args.ny is None: from ducc0.fft import good_size fov = args.fov * 3600 npix = int(fov / args.cell_size) if npix % 2: npix += 1 args.nx = good_size(npix) args.ny = good_size(npix) if args.nband is None: args.nband = freq.size print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny), file=dest) # mask if args.mask is not None: mask_array = load_fits(args.mask, dtype=args.real_type).squeeze() if mask_array.shape != (args.nx, args.ny): raise ValueError("Mask has incorrect shape.") # add freq axis mask_array = mask_array[None] def mask(x): return mask_array * x else: mask_array = None def mask(x): return x # init gridder R = Gridder( args.ms, args.nx, args.ny, args.cell_size, nband=args.nband, nthreads=args.nthreads, do_wstacking=args.do_wstacking, row_chunks=args.row_chunks, psf_oversize=args.psf_oversize, data_column=args.data_column, epsilon=args.epsilon, weight_column=args.weight_column, imaging_weight_column=args.imaging_weight_column, model_column=args.model_column, flag_column=args.flag_column, weighting=args.weighting, robust=args.robust, mem_limit=int( 0.8 * args.mem_limit)) # assumes gridding accounts for 80% memory freq_out = R.freq_out radec = R.radec print("PSF size set to (%i, %i, %i)" % (args.nband, R.nx_psf, R.ny_psf), file=dest) # get headers hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx, args.ny, radec, freq_out) hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx, args.ny, radec, np.mean(freq_out)) hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600, R.nx_psf, R.ny_psf, radec, freq_out) hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, R.nx_psf, R.ny_psf, radec, np.mean(freq_out)) # psf if args.psf is not None: try: compare_headers(hdr_psf, fits.getheader(args.psf)) psf = load_fits(args.psf, dtype=args.real_type).squeeze() except BaseException: raise psf = R.make_psf() save_fits(args.outfile + '_psf.fits', psf, hdr_psf) else: psf = R.make_psf() save_fits(args.outfile + '_psf.fits', psf, hdr_psf) # Normalising by wsum (so that the PSF always sums to 1) results in the # most intuitive sig_21 values and by far the least bookkeeping. # However, we won't save the cubes that way as it destroys information # about the noise in image space. Note only the MFS images will have the # usual units of Jy/beam. wsums = np.amax(psf.reshape(args.nband, R.nx_psf * R.ny_psf), axis=1) wsum = np.sum(wsums) psf /= wsum psf_mfs = np.sum(psf, axis=0) # fit restoring psf GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0) GaussPars = fitcleanbeam(psf, level=0.5, pixsize=1.0) cpsf_mfs = np.zeros(psf_mfs.shape, dtype=args.real_type) cpsf = np.zeros(psf.shape, dtype=args.real_type) lpsf = np.arange(-R.nx_psf / 2, R.nx_psf / 2) mpsf = np.arange(-R.ny_psf / 2, R.ny_psf / 2) xx, yy = np.meshgrid(lpsf, mpsf, indexing='ij') cpsf_mfs = Gaussian2D(xx, yy, GaussPar[0], normalise=False) for v in range(args.nband): cpsf[v] = Gaussian2D(xx, yy, GaussPars[v], normalise=False) from pfb.utils.fits import add_beampars GaussPar = list(GaussPar[0]) GaussPar[0] *= args.cell_size / 3600 GaussPar[1] *= args.cell_size / 3600 GaussPar = tuple(GaussPar) hdr_psf_mfs = add_beampars(hdr_psf_mfs, GaussPar) save_fits(args.outfile + '_cpsf_mfs.fits', cpsf_mfs, hdr_psf_mfs) save_fits(args.outfile + '_psf_mfs.fits', psf_mfs, hdr_psf_mfs) GaussPars = list(GaussPars) for b in range(args.nband): GaussPars[b] = list(GaussPars[b]) GaussPars[b][0] *= args.cell_size / 3600 GaussPars[b][1] *= args.cell_size / 3600 GaussPars[b] = tuple(GaussPars[b]) GaussPars = tuple(GaussPars) hdr_psf = add_beampars(hdr_psf, GaussPar, GaussPars) save_fits(args.outfile + '_cpsf.fits', cpsf, hdr_psf) # dirty if args.dirty is not None: try: compare_headers(hdr, fits.getheader(args.dirty)) dirty = load_fits(args.dirty).squeeze() except BaseException: raise dirty = R.make_dirty() save_fits(args.outfile + '_dirty.fits', dirty, hdr) else: dirty = R.make_dirty() save_fits(args.outfile + '_dirty.fits', dirty, hdr) dirty /= wsum dirty_mfs = np.sum(dirty, axis=0) save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs) quit() # initial model and residual if args.x0 is not None: try: compare_headers(hdr, fits.getheader(args.x0)) model = load_fits(args.x0, dtype=args.real_type).squeeze() if args.first_residual is not None: try: compare_headers(hdr, fits.getheader(args.first_residual)) residual = load_fits(args.first_residual, dtype=args.real_type).squeeze() except BaseException: residual = R.make_residual(model) save_fits(args.outfile + '_first_residual.fits', residual, hdr) else: residual = R.make_residual(model) save_fits(args.outfile + '_first_residual.fits', residual, hdr) residual /= wsum except BaseException: model = np.zeros((args.nband, args.nx, args.ny)) residual = dirty.copy() else: model = np.zeros((args.nband, args.nx, args.ny)) residual = dirty.copy() residual_mfs = np.sum(residual, axis=0) save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs) # smooth beam if args.beam_model is not None: if args.beam_model[-5:] == '.fits': beam_image = load_fits(args.beam_model, dtype=args.real_type).squeeze() if beam_image.shape != (args.nband, args.nx, args.ny): raise ValueError("Beam has incorrect shape") elif args.beam_model == "JimBeam": from katbeam import JimBeam if args.band.lower() == 'l': beam = JimBeam('MKAT-AA-L-JIM-2020') else: beam = JimBeam('MKAT-AA-UHF-JIM-2020') beam_image = np.zeros((args.nband, args.nx, args.ny), dtype=args.real_type) l_coord, ref_l = data_from_header(hdr, axis=1) l_coord -= ref_l m_coord, ref_m = data_from_header(hdr, axis=2) m_coord -= ref_m xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij') for v in range(args.nband): beam_image[v] = beam.I(xx, yy, freq_out[v]) def beam(x): return beam_image * x else: beam_image = None def beam(x): return x if args.init_nnls: print("Initialising with NNLS", file=log) model = nnls(psf, model, residual, mask=mask_array, beam_image=beam_image, hdr=hdr, hdr_mfs=hdr_mfs, outfile=args.outfile, maxit=1, nthreads=args.nthreads) residual = R.make_residual(beam(mask(model))) / wsum residual_mfs = np.sum(residual, axis=0) # deconvolve rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) redo_dirty = False print("Peak of initial residual is %f and rms is %f" % (rmax, rms), file=dest) for i in range(0, args.maxit): # run minor cycle of choice modelp = model.copy() if args.deconv_mode == 'sara': model = sara(psf, model, residual, mask=mask_array, beam_image=beam_image, hessian=R.convolve, wsum=wsum, adapt_sig21=args.adapt_sig21, hdr=hdr, hdr_mfs=hdr_mfs, outfile=args.outfile, cpsf=cpsf, nthreads=args.nthreads, sig_21=args.sig_21, sigma_frac=args.sigma_frac, maxit=args.minormaxit, tol=args.minortol, gamma=args.gamma, psi_levels=args.psi_levels, psi_basis=args.psi_basis, pdtol=args.pdtol, pdmaxit=args.pdmaxit, pdverbose=args.pdverbose, positivity=args.positivity, cgtol=args.cgtol, cgminit=args.cgminit, cgmaxit=args.cgmaxit, cgverbose=args.cgverbose, pmtol=args.pmtol, pmmaxit=args.pmmaxit, pmverbose=args.pmverbose) elif args.deconv_mode == 'clean': model = clean(psf, model, residual, mask=mask_array, beam=beam_image, nthreads=args.nthreads, maxit=args.minormaxit, gamma=args.gamma, peak_factor=args.peak_factor, threshold=args.threshold, hbgamma=args.hbgamma, hbpf=args.hbpf, hbmaxit=args.hbmaxit, hbverbose=args.hbverbose) elif args.deconv_mode == 'spotless': model = spotless(psf, model, residual, mask=mask_array, beam_image=beam_image, hessian=R.convolve, wsum=wsum, adapt_sig21=args.adapt_sig21, cpsf=cpsf_mfs, hdr=hdr, hdr_mfs=hdr_mfs, outfile=args.outfile, sig_21=args.sig_21, sigma_frac=args.sigma_frac, nthreads=args.nthreads, gamma=args.gamma, peak_factor=args.peak_factor, maxit=args.minormaxit, tol=args.minortol, threshold=args.threshold, positivity=args.positivity, hbgamma=args.hbgamma, hbpf=args.hbpf, hbmaxit=args.hbmaxit, hbverbose=args.hbverbose, pdtol=args.pdtol, pdmaxit=args.pdmaxit, pdverbose=args.pdverbose, cgtol=args.cgtol, cgminit=args.cgminit, cgmaxit=args.cgmaxit, cgverbose=args.cgverbose, pmtol=args.pmtol, pmmaxit=args.pmmaxit, pmverbose=args.pmverbose) else: raise ValueError("Unknown deconvolution mode ", args.deconv_mode) # get residual if redo_dirty: # Need to do this if weights or Jones has changed # (eg. if we change robustness factor, reweight or calibrate) psf = R.make_psf() wsums = np.amax(psf.reshape(args.nband, R.nx_psf * R.ny_psf), axis=1) wsum = np.sum(wsums) psf /= wsum dirty = R.make_dirty() / wsum # compute in image space # residual = dirty - R.convolve(beam(mask(model))) / wsum residual = R.make_residual(beam(mask(model))) / wsum residual_mfs = np.sum(residual, axis=0) # save current iteration model_mfs = np.mean(model, axis=0) save_fits(args.outfile + '_major' + str(i + 1) + '_model_mfs.fits', model_mfs, hdr_mfs) save_fits(args.outfile + '_major' + str(i + 1) + '_model.fits', model, hdr) save_fits(args.outfile + '_major' + str(i + 1) + '_residual_mfs.fits', residual_mfs, hdr_mfs) save_fits(args.outfile + '_major' + str(i + 1) + '_residual.fits', residual * wsum, hdr) # check stopping criteria rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) eps = np.linalg.norm(model - modelp) / np.linalg.norm(model) print("At iteration %i peak of residual is %f, rms is %f, current " "eps is %f" % (i + 1, rmax, rms, eps), file=dest) if eps < args.tol: break if args.mop_flux: print("Mopping flux", file=dest) # vague Gaussian prior on x def hess(x): return mask(beam(R.convolve(mask(beam(x))))) / wsum + 1e-6 * x def M(x): return x / 1e-6 # preconditioner x = pcg(hess, mask(beam(residual)), np.zeros(residual.shape, dtype=residual.dtype), M=M, tol=0.1 * args.cgtol, maxit=args.cgmaxit, minit=args.cgminit, verbosity=args.cgverbose) model += x # residual = dirty - R.convolve(beam(mask(model))) / wsum residual = R.make_residual(beam(mask(model))) / wsum save_fits(args.outfile + '_mopped_model.fits', model, hdr) save_fits(args.outfile + '_mopped_residual.fits', residual, hdr) model_mfs = np.mean(model, axis=0) save_fits(args.outfile + '_mopped_model_mfs.fits', model_mfs, hdr_mfs) residual_mfs = np.sum(residual, axis=0) save_fits(args.outfile + '_mopped_residual_mfs.fits', residual_mfs, hdr_mfs) rmax = np.abs(residual_mfs).max() rms = np.std(residual_mfs) print("After mopping flux peak of residual is %f, rms is %f" % (rmax, rms), file=dest) # if args.interp_model: # nband = args.nband # order = args.spectral_poly_order # phi.trim_fat(model) # I = np.argwhere(phi.mask).squeeze() # Ix = I[:, 0] # Iy = I[:, 1] # npix = I.shape[0] # # get components # beta = model[:, Ix, Iy] # # fit integrated polynomial to model components # # we are given frequencies at bin centers, convert to bin edges # ref_freq = np.mean(freq_out) # delta_freq = freq_out[1] - freq_out[0] # wlow = (freq_out - delta_freq/2.0)/ref_freq # whigh = (freq_out + delta_freq/2.0)/ref_freq # wdiff = whigh - wlow # # set design matrix for each component # Xdesign = np.zeros([freq_out.size, args.spectral_poly_order]) # for i in range(1, args.spectral_poly_order+1): # Xdesign[:, i-1] = (whigh**i - wlow**i)/(i*wdiff) # weights = psf_max[:, None] # dirty_comps = Xdesign.T.dot(weights*beta) # hess_comps = Xdesign.T.dot(weights*Xdesign) # comps = np.linalg.solve(hess_comps, dirty_comps) # np.savez(args.outfile + "spectral_comps", comps=comps, ref_freq=ref_freq, mask=np.any(model, axis=0)) if args.write_model: print("Writing model", file=dest) R.write_model(model) if args.make_restored: print("Making restored", file=dest) cpsfo = PSF(cpsf, residual.shape, nthreads=args.nthreads) restored = cpsfo.convolve(model) # residual needs to be in Jy/beam before adding to convolved model wsums = np.amax(psf.reshape(-1, R.nx_psf * R.ny_psf), axis=1) restored += residual / wsums[:, None, None] save_fits(args.outfile + '_restored.fits', restored, hdr) restored_mfs = np.mean(restored, axis=0) save_fits(args.outfile + '_restored_mfs.fits', restored_mfs, hdr_mfs) residual_mfs = np.sum(residual, axis=0)
def chan_to_band_mapping(ms_name, nband=None): ''' Construct dictionaries containing per MS and SPW channel to band mapping. Currently assumes we are only imaging field 0 of the first MS. Input: ms_name - list of ms names nband - number of imaging bands Output: freqs - dict[MS][SPW] chunked dask arrays of the freq to band mapping freq_bin_idx - dict[MS][SPW] chunked dask arrays of bin starting indices freq_bin_counts - dict[MS][SPW] chunked dask arrays of counts in each bin freq_out - frequencies of average (LB - should a weighted sum rather be computed?) band_mapping - dict[MS][SPW] identifying imaging bands going into degridder chan_chunks - dict[MS][SPW] specifying dask chunking scheme over channel ''' from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table import dask import dask.array as da from omegaconf import ListConfig if not isinstance(ms_name, list) and not isinstance(ms_name, ListConfig): ms_name = [ms_name] # first pass through data to determine freq_mapping radec = None freqs = {} all_freqs = [] spws = {} for ims in ms_name: xds = xds_from_ms(ims, chunks={"row": -1}, columns=('TIME', )) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws_table = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws_table = dask.compute(spws_table)[0] pols = dask.compute(pols)[0] freqs[ims] = {} spws[ims] = [] for ds in xds: field = fields[ds.FIELD_ID] # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue spw = spws_table[ds.DATA_DESC_ID] tmp_freq = spw.CHAN_FREQ.data.squeeze() freqs[ims][ds.DATA_DESC_ID] = tmp_freq all_freqs.append(list([tmp_freq])) spws[ims].append(ds.DATA_DESC_ID) # freq mapping all_freqs = dask.compute(all_freqs) ufreqs = np.unique(all_freqs) # sorted ascending nchan = ufreqs.size if nband is None: nband = nchan else: nband = nband # bin edges fmin = ufreqs[0] fmax = ufreqs[-1] fbins = np.linspace(fmin, fmax, nband + 1) freq_out = np.zeros(nband) for band in range(nband): indl = ufreqs >= fbins[band] # inclusive except for the last one indu = ufreqs < fbins[band + 1] + 1e-6 freq_out[band] = np.mean(ufreqs[indl & indu]) # chan <-> band mapping band_mapping = {} chan_chunks = {} freq_bin_idx = {} freq_bin_counts = {} for ims in freqs: freq_bin_idx[ims] = {} freq_bin_counts[ims] = {} band_mapping[ims] = {} chan_chunks[ims] = [] for spw in freqs[ims]: freq = np.atleast_1d(dask.compute(freqs[ims][spw])[0]) band_map = np.zeros(freq.size, dtype=np.int32) for band in range(nband): indl = freq >= fbins[band] indu = freq < fbins[band + 1] + 1e-6 band_map = np.where(indl & indu, band, band_map) # to dask arrays bands, bin_counts = np.unique(band_map, return_counts=True) band_mapping[ims][spw] = tuple(bands) chan_chunks[ims].append({'chan': tuple(bin_counts)}) freqs[ims][spw] = da.from_array(freq, chunks=tuple(bin_counts)) bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1] freq_bin_idx[ims][spw] = da.from_array(bin_idx, chunks=1) freq_bin_counts[ims][spw] = da.from_array(bin_counts, chunks=1) return freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks
def __init__(self, ms_name, nx, ny, cell_size, nband=None, nthreads=8, do_wstacking=1, Stokes='I', row_chunks=-1, chan_chunks=32, optimise_chunks=True, epsilon=1e-5, psf_oversize=2.0, weighting=None, robust=None, data_column='CORRECTED_DATA', weight_column='WEIGHT_SPECTRUM', mueller_column=None, model_column="MODEL_DATA", flag_column='FLAG', imaging_weight_column=None, real_type='f4', cdir=None, mem_limit=None): ''' TODO - currently row_chunks and chan_chunks are only used for the compute_weights() and write_component_model() methods. All other methods assume that the data for a single imaging band per ms and spw fit into memory. The optimise_chunks argument is a promise to improve this in the future. TODO - current IO can probably be massively reduced if we optimize for specific Stokes outputs and we optimise the chunking strategy. In particular, we can write out the weights for Stokes I imaging in advance and then only load precomputed scalar weights in the convolve function. Since we currently load in weights, imaging weights and a complex "Mueller" term for all 4 correlations, we can in principle reduce IO and memory footprint by about a factor of 16. # of GB for 8 hr 8 sec 32k observation 64*(64-1) //2 * 8 * 60 * 60 // 8 * 2**15 * 4 * 8 / 1e9 = 7610 GB # of GB for 8 hr 8 sec 4k observation 64*(64-1) //2 * 8 * 60 * 60 // 8 * 2**15 * 4 * 8 / 1e9 = 951 GB ''' if Stokes != 'I': raise NotImplementedError("Only Stokes I currently supported") self.nx = nx self.ny = ny self.cell = cell_size * np.pi / 60 / 60 / 180 self.nthreads = nthreads self.do_wstacking = do_wstacking self.epsilon = epsilon self.row_chunks = row_chunks self.chan_chunks = chan_chunks self.psf_oversize = psf_oversize self.nx_psf = int(self.psf_oversize * self.nx) self.nx_psf += self.nx_psf % 2 self.ny_psf = int(self.psf_oversize * self.ny) self.ny_psf += self.ny_psf % 2 self.real_type = real_type if isinstance(ms_name, list): self.ms = ms_name else: self.ms = [ms_name] # first pass through data to determine freq_mapping self.radec = None self.freq = {} self.freq_np = {} all_freqs = [] self.spws = {} for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks={"row": -1}, columns=('TIME')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] self.freq[ims] = {} self.freq_np[ims] = {} self.spws[ims] = [] maxchans = 0 ncorr = 4 # TODO - get ncorr from ds for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() # check fields match if self.radec is None: self.radec = radec if not np.array_equal(radec, self.radec): continue spw = spws[ds.DATA_DESC_ID] tmp_freq = spw.CHAN_FREQ.data.squeeze() maxchans = np.maximum(maxchans, tmp_freq.size) self.freq[ims][ds.DATA_DESC_ID] = tmp_freq self.freq_np[ims][ds.DATA_DESC_ID] = dask.compute(tmp_freq)[0] all_freqs.append(list([tmp_freq])) self.spws[ims].append(ds.DATA_DESC_ID) self.data_column = data_column self.weight_column = weight_column self.model_column = model_column self.flag_column = flag_column self.columns = (self.data_column, self.weight_column, self.flag_column, 'UVW') # TODO - write jones2col if column does not exist self.mueller_column = mueller_column if mueller_column is not None: self.columns += (self.mueller_column, ) # check that all measurement sets contain the required columns for ims in self.ms: xds = xds_from_ms(ims) for ds in xds: for column in self.columns: try: getattr(ds, column) except BaseException: raise ValueError("No column named %s in %s" % (column, ims)) # freq mapping all_freqs = dask.compute(all_freqs) ufreqs = np.unique(all_freqs) # sorted ascending self.nchan = ufreqs.size if nband is None: self.nband = self.nchan else: self.nband = nband # bin edges fmin = ufreqs[0] fmax = ufreqs[-1] fbins = np.linspace(fmin, fmax, self.nband + 1) self.freq_out = np.zeros(self.nband) for band in range(self.nband): indl = ufreqs >= fbins[band] # inclusive except for the last one indu = ufreqs < fbins[band + 1] + 1e-6 self.freq_out[band] = np.mean(ufreqs[indl & indu]) # chan <-> band mapping self.band_mapping = {} self.chunks = {} self.freq_bin_idx = {} self.freq_bin_counts = {} self.freq_bin_idx_np = {} self.freq_bin_counts_np = {} for ims in self.freq: self.freq_bin_idx[ims] = {} self.freq_bin_counts[ims] = {} self.freq_bin_idx_np[ims] = {} self.freq_bin_counts_np[ims] = {} self.band_mapping[ims] = {} self.chunks[ims] = [] for spw in self.freq[ims]: freq = np.atleast_1d(dask.compute(self.freq[ims][spw])[0]) band_map = np.zeros(freq.size, dtype=np.int32) for band in range(self.nband): indl = freq >= fbins[band] indu = freq < fbins[band + 1] + 1e-6 band_map = np.where(indl & indu, band, band_map) # to dask arrays bands, bin_counts = np.unique(band_map, return_counts=True) self.band_mapping[ims][spw] = tuple(bands) self.chunks[ims].append({'row': -1, 'chan': tuple(bin_counts)}) self.freq[ims][spw] = da.from_array(freq, chunks=tuple(bin_counts)) bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1] self.freq_bin_idx[ims][spw] = da.from_array(bin_idx, chunks=1) self.freq_bin_counts[ims][spw] = da.from_array(bin_counts, chunks=1) self.freq_bin_idx_np[ims][spw] = bin_idx self.freq_bin_counts_np[ims][spw] = bin_counts # compute imaging weights if weighting is not None: if imaging_weight_column is None: self.imaging_weight_column = "IMAGING_WEIGHT_SPECTRUM" else: # this column is always created if asked self.imaging_weight_column = imaging_weight_column print("Computing weights", file=log) self.compute_weights(robust) self.columns += (self.imaging_weight_column, ) else: self.imaging_weight_column = None
def __init__(self, msname=None, log=None): if not msname: return self.msname = msname self.log = log tab = table(msname, ack=False) log and log.info(f": MS {msname} contains {tab.nrows()} rows") self.valid_columns = set(tab.colnames()) spw_tab = daskms.xds_from_table(msname + '::SPECTRAL_WINDOW', columns=['CHAN_FREQ']) self.chan_freqs = spw_tab[ 0].CHAN_FREQ # important for this to be an xarray self.nspw = self.chan_freqs.shape[0] self.spw = NamedList("spw", list(map(str, range(self.nspw)))) log and log.info( f": {self.chan_freqs.shape} spectral windows and channels") self.field = NamedList( "field", table(msname + '::FIELD', ack=False).getcol("NAME")) log and log.info( f": {len(self.field)} fields: {' '.join(self.field.names)}") scan_numbers = sorted(set(tab.getcol("SCAN_NUMBER"))) log and log.info( f": {len(scan_numbers)} scans, first #{scan_numbers[0]}, last #{scan_numbers[-1]}" ) all_scans = NamedList("scan", list(map(str, range(scan_numbers[-1] + 1)))) self.scan = all_scans.get_subset(scan_numbers) self.all_antenna = NamedList( "antenna", table(msname + '::ANTENNA', ack=False).getcol("NAME")) self.antenna = self.all_antenna.get_subset( list(set(tab.getcol("ANTENNA1")) | set(tab.getcol("ANTENNA2")))) baselines = [(p, q) for p in self.antenna.numbers for q in self.antenna.numbers if p <= q] self.baseline_numbering = {(p, q): i for i, (p, q) in enumerate(baselines)} self.baseline_numbering.update({(q, p): i for i, (p, q) in enumerate(baselines)}) log and log.info( f": {len(self.antenna)} antennas: {self.antenna.str_list()}") pol_tab = table(msname + '::POLARIZATION', ack=False) all_corr_labels = [ STOKES_TYPES[icorr] for icorr in pol_tab.getcol("CORR_TYPE", 0, 1).ravel() ] self.corr = NamedList("correlation", all_corr_labels.copy()) # Maps correlation -> callable that extracts that correlation from visibility data # By default, populated with slicing functions for 0...3, # but can also be extended with "I", "Q", etx. self.corr_data_mappers = OrderedDict({ i: lambda x, icorr=i: x[..., icorr] for i in range(len(all_corr_labels)) }) # Maps correlation -> callable that extracts that correlation from flag data self.corr_flag_mappers = self.corr_data_mappers.copy() # add mappings and labels for Stokes parameters xx, xy, yx, yy = [ self.corr.map.get(c) for c in ("XX", "XY", "YX", "YY") ] rr, rl, lr, ll = [ self.corr.map.get(c) for c in ("RR", "RL", "LR", "LL") ] def add_stokes(a, b, I, J, imag=False): """Adds mappers for Stokes A and B as the sum/difference of components I and J, divided by 2 or 2j""" def _sum(x): return (x[..., I] + x[..., J]) / 2 def _diff(x): return (x[..., I] - x[..., J]) / (2j if imag else 2) def _or(x): return (x[..., I] | x[..., J]) nonlocal all_corr_labels if a not in self.corr_data_mappers: self.corr_data_mappers[len(all_corr_labels)] = _sum self.corr_flag_mappers[len(all_corr_labels)] = _or all_corr_labels.append(a) if b not in self.corr_data_mappers: self.corr_data_mappers[len(all_corr_labels)] = _diff self.corr_flag_mappers[len(all_corr_labels)] = _or all_corr_labels.append(b) if xx is not None and yy is not None: add_stokes("I", "Q", xx, yy) if rr is not None and ll is not None: add_stokes("I", "V", rr, ll) if xy is not None and yx is not None: add_stokes("U", "V", xy, yx, True) if rl is not None and lr is not None: add_stokes("Q", "U", rl, lr, True) self.all_corr = NamedList("correlation", all_corr_labels) log and log.info(f": corrs/Stokes {' '.join(self.all_corr.names)}")
def make_residual(self, x): # Note deprecated (does not support Jones terms) print("Making residual", file=log) x = da.from_array(x.astype(self.real_type), chunks=(1, self.nx, self.ny), name=False) residuals = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data data = getattr(ds, self.data_column).data dataxx = data[:, :, 0] datayy = data[:, :, -1] weights = getattr(ds, self.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) if self.imaging_weight_column is not None: imaging_weights = getattr(ds, self.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to( imaging_weights[:, None, :], data.shape, chunks=data.chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # weighted sum corr to Stokes I weights = weightsxx + weightsyy data = (weightsxx * dataxx + weightsyy * datayy) data = da.where(weights, data / weights, 0.0j) # only keep data where both corrs are unflagged flag = getattr(ds, self.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] flag = ~(flagxx | flagyy) # ducc0 convention bands = self.band_mapping[ims][spw] model = x[list(bands), :, :] residual = im2residim(uvw, freq, model, data, freq_bin_idx, freq_bin_counts, self.cell, weights=weights, flag=flag.astype(np.uint8), nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking, double_accum=True) residuals.append(residual) residuals = dask.compute(residuals)[0] return accumulate_dirty(residuals, self.nband, self.band_mapping).astype(self.real_type)
in [('antenna', "ANTENNA"), ('ddid', "DATA_DESCRIPTION"), ('spw', "SPECTRAL_WINDOW"), ('pol', "POLARIZATION"), ('field', "FIELD")]} with scheduler_context(args): # Get datasets from the main MS # partition by FIELD_ID and DATA_DESC_ID # and sorted by TIME datasets = xds_from_ms(args.ms, group_cols=("FIELD_ID", "DATA_DESC_ID"), index_cols="TIME") # Get the antenna dataset ant_ds = list(xds_from_table(table_name['antenna'])) assert len(ant_ds) == 1 ant_ds = ant_ds[0].rename({'row': 'antenna'}) # Get datasets for DATA_DESCRIPTION, SPECTRAL_WINDOW # POLARIZATION and FIELD, partitioned by row ddid_ds = list(xds_from_table(table_name['ddid'], group_cols="__row__")) spwds = list(xds_from_table(table_name['spw'], group_cols="__row__")) pds = list(xds_from_table(table_name['pol'], group_cols="__row__")) field_ds = list(xds_from_table(table_name['field'], group_cols="__row__")) # For each partitioned dataset from the main MS,
def make_psf(self): print("Making PSF", file=log) psfs = [] self.stokes_weights = {} self.uvws = {} for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] self.stokes_weights[ims] = {} self.uvws[ims] = {} for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue # this is not correct, need to use spw spw = ds.DATA_DESC_ID freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data flag = getattr(ds, self.flag_column).data weights = getattr(ds, self.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], flag.shape, chunks=flag.chunks) if self.imaging_weight_column is not None: imaging_weights = getattr(ds, self.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to( imaging_weights[:, None, :], flag.shape, chunks=flag.chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # for the PSF we need to scale the weights by the # Mueller amplitudes squared if self.mueller_column is not None: mueller = getattr(ds, self.mueller_column).data weightsxx *= da.absolute(mueller[:, :, 0])**2 weightsyy *= da.absolute(mueller[:, :, -1])**2 # weighted sum corr to Stokes I weights = weightsxx + weightsyy # only keep data where both corrs are unflagged flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] flag = ~(flagxx | flagyy) # ducc0 convention weights *= flag data = weights.astype(np.complex64) psf = vis2im(uvw, freq, data, freq_bin_idx, freq_bin_counts, self.nx_psf, self.ny_psf, self.cell, flag=flag.astype(np.uint8), nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking, double_accum=True) psfs.append(psf) # assumes that stokes weights and uvw fit into memory # self.stokes_weights[ims][spw] = dask.persist(weights.rechunk({0:-1}))[0] # self.uvws[ims][spw] = dask.persist(uvw.rechunk({0:-1}))[0] # for comparison with numpy implementation # self.stokes_weights[ims][spw] = dask.compute(weights)[0] # self.uvws[ims][spw] = dask.compute(uvw)[0] # import pdb # pdb.set_trace() psfs = dask.compute(psfs, scheduler='single-threaded')[0] return accumulate_dirty(psfs, self.nband, self.band_mapping).astype(self.real_type)
def full_ms_data(ms_name): """Load full ms into memory for pytest.""" return xds_from_table(ms_name)[0]