Exemplo n.º 1
0
    def __init__(self, subtable):
        if subtable not in SUBTABLES:
            raise ValueError("'%s' is not a valid Measurement Set "
                             "sub-table" % subtable)

        self.subtable = subtable
        self.DEFAULT_TABLE_DESC = pt.required_ms_desc(subtable)
        self.REQUIRED_FIELDS = set(self.DEFAULT_TABLE_DESC.keys())
Exemplo n.º 2
0
def test_ms_builder(tmp_path, variables, chunks, fixed):
    def _variable_factory(dims, dtype):
        shape = tuple(sum(chunks[d]) for d in dims)
        achunks = tuple(chunks[d] for d in dims)
        dask_array = da.random.random(shape, chunks=achunks).astype(dtype)
        return [Variable(dims, dask_array, {})]

    variables = {
        n: _variable_factory(dims, dtype)
        for n, dims, dtype in variables
    }
    var_names = set(variables.keys())

    builder = MSDescriptorBuilder(fixed)
    default_desc = builder.default_descriptor()
    tab_desc = builder.descriptor(variables, default_desc)
    dminfo = builder.dminfo(tab_desc)

    # These columns must always be present on an MS
    required_cols = {
        k
        for k in pt.required_ms_desc().keys() if not k.startswith('_')
    }

    filename = str(tmp_path / "test_plugin.ms")

    with pt.table(filename, tab_desc, dminfo=dminfo, ack=False, nrow=10) as T:
        # We got required + the extra columns we asked for
        assert set(T.colnames()) == set.union(var_names, required_cols)

        if fixed:
            original_dminfo = {v['NAME']: v for v in dminfo.values()}
            table_dminfo = {v['NAME']: v for v in T.getdminfo().values()}

            assert len(original_dminfo) == len(table_dminfo)

            for dm_name, dm_group in table_dminfo.items():
                odm_group = original_dminfo[dm_name]
                assert odm_group['TYPE'] == dm_group['TYPE']
                assert set(odm_group['COLUMNS']) == set(dm_group['COLUMNS'])

                if dm_group['TYPE'] == 'TiledColumnStMan':
                    original_tile = odm_group['SPEC']['DEFAULTTILESHAPE']
                    table_tile = dm_group['SPEC']['DEFAULTTILESHAPE']
                    assert_array_equal(original_tile, table_tile)
Exemplo n.º 3
0
def test_ms_builder(tmp_path, variables, fixed):
    var_names = set(variables.keys())

    builder = MSDescriptorBuilder(fixed)
    default_desc = builder.default_descriptor()
    tab_desc = builder.descriptor(variables, default_desc)
    dminfo = builder.dminfo(tab_desc)

    # These columns must always be present on an MS
    required_cols = {k for k in pt.required_ms_desc().keys()
                     if not k.startswith('_')}

    filename = str(tmp_path / "test_plugin.ms")

    with pt.table(filename, tab_desc, dminfo=dminfo, ack=False, nrow=10) as T:
        # We got required + the extra columns we asked for

        assert set(T.colnames()) == set.union(var_names, required_cols)

        if fixed:
            original_dminfo = {v['NAME']: v for v in dminfo.values()}
            table_dminfo = {v['NAME']: v for v in T.getdminfo().values()}

            for column in variables.keys():
                try:
                    column_group = table_dminfo[column + "_GROUP"]
                except KeyError:
                    raise ValueError(f"{column} should be fixed but no "
                                     f"Data Manager Group was created")

                assert column in column_group["COLUMNS"]
                assert column_group["TYPE"] == "TiledColumnStMan"

            assert len(original_dminfo) == len(table_dminfo)

            for dm_name, dm_group in table_dminfo.items():
                odm_group = original_dminfo[dm_name]
                assert odm_group['TYPE'] == dm_group['TYPE']
                assert set(odm_group['COLUMNS']) == set(dm_group['COLUMNS'])

                if dm_group['TYPE'] == 'TiledColumnStMan':
                    original_tile = odm_group['SPEC']['DEFAULTTILESHAPE']
                    table_tile = dm_group['SPEC']['DEFAULTTILESHAPE']
                    assert_array_equal(original_tile, table_tile)
