def test_select_field(): datasets = [] for field_name in ["DEEP2"]: name = np.asarray([field_name], dtype=np.object) ds = Dataset({"NAME": (("row", ), da.from_array(name, chunks=1))}) datasets.append(ds) # No field selection, single field id returned assert select_field_id(datasets, None) == 0 datasets = [] for field_name in ["PKS-1934", "3C286", "DEEP2"]: name = np.asarray([field_name], dtype=np.object) ds = Dataset({"NAME": (("row", ), da.from_array(name, chunks=1))}) datasets.append(ds) # No field selection, ValueError raised with pytest.raises(ValueError): assert select_field_id(datasets, None) == [0, 1, 2] assert select_field_id(datasets, "PKS-1934") == 0 assert select_field_id(datasets, "0") == 0 assert select_field_id(datasets, "3C286") == 1 assert select_field_id(datasets, "1") == 1 assert select_field_id(datasets, "DEEP2") == 2 assert select_field_id(datasets, "2") == 2
def test_ms_create_and_update(Dataset, tmp_path, chunks): """ Test that we can update and append at the same time """ filename = str(tmp_path / "create-and-update.ms") rs = np.random.RandomState(42) # Create a dataset of 10 rows with DATA and DATA_DESC_ID dims = ("row", "chan", "corr") row, chan, corr = tuple(sum(chunks[d]) for d in dims) ms_datasets = [] np_data = (rs.normal(size=(row, chan, corr)) + 1j * rs.normal(size=(row, chan, corr))).astype(np.complex64) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) # Create dask ddid column dask_ddid = da.full(row, 0, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'DATA_DESC_ID': (("row", ), dask_ddid), }) ms_datasets.append(dataset) # Write it writes = xds_to_table(ms_datasets, filename, ["DATA", "DATA_DESC_ID"]) dask.compute(writes) ms_datasets = xds_from_ms(filename) # Now add another dataset (different DDID), with no ROWID np_data = (rs.normal(size=(row, chan, corr)) + 1j * rs.normal(size=(row, chan, corr))).astype(np.complex64) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) # Create dask ddid column dask_ddid = da.full(row, 1, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'DATA_DESC_ID': (("row", ), dask_ddid), }) ms_datasets.append(dataset) # Write it writes = xds_to_table(ms_datasets, filename, ["DATA", "DATA_DESC_ID"]) dask.compute(writes) # Rows have been added and additional data is present with pt.table(filename, ack=False, readonly=True) as T: first_data_desc_id = da.full(row, ms_datasets[0].DATA_DESC_ID, chunks=chunks['row']) ds_data = da.concatenate( [ms_datasets[0].DATA.data, ms_datasets[1].DATA.data]) ds_ddid = da.concatenate( [first_data_desc_id, ms_datasets[1].DATA_DESC_ID.data]) assert_array_equal(T.getcol("DATA"), ds_data) assert_array_equal(T.getcol("DATA_DESC_ID"), ds_ddid)
def average_spw(spw_ds, chan_bin_size): """ Parameters ---------- spw_ds : list of Datasets list of Datasets, each describing a single Spectral Window chan_bin_size : int Number of channels in an averaging bin Returns ------- spw_ds : list of Datasets list of Datasets, each describing an averaged Spectral Window """ new_spw_ds = [] for r, spw in enumerate(spw_ds): # Get the dataset variables as a mutable dictionary dv = dict(spw.data_vars) # Extract arrays we wish to average chan_freq = dv['CHAN_FREQ'].data[0] chan_width = dv['CHAN_WIDTH'].data[0] effective_bw = dv['EFFECTIVE_BW'].data[0] resolution = dv['RESOLUTION'].data[0] # Construct channel metadata chan_arrays = (chan_freq, chan_width, effective_bw, resolution) chan_meta = chan_metadata((), chan_arrays, chan_bin_size) # Average channel based data avg = dask_chan_avg(chan_meta, chan_freq=chan_freq, chan_width=chan_width, effective_bw=effective_bw, resolution=resolution, chan_bin_size=chan_bin_size) num_chan = da.full((1, ), avg.chan_freq.shape[0], dtype=np.int32) # These columns change, re-create them dv['NUM_CHAN'] = (("row", ), num_chan) dv['CHAN_FREQ'] = (("row", "chan"), avg.chan_freq[None, :]) dv['CHAN_WIDTH'] = (("row", "chan"), avg.chan_width[None, :]) dv['EFFECTIVE_BW'] = (("row", "chan"), avg.effective_bw[None, :]) dv['RESOLUTION'] = (("row", "chan"), avg.resolution[None, :]) # But re-use all the others new_spw_ds.append(Dataset(dv)) return new_spw_ds
def test_select_fields(): datasets = [] for field_name in ["PKS-1934", "3C286", "DEEP2"]: name = np.asarray([field_name], dtype=np.object) ds = Dataset({"NAME": (("row",), da.from_array(name, chunks=1))}) datasets.append(ds) # No field selection, all fields returned assert valid_field_ids(datasets, None) == [0, 1, 2] assert valid_field_ids(datasets, "PKS-1934") == [0] assert valid_field_ids(datasets, "3C286, DEEP2") == [1, 2] assert valid_field_ids(datasets, "1, DEEP2") == [1, 2] assert valid_field_ids(datasets, "0, 1, 2") == [0, 1, 2] assert valid_field_ids(datasets, "2, 3") == [2]
def _psf(**kw): args = OmegaConf.create(kw) from omegaconf import ListConfig if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig): args.ms = [args.ms] OmegaConf.set_struct(args, True) import numpy as np from pfb.utils.misc import chan_to_band_mapping import dask # from dask.distributed import performance_report from dask.graph_manipulation import clone from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table from daskms import Dataset from daskms.experimental.zarr import xds_to_zarr import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import dirty as vis2im from ducc0.fft import good_size from pfb.utils.misc import stitch_images, plan_row_chunk from pfb.utils.fits import set_wcs, save_fits # chan <-> band mapping ms = args.ms nband = args.nband freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping( ms, nband=nband) # gridder memory budget max_chan_chunk = 0 max_freq = 0 for ims in args.ms: for spw in freqs[ims]: counts = freq_bin_counts[ims][spw].compute() freq = freqs[ims][spw].compute() max_chan_chunk = np.maximum(max_chan_chunk, counts.max()) max_freq = np.maximum(max_freq, freq.max()) # assumes measurement sets have the same columns, # number of correlations etc. xds = xds_from_ms(args.ms[0]) ncorr = xds[0].dims['corr'] nrow = xds[0].dims['row'] # we still have to cater for complex valued data because we cast # the weights to complex but we not longer need to factor the # weight column into our memory budget data_bytes = getattr(xds[0], args.data_column).data.itemsize bytes_per_row = max_chan_chunk * ncorr * data_bytes memory_per_row = bytes_per_row # flags (uint8 or bool) memory_per_row += bytes_per_row / 8 # UVW memory_per_row += xds[0].UVW.data.itemsize * 3 # ANTENNA1/2 memory_per_row += xds[0].ANTENNA1.data.itemsize * 2 # TIME memory_per_row += xds[0].TIME.data.itemsize # data column is not actually read into memory just used to infer # dtype and chunking columns = (args.data_column, args.weight_column, args.flag_column, 'UVW', 'ANTENNA1', 'ANTENNA2', 'TIME') # flag row if 'FLAG_ROW' in xds[0]: columns += ('FLAG_ROW', ) memory_per_row += xds[0].FLAG_ROW.data.itemsize # imaging weights if args.imaging_weight_column is not None: columns += (args.imaging_weight_column, ) memory_per_row += bytes_per_row / 2 # Mueller term (complex valued) if args.mueller_column is not None: columns += (args.mueller_column, ) memory_per_row += bytes_per_row # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 for ims in args.ms: xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1}) for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) uv_max = uv_max.compute() del uvw # image size cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed) if args.cell_size is not None: cell_size = args.cell_size cell_rad = cell_size * np.pi / 60 / 60 / 180 if cell_N / cell_rad < 1: raise ValueError( "Requested cell size too small. " "Super resolution factor = ", cell_N / cell_rad) print("Super resolution factor = %f" % (cell_N / cell_rad), file=log) else: cell_rad = cell_N / args.super_resolution_factor cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % cell_size, file=log) if args.nx is None: fov = args.field_of_view * 3600 npix = int(args.psf_oversize * fov / cell_size) if npix % 2: npix += 1 nx = npix ny = npix else: nx = args.nx ny = args.ny if args.ny is not None else nx print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log) # get approx image size # this is not a conservative estimate when multiple SPW's map to a single # imaging band pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # full image on single node row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size, nrow, memory_per_row, args.nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow, memory_per_row, args.nthreads_per_worker) if args.row_chunks is not None: row_chunk = int(args.row_chunks) if row_chunk == -1: row_chunk = nrow print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) chunks = {} for ims in args.ms: chunks[ims] = [] # xds_from_ms expects a list per ds for spw in freqs[ims]: chunks[ims].append({ 'row': row_chunk, 'chan': chan_chunks[ims][spw]['chan'] }) psfs = [] radec = None # assumes we are only imaging field 0 of first MS out_datasets = [] for ims in args.ms: xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue # this is not correct, need to use spw spw = ds.DATA_DESC_ID uvw = clone(ds.UVW.data) data_type = getattr(ds, args.data_column).data.dtype data_shape = getattr(ds, args.data_column).data.shape data_chunks = getattr(ds, args.data_column).data.chunks weights = getattr(ds, args.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data_shape, chunks=data_chunks) if args.imaging_weight_column is not None: imaging_weights = getattr(ds, args.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data_shape, chunks=data_chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # apply mueller term if args.mueller_column is not None: mueller = getattr(ds, args.mueller_column).data weightsxx *= da.absolute(mueller[:, :, 0])**2 weightsyy *= da.absolute(mueller[:, :, -1])**2 # weighted sum corr to Stokes I weights = weightsxx + weightsyy # MS may contain auto-correlations if 'FLAG_ROW' in xds[0]: frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data == ds.ANTENNA2.data) else: frow = (ds.ANTENNA1.data == ds.ANTENNA2.data) # only keep data where both corrs are unflagged flag = getattr(ds, args.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] # ducc0 uses uint8 mask not flag mask = ~da.logical_or((flagxx | flagyy), frow[:, None]) psf = vis2im(uvw, freqs[ims][spw], weights.astype(data_type), freq_bin_idx[ims][spw], freq_bin_counts[ims][spw], nx, ny, cell_rad, flag=mask.astype(np.uint8), nthreads=args.nvthreads, epsilon=args.epsilon, do_wstacking=args.wstack, double_accum=args.double_accum) psfs.append(psf) data_vars = { 'FIELD_ID': (('row', ), da.full_like(ds.TIME.data, ds.FIELD_ID, chunks=args.row_out_chunk)), 'DATA_DESC_ID': (('row', ), da.full_like(ds.TIME.data, ds.DATA_DESC_ID, chunks=args.row_out_chunk)), 'WEIGHT': (('row', 'chan'), weights.rechunk({0: args.row_out_chunk })), # why no 'f4'? 'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk})) } coords = {'chan': (('chan', ), freqs[ims][spw])} out_ds = Dataset(data_vars, coords) out_datasets.append(out_ds) writes = xds_to_zarr(out_datasets, args.output_filename + '.zarr', columns='ALL') # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False) # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False) if not args.mock: # psfs = dask.compute(psfs, writes, optimize_graph=False)[0] # with performance_report(filename=args.output_filename + '_psf_per.html'): psfs = dask.compute(psfs, writes, optimize_graph=False)[0] psf = stitch_images(psfs, nband, band_mapping) hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, freq_out) save_fits(args.output_filename + '_psf.fits', psf, hdr, dtype=args.output_type) psf_mfs = np.sum(psf, axis=0) wsum = psf_mfs.max() psf_mfs /= wsum hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, np.mean(freq_out)) save_fits(args.output_filename + '_psf_mfs.fits', psf_mfs, hdr_mfs, dtype=args.output_type) print("All done here.", file=log)
def ms_create(ms_table_name, info, ant_pos, vis_array, baselines, timestamps, pol_feeds, sources): ''' Create a Measurement Set from some TART observations Parameters ---------- ms_table_name : string The name of the MS top level directory. I think this only workds in the local directory. info : JSON "info": { "info": { "L0_frequency": 1571328000.0, "bandwidth": 2500000.0, "baseband_frequency": 4092000.0, "location": { "alt": 270.0, "lat": -45.85177, "lon": 170.5456 }, "name": "Signal Hill - Dunedin", "num_antenna": 24, "operating_frequency": 1575420000.0, "sampling_frequency": 16368000.0 } }, Returns ------- None ''' epoch_s = timestamp_to_ms_epoch(timestamps) LOGGER.info("Time {}".format(epoch_s)) try: loc = info['location'] except: loc = info # Sort out the coordinate frames using astropy # https://casa.nrao.edu/casadocs/casa-5.4.1/reference-material/coordinate-frames iers.conf.iers_auto_url = 'https://astroconda.org/aux/astropy_mirror/iers_a_1/finals2000A.all' iers.conf.auto_max_age = None location = EarthLocation.from_geodetic(lon=loc['lon']*u.deg, lat=loc['lat']*u.deg, height=loc['alt']*u.m, ellipsoid='WGS84') obstime = Time(timestamps) local_frame = AltAz(obstime=obstime, location=location) phase_altaz = SkyCoord(alt=90.0*u.deg, az=0.0*u.deg, obstime = obstime, frame = 'altaz', location = location) phase_j2000 = phase_altaz.transform_to('fk5') # Get the stokes enums for the polarization types corr_types = [[MS_STOKES_ENUMS[p_f] for p_f in pol_feeds]] LOGGER.info("Pol Feeds {}".format(pol_feeds)) LOGGER.info("Correlation Types {}".format(corr_types)) num_freq_channels = [1] ant_table = MSTable(ms_table_name, 'ANTENNA') feed_table = MSTable(ms_table_name, 'FEED') field_table = MSTable(ms_table_name, 'FIELD') pol_table = MSTable(ms_table_name, 'POLARIZATION') obs_table = MSTable(ms_table_name, 'OBSERVATION') # SOURCE is an optional MS sub-table src_table = MSTable(ms_table_name, 'SOURCE') ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION")) spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW")) ms_datasets = [] ddid_datasets = [] spw_datasets = [] # Create ANTENNA dataset # Each column in the ANTENNA has a fixed shape so we # can represent all rows with one dataset num_ant = len(ant_pos) position = da.asarray(ant_pos) diameter = da.ones(num_ant) * 0.025 offset = da.zeros((num_ant, 3)) names = np.array(['ANTENNA-%d' % i for i in range(num_ant)], dtype=np.object) stations = np.array([info['name'] for i in range(num_ant)], dtype=np.object) dataset = Dataset({ 'POSITION': (("row", "xyz"), position), 'OFFSET': (("row", "xyz"), offset), 'DISH_DIAMETER': (("row",), diameter), 'NAME': (("row",), da.from_array(names, chunks=num_ant)), 'STATION': (("row",), da.from_array(stations, chunks=num_ant)), }) ant_table.append(dataset) ################### Create a FEED dataset. ################################### # There is one feed per antenna, so this should be quite similar to the ANTENNA num_pols = len(pol_feeds) pol_types = pol_feeds pol_responses = [POL_RESPONSES[ct] for ct in pol_feeds] LOGGER.info("Pol Types {}".format(pol_types)) LOGGER.info("Pol Responses {}".format(pol_responses)) antenna_ids = da.asarray(range(num_ant)) feed_ids = da.zeros(num_ant) num_receptors = da.zeros(num_ant) + num_pols polarization_types = np.array([pol_types for i in range(num_ant)], dtype=np.object) receptor_angles = np.array([[0.0] for i in range(num_ant)]) pol_response = np.array([pol_responses for i in range(num_ant)]) beam_offset = np.array([[[0.0, 0.0]] for i in range(num_ant)]) dataset = Dataset({ 'ANTENNA_ID': (("row",), antenna_ids), 'FEED_ID': (("row",), feed_ids), 'NUM_RECEPTORS': (("row",), num_receptors), 'POLARIZATION_TYPE': (("row", "receptors",), da.from_array(polarization_types, chunks=num_ant)), 'RECEPTOR_ANGLE': (("row", "receptors",), da.from_array(receptor_angles, chunks=num_ant)), 'POL_RESPONSE': (("row", "receptors", "receptors-2"), da.from_array(pol_response, chunks=num_ant)), 'BEAM_OFFSET': (("row", "receptors", "radec"), da.from_array(beam_offset, chunks=num_ant)), }) feed_table.append(dataset) ####################### FIELD dataset ######################################### direction = [[phase_j2000.ra.radian, phase_j2000.dec.radian]] field_direction = da.asarray(direction)[None, :] field_name = da.asarray(np.asarray(['up'], dtype=np.object), chunks=1) field_num_poly = da.zeros(1) # Zero order polynomial in time for phase center. dir_dims = ("row", 'field-poly', 'field-dir',) dataset = Dataset({ 'PHASE_DIR': (dir_dims, field_direction), 'DELAY_DIR': (dir_dims, field_direction), 'REFERENCE_DIR': (dir_dims, field_direction), 'NUM_POLY': (("row", ), field_num_poly), 'NAME': (("row", ), field_name), }) field_table.append(dataset) ######################### OBSERVATION dataset ##################################### dataset = Dataset({ 'TELESCOPE_NAME': (("row",), da.asarray(np.asarray(['TART'], dtype=np.object), chunks=1)), 'OBSERVER': (("row",), da.asarray(np.asarray(['Tim'], dtype=np.object), chunks=1)), "TIME_RANGE": (("row","obs-exts"), da.asarray(np.array([[epoch_s, epoch_s+1]]), chunks=1)), }) obs_table.append(dataset) ######################## SOURCE datasets ######################################## for src in sources: name = src['name'] # Convert to J2000 dir_altaz = SkyCoord(alt=src['el']*u.deg, az=src['az']*u.deg, obstime = obstime, frame = 'altaz', location = location) dir_j2000 = dir_altaz.transform_to('fk5') direction = [dir_j2000.ra.radian, dir_j2000.dec.radian] #LOGGER.info("SOURCE: {}, timestamp: {}".format(name, timestamps)) dask_num_lines = da.full((1,), 1, dtype=np.int32) dask_direction = da.asarray(direction)[None, :] dask_name = da.asarray(np.asarray([name], dtype=np.object), chunks=1) dask_time = da.asarray(np.array([epoch_s])) dataset = Dataset({ "NUM_LINES": (("row",), dask_num_lines), "NAME": (("row",), dask_name), "TIME": (("row",), dask_time), "DIRECTION": (("row", "dir"), dask_direction), }) src_table.append(dataset) # Create POLARISATION datasets. # Dataset per output row required because column shapes are variable for corr_type in corr_types: corr_prod = [[i, i] for i in range(len(corr_type))] corr_prod = np.array(corr_prod) LOGGER.info("Corr Prod {}".format(corr_prod)) LOGGER.info("Corr Type {}".format(corr_type)) dask_num_corr = da.full((1,), len(corr_type), dtype=np.int32) LOGGER.info("NUM_CORR {}".format(dask_num_corr)) dask_corr_type = da.from_array(corr_type, chunks=len(corr_type))[None, :] dask_corr_product = da.asarray(corr_prod)[None, :] LOGGER.info("Dask Corr Prod {}".format(dask_corr_product.shape)) LOGGER.info("Dask Corr Type {}".format(dask_corr_type.shape)) dataset = Dataset({ "NUM_CORR": (("row",), dask_num_corr), "CORR_TYPE": (("row", "corr"), dask_corr_type), "CORR_PRODUCT": (("row", "corr", "corrprod_idx"), dask_corr_product), }) pol_table.append(dataset) # Create multiple SPECTRAL_WINDOW datasets # Dataset per output row required because column shapes are variable for num_chan in num_freq_channels: dask_num_chan = da.full((1,), num_chan, dtype=np.int32) dask_chan_freq = da.asarray([[info['operating_frequency']]]) dask_chan_width = da.full((1, num_chan), 2.5e6/num_chan) dataset = Dataset({ "NUM_CHAN": (("row",), dask_num_chan), "CHAN_FREQ": (("row", "chan"), dask_chan_freq), "CHAN_WIDTH": (("row", "chan"), dask_chan_width), "EFFECTIVE_BW": (("row", "chan"), dask_chan_width), "RESOLUTION": (("row", "chan"), dask_chan_width), }) spw_datasets.append(dataset) # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION # create a corresponding DATA_DESCRIPTION. # Each column has fixed shape so we handle all rows at once spw_ids, pol_ids = zip(*product(range(len(num_freq_channels)), range(len(corr_types)))) dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32)) dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32)) ddid_datasets.append(Dataset({ "SPECTRAL_WINDOW_ID": (("row",), dask_spw_ids), "POLARIZATION_ID": (("row",), dask_pol_ids), })) # Now create the associated MS dataset #vis_data, baselines = cal_vis.get_all_visibility() #vis_array = np.array(vis_data, dtype=np.complex64) chunks = { "row": (vis_array.shape[0],), } baselines = np.array(baselines) #LOGGER.info(f"baselines {baselines}") bl_pos = np.array(ant_pos)[baselines] uu_a, vv_a, ww_a = -(bl_pos[:, 1] - bl_pos[:, 0]).T #/constants.L1_WAVELENGTH # Use the - sign to get the same orientation as our tart projections. uvw_array = np.array([uu_a, vv_a, ww_a]).T for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)): # Infer row, chan and correlation shape #LOGGER.info("ddid:{} ({}, {})".format(ddid, spw_id, pol_id)) row = sum(chunks['row']) chan = spw_datasets[spw_id].CHAN_FREQ.shape[1] corr = pol_table.datasets[pol_id].CORR_TYPE.shape[1] # Create some dask vis data dims = ("row", "chan", "corr") LOGGER.info("Data size %s %s %s" % (row, chan, corr)) #np_data = vis_array.reshape((row, chan, corr)) np_data = np.zeros((row, chan, corr), dtype=np.complex128) for i in range(corr): np_data[:, :, i] = vis_array.reshape((row, chan)) #np_data = np.array([vis_array.reshape((row, chan, 1)) for i in range(corr)]) np_uvw = uvw_array.reshape((row, 3)) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) flag_categories = da.from_array(0.05*np.ones((row, chan, corr, 1))) flag_data = np.zeros((row, chan, corr), dtype=np.bool_) uvw_data = da.from_array(np_uvw) # Create dask ddid column dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'FLAG': (dims, da.from_array(flag_data)), 'TIME': (("row", "corr"), da.from_array(epoch_s*np.ones((row, corr)))), 'TIME_CENTROID': (("row", "corr"), da.from_array(epoch_s*np.ones((row, corr)))), 'WEIGHT': (("row", "corr"), da.from_array(0.95*np.ones((row, corr)))), 'WEIGHT_SPECTRUM': (dims, da.from_array(0.95*np.ones_like(np_data, dtype=np.float64))), 'SIGMA_SPECTRUM': (dims, da.from_array(np.ones_like(np_data, dtype=np.float64)*0.05)), 'SIGMA': (("row", "corr"), da.from_array(0.05*np.ones((row, corr)))), 'UVW': (("row", "uvw",), uvw_data), 'FLAG_CATEGORY': (('row', 'flagcat', 'chan', 'corr'), flag_categories), # {'dims': ('flagcat', 'chan', 'corr')} 'ANTENNA1': (("row",), da.from_array(baselines[:, 0])), 'ANTENNA2': (("row",), da.from_array(baselines[:, 1])), 'FEED1': (("row",), da.from_array(baselines[:, 0])), 'FEED2': (("row",), da.from_array(baselines[:, 1])), 'DATA_DESC_ID': (("row",), dask_ddid) }) ms_datasets.append(dataset) ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL") spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL") ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL") dask.compute(ms_writes) ant_table.write() feed_table.write() field_table.write() pol_table.write() obs_table.write() src_table.write() dask.compute(spw_writes) dask.compute(ddid_writes)
def output_dataset(avg, field_id, data_desc_id, scan_number, group_row_chunks): """ Parameters ---------- avg : namedtuple Result of :func:`average` field_id : int FIELD_ID for this averaged data data_desc_id : int DATA_DESC_ID for this averaged data scan_number : int SCAN_NUMBER for this averaged data Returns ------- Dataset Dataset containing averaged data """ # Create ID columns field_id = id_full_like(avg.time, fill_value=field_id) data_desc_id = id_full_like(avg.time, fill_value=data_desc_id) scan_number = id_full_like(avg.time, fill_value=scan_number) # Single flag category, equal to flags flag_cats = avg.flag[:, None, :, :] out_ds = { # Explicitly zero these columns? But this happens anyway # "ARRAY_ID": (("row",), zeros), # "OBSERVATION_ID": (("row",), zeros), # "PROCESSOR_ID": (("row",), zeros), # "STATE_ID": (("row",), zeros), "ANTENNA1": (("row", ), avg.antenna1), "ANTENNA2": (("row", ), avg.antenna2), "DATA_DESC_ID": (("row", ), data_desc_id), "FIELD_ID": (("row", ), field_id), "SCAN_NUMBER": (("row", ), scan_number), "FLAG_ROW": (("row", ), avg.flag_row), "FLAG_CATEGORY": (("row", "flagcat", "chan", "corr"), flag_cats), "TIME": (("row", ), avg.time), "INTERVAL": (("row", ), avg.interval), "TIME_CENTROID": (("row", ), avg.time_centroid), "EXPOSURE": (("row", ), avg.exposure), "UVW": (("row", "[uvw]"), avg.uvw), "WEIGHT": (("row", "corr"), avg.weight), "SIGMA": (("row", "corr"), avg.sigma), "DATA": (("row", "chan", "corr"), avg.vis), "FLAG": (("row", "chan", "corr"), avg.flag), } # Add optionally averaged columns columns if avg.weight_spectrum is not None: out_ds['WEIGHT_SPECTRUM'] = (("row", "chan", "corr"), avg.weight_spectrum) if avg.sigma_spectrum is not None: out_ds['SIGMA_SPECTRUM'] = (("row", "chan", "corr"), avg.sigma_spectrum) # Concatenate row chunks together if group_row_chunks > 1: grc = group_row_chunks out_ds = { k: (dims, concatenate_row_chunks(data, group_every=grc)) for k, (dims, data) in out_ds.items() } return Dataset(out_ds)
def test_ms_create(Dataset, tmp_path, chunks, num_chans, corr_types, sources): # Set up rs = np.random.RandomState(42) ms_path = tmp_path / "create.ms" ms_table_name = str(ms_path) ant_table_name = "::".join((ms_table_name, "ANTENNA")) ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION")) pol_table_name = "::".join((ms_table_name, "POLARIZATION")) spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW")) # SOURCE is an optional MS sub-table src_table_name = "::".join((ms_table_name, "SOURCE")) ms_datasets = [] ant_datasets = [] ddid_datasets = [] pol_datasets = [] spw_datasets = [] src_datasets = [] # For comparison all_data_desc_id = [] all_data = [] # Create ANTENNA dataset of 64 antennas # Each column in the ANTENNA has a fixed shape so we # can represent all rows with one dataset na = 64 position = da.random.random((na, 3)) * 10000 offset = da.random.random((na, 3)) names = np.array(['ANTENNA-%d' % i for i in range(na)], dtype=np.object) ds = Dataset({ 'POSITION': (("row", "xyz"), position), 'OFFSET': (("row", "xyz"), offset), 'NAME': (("row", ), da.from_array(names, chunks=na)), }) ant_datasets.append(ds) # Create SOURCE datasets for s, (name, direction, rest_freq) in enumerate(sources): dask_num_lines = da.full((1, ), len(rest_freq), dtype=np.int32) dask_direction = da.asarray(direction)[None, :] dask_rest_freq = da.asarray(rest_freq)[None, :] dask_name = da.asarray(np.asarray([name], dtype=np.object)) ds = Dataset({ "NUM_LINES": (("row", ), dask_num_lines), "NAME": (("row", ), dask_name), "REST_FREQUENCY": (("row", "line"), dask_rest_freq), "DIRECTION": (("row", "dir"), dask_direction), }) src_datasets.append(ds) # Create POLARISATION datasets. # Dataset per output row required because column shapes are variable for r, corr_type in enumerate(corr_types): dask_num_corr = da.full((1, ), len(corr_type), dtype=np.int32) dask_corr_type = da.from_array(corr_type, chunks=len(corr_type))[None, :] ds = Dataset({ "NUM_CORR": (("row", ), dask_num_corr), "CORR_TYPE": (("row", "corr"), dask_corr_type), }) pol_datasets.append(ds) # Create multiple MeerKAT L-band SPECTRAL_WINDOW datasets # Dataset per output row required because column shapes are variable for num_chan in num_chans: dask_num_chan = da.full((1, ), num_chan, dtype=np.int32) dask_chan_freq = da.linspace(.856e9, 2 * .856e9, num_chan, chunks=num_chan)[None, :] dask_chan_width = da.full((1, num_chan), .856e9 / num_chan) ds = Dataset({ "NUM_CHAN": (("row", ), dask_num_chan), "CHAN_FREQ": (("row", "chan"), dask_chan_freq), "CHAN_WIDTH": (("row", "chan"), dask_chan_width), }) spw_datasets.append(ds) # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION # create a corresponding DATA_DESCRIPTION. # Each column has fixed shape so we handle all rows at once spw_ids, pol_ids = zip( *product(range(len(num_chans)), range(len(corr_types)))) dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32)) dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32)) ddid_datasets.append( Dataset({ "SPECTRAL_WINDOW_ID": (("row", ), dask_spw_ids), "POLARIZATION_ID": (("row", ), dask_pol_ids), })) # Now create the associated MS dataset for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)): # Infer row, chan and correlation shape row = sum(chunks['row']) chan = spw_datasets[spw_id].CHAN_FREQ.shape[1] corr = pol_datasets[pol_id].CORR_TYPE.shape[1] # Create some dask vis data dims = ("row", "chan", "corr") np_data = (rs.normal(size=(row, chan, corr)) + 1j * rs.normal(size=(row, chan, corr))).astype(np.complex64) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) # Create dask ddid column dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'DATA_DESC_ID': (("row", ), dask_ddid) }) ms_datasets.append(dataset) all_data.append(dask_data) all_data_desc_id.append(dask_ddid) ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL") ant_writes = xds_to_table(ant_datasets, ant_table_name, columns="ALL") pol_writes = xds_to_table(pol_datasets, pol_table_name, columns="ALL") spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL") ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL") source_writes = xds_to_table(src_datasets, src_table_name, columns="ALL") dask.compute(ms_writes, ant_writes, pol_writes, spw_writes, ddid_writes, source_writes) # Check ANTENNA table correctly created with pt.table(ant_table_name, ack=False) as A: assert_array_equal(A.getcol("NAME"), names) assert_array_equal(A.getcol("POSITION"), position) assert_array_equal(A.getcol("OFFSET"), offset) required_desc = pt.required_ms_desc("ANTENNA") required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) assert set(A.colnames()) == set(required_columns) # Check POLARIZATION table correctly created with pt.table(pol_table_name, ack=False) as P: for r, corr_type in enumerate(corr_types): assert_array_equal(P.getcol("CORR_TYPE", startrow=r, nrow=1), [corr_type]) assert_array_equal(P.getcol("NUM_CORR", startrow=r, nrow=1), [len(corr_type)]) required_desc = pt.required_ms_desc("POLARIZATION") required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) assert set(P.colnames()) == set(required_columns) # Check SPECTRAL_WINDOW table correctly created with pt.table(spw_table_name, ack=False) as S: for r, num_chan in enumerate(num_chans): assert_array_equal( S.getcol("NUM_CHAN", startrow=r, nrow=1)[0], num_chan) assert_array_equal( S.getcol("CHAN_FREQ", startrow=r, nrow=1)[0], np.linspace(.856e9, 2 * .856e9, num_chan)) assert_array_equal( S.getcol("CHAN_WIDTH", startrow=r, nrow=1)[0], np.full(num_chan, .856e9 / num_chan)) required_desc = pt.required_ms_desc("SPECTRAL_WINDOW") required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) assert set(S.colnames()) == set(required_columns) # We should get a cartesian product out with pt.table(ddid_table_name, ack=False) as D: spw_id, pol_id = zip( *product(range(len(num_chans)), range(len(corr_types)))) assert_array_equal(pol_id, D.getcol("POLARIZATION_ID")) assert_array_equal(spw_id, D.getcol("SPECTRAL_WINDOW_ID")) required_desc = pt.required_ms_desc("DATA_DESCRIPTION") required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) assert set(D.colnames()) == set(required_columns) with pt.table(src_table_name, ack=False) as S: for r, (name, direction, rest_freq) in enumerate(sources): assert_array_equal(S.getcol("NAME", startrow=r, nrow=1)[0], [name]) assert_array_equal(S.getcol("REST_FREQUENCY", startrow=r, nrow=1), [rest_freq]) assert_array_equal(S.getcol("DIRECTION", startrow=r, nrow=1), [direction]) with pt.table(ms_table_name, ack=False) as T: # DATA_DESC_ID's are all the same shape assert_array_equal(T.getcol("DATA_DESC_ID"), da.concatenate(all_data_desc_id)) # DATA is variably shaped (on DATA_DESC_ID) so we # compared each one separately. for ddid, data in enumerate(all_data): ms_data = T.getcol("DATA", startrow=ddid * row, nrow=row) assert_array_equal(ms_data, data) required_desc = pt.required_ms_desc() required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) # Check we have the required columns assert set(T.colnames()) == required_columns.union( ["DATA", "DATA_DESC_ID"])
def output_dataset(avg, field_id, data_desc_id, scan_number, group_row_chunks, viscolumn=""): """ Parameters ---------- avg : namedtuple Result of :func:`average` field_id : int FIELD_ID for this averaged data data_desc_id : int DATA_DESC_ID for this averaged data scan_number : int SCAN_NUMBER for this averaged data group_row_chunks : int Concatenate row chunks group Returns ------- Dataset Dataset containing averaged data """ # Create ID columns fid = field_id ddid = data_desc_id scn = scan_number field_id = id_full_like(avg.time, fill_value=field_id) data_desc_id = id_full_like(avg.time, fill_value=data_desc_id) scan_number = id_full_like(avg.time, fill_value=scan_number) # Single flag category, equal to flags flag_cats = avg.flag[:, None, :, :] out_ds = { "ANTENNA1": (("row", ), avg.antenna1), "ANTENNA2": (("row", ), avg.antenna2), "DATA_DESC_ID": (("row", ), data_desc_id), "FIELD_ID": (("row", ), field_id), "SCAN_NUMBER": (("row", ), scan_number), "FLAG_ROW": (("row", ), avg.flag_row), # "FLAG_CATEGORY": (("row", "flagcat", "chan", "corr"), flag_cats), "TIME": (("row", ), avg.time), "INTERVAL": (("row", ), avg.interval), "TIME_CENTROID": (("row", ), avg.time_centroid), "EXPOSURE": (("row", ), avg.exposure), "UVW": (("row", "uvw"), avg.uvw), "WEIGHT": (("row", "corr"), avg.weight), "SIGMA": (("row", "corr"), avg.sigma), viscolumn: (("row", "chan", "corr"), avg.vis), "FLAG": (("row", "chan", "corr"), avg.flag), } # Add optionally averaged columns columns if avg.weight_spectrum is not None: out_ds["WEIGHT_SPECTRUM"] = (("row", "chan", "corr"), avg.weight_spectrum) if avg.sigma_spectrum is not None: out_ds["SIGMA_SPECTRUM"] = (("row", "chan", "corr"), avg.sigma_spectrum) # Concatenate row chunks together if group_row_chunks > 1: grc = group_row_chunks # Remove items whose values are None out_ds = { k: (dims, concatenate_row_chunks(data, group_every=grc)) for k, (dims, data) in out_ds.items() if data is not None } return Dataset(out_ds, attrs={ "DATA_DESC_ID": ddid, "FIELD_ID": fid, "SCAN_NUMBER": scn })
def ms_create(ms_table_name, info, ant_pos, cal_vis, timestamps, corr_types, sources): ''' "info": { "info": { "L0_frequency": 1571328000.0, "bandwidth": 2500000.0, "baseband_frequency": 4092000.0, "location": { "alt": 270.0, "lat": -45.85177, "lon": 170.5456 }, "name": "Signal Hill - Dunedin", "num_antenna": 24, "operating_frequency": 1575420000.0, "sampling_frequency": 16368000.0 } }, ''' num_chans = [1] rs = np.random.RandomState(42) ant_table_name = "::".join((ms_table_name, "ANTENNA")) feed_table_name = "::".join((ms_table_name, "FEED")) field_table_name = "::".join((ms_table_name, "FIELD")) obs_table_name = "::".join((ms_table_name, "OBSERVATION")) ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION")) pol_table_name = "::".join((ms_table_name, "POLARIZATION")) spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW")) # SOURCE is an optional MS sub-table src_table_name = "::".join((ms_table_name, "SOURCE")) ms_datasets = [] ant_datasets = [] feed_datasets = [] field_datasets = [] obs_datasets = [] ddid_datasets = [] pol_datasets = [] spw_datasets = [] src_datasets = [] # Create ANTENNA dataset # Each column in the ANTENNA has a fixed shape so we # can represent all rows with one dataset na = len(ant_pos) position = da.asarray(ant_pos) # da.zeros((na, 3)) diameter = da.ones(na) * 0.025 offset = da.zeros((na, 3)) names = np.array(['ANTENNA-%d' % i for i in range(na)], dtype=np.object) stations = np.array([info['name'] for i in range(na)], dtype=np.object) ds = Dataset({ 'POSITION': (("row", "xyz"), position), 'OFFSET': (("row", "xyz"), offset), 'DISH_DIAMETER': (("row", ), diameter), 'NAME': (("row", ), da.from_array(names, chunks=na)), 'STATION': (("row", ), da.from_array(stations, chunks=na)), }) ant_datasets.append(ds) ######################################### Create a FEED dataset. ####################################### # There is one feed per antenna, so this should be quite similar to the ANTENNA antenna_ids = da.asarray(range(na)) feed_ids = da.zeros(na) num_receptors = da.ones(na) polarization_types = np.array([['XX'] for i in range(na)], dtype=np.object) receptor_angles = np.array([[0.0] for i in range(na)]) pol_response = np.array([[[0.0 + 1.0j, 1.0 - 1.0j]] for i in range(na)]) beam_offset = np.array([[[0.0, 0.0]] for i in range(na)]) ds = Dataset({ 'ANTENNA_ID': (("row", ), antenna_ids), 'FEED_ID': (("row", ), feed_ids), 'NUM_RECEPTORS': (("row", ), num_receptors), 'POLARIZATION_TYPE': (( "row", "receptors", ), da.from_array(polarization_types, chunks=na)), 'RECEPTOR_ANGLE': (( "row", "receptors", ), da.from_array(receptor_angles, chunks=na)), 'POL_RESPONSE': (("row", "receptors", "receptors-2"), da.from_array(pol_response, chunks=na)), 'BEAM_OFFSET': (("row", "receptors", "radec"), da.from_array(beam_offset, chunks=na)), }) feed_datasets.append(ds) ########################################### FIELD dataset ################################################ direction = [[np.radians(90.0), np.radians(0.0)]] ## Phase Center in J2000 field_direction = da.asarray(direction)[None, :] field_name = da.asarray(np.asarray(['up'], dtype=np.object)) field_num_poly = da.zeros( 1) # Zero order polynomial in time for phase center. dir_dims = ( "row", 'field-poly', 'field-dir', ) ds = Dataset({ 'PHASE_DIR': (dir_dims, field_direction), 'DELAY_DIR': (dir_dims, field_direction), 'REFERENCE_DIR': (dir_dims, field_direction), 'NUM_POLY': (("row", ), field_num_poly), 'NAME': (("row", ), field_name), }) field_datasets.append(ds) ########################################### OBSERVATION dataset ################################################ ds = Dataset({ 'TELESCOPE_NAME': (("row", ), da.asarray(np.asarray(['TART'], dtype=np.object))), 'OBSERVER': (("row", ), da.asarray(np.asarray(['Tim'], dtype=np.object))), }) obs_datasets.append(ds) #################################### Create SOURCE datasets ############################################# for s, src in enumerate(sources): name = src['name'] rest_freq = [info['operating_frequency']] direction = [ np.radians(src['el']), np.radians(src['az']) ] ## FIXME these are in elevation and azimuth. Not in J2000. #logger.info("SOURCE: {}, timestamp: {}".format(name, timestamps)) dask_num_lines = da.full((1, ), len(rest_freq), dtype=np.int32) dask_direction = da.asarray(direction)[None, :] dask_rest_freq = da.asarray(rest_freq)[None, :] dask_name = da.asarray(np.asarray([name], dtype=np.object)) dask_time = da.asarray(np.asarray([timestamps], dtype=np.object)) ds = Dataset({ "NUM_LINES": (("row", ), dask_num_lines), "NAME": (("row", ), dask_name), #"TIME": (("row",), dask_time), # FIXME. Causes an error. Need to sort out TIME data fields #"REST_FREQUENCY": (("row", "line"), dask_rest_freq), "DIRECTION": (("row", "dir"), dask_direction), }) src_datasets.append(ds) # Create POLARISATION datasets. # Dataset per output row required because column shapes are variable for r, corr_type in enumerate(corr_types): dask_num_corr = da.full((1, ), len(corr_type), dtype=np.int32) dask_corr_type = da.from_array(corr_type, chunks=len(corr_type))[None, :] dask_corr_type = da.from_array(corr_type, chunks=len(corr_type))[None, :] ds = Dataset({ "NUM_CORR": (("row", ), dask_num_corr), #"CORR_PRODUCT": (("row",), dask_num_corr), "CORR_TYPE": (("row", "corr"), dask_corr_type), }) pol_datasets.append(ds) # Create multiple SPECTRAL_WINDOW datasets # Dataset per output row required because column shapes are variable for num_chan in num_chans: dask_num_chan = da.full((1, ), num_chan, dtype=np.int32) dask_chan_freq = da.asarray([[info['operating_frequency']]]) dask_chan_width = da.full((1, num_chan), 2.5e6 / num_chan) ds = Dataset({ "NUM_CHAN": (("row", ), dask_num_chan), "CHAN_FREQ": (("row", "chan"), dask_chan_freq), "CHAN_WIDTH": (("row", "chan"), dask_chan_width), }) spw_datasets.append(ds) # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION # create a corresponding DATA_DESCRIPTION. # Each column has fixed shape so we handle all rows at once spw_ids, pol_ids = zip( *product(range(len(num_chans)), range(len(corr_types)))) dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32)) dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32)) ddid_datasets.append( Dataset({ "SPECTRAL_WINDOW_ID": (("row", ), dask_spw_ids), "POLARIZATION_ID": (("row", ), dask_pol_ids), })) # Now create the associated MS dataset vis_data, baselines = cal_vis.get_all_visibility() vis_array = np.array(vis_data, dtype=np.complex64) chunks = { "row": (vis_array.shape[0], ), } baselines = np.array(baselines) bl_pos = np.array(ant_pos)[baselines] uu_a, vv_a, ww_a = -( bl_pos[:, 1] - bl_pos[:, 0] ).T / constants.L1_WAVELENGTH # Use the - sign to get the same orientation as our tart projections. uvw_array = np.array([uu_a, vv_a, ww_a]).T for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)): # Infer row, chan and correlation shape logger.info("ddid:{} ({}, {})".format(ddid, spw_id, pol_id)) row = sum(chunks['row']) chan = spw_datasets[spw_id].CHAN_FREQ.shape[1] corr = pol_datasets[pol_id].CORR_TYPE.shape[1] # Create some dask vis data dims = ("row", "chan", "corr") logger.info("Data size {}".format((row, chan, corr))) np_data = vis_array.reshape((row, chan, corr)) np_uvw = uvw_array.reshape((row, 3)) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) uvw_data = da.from_array(np_uvw) # Create dask ddid column dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'WEIGHT': (("row", "corr"), da.from_array(0.95 * np.ones((row, corr)))), 'WEIGHT_SPECTRUM': (dims, da.from_array(0.95 * np.ones_like(np_data, dtype=np.float64))), 'SIGMA_SPECTRUM': (dims, da.from_array(np.ones_like(np_data, dtype=np.float64) * 0.05)), 'UVW': (( "row", "uvw", ), uvw_data), 'ANTENNA1': (("row", ), da.from_array(baselines[:, 0])), 'ANTENNA2': (("row", ), da.from_array(baselines[:, 1])), 'FEED1': (("row", ), da.from_array(baselines[:, 0])), 'FEED2': (("row", ), da.from_array(baselines[:, 1])), 'DATA_DESC_ID': (("row", ), dask_ddid) }) ms_datasets.append(dataset) ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL") ant_writes = xds_to_table(ant_datasets, ant_table_name, columns="ALL") feed_writes = xds_to_table(feed_datasets, feed_table_name, columns="ALL") field_writes = xds_to_table(field_datasets, field_table_name, columns="ALL") obs_writes = xds_to_table(obs_datasets, obs_table_name, columns="ALL") pol_writes = xds_to_table(pol_datasets, pol_table_name, columns="ALL") spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL") ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL") source_writes = xds_to_table(src_datasets, src_table_name, columns="ALL") dask.compute(ms_writes) dask.compute(ant_writes) dask.compute(feed_writes) dask.compute(field_writes) dask.compute(obs_writes) dask.compute(pol_writes) dask.compute(spw_writes) dask.compute(ddid_writes) dask.compute(source_writes)
def output_dataset(avg, field_id, data_desc_id, scan_number, group_row_chunks): """ Parameters ---------- avg : namedtuple Result of :func:`average` field_id : int FIELD_ID for this averaged data data_desc_id : int DATA_DESC_ID for this averaged data scan_number : int SCAN_NUMBER for this averaged data Returns ------- Dataset Dataset containing averaged data """ # Create ID columns field_id = id_full_like(avg.time, fill_value=field_id) data_desc_id = id_full_like(avg.time, fill_value=data_desc_id) scan_number = id_full_like(avg.time, fill_value=scan_number) flag_cats = flag_categories(avg.flag) zeros = id_full_like(avg.antenna1, fill_value=0) out_ds = { # Explicitly zero these columns? But this happens anyway "ARRAY_ID": (("row",), zeros), "OBSERVATION_ID": (("row",), zeros), "PROCESSOR_ID": (("row",), zeros), "STATE_ID": (("row",), zeros), "ANTENNA1": (("row",), avg.antenna1), "ANTENNA2": (("row",), avg.antenna2), "DATA_DESC_ID": (("row",), data_desc_id), "FIELD_ID": (("row",), field_id), "SCAN_NUMBER": (("row",), scan_number), "FLAG_ROW": (("row",), avg.flag_row), "FLAG_CATEGORY": (("row", "flagcat", "chan", "corr"), flag_cats), "TIME": (("row",), avg.time), "INTERVAL": (("row",), avg.interval), "TIME_CENTROID": (("row",), avg.time_centroid), "EXPOSURE": (("row",), avg.exposure), "UVW": (("row", "[uvw]"), avg.uvw), "WEIGHT": (("row", "corr"), avg.weight), "SIGMA": (("row", "corr"), avg.sigma), # "DATA": (("row", "chan", "corr"), avg.vis), "FLAG": (("row", "chan", "corr"), avg.flag), } if avg.visibilities is not None: if type(avg.visibilities) is dict: for column, data in avg.visibilities.items(): out_ds[column] = (("row", "chan", "corr"), data) elif isinstance(avg.visibilities, da.Array): out_ds["DATA"] = (("row", "chan", "corr"), avg.visibilities) else: raise TypeError(f"Unknown visibility type {type(avg.visibilities)}") if hasattr(avg, "offsets"): num_chan = da.map_blocks(np.diff, avg.offsets) out_ds['NUM_CHAN'] = (("row",), num_chan) if hasattr(avg, "decorr_chan_width"): out_ds['DECORR_CHAN_WIDTH'] = (("row",), avg.decorr_chan_width) # Add optionally averaged columns columns if avg.weight_spectrum is not None: out_ds['WEIGHT_SPECTRUM'] = (("row", "chan", "corr"), avg.weight_spectrum) if avg.sigma_spectrum is not None: out_ds['SIGMA_SPECTRUM'] = (("row", "chan", "corr"), avg.sigma_spectrum) # Concatenate row chunks together if group_row_chunks > 1: grc = group_row_chunks out_ds = {k: (dims, row_concatenate(data, group_every=grc)) for k, (dims, data) in out_ds.items()} return Dataset(out_ds)
def bda_average_spw(out_datasets, ddid_ds, spw_ds): """ Parameters ---------- out_datasets : list of Datasets list of Datasets ddid_ds : Dataset DATA_DESCRIPTION dataset spw_ds : list of Datasets list of Datasets, each describing a single Spectral Window Returns ------- output_ds : list of Datasets list of Datasets spw_ds : list of Datasets list of Datasets, each describing an averaged Spectral Window """ channelisations = [] # Over the entire set of datasets, determine the complete # set of channelisations, per input DDID and # reduce down to a single object for out_ds in out_datasets: transform = da.blockwise(_channelisations, ("row",), out_ds.DATA_DESC_ID.data, ("row",), out_ds.NUM_CHAN.data, ("row",), ddid_ds.SPECTRAL_WINDOW_ID.data, ("ddid",), ddid_ds.POLARIZATION_ID.data, ("ddid",), meta=np.empty((0,), dtype=np.object)) result = da.reduction(transform, chunk=_noop, combine=combine, aggregate=combine, concatenate=False, keepdims=True, meta=np.empty((0,), dtype=np.object), dtype=np.object) channelisations.append(result) # Final reduction object, note the aggregate method # which generates the mapping ddid_chan_map = da.reduction(da.concatenate(channelisations), chunk=_noop, combine=combine, aggregate=aggregate, concatenate=False, keepdims=False, meta=np.empty((), dtype=np.object), dtype=np.object) def _squeeze_tuplify(*args): return tuple(a.squeeze() for a in args) chan_freqs = da.blockwise(_squeeze_tuplify, ("row", "chan"), *(a for spw in spw_ds for a in (spw.CHAN_FREQ.data, ("row", "chan"))), concatenate=False, align_arrays=False, adjust_chunks={"chan": lambda c: np.nan}, meta=np.empty((0, 0), dtype=np.object)) chan_widths = da.blockwise(_squeeze_tuplify, ("row", "chan"), *(a for spw in spw_ds for a in (spw.CHAN_WIDTH.data, ("row", "chan"))), concatenate=False, align_arrays=False, adjust_chunks={"chan": lambda c: np.nan}, meta=np.empty((0, 0), dtype=np.object)) ref_freqs = da.blockwise(_squeeze_tuplify, ("row",), *(a for spw in spw_ds for a in (spw.REF_FREQUENCY.data, ("row",))), concatenate=False, align_arrays=False, meta=np.empty((0,), dtype=np.object)) meas_freq_refs = da.blockwise(_squeeze_tuplify, ("row",), *(a for spw in spw_ds for a in (spw.REF_FREQUENCY.data, ("row",))), concatenate=False, align_arrays=False, meta=np.empty((0,), dtype=np.object)) result = da.blockwise(ddid_and_spw_factory, ("row", "chan"), chan_freqs, ("row", "chan"), chan_widths, ("row", "chan"), ref_freqs, ("row",), meas_freq_refs, ("row",), ddid_chan_map, (), meta=np.empty((0, 0), dtype=np.object)) # There should only be one chunk assert result.npartitions == 1 chan_freq = da.blockwise(getitem, ("row", "chan"), result, ("row", "chan"), 0, None, dtype=np.float64) chan_width = da.blockwise(getitem, ("row", "chan"), result, ("row", "chan"), 1, None, dtype=np.float64) num_chan = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 2, None, dtype=np.int32) ref_freq = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 3, None, dtype=np.float64) meas_freq_refs = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 4, None, dtype=np.float64) total_bw = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 5, None, dtype=np.float64) spectral_window_id = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 6, None, dtype=np.int32) polarization_id = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 7, None, dtype=np.int32) ddid_map = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 8, None, dtype=np.int32) for o, out_ds in enumerate(out_datasets): data_desc_id = da.blockwise(_new_ddids, ("row",), out_ds.DATA_DESC_ID.data, ("row",), out_ds.NUM_CHAN.data, ("row",), ddid_map, ("ddid",), dtype=out_ds.DATA_DESC_ID.dtype) dv = dict(out_ds.data_vars) dv["DATA_DESC_ID"] = (("row",), data_desc_id) del dv["NUM_CHAN"] del dv["DECORR_CHAN_WIDTH"] out_datasets[o] = Dataset(dv, out_ds.coords, out_ds.attrs) out_spw_ds = Dataset({ "CHAN_FREQ": (("row", "chan"), chan_freq), "CHAN_WIDTH": (("row", "chan"), chan_width), "EFFECTIVE_BW": (("row", "chan"), chan_width), "RESOLUTION": (("row", "chan"), chan_width), "NUM_CHAN": (("row",), num_chan), "REF_FREQUENCY": (("row",), ref_freq), "TOTAL_BANDWIDTH": (("row",), total_bw) }) out_ddid_ds = Dataset({ "SPECTRAL_WINDOW_ID": (("row",), spectral_window_id), "POLARIZATION_ID": (("row",), polarization_id), }) return out_datasets, [out_spw_ds], out_ddid_ds