コード例 #1
0
def test_add_datavars(ms, tmp_path_factory, prechunking, postchunking):
    store = tmp_path_factory.mktemp("zarr_store")
    ref_datasets = xds_from_ms(ms)

    for i, ds in enumerate(ref_datasets):
        chunks = ds.chunks
        row = sum(chunks["row"])
        chan = sum(chunks["chan"])
        corr = sum(chunks["corr"])

        ref_datasets[i] = ds.assign_coords(
            row=np.arange(row),
            chan=np.arange(chan),
            corr=np.arange(corr),
            dummy=np.arange(10)  # Orphan coordinate.
        )

    chunked_datasets = [ds.chunk(prechunking) for ds in ref_datasets]
    dask.compute(xds_to_zarr(chunked_datasets, store))

    rechunked_datasets = [ds.chunk(postchunking)
                          for ds in xds_from_zarr(store)]
    augmented_datasets = [ds.assign({"DUMMY": (("row", "chan", "corr"),
                                    da.zeros_like(ds.DATA.data))})
                          for ds in rechunked_datasets]
    dask.compute(xds_to_zarr(augmented_datasets, store, rechunk=True))

    augmented_datasets = xds_from_zarr(store)

    assert all([ds.DUMMY.chunks == cds.DATA.chunks
                for ds, cds in zip(augmented_datasets, chunked_datasets)])
コード例 #2
0
def test_rechunking(ms, tmp_path_factory, prechunking, postchunking):
    store = tmp_path_factory.mktemp("zarr_store")
    ref_datasets = xds_from_ms(ms)

    for i, ds in enumerate(ref_datasets):
        chunks = ds.chunks
        row = sum(chunks["row"])
        chan = sum(chunks["chan"])
        corr = sum(chunks["corr"])

        ref_datasets[i] = ds.assign_coords(
            row=np.arange(row),
            chan=np.arange(chan),
            corr=np.arange(corr),
            dummy=np.arange(10)  # Orphan coordinate.
        )

    chunked_datasets = [ds.chunk(prechunking) for ds in ref_datasets]
    dask.compute(xds_to_zarr(chunked_datasets, store))

    rechunked_datasets = [ds.chunk(postchunking)
                          for ds in xds_from_zarr(store)]
    dask.compute(xds_to_zarr(rechunked_datasets, store, rechunk=True))

    rechunked_datasets = xds_from_zarr(store)

    assert all([ds.equals(rds)
                for ds, rds in zip(rechunked_datasets, ref_datasets)])
コード例 #3
0
def test_basic_roundtrip(tmp_path):

    path = tmp_path / "test.zarr"

    # We need >10 datasets to be sure roundtripping is consistent.
    xdsl = [Dataset({'x': (('y',), da.ones(i))}) for i in range(1, 12)]
    dask.compute(xds_to_zarr(xdsl, path))

    xdsl = xds_from_zarr(path)
    dask.compute(xds_to_zarr(xdsl, path))
コード例 #4
0
def test_xds_to_zarr(ms, spw_table, ant_table, tmp_path_factory):
    zarr_store = tmp_path_factory.mktemp("zarr_store") / "test.zarr"
    spw_store = zarr_store.parent / f"{zarr_store.name}::SPECTRAL_WINDOW"
    ant_store = zarr_store.parent / f"{zarr_store.name}::ANTENNA"

    ms_datasets = xds_from_ms(ms)
    spw_datasets = xds_from_table(spw_table, group_cols="__row__")
    ant_datasets = xds_from_table(ant_table)

    for i, ds in enumerate(ms_datasets):
        dims = ds.dims
        row, chan, corr = (dims[d] for d in ("row", "chan", "corr"))

        ms_datasets[i] = ds.assign_coords(
            **{
                "chan": (("chan", ), np.arange(chan)),
                "corr": (("corr", ), np.arange(corr)),
            })

    main_zarr_writes = xds_to_zarr(ms_datasets, zarr_store)
    assert len(ms_datasets) == len(main_zarr_writes)

    for ms_ds, zw_ds in zip(ms_datasets, main_zarr_writes):
        for k, _ in ms_ds.attrs[DASKMS_PARTITION_KEY]:
            assert getattr(ms_ds, k) == getattr(zw_ds, k)

    writes = [main_zarr_writes]
    writes.extend(xds_to_zarr(spw_datasets, spw_store))
    writes.extend(xds_to_zarr(ant_datasets, ant_store))
    dask.compute(writes)

    zarr_datasets = xds_from_zarr(zarr_store, chunks={"row": 1})

    for ms_ds, zarr_ds in zip(ms_datasets, zarr_datasets):
        # Check data variables
        assert ms_ds.data_vars, "MS Dataset has no variables"

        for name, var in ms_ds.data_vars.items():
            zdata = getattr(zarr_ds, name).data
            assert type(zdata) is type(var.data)  # noqa
            assert_array_equal(var.data, zdata)

        # Check coordinates
        assert ms_ds.coords, "MS Datset has no coordinates"

        for name, var in ms_ds.coords.items():
            zdata = getattr(zarr_ds, name).data
            assert type(zdata) is type(var.data)  # noqa
            assert_array_equal(var.data, zdata)

        # Check dataset attributes
        for k, v in ms_ds.attrs.items():
            zattr = getattr(zarr_ds, k)
            assert_array_equal(zattr, v)