Exemplo n.º 4
0
def test_ms_subtable_builder(tmp_path, table):
    A = da.zeros((10, 20, 30), chunks=(2, 20, 30), dtype=np.int32)
    variables = {"FOO": Variable(("row", "chan", "corr"), A, {})}
    var_names = set(variables.keys())

    builder = MSSubTableDescriptorBuilder(table)
    default_desc = builder.default_descriptor()
    tab_desc = builder.descriptor(variables, default_desc)
    dminfo = builder.dminfo(tab_desc)

    # These columns must always be present on an MS
    required_cols = {
        k
        for k in pt.required_ms_desc(table).keys() if not k.startswith('_')
    }

    filename = str(tmp_path / (f"{table}.table"))

    with pt.table(filename, tab_desc, dminfo=dminfo, ack=False) as T:
        T.addrows(10)

        # We got required + the extra columns we asked for
        assert set(T.colnames()) == set.union(var_names, required_cols)
Exemplo n.º 5
0
def kat_ms_desc_and_dminfo(nbl, nchan, ncorr, model_data=False):
    """
    Creates Table Description and Data Manager Information objecs that
    describe a MeasurementSet suitable for holding MeerKAT data.

    Creates additional DATA, IMAGING_WEIGHT and possibly
    MODEL_DATA and CORRECTED_DATA columns.

    Columns are given fixed shapes defined by the arguments to this function.

    :param nbl: Number of baselines.
    :param nchan: Number of channels.
    :param ncorr: Number of correlations.
    :param model_data: Boolean indicated whether MODEL_DATA and CORRECTED_DATA
                        should be added to the Measurement Set.
    :return: Returns a tuple containing a table description describing
            the extra columns and hypercolumns, as well as a Data Manager
            description.
    """

    if not casacore_binding == 'pyrap':
        raise ValueError("kat_ms_desc_and_dminfo requires the "
                         "casacore binding to operate")

    # Columns that will be modified.
    # We want to keep things like their
    # keywords, dims and shapes
    modify_columns = {
        "WEIGHT", "SIGMA", "FLAG", "FLAG_CATEGORY", "UVW", "ANTENNA1",
        "ANTENNA2"
    }

    # Get the required table descriptor for an MS
    table_desc = tables.required_ms_desc("MAIN")

    # Take columns we wish to modify
    extra_table_desc = {
        c: d
        for c, d in table_desc.iteritems() if c in modify_columns
    }

    # Used to set the SPEC for each Data Manager Group
    dmgroup_spec = {}

    def dmspec(coldesc, tile_mem_limit=None):
        """
        Create data manager spec for a given column description,
        mostly by adding a DEFAULTTILESHAPE that fits into the
        supplied memory limit.
        """

        # Choose 4MB if none given
        if tile_mem_limit is None:
            tile_mem_limit = 4 * 1024 * 1024

        # Get the reversed column shape. DEFAULTTILESHAPE is deep in
        # casacore and its necessary to specify their ordering here
        # ntilerows is the dim that will change least quickly
        rev_shape = list(reversed(coldesc["shape"]))

        ntilerows = 1
        np_dtype = MS_TO_NP_TYPE_MAP[coldesc["valueType"].upper()]
        nbytes = np.dtype(np_dtype).itemsize

        # Try bump up the number of rows in our tiles while they're
        # below the memory limit for the tile
        while np.product(rev_shape +
                         [2 * ntilerows]) * nbytes < tile_mem_limit:
            ntilerows *= 2

        return {"DEFAULTTILESHAPE": np.int32(rev_shape + [ntilerows])}

    # Update existing columns with shape and data manager information
    dm_group = 'UVW'
    shape = [3]
    extra_table_desc["UVW"].update(options=0,
                                   shape=shape,
                                   ndim=len(shape),
                                   dataManagerGroup=dm_group,
                                   dataManagerType='TiledColumnStMan')
    dmgroup_spec[dm_group] = dmspec(extra_table_desc["UVW"])

    dm_group = 'Weight'
    shape = [ncorr]
    extra_table_desc["WEIGHT"].update(options=4,
                                      shape=shape,
                                      ndim=len(shape),
                                      dataManagerGroup=dm_group,
                                      dataManagerType='TiledColumnStMan')
    dmgroup_spec[dm_group] = dmspec(extra_table_desc["WEIGHT"])

    dm_group = 'Sigma'
    shape = [ncorr]
    extra_table_desc["SIGMA"].update(options=4,
                                     shape=shape,
                                     ndim=len(shape),
                                     dataManagerGroup=dm_group,
                                     dataManagerType='TiledColumnStMan')
    dmgroup_spec[dm_group] = dmspec(extra_table_desc["SIGMA"])

    dm_group = 'Flag'
    shape = [nchan, ncorr]
    extra_table_desc["FLAG"].update(options=4,
                                    shape=shape,
                                    ndim=len(shape),
                                    dataManagerGroup=dm_group,
                                    dataManagerType='TiledColumnStMan')
    dmgroup_spec[dm_group] = dmspec(extra_table_desc["FLAG"])

    dm_group = 'FlagCategory'
    shape = [1, nchan, ncorr]
    extra_table_desc["FLAG_CATEGORY"].update(
        options=4,
        keywords={},
        shape=shape,
        ndim=len(shape),
        dataManagerGroup=dm_group,
        dataManagerType='TiledColumnStMan')
    dmgroup_spec[dm_group] = dmspec(extra_table_desc["FLAG_CATEGORY"])

    # Create new columns for integration into the MS
    additional_columns = []

    dm_group = 'Data'
    shape = [nchan, ncorr]
    desc = tables.tablecreatearraycoldesc("DATA",
                                          0 + 0j,
                                          comment="The Visibility DATA Column",
                                          options=4,
                                          valuetype='complex',
                                          keywords={"UNIT": "Jy"},
                                          shape=shape,
                                          ndim=len(shape),
                                          datamanagergroup=dm_group,
                                          datamanagertype='TiledColumnStMan')
    dmgroup_spec[dm_group] = dmspec(desc["desc"])
    additional_columns.append(desc)

    dm_group = 'ImagingWeight'
    shape = [nchan]
    desc = tables.tablecreatearraycoldesc(
        "IMAGING_WEIGHT",
        0,
        comment="Weight set by imaging task (e.g. uniform weighting)",
        options=4,
        valuetype='float',
        shape=shape,
        ndim=len(shape),
        datamanagergroup=dm_group,
        datamanagertype='TiledColumnStMan')
    dmgroup_spec[dm_group] = dmspec(desc["desc"])
    additional_columns.append(desc)

    # Add MODEL_DATA and CORRECTED_DATA if requested
    if model_data == True:
        dm_group = 'ModelData'
        shape = [nchan, ncorr]
        desc = tables.tablecreatearraycoldesc(
            "MODEL_DATA",
            0 + 0j,
            comment="The Visibility MODEL_DATA Column",
            options=4,
            valuetype='complex',
            keywords={"UNIT": "Jy"},
            shape=shape,
            ndim=len(shape),
            datamanagergroup=dm_group,
            datamanagertype='TiledColumnStMan')
        dmgroup_spec[dm_group] = dmspec(desc["desc"])
        additional_columns.append(desc)

        dm_group = 'CorrectedData'
        shape = [nchan, ncorr]
        desc = tables.tablecreatearraycoldesc(
            "CORRECTED_DATA",
            0 + 0j,
            comment="The Visibility CORRECTED_DATA Column",
            options=4,
            valuetype='complex',
            keywords={"UNIT": "Jy"},
            shape=shape,
            ndim=len(shape),
            datamanagergroup=dm_group,
            datamanagertype='TiledColumnStMan')
        dmgroup_spec[dm_group] = dmspec(desc["desc"])
        additional_columns.append(desc)

    # Update extra table description with additional columns
    extra_table_desc.update(tables.maketabdesc(additional_columns))

    # Update the original table descriptor with modifications/additions
    # Need this to construct a complete Data Manager specification
    # that includes the original columns
    table_desc.update(extra_table_desc)

    # Construct DataManager Specification
    dminfo = tables.makedminfo(table_desc, dmgroup_spec)

    return extra_table_desc, dminfo
