Beispiel #1
0
def test_storage_ms(ms):

    oxdsl = xds_from_ms(ms)

    writes = xds_to_storage_table(oxdsl, ms)

    oxdsl = dask.compute(oxdsl)[0]

    dask.compute(writes)

    xdsl = dask.compute(xds_from_ms(ms))[0]

    assert all([xds.equals(oxds) for xds, oxds in zip(xdsl, oxdsl)])
Beispiel #2
0
    def _derive_row_chunking(self, args):
        datasets = xds_from_ms(args.ms,
                               group_cols=GROUP_COLS,
                               columns=["TIME", "INTERVAL"],
                               chunks={'row': args.row_chunks})

        return dataset_chunks(datasets, args.time_bin_secs, args.row_chunks)
Beispiel #3
0
def test_add_datavars(ms, tmp_path_factory, prechunking, postchunking):
    store = tmp_path_factory.mktemp("zarr_store")
    ref_datasets = xds_from_ms(ms)

    for i, ds in enumerate(ref_datasets):
        chunks = ds.chunks
        row = sum(chunks["row"])
        chan = sum(chunks["chan"])
        corr = sum(chunks["corr"])

        ref_datasets[i] = ds.assign_coords(
            row=np.arange(row),
            chan=np.arange(chan),
            corr=np.arange(corr),
            dummy=np.arange(10)  # Orphan coordinate.
        )

    chunked_datasets = [ds.chunk(prechunking) for ds in ref_datasets]
    dask.compute(xds_to_zarr(chunked_datasets, store))

    rechunked_datasets = [ds.chunk(postchunking)
                          for ds in xds_from_zarr(store)]
    augmented_datasets = [ds.assign({"DUMMY": (("row", "chan", "corr"),
                                    da.zeros_like(ds.DATA.data))})
                          for ds in rechunked_datasets]
    dask.compute(xds_to_zarr(augmented_datasets, store, rechunk=True))

    augmented_datasets = xds_from_zarr(store)

    assert all([ds.DUMMY.chunks == cds.DATA.chunks
                for ds, cds in zip(augmented_datasets, chunked_datasets)])
Beispiel #4
0
def test_rechunking(ms, tmp_path_factory, prechunking, postchunking):
    store = tmp_path_factory.mktemp("zarr_store")
    ref_datasets = xds_from_ms(ms)

    for i, ds in enumerate(ref_datasets):
        chunks = ds.chunks
        row = sum(chunks["row"])
        chan = sum(chunks["chan"])
        corr = sum(chunks["corr"])

        ref_datasets[i] = ds.assign_coords(
            row=np.arange(row),
            chan=np.arange(chan),
            corr=np.arange(corr),
            dummy=np.arange(10)  # Orphan coordinate.
        )

    chunked_datasets = [ds.chunk(prechunking) for ds in ref_datasets]
    dask.compute(xds_to_zarr(chunked_datasets, store))

    rechunked_datasets = [ds.chunk(postchunking)
                          for ds in xds_from_zarr(store)]
    dask.compute(xds_to_zarr(rechunked_datasets, store, rechunk=True))

    rechunked_datasets = xds_from_zarr(store)

    assert all([ds.equals(rds)
                for ds, rds in zip(rechunked_datasets, ref_datasets)])
Beispiel #5
0
def test_keyword_write(ms):
    datasets = xds_from_ms(ms)

    # Add to table keywords
    writes = xds_to_table([], ms, [], table_keywords={'bob': 'qux'})
    dask.compute(writes)

    with pt.table(ms, ack=False, readonly=True) as T:
        assert T.getkeywords()['bob'] == 'qux'

    # Add to column keywords
    writes = xds_to_table(datasets,
                          ms, [],
                          column_keywords={'STATE_ID': {
                              'bob': 'qux'
                          }})
    dask.compute(writes)

    with pt.table(ms, ack=False, readonly=True) as T:
        assert T.getcolkeywords("STATE_ID")['bob'] == 'qux'

    # Remove from column and table keywords
    from daskms.writes import DELKW
    writes = xds_to_table(datasets,
                          ms, [],
                          table_keywords={'bob': DELKW},
                          column_keywords={'STATE_ID': {
                              'bob': DELKW
                          }})
    dask.compute(writes)

    with pt.table(ms, ack=False, readonly=True) as T:
        assert 'bob' not in T.getkeywords()
        assert 'bob' not in T.getcolkeywords("STATE_ID")
Beispiel #6
0
def test_keyword_read(keyword_ms, table_kw, column_kw):
    # Create an example MS
    with pt.table(keyword_ms, ack=False, readonly=True) as T:
        desc = T._getdesc(actual=True)

    ret = xds_from_ms(keyword_ms,
                      table_keywords=table_kw,
                      column_keywords=column_kw)

    if isinstance(ret, tuple):
        ret_pos = 1

        if table_kw is True:
            assert desc["_keywords_"] == ret[ret_pos]
            ret_pos += 1

        if column_kw is True:
            colkw = ret[ret_pos]

            for column, keywords in colkw.items():
                assert desc[column]['keywords'] == keywords

            ret_pos += 1
    else:
        assert table_kw is False
        assert column_kw is False
        assert isinstance(ret, list)
Beispiel #7
0
    def _input_datasets(self, args, row_chunks):
        # Set up row chunks
        chunks = [{'row': rc} for rc in row_chunks]

        main_ds, tabkw, colkw = xds_from_ms(args.ms,
                                            group_cols=GROUP_COLS,
                                            table_keywords=True,
                                            column_keywords=True,
                                            chunks=chunks)

        # Figure out non SPW + SORTED sub-tables to just copy
        subtables = {
            k
            for k, v in tabkw.items()
            if k not in ("SPECTRAL_WINDOW", "SORTED_TABLE")
            and isinstance(v, str) and v.startswith("Table: ")
        }

        spw_ds = xds_from_table("::".join((args.ms, "SPECTRAL_WINDOW")),
                                group_cols="__row__")

        field_ds = xds_from_table("::".join((args.ms, "FIELD")),
                                  group_cols="__row__")

        ddid_ds = xds_from_table("::".join((args.ms, "DATA_DESCRIPTION")))
        assert len(ddid_ds) == 1
        ddid_ds = dask.compute(ddid_ds)[0]

        return main_ds, spw_ds, ddid_ds, field_ds, subtables