コード例 #5
0
def zarr_tester(ms, spw_table, ant_table,
                zarr_store, spw_store, ant_store):

    ms_datasets = xds_from_ms(ms)
    spw_datasets = xds_from_table(spw_table, group_cols="__row__")
    ant_datasets = xds_from_table(ant_table)

    for i, ds in enumerate(ms_datasets):
        dims = ds.dims
        row, chan, corr = (dims[d] for d in ("row", "chan", "corr"))

        ms_datasets[i] = ds.assign_coords(**{
            "chan": (("chan",), np.arange(chan)),
            "corr": (("corr",), np.arange(corr)),
        })

    main_zarr_writes = xds_to_zarr(ms_datasets, zarr_store.url,
                                   storage_options=zarr_store.storage_options)
    assert len(ms_datasets) == len(main_zarr_writes)

    for ms_ds, zw_ds in zip(ms_datasets, main_zarr_writes):
        for k, _ in ms_ds.attrs[DASKMS_PARTITION_KEY]:
            assert getattr(ms_ds, k) == getattr(zw_ds, k)

    writes = [main_zarr_writes]
    writes.extend(xds_to_zarr(spw_datasets, spw_store))
    writes.extend(xds_to_zarr(ant_datasets, ant_store))
    dask.compute(writes)

    zarr_datasets = xds_from_storage_ms(zarr_store, chunks={"row": 1})

    for ms_ds, zarr_ds in zip(ms_datasets, zarr_datasets):
        # Check data variables
        assert ms_ds.data_vars, "MS Dataset has no variables"

        for name, var in ms_ds.data_vars.items():
            zdata = getattr(zarr_ds, name).data
            assert type(zdata) is type(var.data)  # noqa
            assert_array_equal(var.data, zdata)

        # Check coordinates
        assert ms_ds.coords, "MS Datset has no coordinates"

        for name, var in ms_ds.coords.items():
            zdata = getattr(zarr_ds, name).data
            assert type(zdata) is type(var.data)  # noqa
            assert_array_equal(var.data, zdata)

        # Check dataset attributes
        for k, v in ms_ds.attrs.items():
            zattr = getattr(zarr_ds, k)
            assert_array_equal(zattr, v)
コード例 #6
0
def test_xds_to_zarr_coords(tmp_path_factory):
    zarr_store = tmp_path_factory.mktemp("zarr_coords") / "test.zarr"

    data = da.ones((100, 16, 4), chunks=(10, 4, 1), dtype=np.complex64)
    rowid = da.arange(100, chunks=10)

    data_vars = {"DATA": (("row", "chan", "corr"), data)}
    coords = {
        "ROWID": (("row",), rowid),
        "chan": (("chan",), np.arange(16)),
        "foo": (("foo",), np.arange(4)),
    }

    ds = [Dataset(data_vars, coords=coords)]

    writes = xds_to_zarr(ds, zarr_store)
    dask.compute(writes)

    rds = xds_from_zarr(zarr_store)
    assert len(ds) == len(rds)

    for ods, nds in zip(ds, rds):
        for c, v in ods.data_vars.items():
            assert_array_equal(v.data, getattr(nds, c).data)

        for c, v in ods.coords.items():
            assert_array_equal(v.data, getattr(nds, c).data)
コード例 #7
0
def test_zarr_string_array(tmp_path_factory):
    zarr_store = tmp_path_factory.mktemp("string-arrays") / "test.zarr"

    data = ["hello", "this", "strange new world",
            "full of", "interesting", "stuff"]
    data = np.array(data, dtype=object).reshape(3, 2)
    data = da.from_array(data, chunks=((2, 1), (1, 1)))

    datasets = [Dataset({"DATA": (("x", "y"), data)})]
    writes = xds_to_zarr(datasets, zarr_store)
    dask.compute(writes)

    new_datasets = xds_from_zarr(zarr_store)

    assert len(new_datasets) == len(datasets)

    for nds, ds in zip(new_datasets, datasets):
        assert_array_equal(nds.DATA.data, ds.DATA.data)
コード例 #8
0
def xds_to_storage_table(xds, store, **kwargs):
    if not isinstance(store, DaskMSStore):
        store = DaskMSStore(store, **kwargs.pop("storage_options", {}))

    typ = store.type()

    if typ == "casa":
        filter_kwargs(xds_to_table, kwargs)
        return xds_to_table(xds, store, **kwargs)
    elif typ == "zarr":
        from daskms.experimental.zarr import xds_to_zarr
        filter_kwargs(xds_to_zarr, kwargs)
        return xds_to_zarr(xds, store, **kwargs)
    elif typ == "parquet":
        from daskms.experimental.arrow import xds_to_parquet
        filter_kwargs(xds_to_parquet, kwargs)
        return xds_to_parquet(xds, store, **kwargs)
    else:
        raise TypeError(f"Unknown dataset {typ}")
