def test_add_datavars(ms, tmp_path_factory, prechunking, postchunking): store = tmp_path_factory.mktemp("zarr_store") ref_datasets = xds_from_ms(ms) for i, ds in enumerate(ref_datasets): chunks = ds.chunks row = sum(chunks["row"]) chan = sum(chunks["chan"]) corr = sum(chunks["corr"]) ref_datasets[i] = ds.assign_coords( row=np.arange(row), chan=np.arange(chan), corr=np.arange(corr), dummy=np.arange(10) # Orphan coordinate. ) chunked_datasets = [ds.chunk(prechunking) for ds in ref_datasets] dask.compute(xds_to_zarr(chunked_datasets, store)) rechunked_datasets = [ds.chunk(postchunking) for ds in xds_from_zarr(store)] augmented_datasets = [ds.assign({"DUMMY": (("row", "chan", "corr"), da.zeros_like(ds.DATA.data))}) for ds in rechunked_datasets] dask.compute(xds_to_zarr(augmented_datasets, store, rechunk=True)) augmented_datasets = xds_from_zarr(store) assert all([ds.DUMMY.chunks == cds.DATA.chunks for ds, cds in zip(augmented_datasets, chunked_datasets)])
def test_rechunking(ms, tmp_path_factory, prechunking, postchunking): store = tmp_path_factory.mktemp("zarr_store") ref_datasets = xds_from_ms(ms) for i, ds in enumerate(ref_datasets): chunks = ds.chunks row = sum(chunks["row"]) chan = sum(chunks["chan"]) corr = sum(chunks["corr"]) ref_datasets[i] = ds.assign_coords( row=np.arange(row), chan=np.arange(chan), corr=np.arange(corr), dummy=np.arange(10) # Orphan coordinate. ) chunked_datasets = [ds.chunk(prechunking) for ds in ref_datasets] dask.compute(xds_to_zarr(chunked_datasets, store)) rechunked_datasets = [ds.chunk(postchunking) for ds in xds_from_zarr(store)] dask.compute(xds_to_zarr(rechunked_datasets, store, rechunk=True)) rechunked_datasets = xds_from_zarr(store) assert all([ds.equals(rds) for ds, rds in zip(rechunked_datasets, ref_datasets)])
def test_basic_roundtrip(tmp_path): path = tmp_path / "test.zarr" # We need >10 datasets to be sure roundtripping is consistent. xdsl = [Dataset({'x': (('y',), da.ones(i))}) for i in range(1, 12)] dask.compute(xds_to_zarr(xdsl, path)) xdsl = xds_from_zarr(path) dask.compute(xds_to_zarr(xdsl, path))
def test_xds_to_zarr(ms, spw_table, ant_table, tmp_path_factory): zarr_store = tmp_path_factory.mktemp("zarr_store") / "test.zarr" spw_store = zarr_store.parent / f"{zarr_store.name}::SPECTRAL_WINDOW" ant_store = zarr_store.parent / f"{zarr_store.name}::ANTENNA" ms_datasets = xds_from_ms(ms) spw_datasets = xds_from_table(spw_table, group_cols="__row__") ant_datasets = xds_from_table(ant_table) for i, ds in enumerate(ms_datasets): dims = ds.dims row, chan, corr = (dims[d] for d in ("row", "chan", "corr")) ms_datasets[i] = ds.assign_coords( **{ "chan": (("chan", ), np.arange(chan)), "corr": (("corr", ), np.arange(corr)), }) main_zarr_writes = xds_to_zarr(ms_datasets, zarr_store) assert len(ms_datasets) == len(main_zarr_writes) for ms_ds, zw_ds in zip(ms_datasets, main_zarr_writes): for k, _ in ms_ds.attrs[DASKMS_PARTITION_KEY]: assert getattr(ms_ds, k) == getattr(zw_ds, k) writes = [main_zarr_writes] writes.extend(xds_to_zarr(spw_datasets, spw_store)) writes.extend(xds_to_zarr(ant_datasets, ant_store)) dask.compute(writes) zarr_datasets = xds_from_zarr(zarr_store, chunks={"row": 1}) for ms_ds, zarr_ds in zip(ms_datasets, zarr_datasets): # Check data variables assert ms_ds.data_vars, "MS Dataset has no variables" for name, var in ms_ds.data_vars.items(): zdata = getattr(zarr_ds, name).data assert type(zdata) is type(var.data) # noqa assert_array_equal(var.data, zdata) # Check coordinates assert ms_ds.coords, "MS Datset has no coordinates" for name, var in ms_ds.coords.items(): zdata = getattr(zarr_ds, name).data assert type(zdata) is type(var.data) # noqa assert_array_equal(var.data, zdata) # Check dataset attributes for k, v in ms_ds.attrs.items(): zattr = getattr(zarr_ds, k) assert_array_equal(zattr, v)
def zarr_tester(ms, spw_table, ant_table, zarr_store, spw_store, ant_store): ms_datasets = xds_from_ms(ms) spw_datasets = xds_from_table(spw_table, group_cols="__row__") ant_datasets = xds_from_table(ant_table) for i, ds in enumerate(ms_datasets): dims = ds.dims row, chan, corr = (dims[d] for d in ("row", "chan", "corr")) ms_datasets[i] = ds.assign_coords(**{ "chan": (("chan",), np.arange(chan)), "corr": (("corr",), np.arange(corr)), }) main_zarr_writes = xds_to_zarr(ms_datasets, zarr_store.url, storage_options=zarr_store.storage_options) assert len(ms_datasets) == len(main_zarr_writes) for ms_ds, zw_ds in zip(ms_datasets, main_zarr_writes): for k, _ in ms_ds.attrs[DASKMS_PARTITION_KEY]: assert getattr(ms_ds, k) == getattr(zw_ds, k) writes = [main_zarr_writes] writes.extend(xds_to_zarr(spw_datasets, spw_store)) writes.extend(xds_to_zarr(ant_datasets, ant_store)) dask.compute(writes) zarr_datasets = xds_from_storage_ms(zarr_store, chunks={"row": 1}) for ms_ds, zarr_ds in zip(ms_datasets, zarr_datasets): # Check data variables assert ms_ds.data_vars, "MS Dataset has no variables" for name, var in ms_ds.data_vars.items(): zdata = getattr(zarr_ds, name).data assert type(zdata) is type(var.data) # noqa assert_array_equal(var.data, zdata) # Check coordinates assert ms_ds.coords, "MS Datset has no coordinates" for name, var in ms_ds.coords.items(): zdata = getattr(zarr_ds, name).data assert type(zdata) is type(var.data) # noqa assert_array_equal(var.data, zdata) # Check dataset attributes for k, v in ms_ds.attrs.items(): zattr = getattr(zarr_ds, k) assert_array_equal(zattr, v)
def test_xds_to_zarr_coords(tmp_path_factory): zarr_store = tmp_path_factory.mktemp("zarr_coords") / "test.zarr" data = da.ones((100, 16, 4), chunks=(10, 4, 1), dtype=np.complex64) rowid = da.arange(100, chunks=10) data_vars = {"DATA": (("row", "chan", "corr"), data)} coords = { "ROWID": (("row",), rowid), "chan": (("chan",), np.arange(16)), "foo": (("foo",), np.arange(4)), } ds = [Dataset(data_vars, coords=coords)] writes = xds_to_zarr(ds, zarr_store) dask.compute(writes) rds = xds_from_zarr(zarr_store) assert len(ds) == len(rds) for ods, nds in zip(ds, rds): for c, v in ods.data_vars.items(): assert_array_equal(v.data, getattr(nds, c).data) for c, v in ods.coords.items(): assert_array_equal(v.data, getattr(nds, c).data)
def test_zarr_string_array(tmp_path_factory): zarr_store = tmp_path_factory.mktemp("string-arrays") / "test.zarr" data = ["hello", "this", "strange new world", "full of", "interesting", "stuff"] data = np.array(data, dtype=object).reshape(3, 2) data = da.from_array(data, chunks=((2, 1), (1, 1))) datasets = [Dataset({"DATA": (("x", "y"), data)})] writes = xds_to_zarr(datasets, zarr_store) dask.compute(writes) new_datasets = xds_from_zarr(zarr_store) assert len(new_datasets) == len(datasets) for nds, ds in zip(new_datasets, datasets): assert_array_equal(nds.DATA.data, ds.DATA.data)
def xds_to_storage_table(xds, store, **kwargs): if not isinstance(store, DaskMSStore): store = DaskMSStore(store, **kwargs.pop("storage_options", {})) typ = store.type() if typ == "casa": filter_kwargs(xds_to_table, kwargs) return xds_to_table(xds, store, **kwargs) elif typ == "zarr": from daskms.experimental.zarr import xds_to_zarr filter_kwargs(xds_to_zarr, kwargs) return xds_to_zarr(xds, store, **kwargs) elif typ == "parquet": from daskms.experimental.arrow import xds_to_parquet filter_kwargs(xds_to_parquet, kwargs) return xds_to_parquet(xds, store, **kwargs) else: raise TypeError(f"Unknown dataset {typ}")
def test_storage_zarr(ms, tmp_path_factory): zarr_store = tmp_path_factory.mktemp("zarr") / "test.zarr" oxdsl = xds_from_ms(ms) writes = xds_to_zarr(oxdsl, zarr_store) dask.compute(writes) oxdsl = xds_from_zarr(zarr_store) writes = xds_to_storage_table(oxdsl, zarr_store) oxdsl = dask.compute(oxdsl)[0] dask.compute(writes) xdsl = dask.compute(xds_from_zarr(zarr_store))[0] assert all([xds.equals(oxds) for xds, oxds in zip(xdsl, oxdsl)])
def test_multiprocess_create(ms, tmp_path_factory): zarr_store = tmp_path_factory.mktemp("zarr_store") / "test.zarr" ms_datasets = xds_from_ms(ms) for i, ds in enumerate(ms_datasets): ms_datasets[i] = ds.chunk({"row": 1}) writes = xds_to_zarr(ms_datasets, zarr_store) dask.compute(writes, scheduler="processes") zds = xds_from_zarr(zarr_store) for zds, msds in zip(zds, ms_datasets): for k, v in msds.data_vars.items(): assert_array_equal(v, getattr(zds, k)) for k, v in msds.coords.items(): assert_array_equal(v, getattr(zds, k)) for k, v in msds.attrs.items(): assert_array_equal(v, getattr(zds, k))
def _psf(**kw): args = OmegaConf.create(kw) from omegaconf import ListConfig if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig): args.ms = [args.ms] OmegaConf.set_struct(args, True) import numpy as np from pfb.utils.misc import chan_to_band_mapping import dask # from dask.distributed import performance_report from dask.graph_manipulation import clone from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table from daskms import Dataset from daskms.experimental.zarr import xds_to_zarr import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import dirty as vis2im from ducc0.fft import good_size from pfb.utils.misc import stitch_images, plan_row_chunk from pfb.utils.fits import set_wcs, save_fits # chan <-> band mapping ms = args.ms nband = args.nband freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping( ms, nband=nband) # gridder memory budget max_chan_chunk = 0 max_freq = 0 for ims in args.ms: for spw in freqs[ims]: counts = freq_bin_counts[ims][spw].compute() freq = freqs[ims][spw].compute() max_chan_chunk = np.maximum(max_chan_chunk, counts.max()) max_freq = np.maximum(max_freq, freq.max()) # assumes measurement sets have the same columns, # number of correlations etc. xds = xds_from_ms(args.ms[0]) ncorr = xds[0].dims['corr'] nrow = xds[0].dims['row'] # we still have to cater for complex valued data because we cast # the weights to complex but we not longer need to factor the # weight column into our memory budget data_bytes = getattr(xds[0], args.data_column).data.itemsize bytes_per_row = max_chan_chunk * ncorr * data_bytes memory_per_row = bytes_per_row # flags (uint8 or bool) memory_per_row += bytes_per_row / 8 # UVW memory_per_row += xds[0].UVW.data.itemsize * 3 # ANTENNA1/2 memory_per_row += xds[0].ANTENNA1.data.itemsize * 2 # TIME memory_per_row += xds[0].TIME.data.itemsize # data column is not actually read into memory just used to infer # dtype and chunking columns = (args.data_column, args.weight_column, args.flag_column, 'UVW', 'ANTENNA1', 'ANTENNA2', 'TIME') # flag row if 'FLAG_ROW' in xds[0]: columns += ('FLAG_ROW', ) memory_per_row += xds[0].FLAG_ROW.data.itemsize # imaging weights if args.imaging_weight_column is not None: columns += (args.imaging_weight_column, ) memory_per_row += bytes_per_row / 2 # Mueller term (complex valued) if args.mueller_column is not None: columns += (args.mueller_column, ) memory_per_row += bytes_per_row # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 for ims in args.ms: xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1}) for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) uv_max = uv_max.compute() del uvw # image size cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed) if args.cell_size is not None: cell_size = args.cell_size cell_rad = cell_size * np.pi / 60 / 60 / 180 if cell_N / cell_rad < 1: raise ValueError( "Requested cell size too small. " "Super resolution factor = ", cell_N / cell_rad) print("Super resolution factor = %f" % (cell_N / cell_rad), file=log) else: cell_rad = cell_N / args.super_resolution_factor cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % cell_size, file=log) if args.nx is None: fov = args.field_of_view * 3600 npix = int(args.psf_oversize * fov / cell_size) if npix % 2: npix += 1 nx = npix ny = npix else: nx = args.nx ny = args.ny if args.ny is not None else nx print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log) # get approx image size # this is not a conservative estimate when multiple SPW's map to a single # imaging band pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # full image on single node row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size, nrow, memory_per_row, args.nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow, memory_per_row, args.nthreads_per_worker) if args.row_chunks is not None: row_chunk = int(args.row_chunks) if row_chunk == -1: row_chunk = nrow print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) chunks = {} for ims in args.ms: chunks[ims] = [] # xds_from_ms expects a list per ds for spw in freqs[ims]: chunks[ims].append({ 'row': row_chunk, 'chan': chan_chunks[ims][spw]['chan'] }) psfs = [] radec = None # assumes we are only imaging field 0 of first MS out_datasets = [] for ims in args.ms: xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue # this is not correct, need to use spw spw = ds.DATA_DESC_ID uvw = clone(ds.UVW.data) data_type = getattr(ds, args.data_column).data.dtype data_shape = getattr(ds, args.data_column).data.shape data_chunks = getattr(ds, args.data_column).data.chunks weights = getattr(ds, args.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data_shape, chunks=data_chunks) if args.imaging_weight_column is not None: imaging_weights = getattr(ds, args.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data_shape, chunks=data_chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # apply mueller term if args.mueller_column is not None: mueller = getattr(ds, args.mueller_column).data weightsxx *= da.absolute(mueller[:, :, 0])**2 weightsyy *= da.absolute(mueller[:, :, -1])**2 # weighted sum corr to Stokes I weights = weightsxx + weightsyy # MS may contain auto-correlations if 'FLAG_ROW' in xds[0]: frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data == ds.ANTENNA2.data) else: frow = (ds.ANTENNA1.data == ds.ANTENNA2.data) # only keep data where both corrs are unflagged flag = getattr(ds, args.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] # ducc0 uses uint8 mask not flag mask = ~da.logical_or((flagxx | flagyy), frow[:, None]) psf = vis2im(uvw, freqs[ims][spw], weights.astype(data_type), freq_bin_idx[ims][spw], freq_bin_counts[ims][spw], nx, ny, cell_rad, flag=mask.astype(np.uint8), nthreads=args.nvthreads, epsilon=args.epsilon, do_wstacking=args.wstack, double_accum=args.double_accum) psfs.append(psf) data_vars = { 'FIELD_ID': (('row', ), da.full_like(ds.TIME.data, ds.FIELD_ID, chunks=args.row_out_chunk)), 'DATA_DESC_ID': (('row', ), da.full_like(ds.TIME.data, ds.DATA_DESC_ID, chunks=args.row_out_chunk)), 'WEIGHT': (('row', 'chan'), weights.rechunk({0: args.row_out_chunk })), # why no 'f4'? 'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk})) } coords = {'chan': (('chan', ), freqs[ims][spw])} out_ds = Dataset(data_vars, coords) out_datasets.append(out_ds) writes = xds_to_zarr(out_datasets, args.output_filename + '.zarr', columns='ALL') # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False) # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False) if not args.mock: # psfs = dask.compute(psfs, writes, optimize_graph=False)[0] # with performance_report(filename=args.output_filename + '_psf_per.html'): psfs = dask.compute(psfs, writes, optimize_graph=False)[0] psf = stitch_images(psfs, nband, band_mapping) hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, freq_out) save_fits(args.output_filename + '_psf.fits', psf, hdr, dtype=args.output_type) psf_mfs = np.sum(psf, axis=0) wsum = psf_mfs.max() psf_mfs /= wsum hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, np.mean(freq_out)) save_fits(args.output_filename + '_psf_mfs.fits', psf_mfs, hdr_mfs, dtype=args.output_type) print("All done here.", file=log)