Beispiel #8
0
def test_cached_array(ms):
    ds = xds_from_ms(ms, group_cols=[], chunks={'row': 1, 'chan': 4})[0]

    data = ds.DATA.data
    cached_data = cached_array(data)
    assert_array_almost_equal(cached_data, data)

    # 2 x row blocks + row x chan x corr blocks
    assert len(_key_cache) == data.numblocks[0] * 2 + data.npartitions
    # rows, row runs and data array cache's
    assert len(_array_cache_cache) == 3

    # Pickling works
    pickled_data = pickle.loads(pickle.dumps(cached_data))
    assert_array_almost_equal(pickled_data, data)

    # Same underlying caching is re-used
    # 2 x row blocks + row x chan x corr blocks
    assert len(_key_cache) == data.numblocks[0] * 2 + data.npartitions
    # rows, row runs and data array cache's
    assert len(_array_cache_cache) == 3

    del pickled_data, cached_data, data, ds
    gc.collect()

    assert len(_key_cache) == 0
    assert len(_array_cache_cache) == 0
Beispiel #9
0
def predict(args):
    # Convert source data into dask arrays
    sky_model = parse_sky_model(args.sky_model, args.model_chunks)

    # Get the support tables
    tables = support_tables(
        args, ["FIELD", "DATA_DESCRIPTION", "SPECTRAL_WINDOW", "POLARIZATION"])

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]

    # List of write operations
    writes = []

    # Construct a graph for each DATA_DESC_ID
    for xds in xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks}):

        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]]
        pol = pol_ds[ddid.POLARIZATION_ID.data[0]]

        # Select single dataset row out
        corrs = pol.NUM_CORR.data[0]

        _, time_index = da.unique(xds.TIME.data, return_inverse=True)

        # Generate visibility expressions for each source type
        source_vis = [
            vis_factory(args, stype, sky_model, time_index, xds, field, spw,
                        pol) for stype in sky_model.keys()
        ]

        # Sum visibilities together
        vis = sum(source_vis)

        # Reshape (2, 2) correlation to shape (4,)
        if corrs == 4:
            vis = vis.reshape(vis.shape[:2] + (4, ))

        # Assign visibilities to MODEL_DATA array on the dataset
        xds = xds.assign(MODEL_DATA=(("row", "chan", "corr"), vis))
        # Create a write to the table
        write = xds_to_table(xds, args.ms, ['MODEL_DATA'])
        # Add to the list of writes
        writes.append(write)

    # Submit all graph computations in parallel
    with ProgressBar():
        dask.compute(writes)
Beispiel #10
0
def test_ms_create_and_update(Dataset, tmp_path, chunks):
    """ Test that we can update and append at the same time """
    filename = str(tmp_path / "create-and-update.ms")

    rs = np.random.RandomState(42)

    # Create a dataset of 10 rows with DATA and DATA_DESC_ID
    dims = ("row", "chan", "corr")
    row, chan, corr = tuple(sum(chunks[d]) for d in dims)
    ms_datasets = []
    np_data = (rs.normal(size=(row, chan, corr)) +
               1j * rs.normal(size=(row, chan, corr))).astype(np.complex64)

    data_chunks = tuple((chunks['row'], chan, corr))
    dask_data = da.from_array(np_data, chunks=data_chunks)
    # Create dask ddid column
    dask_ddid = da.full(row, 0, chunks=chunks['row'], dtype=np.int32)
    dataset = Dataset({
        'DATA': (dims, dask_data),
        'DATA_DESC_ID': (("row", ), dask_ddid),
    })
    ms_datasets.append(dataset)

    # Write it
    writes = xds_to_table(ms_datasets, filename, ["DATA", "DATA_DESC_ID"])
    dask.compute(writes)

    ms_datasets = xds_from_ms(filename)

    # Now add another dataset (different DDID), with no ROWID
    np_data = (rs.normal(size=(row, chan, corr)) +
               1j * rs.normal(size=(row, chan, corr))).astype(np.complex64)
    data_chunks = tuple((chunks['row'], chan, corr))
    dask_data = da.from_array(np_data, chunks=data_chunks)
    # Create dask ddid column
    dask_ddid = da.full(row, 1, chunks=chunks['row'], dtype=np.int32)
    dataset = Dataset({
        'DATA': (dims, dask_data),
        'DATA_DESC_ID': (("row", ), dask_ddid),
    })
    ms_datasets.append(dataset)

    # Write it
    writes = xds_to_table(ms_datasets, filename, ["DATA", "DATA_DESC_ID"])
    dask.compute(writes)

    # Rows have been added and additional data is present
    with pt.table(filename, ack=False, readonly=True) as T:
        first_data_desc_id = da.full(row,
                                     ms_datasets[0].DATA_DESC_ID,
                                     chunks=chunks['row'])
        ds_data = da.concatenate(
            [ms_datasets[0].DATA.data, ms_datasets[1].DATA.data])
        ds_ddid = da.concatenate(
            [first_data_desc_id, ms_datasets[1].DATA_DESC_ID.data])
        assert_array_equal(T.getcol("DATA"), ds_data)
        assert_array_equal(T.getcol("DATA_DESC_ID"), ds_ddid)
Beispiel #11
0
    def _derive_row_chunking(self, args):
        datasets = xds_from_ms(args.ms,
                               group_cols=GROUP_COLS,
                               columns=["TIME", "INTERVAL", "UVW"],
                               chunks={'row': args.row_chunks},
                               taql_where=args.taql_where)

        return dataset_chunks(datasets,
                              args.time_bin_secs,
                              args.row_chunks,
                              bda=args.command)