コード例 #9
0
def test_storage_zarr(ms, tmp_path_factory):

    zarr_store = tmp_path_factory.mktemp("zarr") / "test.zarr"

    oxdsl = xds_from_ms(ms)

    writes = xds_to_zarr(oxdsl, zarr_store)

    dask.compute(writes)

    oxdsl = xds_from_zarr(zarr_store)

    writes = xds_to_storage_table(oxdsl, zarr_store)

    oxdsl = dask.compute(oxdsl)[0]

    dask.compute(writes)

    xdsl = dask.compute(xds_from_zarr(zarr_store))[0]

    assert all([xds.equals(oxds) for xds, oxds in zip(xdsl, oxdsl)])
コード例 #10
0
def test_multiprocess_create(ms, tmp_path_factory):
    zarr_store = tmp_path_factory.mktemp("zarr_store") / "test.zarr"

    ms_datasets = xds_from_ms(ms)

    for i, ds in enumerate(ms_datasets):
        ms_datasets[i] = ds.chunk({"row": 1})

    writes = xds_to_zarr(ms_datasets, zarr_store)

    dask.compute(writes, scheduler="processes")

    zds = xds_from_zarr(zarr_store)

    for zds, msds in zip(zds, ms_datasets):
        for k, v in msds.data_vars.items():
            assert_array_equal(v, getattr(zds, k))

        for k, v in msds.coords.items():
            assert_array_equal(v, getattr(zds, k))

        for k, v in msds.attrs.items():
            assert_array_equal(v, getattr(zds, k))
コード例 #11
0
ファイル: psf.py プロジェクト: ratt-ru/pfb-clean
def _psf(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    # from dask.distributed import performance_report
    from dask.graph_manipulation import clone
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    from daskms import Dataset
    from daskms.experimental.zarr import xds_to_zarr
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import dirty as vis2im
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    ms = args.ms
    nband = args.nband
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in args.ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(args.ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    # we still have to cater for complex valued data because we cast
    # the weights to complex but we not longer need to factor the
    # weight column into our memory budget
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # flags (uint8 or bool)
    memory_per_row += bytes_per_row / 8

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    # TIME
    memory_per_row += xds[0].TIME.data.itemsize

    # data column is not actually read into memory just used to infer
    # dtype and chunking
    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2', 'TIME')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in args.ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(args.psf_oversize * fov / cell_size)
        if npix % 2:
            npix += 1
        nx = npix
        ny = npix
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size,
                                   nrow, memory_per_row,
                                   args.nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                   memory_per_row, args.nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in args.ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    psfs = []
    radec = None  # assumes we are only imaging field 0 of first MS
    out_datasets = []
    for ims in args.ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data_type = getattr(ds, args.data_column).data.dtype
            data_shape = getattr(ds, args.data_column).data.shape
            data_chunks = getattr(ds, args.data_column).data.chunks

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data_shape,
                                          chunks=data_chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data_shape,
                                                      chunks=data_chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply mueller term
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                weightsxx *= da.absolute(mueller[:, :, 0])**2
                weightsyy *= da.absolute(mueller[:, :, -1])**2

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            psf = vis2im(uvw,
                         freqs[ims][spw],
                         weights.astype(data_type),
                         freq_bin_idx[ims][spw],
                         freq_bin_counts[ims][spw],
                         nx,
                         ny,
                         cell_rad,
                         flag=mask.astype(np.uint8),
                         nthreads=args.nvthreads,
                         epsilon=args.epsilon,
                         do_wstacking=args.wstack,
                         double_accum=args.double_accum)

            psfs.append(psf)

            data_vars = {
                'FIELD_ID': (('row', ),
                             da.full_like(ds.TIME.data,
                                          ds.FIELD_ID,
                                          chunks=args.row_out_chunk)),
                'DATA_DESC_ID': (('row', ),
                                 da.full_like(ds.TIME.data,
                                              ds.DATA_DESC_ID,
                                              chunks=args.row_out_chunk)),
                'WEIGHT':
                (('row', 'chan'), weights.rechunk({0: args.row_out_chunk
                                                   })),  # why no 'f4'?
                'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk}))
            }

            coords = {'chan': (('chan', ), freqs[ims][spw])}

            out_ds = Dataset(data_vars, coords)

            out_datasets.append(out_ds)

    writes = xds_to_zarr(out_datasets,
                         args.output_filename + '.zarr',
                         columns='ALL')

    # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False)
    # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False)

    if not args.mock:
        # psfs = dask.compute(psfs, writes, optimize_graph=False)[0]
        # with performance_report(filename=args.output_filename + '_psf_per.html'):
        psfs = dask.compute(psfs, writes, optimize_graph=False)[0]

        psf = stitch_images(psfs, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_psf.fits',
                  psf,
                  hdr,
                  dtype=args.output_type)

        psf_mfs = np.sum(psf, axis=0)
        wsum = psf_mfs.max()
        psf_mfs /= wsum

        hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                          np.mean(freq_out))
        save_fits(args.output_filename + '_psf_mfs.fits',
                  psf_mfs,
                  hdr_mfs,
                  dtype=args.output_type)

    print("All done here.", file=log)