def test_ms_write(ms, group_cols, index_cols, select_cols): # Zero everything to be sure with pt.table(ms, readonly=False, lockoptions='auto', ack=False) as table: table.putcol("STATE_ID", np.full(table.nrows(), 0, dtype=np.int32)) data = np.zeros_like(table.getcol("DATA")) data_dtype = data.dtype table.putcol("DATA", data) xds = xds_from_ms(ms, columns=select_cols, group_cols=group_cols, index_cols=index_cols, chunks={"row": 2}) written_states = [] written_data = [] writes = [] # Write out STATE_ID and DATA for i, ds in enumerate(xds): dims = ds.dims chunks = ds.chunks state = da.arange(i, i + dims["row"], chunks=chunks["row"]) written_states.append(state.astype(np.int32)) data = da.arange(i, i + dims["row"] * dims["chan"] * dims["corr"]) data = data.reshape(dims["row"], dims["chan"], dims["corr"]) data = data.rechunk((chunks["row"], chunks["chan"], chunks["corr"])) written_data.append(data.astype(data_dtype)) state = xr.DataArray(state, dims=['row']) data = xr.DataArray(data, dims=['row', 'chan', 'corr']) nds = ds.assign(STATE_ID=state, DATA=data) write = xds_to_table(nds, ms, ["STATE_ID", "DATA"]) writes.append(write) # Do all writes in parallel dask.compute(writes) xds = xds_from_ms(ms, columns=select_cols, group_cols=group_cols, index_cols=index_cols, chunks={"row": 2}) # Check that state and data have been correctly written it = enumerate(zip(xds, written_states, written_data)) for i, (ds, state, data) in it: assert np.all(ds.STATE_ID.data.compute() == state.compute()) assert np.all(ds.DATA.data.compute() == data.compute())
def test_multireadwrite(ms, group_cols, index_cols): xds = xds_from_ms(ms, group_cols=group_cols, index_cols=index_cols) nds = [ds.copy() for ds in xds] writes = [xds_to_table(sds, ms, sds.data_vars.keys()) for sds in nds] da.compute(writes)
def _proc_map_fn(args): ms, i = args xds = xds_from_ms(ms, columns=["STATE_ID"], group_cols=["FIELD_ID"]) xds[i] = xds[i].assign(STATE_ID=xds[i].STATE_ID + i) write = xds_to_table(xds[i], ms, ["STATE_ID"]) write.compute(scheduler='sync') return True
def test_table_schema(ms, group_cols, index_cols): # Test default MS Schema xds = xds_from_ms(ms, columns=["DATA"], group_cols=group_cols, index_cols=index_cols, chunks={"row": 1e9}) assert xds[0].DATA.dims == ("row", "chan", "corr") # Test custom column schema specified by ColumnSchema objet table_schema = MS_SCHEMA.copy() table_schema['DATA'] = ColumnSchema(("my-chan", "my-corr")) xds = xds_from_ms(ms, columns=["DATA"], group_cols=group_cols, index_cols=index_cols, table_schema=table_schema, chunks={"row": 1e9}) assert xds[0].DATA.dims == ("row", "my-chan", "my-corr") # Test custom column schema specified by tuple object table_schema['DATA'] = ("my-chan", "my-corr") xds = xds_from_ms(ms, columns=["DATA"], group_cols=group_cols, index_cols=index_cols, table_schema=table_schema, chunks={"row": 1e9}) assert xds[0].DATA.dims == ("row", "my-chan", "my-corr") table_schema = {"DATA": ("my-chan", "my-corr")} xds = xds_from_ms(ms, columns=["DATA"], group_cols=group_cols, index_cols=index_cols, table_schema=["MS", table_schema], chunks={"row": 1e9}) assert xds[0].DATA.dims == ("row", "my-chan", "my-corr")
def test_row_query(ms, index_cols): xds = xds_from_ms(ms, columns=index_cols, group_cols="__row__", index_cols=index_cols, chunks={"row": 2}) with pt.table(ms, readonly=False, ack=False) as table: # Get the expected row ordering by lexically # sorting the indexing columns cols = [(name, table.getcol(name)) for name in index_cols] expected_rows = np.lexsort(tuple(c for n, c in reversed(cols))) assert len(xds) == table.nrows() for ds, expected_row in zip(xds, expected_rows): assert ds.table_row == expected_row
def test_ms_read(ms, group_cols, index_cols, select_cols): xds = xds_from_ms(ms, columns=select_cols, group_cols=group_cols, index_cols=index_cols, chunks={"row": 2}) order = orderby_clause(index_cols) with pt.table(ms, lockoptions='auto', ack=False) as T: # noqa for ds in xds: group_col_values = [getattr(ds, c) for c in group_cols] where = where_clause(group_cols, group_col_values) query = "SELECT * FROM $T %s %s" % (where, order) with pt.taql(query) as Q: for c in select_cols: np_data = Q.getcol(c) dask_data = getattr(ds, c).data.compute() assert np.all(np_data == dask_data)
def test_unfragmented_ms(ms, group_cols, index_cols): from xarrayms.xarray_ms import get_row_runs patch_target = "xarrayms.xarray_ms.get_row_runs" def mock_row_runs(*args, **kwargs): """ Calls get_row_runs and does some testing """ # import pdb; pdb.set_trace() row_runs, row_resorts = get_row_runs(*args, **kwargs) # Do some checks to ensure that fragmentation was handled assert kwargs['min_frag_level'] is False assert row_resorts.compute() is None return row_runs, row_resorts with patch(patch_target, side_effect=mock_row_runs) as patch_fn: xds = xds_from_ms( ms, columns=index_cols, # noqa group_cols=group_cols, index_cols=index_cols, min_frag_level=False, chunks={"row": 1e9}) assert patch_fn.called_once_with(min_frag_level=False, sort_dir="read")
def test_fragmented_ms(ms, group_cols, index_cols): select_cols = index_cols + ["STATE_ID"] # Zero everything to be sure with pt.table(ms, readonly=False, lockoptions='auto', ack=False) as table: table.putcol("STATE_ID", np.full(table.nrows(), 0, dtype=np.int32)) # Patch the get_row_runs function to check that it is called # and resorting is invoked # Unfragmented is 1.00, induce # fragmentation handling min_frag_level = 0.9999 from xarrayms.xarray_ms import get_row_runs patch_target = "xarrayms.xarray_ms.get_row_runs" def mock_row_runs(*args, **kwargs): """ Calls get_row_runs and does some testing """ row_runs, row_resorts = get_row_runs(*args, **kwargs) # Do some checks to ensure that fragmentation was handled assert kwargs['min_frag_level'] == min_frag_level assert isinstance(row_resorts.compute(), np.ndarray) return row_runs, row_resorts with patch(patch_target, side_effect=mock_row_runs) as patch_fn: xds = xds_from_ms(ms, columns=select_cols, group_cols=group_cols, index_cols=index_cols, min_frag_level=min_frag_level, chunks={"row": 1e9}) # Check that mock_row_runs was called assert patch_fn.called_once_with(min_frag_level=min_frag_level, sort_dir="read") order = orderby_clause(index_cols) written_states = [] with pt.table(ms, readonly=True, lockoptions='auto', ack=False) as table: for i, ds in enumerate(xds): group_col_values = [getattr(ds, c) for c in group_cols] where = where_clause(group_cols, group_col_values) query = "SELECT * FROM $table %s %s" % (where, order) # Check that each column is correctly read with pt.taql(query) as Q: for c in select_cols: np_data = Q.getcol(c) dask_data = getattr(ds, c).data.compute() assert np.all(np_data == dask_data) # Now write some data to the STATE_ID column state = da.arange(i, i + ds.dims['row'], chunks=ds.chunks['row']) written_states.append(state) state = xr.DataArray(state, dims=['row']) nds = ds.assign(STATE_ID=state) with patch(patch_target, side_effect=mock_row_runs) as patch_fn: xds_to_table(nds, ms, "STATE_ID", min_frag_level=min_frag_level).compute() assert patch_fn.called_once_with(min_frag_level=min_frag_level, sort_dir="write") # Check that state has been correctly written xds = list( xds_from_ms(ms, columns=select_cols, group_cols=group_cols, index_cols=index_cols, min_frag_level=min_frag_level, chunks={"row": 1e9})) for i, (ds, expected) in enumerate(zip(xds, written_states)): assert np.all(ds.STATE_ID.data.compute() == expected)