Beispiel #12
0
def test_xds_to_zarr(ms, spw_table, ant_table, tmp_path_factory):
    zarr_store = tmp_path_factory.mktemp("zarr_store") / "test.zarr"
    spw_store = zarr_store.parent / f"{zarr_store.name}::SPECTRAL_WINDOW"
    ant_store = zarr_store.parent / f"{zarr_store.name}::ANTENNA"

    ms_datasets = xds_from_ms(ms)
    spw_datasets = xds_from_table(spw_table, group_cols="__row__")
    ant_datasets = xds_from_table(ant_table)

    for i, ds in enumerate(ms_datasets):
        dims = ds.dims
        row, chan, corr = (dims[d] for d in ("row", "chan", "corr"))

        ms_datasets[i] = ds.assign_coords(
            **{
                "chan": (("chan", ), np.arange(chan)),
                "corr": (("corr", ), np.arange(corr)),
            })

    main_zarr_writes = xds_to_zarr(ms_datasets, zarr_store)
    assert len(ms_datasets) == len(main_zarr_writes)

    for ms_ds, zw_ds in zip(ms_datasets, main_zarr_writes):
        for k, _ in ms_ds.attrs[DASKMS_PARTITION_KEY]:
            assert getattr(ms_ds, k) == getattr(zw_ds, k)

    writes = [main_zarr_writes]
    writes.extend(xds_to_zarr(spw_datasets, spw_store))
    writes.extend(xds_to_zarr(ant_datasets, ant_store))
    dask.compute(writes)

    zarr_datasets = xds_from_zarr(zarr_store, chunks={"row": 1})

    for ms_ds, zarr_ds in zip(ms_datasets, zarr_datasets):
        # Check data variables
        assert ms_ds.data_vars, "MS Dataset has no variables"

        for name, var in ms_ds.data_vars.items():
            zdata = getattr(zarr_ds, name).data
            assert type(zdata) is type(var.data)  # noqa
            assert_array_equal(var.data, zdata)

        # Check coordinates
        assert ms_ds.coords, "MS Datset has no coordinates"

        for name, var in ms_ds.coords.items():
            zdata = getattr(zarr_ds, name).data
            assert type(zdata) is type(var.data)  # noqa
            assert_array_equal(var.data, zdata)

        # Check dataset attributes
        for k, v in ms_ds.attrs.items():
            zattr = getattr(zarr_ds, k)
            assert_array_equal(zattr, v)
def test_write_table_proxy_keyword(ms):
    datasets = xds_from_ms(ms)

    # Test that we get a TableProxy if requested
    writes, tp = xds_to_table(datasets, ms, [], table_proxy=True)
    assert isinstance(writes, list) and isinstance(writes[0], Dataset)
    assert isinstance(tp, TableProxy)
    assert tp.nrows().result() == 10

    writes = xds_to_table(datasets, ms, [], table_proxy=False)
    assert isinstance(writes, list) and isinstance(writes[0], Dataset)
Beispiel #14
0
    def convolve(self, x):
        # print("Applying Hessian", file=log)
        x = da.from_array(x.astype(self.real_type),
                          chunks=(1, self.nx, self.ny),
                          name=False)

        convolvedims = []
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID

                bands = self.band_mapping[ims][spw]
                model = x[list(bands), :, :]
                convolvedim = hessian(self.uvws[ims][spw],
                                      self.freq[ims][spw],
                                      model,
                                      self.freq_bin_idx[ims][spw],
                                      self.freq_bin_counts[ims][spw],
                                      self.cell,
                                      weights=self.stokes_weights[ims][spw],
                                      nthreads=self.nthreads // self.nband,
                                      epsilon=self.epsilon,
                                      do_wstacking=self.do_wstacking,
                                      double_accum=True)

                convolvedims.append(convolvedim)

        convolvedims = dask.compute(convolvedims)[0]

        return accumulate_dirty(convolvedims, self.nband,
                                self.band_mapping).astype(self.real_type)
Beispiel #15
0
def zarr_tester(ms, spw_table, ant_table,
                zarr_store, spw_store, ant_store):

    ms_datasets = xds_from_ms(ms)
    spw_datasets = xds_from_table(spw_table, group_cols="__row__")
    ant_datasets = xds_from_table(ant_table)

    for i, ds in enumerate(ms_datasets):
        dims = ds.dims
        row, chan, corr = (dims[d] for d in ("row", "chan", "corr"))

        ms_datasets[i] = ds.assign_coords(**{
            "chan": (("chan",), np.arange(chan)),
            "corr": (("corr",), np.arange(corr)),
        })

    main_zarr_writes = xds_to_zarr(ms_datasets, zarr_store.url,
                                   storage_options=zarr_store.storage_options)
    assert len(ms_datasets) == len(main_zarr_writes)

    for ms_ds, zw_ds in zip(ms_datasets, main_zarr_writes):
        for k, _ in ms_ds.attrs[DASKMS_PARTITION_KEY]:
            assert getattr(ms_ds, k) == getattr(zw_ds, k)

    writes = [main_zarr_writes]
    writes.extend(xds_to_zarr(spw_datasets, spw_store))
    writes.extend(xds_to_zarr(ant_datasets, ant_store))
    dask.compute(writes)

    zarr_datasets = xds_from_storage_ms(zarr_store, chunks={"row": 1})

    for ms_ds, zarr_ds in zip(ms_datasets, zarr_datasets):
        # Check data variables
        assert ms_ds.data_vars, "MS Dataset has no variables"

        for name, var in ms_ds.data_vars.items():
            zdata = getattr(zarr_ds, name).data
            assert type(zdata) is type(var.data)  # noqa
            assert_array_equal(var.data, zdata)

        # Check coordinates
        assert ms_ds.coords, "MS Datset has no coordinates"

        for name, var in ms_ds.coords.items():
            zdata = getattr(zarr_ds, name).data
            assert type(zdata) is type(var.data)  # noqa
            assert_array_equal(var.data, zdata)

        # Check dataset attributes
        for k, v in ms_ds.attrs.items():
            zattr = getattr(zarr_ds, k)
            assert_array_equal(zattr, v)
