def test_keyword_write(ms): datasets = xds_from_ms(ms) # Add to table keywords writes = xds_to_table([], ms, [], table_keywords={'bob': 'qux'}) dask.compute(writes) with pt.table(ms, ack=False, readonly=True) as T: assert T.getkeywords()['bob'] == 'qux' # Add to column keywords writes = xds_to_table(datasets, ms, [], column_keywords={'STATE_ID': { 'bob': 'qux' }}) dask.compute(writes) with pt.table(ms, ack=False, readonly=True) as T: assert T.getcolkeywords("STATE_ID")['bob'] == 'qux' # Remove from column and table keywords from daskms.writes import DELKW writes = xds_to_table(datasets, ms, [], table_keywords={'bob': DELKW}, column_keywords={'STATE_ID': { 'bob': DELKW }}) dask.compute(writes) with pt.table(ms, ack=False, readonly=True) as T: assert 'bob' not in T.getkeywords() assert 'bob' not in T.getcolkeywords("STATE_ID")
def execute(self): """ Execute the application """ logger.info("xova {args}", args=" ".join(self.cmdline_args)) self.args = args = parse_args(self.cmdline_args) log_args(args) self._create_output_ms(args) row_chunks, time_chunks, interval_chunks = self._derive_row_chunking( args) (main_ds, spw_ds, ddid_ds, field_ds, subtables) = self._input_datasets(args, row_chunks) # Set up Main MS data averaging main_ds = average_main(main_ds, field_ds, args.time_bin_secs, args.chan_bin_size, args.fields, args.scan_numbers, args.group_row_chunks, args.respect_flag_row, viscolumn=args.data_column) main_writes = xds_to_table(main_ds, args.output, "ALL") # Set up SPW data averaging spw_ds = average_spw(spw_ds, args.chan_bin_size) spw_table = "::".join((args.output, "SPECTRAL_WINDOW")) spw_writes = xds_to_table(spw_ds, spw_table, "ALL") copy_subtables(args.ms, args.output, subtables) self._execute_graph(main_writes, spw_writes)
def test_ms_create_and_update(Dataset, tmp_path, chunks): """ Test that we can update and append at the same time """ filename = str(tmp_path / "create-and-update.ms") rs = np.random.RandomState(42) # Create a dataset of 10 rows with DATA and DATA_DESC_ID dims = ("row", "chan", "corr") row, chan, corr = tuple(sum(chunks[d]) for d in dims) ms_datasets = [] np_data = (rs.normal(size=(row, chan, corr)) + 1j * rs.normal(size=(row, chan, corr))).astype(np.complex64) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) # Create dask ddid column dask_ddid = da.full(row, 0, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'DATA_DESC_ID': (("row", ), dask_ddid), }) ms_datasets.append(dataset) # Write it writes = xds_to_table(ms_datasets, filename, ["DATA", "DATA_DESC_ID"]) dask.compute(writes) ms_datasets = xds_from_ms(filename) # Now add another dataset (different DDID), with no ROWID np_data = (rs.normal(size=(row, chan, corr)) + 1j * rs.normal(size=(row, chan, corr))).astype(np.complex64) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) # Create dask ddid column dask_ddid = da.full(row, 1, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'DATA_DESC_ID': (("row", ), dask_ddid), }) ms_datasets.append(dataset) # Write it writes = xds_to_table(ms_datasets, filename, ["DATA", "DATA_DESC_ID"]) dask.compute(writes) # Rows have been added and additional data is present with pt.table(filename, ack=False, readonly=True) as T: first_data_desc_id = da.full(row, ms_datasets[0].DATA_DESC_ID, chunks=chunks['row']) ds_data = da.concatenate( [ms_datasets[0].DATA.data, ms_datasets[1].DATA.data]) ds_ddid = da.concatenate( [first_data_desc_id, ms_datasets[1].DATA_DESC_ID.data]) assert_array_equal(T.getcol("DATA"), ds_data) assert_array_equal(T.getcol("DATA_DESC_ID"), ds_ddid)
def test_write_table_proxy_keyword(ms): datasets = xds_from_ms(ms) # Test that we get a TableProxy if requested writes, tp = xds_to_table(datasets, ms, [], table_proxy=True) assert isinstance(writes, list) and isinstance(writes[0], Dataset) assert isinstance(tp, TableProxy) assert tp.nrows().result() == 10 writes = xds_to_table(datasets, ms, [], table_proxy=False) assert isinstance(writes, list) and isinstance(writes[0], Dataset)
def predict(args): # Convert source data into dask arrays sky_model = parse_sky_model(args.sky_model, args.model_chunks) # Get the support tables tables = support_tables( args, ["FIELD", "DATA_DESCRIPTION", "SPECTRAL_WINDOW", "POLARIZATION"]) field_ds = tables["FIELD"] ddid_ds = tables["DATA_DESCRIPTION"] spw_ds = tables["SPECTRAL_WINDOW"] pol_ds = tables["POLARIZATION"] # List of write operations writes = [] # Construct a graph for each DATA_DESC_ID for xds in xds_from_ms(args.ms, columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"], group_cols=["FIELD_ID", "DATA_DESC_ID"], chunks={"row": args.row_chunks}): # Extract frequencies from the spectral window associated # with this data descriptor id field = field_ds[xds.attrs['FIELD_ID']] ddid = ddid_ds[xds.attrs['DATA_DESC_ID']] spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]] pol = pol_ds[ddid.POLARIZATION_ID.data[0]] # Select single dataset row out corrs = pol.NUM_CORR.data[0] _, time_index = da.unique(xds.TIME.data, return_inverse=True) # Generate visibility expressions for each source type source_vis = [ vis_factory(args, stype, sky_model, time_index, xds, field, spw, pol) for stype in sky_model.keys() ] # Sum visibilities together vis = sum(source_vis) # Reshape (2, 2) correlation to shape (4,) if corrs == 4: vis = vis.reshape(vis.shape[:2] + (4, )) # Assign visibilities to MODEL_DATA array on the dataset xds = xds.assign(MODEL_DATA=(("row", "chan", "corr"), vis)) # Create a write to the table write = xds_to_table(xds, args.ms, ['MODEL_DATA']) # Add to the list of writes writes.append(write) # Submit all graph computations in parallel with ProgressBar(): dask.compute(writes)
def write_component_model(self, comps, ref_freq, mask, row_chunks, chan_chunks): print("Writing model data at full freq resolution") order, npix = comps.shape comps = da.from_array(comps, chunks=(-1, -1)) mask = da.from_array(mask.squeeze(), chunks=(-1, -1)) writes = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks={'row':(row_chunks,), 'chan':(chan_chunks,)}, columns=('MODEL_DATA', 'UVW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = spws[ds.DATA_DESC_ID] freq = spw.CHAN_FREQ.data.squeeze() freq_bin_idx = da.arange(0, freq.size, 1, chunks=freq.chunks, dtype=np.int64) freq_bin_counts = da.ones(freq.size, chunks=freq.chunks, dtype=np.int64) uvw = ds.UVW.data model_vis = getattr(ds, 'MODEL_DATA').data model = model_from_comps(comps, freq, mask, ref_freq) vis = im2vis(uvw, freq, model, freq_bin_idx, freq_bin_counts, self.cell, nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking) model_vis = populate_model(vis, model_vis) out_ds = ds.assign(**{self.model_column: (("row", "chan", "corr"), model_vis)}) out_data.append(out_ds) writes.append(xds_to_table(out_data, ims, columns=[self.model_column])) dask.compute(writes, scheduler='single-threaded')
def write_model(self, x): print("Writing model data") x = da.from_array(x.astype(np.float32), chunks=(1, self.nx, self.ny)) writes = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=('MODEL_DATA', 'UVW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data model_vis = getattr(ds, 'MODEL_DATA').data bands = self.band_mapping[ims][spw] model = x[list(bands), :, :] vis = im2vis(uvw, freq, model, freq_bin_idx, freq_bin_counts, self.cell, nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking) model_vis = populate_model(vis, model_vis) out_ds = ds.assign(**{self.model_column: (("row", "chan", "corr"), model_vis)}) out_data.append(out_ds) writes.append(xds_to_table(out_data, ims, columns=[self.model_column])) dask.compute(writes, scheduler='single-threaded')
def execute(self): """ Execute the application """ logger.info("xova version {v}", v=xova_version) logger.info("xova {args}", args=" ".join(self.cmdline_args)) self.args = args = parse_args(self.cmdline_args) log_args(args) if args.command == "check": check_ms(args) logger.info("{ms} is conformant".format(ms=args.ms)) return self._maybe_remove_output_ms(args) row_chunks, time_chunks, interval_chunks = self._derive_row_chunking( args) (main_ds, spw_ds, ddid_ds, field_ds, subtables) = self._input_datasets(args, row_chunks) # Set up Main MS data averaging if args.command == "timechannel": output_ds = average_main(main_ds, field_ds, args.time_bin_secs, args.chan_bin_size, args.fields, args.scan_numbers, args.group_row_chunks, args.respect_flag_row, viscolumn=args.data_column) spw_ds = average_spw(spw_ds, args.chan_bin_size) elif args.command == "bda": output_ds = bda_average_main(main_ds, field_ds, ddid_ds, spw_ds, args) output_ds, spw_ds, out_ddid_ds = bda_average_spw( output_ds, ddid_ds, spw_ds) else: raise ValueError("Invalid command %s" % args.command) main_writes = xds_to_table(output_ds, args.output, "ALL", descriptor="ms(False)") spw_table = "::".join((args.output, "SPECTRAL_WINDOW")) spw_writes = xds_to_table(spw_ds, spw_table, "ALL") if args.command == "bda": ddid_table = "::".join((args.output, "DATA_DESCRIPTION")) ddid_writes = xds_to_table(out_ddid_ds, ddid_table, "ALL") subtables.discard("DATA_DESCRIPTION") else: ddid_writes = None copy_subtables(args.ms, args.output, subtables) self._execute_graph(main_writes, spw_writes, ddid_writes) if not args.average_uvw_coordinates: fixms(args.output) else: logger.warning( "Applying approximation to uvw coordinates as you requested - " "the spatial frequencies of your long baseline data may be " "serverely affected!")
def main(args): """ Flags outliers in data given a model and rescale weights so that whitened residuals have a mean amplitude of sqrt(2). Flags and weights are computed per chunk of data """ radec_ref = None writes = [] for ims in args.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks={ "row": args.row_chunks, "chan": args.chan_chunks }, columns=('UVW', args.data_column, args.weight_column, args.model_column, args.flag_column, 'FLAG_ROW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() # check fields match if radec_ref is None: radec_ref = radec if not np.array_equal(radec, radec_ref): continue # load in data and compute whitened residuals data = getattr(ds, args.data_column).data model = getattr(ds, args.model_column).data flag = getattr(ds, args.flag_column).data flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None]) weights = getattr(ds, args.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) if args.trim_channels: flag = trim_chans(flag, args.trim_channels) # Stokes I vis weights = (~flag) * weights resid_vis = (data - model) * weights wsums = (weights[:, :, 0] + weights[:, :, -1]) resid_vis_I = da.where( wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums, 0.0j) # whiten and take abs white_resid = resid_vis_I * da.sqrt(wsums) abs_resid_vis_I = (white_resid).__abs__() # mean amp sum_amp = da.sum(abs_resid_vis_I) count = da.sum(wsums > 0) mean_amp = sum_amp / count flag_legacy = flag[:, :, 0] | flag[:, :, -1] flag_I = da.logical_or(abs_resid_vis_I > args.sigma_cut * mean_amp, flag_legacy) # new flags updated_flag = da.broadcast_to(flag_I[:, :, None], flag.shape, chunks=flag.chunks) # scale weights (whitened residuals should have mean amplitude of 1/sqrt(2)) if args.scale_weights: # recompute mean amp with new flags weights = (~updated_flag) * weights resid_vis = (data - model) * weights wsums = (weights[:, :, 0] + weights[:, :, -1]) resid_vis_I = da.where( wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums, 0.0j) white_resid = resid_vis_I * da.sqrt(wsums) abs_resid_vis_I = (white_resid).__abs__() sum_amp = da.sum(abs_resid_vis_I) count = da.sum(wsums > 0) mean_amp = sum_amp / count updated_weight = 2**0.5 * weights / mean_amp**2 else: updated_weight = weights ds = ds.assign(**{ args.weight_out_column: (("row", "chan", "corr"), updated_weight) }) ds = ds.assign(**{ args.flag_out_column: (("row", "chan", "corr"), updated_flag) }) out_data.append(ds) writes.append( xds_to_table( out_data, ims, columns=[args.flag_out_column, args.weight_out_column])) with ProgressBar(): dask.compute(writes) # report new mean amp if args.report_means: radec_ref = None mean_amps = [] for ims in args.ms: xds = xds_from_ms( ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks={ "row": args.row_chunks, "chan": args.chan_chunks }, columns=('UVW', args.data_column, args.weight_out_column, args.model_column, args.flag_out_column, 'FLAG_ROW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() # check fields match if radec_ref is None: radec_ref = radec if not np.array_equal(radec, radec_ref): continue # load in data and compute whitened residuals data = getattr(ds, args.data_column).data model = getattr(ds, args.model_column).data flag = getattr(ds, args.flag_out_column).data flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None]) weights = getattr(ds, args.weight_out_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) # Stokes I vis weights = (~flag) * weights resid_vis = (data - model) * weights wsums = (weights[:, :, 0] + weights[:, :, -1]) resid_vis_I = da.where( wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums, 0.0j) # whiten and take abs white_resid = resid_vis_I * da.sqrt(wsums) abs_resid_vis_I = (white_resid).__abs__() # mean amp sum_amp = da.sum(abs_resid_vis_I) count = da.sum(wsums > 0) mean_amps.append(sum_amp / count) mean_amps = dask.compute(mean_amps)[0] print(mean_amps)
def ms_create(ms_table_name, info, ant_pos, vis_array, baselines, timestamps, pol_feeds, sources): ''' Create a Measurement Set from some TART observations Parameters ---------- ms_table_name : string The name of the MS top level directory. I think this only workds in the local directory. info : JSON "info": { "info": { "L0_frequency": 1571328000.0, "bandwidth": 2500000.0, "baseband_frequency": 4092000.0, "location": { "alt": 270.0, "lat": -45.85177, "lon": 170.5456 }, "name": "Signal Hill - Dunedin", "num_antenna": 24, "operating_frequency": 1575420000.0, "sampling_frequency": 16368000.0 } }, Returns ------- None ''' epoch_s = timestamp_to_ms_epoch(timestamps) LOGGER.info("Time {}".format(epoch_s)) try: loc = info['location'] except: loc = info # Sort out the coordinate frames using astropy # https://casa.nrao.edu/casadocs/casa-5.4.1/reference-material/coordinate-frames iers.conf.iers_auto_url = 'https://astroconda.org/aux/astropy_mirror/iers_a_1/finals2000A.all' iers.conf.auto_max_age = None location = EarthLocation.from_geodetic(lon=loc['lon']*u.deg, lat=loc['lat']*u.deg, height=loc['alt']*u.m, ellipsoid='WGS84') obstime = Time(timestamps) local_frame = AltAz(obstime=obstime, location=location) phase_altaz = SkyCoord(alt=90.0*u.deg, az=0.0*u.deg, obstime = obstime, frame = 'altaz', location = location) phase_j2000 = phase_altaz.transform_to('fk5') # Get the stokes enums for the polarization types corr_types = [[MS_STOKES_ENUMS[p_f] for p_f in pol_feeds]] LOGGER.info("Pol Feeds {}".format(pol_feeds)) LOGGER.info("Correlation Types {}".format(corr_types)) num_freq_channels = [1] ant_table = MSTable(ms_table_name, 'ANTENNA') feed_table = MSTable(ms_table_name, 'FEED') field_table = MSTable(ms_table_name, 'FIELD') pol_table = MSTable(ms_table_name, 'POLARIZATION') obs_table = MSTable(ms_table_name, 'OBSERVATION') # SOURCE is an optional MS sub-table src_table = MSTable(ms_table_name, 'SOURCE') ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION")) spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW")) ms_datasets = [] ddid_datasets = [] spw_datasets = [] # Create ANTENNA dataset # Each column in the ANTENNA has a fixed shape so we # can represent all rows with one dataset num_ant = len(ant_pos) position = da.asarray(ant_pos) diameter = da.ones(num_ant) * 0.025 offset = da.zeros((num_ant, 3)) names = np.array(['ANTENNA-%d' % i for i in range(num_ant)], dtype=np.object) stations = np.array([info['name'] for i in range(num_ant)], dtype=np.object) dataset = Dataset({ 'POSITION': (("row", "xyz"), position), 'OFFSET': (("row", "xyz"), offset), 'DISH_DIAMETER': (("row",), diameter), 'NAME': (("row",), da.from_array(names, chunks=num_ant)), 'STATION': (("row",), da.from_array(stations, chunks=num_ant)), }) ant_table.append(dataset) ################### Create a FEED dataset. ################################### # There is one feed per antenna, so this should be quite similar to the ANTENNA num_pols = len(pol_feeds) pol_types = pol_feeds pol_responses = [POL_RESPONSES[ct] for ct in pol_feeds] LOGGER.info("Pol Types {}".format(pol_types)) LOGGER.info("Pol Responses {}".format(pol_responses)) antenna_ids = da.asarray(range(num_ant)) feed_ids = da.zeros(num_ant) num_receptors = da.zeros(num_ant) + num_pols polarization_types = np.array([pol_types for i in range(num_ant)], dtype=np.object) receptor_angles = np.array([[0.0] for i in range(num_ant)]) pol_response = np.array([pol_responses for i in range(num_ant)]) beam_offset = np.array([[[0.0, 0.0]] for i in range(num_ant)]) dataset = Dataset({ 'ANTENNA_ID': (("row",), antenna_ids), 'FEED_ID': (("row",), feed_ids), 'NUM_RECEPTORS': (("row",), num_receptors), 'POLARIZATION_TYPE': (("row", "receptors",), da.from_array(polarization_types, chunks=num_ant)), 'RECEPTOR_ANGLE': (("row", "receptors",), da.from_array(receptor_angles, chunks=num_ant)), 'POL_RESPONSE': (("row", "receptors", "receptors-2"), da.from_array(pol_response, chunks=num_ant)), 'BEAM_OFFSET': (("row", "receptors", "radec"), da.from_array(beam_offset, chunks=num_ant)), }) feed_table.append(dataset) ####################### FIELD dataset ######################################### direction = [[phase_j2000.ra.radian, phase_j2000.dec.radian]] field_direction = da.asarray(direction)[None, :] field_name = da.asarray(np.asarray(['up'], dtype=np.object), chunks=1) field_num_poly = da.zeros(1) # Zero order polynomial in time for phase center. dir_dims = ("row", 'field-poly', 'field-dir',) dataset = Dataset({ 'PHASE_DIR': (dir_dims, field_direction), 'DELAY_DIR': (dir_dims, field_direction), 'REFERENCE_DIR': (dir_dims, field_direction), 'NUM_POLY': (("row", ), field_num_poly), 'NAME': (("row", ), field_name), }) field_table.append(dataset) ######################### OBSERVATION dataset ##################################### dataset = Dataset({ 'TELESCOPE_NAME': (("row",), da.asarray(np.asarray(['TART'], dtype=np.object), chunks=1)), 'OBSERVER': (("row",), da.asarray(np.asarray(['Tim'], dtype=np.object), chunks=1)), "TIME_RANGE": (("row","obs-exts"), da.asarray(np.array([[epoch_s, epoch_s+1]]), chunks=1)), }) obs_table.append(dataset) ######################## SOURCE datasets ######################################## for src in sources: name = src['name'] # Convert to J2000 dir_altaz = SkyCoord(alt=src['el']*u.deg, az=src['az']*u.deg, obstime = obstime, frame = 'altaz', location = location) dir_j2000 = dir_altaz.transform_to('fk5') direction = [dir_j2000.ra.radian, dir_j2000.dec.radian] #LOGGER.info("SOURCE: {}, timestamp: {}".format(name, timestamps)) dask_num_lines = da.full((1,), 1, dtype=np.int32) dask_direction = da.asarray(direction)[None, :] dask_name = da.asarray(np.asarray([name], dtype=np.object), chunks=1) dask_time = da.asarray(np.array([epoch_s])) dataset = Dataset({ "NUM_LINES": (("row",), dask_num_lines), "NAME": (("row",), dask_name), "TIME": (("row",), dask_time), "DIRECTION": (("row", "dir"), dask_direction), }) src_table.append(dataset) # Create POLARISATION datasets. # Dataset per output row required because column shapes are variable for corr_type in corr_types: corr_prod = [[i, i] for i in range(len(corr_type))] corr_prod = np.array(corr_prod) LOGGER.info("Corr Prod {}".format(corr_prod)) LOGGER.info("Corr Type {}".format(corr_type)) dask_num_corr = da.full((1,), len(corr_type), dtype=np.int32) LOGGER.info("NUM_CORR {}".format(dask_num_corr)) dask_corr_type = da.from_array(corr_type, chunks=len(corr_type))[None, :] dask_corr_product = da.asarray(corr_prod)[None, :] LOGGER.info("Dask Corr Prod {}".format(dask_corr_product.shape)) LOGGER.info("Dask Corr Type {}".format(dask_corr_type.shape)) dataset = Dataset({ "NUM_CORR": (("row",), dask_num_corr), "CORR_TYPE": (("row", "corr"), dask_corr_type), "CORR_PRODUCT": (("row", "corr", "corrprod_idx"), dask_corr_product), }) pol_table.append(dataset) # Create multiple SPECTRAL_WINDOW datasets # Dataset per output row required because column shapes are variable for num_chan in num_freq_channels: dask_num_chan = da.full((1,), num_chan, dtype=np.int32) dask_chan_freq = da.asarray([[info['operating_frequency']]]) dask_chan_width = da.full((1, num_chan), 2.5e6/num_chan) dataset = Dataset({ "NUM_CHAN": (("row",), dask_num_chan), "CHAN_FREQ": (("row", "chan"), dask_chan_freq), "CHAN_WIDTH": (("row", "chan"), dask_chan_width), "EFFECTIVE_BW": (("row", "chan"), dask_chan_width), "RESOLUTION": (("row", "chan"), dask_chan_width), }) spw_datasets.append(dataset) # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION # create a corresponding DATA_DESCRIPTION. # Each column has fixed shape so we handle all rows at once spw_ids, pol_ids = zip(*product(range(len(num_freq_channels)), range(len(corr_types)))) dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32)) dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32)) ddid_datasets.append(Dataset({ "SPECTRAL_WINDOW_ID": (("row",), dask_spw_ids), "POLARIZATION_ID": (("row",), dask_pol_ids), })) # Now create the associated MS dataset #vis_data, baselines = cal_vis.get_all_visibility() #vis_array = np.array(vis_data, dtype=np.complex64) chunks = { "row": (vis_array.shape[0],), } baselines = np.array(baselines) #LOGGER.info(f"baselines {baselines}") bl_pos = np.array(ant_pos)[baselines] uu_a, vv_a, ww_a = -(bl_pos[:, 1] - bl_pos[:, 0]).T #/constants.L1_WAVELENGTH # Use the - sign to get the same orientation as our tart projections. uvw_array = np.array([uu_a, vv_a, ww_a]).T for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)): # Infer row, chan and correlation shape #LOGGER.info("ddid:{} ({}, {})".format(ddid, spw_id, pol_id)) row = sum(chunks['row']) chan = spw_datasets[spw_id].CHAN_FREQ.shape[1] corr = pol_table.datasets[pol_id].CORR_TYPE.shape[1] # Create some dask vis data dims = ("row", "chan", "corr") LOGGER.info("Data size %s %s %s" % (row, chan, corr)) #np_data = vis_array.reshape((row, chan, corr)) np_data = np.zeros((row, chan, corr), dtype=np.complex128) for i in range(corr): np_data[:, :, i] = vis_array.reshape((row, chan)) #np_data = np.array([vis_array.reshape((row, chan, 1)) for i in range(corr)]) np_uvw = uvw_array.reshape((row, 3)) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) flag_categories = da.from_array(0.05*np.ones((row, chan, corr, 1))) flag_data = np.zeros((row, chan, corr), dtype=np.bool_) uvw_data = da.from_array(np_uvw) # Create dask ddid column dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'FLAG': (dims, da.from_array(flag_data)), 'TIME': (("row", "corr"), da.from_array(epoch_s*np.ones((row, corr)))), 'TIME_CENTROID': (("row", "corr"), da.from_array(epoch_s*np.ones((row, corr)))), 'WEIGHT': (("row", "corr"), da.from_array(0.95*np.ones((row, corr)))), 'WEIGHT_SPECTRUM': (dims, da.from_array(0.95*np.ones_like(np_data, dtype=np.float64))), 'SIGMA_SPECTRUM': (dims, da.from_array(np.ones_like(np_data, dtype=np.float64)*0.05)), 'SIGMA': (("row", "corr"), da.from_array(0.05*np.ones((row, corr)))), 'UVW': (("row", "uvw",), uvw_data), 'FLAG_CATEGORY': (('row', 'flagcat', 'chan', 'corr'), flag_categories), # {'dims': ('flagcat', 'chan', 'corr')} 'ANTENNA1': (("row",), da.from_array(baselines[:, 0])), 'ANTENNA2': (("row",), da.from_array(baselines[:, 1])), 'FEED1': (("row",), da.from_array(baselines[:, 0])), 'FEED2': (("row",), da.from_array(baselines[:, 1])), 'DATA_DESC_ID': (("row",), dask_ddid) }) ms_datasets.append(dataset) ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL") spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL") ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL") dask.compute(ms_writes) ant_table.write() feed_table.write() field_table.write() pol_table.write() obs_table.write() src_table.write() dask.compute(spw_writes) dask.compute(ddid_writes)
def write(self): writes = xds_to_table(self.datasets, self.table_name, columns="ALL") dask.compute(writes)
ncomps = idx_nz.size model_predict = da.from_array(model_predict, chunks=(ncomps, nchan, ncorr)) lm = da.from_array(lm[idx_nz, :], chunks=(ncomps, 2)) ms_freqs = spw_ds.CHAN_FREQ.data xds = xds_from_ms(args.ms, columns=["UVW", args.colname], chunks={"row": args.row_chunks})[0] uvw = xds.UVW.data vis = im_to_vis(model_predict, uvw, lm, ms_freqs) data = getattr(xds, args.colname) if data.shape != vis.shape: print("Assuming only Stokes I passed in") if vis.shape[-1] == 1 and data.shape[-1] == 4: tmp_zero = da.zeros(vis.shape, chunks=(args.row_chunks, nchan, 1)) vis = da.concatenate((vis, tmp_zero, tmp_zero, vis), axis=-1) elif vis.shape[-1] == 1 and data.shape[-1] == 2: vis = da.concatenate((vis, vis), axis=-1) else: raise ValueError("Incompatible corr axes") vis = vis.rechunk((args.row_chunks, nchan, data.shape[-1])) # Assign visibilities to MODEL_DATA array on the dataset xds = xds.assign(**{args.colname: (("row", "chan", "corr"), vis)}) # Create a write to the table write = xds_to_table(xds, args.ms, [args.colname]) # Submit all graph computations in parallel with ProgressBar(): dask.compute(write)
def _predict(ms, stack, **kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) pyscilog.log_to_file(args.output_filename + '.log') pyscilog.enable_memory_logging(level=3) # number of threads per worker if args.nthreads is None: if args.host_address is not None: raise ValueError( "You have to specify nthreads when using a distributed scheduler" ) import multiprocessing nthreads = multiprocessing.cpu_count() args.nthreads = nthreads else: nthreads = args.nthreads if args.mem_limit is None: if args.host_address is not None: raise ValueError( "You have to specify mem-limit when using a distributed scheduler" ) import psutil mem_limit = int(psutil.virtual_memory()[0] / 1e9) # 100% of memory by default args.mem_limit = mem_limit else: mem_limit = args.mem_limit nband = args.nband if args.nworkers is None: nworkers = nband args.nworkers = nworkers else: nworkers = args.nworkers if args.nthreads_per_worker is None: nthreads_per_worker = 1 args.nthreads_per_worker = nthreads_per_worker else: nthreads_per_worker = args.nthreads_per_worker # the number of chunks being read in simultaneously is equal to # the number of dask threads nthreads_dask = nworkers * nthreads_per_worker if args.ngridder_threads is None: if args.host_address is not None: ngridder_threads = nthreads // nthreads_per_worker else: ngridder_threads = nthreads // nthreads_dask args.ngridder_threads = ngridder_threads else: ngridder_threads = args.ngridder_threads ms = list(ms) print('Input Options:', file=log) for key in kw.keys(): print(' %25s = %s' % (key, args[key]), file=log) # numpy imports have to happen after this step from pfb import set_client set_client(nthreads, mem_limit, nworkers, nthreads_per_worker, args.host_address, stack, log) import numpy as np from pfb.utils.misc import chan_to_band_mapping import dask from dask.distributed import performance_report from dask.graph_manipulation import clone from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table from daskms.utils import dataset_type mstype = dataset_type(ms[0]) if mstype == 'casa': from daskms import xds_to_table elif mstype == 'zarr': from daskms.experimental.zarr import xds_to_zarr as xds_to_table import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import model as im2vis from pfb.utils.fits import load_fits from pfb.utils.misc import restore_corrs, plan_row_chunk from astropy.io import fits # always returns 4D # gridder expects freq axis model = np.atleast_3d(load_fits(args.model).squeeze()) nband, nx, ny = model.shape hdr = fits.getheader(args.model) cell_d = np.abs(hdr['CDELT1']) cell_rad = np.deg2rad(cell_d) # chan <-> band mapping freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping( ms, nband=nband) # degridder memory budget max_chan_chunk = 0 for ims in ms: for spw in freqs[ims]: counts = freq_bin_counts[ims][spw].compute() max_chan_chunk = np.maximum(max_chan_chunk, counts.max()) # assumes number of correlations are the same across MS/SPW xds = xds_from_ms(ms[0]) ncorr = xds[0].dims['corr'] nrow = xds[0].dims['row'] if args.output_type is not None: output_type = np.dtype(args.output_type) else: output_type = np.result_type(np.dtype(args.real_type), np.complex64) data_bytes = output_type.itemsize bytes_per_row = max_chan_chunk * ncorr * data_bytes memory_per_row = bytes_per_row # model memory_per_row += 3 * 8 # uvw if mstype == 'zarr': if args.model_column in xds[0].keys(): model_chunks = getattr(xds[0], args.model_column).data.chunks else: model_chunks = xds[0].DATA.data.chunks print('Chunking model same as data') # get approx image size # this is not a conservative estimate when multiple SPW's map to a single # imaging band pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # full image on single node row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow, memory_per_row, nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row, nthreads_per_worker) if args.row_chunks is not None: row_chunk = int(args.row_chunks) if row_chunk == -1: row_chunk = nrow print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) chunks = {} for ims in ms: chunks[ims] = [] # xds_from_ms expects a list per ds for spw in freqs[ims]: chunks[ims].append({ 'row': row_chunk, 'chan': chan_chunks[ims][spw]['chan'] }) model = da.from_array(model.astype(args.real_type), chunks=(1, nx, ny), name=False) writes = [] radec = None # assumes we are only imaging field 0 of first MS for ims in ms: xds = xds_from_ms(ims, chunks=chunks[ims], columns=('UVW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw uvw = clone(ds.UVW.data) bands = band_mapping[ims][spw] model = model[list(bands), :, :] vis = im2vis(uvw, freqs[ims][spw], model, freq_bin_idx[ims][spw], freq_bin_counts[ims][spw], cell_rad, nthreads=ngridder_threads, epsilon=args.epsilon, do_wstacking=args.wstack) model_vis = restore_corrs(vis, ncorr) if mstype == 'zarr': model_vis = model_vis.rechunk(model_chunks) uvw = uvw.rechunk((model_chunks[0], 3)) out_ds = ds.assign( **{ args.model_column: (("row", "chan", "corr"), model_vis), 'UVW': (("row", "three"), uvw) }) # out_ds = ds.assign(**{args.model_column: (("row", "chan", "corr"), model_vis)}) out_data.append(out_ds) writes.append(xds_to_table(out_data, ims, columns=[args.model_column])) dask.visualize(*writes, filename=args.output_filename + '_predict_graph.pdf', optimize_graph=False, collapse_outputs=True) if not args.mock: with performance_report(filename=args.output_filename + '_predict_per.html'): dask.compute(writes, optimize_graph=False) print("All done here.", file=log)
def predict(args): # Convert source data into dask arrays sky_model = parse_sky_model(args.sky_model, args.model_chunks) # Get the support tables tables = support_tables(args) ant_ds = tables["ANTENNA"] field_ds = tables["FIELD"] ddid_ds = tables["DATA_DESCRIPTION"] spw_ds = tables["SPECTRAL_WINDOW"] pol_ds = tables["POLARIZATION"] # List of write operations writes = [] # Construct a graph for each DATA_DESC_ID for xds in xds_from_ms( args.ms, columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"], group_cols=["FIELD_ID", "DATA_DESC_ID"], chunks={"row": args.row_chunks}, ): # Perform subtable joins ant = ant_ds[0] field = field_ds[xds.attrs["FIELD_ID"]] ddid = ddid_ds[xds.attrs["DATA_DESC_ID"]] spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]] pol = pol_ds[ddid.POLARIZATION_ID.data[0]] # Select single dataset row out corrs = pol.NUM_CORR.data[0] # Generate visibility expressions for each source type source_vis = [ vis_factory(args, stype, sky_model, xds, ant, field, spw, pol) for stype in sky_model.keys() ] # Sum visibilities together vis = sum(source_vis) # Reshape (2, 2) correlation to shape (4,) if corrs == 4: vis = vis.reshape(vis.shape[:2] + (4, )) # Assign visibilities to MODEL_DATA array on the dataset xds = (xds.assign(MODEL_DATA=(("row", "chan", "corr"), vis)) if args.data_column == "MODEL_DATA" else xds.assign(CORRECTED_DATA=(("row", "chan", "corr"), vis))) # Create a write to the table write = xds_to_table(xds, args.ms, [args.data_column]) # Add to the list of writes writes.append(write) # Submit all graph computations in parallel with ProgressBar(): da.compute(writes)
def _predict(args): # get inclusion regions include_regions = load_regions(args.within) if args.within else [] # Import source data from WSClean component list # See https://sourceforge.net/p/wsclean/wiki/ComponentList (comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind, gaussian_shape) = import_from_wsclean(args.sky_model, include_regions=include_regions, point_only=args.points_only, num=args.num_sources or None) # Add output column if it isn't present ms_rows, ms_datatype = ms_preprocess(args) # Get the support tables tables = support_tables( args, ["FIELD", "DATA_DESCRIPTION", "SPECTRAL_WINDOW", "POLARIZATION"]) field_ds = tables["FIELD"] ddid_ds = tables["DATA_DESCRIPTION"] spw_ds = tables["SPECTRAL_WINDOW"] pol_ds = tables["POLARIZATION"] max_num_chan = max([ss.NUM_CHAN.data[0] for ss in spw_ds]) max_num_corr = max([ss.NUM_CORR.data[0] for ss in pol_ds]) # Perform resource budgeting args.row_chunks, args.model_chunks = get_budget(comp_type.shape[0], ms_rows, max_num_chan, max_num_corr, ms_datatype, args) radec = da.from_array(radec, chunks=(args.model_chunks, 2)) stokes = da.from_array(stokes, chunks=(args.model_chunks, 4)) if np.count_nonzero(comp_type == 'GAUSSIAN') > 0: gaussian_components = True gshape_chunks = (args.model_chunks, 3) gaussian_shape = da.from_array(gaussian_shape, chunks=gshape_chunks) else: gaussian_components = False if args.spectra: spec_chunks = (args.model_chunks, spec_coeff.shape[1]) spec_coeff = da.from_array(spec_coeff, chunks=spec_chunks) ref_freq = da.from_array(ref_freq, chunks=(args.model_chunks, )) # List of write operations writes = [] # Construct a graph for each FIELD and DATA DESCRIPTOR datasets = xds_from_ms(args.ms, columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"], group_cols=["FIELD_ID", "DATA_DESC_ID"], chunks={"row": args.row_chunks}) select_fields = valid_field_ids(field_ds, args.fields) for xds in filter_datasets(datasets, select_fields): # Extract frequencies from the spectral window associated # with this data descriptor id field = field_ds[xds.attrs['FIELD_ID']] ddid = ddid_ds[xds.attrs['DATA_DESC_ID']] spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]] pol = pol_ds[ddid.POLARIZATION_ID.data[0]] frequency = spw.CHAN_FREQ.data[0] corrs = pol.NUM_CORR.values lm = radec_to_lm(radec, field.PHASE_DIR.data[0][0]) if args.exp_sign_convention == 'casa': uvw = -xds.UVW.data elif args.exp_sign_convention == 'thompson': uvw = xds.UVW.data else: raise ValueError("Invalid sign convention '%s'" % args.sign) if args.spectra: # flux density at reference frequency ... # ... for logarithmic polynomial functions if log_spec_ind: Is = da.log(stokes[:, 0, None]) * frequency[None, :]**0 # ... or for ordinary polynomial functions else: Is = stokes[:, 0, None] * frequency[None, :]**0 # additional terms of SED ... for jj in range(spec_coeff.shape[1]): # ... for logarithmic polynomial functions if log_spec_ind: Is += spec_coeff[:, jj, None] * \ da.log((frequency[None, :]/ref_freq[:, None])**(jj+1)) # ... or for ordinary polynomial functions else: Is += spec_coeff[:, jj, None] * \ (frequency[None, :]/ref_freq[:, None]-1)**(jj+1) if log_spec_ind: Is = da.exp(Is) Qs = da.zeros_like(Is) Us = da.zeros_like(Is) Vs = da.zeros_like(Is) # stack along new axis and make it the last axis of the new array spectrum = da.stack([Is, Qs, Us, Vs], axis=-1) spectrum = spectrum.rechunk(spectrum.chunks[:2] + (spectrum.shape[2], )) print('-------------------------------------------') print('Nr sources = {0:d}'.format(stokes.shape[0])) print('-------------------------------------------') print('stokes.shape = {0:}'.format(stokes.shape)) print('frequency.shape = {0:}'.format(frequency.shape)) if args.spectra: print('Is.shape = {0:}'.format(Is.shape)) if args.spectra: print('spectrum.shape = {0:}'.format(spectrum.shape)) # (source, row, frequency) phase = phase_delay(lm, uvw, frequency) # If at least one Gaussian component is present in the component # list then all sources are modelled as Gaussian components # (Delta components have zero width) if gaussian_components: phase *= gaussian(uvw, frequency, gaussian_shape) # (source, frequency, corr_products) brightness = convert(spectrum if args.spectra else stokes, ["I", "Q", "U", "V"], corr_schema(pol)) print('brightness.shape = {0:}'.format(brightness.shape)) print('phase.shape = {0:}'.format(phase.shape)) print('-------------------------------------------') print('Attempting phase-brightness einsum with "{0:s}"'.format( einsum_schema(pol, args.spectra))) # (source, row, frequency, corr_products) jones = da.einsum(einsum_schema(pol, args.spectra), phase, brightness) print('jones.shape = {0:}'.format(jones.shape)) print('-------------------------------------------') if gaussian_components: print('Some Gaussian sources found') else: print('All sources are Delta functions') print('-------------------------------------------') # Identify time indices _, time_index = da.unique(xds.TIME.data, return_inverse=True) # Predict visibilities vis = predict_vis(time_index, xds.ANTENNA1.data, xds.ANTENNA2.data, None, jones, None, None, None, None) # Reshape (2, 2) correlation to shape (4,) if corrs == 4: vis = vis.reshape(vis.shape[:2] + (4, )) # Assign visibilities to MODEL_DATA array on the dataset xds = xds.assign( **{args.output_column: (("row", "chan", "corr"), vis)}) # Create a write to the table write = xds_to_table(xds, args.ms, [args.output_column]) # Add to the list of writes writes.append(write) with ExitStack() as stack: if sys.stdout.isatty(): # Default progress bar in user terminal stack.enter_context(ProgressBar()) else: # Log progress every 5 minutes stack.enter_context(ProgressBar(minimum=2 * 60, dt=5)) # Submit all graph computations in parallel dask.compute(writes)
def simulate(args): # get full time column and compute row chunks ms = table(args.ms) time = ms.getcol('TIME') row_chunks, tbin_idx, tbin_counts = chunkify_rows(time, args.utimes_per_chunk) # convert to dask arrays tbin_idx = da.from_array(tbin_idx, chunks=(args.utimes_per_chunk)) tbin_counts = da.from_array(tbin_counts, chunks=(args.utimes_per_chunk)) n_time = tbin_idx.size ant1 = ms.getcol('ANTENNA1') ant2 = ms.getcol('ANTENNA2') n_ant = np.maximum(ant1.max(), ant2.max()) + 1 flag = ms.getcol("FLAG") n_row, n_freq, n_corr = flag.shape if n_corr == 4: model_corr = (2, 2) jones_corr = (2, ) elif n_corr == 2: model_corr = (2, ) jones_corr = (2, ) elif n_corr == 1: model_corr = (1, ) jones_corr = (1, ) else: raise RuntimeError("Invalid number of correlations") ms.close() # get phase dir radec0 = table(args.ms + '::FIELD').getcol('PHASE_DIR').squeeze() # get freqs freq = table(args.ms + '::SPECTRAL_WINDOW').getcol('CHAN_FREQ')[0].astype( np.float64) assert freq.size == n_freq # get source coordinates from lsm lsm = Tigger.load(args.sky_model) radec = [] stokes = [] spi = [] ref_freqs = [] for source in lsm.sources: radec.append([source.pos.ra, source.pos.dec]) stokes.append([source.flux.I]) tmp_spec = source.spectrum spi.append([tmp_spec.spi if tmp_spec is not None else 0.0]) ref_freqs.append([tmp_spec.freq0 if tmp_spec is not None else 1.0]) n_dir = len(stokes) radec = np.asarray(radec) lm = radec_to_lm(radec, radec0) # load in the model file model = np.zeros((n_freq, n_dir) + model_corr) stokes = np.asarray(stokes) ref_freqs = np.asarray(ref_freqs) spi = np.asarray(spi) for d in range(n_dir): Stokes_I = stokes[d] * (freq / ref_freqs[d])**spi[d] if n_corr == 4: model[:, d, 0, 0] = Stokes_I model[:, d, 1, 1] = Stokes_I elif n_corr == 2: model[:, d, 0] = Stokes_I model[:, d, 1] = Stokes_I else: model[:, d, 0] = Stokes_I # append antenna columns cols = [] cols.append('ANTENNA1') cols.append('ANTENNA2') cols.append('UVW') # load in gains jones, alphas = make_screen(lm, freq, n_time, n_ant, jones_corr[0]) jones = jones.astype(np.complex128) jones_shape = jones.shape jones_da = da.from_array(jones, chunks=(args.utimes_per_chunk, ) + jones_shape[1::]) freqs = da.from_array(freq, chunks=(n_freq)) lm = da.from_array(np.tile(lm[None], (n_time, 1, 1)), chunks=(args.utimes_per_chunk, n_dir, 2)) # change model to dask array tmp_shape = (n_time, ) for i in range(len(model.shape)): tmp_shape += (1, ) model = da.from_array(np.tile(model[None], tmp_shape), chunks=(args.utimes_per_chunk, ) + model.shape) # load data in in chunks and apply gains to each chunk xds = xds_from_ms(args.ms, columns=cols, chunks={"row": row_chunks})[0] ant1 = xds.ANTENNA1.data ant2 = xds.ANTENNA2.data uvw = xds.UVW.data # apply gains data = compute_and_corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones_da, model, uvw, freqs, lm) # Assign visibilities to args.out_col and write to ms xds = xds.assign( **{ args.out_col: (("row", "chan", "corr"), data.reshape(n_row, n_freq, n_corr)) }) # Create a write to the table write = xds_to_table(xds, args.ms, [args.out_col]) # Submit all graph computations in parallel with ProgressBar(): write.compute() return jones, alphas
ant2 = xds.ANTENNA2.data model = [] for col in model_cols: model.append(getattr(xds, col).data) model = da.stack(model, axis=2).rechunk({2: 3}) # reshape the correlation axis if model.shape[-1] > 2: n_row, n_chan, n_dir, n_corr = model.shape model = model.reshape(n_row, n_chan, n_dir, 2, 2) reshape_vis = True else: reshape_vis = False # apply gains corrupted_data = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones, model) if reshape_vis: corrupted_data = corrupted_data.reshape(n_row, n_chan, n_corr) # Assign visibilities to args.out_col and write to ms xds = xds.assign(**{args.out_col: (("row", "chan", "corr"), corrupted_data)}) # Create a write to the table write = xds_to_table(xds, args.ms, [args.out_col]) # Submit all graph computations in parallel with ProgressBar(): write.compute()
def both(args): """Generate model data, corrupted visibilities and gains (phase-only or normal)""" # Set thread count to cpu count if args.ncpu: from multiprocessing.pool import ThreadPool import dask dask.config.set(pool=ThreadPool(args.ncpu)) else: import multiprocessing args.ncpu = multiprocessing.cpu_count() # Get full time column and compute row chunks ms = xds_from_table(args.ms)[0] row_chunks, tbin_idx, tbin_counts = chunkify_rows( ms.TIME, args.utimes_per_chunk) # Convert time rows to dask arrays tbin_idx = da.from_array(tbin_idx, chunks=(args.utimes_per_chunk)) tbin_counts = da.from_array(tbin_counts, chunks=(args.utimes_per_chunk)) # Time axis n_time = tbin_idx.size # Get antenna columns ant1 = ms.ANTENNA1.data ant2 = ms.ANTENNA2.data # No. of antennas axis n_ant = (np.maximum(ant1.max(), ant2.max()) + 1).compute() # Get flag column flag = ms.FLAG.data # Get convention if args.phase_convention == 'CASA': uvw = -(ms.UVW.data.astype(np.float64)) elif args.phase_convention == 'CODEX': uvw = ms.UVW.data.astype(np.float64) else: raise ValueError("Unknown sign convention for phase") # Get rest of dimensions n_row, n_freq, n_corr = flag.shape # Raise error if correlation axis too small if n_corr != 4: raise NotImplementedError("Only 4 correlations "\ + "currently supported") # Get phase direction radec0_table = xds_from_table(args.ms+'::FIELD')[0] radec0 = radec0_table.PHASE_DIR.data.squeeze().compute() # Get frequency column freq_table = xds_from_table(args.ms+'::SPECTRAL_WINDOW')[0] freq = freq_table.CHAN_FREQ.data.astype(np.float64)[0] # Check dimension assert freq.size == n_freq # Check for sky-model if args.sky_model == 'MODEL-1.txt': args.sky_model = MODEL_1 elif args.sky_model == 'MODEL-4.txt': args.sky_model = MODEL_4 elif args.sky_model == 'MODEL-50.txt': args.sky_model = MODEL_50 else: raise NotImplemented(f"Sky-model {args.sky_model} not in "\ + "kalcal/datasets/sky_model/") # Build source model from lsm lsm = Tigger.load(args.sky_model) # Direction axis n_dir = len(lsm.sources) # Create initial model array model = np.zeros((n_dir, n_freq, n_corr), dtype=np.float64) # Create initial coordinate array and source names lm = np.zeros((n_dir, 2), dtype=np.float64) source_names = [] # Cycle coordinates creating a source with flux for d, source in enumerate(lsm.sources): # Extract name source_names.append(source.name) # Extract position radec_s = np.array([[source.pos.ra, source.pos.dec]]) lm[d] = radec_to_lm(radec_s, radec0) # Get flux - Stokes I if source.flux.I: I0 = source.flux.I # Get spectrum (only spi currently supported) tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 0] = I0 * (freq/ref_freq)**spi # Get flux - Stokes Q if source.flux.Q: Q0 = source.flux.Q # Get spectrum tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 1] = Q0 * (freq/ref_freq)**spi # Get flux - Stokes U if source.flux.U: U0 = source.flux.U # Get spectrum tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 2] = U0 * (freq/ref_freq)**spi # Get flux - Stokes V if source.flux.V: V0 = source.flux.V # Get spectrum tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 3] = V0 * (freq/ref_freq)**spi # Generate gains jones = None jones_shape = None # Dask to NP t = tbin_idx.compute() nu = freq.compute() print('==> Both-mode') if args.mode == "phase": jones = phase_gains(lm, nu, n_time, n_ant, args.alpha_std) elif args.mode == "normal": jones = normal_gains(t, nu, lm, n_ant, n_corr, args.sigma_f, args.lt, args.lnu, args.ls) else: raise ValueError("Only normal and phase modes available.") print() # Reduce jones to diagonals only jones = jones[:, :, :, :, (0, -1)] # Jones to complex jones = jones.astype(np.complex128) # Jones shape jones_shape = jones.shape # Generate filename if args.out == "": args.out = f"{args.mode}.npy" # Save gains and settings to file with open(args.out, 'wb') as file: np.save(file, jones) # Build dask graph lm = da.from_array(lm, chunks=lm.shape) model = da.from_array(model, chunks=model.shape) jones_da = da.from_array(jones, chunks=(args.utimes_per_chunk,) + jones_shape[1::]) # Append antenna columns cols = [] cols.append('ANTENNA1') cols.append('ANTENNA2') cols.append('UVW') # Load data in in chunks and apply gains to each chunk xds = xds_from_ms(args.ms, columns=cols, chunks={"row": row_chunks})[0] ant1 = xds.ANTENNA1.data ant2 = xds.ANTENNA2.data # Adjust UVW based on phase-convention if args.phase_convention == 'CASA': uvw = -xds.UVW.data.astype(np.float64) elif args.phase_convention == 'CODEX': uvw = xds.UVW.data.astype(np.float64) else: raise ValueError("Unknown sign convention for phase") # Get model visibilities model_vis = np.zeros((n_row, n_freq, n_dir, n_corr), dtype=np.complex128) for s in range(n_dir): model_vis[:, :, s] = im_to_vis( model[s].reshape((1, n_freq, n_corr)), uvw, lm[s].reshape((1, 2)), freq, dtype=np.complex64, convention='fourier') # NP to Dask model_vis = da.from_array(model_vis, chunks=(row_chunks, n_freq, n_dir, n_corr)) # Convert Stokes to corr in_schema = ['I', 'Q', 'U', 'V'] out_schema = [['RR', 'RL'], ['LR', 'LL']] model_vis = convert(model_vis, in_schema, out_schema) # Apply gains data = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones_da, model_vis).reshape( (n_row, n_freq, n_corr)) # Assign model visibilities out_names = [] for d in range(n_dir): xds = xds.assign(**{source_names[d]: (("row", "chan", "corr"), model_vis[:, :, d].reshape( n_row, n_freq, n_corr).astype(np.complex64))}) out_names += [source_names[d]] # Assign noise free visibilities to 'CLEAN_DATA' xds = xds.assign(**{'CLEAN_DATA': (("row", "chan", "corr"), data.astype(np.complex64))}) out_names += ['CLEAN_DATA'] # Get noise realisation if args.sigma_n > 0.0: # Noise matrix noise = (da.random.normal(loc=0.0, scale=args.sigma_n, size=(n_row, n_freq, n_corr), chunks=(row_chunks, n_freq, n_corr)) \ + 1.0j*da.random.normal(loc=0.0, scale=args.sigma_n, size=(n_row, n_freq, n_corr), chunks=(row_chunks, n_freq, n_corr)))/np.sqrt(2.0) # Zero matrix for off-diagonals zero = da.zeros_like(noise[:, :, 0]) # Dask to NP noise = noise.compute() zero = zero.compute() # Remove noise on off-diagonals noise[:, :, 1] = zero[:, :] noise[:, :, 2] = zero[:, :] # NP to Dask noise = da.from_array(noise, chunks=(row_chunks, n_freq, n_corr)) # Assign noise to 'NOISE' xds = xds.assign(**{'NOISE': (("row", "chan", "corr"), noise.astype(np.complex64))}) out_names += ['NOISE'] # Add noise to data and assign to 'DATA' noisy_data = data + noise xds = xds.assign(**{'DATA': (("row", "chan", "corr"), noisy_data.astype(np.complex64))}) out_names += ['DATA'] # Create a write to the table write = xds_to_table(xds, args.ms, out_names) # Submit all graph computations in parallel with ProgressBar(): write.compute() print(f"==> Applied Jones to MS: {args.ms} <--> {args.out}")
def _main(args): tic = time.time() log.info(banner()) if args.disable_post_mortem: log.warn("Disabling crash debugging with the " "Interactive Python Debugger, as per user request") post_mortem_handler.disable_pdb_on_error() log.info("Flagging on the {0:s} column".format(args.data_column)) data_column = args.data_column masked_channels = [ load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks() ] GD = args.config log_configuration(args) # Group datasets by these columns group_cols = ["FIELD_ID", "DATA_DESC_ID", "SCAN_NUMBER"] # Index datasets by these columns index_cols = ['TIME'] # Reopen the datasets using the aggregated row ordering columns = [data_column, "FLAG", "TIME", "ANTENNA1", "ANTENNA2"] if args.subtract_model_column is not None: columns.append(args.subtract_model_column) xds = list( xds_from_ms(args.ms, columns=tuple(columns), group_cols=group_cols, index_cols=index_cols, chunks={"row": args.row_chunks})) # Get support tables st = support_tables(args.ms) ddid_ds = st["DATA_DESCRIPTION"] field_ds = st["FIELD"] pol_ds = st["POLARIZATION"] spw_ds = st["SPECTRAL_WINDOW"] ant_ds = st["ANTENNA"] assert len(ant_ds) == 1 assert len(ddid_ds) == 1 antspos = ant_ds[0].POSITION.data antsnames = ant_ds[0].NAME.data fieldnames = [fds.NAME.data[0] for fds in field_ds] avail_scans = [ds.SCAN_NUMBER for ds in xds] args.scan_numbers = list( set(avail_scans).intersection(args.scan_numbers if args.scan_numbers is not None else avail_scans)) if args.scan_numbers != []: log.info("Only considering scans '{0:s}' as " "per user selection criterion".format(", ".join( map(str, map(int, args.scan_numbers))))) if args.field_names != []: flatten_field_names = [] for f in args.field_names: # accept comma lists per specification flatten_field_names += [x.strip() for x in f.split(",")] for f in flatten_field_names: if re.match(r"^\d+$", f) and int(f) < len(fieldnames): flatten_field_names.append(fieldnames[int(f)]) flatten_field_names = list( set( filter(lambda x: not re.match(r"^\d+$", x), flatten_field_names))) log.info("Only considering fields '{0:s}' for flagging per " "user " "selection criterion.".format(", ".join(flatten_field_names))) if not set(flatten_field_names) <= set(fieldnames): raise ValueError("One or more fields cannot be " "found in dataset '{0:s}' " "You specified {1:s}, but " "only {2:s} are available".format( args.ms, ",".join(flatten_field_names), ",".join(fieldnames))) field_dict = {fieldnames.index(fn): fn for fn in flatten_field_names} else: field_dict = {i: fn for i, fn in enumerate(fieldnames)} # List which hold our dask compute graphs for each dataset write_computes = [] original_stats = [] final_stats = [] # Iterate through each dataset for ds in xds: if ds.FIELD_ID not in field_dict: continue if (args.scan_numbers is not None and ds.SCAN_NUMBER not in args.scan_numbers): continue log.info("Adding field '{0:s}' scan {1:d} to " "compute graph for processing".format(field_dict[ds.FIELD_ID], ds.SCAN_NUMBER)) ddid = ddid_ds[ds.attrs['DATA_DESC_ID']] spw_info = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]] pol_info = pol_ds[ddid.POLARIZATION_ID.data[0]] nrow, nchan, ncorr = getattr(ds, data_column).data.shape # Visibilities from the dataset vis = getattr(ds, data_column).data if args.subtract_model_column is not None: log.info("Forming residual data between '{0:s}' and " "'{1:s}' for flagging.".format( data_column, args.subtract_model_column)) vismod = getattr(ds, args.subtract_model_column).data vis = vis - vismod antenna1 = ds.ANTENNA1.data antenna2 = ds.ANTENNA2.data chan_freq = spw_info.CHAN_FREQ.data[0] chan_width = spw_info.CHAN_WIDTH.data[0] # Generate unflagged defaults if we should ignore existing flags # otherwise take flags from the dataset if args.ignore_flags is True: flags = da.full_like(vis, False, dtype=np.bool) log.critical("Completely ignoring measurement set " "flags as per '-if' request. " "Strategy WILL NOT or with original flags, even if " "specified!") else: flags = ds.FLAG.data # If we're flagging on polarised intensity, # we convert visibilities to polarised intensity # and any flagged correlation will flag the entire visibility if args.flagging_strategy == "polarisation": corr_type = pol_info.CORR_TYPE.data[0].tolist() stokes_map = stokes_corr_map(corr_type) stokes_pol = tuple(v for k, v in stokes_map.items() if k != "I") vis = polarised_intensity(vis, stokes_pol) flags = da.any(flags, axis=2, keepdims=True) elif args.flagging_strategy == "total_power": if args.subtract_model_column is None: log.critical("You requested to flag total quadrature " "power, but not on residuals. " "This is not advisable and the flagger " "may mistake fringes of " "off-axis sources for broadband RFI.") corr_type = pol_info.CORR_TYPE.data[0].tolist() stokes_map = stokes_corr_map(corr_type) stokes_pol = tuple(v for k, v in stokes_map.items()) vis = polarised_intensity(vis, stokes_pol) flags = da.any(flags, axis=2, keepdims=True) elif args.flagging_strategy == "standard": if args.subtract_model_column is None: log.critical("You requested to flag per correlation, " "but not on residuals. " "This is not advisable and the flagger " "may mistake fringes of off-axis sources " "for broadband RFI.") else: raise ValueError("Invalid flagging strategy '%s'" % args.flagging_strategy) ubl = unique_baselines(antenna1, antenna2) utime, time_inv = da.unique(ds.TIME.data, return_inverse=True) utime, ubl = dask.compute(utime, ubl) ubl = ubl.view(np.int32).reshape(-1, 2) # Stack the baseline index with the unique baselines bl_range = np.arange(ubl.shape[0], dtype=ubl.dtype)[:, None] ubl = np.concatenate([bl_range, ubl], axis=1) ubl = da.from_array(ubl, chunks=(args.baseline_chunks, 3)) vis_windows, flag_windows = pack_data(time_inv, ubl, antenna1, antenna2, vis, flags, utime.shape[0], backend=args.window_backend, path=args.temporary_directory) original_stats.append( window_stats(flag_windows, ubl, chan_freq, antsnames, ds.SCAN_NUMBER, field_dict[ds.FIELD_ID], ds.attrs['DATA_DESC_ID'])) with StrategyExecutor(antspos, ubl, chan_freq, chan_width, masked_channels, GD['strategies']) as se: flag_windows = se.apply_strategies(flag_windows, vis_windows) final_stats.append( window_stats(flag_windows, ubl, chan_freq, antsnames, ds.SCAN_NUMBER, field_dict[ds.FIELD_ID], ds.attrs['DATA_DESC_ID'])) # Unpack window data for writing back to the MS unpacked_flags = unpack_data(antenna1, antenna2, time_inv, ubl, flag_windows) # Flag entire visibility if any correlations are flagged equalized_flags = da.sum(unpacked_flags, axis=2, keepdims=True) > 0 corr_flags = da.broadcast_to(equalized_flags, (nrow, nchan, ncorr)) if corr_flags.chunks != ds.FLAG.data.chunks: raise ValueError("Output flag chunking does not " "match input flag chunking") # Create new dataset containing new flags new_ds = ds.assign(FLAG=(("row", "chan", "corr"), corr_flags)) # Write back to original dataset writes = xds_to_table(new_ds, args.ms, "FLAG") # original should also have .compute called because we need stats write_computes.append(writes) if len(write_computes) > 0: # Combine stats from all datasets original_stats = combine_window_stats(original_stats) final_stats = combine_window_stats(final_stats) with contextlib.ExitStack() as stack: # Create dask profiling contexts profilers = [] if can_profile: profilers.append(stack.enter_context(Profiler())) profilers.append(stack.enter_context(CacheProfiler())) profilers.append(stack.enter_context(ResourceProfiler())) if sys.stdout.isatty(): # Interactive terminal, default ProgressBar stack.enter_context(ProgressBar()) else: # Non-interactive, emit a bar every 5 minutes so # as not to spam the log stack.enter_context(ProgressBar(minimum=1, dt=5 * 60)) _, original_stats, final_stats = dask.compute( write_computes, original_stats, final_stats) if can_profile: visualize(profilers) toc = time.time() # Log each summary line for line in summarise_stats(final_stats, original_stats): log.info(line) elapsed = toc - tic log.info("Data flagged successfully in " "{0:02.0f}h{1:02.0f}m{2:02.0f}s".format((elapsed // 60) // 60, (elapsed // 60) % 60, elapsed % 60)) else: log.info("User data selection criteria resulted in empty dataset. " "Nothing to be done. Bye!")
def _jones2col(**kw): args = OmegaConf.create(kw) from omegaconf import ListConfig if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig): args.ms = [args.ms] OmegaConf.set_struct(args, True) import numpy as np from daskms.experimental.zarr import xds_from_zarr from daskms import xds_from_ms, xds_to_table import dask.array as da import dask from africanus.calibration.utils import chunkify_rows from africanus.calibration.utils.dask import corrupt_vis # get net gains G = xds_from_zarr(args.gain_table + '::G') # chunking info t_chunks = G[0].t_chunk.data if len(t_chunks) > 1: t_chunks = G[0].t_chunk.data[1:-1] - G[0].t_chunk.data[0:-2] assert (t_chunks == t_chunks[0]).all() utpc = t_chunks[0] else: utpc = t_chunks[0] times = xds_from_ms(args.ms[0], columns=['TIME'])[0].get('TIME').data.compute() row_chunks, tbin_idx, tbin_counts = chunkify_rows(times, utimes_per_chunk=utpc, daskify_idx=True) f_chunks = G[0].f_chunk.data if len(f_chunks) > 1: f_chunks = G[0].f_chunk.data[1:-1] - G[0].f_chunk.data[0:-2] assert (f_chunks == f_chunks[0]).all() chan_chunks = f_chunks[0] else: if f_chunks[0]: chan_chunks = f_chunks[0] else: chan_chunks = -1 columns = ('DATA', 'FLAG', 'FLAG_ROW', 'ANTENNA1', 'ANTENNA2') if args.acol is not None: columns += (args.acol,) # open MS xds = xds_from_ms(args.ms[0], chunks={'row': row_chunks, 'chan': chan_chunks}, columns=columns, group_cols=('FIELD_ID', 'DATA_DESC_ID', 'SCAN_NUMBER')) # Current hack probably only works for single field and DDID try: assert len(xds) == len(G) except Exception as e: raise ValueError("Number of datasets in gains do not " "match those in MS") # assuming scans are aligned out_data = [] for g, ds in zip(G, xds): try: assert g.SCAN_NUMBER == ds.SCAN_NUMBER except Exception as e: raise ValueError("Scans not aligned") nrow = ds.dims['row'] nchan = ds.dims['chan'] ncorr = ds.dims['corr'] # need to swap axes for africanus jones = da.swapaxes(g.gains.data, 1, 2) flag = ds.FLAG.data frow = ds.FLAG_ROW.data ant1 = ds.ANTENNA1.data ant2 = ds.ANTENNA2.data frow = (frow | (ant1 == ant2)) flag = (flag[:, :, 0] | flag[:, :, -1]) flag = da.logical_or(flag, frow[:, None]) if args.acol is not None: acol = ds.get(args.acol).data.reshape(nrow, nchan, 1, ncorr) else: acol = da.ones((nrow, nchan, 1, ncorr), chunks=(row_chunks, chan_chunks, 1, -1), dtype=jones.dtype) cvis = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones, acol) # compare where unflagged if args.compareto is not None: flag = flag.compute() vis = ds.get(args.compareto).values[~flag] print("Max abs difference = ", np.abs(cvis.compute()[~flag] - vis).max()) quit() out_ds = ds.assign(**{args.mueller_column: (("row", "chan", "corr"), cvis)}) out_data.append(out_ds) writes = xds_to_table(out_data, args.ms[0], columns=[args.mueller_column]) dask.compute(writes)
def new(ms, sky_model, gains, **kwargs): """Generate model visibilties per source (as direction axis) for stokes I and Q and generate relevant visibilities.""" # Options to attributed dictionary if kwargs["yaml"] is not None: options = ocf.load(kwargs["yaml"]) else: options = ocf.create(kwargs) # Set to struct ocf.set_struct(options, True) # Change path to sky model if chosen try: sky_model = sky_models[sky_model.lower()] except: # Own sky model reference pass # Set thread count to cpu count if options.ncpu: from multiprocessing.pool import ThreadPool import dask dask.config.set(pool=ThreadPool(options.ncpu)) else: import multiprocessing options.ncpu = multiprocessing.cpu_count() # Load gains to corrupt with with open(gains, "rb") as file: jones = np.load(file) # Load dimensions n_time, n_ant, n_chan, n_dir, n_corr = jones.shape n_row = n_time * (n_ant * (n_ant - 1) // 2) # Load ms MS = xds_from_ms(ms)[0] # Get time-bin indices and counts row_chunks, tbin_indices, tbin_counts = chunkify_rows( MS.TIME, options.utime) # Close and reopen with chunked rows MS.close() MS = xds_from_ms(ms, chunks={"row": row_chunks})[0] # Get antenna arrays (dask ignored for now) ant1 = MS.ANTENNA1.data ant2 = MS.ANTENNA2.data # Adjust UVW based on phase-convention if options.phase_convention.upper() == 'CASA': uvw = -MS.UVW.data.astype(np.float64) elif options.phase_convention.upper() == 'CODEX': uvw = MS.UVW.data.astype(np.float64) else: raise ValueError("Unknown sign convention for phase.") # MS dimensions dims = ocf.create(dict(MS.sizes)) # Close MS MS.close() # Build source model from lsm lsm = Tigger.load(sky_model) # Check if dimensions match jones assert n_time * (n_ant * (n_ant - 1) // 2) == dims.row assert n_time == len(tbin_indices) assert n_ant == np.max((np.max(ant1), np.max(ant2))) + 1 assert n_chan == dims.chan assert n_corr == dims.corr # If gains are DIE if options.die: assert n_dir == 1 n_dir = len(lsm.sources) else: assert n_dir == len(lsm.sources) # Get phase direction radec0_table = xds_from_table(ms + '::FIELD')[0] radec0 = radec0_table.PHASE_DIR.data.squeeze().compute() radec0_table.close() # Get frequency column freq_table = xds_from_table(ms + '::SPECTRAL_WINDOW')[0] freq = freq_table.CHAN_FREQ.data.astype(np.float64)[0] freq_table.close() # Get feed orientation feed_table = xds_from_table(ms + '::FEED')[0] feeds = feed_table.POLARIZATION_TYPE.data[0].compute() # Create initial model array model = np.zeros((n_dir, n_chan, n_corr), dtype=np.float64) # Create initial coordinate array and source names lm = np.zeros((n_dir, 2), dtype=np.float64) source_names = [] # Cycle coordinates creating a source with flux print("==> Building model visibilities") for d, source in enumerate(lsm.sources): # Extract name source_names.append(source.name) # Extract position radec_s = np.array([[source.pos.ra, source.pos.dec]]) lm[d] = radec_to_lm(radec_s, radec0) # Get flux - Stokes I if source.flux.I: I0 = source.flux.I # Get spectrum (only spi currently supported) tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 0] = I0 * (freq / ref_freq)**spi # Get flux - Stokes Q if source.flux.Q: Q0 = source.flux.Q # Get spectrum tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 1] = Q0 * (freq / ref_freq)**spi # Get flux - Stokes U if source.flux.U: U0 = source.flux.U # Get spectrum tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 2] = U0 * (freq / ref_freq)**spi # Get flux - Stokes V if source.flux.V: V0 = source.flux.V # Get spectrum tmp_spec = source.spectrum spi = [tmp_spec.spi if tmp_spec is not None else 0.0] ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0] # Generate model flux model[d, :, 3] = V0 * (freq / ref_freq)**spi # Close sky-model del lsm # Build dask graph tbin_indices = da.from_array(tbin_indices, chunks=(options.utime)) tbin_counts = da.from_array(tbin_counts, chunks=(options.utime)) lm = da.from_array(lm, chunks=lm.shape) model = da.from_array(model, chunks=model.shape) jones = da.from_array(jones, chunks=(options.utime, ) + jones.shape[1::]) # Apply image to visibility for each source sources = [] for s in range(n_dir): source_vis = im_to_vis(model[s].reshape((1, n_chan, n_corr)), uvw, lm[s].reshape((1, 2)), freq, dtype=np.complex64, convention='fourier') sources.append(source_vis) model_vis = da.stack(sources, axis=2) # Sum over direction? if options.die: model_vis = da.sum(model_vis, axis=2, keepdims=True) n_dir = 1 source_names = [options.mname] # Select schema based on feed orientation if (feeds == ["X", "Y"]).all(): out_schema = [["XX", "XY"], ["YX", "YY"]] elif (feeds == ["R", "L"]).all(): out_schema = [['RR', 'RL'], ['LR', 'LL']] else: raise ValueError("Unknown feed orientation implementation.") # Convert Stokes to Correlations in_schema = ['I', 'Q', 'U', 'V'] model_vis = convert(model_vis, in_schema, out_schema).reshape( (n_row, n_chan, n_dir, n_corr)) # Apply gains to model_vis print("==> Corrupting visibilities") data = corrupt_vis(tbin_indices, tbin_counts, ant1, ant2, jones, model_vis) # Reopen MS MS = xds_from_ms(ms, chunks={"row": row_chunks})[0] # Assign model visibilities out_names = [] for d in range(n_dir): MS = MS.assign( **{ source_names[d]: (("row", "chan", "corr"), model_vis[:, :, d].astype(np.complex64)) }) out_names += [source_names[d]] # Assign noise free visibilities to 'CLEAN_DATA' MS = MS.assign( **{ 'CLEAN_' + options.dname: (("row", "chan", "corr"), data.astype(np.complex64)) }) out_names += ['CLEAN_' + options.dname] # Get noise realisation if options.std > 0.0: # Noise matrix print(f"==> Applying noise (std={options.std}) to visibilities") noise = [] for i in range(2): real = da.random.normal(loc=0.0, scale=options.std, size=(n_row, n_chan), chunks=(row_chunks, n_chan)) imag = 1.0j * (da.random.normal(loc=0.0, scale=options.std, size=(n_row, n_chan), chunks=(row_chunks, n_chan))) noise.append(real + imag) # Zero matrix for off-diagonals zero = da.zeros((n_row, n_chan), chunks=(row_chunks, n_chan)) noise.insert(1, zero) noise.insert(2, zero) # NP to Dask noise = da.stack(noise, axis=2).rechunk((row_chunks, n_chan, n_corr)) # Assign noise to 'NOISE' MS = MS.assign( **{'NOISE': (("row", "chan", "corr"), noise.astype(np.complex64))}) out_names += ['NOISE'] # Add noise to data and assign to 'DATA' noisy_data = data + noise MS = MS.assign( **{ options.dname: (("row", "chan", "corr"), noisy_data.astype(np.complex64)) }) out_names += [options.dname] # Create a write to the table write = xds_to_table(MS, ms, out_names) # Submit all graph computations in parallel print(f"==> Executing `dask-ms` write to `{ms}` for the following columns: "\ + f"{', '.join(out_names)}") with ProgressBar(): write.compute() print(f"==> Completed.")
def read_ms(ms, num_vis, res_arcmin, chunks=50000, channel=0): ''' Use dask-ms to load the necessary data to create a telescope operator (will use uvw positions, and antenna positions) -- res_arcmin: Used to calculate the maximum baselines to consider. We want two pixels per smallest fringe pix_res > fringe / 2 u sin(theta) = n (for nth fringe) at small angles: theta = 1/u, or u_max = 1 / theta d sin(theta) = lambda / 2 d / lambda = 1 / (2 sin(theta)) u_max = lambda / 2sin(theta) ''' with scheduler_context(): # Create a dataset representing the entire antenna table ant_table = '::'.join((ms, 'ANTENNA')) for ant_ds in xds_from_table(ant_table): #print(ant_ds) #print(dask.compute(ant_ds.NAME.data, #ant_ds.POSITION.data, #ant_ds.DISH_DIAMETER.data)) ant_p = np.array(ant_ds.POSITION.data) logger.info("Antenna Positions {}".format(ant_p.shape)) # Create a dataset representing the field field_table = '::'.join((ms, 'FIELD')) for field_ds in xds_from_table(field_table): #print(ant_ds) #print(dask.compute(ant_ds.NAME.data, #ant_ds.POSITION.data, #ant_ds.DISH_DIAMETER.data)) phase_dir = np.array(field_ds.PHASE_DIR.data)[0].flatten() logger.info("Phase Dir {}".format(np.degrees(phase_dir))) # Create datasets representing each row of the spw table spw_table = '::'.join((ms, 'SPECTRAL_WINDOW')) for spw_ds in xds_from_table(spw_table, group_cols="__row__"): #print(spw_ds) #print(spw_ds.NUM_CHAN.values) logger.info("CHAN_FREQ.values: {}".format( spw_ds.CHAN_FREQ.values.shape)) frequencies = dask.compute(spw_ds.CHAN_FREQ.values)[0].flatten() frequency = frequencies[channel] logger.info("Frequencies = {}".format(frequencies)) logger.info("Frequency = {}".format(frequency)) logger.info("NUM_CHAN = %f" % np.array(spw_ds.NUM_CHAN.values)[0]) # Create datasets from a partioning of the MS datasets = list(xds_from_ms(ms, chunks={'row': chunks})) pol = 0 for ds in datasets: logger.info("DATA shape: {}".format(ds.DATA.data.shape)) logger.info("UVW shape: {}".format(ds.UVW.data.shape)) uvw = np.array(ds.UVW.data) # UVW is stored in meters! ant1 = np.array(ds.ANTENNA1.data) ant2 = np.array(ds.ANTENNA2.data) flags = np.array(ds.FLAG.data) cv_vis = np.array(ds.DATA.data)[:, channel, pol] epoch_seconds = np.array(ds.TIME.data)[0] # Try write the STATE_ID column back write = xds_to_table(ds, ms, 'STATE_ID') with ProgressBar(), Profiler() as prof: write.compute() # Profile #prof.visualize(file_path="chunked.html") ### NOW REMOVE DATA THAT DOESN'T FIT THE IMAGE RESOLUTION u_max = get_resolution_max_baseline(res_arcmin, frequency) logger.info("Resolution Max UVW: {:g}".format(u_max)) logger.info("Flags: {}".format(flags.shape)) # Now report the recommended resolution from the data. # 1.0 / 2*np.sin(theta) = limit_u limit_uvw = np.max(np.abs(uvw), 0) res_limit = get_baseline_resolution(limit_uvw[0], frequency) logger.info("Nyquist resolution: {:g} arcmin".format( np.degrees(res_limit) * 60.0)) #maxuvw = np.max(np.abs(uvw), 1) #logger.info(np.random.choice(maxuvw, 100)) if False: good_data = np.array(np.where(flags[:, channel, pol] == 0)).T.reshape((-1, )) else: good_data = np.array( np.where((flags[:, channel, pol] == 0) & (np.max(np.abs(uvw), 1) < u_max))).T.reshape((-1, )) logger.info("Good Data {}".format(good_data.shape)) logger.info("Maximum UVW: {}".format(limit_uvw)) logger.info("Minimum UVW: {}".format(np.min(np.abs(uvw), 0))) n_ant = len(ant_p) good_vis = cv_vis[good_data] n_max = len(good_vis) indices = np.random.choice(good_data, min(num_vis, n_max)) hdr = { 'CTYPE1': ('RA---SIN', "Right ascension angle cosine"), 'CRVAL1': np.degrees(phase_dir)[0], 'CUNIT1': 'deg ', 'CTYPE2': ('DEC--SIN', "Declination angle cosine "), 'CRVAL2': np.degrees(phase_dir)[1], 'CUNIT2': 'deg ', 'CTYPE3': 'FREQ ', # / Central frequency ", 'CRPIX3': 1., 'CRVAL3': "{}".format(frequency), 'CDELT3': 10026896.158854, 'CUNIT3': 'Hz ', 'EQUINOX': '2000.', 'DATE-OBS': "{}".format(epoch_seconds), 'BTYPE': 'Intensity' } #from astropy.wcs.utils import celestial_frame_to_wcs #from astropy.coordinates import FK5 #frame = FK5(equinox='J2010') #wcs = celestial_frame_to_wcs(frame) #wcs.to_header() u_arr = uvw[indices, 0] v_arr = uvw[indices, 1] w_arr = uvw[indices, 2] cv_vis = cv_vis[indices] # Convert from reduced Julian Date to timestamp. timestamp = datetime.datetime( 1858, 11, 17, 0, 0, 0, tzinfo=datetime.timezone.utc) + datetime.timedelta( seconds=epoch_seconds) return u_arr, v_arr, w_arr, frequency, cv_vis, hdr, timestamp
def compute_weights(self, robust): from pfb.utils.weighting import compute_counts, counts_to_weights # compute counts counts = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=('UVW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # not optimal, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data count = compute_counts(uvw, freq, freq_bin_idx, freq_bin_counts, self.nx, self.ny, self.cell, self.cell, np.float32) counts.append(count) counts = dask.compute(counts)[0] counts = accumulate_dirty(counts, self.nband, self.band_mapping) counts = da.from_array(counts, chunks=(1, -1, -1)) # convert counts to weights writes = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data weights = counts_to_weights(counts, uvw, freq, freq_bin_idx, freq_bin_counts, self.nx, self.ny, self.cell, self.cell, np.float32, robust) # hack to get shape and chunking info data = getattr(ds, self.data_column).data weights = da.broadcast_to(weights[:, :, None], data.shape, chunks=data.chunks) out_ds = ds.assign(**{ self.imaging_weight_column: (("row", "chan", "corr"), weights) }) out_data.append(out_ds) writes.append( xds_to_table(out_data, ims, columns=[self.imaging_weight_column])) dask.compute(writes)
def _predict(args): import pkg_resources version = pkg_resources.get_distribution("crystalball").version log.info("Crystalball version {0}", version) # get inclusion regions include_regions = load_regions(args.within) if args.within else [] # Import source data from WSClean component list # See https://sourceforge.net/p/wsclean/wiki/ComponentList source_model = import_from_wsclean(args.sky_model, include_regions=include_regions, point_only=args.points_only, num=args.num_sources or None) # Add output column if it isn't present ms_rows, ms_datatype = ms_preprocess(args) # Get the support tables tables = support_tables( args, ["FIELD", "DATA_DESCRIPTION", "SPECTRAL_WINDOW", "POLARIZATION"]) field_ds = tables["FIELD"] ddid_ds = tables["DATA_DESCRIPTION"] spw_ds = tables["SPECTRAL_WINDOW"] pol_ds = tables["POLARIZATION"] max_num_chan = max([ss.NUM_CHAN.data[0] for ss in spw_ds]) max_num_corr = max([ss.NUM_CORR.data[0] for ss in pol_ds]) # Perform resource budgeting nsources = source_model.source_type.shape[0] args.row_chunks, args.model_chunks = get_budget(nsources, ms_rows, max_num_chan, max_num_corr, ms_datatype, args) source_model = source_model_to_dask(source_model, args.model_chunks) # List of write operations writes = [] datasets = xds_from_ms(args.ms, columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"], group_cols=["FIELD_ID", "DATA_DESC_ID"], chunks={"row": args.row_chunks}) field_id = select_field_id(field_ds, args.field) for xds in filter_datasets(datasets, field_id): # Extract frequencies from the spectral window associated # with this data descriptor id field = field_ds[xds.attrs['FIELD_ID']] ddid = ddid_ds[xds.attrs['DATA_DESC_ID']] spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]] pol = pol_ds[ddid.POLARIZATION_ID.data[0]] frequency = spw.CHAN_FREQ.data[0] lm = radec_to_lm(source_model.radec, field.PHASE_DIR.data[0][0]) with warnings.catch_warnings(): # Ignore dask chunk warnings emitted when going from 1D # inputs to a 2D space of chunks warnings.simplefilter('ignore', category=PerformanceWarning) vis = wsclean_predict(xds.UVW.data, lm, source_model.source_type, source_model.flux, source_model.spi, source_model.log_poly, source_model.ref_freq, source_model.gauss_shape, frequency) vis = fill_correlations(vis, pol) log.info('Field {0} DDID {1:d} rows {2} chans {3} corrs {4}', field.NAME.values[0], xds.DATA_DESC_ID, vis.shape[0], vis.shape[1], vis.shape[2]) # Assign visibilities to MODEL_DATA array on the dataset xds = xds.assign( **{args.output_column: (("row", "chan", "corr"), vis)}) # Create a write to the table write = xds_to_table(xds, args.ms, [args.output_column]) # Add to the list of writes writes.append(write) with ExitStack() as stack: if sys.stdout.isatty(): # Default progress bar in user terminal stack.enter_context(EstimatingProgressBar()) else: # Log progress every 5 minutes stack.enter_context(EstimatingProgressBar(minimum=2 * 60, dt=5)) # Submit all graph computations in parallel dask.compute(writes) log.info("Finished")
def test_ms_create(Dataset, tmp_path, chunks, num_chans, corr_types, sources): # Set up rs = np.random.RandomState(42) ms_path = tmp_path / "create.ms" ms_table_name = str(ms_path) ant_table_name = "::".join((ms_table_name, "ANTENNA")) ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION")) pol_table_name = "::".join((ms_table_name, "POLARIZATION")) spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW")) # SOURCE is an optional MS sub-table src_table_name = "::".join((ms_table_name, "SOURCE")) ms_datasets = [] ant_datasets = [] ddid_datasets = [] pol_datasets = [] spw_datasets = [] src_datasets = [] # For comparison all_data_desc_id = [] all_data = [] # Create ANTENNA dataset of 64 antennas # Each column in the ANTENNA has a fixed shape so we # can represent all rows with one dataset na = 64 position = da.random.random((na, 3)) * 10000 offset = da.random.random((na, 3)) names = np.array(['ANTENNA-%d' % i for i in range(na)], dtype=np.object) ds = Dataset({ 'POSITION': (("row", "xyz"), position), 'OFFSET': (("row", "xyz"), offset), 'NAME': (("row", ), da.from_array(names, chunks=na)), }) ant_datasets.append(ds) # Create SOURCE datasets for s, (name, direction, rest_freq) in enumerate(sources): dask_num_lines = da.full((1, ), len(rest_freq), dtype=np.int32) dask_direction = da.asarray(direction)[None, :] dask_rest_freq = da.asarray(rest_freq)[None, :] dask_name = da.asarray(np.asarray([name], dtype=np.object)) ds = Dataset({ "NUM_LINES": (("row", ), dask_num_lines), "NAME": (("row", ), dask_name), "REST_FREQUENCY": (("row", "line"), dask_rest_freq), "DIRECTION": (("row", "dir"), dask_direction), }) src_datasets.append(ds) # Create POLARISATION datasets. # Dataset per output row required because column shapes are variable for r, corr_type in enumerate(corr_types): dask_num_corr = da.full((1, ), len(corr_type), dtype=np.int32) dask_corr_type = da.from_array(corr_type, chunks=len(corr_type))[None, :] ds = Dataset({ "NUM_CORR": (("row", ), dask_num_corr), "CORR_TYPE": (("row", "corr"), dask_corr_type), }) pol_datasets.append(ds) # Create multiple MeerKAT L-band SPECTRAL_WINDOW datasets # Dataset per output row required because column shapes are variable for num_chan in num_chans: dask_num_chan = da.full((1, ), num_chan, dtype=np.int32) dask_chan_freq = da.linspace(.856e9, 2 * .856e9, num_chan, chunks=num_chan)[None, :] dask_chan_width = da.full((1, num_chan), .856e9 / num_chan) ds = Dataset({ "NUM_CHAN": (("row", ), dask_num_chan), "CHAN_FREQ": (("row", "chan"), dask_chan_freq), "CHAN_WIDTH": (("row", "chan"), dask_chan_width), }) spw_datasets.append(ds) # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION # create a corresponding DATA_DESCRIPTION. # Each column has fixed shape so we handle all rows at once spw_ids, pol_ids = zip( *product(range(len(num_chans)), range(len(corr_types)))) dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32)) dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32)) ddid_datasets.append( Dataset({ "SPECTRAL_WINDOW_ID": (("row", ), dask_spw_ids), "POLARIZATION_ID": (("row", ), dask_pol_ids), })) # Now create the associated MS dataset for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)): # Infer row, chan and correlation shape row = sum(chunks['row']) chan = spw_datasets[spw_id].CHAN_FREQ.shape[1] corr = pol_datasets[pol_id].CORR_TYPE.shape[1] # Create some dask vis data dims = ("row", "chan", "corr") np_data = (rs.normal(size=(row, chan, corr)) + 1j * rs.normal(size=(row, chan, corr))).astype(np.complex64) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) # Create dask ddid column dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'DATA_DESC_ID': (("row", ), dask_ddid) }) ms_datasets.append(dataset) all_data.append(dask_data) all_data_desc_id.append(dask_ddid) ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL") ant_writes = xds_to_table(ant_datasets, ant_table_name, columns="ALL") pol_writes = xds_to_table(pol_datasets, pol_table_name, columns="ALL") spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL") ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL") source_writes = xds_to_table(src_datasets, src_table_name, columns="ALL") dask.compute(ms_writes, ant_writes, pol_writes, spw_writes, ddid_writes, source_writes) # Check ANTENNA table correctly created with pt.table(ant_table_name, ack=False) as A: assert_array_equal(A.getcol("NAME"), names) assert_array_equal(A.getcol("POSITION"), position) assert_array_equal(A.getcol("OFFSET"), offset) required_desc = pt.required_ms_desc("ANTENNA") required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) assert set(A.colnames()) == set(required_columns) # Check POLARIZATION table correctly created with pt.table(pol_table_name, ack=False) as P: for r, corr_type in enumerate(corr_types): assert_array_equal(P.getcol("CORR_TYPE", startrow=r, nrow=1), [corr_type]) assert_array_equal(P.getcol("NUM_CORR", startrow=r, nrow=1), [len(corr_type)]) required_desc = pt.required_ms_desc("POLARIZATION") required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) assert set(P.colnames()) == set(required_columns) # Check SPECTRAL_WINDOW table correctly created with pt.table(spw_table_name, ack=False) as S: for r, num_chan in enumerate(num_chans): assert_array_equal( S.getcol("NUM_CHAN", startrow=r, nrow=1)[0], num_chan) assert_array_equal( S.getcol("CHAN_FREQ", startrow=r, nrow=1)[0], np.linspace(.856e9, 2 * .856e9, num_chan)) assert_array_equal( S.getcol("CHAN_WIDTH", startrow=r, nrow=1)[0], np.full(num_chan, .856e9 / num_chan)) required_desc = pt.required_ms_desc("SPECTRAL_WINDOW") required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) assert set(S.colnames()) == set(required_columns) # We should get a cartesian product out with pt.table(ddid_table_name, ack=False) as D: spw_id, pol_id = zip( *product(range(len(num_chans)), range(len(corr_types)))) assert_array_equal(pol_id, D.getcol("POLARIZATION_ID")) assert_array_equal(spw_id, D.getcol("SPECTRAL_WINDOW_ID")) required_desc = pt.required_ms_desc("DATA_DESCRIPTION") required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) assert set(D.colnames()) == set(required_columns) with pt.table(src_table_name, ack=False) as S: for r, (name, direction, rest_freq) in enumerate(sources): assert_array_equal(S.getcol("NAME", startrow=r, nrow=1)[0], [name]) assert_array_equal(S.getcol("REST_FREQUENCY", startrow=r, nrow=1), [rest_freq]) assert_array_equal(S.getcol("DIRECTION", startrow=r, nrow=1), [direction]) with pt.table(ms_table_name, ack=False) as T: # DATA_DESC_ID's are all the same shape assert_array_equal(T.getcol("DATA_DESC_ID"), da.concatenate(all_data_desc_id)) # DATA is variably shaped (on DATA_DESC_ID) so we # compared each one separately. for ddid, data in enumerate(all_data): ms_data = T.getcol("DATA", startrow=ddid * row, nrow=row) assert_array_equal(ms_data, data) required_desc = pt.required_ms_desc() required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) # Check we have the required columns assert set(T.colnames()) == required_columns.union( ["DATA", "DATA_DESC_ID"])
def ms_create(ms_table_name, info, ant_pos, cal_vis, timestamps, corr_types, sources): ''' "info": { "info": { "L0_frequency": 1571328000.0, "bandwidth": 2500000.0, "baseband_frequency": 4092000.0, "location": { "alt": 270.0, "lat": -45.85177, "lon": 170.5456 }, "name": "Signal Hill - Dunedin", "num_antenna": 24, "operating_frequency": 1575420000.0, "sampling_frequency": 16368000.0 } }, ''' num_chans = [1] rs = np.random.RandomState(42) ant_table_name = "::".join((ms_table_name, "ANTENNA")) feed_table_name = "::".join((ms_table_name, "FEED")) field_table_name = "::".join((ms_table_name, "FIELD")) obs_table_name = "::".join((ms_table_name, "OBSERVATION")) ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION")) pol_table_name = "::".join((ms_table_name, "POLARIZATION")) spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW")) # SOURCE is an optional MS sub-table src_table_name = "::".join((ms_table_name, "SOURCE")) ms_datasets = [] ant_datasets = [] feed_datasets = [] field_datasets = [] obs_datasets = [] ddid_datasets = [] pol_datasets = [] spw_datasets = [] src_datasets = [] # Create ANTENNA dataset # Each column in the ANTENNA has a fixed shape so we # can represent all rows with one dataset na = len(ant_pos) position = da.asarray(ant_pos) # da.zeros((na, 3)) diameter = da.ones(na) * 0.025 offset = da.zeros((na, 3)) names = np.array(['ANTENNA-%d' % i for i in range(na)], dtype=np.object) stations = np.array([info['name'] for i in range(na)], dtype=np.object) ds = Dataset({ 'POSITION': (("row", "xyz"), position), 'OFFSET': (("row", "xyz"), offset), 'DISH_DIAMETER': (("row", ), diameter), 'NAME': (("row", ), da.from_array(names, chunks=na)), 'STATION': (("row", ), da.from_array(stations, chunks=na)), }) ant_datasets.append(ds) ######################################### Create a FEED dataset. ####################################### # There is one feed per antenna, so this should be quite similar to the ANTENNA antenna_ids = da.asarray(range(na)) feed_ids = da.zeros(na) num_receptors = da.ones(na) polarization_types = np.array([['XX'] for i in range(na)], dtype=np.object) receptor_angles = np.array([[0.0] for i in range(na)]) pol_response = np.array([[[0.0 + 1.0j, 1.0 - 1.0j]] for i in range(na)]) beam_offset = np.array([[[0.0, 0.0]] for i in range(na)]) ds = Dataset({ 'ANTENNA_ID': (("row", ), antenna_ids), 'FEED_ID': (("row", ), feed_ids), 'NUM_RECEPTORS': (("row", ), num_receptors), 'POLARIZATION_TYPE': (( "row", "receptors", ), da.from_array(polarization_types, chunks=na)), 'RECEPTOR_ANGLE': (( "row", "receptors", ), da.from_array(receptor_angles, chunks=na)), 'POL_RESPONSE': (("row", "receptors", "receptors-2"), da.from_array(pol_response, chunks=na)), 'BEAM_OFFSET': (("row", "receptors", "radec"), da.from_array(beam_offset, chunks=na)), }) feed_datasets.append(ds) ########################################### FIELD dataset ################################################ direction = [[np.radians(90.0), np.radians(0.0)]] ## Phase Center in J2000 field_direction = da.asarray(direction)[None, :] field_name = da.asarray(np.asarray(['up'], dtype=np.object)) field_num_poly = da.zeros( 1) # Zero order polynomial in time for phase center. dir_dims = ( "row", 'field-poly', 'field-dir', ) ds = Dataset({ 'PHASE_DIR': (dir_dims, field_direction), 'DELAY_DIR': (dir_dims, field_direction), 'REFERENCE_DIR': (dir_dims, field_direction), 'NUM_POLY': (("row", ), field_num_poly), 'NAME': (("row", ), field_name), }) field_datasets.append(ds) ########################################### OBSERVATION dataset ################################################ ds = Dataset({ 'TELESCOPE_NAME': (("row", ), da.asarray(np.asarray(['TART'], dtype=np.object))), 'OBSERVER': (("row", ), da.asarray(np.asarray(['Tim'], dtype=np.object))), }) obs_datasets.append(ds) #################################### Create SOURCE datasets ############################################# for s, src in enumerate(sources): name = src['name'] rest_freq = [info['operating_frequency']] direction = [ np.radians(src['el']), np.radians(src['az']) ] ## FIXME these are in elevation and azimuth. Not in J2000. #logger.info("SOURCE: {}, timestamp: {}".format(name, timestamps)) dask_num_lines = da.full((1, ), len(rest_freq), dtype=np.int32) dask_direction = da.asarray(direction)[None, :] dask_rest_freq = da.asarray(rest_freq)[None, :] dask_name = da.asarray(np.asarray([name], dtype=np.object)) dask_time = da.asarray(np.asarray([timestamps], dtype=np.object)) ds = Dataset({ "NUM_LINES": (("row", ), dask_num_lines), "NAME": (("row", ), dask_name), #"TIME": (("row",), dask_time), # FIXME. Causes an error. Need to sort out TIME data fields #"REST_FREQUENCY": (("row", "line"), dask_rest_freq), "DIRECTION": (("row", "dir"), dask_direction), }) src_datasets.append(ds) # Create POLARISATION datasets. # Dataset per output row required because column shapes are variable for r, corr_type in enumerate(corr_types): dask_num_corr = da.full((1, ), len(corr_type), dtype=np.int32) dask_corr_type = da.from_array(corr_type, chunks=len(corr_type))[None, :] dask_corr_type = da.from_array(corr_type, chunks=len(corr_type))[None, :] ds = Dataset({ "NUM_CORR": (("row", ), dask_num_corr), #"CORR_PRODUCT": (("row",), dask_num_corr), "CORR_TYPE": (("row", "corr"), dask_corr_type), }) pol_datasets.append(ds) # Create multiple SPECTRAL_WINDOW datasets # Dataset per output row required because column shapes are variable for num_chan in num_chans: dask_num_chan = da.full((1, ), num_chan, dtype=np.int32) dask_chan_freq = da.asarray([[info['operating_frequency']]]) dask_chan_width = da.full((1, num_chan), 2.5e6 / num_chan) ds = Dataset({ "NUM_CHAN": (("row", ), dask_num_chan), "CHAN_FREQ": (("row", "chan"), dask_chan_freq), "CHAN_WIDTH": (("row", "chan"), dask_chan_width), }) spw_datasets.append(ds) # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION # create a corresponding DATA_DESCRIPTION. # Each column has fixed shape so we handle all rows at once spw_ids, pol_ids = zip( *product(range(len(num_chans)), range(len(corr_types)))) dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32)) dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32)) ddid_datasets.append( Dataset({ "SPECTRAL_WINDOW_ID": (("row", ), dask_spw_ids), "POLARIZATION_ID": (("row", ), dask_pol_ids), })) # Now create the associated MS dataset vis_data, baselines = cal_vis.get_all_visibility() vis_array = np.array(vis_data, dtype=np.complex64) chunks = { "row": (vis_array.shape[0], ), } baselines = np.array(baselines) bl_pos = np.array(ant_pos)[baselines] uu_a, vv_a, ww_a = -( bl_pos[:, 1] - bl_pos[:, 0] ).T / constants.L1_WAVELENGTH # Use the - sign to get the same orientation as our tart projections. uvw_array = np.array([uu_a, vv_a, ww_a]).T for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)): # Infer row, chan and correlation shape logger.info("ddid:{} ({}, {})".format(ddid, spw_id, pol_id)) row = sum(chunks['row']) chan = spw_datasets[spw_id].CHAN_FREQ.shape[1] corr = pol_datasets[pol_id].CORR_TYPE.shape[1] # Create some dask vis data dims = ("row", "chan", "corr") logger.info("Data size {}".format((row, chan, corr))) np_data = vis_array.reshape((row, chan, corr)) np_uvw = uvw_array.reshape((row, 3)) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) uvw_data = da.from_array(np_uvw) # Create dask ddid column dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'WEIGHT': (("row", "corr"), da.from_array(0.95 * np.ones((row, corr)))), 'WEIGHT_SPECTRUM': (dims, da.from_array(0.95 * np.ones_like(np_data, dtype=np.float64))), 'SIGMA_SPECTRUM': (dims, da.from_array(np.ones_like(np_data, dtype=np.float64) * 0.05)), 'UVW': (( "row", "uvw", ), uvw_data), 'ANTENNA1': (("row", ), da.from_array(baselines[:, 0])), 'ANTENNA2': (("row", ), da.from_array(baselines[:, 1])), 'FEED1': (("row", ), da.from_array(baselines[:, 0])), 'FEED2': (("row", ), da.from_array(baselines[:, 1])), 'DATA_DESC_ID': (("row", ), dask_ddid) }) ms_datasets.append(dataset) ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL") ant_writes = xds_to_table(ant_datasets, ant_table_name, columns="ALL") feed_writes = xds_to_table(feed_datasets, feed_table_name, columns="ALL") field_writes = xds_to_table(field_datasets, field_table_name, columns="ALL") obs_writes = xds_to_table(obs_datasets, obs_table_name, columns="ALL") pol_writes = xds_to_table(pol_datasets, pol_table_name, columns="ALL") spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL") ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL") source_writes = xds_to_table(src_datasets, src_table_name, columns="ALL") dask.compute(ms_writes) dask.compute(ant_writes) dask.compute(feed_writes) dask.compute(field_writes) dask.compute(obs_writes) dask.compute(pol_writes) dask.compute(spw_writes) dask.compute(ddid_writes) dask.compute(source_writes)
# Create a dataset representing the entire antenna table ant_table = '::'.join((args.ms, 'ANTENNA')) for ant_ds in xds_from_table(ant_table): print(ant_ds) print( dask.compute(ant_ds.NAME.data, ant_ds.POSITION.data, ant_ds.DISH_DIAMETER.data)) # Create datasets representing each row of the spw table spw_table = '::'.join((args.ms, 'SPECTRAL_WINDOW')) for spw_ds in xds_from_table(spw_table, group_cols="__row__"): print(spw_ds) print(spw_ds.NUM_CHAN.values) print(spw_ds.CHAN_FREQ.values) # Create datasets from a partioning of the MS datasets = list(xds_from_ms(args.ms, chunks={'row': args.chunks})) for ds in datasets: print(ds) # Try write the STATE_ID column back write = xds_to_table(ds, args.ms, 'STATE_ID') with ProgressBar(), Profiler() as prof: write.compute() # Profile prof.visualize(file_path="chunked.html")
def main(args): # get full time column and compute row chunks ms = table(args.ms) time = ms.getcol('TIME') row_chunks, tbin_idx, tbin_counts = chunkify_rows( time, args.utimes_per_chunk) # convert to dask arrays tbin_idx = da.from_array(tbin_idx, chunks=(args.utimes_per_chunk)) tbin_counts = da.from_array(tbin_counts, chunks=(args.utimes_per_chunk)) n_time = tbin_idx.size ms.close() # get phase dir fld = table(args.ms+'::FIELD') radec0 = fld.getcol('PHASE_DIR').squeeze().reshape(1, 2) radec0 = np.tile(radec0, (n_time, 1)) fld.close() # get freqs freqs = table( args.ms+'::SPECTRAL_WINDOW').getcol('CHAN_FREQ')[0].astype(np.float64) n_freq = freqs.size freqs = da.from_array(freqs, chunks=(n_freq)) # get source coordinates from lsm lsm = Tigger.load(args.sky_model) radec = [] stokes = [] spi = [] ref_freqs = [] for source in lsm.sources: radec.append([source.pos.ra, source.pos.dec]) stokes.append([source.flux.I]) spi.append(source.spectrum.spi) ref_freqs.append(source.spectrum.freq0) n_dir = len(stokes) radec = np.asarray(radec) lm = np.zeros((n_time,) + radec.shape) for t in range(n_time): lm[t] = radec_to_lm(radec, radec0[t]) lm = da.from_array(lm, chunks=(args.utimes_per_chunk, n_dir, 2)) # load in the model file n_corr = 1 model = np.zeros((n_time, n_freq, n_dir, n_corr)) stokes = np.asarray(stokes) ref_freqs = np.asarray(ref_freqs) spi = np.asarray(spi) for t in range(n_time): for d in range(n_dir): model[t, :, d, 0] = stokes[d] * (freqs/ref_freqs[d])**spi[d] # append antenna columns cols = [] cols.append('ANTENNA1') cols.append('ANTENNA2') cols.append('UVW') # load in gains jones = np.load(args.gain_file) jones = jones.astype(np.complex128) jones_shape = jones.shape jones = da.from_array(jones, chunks=(args.utimes_per_chunk,) + jones_shape[1::]) # change model to dask array model = da.from_array(model, chunks=(args.utimes_per_chunk,) + model.shape[1::]) # load data in in chunks and apply gains to each chunk xds = xds_from_ms(args.ms, columns=cols, chunks={"row": row_chunks})[0] ant1 = xds.ANTENNA1.data ant2 = xds.ANTENNA2.data uvw = xds.UVW.data # apply gains data = compute_and_corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones, model, uvw, freqs, lm) # Assign visibilities to args.out_col and write to ms xds = xds.assign(**{args.out_col: (("row", "chan", "corr"), data)}) # Create a write to the table write = xds_to_table(xds, args.ms, [args.out_col]) # Submit all graph computations in parallel with ProgressBar(): write.compute()