예제 #1
0
def test_ms_write(ms, group_cols, index_cols, select_cols):
    # Zero everything to be sure
    with pt.table(ms, readonly=False, lockoptions='auto', ack=False) as table:
        table.putcol("STATE_ID", np.full(table.nrows(), 0, dtype=np.int32))
        data = np.zeros_like(table.getcol("DATA"))
        data_dtype = data.dtype
        table.putcol("DATA", data)

    xds = xds_from_ms(ms,
                      columns=select_cols,
                      group_cols=group_cols,
                      index_cols=index_cols,
                      chunks={"row": 2})

    written_states = []
    written_data = []
    writes = []

    # Write out STATE_ID and DATA
    for i, ds in enumerate(xds):
        dims = ds.dims
        chunks = ds.chunks
        state = da.arange(i, i + dims["row"], chunks=chunks["row"])
        written_states.append(state.astype(np.int32))

        data = da.arange(i, i + dims["row"] * dims["chan"] * dims["corr"])
        data = data.reshape(dims["row"], dims["chan"], dims["corr"])
        data = data.rechunk((chunks["row"], chunks["chan"], chunks["corr"]))
        written_data.append(data.astype(data_dtype))

        state = xr.DataArray(state, dims=['row'])
        data = xr.DataArray(data, dims=['row', 'chan', 'corr'])
        nds = ds.assign(STATE_ID=state, DATA=data)
        write = xds_to_table(nds, ms, ["STATE_ID", "DATA"])
        writes.append(write)

    # Do all writes in parallel
    dask.compute(writes)

    xds = xds_from_ms(ms,
                      columns=select_cols,
                      group_cols=group_cols,
                      index_cols=index_cols,
                      chunks={"row": 2})

    # Check that state and data have been correctly written
    it = enumerate(zip(xds, written_states, written_data))
    for i, (ds, state, data) in it:
        assert np.all(ds.STATE_ID.data.compute() == state.compute())
        assert np.all(ds.DATA.data.compute() == data.compute())
예제 #2
0
def test_multireadwrite(ms, group_cols, index_cols):
    xds = xds_from_ms(ms, group_cols=group_cols, index_cols=index_cols)

    nds = [ds.copy() for ds in xds]
    writes = [xds_to_table(sds, ms, sds.data_vars.keys()) for sds in nds]

    da.compute(writes)
예제 #3
0
def _proc_map_fn(args):
    ms, i = args
    xds = xds_from_ms(ms, columns=["STATE_ID"], group_cols=["FIELD_ID"])
    xds[i] = xds[i].assign(STATE_ID=xds[i].STATE_ID + i)
    write = xds_to_table(xds[i], ms, ["STATE_ID"])
    write.compute(scheduler='sync')
    return True
예제 #4
0
def test_table_schema(ms, group_cols, index_cols):
    # Test default MS Schema
    xds = xds_from_ms(ms,
                      columns=["DATA"],
                      group_cols=group_cols,
                      index_cols=index_cols,
                      chunks={"row": 1e9})

    assert xds[0].DATA.dims == ("row", "chan", "corr")

    # Test custom column schema specified by ColumnSchema objet
    table_schema = MS_SCHEMA.copy()
    table_schema['DATA'] = ColumnSchema(("my-chan", "my-corr"))

    xds = xds_from_ms(ms,
                      columns=["DATA"],
                      group_cols=group_cols,
                      index_cols=index_cols,
                      table_schema=table_schema,
                      chunks={"row": 1e9})

    assert xds[0].DATA.dims == ("row", "my-chan", "my-corr")

    # Test custom column schema specified by tuple object
    table_schema['DATA'] = ("my-chan", "my-corr")

    xds = xds_from_ms(ms,
                      columns=["DATA"],
                      group_cols=group_cols,
                      index_cols=index_cols,
                      table_schema=table_schema,
                      chunks={"row": 1e9})

    assert xds[0].DATA.dims == ("row", "my-chan", "my-corr")

    table_schema = {"DATA": ("my-chan", "my-corr")}

    xds = xds_from_ms(ms,
                      columns=["DATA"],
                      group_cols=group_cols,
                      index_cols=index_cols,
                      table_schema=["MS", table_schema],
                      chunks={"row": 1e9})

    assert xds[0].DATA.dims == ("row", "my-chan", "my-corr")
예제 #5
0
def test_row_query(ms, index_cols):
    xds = xds_from_ms(ms,
                      columns=index_cols,
                      group_cols="__row__",
                      index_cols=index_cols,
                      chunks={"row": 2})

    with pt.table(ms, readonly=False, ack=False) as table:
        # Get the expected row ordering by lexically
        # sorting the indexing columns
        cols = [(name, table.getcol(name)) for name in index_cols]
        expected_rows = np.lexsort(tuple(c for n, c in reversed(cols)))

        assert len(xds) == table.nrows()

        for ds, expected_row in zip(xds, expected_rows):
            assert ds.table_row == expected_row