Beispiel #16
0
    def write_model(self, x):
        print("Writing model data")
        x = da.from_array(x.astype(np.float32), chunks=(1, self.nx, self.ny))
        writes  = []
        for ims in self.ms:
            xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=('MODEL_DATA', 'UVW'))

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            out_data = []
            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                model_vis = getattr(ds, 'MODEL_DATA').data

                bands = self.band_mapping[ims][spw]
                model = x[list(bands), :, :]
                vis = im2vis(uvw, freq, model, freq_bin_idx, freq_bin_counts,
                             self.cell, nthreads=self.nthreads, epsilon=self.epsilon,
                             do_wstacking=self.do_wstacking)

                model_vis = populate_model(vis, model_vis)
                
                out_ds = ds.assign(**{self.model_column: (("row", "chan", "corr"),
                                                          model_vis)})
                out_data.append(out_ds)
            writes.append(xds_to_table(out_data, ims, columns=[self.model_column]))
        dask.compute(writes, scheduler='single-threaded')
Beispiel #17
0
    def write_component_model(self, comps, ref_freq, mask, row_chunks, chan_chunks):
        print("Writing model data at full freq resolution")
        order, npix = comps.shape
        comps = da.from_array(comps, chunks=(-1, -1))
        mask = da.from_array(mask.squeeze(), chunks=(-1, -1))
        writes  = []
        for ims in self.ms:
            xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks={'row':(row_chunks,), 'chan':(chan_chunks,)},
                              columns=('MODEL_DATA', 'UVW'))

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            pols = dask.compute(pols)[0]

            out_data = []
            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = spws[ds.DATA_DESC_ID]
                freq = spw.CHAN_FREQ.data.squeeze()
                freq_bin_idx = da.arange(0, freq.size, 1, chunks=freq.chunks, dtype=np.int64)
                freq_bin_counts = da.ones(freq.size, chunks=freq.chunks, dtype=np.int64)

                uvw = ds.UVW.data

                model_vis = getattr(ds, 'MODEL_DATA').data

                model = model_from_comps(comps, freq, mask, ref_freq)
                
                vis = im2vis(uvw, freq, model, freq_bin_idx, freq_bin_counts,
                             self.cell, nthreads=self.nthreads, epsilon=self.epsilon,
                             do_wstacking=self.do_wstacking)

                model_vis = populate_model(vis, model_vis)
                
                out_ds = ds.assign(**{self.model_column: (("row", "chan", "corr"),
                                                          model_vis)})
                out_data.append(out_ds)
            writes.append(xds_to_table(out_data, ims, columns=[self.model_column]))
        dask.compute(writes, scheduler='single-threaded')
Beispiel #18
0
def test_xarray_to_zarr(ms, tmp_path_factory):
    store = tmp_path_factory.mktemp("zarr_store")
    datasets = xds_from_ms(ms)

    for i, ds in enumerate(datasets):
        chunks = ds.chunks
        row = sum(chunks["row"])
        chan = sum(chunks["chan"])
        corr = sum(chunks["corr"])

        datasets[i] = ds.assign_coords(row=np.arange(row),
                                       chan=np.arange(chan),
                                       corr=np.arange(corr))

    for i, ds in enumerate(datasets):
        ds.to_zarr(str(store / f"ds-{i}.zarr"))
Beispiel #19
0
def test_unified_schema(ms):
    datasets = xds_from_ms(ms)
    assert len(datasets) == 3

    from daskms.experimental.arrow.arrow_schema import ArrowSchema

    schema = ArrowSchema.from_datasets(datasets)

    for ds in datasets:
        for column, var in ds.data_vars.items():
            s = schema.data_vars[column]
            assert s.dims == var.dims[1:]
            assert s.shape == var.shape[1:]
            assert np.dtype(s.dtype) == var.dtype
            assert isinstance(var.data, s.type)

    schema.to_arrow_schema()
Beispiel #20
0
def test_xds_to_parquet(ms, tmp_path_factory, spw_table, ant_table):
    store = tmp_path_factory.mktemp("parquet_store") / "out.parquet"
    # antenna_store = store.parent / f"{store.name}::ANTENNA"
    # spw_store = store.parent / f"{store.name}::SPECTRAL_WINDOW"

    datasets = xds_from_ms(ms)

    # We can test row chunking if xarray is installed
    if xarray is not None:
        datasets = [ds.chunk({"row": 1}) for ds in datasets]

    # spw_datasets = xds_from_table(spw_table, group_cols="__row__")
    # ant_datasets = xds_from_table(ant_table, group_cols="__row__")

    writes = []
    writes.extend(xds_to_parquet(datasets, store))
    # TODO(sjperkins)
    # Fix arrow shape unification errors
    # writes.extend(xds_to_parquet(spw_datasets, spw_store))
    # writes.extend(xds_to_parquet(ant_datasets, antenna_store))
    dask.compute(writes)

    pq_datasets = xds_from_parquet(store, chunks={"row": 1})
    assert len(datasets) == len(pq_datasets)

    for ds, pq_ds in zip(datasets, pq_datasets):
        for column, var in ds.data_vars.items():
            pq_var = getattr(pq_ds, column)
            assert_array_equal(var.data, pq_var.data)
            assert var.dims == pq_var.dims

        for column, var in ds.coords.items():
            pq_var = getattr(pq_ds, column)
            assert_array_equal(var.data, pq_var.data)
            assert var.dims == pq_var.dims

        partitions = ds.attrs[DASKMS_PARTITION_KEY]
        pq_partitions = pq_ds.attrs[DASKMS_PARTITION_KEY]
        assert partitions == pq_partitions

        for field, dtype in partitions:
            assert getattr(ds, field) == getattr(pq_ds, field)
Beispiel #21
0
def test_expressions(ms):
    datasets = xds_from_ms(ms)

    for i, ds in enumerate(datasets):
        dims = ds.DATA.dims
        datasets[i] = ds.assign(DIR1_DATA=(dims, ds.DATA.data),
                                DIR2_DATA=(dims, ds.DATA.data),
                                DIR3_DATA=(dims, ds.DATA.data))

    results = [
        ds.DATA.data /
        (-ds.DIR1_DATA.data + ds.DIR2_DATA.data + ds.DIR3_DATA.data) * 4
        for ds in datasets
    ]

    string = "DATA / (-DIR1_DATA + DIR2_DATA + DIR3_DATA)*4"
    expressions = data_column_expr(string, datasets)

    for i, (ds, expr) in enumerate(zip(datasets, expressions)):
        assert_array_equal(results[i], expr)
