def __init__(self, subtable): if subtable not in SUBTABLES: raise ValueError("'%s' is not a valid Measurement Set " "sub-table" % subtable) self.subtable = subtable self.DEFAULT_TABLE_DESC = pt.required_ms_desc(subtable) self.REQUIRED_FIELDS = set(self.DEFAULT_TABLE_DESC.keys())
def test_ms_builder(tmp_path, variables, chunks, fixed): def _variable_factory(dims, dtype): shape = tuple(sum(chunks[d]) for d in dims) achunks = tuple(chunks[d] for d in dims) dask_array = da.random.random(shape, chunks=achunks).astype(dtype) return [Variable(dims, dask_array, {})] variables = { n: _variable_factory(dims, dtype) for n, dims, dtype in variables } var_names = set(variables.keys()) builder = MSDescriptorBuilder(fixed) default_desc = builder.default_descriptor() tab_desc = builder.descriptor(variables, default_desc) dminfo = builder.dminfo(tab_desc) # These columns must always be present on an MS required_cols = { k for k in pt.required_ms_desc().keys() if not k.startswith('_') } filename = str(tmp_path / "test_plugin.ms") with pt.table(filename, tab_desc, dminfo=dminfo, ack=False, nrow=10) as T: # We got required + the extra columns we asked for assert set(T.colnames()) == set.union(var_names, required_cols) if fixed: original_dminfo = {v['NAME']: v for v in dminfo.values()} table_dminfo = {v['NAME']: v for v in T.getdminfo().values()} assert len(original_dminfo) == len(table_dminfo) for dm_name, dm_group in table_dminfo.items(): odm_group = original_dminfo[dm_name] assert odm_group['TYPE'] == dm_group['TYPE'] assert set(odm_group['COLUMNS']) == set(dm_group['COLUMNS']) if dm_group['TYPE'] == 'TiledColumnStMan': original_tile = odm_group['SPEC']['DEFAULTTILESHAPE'] table_tile = dm_group['SPEC']['DEFAULTTILESHAPE'] assert_array_equal(original_tile, table_tile)
def test_ms_builder(tmp_path, variables, fixed): var_names = set(variables.keys()) builder = MSDescriptorBuilder(fixed) default_desc = builder.default_descriptor() tab_desc = builder.descriptor(variables, default_desc) dminfo = builder.dminfo(tab_desc) # These columns must always be present on an MS required_cols = {k for k in pt.required_ms_desc().keys() if not k.startswith('_')} filename = str(tmp_path / "test_plugin.ms") with pt.table(filename, tab_desc, dminfo=dminfo, ack=False, nrow=10) as T: # We got required + the extra columns we asked for assert set(T.colnames()) == set.union(var_names, required_cols) if fixed: original_dminfo = {v['NAME']: v for v in dminfo.values()} table_dminfo = {v['NAME']: v for v in T.getdminfo().values()} for column in variables.keys(): try: column_group = table_dminfo[column + "_GROUP"] except KeyError: raise ValueError(f"{column} should be fixed but no " f"Data Manager Group was created") assert column in column_group["COLUMNS"] assert column_group["TYPE"] == "TiledColumnStMan" assert len(original_dminfo) == len(table_dminfo) for dm_name, dm_group in table_dminfo.items(): odm_group = original_dminfo[dm_name] assert odm_group['TYPE'] == dm_group['TYPE'] assert set(odm_group['COLUMNS']) == set(dm_group['COLUMNS']) if dm_group['TYPE'] == 'TiledColumnStMan': original_tile = odm_group['SPEC']['DEFAULTTILESHAPE'] table_tile = dm_group['SPEC']['DEFAULTTILESHAPE'] assert_array_equal(original_tile, table_tile)
def test_ms_subtable_builder(tmp_path, table): A = da.zeros((10, 20, 30), chunks=(2, 20, 30), dtype=np.int32) variables = {"FOO": Variable(("row", "chan", "corr"), A, {})} var_names = set(variables.keys()) builder = MSSubTableDescriptorBuilder(table) default_desc = builder.default_descriptor() tab_desc = builder.descriptor(variables, default_desc) dminfo = builder.dminfo(tab_desc) # These columns must always be present on an MS required_cols = { k for k in pt.required_ms_desc(table).keys() if not k.startswith('_') } filename = str(tmp_path / (f"{table}.table")) with pt.table(filename, tab_desc, dminfo=dminfo, ack=False) as T: T.addrows(10) # We got required + the extra columns we asked for assert set(T.colnames()) == set.union(var_names, required_cols)
def kat_ms_desc_and_dminfo(nbl, nchan, ncorr, model_data=False): """ Creates Table Description and Data Manager Information objecs that describe a MeasurementSet suitable for holding MeerKAT data. Creates additional DATA, IMAGING_WEIGHT and possibly MODEL_DATA and CORRECTED_DATA columns. Columns are given fixed shapes defined by the arguments to this function. :param nbl: Number of baselines. :param nchan: Number of channels. :param ncorr: Number of correlations. :param model_data: Boolean indicated whether MODEL_DATA and CORRECTED_DATA should be added to the Measurement Set. :return: Returns a tuple containing a table description describing the extra columns and hypercolumns, as well as a Data Manager description. """ if not casacore_binding == 'pyrap': raise ValueError("kat_ms_desc_and_dminfo requires the " "casacore binding to operate") # Columns that will be modified. # We want to keep things like their # keywords, dims and shapes modify_columns = { "WEIGHT", "SIGMA", "FLAG", "FLAG_CATEGORY", "UVW", "ANTENNA1", "ANTENNA2" } # Get the required table descriptor for an MS table_desc = tables.required_ms_desc("MAIN") # Take columns we wish to modify extra_table_desc = { c: d for c, d in table_desc.iteritems() if c in modify_columns } # Used to set the SPEC for each Data Manager Group dmgroup_spec = {} def dmspec(coldesc, tile_mem_limit=None): """ Create data manager spec for a given column description, mostly by adding a DEFAULTTILESHAPE that fits into the supplied memory limit. """ # Choose 4MB if none given if tile_mem_limit is None: tile_mem_limit = 4 * 1024 * 1024 # Get the reversed column shape. DEFAULTTILESHAPE is deep in # casacore and its necessary to specify their ordering here # ntilerows is the dim that will change least quickly rev_shape = list(reversed(coldesc["shape"])) ntilerows = 1 np_dtype = MS_TO_NP_TYPE_MAP[coldesc["valueType"].upper()] nbytes = np.dtype(np_dtype).itemsize # Try bump up the number of rows in our tiles while they're # below the memory limit for the tile while np.product(rev_shape + [2 * ntilerows]) * nbytes < tile_mem_limit: ntilerows *= 2 return {"DEFAULTTILESHAPE": np.int32(rev_shape + [ntilerows])} # Update existing columns with shape and data manager information dm_group = 'UVW' shape = [3] extra_table_desc["UVW"].update(options=0, shape=shape, ndim=len(shape), dataManagerGroup=dm_group, dataManagerType='TiledColumnStMan') dmgroup_spec[dm_group] = dmspec(extra_table_desc["UVW"]) dm_group = 'Weight' shape = [ncorr] extra_table_desc["WEIGHT"].update(options=4, shape=shape, ndim=len(shape), dataManagerGroup=dm_group, dataManagerType='TiledColumnStMan') dmgroup_spec[dm_group] = dmspec(extra_table_desc["WEIGHT"]) dm_group = 'Sigma' shape = [ncorr] extra_table_desc["SIGMA"].update(options=4, shape=shape, ndim=len(shape), dataManagerGroup=dm_group, dataManagerType='TiledColumnStMan') dmgroup_spec[dm_group] = dmspec(extra_table_desc["SIGMA"]) dm_group = 'Flag' shape = [nchan, ncorr] extra_table_desc["FLAG"].update(options=4, shape=shape, ndim=len(shape), dataManagerGroup=dm_group, dataManagerType='TiledColumnStMan') dmgroup_spec[dm_group] = dmspec(extra_table_desc["FLAG"]) dm_group = 'FlagCategory' shape = [1, nchan, ncorr] extra_table_desc["FLAG_CATEGORY"].update( options=4, keywords={}, shape=shape, ndim=len(shape), dataManagerGroup=dm_group, dataManagerType='TiledColumnStMan') dmgroup_spec[dm_group] = dmspec(extra_table_desc["FLAG_CATEGORY"]) # Create new columns for integration into the MS additional_columns = [] dm_group = 'Data' shape = [nchan, ncorr] desc = tables.tablecreatearraycoldesc("DATA", 0 + 0j, comment="The Visibility DATA Column", options=4, valuetype='complex', keywords={"UNIT": "Jy"}, shape=shape, ndim=len(shape), datamanagergroup=dm_group, datamanagertype='TiledColumnStMan') dmgroup_spec[dm_group] = dmspec(desc["desc"]) additional_columns.append(desc) dm_group = 'ImagingWeight' shape = [nchan] desc = tables.tablecreatearraycoldesc( "IMAGING_WEIGHT", 0, comment="Weight set by imaging task (e.g. uniform weighting)", options=4, valuetype='float', shape=shape, ndim=len(shape), datamanagergroup=dm_group, datamanagertype='TiledColumnStMan') dmgroup_spec[dm_group] = dmspec(desc["desc"]) additional_columns.append(desc) # Add MODEL_DATA and CORRECTED_DATA if requested if model_data == True: dm_group = 'ModelData' shape = [nchan, ncorr] desc = tables.tablecreatearraycoldesc( "MODEL_DATA", 0 + 0j, comment="The Visibility MODEL_DATA Column", options=4, valuetype='complex', keywords={"UNIT": "Jy"}, shape=shape, ndim=len(shape), datamanagergroup=dm_group, datamanagertype='TiledColumnStMan') dmgroup_spec[dm_group] = dmspec(desc["desc"]) additional_columns.append(desc) dm_group = 'CorrectedData' shape = [nchan, ncorr] desc = tables.tablecreatearraycoldesc( "CORRECTED_DATA", 0 + 0j, comment="The Visibility CORRECTED_DATA Column", options=4, valuetype='complex', keywords={"UNIT": "Jy"}, shape=shape, ndim=len(shape), datamanagergroup=dm_group, datamanagertype='TiledColumnStMan') dmgroup_spec[dm_group] = dmspec(desc["desc"]) additional_columns.append(desc) # Update extra table description with additional columns extra_table_desc.update(tables.maketabdesc(additional_columns)) # Update the original table descriptor with modifications/additions # Need this to construct a complete Data Manager specification # that includes the original columns table_desc.update(extra_table_desc) # Construct DataManager Specification dminfo = tables.makedminfo(table_desc, dmgroup_spec) return extra_table_desc, dminfo
def __init__(self, fixed=True): super(AbstractDescriptorBuilder, self).__init__() self.DEFAULT_MS_DESC = pt.required_ms_desc() self.REQUIRED_FIELDS = set(self.DEFAULT_MS_DESC.keys()) self.fixed = fixed self.ms_dims = None
def test_ms_create(Dataset, tmp_path, chunks, num_chans, corr_types, sources): # Set up rs = np.random.RandomState(42) ms_path = tmp_path / "create.ms" ms_table_name = str(ms_path) ant_table_name = "::".join((ms_table_name, "ANTENNA")) ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION")) pol_table_name = "::".join((ms_table_name, "POLARIZATION")) spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW")) # SOURCE is an optional MS sub-table src_table_name = "::".join((ms_table_name, "SOURCE")) ms_datasets = [] ant_datasets = [] ddid_datasets = [] pol_datasets = [] spw_datasets = [] src_datasets = [] # For comparison all_data_desc_id = [] all_data = [] # Create ANTENNA dataset of 64 antennas # Each column in the ANTENNA has a fixed shape so we # can represent all rows with one dataset na = 64 position = da.random.random((na, 3)) * 10000 offset = da.random.random((na, 3)) names = np.array(['ANTENNA-%d' % i for i in range(na)], dtype=np.object) ds = Dataset({ 'POSITION': (("row", "xyz"), position), 'OFFSET': (("row", "xyz"), offset), 'NAME': (("row", ), da.from_array(names, chunks=na)), }) ant_datasets.append(ds) # Create SOURCE datasets for s, (name, direction, rest_freq) in enumerate(sources): dask_num_lines = da.full((1, ), len(rest_freq), dtype=np.int32) dask_direction = da.asarray(direction)[None, :] dask_rest_freq = da.asarray(rest_freq)[None, :] dask_name = da.asarray(np.asarray([name], dtype=np.object)) ds = Dataset({ "NUM_LINES": (("row", ), dask_num_lines), "NAME": (("row", ), dask_name), "REST_FREQUENCY": (("row", "line"), dask_rest_freq), "DIRECTION": (("row", "dir"), dask_direction), }) src_datasets.append(ds) # Create POLARISATION datasets. # Dataset per output row required because column shapes are variable for r, corr_type in enumerate(corr_types): dask_num_corr = da.full((1, ), len(corr_type), dtype=np.int32) dask_corr_type = da.from_array(corr_type, chunks=len(corr_type))[None, :] ds = Dataset({ "NUM_CORR": (("row", ), dask_num_corr), "CORR_TYPE": (("row", "corr"), dask_corr_type), }) pol_datasets.append(ds) # Create multiple MeerKAT L-band SPECTRAL_WINDOW datasets # Dataset per output row required because column shapes are variable for num_chan in num_chans: dask_num_chan = da.full((1, ), num_chan, dtype=np.int32) dask_chan_freq = da.linspace(.856e9, 2 * .856e9, num_chan, chunks=num_chan)[None, :] dask_chan_width = da.full((1, num_chan), .856e9 / num_chan) ds = Dataset({ "NUM_CHAN": (("row", ), dask_num_chan), "CHAN_FREQ": (("row", "chan"), dask_chan_freq), "CHAN_WIDTH": (("row", "chan"), dask_chan_width), }) spw_datasets.append(ds) # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION # create a corresponding DATA_DESCRIPTION. # Each column has fixed shape so we handle all rows at once spw_ids, pol_ids = zip( *product(range(len(num_chans)), range(len(corr_types)))) dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32)) dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32)) ddid_datasets.append( Dataset({ "SPECTRAL_WINDOW_ID": (("row", ), dask_spw_ids), "POLARIZATION_ID": (("row", ), dask_pol_ids), })) # Now create the associated MS dataset for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)): # Infer row, chan and correlation shape row = sum(chunks['row']) chan = spw_datasets[spw_id].CHAN_FREQ.shape[1] corr = pol_datasets[pol_id].CORR_TYPE.shape[1] # Create some dask vis data dims = ("row", "chan", "corr") np_data = (rs.normal(size=(row, chan, corr)) + 1j * rs.normal(size=(row, chan, corr))).astype(np.complex64) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) # Create dask ddid column dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'DATA_DESC_ID': (("row", ), dask_ddid) }) ms_datasets.append(dataset) all_data.append(dask_data) all_data_desc_id.append(dask_ddid) ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL") ant_writes = xds_to_table(ant_datasets, ant_table_name, columns="ALL") pol_writes = xds_to_table(pol_datasets, pol_table_name, columns="ALL") spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL") ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL") source_writes = xds_to_table(src_datasets, src_table_name, columns="ALL") dask.compute(ms_writes, ant_writes, pol_writes, spw_writes, ddid_writes, source_writes) # Check ANTENNA table correctly created with pt.table(ant_table_name, ack=False) as A: assert_array_equal(A.getcol("NAME"), names) assert_array_equal(A.getcol("POSITION"), position) assert_array_equal(A.getcol("OFFSET"), offset) required_desc = pt.required_ms_desc("ANTENNA") required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) assert set(A.colnames()) == set(required_columns) # Check POLARIZATION table correctly created with pt.table(pol_table_name, ack=False) as P: for r, corr_type in enumerate(corr_types): assert_array_equal(P.getcol("CORR_TYPE", startrow=r, nrow=1), [corr_type]) assert_array_equal(P.getcol("NUM_CORR", startrow=r, nrow=1), [len(corr_type)]) required_desc = pt.required_ms_desc("POLARIZATION") required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) assert set(P.colnames()) == set(required_columns) # Check SPECTRAL_WINDOW table correctly created with pt.table(spw_table_name, ack=False) as S: for r, num_chan in enumerate(num_chans): assert_array_equal( S.getcol("NUM_CHAN", startrow=r, nrow=1)[0], num_chan) assert_array_equal( S.getcol("CHAN_FREQ", startrow=r, nrow=1)[0], np.linspace(.856e9, 2 * .856e9, num_chan)) assert_array_equal( S.getcol("CHAN_WIDTH", startrow=r, nrow=1)[0], np.full(num_chan, .856e9 / num_chan)) required_desc = pt.required_ms_desc("SPECTRAL_WINDOW") required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) assert set(S.colnames()) == set(required_columns) # We should get a cartesian product out with pt.table(ddid_table_name, ack=False) as D: spw_id, pol_id = zip( *product(range(len(num_chans)), range(len(corr_types)))) assert_array_equal(pol_id, D.getcol("POLARIZATION_ID")) assert_array_equal(spw_id, D.getcol("SPECTRAL_WINDOW_ID")) required_desc = pt.required_ms_desc("DATA_DESCRIPTION") required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) assert set(D.colnames()) == set(required_columns) with pt.table(src_table_name, ack=False) as S: for r, (name, direction, rest_freq) in enumerate(sources): assert_array_equal(S.getcol("NAME", startrow=r, nrow=1)[0], [name]) assert_array_equal(S.getcol("REST_FREQUENCY", startrow=r, nrow=1), [rest_freq]) assert_array_equal(S.getcol("DIRECTION", startrow=r, nrow=1), [direction]) with pt.table(ms_table_name, ack=False) as T: # DATA_DESC_ID's are all the same shape assert_array_equal(T.getcol("DATA_DESC_ID"), da.concatenate(all_data_desc_id)) # DATA is variably shaped (on DATA_DESC_ID) so we # compared each one separately. for ddid, data in enumerate(all_data): ms_data = T.getcol("DATA", startrow=ddid * row, nrow=row) assert_array_equal(ms_data, data) required_desc = pt.required_ms_desc() required_columns = set(k for k in required_desc.keys() if not k.startswith("_")) # Check we have the required columns assert set(T.colnames()) == required_columns.union( ["DATA", "DATA_DESC_ID"])