Exemplo n.º 6
0
 def __init__(self, fixed=True):
     super(AbstractDescriptorBuilder, self).__init__()
     self.DEFAULT_MS_DESC = pt.required_ms_desc()
     self.REQUIRED_FIELDS = set(self.DEFAULT_MS_DESC.keys())
     self.fixed = fixed
     self.ms_dims = None
Exemplo n.º 7
0
def test_ms_create(Dataset, tmp_path, chunks, num_chans, corr_types, sources):
    # Set up
    rs = np.random.RandomState(42)

    ms_path = tmp_path / "create.ms"

    ms_table_name = str(ms_path)
    ant_table_name = "::".join((ms_table_name, "ANTENNA"))
    ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION"))
    pol_table_name = "::".join((ms_table_name, "POLARIZATION"))
    spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW"))
    # SOURCE is an optional MS sub-table
    src_table_name = "::".join((ms_table_name, "SOURCE"))

    ms_datasets = []
    ant_datasets = []
    ddid_datasets = []
    pol_datasets = []
    spw_datasets = []
    src_datasets = []

    # For comparison
    all_data_desc_id = []
    all_data = []

    # Create ANTENNA dataset of 64 antennas
    # Each column in the ANTENNA has a fixed shape so we
    # can represent all rows with one dataset
    na = 64
    position = da.random.random((na, 3)) * 10000
    offset = da.random.random((na, 3))
    names = np.array(['ANTENNA-%d' % i for i in range(na)], dtype=np.object)
    ds = Dataset({
        'POSITION': (("row", "xyz"), position),
        'OFFSET': (("row", "xyz"), offset),
        'NAME': (("row", ), da.from_array(names, chunks=na)),
    })
    ant_datasets.append(ds)

    # Create SOURCE datasets
    for s, (name, direction, rest_freq) in enumerate(sources):
        dask_num_lines = da.full((1, ), len(rest_freq), dtype=np.int32)
        dask_direction = da.asarray(direction)[None, :]
        dask_rest_freq = da.asarray(rest_freq)[None, :]
        dask_name = da.asarray(np.asarray([name], dtype=np.object))
        ds = Dataset({
            "NUM_LINES": (("row", ), dask_num_lines),
            "NAME": (("row", ), dask_name),
            "REST_FREQUENCY": (("row", "line"), dask_rest_freq),
            "DIRECTION": (("row", "dir"), dask_direction),
        })
        src_datasets.append(ds)

    # Create POLARISATION datasets.
    # Dataset per output row required because column shapes are variable
    for r, corr_type in enumerate(corr_types):
        dask_num_corr = da.full((1, ), len(corr_type), dtype=np.int32)
        dask_corr_type = da.from_array(corr_type,
                                       chunks=len(corr_type))[None, :]
        ds = Dataset({
            "NUM_CORR": (("row", ), dask_num_corr),
            "CORR_TYPE": (("row", "corr"), dask_corr_type),
        })

        pol_datasets.append(ds)

    # Create multiple MeerKAT L-band SPECTRAL_WINDOW datasets
    # Dataset per output row required because column shapes are variable
    for num_chan in num_chans:
        dask_num_chan = da.full((1, ), num_chan, dtype=np.int32)
        dask_chan_freq = da.linspace(.856e9,
                                     2 * .856e9,
                                     num_chan,
                                     chunks=num_chan)[None, :]
        dask_chan_width = da.full((1, num_chan), .856e9 / num_chan)

        ds = Dataset({
            "NUM_CHAN": (("row", ), dask_num_chan),
            "CHAN_FREQ": (("row", "chan"), dask_chan_freq),
            "CHAN_WIDTH": (("row", "chan"), dask_chan_width),
        })

        spw_datasets.append(ds)

    # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION
    # create a corresponding DATA_DESCRIPTION.
    # Each column has fixed shape so we handle all rows at once
    spw_ids, pol_ids = zip(
        *product(range(len(num_chans)), range(len(corr_types))))
    dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32))
    dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32))
    ddid_datasets.append(
        Dataset({
            "SPECTRAL_WINDOW_ID": (("row", ), dask_spw_ids),
            "POLARIZATION_ID": (("row", ), dask_pol_ids),
        }))

    # Now create the associated MS dataset
    for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)):
        # Infer row, chan and correlation shape
        row = sum(chunks['row'])
        chan = spw_datasets[spw_id].CHAN_FREQ.shape[1]
        corr = pol_datasets[pol_id].CORR_TYPE.shape[1]

        # Create some dask vis data
        dims = ("row", "chan", "corr")
        np_data = (rs.normal(size=(row, chan, corr)) +
                   1j * rs.normal(size=(row, chan, corr))).astype(np.complex64)

        data_chunks = tuple((chunks['row'], chan, corr))
        dask_data = da.from_array(np_data, chunks=data_chunks)
        # Create dask ddid column
        dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32)
        dataset = Dataset({
            'DATA': (dims, dask_data),
            'DATA_DESC_ID': (("row", ), dask_ddid)
        })
        ms_datasets.append(dataset)
        all_data.append(dask_data)
        all_data_desc_id.append(dask_ddid)

    ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL")
    ant_writes = xds_to_table(ant_datasets, ant_table_name, columns="ALL")
    pol_writes = xds_to_table(pol_datasets, pol_table_name, columns="ALL")
    spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL")
    ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL")
    source_writes = xds_to_table(src_datasets, src_table_name, columns="ALL")

    dask.compute(ms_writes, ant_writes, pol_writes, spw_writes, ddid_writes,
                 source_writes)

    # Check ANTENNA table correctly created
    with pt.table(ant_table_name, ack=False) as A:
        assert_array_equal(A.getcol("NAME"), names)
        assert_array_equal(A.getcol("POSITION"), position)
        assert_array_equal(A.getcol("OFFSET"), offset)

        required_desc = pt.required_ms_desc("ANTENNA")
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        assert set(A.colnames()) == set(required_columns)

    # Check POLARIZATION table correctly created
    with pt.table(pol_table_name, ack=False) as P:
        for r, corr_type in enumerate(corr_types):
            assert_array_equal(P.getcol("CORR_TYPE", startrow=r, nrow=1),
                               [corr_type])
            assert_array_equal(P.getcol("NUM_CORR", startrow=r, nrow=1),
                               [len(corr_type)])

        required_desc = pt.required_ms_desc("POLARIZATION")
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        assert set(P.colnames()) == set(required_columns)

    # Check SPECTRAL_WINDOW table correctly created
    with pt.table(spw_table_name, ack=False) as S:
        for r, num_chan in enumerate(num_chans):
            assert_array_equal(
                S.getcol("NUM_CHAN", startrow=r, nrow=1)[0], num_chan)
            assert_array_equal(
                S.getcol("CHAN_FREQ", startrow=r, nrow=1)[0],
                np.linspace(.856e9, 2 * .856e9, num_chan))
            assert_array_equal(
                S.getcol("CHAN_WIDTH", startrow=r, nrow=1)[0],
                np.full(num_chan, .856e9 / num_chan))

        required_desc = pt.required_ms_desc("SPECTRAL_WINDOW")
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        assert set(S.colnames()) == set(required_columns)

    # We should get a cartesian product out
    with pt.table(ddid_table_name, ack=False) as D:
        spw_id, pol_id = zip(
            *product(range(len(num_chans)), range(len(corr_types))))
        assert_array_equal(pol_id, D.getcol("POLARIZATION_ID"))
        assert_array_equal(spw_id, D.getcol("SPECTRAL_WINDOW_ID"))

        required_desc = pt.required_ms_desc("DATA_DESCRIPTION")
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        assert set(D.colnames()) == set(required_columns)

    with pt.table(src_table_name, ack=False) as S:
        for r, (name, direction, rest_freq) in enumerate(sources):
            assert_array_equal(S.getcol("NAME", startrow=r, nrow=1)[0], [name])
            assert_array_equal(S.getcol("REST_FREQUENCY", startrow=r, nrow=1),
                               [rest_freq])
            assert_array_equal(S.getcol("DIRECTION", startrow=r, nrow=1),
                               [direction])

    with pt.table(ms_table_name, ack=False) as T:
        # DATA_DESC_ID's are all the same shape
        assert_array_equal(T.getcol("DATA_DESC_ID"),
                           da.concatenate(all_data_desc_id))

        # DATA is variably shaped (on DATA_DESC_ID) so we
        # compared each one separately.
        for ddid, data in enumerate(all_data):
            ms_data = T.getcol("DATA", startrow=ddid * row, nrow=row)
            assert_array_equal(ms_data, data)

        required_desc = pt.required_ms_desc()
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        # Check we have the required columns
        assert set(T.colnames()) == required_columns.union(
            ["DATA", "DATA_DESC_ID"])