Beispiel #22
0
def test_github_98():
    ms = "/home/sperkins/data/AF0236_spw01.ms/"

    if not os.path.exists(ms):
        pytest.skip("AF0236_spw01.ms on which this "
                    "test depends is not present")

    datasets = xds_from_ms(ms,
                           columns=['DATA', 'ANTENNA1', 'ANTENNA2'],
                           group_cols=['DATA_DESC_ID'],
                           taql_where='ANTENNA1 == 5 || ANTENNA2 == 5')

    assert len(datasets) == 2
    assert datasets[0].DATA_DESC_ID == 0
    assert datasets[1].DATA_DESC_ID == 1

    for ds in datasets:
        expr = da.logical_or(ds.ANTENNA1.data == 5, ds.ANTENNA2.data == 5)
        expr, equal = dask.compute(expr, da.all(expr))
        assert equal.item() is True
        assert len(expr) > 0
Beispiel #23
0
def test_storage_parquet(ms, tmp_path_factory):

    parquet_store = tmp_path_factory.mktemp("parquet") / "test.parquet"

    oxdsl = xds_from_ms(ms)

    writes = xds_to_parquet(oxdsl, parquet_store)

    dask.compute(writes)

    oxdsl = xds_from_parquet(parquet_store)

    writes = xds_to_storage_table(oxdsl, parquet_store)

    oxdsl = dask.compute(oxdsl)[0]

    dask.compute(writes)

    xdsl = dask.compute(xds_from_parquet(parquet_store))[0]

    assert all([xds.equals(oxds) for xds, oxds in zip(xdsl, oxdsl)])
Beispiel #24
0
def test_storage_zarr(ms, tmp_path_factory):

    zarr_store = tmp_path_factory.mktemp("zarr") / "test.zarr"

    oxdsl = xds_from_ms(ms)

    writes = xds_to_zarr(oxdsl, zarr_store)

    dask.compute(writes)

    oxdsl = xds_from_zarr(zarr_store)

    writes = xds_to_storage_table(oxdsl, zarr_store)

    oxdsl = dask.compute(oxdsl)[0]

    dask.compute(writes)

    xdsl = dask.compute(xds_from_zarr(zarr_store))[0]

    assert all([xds.equals(oxds) for xds, oxds in zip(xdsl, oxdsl)])
Beispiel #25
0
def test_multiprocess_create(ms, tmp_path_factory):
    zarr_store = tmp_path_factory.mktemp("zarr_store") / "test.zarr"

    ms_datasets = xds_from_ms(ms)

    for i, ds in enumerate(ms_datasets):
        ms_datasets[i] = ds.chunk({"row": 1})

    writes = xds_to_zarr(ms_datasets, zarr_store)

    dask.compute(writes, scheduler="processes")

    zds = xds_from_zarr(zarr_store)

    for zds, msds in zip(zds, ms_datasets):
        for k, v in msds.data_vars.items():
            assert_array_equal(v, getattr(zds, k))

        for k, v in msds.coords.items():
            assert_array_equal(v, getattr(zds, k))

        for k, v in msds.attrs.items():
            assert_array_equal(v, getattr(zds, k))
Beispiel #26
0
def _jones2col(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from daskms.experimental.zarr import xds_from_zarr
    from daskms import xds_from_ms, xds_to_table
    import dask.array as da
    import dask
    from africanus.calibration.utils import chunkify_rows
    from africanus.calibration.utils.dask import corrupt_vis

    # get net gains
    G = xds_from_zarr(args.gain_table + '::G')

    # chunking info
    t_chunks = G[0].t_chunk.data
    if len(t_chunks) > 1:
        t_chunks = G[0].t_chunk.data[1:-1] - G[0].t_chunk.data[0:-2]
        assert (t_chunks == t_chunks[0]).all()
        utpc = t_chunks[0]
    else:
        utpc = t_chunks[0]
    times = xds_from_ms(args.ms[0], columns=['TIME'])[0].get('TIME').data.compute()
    row_chunks, tbin_idx, tbin_counts = chunkify_rows(times, utimes_per_chunk=utpc, daskify_idx=True)

    f_chunks = G[0].f_chunk.data
    if len(f_chunks) > 1:
        f_chunks = G[0].f_chunk.data[1:-1] - G[0].f_chunk.data[0:-2]
        assert (f_chunks == f_chunks[0]).all()
        chan_chunks = f_chunks[0]
    else:
        if f_chunks[0]:
            chan_chunks = f_chunks[0]
        else:
            chan_chunks = -1

    columns = ('DATA', 'FLAG', 'FLAG_ROW', 'ANTENNA1', 'ANTENNA2')
    if args.acol is not None:
        columns += (args.acol,)

    # open MS
    xds = xds_from_ms(args.ms[0], chunks={'row': row_chunks, 'chan': chan_chunks},
                      columns=columns,
                      group_cols=('FIELD_ID', 'DATA_DESC_ID', 'SCAN_NUMBER'))

    # Current hack probably only works for single field and DDID
    try:
        assert len(xds) == len(G)
    except Exception as e:
        raise ValueError("Number of datasets in gains do not "
                            "match those in MS")

    # assuming scans are aligned
    out_data = []
    for g, ds in zip(G, xds):
        try:
            assert g.SCAN_NUMBER == ds.SCAN_NUMBER
        except Exception as e:
            raise ValueError("Scans not aligned")

        nrow = ds.dims['row']
        nchan = ds.dims['chan']
        ncorr = ds.dims['corr']

        # need to swap axes for africanus
        jones = da.swapaxes(g.gains.data, 1, 2)
        flag = ds.FLAG.data
        frow = ds.FLAG_ROW.data
        ant1 = ds.ANTENNA1.data
        ant2 = ds.ANTENNA2.data

        frow = (frow | (ant1 == ant2))
        flag = (flag[:, :, 0] | flag[:, :, -1])
        flag = da.logical_or(flag, frow[:, None])

        if args.acol is not None:
            acol = ds.get(args.acol).data.reshape(nrow, nchan, 1, ncorr)
        else:
            acol = da.ones((nrow, nchan, 1, ncorr),
                           chunks=(row_chunks, chan_chunks, 1, -1),
                           dtype=jones.dtype)

        cvis = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones, acol)

        # compare where unflagged
        if args.compareto is not None:
            flag = flag.compute()
            vis = ds.get(args.compareto).values[~flag]
            print("Max abs difference = ", np.abs(cvis.compute()[~flag] - vis).max())
            quit()

        out_ds = ds.assign(**{args.mueller_column: (("row", "chan", "corr"), cvis)})
        out_data.append(out_ds)

    writes = xds_to_table(out_data, args.ms[0], columns=[args.mueller_column])
    dask.compute(writes)