예제 #6
0
def test_ms_read(ms, group_cols, index_cols, select_cols):
    xds = xds_from_ms(ms,
                      columns=select_cols,
                      group_cols=group_cols,
                      index_cols=index_cols,
                      chunks={"row": 2})

    order = orderby_clause(index_cols)

    with pt.table(ms, lockoptions='auto', ack=False) as T:  # noqa
        for ds in xds:
            group_col_values = [getattr(ds, c) for c in group_cols]
            where = where_clause(group_cols, group_col_values)
            query = "SELECT * FROM $T %s %s" % (where, order)

            with pt.taql(query) as Q:
                for c in select_cols:
                    np_data = Q.getcol(c)
                    dask_data = getattr(ds, c).data.compute()
                    assert np.all(np_data == dask_data)
예제 #7
0
def test_unfragmented_ms(ms, group_cols, index_cols):
    from xarrayms.xarray_ms import get_row_runs
    patch_target = "xarrayms.xarray_ms.get_row_runs"

    def mock_row_runs(*args, **kwargs):
        """ Calls get_row_runs and does some testing """
        # import pdb; pdb.set_trace()
        row_runs, row_resorts = get_row_runs(*args, **kwargs)
        # Do some checks to ensure that fragmentation was handled
        assert kwargs['min_frag_level'] is False
        assert row_resorts.compute() is None
        return row_runs, row_resorts

    with patch(patch_target, side_effect=mock_row_runs) as patch_fn:
        xds = xds_from_ms(
            ms,
            columns=index_cols,  # noqa
            group_cols=group_cols,
            index_cols=index_cols,
            min_frag_level=False,
            chunks={"row": 1e9})

        assert patch_fn.called_once_with(min_frag_level=False, sort_dir="read")
예제 #8
0
def test_fragmented_ms(ms, group_cols, index_cols):
    select_cols = index_cols + ["STATE_ID"]

    # Zero everything to be sure
    with pt.table(ms, readonly=False, lockoptions='auto', ack=False) as table:
        table.putcol("STATE_ID", np.full(table.nrows(), 0, dtype=np.int32))

    # Patch the get_row_runs function to check that it is called
    # and resorting is invoked
    # Unfragmented is 1.00, induce
    # fragmentation handling
    min_frag_level = 0.9999
    from xarrayms.xarray_ms import get_row_runs
    patch_target = "xarrayms.xarray_ms.get_row_runs"

    def mock_row_runs(*args, **kwargs):
        """ Calls get_row_runs and does some testing """
        row_runs, row_resorts = get_row_runs(*args, **kwargs)
        # Do some checks to ensure that fragmentation was handled
        assert kwargs['min_frag_level'] == min_frag_level
        assert isinstance(row_resorts.compute(), np.ndarray)
        return row_runs, row_resorts

    with patch(patch_target, side_effect=mock_row_runs) as patch_fn:
        xds = xds_from_ms(ms,
                          columns=select_cols,
                          group_cols=group_cols,
                          index_cols=index_cols,
                          min_frag_level=min_frag_level,
                          chunks={"row": 1e9})

    # Check that mock_row_runs was called
    assert patch_fn.called_once_with(min_frag_level=min_frag_level,
                                     sort_dir="read")

    order = orderby_clause(index_cols)
    written_states = []

    with pt.table(ms, readonly=True, lockoptions='auto', ack=False) as table:
        for i, ds in enumerate(xds):
            group_col_values = [getattr(ds, c) for c in group_cols]
            where = where_clause(group_cols, group_col_values)
            query = "SELECT * FROM $table %s %s" % (where, order)

            # Check that each column is correctly read
            with pt.taql(query) as Q:
                for c in select_cols:
                    np_data = Q.getcol(c)
                    dask_data = getattr(ds, c).data.compute()
                    assert np.all(np_data == dask_data)

            # Now write some data to the STATE_ID column
            state = da.arange(i, i + ds.dims['row'], chunks=ds.chunks['row'])
            written_states.append(state)
            state = xr.DataArray(state, dims=['row'])
            nds = ds.assign(STATE_ID=state)

            with patch(patch_target, side_effect=mock_row_runs) as patch_fn:
                xds_to_table(nds,
                             ms,
                             "STATE_ID",
                             min_frag_level=min_frag_level).compute()

            assert patch_fn.called_once_with(min_frag_level=min_frag_level,
                                             sort_dir="write")

    # Check that state has been correctly written
    xds = list(
        xds_from_ms(ms,
                    columns=select_cols,
                    group_cols=group_cols,
                    index_cols=index_cols,
                    min_frag_level=min_frag_level,
                    chunks={"row": 1e9}))

    for i, (ds, expected) in enumerate(zip(xds, written_states)):
        assert np.all(ds.STATE_ID.data.compute() == expected)