Beispiel #27
0
def main(args):
    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list(tmp_freq))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        print("Super resolution factor = ", cell_N / cell_rad)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size)

    if args.nx is None or args.ny is None:
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = npix
        args.ny = npix

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny))

    # init gridder
    R = Gridder(args.ms,
                args.nx,
                args.ny,
                args.cell_size,
                nband=args.nband,
                nthreads=args.nthreads,
                do_wstacking=args.do_wstacking,
                row_chunks=args.row_chunks,
                optimise_chunks=True,
                data_column=args.data_column,
                weight_column=args.weight_column,
                imaging_weight_column=args.imaging_weight_column,
                model_column=args.model_column,
                flag_column=args.flag_column)
    freq_out = R.freq_out
    radec = R.radec

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                      2 * args.nx, 2 * args.ny, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          2 * args.nx, 2 * args.ny, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        compare_headers(hdr_psf, fits.getheader(args.psf))
        psf_array = load_fits(args.psf)
    else:
        psf_array = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)

    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    counts = np.sum(psf_max > 0)
    psf_max_mean = wsum / counts
    psf_array /= psf_max_mean
    psf = PSF(psf_array, args.nthreads)
    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    psf_max[psf_max < 1e-15] = 1e-15

    if args.dirty is not None:
        compare_headers(hdr, fits.getheader(args.dirty))
        dirty = load_fits(args.dirty)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty /= psf_max_mean

    # mfs residual
    wsum = np.sum(psf_max)
    dirty_mfs = np.sum(dirty, axis=0) / wsum
    rmax = np.abs(dirty_mfs).max()
    rms = np.std(dirty_mfs)
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    psf_mfs = np.sum(psf_array, axis=0) / wsum
    save_fits(
        args.outfile + '_psf_mfs.fits', psf_mfs[args.nx // 2:3 * args.nx // 2,
                                                args.ny // 2:3 * args.ny // 2],
        hdr_mfs)

    # mask
    if args.mask is not None:
        mask = load_fits(args.mask, dtype=np.int64)
        if mask.shape != (args.nx, args.ny):
            raise ValueError("Mask has incorrect shape")
    else:
        mask = np.ones((args.nx, args.ny), dtype=np.int64)

    if args.point_mask is not None:
        pmask = load_fits(args.point_mask, dtype=np.bool)
        if pmask.shape != (args.nx, args.ny):
            raise ValueError("Mask has incorrect shape")
    else:
        pmask = None

    # Reporting
    print("At iteration 0 peak of residual is %f and rms is %f" % (rmax, rms))
    report_iters = list(np.arange(0, args.maxit, args.report_freq))
    if report_iters[-1] != args.maxit - 1:
        report_iters.append(args.maxit - 1)

    # set up point sources
    phi = Dirac(args.nband, args.nx, args.ny, mask=pmask)
    dual = np.zeros((args.nband, args.nx, args.ny), dtype=np.float64)
    weights_21 = np.where(phi.mask, 1, np.inf)

    # preconditioning matrix
    def hess(beta):
        return phi.hdot(psf.convolve(
            phi.dot(beta))) + beta / args.sig_l2**2  # vague prior on beta

    # get new spectral norm
    L = power_method(hess, dirty.shape, tol=args.pmtol, maxit=args.pmmaxit)

    # deconvolve
    eps = 1.0
    i = 0
    residual = dirty.copy()
    model = np.zeros(dirty.shape, dtype=dirty.dtype)
    for i in range(1, args.maxit):
        # find point source candidates
        if args.do_clean:
            model_tmp = hogbom(mask[None] * residual / psf_max[:, None, None],
                               psf_array / psf_max[:, None, None],
                               gamma=args.cgamma,
                               pf=args.peak_factor)
            phi.update_locs(np.any(model_tmp, axis=0))
            # get new spectral norm
            L = power_method(hess,
                             model.shape,
                             tol=args.pmtol,
                             maxit=args.pmmaxit)
        else:
            model_tmp = np.zeros_like(residual, dtype=residual.dtype)

        # solve for beta updates
        x = pcg(hess,
                phi.hdot(residual),
                phi.hdot(model_tmp),
                M=lambda x: x * args.sig_l2**2,
                tol=args.cgtol,
                maxit=args.cgmaxit,
                verbosity=args.cgverbose)

        modelp = model.copy()
        model += args.gamma * x

        # impose sparsity and positivity in point sources
        weights_21 = np.where(phi.mask, 1, 1e10)  # 1e10 for effective infinity
        model, dual = primal_dual(hess,
                                  model,
                                  modelp,
                                  dual,
                                  args.sig_21,
                                  phi,
                                  weights_21,
                                  L,
                                  tol=args.pdtol,
                                  maxit=args.pdmaxit,
                                  axis=0,
                                  positivity=args.positivity,
                                  report_freq=100)

        # update Dirac dictionary (remove zero components)
        phi.trim_fat(model)

        # get residual
        residual = R.make_residual(model) / psf_max_mean

        # check stopping criteria
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(mask * residual_mfs).max()
        rms = np.std(mask * residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        if i in report_iters:
            # save current iteration
            save_fits(args.outfile + str(i) + '_model.fits', model, hdr)

            model_mfs = np.mean(model, axis=0)
            save_fits(args.outfile + str(i) + '_model_mfs.fits', model_mfs,
                      hdr_mfs)

            save_fits(args.outfile + str(i) + '_residual.fits',
                      residual / psf_max[:, None, None], hdr)

            save_fits(args.outfile + str(i) + '_residual_mfs.fits',
                      residual_mfs, hdr_mfs)

        print(
            "At iteration %i peak of residual is %f, rms is %f, current eps is %f"
            % (i, rmax, rms, eps))

        if eps < args.tol:
            print("We have convergence!")
            break

    # final iteration with only a positivity constraint on pixel locs
    tmp = phi.hdot(model)
    x = pcg(hess,
            phi.hdot(residual),
            np.zeros_like(tmp, dtype=tmp.dtype),
            M=lambda x: x * args.sig_l2**2,
            tol=args.cgtol,
            maxit=args.cgmaxit,
            verbosity=args.cgverbose)

    modelp = model.copy()
    model += args.gamma * x
    model, dual = primal_dual(hess,
                              model,
                              modelp,
                              dual,
                              0.0,
                              phi,
                              weights_21,
                              L,
                              tol=args.pdtol,
                              maxit=args.pdmaxit,
                              axis=0,
                              report_freq=100)

    # get residual
    residual = R.make_residual(model) / psf_max_mean

    # check stopping criteria
    residual_mfs = np.sum(residual, axis=0) / wsum
    rmax = np.abs(mask * residual_mfs).max()
    rms = np.std(mask * residual_mfs)
    print("At final iteration peak of residual is %f and rms is %f" %
          (rmax, rms))

    save_fits(args.outfile + '_model.fits', model, hdr)

    model_mfs = np.mean(model, axis=0)
    save_fits(args.outfile + '_model_mfs.fits', model_mfs, hdr_mfs)

    save_fits(args.outfile + '_residual.fits',
              residual / psf_max[:, None, None], hdr)

    save_fits(args.outfile + '_residual_mfs.fits', residual_mfs, hdr_mfs)

    if args.interp_model:
        nband = args.nband
        order = args.spectral_poly_order
        phi.trim_fat(model)
        I = np.argwhere(phi.mask).squeeze()
        Ix = I[:, 0]
        Iy = I[:, 1]
        npix = I.shape[0]

        # get components
        beta = model[:, Ix, Iy]

        # fit integrated polynomial to model components
        # we are given frequencies at bin centers, convert to bin edges
        ref_freq = np.mean(freq_out)
        delta_freq = freq_out[1] - freq_out[0]
        wlow = (freq_out - delta_freq / 2.0) / ref_freq
        whigh = (freq_out + delta_freq / 2.0) / ref_freq
        wdiff = whigh - wlow

        # set design matrix for each component
        Xdesign = np.zeros([freq_out.size, args.spectral_poly_order])
        for i in range(1, args.spectral_poly_order + 1):
            Xdesign[:, i - 1] = (whigh**i - wlow**i) / (i * wdiff)

        weights = psf_max[:, None]
        dirty_comps = Xdesign.T.dot(weights * beta)

        hess_comps = Xdesign.T.dot(weights * Xdesign)

        comps = np.linalg.solve(hess_comps, dirty_comps)

        np.savez(args.outfile + "spectral_comps",
                 comps=comps,
                 ref_freq=ref_freq,
                 mask=np.any(model, axis=0))

    if args.write_model:
        if args.interp_model:
            R.write_component_model(comps, ref_freq, phi.mask, args.row_chunks,
                                    args.chan_chunks)
        else:
            R.write_model(model)
Beispiel #28
0
def _predict(ms, stack, **kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)
    pyscilog.log_to_file(args.output_filename + '.log')
    pyscilog.enable_memory_logging(level=3)

    # number of threads per worker
    if args.nthreads is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify nthreads when using a distributed scheduler"
            )
        import multiprocessing
        nthreads = multiprocessing.cpu_count()
        args.nthreads = nthreads
    else:
        nthreads = args.nthreads

    if args.mem_limit is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify mem-limit when using a distributed scheduler"
            )
        import psutil
        mem_limit = int(psutil.virtual_memory()[0] /
                        1e9)  # 100% of memory by default
        args.mem_limit = mem_limit
    else:
        mem_limit = args.mem_limit

    nband = args.nband
    if args.nworkers is None:
        nworkers = nband
        args.nworkers = nworkers
    else:
        nworkers = args.nworkers

    if args.nthreads_per_worker is None:
        nthreads_per_worker = 1
        args.nthreads_per_worker = nthreads_per_worker
    else:
        nthreads_per_worker = args.nthreads_per_worker

    # the number of chunks being read in simultaneously is equal to
    # the number of dask threads
    nthreads_dask = nworkers * nthreads_per_worker

    if args.ngridder_threads is None:
        if args.host_address is not None:
            ngridder_threads = nthreads // nthreads_per_worker
        else:
            ngridder_threads = nthreads // nthreads_dask
        args.ngridder_threads = ngridder_threads
    else:
        ngridder_threads = args.ngridder_threads

    ms = list(ms)
    print('Input Options:', file=log)
    for key in kw.keys():
        print('     %25s = %s' % (key, args[key]), file=log)

    # numpy imports have to happen after this step
    from pfb import set_client
    set_client(nthreads, mem_limit, nworkers, nthreads_per_worker,
               args.host_address, stack, log)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    from dask.distributed import performance_report
    from dask.graph_manipulation import clone
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    from daskms.utils import dataset_type
    mstype = dataset_type(ms[0])
    if mstype == 'casa':
        from daskms import xds_to_table
    elif mstype == 'zarr':
        from daskms.experimental.zarr import xds_to_zarr as xds_to_table
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import model as im2vis
    from pfb.utils.fits import load_fits
    from pfb.utils.misc import restore_corrs, plan_row_chunk
    from astropy.io import fits

    # always returns 4D
    # gridder expects freq axis
    model = np.atleast_3d(load_fits(args.model).squeeze())
    nband, nx, ny = model.shape
    hdr = fits.getheader(args.model)
    cell_d = np.abs(hdr['CDELT1'])
    cell_rad = np.deg2rad(cell_d)

    # chan <-> band mapping
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # degridder memory budget
    max_chan_chunk = 0
    for ims in ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())

    # assumes number of correlations are the same across MS/SPW
    xds = xds_from_ms(ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    if args.output_type is not None:
        output_type = np.dtype(args.output_type)
    else:
        output_type = np.result_type(np.dtype(args.real_type), np.complex64)
    data_bytes = output_type.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row  # model
    memory_per_row += 3 * 8  # uvw

    if mstype == 'zarr':
        if args.model_column in xds[0].keys():
            model_chunks = getattr(xds[0], args.model_column).data.chunks
        else:
            model_chunks = xds[0].DATA.data.chunks
            print('Chunking model same as data')

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow,
                                   memory_per_row, nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row,
                                   nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    model = da.from_array(model.astype(args.real_type),
                          chunks=(1, nx, ny),
                          name=False)
    writes = []
    radec = None  # assumes we are only imaging field 0 of first MS
    for ims in ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=('UVW'))

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        out_data = []
        for ds in xds:
            field = fields[ds.FIELD_ID]
            radec = field.PHASE_DIR.data.squeeze()

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

            uvw = clone(ds.UVW.data)

            bands = band_mapping[ims][spw]
            model = model[list(bands), :, :]
            vis = im2vis(uvw,
                         freqs[ims][spw],
                         model,
                         freq_bin_idx[ims][spw],
                         freq_bin_counts[ims][spw],
                         cell_rad,
                         nthreads=ngridder_threads,
                         epsilon=args.epsilon,
                         do_wstacking=args.wstack)

            model_vis = restore_corrs(vis, ncorr)
            if mstype == 'zarr':
                model_vis = model_vis.rechunk(model_chunks)
                uvw = uvw.rechunk((model_chunks[0], 3))

            out_ds = ds.assign(
                **{
                    args.model_column: (("row", "chan", "corr"), model_vis),
                    'UVW': (("row", "three"), uvw)
                })
            # out_ds = ds.assign(**{args.model_column: (("row", "chan", "corr"), model_vis)})
            out_data.append(out_ds)

        writes.append(xds_to_table(out_data, ims, columns=[args.model_column]))

    dask.visualize(*writes,
                   filename=args.output_filename + '_predict_graph.pdf',
                   optimize_graph=False,
                   collapse_outputs=True)

    if not args.mock:
        with performance_report(filename=args.output_filename +
                                '_predict_per.html'):
            dask.compute(writes, optimize_graph=False)

    print("All done here.", file=log)
Beispiel #29
0
        # Create a dataset representing the entire antenna table
        ant_table = '::'.join((args.ms, 'ANTENNA'))

        for ant_ds in xds_from_table(ant_table):
            print(ant_ds)
            print(
                dask.compute(ant_ds.NAME.data, ant_ds.POSITION.data,
                             ant_ds.DISH_DIAMETER.data))

        # Create datasets representing each row of the spw table
        spw_table = '::'.join((args.ms, 'SPECTRAL_WINDOW'))

        for spw_ds in xds_from_table(spw_table, group_cols="__row__"):
            print(spw_ds)
            print(spw_ds.NUM_CHAN.values)
            print(spw_ds.CHAN_FREQ.values)

        # Create datasets from a partioning of the MS
        datasets = list(xds_from_ms(args.ms, chunks={'row': args.chunks}))

        for ds in datasets:
            print(ds)

            # Try write the STATE_ID column back
            write = xds_to_table(ds, args.ms, 'STATE_ID')
            with ProgressBar(), Profiler() as prof:
                write.compute()

            # Profile
            prof.visualize(file_path="chunked.html")
Beispiel #30
0
    def make_psf(self):
        print("Making PSF", file=log)
        psfs = []
        self.stokes_weights = {}
        self.uvws = {}
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]
            self.stokes_weights[ims] = {}
            self.uvws[ims] = {}

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                # this is not correct, need to use spw
                spw = ds.DATA_DESC_ID

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                flag = getattr(ds, self.flag_column).data

                weights = getattr(ds, self.weight_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],
                                              flag.shape,
                                              chunks=flag.chunks)

                if self.imaging_weight_column is not None:
                    imaging_weights = getattr(ds,
                                              self.imaging_weight_column).data
                    if len(imaging_weights.shape) < 3:
                        imaging_weights = da.broadcast_to(
                            imaging_weights[:, None, :],
                            flag.shape,
                            chunks=flag.chunks)

                    weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                    weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                else:
                    weightsxx = weights[:, :, 0]
                    weightsyy = weights[:, :, -1]

                # for the PSF we need to scale the weights by the
                # Mueller amplitudes squared
                if self.mueller_column is not None:
                    mueller = getattr(ds, self.mueller_column).data
                    weightsxx *= da.absolute(mueller[:, :, 0])**2
                    weightsyy *= da.absolute(mueller[:, :, -1])**2

                # weighted sum corr to Stokes I
                weights = weightsxx + weightsyy

                # only keep data where both corrs are unflagged
                flagxx = flag[:, :, 0]
                flagyy = flag[:, :, -1]
                flag = ~(flagxx | flagyy)  # ducc0 convention

                weights *= flag

                data = weights.astype(np.complex64)

                psf = vis2im(uvw,
                             freq,
                             data,
                             freq_bin_idx,
                             freq_bin_counts,
                             self.nx_psf,
                             self.ny_psf,
                             self.cell,
                             flag=flag.astype(np.uint8),
                             nthreads=self.nthreads,
                             epsilon=self.epsilon,
                             do_wstacking=self.do_wstacking,
                             double_accum=True)

                psfs.append(psf)

                # assumes that stokes weights and uvw fit into memory
                # self.stokes_weights[ims][spw] = dask.persist(weights.rechunk({0:-1}))[0]
                # self.uvws[ims][spw] = dask.persist(uvw.rechunk({0:-1}))[0]

                # for comparison with numpy implementation
                # self.stokes_weights[ims][spw] = dask.compute(weights)[0]
                # self.uvws[ims][spw] = dask.compute(uvw)[0]

        # import pdb
        # pdb.set_trace()

        psfs = dask.compute(psfs, scheduler='single-threaded')[0]
        return accumulate_dirty(psfs, self.nband,
                                self.band_mapping).astype(self.real_type)