def test_storage_ms(ms): oxdsl = xds_from_ms(ms) writes = xds_to_storage_table(oxdsl, ms) oxdsl = dask.compute(oxdsl)[0] dask.compute(writes) xdsl = dask.compute(xds_from_ms(ms))[0] assert all([xds.equals(oxds) for xds, oxds in zip(xdsl, oxdsl)])
def _derive_row_chunking(self, args): datasets = xds_from_ms(args.ms, group_cols=GROUP_COLS, columns=["TIME", "INTERVAL"], chunks={'row': args.row_chunks}) return dataset_chunks(datasets, args.time_bin_secs, args.row_chunks)
def test_add_datavars(ms, tmp_path_factory, prechunking, postchunking): store = tmp_path_factory.mktemp("zarr_store") ref_datasets = xds_from_ms(ms) for i, ds in enumerate(ref_datasets): chunks = ds.chunks row = sum(chunks["row"]) chan = sum(chunks["chan"]) corr = sum(chunks["corr"]) ref_datasets[i] = ds.assign_coords( row=np.arange(row), chan=np.arange(chan), corr=np.arange(corr), dummy=np.arange(10) # Orphan coordinate. ) chunked_datasets = [ds.chunk(prechunking) for ds in ref_datasets] dask.compute(xds_to_zarr(chunked_datasets, store)) rechunked_datasets = [ds.chunk(postchunking) for ds in xds_from_zarr(store)] augmented_datasets = [ds.assign({"DUMMY": (("row", "chan", "corr"), da.zeros_like(ds.DATA.data))}) for ds in rechunked_datasets] dask.compute(xds_to_zarr(augmented_datasets, store, rechunk=True)) augmented_datasets = xds_from_zarr(store) assert all([ds.DUMMY.chunks == cds.DATA.chunks for ds, cds in zip(augmented_datasets, chunked_datasets)])
def test_rechunking(ms, tmp_path_factory, prechunking, postchunking): store = tmp_path_factory.mktemp("zarr_store") ref_datasets = xds_from_ms(ms) for i, ds in enumerate(ref_datasets): chunks = ds.chunks row = sum(chunks["row"]) chan = sum(chunks["chan"]) corr = sum(chunks["corr"]) ref_datasets[i] = ds.assign_coords( row=np.arange(row), chan=np.arange(chan), corr=np.arange(corr), dummy=np.arange(10) # Orphan coordinate. ) chunked_datasets = [ds.chunk(prechunking) for ds in ref_datasets] dask.compute(xds_to_zarr(chunked_datasets, store)) rechunked_datasets = [ds.chunk(postchunking) for ds in xds_from_zarr(store)] dask.compute(xds_to_zarr(rechunked_datasets, store, rechunk=True)) rechunked_datasets = xds_from_zarr(store) assert all([ds.equals(rds) for ds, rds in zip(rechunked_datasets, ref_datasets)])
def test_keyword_write(ms): datasets = xds_from_ms(ms) # Add to table keywords writes = xds_to_table([], ms, [], table_keywords={'bob': 'qux'}) dask.compute(writes) with pt.table(ms, ack=False, readonly=True) as T: assert T.getkeywords()['bob'] == 'qux' # Add to column keywords writes = xds_to_table(datasets, ms, [], column_keywords={'STATE_ID': { 'bob': 'qux' }}) dask.compute(writes) with pt.table(ms, ack=False, readonly=True) as T: assert T.getcolkeywords("STATE_ID")['bob'] == 'qux' # Remove from column and table keywords from daskms.writes import DELKW writes = xds_to_table(datasets, ms, [], table_keywords={'bob': DELKW}, column_keywords={'STATE_ID': { 'bob': DELKW }}) dask.compute(writes) with pt.table(ms, ack=False, readonly=True) as T: assert 'bob' not in T.getkeywords() assert 'bob' not in T.getcolkeywords("STATE_ID")
def test_keyword_read(keyword_ms, table_kw, column_kw): # Create an example MS with pt.table(keyword_ms, ack=False, readonly=True) as T: desc = T._getdesc(actual=True) ret = xds_from_ms(keyword_ms, table_keywords=table_kw, column_keywords=column_kw) if isinstance(ret, tuple): ret_pos = 1 if table_kw is True: assert desc["_keywords_"] == ret[ret_pos] ret_pos += 1 if column_kw is True: colkw = ret[ret_pos] for column, keywords in colkw.items(): assert desc[column]['keywords'] == keywords ret_pos += 1 else: assert table_kw is False assert column_kw is False assert isinstance(ret, list)
def _input_datasets(self, args, row_chunks): # Set up row chunks chunks = [{'row': rc} for rc in row_chunks] main_ds, tabkw, colkw = xds_from_ms(args.ms, group_cols=GROUP_COLS, table_keywords=True, column_keywords=True, chunks=chunks) # Figure out non SPW + SORTED sub-tables to just copy subtables = { k for k, v in tabkw.items() if k not in ("SPECTRAL_WINDOW", "SORTED_TABLE") and isinstance(v, str) and v.startswith("Table: ") } spw_ds = xds_from_table("::".join((args.ms, "SPECTRAL_WINDOW")), group_cols="__row__") field_ds = xds_from_table("::".join((args.ms, "FIELD")), group_cols="__row__") ddid_ds = xds_from_table("::".join((args.ms, "DATA_DESCRIPTION"))) assert len(ddid_ds) == 1 ddid_ds = dask.compute(ddid_ds)[0] return main_ds, spw_ds, ddid_ds, field_ds, subtables
def test_cached_array(ms): ds = xds_from_ms(ms, group_cols=[], chunks={'row': 1, 'chan': 4})[0] data = ds.DATA.data cached_data = cached_array(data) assert_array_almost_equal(cached_data, data) # 2 x row blocks + row x chan x corr blocks assert len(_key_cache) == data.numblocks[0] * 2 + data.npartitions # rows, row runs and data array cache's assert len(_array_cache_cache) == 3 # Pickling works pickled_data = pickle.loads(pickle.dumps(cached_data)) assert_array_almost_equal(pickled_data, data) # Same underlying caching is re-used # 2 x row blocks + row x chan x corr blocks assert len(_key_cache) == data.numblocks[0] * 2 + data.npartitions # rows, row runs and data array cache's assert len(_array_cache_cache) == 3 del pickled_data, cached_data, data, ds gc.collect() assert len(_key_cache) == 0 assert len(_array_cache_cache) == 0
def predict(args): # Convert source data into dask arrays sky_model = parse_sky_model(args.sky_model, args.model_chunks) # Get the support tables tables = support_tables( args, ["FIELD", "DATA_DESCRIPTION", "SPECTRAL_WINDOW", "POLARIZATION"]) field_ds = tables["FIELD"] ddid_ds = tables["DATA_DESCRIPTION"] spw_ds = tables["SPECTRAL_WINDOW"] pol_ds = tables["POLARIZATION"] # List of write operations writes = [] # Construct a graph for each DATA_DESC_ID for xds in xds_from_ms(args.ms, columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"], group_cols=["FIELD_ID", "DATA_DESC_ID"], chunks={"row": args.row_chunks}): # Extract frequencies from the spectral window associated # with this data descriptor id field = field_ds[xds.attrs['FIELD_ID']] ddid = ddid_ds[xds.attrs['DATA_DESC_ID']] spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]] pol = pol_ds[ddid.POLARIZATION_ID.data[0]] # Select single dataset row out corrs = pol.NUM_CORR.data[0] _, time_index = da.unique(xds.TIME.data, return_inverse=True) # Generate visibility expressions for each source type source_vis = [ vis_factory(args, stype, sky_model, time_index, xds, field, spw, pol) for stype in sky_model.keys() ] # Sum visibilities together vis = sum(source_vis) # Reshape (2, 2) correlation to shape (4,) if corrs == 4: vis = vis.reshape(vis.shape[:2] + (4, )) # Assign visibilities to MODEL_DATA array on the dataset xds = xds.assign(MODEL_DATA=(("row", "chan", "corr"), vis)) # Create a write to the table write = xds_to_table(xds, args.ms, ['MODEL_DATA']) # Add to the list of writes writes.append(write) # Submit all graph computations in parallel with ProgressBar(): dask.compute(writes)
def test_ms_create_and_update(Dataset, tmp_path, chunks): """ Test that we can update and append at the same time """ filename = str(tmp_path / "create-and-update.ms") rs = np.random.RandomState(42) # Create a dataset of 10 rows with DATA and DATA_DESC_ID dims = ("row", "chan", "corr") row, chan, corr = tuple(sum(chunks[d]) for d in dims) ms_datasets = [] np_data = (rs.normal(size=(row, chan, corr)) + 1j * rs.normal(size=(row, chan, corr))).astype(np.complex64) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) # Create dask ddid column dask_ddid = da.full(row, 0, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'DATA_DESC_ID': (("row", ), dask_ddid), }) ms_datasets.append(dataset) # Write it writes = xds_to_table(ms_datasets, filename, ["DATA", "DATA_DESC_ID"]) dask.compute(writes) ms_datasets = xds_from_ms(filename) # Now add another dataset (different DDID), with no ROWID np_data = (rs.normal(size=(row, chan, corr)) + 1j * rs.normal(size=(row, chan, corr))).astype(np.complex64) data_chunks = tuple((chunks['row'], chan, corr)) dask_data = da.from_array(np_data, chunks=data_chunks) # Create dask ddid column dask_ddid = da.full(row, 1, chunks=chunks['row'], dtype=np.int32) dataset = Dataset({ 'DATA': (dims, dask_data), 'DATA_DESC_ID': (("row", ), dask_ddid), }) ms_datasets.append(dataset) # Write it writes = xds_to_table(ms_datasets, filename, ["DATA", "DATA_DESC_ID"]) dask.compute(writes) # Rows have been added and additional data is present with pt.table(filename, ack=False, readonly=True) as T: first_data_desc_id = da.full(row, ms_datasets[0].DATA_DESC_ID, chunks=chunks['row']) ds_data = da.concatenate( [ms_datasets[0].DATA.data, ms_datasets[1].DATA.data]) ds_ddid = da.concatenate( [first_data_desc_id, ms_datasets[1].DATA_DESC_ID.data]) assert_array_equal(T.getcol("DATA"), ds_data) assert_array_equal(T.getcol("DATA_DESC_ID"), ds_ddid)
def _derive_row_chunking(self, args): datasets = xds_from_ms(args.ms, group_cols=GROUP_COLS, columns=["TIME", "INTERVAL", "UVW"], chunks={'row': args.row_chunks}, taql_where=args.taql_where) return dataset_chunks(datasets, args.time_bin_secs, args.row_chunks, bda=args.command)
def test_xds_to_zarr(ms, spw_table, ant_table, tmp_path_factory): zarr_store = tmp_path_factory.mktemp("zarr_store") / "test.zarr" spw_store = zarr_store.parent / f"{zarr_store.name}::SPECTRAL_WINDOW" ant_store = zarr_store.parent / f"{zarr_store.name}::ANTENNA" ms_datasets = xds_from_ms(ms) spw_datasets = xds_from_table(spw_table, group_cols="__row__") ant_datasets = xds_from_table(ant_table) for i, ds in enumerate(ms_datasets): dims = ds.dims row, chan, corr = (dims[d] for d in ("row", "chan", "corr")) ms_datasets[i] = ds.assign_coords( **{ "chan": (("chan", ), np.arange(chan)), "corr": (("corr", ), np.arange(corr)), }) main_zarr_writes = xds_to_zarr(ms_datasets, zarr_store) assert len(ms_datasets) == len(main_zarr_writes) for ms_ds, zw_ds in zip(ms_datasets, main_zarr_writes): for k, _ in ms_ds.attrs[DASKMS_PARTITION_KEY]: assert getattr(ms_ds, k) == getattr(zw_ds, k) writes = [main_zarr_writes] writes.extend(xds_to_zarr(spw_datasets, spw_store)) writes.extend(xds_to_zarr(ant_datasets, ant_store)) dask.compute(writes) zarr_datasets = xds_from_zarr(zarr_store, chunks={"row": 1}) for ms_ds, zarr_ds in zip(ms_datasets, zarr_datasets): # Check data variables assert ms_ds.data_vars, "MS Dataset has no variables" for name, var in ms_ds.data_vars.items(): zdata = getattr(zarr_ds, name).data assert type(zdata) is type(var.data) # noqa assert_array_equal(var.data, zdata) # Check coordinates assert ms_ds.coords, "MS Datset has no coordinates" for name, var in ms_ds.coords.items(): zdata = getattr(zarr_ds, name).data assert type(zdata) is type(var.data) # noqa assert_array_equal(var.data, zdata) # Check dataset attributes for k, v in ms_ds.attrs.items(): zattr = getattr(zarr_ds, k) assert_array_equal(zattr, v)
def test_write_table_proxy_keyword(ms): datasets = xds_from_ms(ms) # Test that we get a TableProxy if requested writes, tp = xds_to_table(datasets, ms, [], table_proxy=True) assert isinstance(writes, list) and isinstance(writes[0], Dataset) assert isinstance(tp, TableProxy) assert tp.nrows().result() == 10 writes = xds_to_table(datasets, ms, [], table_proxy=False) assert isinstance(writes, list) and isinstance(writes[0], Dataset)
def convolve(self, x): # print("Applying Hessian", file=log) x = da.from_array(x.astype(self.real_type), chunks=(1, self.nx, self.ny), name=False) convolvedims = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID bands = self.band_mapping[ims][spw] model = x[list(bands), :, :] convolvedim = hessian(self.uvws[ims][spw], self.freq[ims][spw], model, self.freq_bin_idx[ims][spw], self.freq_bin_counts[ims][spw], self.cell, weights=self.stokes_weights[ims][spw], nthreads=self.nthreads // self.nband, epsilon=self.epsilon, do_wstacking=self.do_wstacking, double_accum=True) convolvedims.append(convolvedim) convolvedims = dask.compute(convolvedims)[0] return accumulate_dirty(convolvedims, self.nband, self.band_mapping).astype(self.real_type)
def zarr_tester(ms, spw_table, ant_table, zarr_store, spw_store, ant_store): ms_datasets = xds_from_ms(ms) spw_datasets = xds_from_table(spw_table, group_cols="__row__") ant_datasets = xds_from_table(ant_table) for i, ds in enumerate(ms_datasets): dims = ds.dims row, chan, corr = (dims[d] for d in ("row", "chan", "corr")) ms_datasets[i] = ds.assign_coords(**{ "chan": (("chan",), np.arange(chan)), "corr": (("corr",), np.arange(corr)), }) main_zarr_writes = xds_to_zarr(ms_datasets, zarr_store.url, storage_options=zarr_store.storage_options) assert len(ms_datasets) == len(main_zarr_writes) for ms_ds, zw_ds in zip(ms_datasets, main_zarr_writes): for k, _ in ms_ds.attrs[DASKMS_PARTITION_KEY]: assert getattr(ms_ds, k) == getattr(zw_ds, k) writes = [main_zarr_writes] writes.extend(xds_to_zarr(spw_datasets, spw_store)) writes.extend(xds_to_zarr(ant_datasets, ant_store)) dask.compute(writes) zarr_datasets = xds_from_storage_ms(zarr_store, chunks={"row": 1}) for ms_ds, zarr_ds in zip(ms_datasets, zarr_datasets): # Check data variables assert ms_ds.data_vars, "MS Dataset has no variables" for name, var in ms_ds.data_vars.items(): zdata = getattr(zarr_ds, name).data assert type(zdata) is type(var.data) # noqa assert_array_equal(var.data, zdata) # Check coordinates assert ms_ds.coords, "MS Datset has no coordinates" for name, var in ms_ds.coords.items(): zdata = getattr(zarr_ds, name).data assert type(zdata) is type(var.data) # noqa assert_array_equal(var.data, zdata) # Check dataset attributes for k, v in ms_ds.attrs.items(): zattr = getattr(zarr_ds, k) assert_array_equal(zattr, v)
def write_model(self, x): print("Writing model data") x = da.from_array(x.astype(np.float32), chunks=(1, self.nx, self.ny)) writes = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=('MODEL_DATA', 'UVW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data model_vis = getattr(ds, 'MODEL_DATA').data bands = self.band_mapping[ims][spw] model = x[list(bands), :, :] vis = im2vis(uvw, freq, model, freq_bin_idx, freq_bin_counts, self.cell, nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking) model_vis = populate_model(vis, model_vis) out_ds = ds.assign(**{self.model_column: (("row", "chan", "corr"), model_vis)}) out_data.append(out_ds) writes.append(xds_to_table(out_data, ims, columns=[self.model_column])) dask.compute(writes, scheduler='single-threaded')
def write_component_model(self, comps, ref_freq, mask, row_chunks, chan_chunks): print("Writing model data at full freq resolution") order, npix = comps.shape comps = da.from_array(comps, chunks=(-1, -1)) mask = da.from_array(mask.squeeze(), chunks=(-1, -1)) writes = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks={'row':(row_chunks,), 'chan':(chan_chunks,)}, columns=('MODEL_DATA', 'UVW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = spws[ds.DATA_DESC_ID] freq = spw.CHAN_FREQ.data.squeeze() freq_bin_idx = da.arange(0, freq.size, 1, chunks=freq.chunks, dtype=np.int64) freq_bin_counts = da.ones(freq.size, chunks=freq.chunks, dtype=np.int64) uvw = ds.UVW.data model_vis = getattr(ds, 'MODEL_DATA').data model = model_from_comps(comps, freq, mask, ref_freq) vis = im2vis(uvw, freq, model, freq_bin_idx, freq_bin_counts, self.cell, nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking) model_vis = populate_model(vis, model_vis) out_ds = ds.assign(**{self.model_column: (("row", "chan", "corr"), model_vis)}) out_data.append(out_ds) writes.append(xds_to_table(out_data, ims, columns=[self.model_column])) dask.compute(writes, scheduler='single-threaded')
def test_xarray_to_zarr(ms, tmp_path_factory): store = tmp_path_factory.mktemp("zarr_store") datasets = xds_from_ms(ms) for i, ds in enumerate(datasets): chunks = ds.chunks row = sum(chunks["row"]) chan = sum(chunks["chan"]) corr = sum(chunks["corr"]) datasets[i] = ds.assign_coords(row=np.arange(row), chan=np.arange(chan), corr=np.arange(corr)) for i, ds in enumerate(datasets): ds.to_zarr(str(store / f"ds-{i}.zarr"))
def test_unified_schema(ms): datasets = xds_from_ms(ms) assert len(datasets) == 3 from daskms.experimental.arrow.arrow_schema import ArrowSchema schema = ArrowSchema.from_datasets(datasets) for ds in datasets: for column, var in ds.data_vars.items(): s = schema.data_vars[column] assert s.dims == var.dims[1:] assert s.shape == var.shape[1:] assert np.dtype(s.dtype) == var.dtype assert isinstance(var.data, s.type) schema.to_arrow_schema()
def test_xds_to_parquet(ms, tmp_path_factory, spw_table, ant_table): store = tmp_path_factory.mktemp("parquet_store") / "out.parquet" # antenna_store = store.parent / f"{store.name}::ANTENNA" # spw_store = store.parent / f"{store.name}::SPECTRAL_WINDOW" datasets = xds_from_ms(ms) # We can test row chunking if xarray is installed if xarray is not None: datasets = [ds.chunk({"row": 1}) for ds in datasets] # spw_datasets = xds_from_table(spw_table, group_cols="__row__") # ant_datasets = xds_from_table(ant_table, group_cols="__row__") writes = [] writes.extend(xds_to_parquet(datasets, store)) # TODO(sjperkins) # Fix arrow shape unification errors # writes.extend(xds_to_parquet(spw_datasets, spw_store)) # writes.extend(xds_to_parquet(ant_datasets, antenna_store)) dask.compute(writes) pq_datasets = xds_from_parquet(store, chunks={"row": 1}) assert len(datasets) == len(pq_datasets) for ds, pq_ds in zip(datasets, pq_datasets): for column, var in ds.data_vars.items(): pq_var = getattr(pq_ds, column) assert_array_equal(var.data, pq_var.data) assert var.dims == pq_var.dims for column, var in ds.coords.items(): pq_var = getattr(pq_ds, column) assert_array_equal(var.data, pq_var.data) assert var.dims == pq_var.dims partitions = ds.attrs[DASKMS_PARTITION_KEY] pq_partitions = pq_ds.attrs[DASKMS_PARTITION_KEY] assert partitions == pq_partitions for field, dtype in partitions: assert getattr(ds, field) == getattr(pq_ds, field)
def test_expressions(ms): datasets = xds_from_ms(ms) for i, ds in enumerate(datasets): dims = ds.DATA.dims datasets[i] = ds.assign(DIR1_DATA=(dims, ds.DATA.data), DIR2_DATA=(dims, ds.DATA.data), DIR3_DATA=(dims, ds.DATA.data)) results = [ ds.DATA.data / (-ds.DIR1_DATA.data + ds.DIR2_DATA.data + ds.DIR3_DATA.data) * 4 for ds in datasets ] string = "DATA / (-DIR1_DATA + DIR2_DATA + DIR3_DATA)*4" expressions = data_column_expr(string, datasets) for i, (ds, expr) in enumerate(zip(datasets, expressions)): assert_array_equal(results[i], expr)
def test_github_98(): ms = "/home/sperkins/data/AF0236_spw01.ms/" if not os.path.exists(ms): pytest.skip("AF0236_spw01.ms on which this " "test depends is not present") datasets = xds_from_ms(ms, columns=['DATA', 'ANTENNA1', 'ANTENNA2'], group_cols=['DATA_DESC_ID'], taql_where='ANTENNA1 == 5 || ANTENNA2 == 5') assert len(datasets) == 2 assert datasets[0].DATA_DESC_ID == 0 assert datasets[1].DATA_DESC_ID == 1 for ds in datasets: expr = da.logical_or(ds.ANTENNA1.data == 5, ds.ANTENNA2.data == 5) expr, equal = dask.compute(expr, da.all(expr)) assert equal.item() is True assert len(expr) > 0
def test_storage_parquet(ms, tmp_path_factory): parquet_store = tmp_path_factory.mktemp("parquet") / "test.parquet" oxdsl = xds_from_ms(ms) writes = xds_to_parquet(oxdsl, parquet_store) dask.compute(writes) oxdsl = xds_from_parquet(parquet_store) writes = xds_to_storage_table(oxdsl, parquet_store) oxdsl = dask.compute(oxdsl)[0] dask.compute(writes) xdsl = dask.compute(xds_from_parquet(parquet_store))[0] assert all([xds.equals(oxds) for xds, oxds in zip(xdsl, oxdsl)])
def test_storage_zarr(ms, tmp_path_factory): zarr_store = tmp_path_factory.mktemp("zarr") / "test.zarr" oxdsl = xds_from_ms(ms) writes = xds_to_zarr(oxdsl, zarr_store) dask.compute(writes) oxdsl = xds_from_zarr(zarr_store) writes = xds_to_storage_table(oxdsl, zarr_store) oxdsl = dask.compute(oxdsl)[0] dask.compute(writes) xdsl = dask.compute(xds_from_zarr(zarr_store))[0] assert all([xds.equals(oxds) for xds, oxds in zip(xdsl, oxdsl)])
def test_multiprocess_create(ms, tmp_path_factory): zarr_store = tmp_path_factory.mktemp("zarr_store") / "test.zarr" ms_datasets = xds_from_ms(ms) for i, ds in enumerate(ms_datasets): ms_datasets[i] = ds.chunk({"row": 1}) writes = xds_to_zarr(ms_datasets, zarr_store) dask.compute(writes, scheduler="processes") zds = xds_from_zarr(zarr_store) for zds, msds in zip(zds, ms_datasets): for k, v in msds.data_vars.items(): assert_array_equal(v, getattr(zds, k)) for k, v in msds.coords.items(): assert_array_equal(v, getattr(zds, k)) for k, v in msds.attrs.items(): assert_array_equal(v, getattr(zds, k))
def _jones2col(**kw): args = OmegaConf.create(kw) from omegaconf import ListConfig if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig): args.ms = [args.ms] OmegaConf.set_struct(args, True) import numpy as np from daskms.experimental.zarr import xds_from_zarr from daskms import xds_from_ms, xds_to_table import dask.array as da import dask from africanus.calibration.utils import chunkify_rows from africanus.calibration.utils.dask import corrupt_vis # get net gains G = xds_from_zarr(args.gain_table + '::G') # chunking info t_chunks = G[0].t_chunk.data if len(t_chunks) > 1: t_chunks = G[0].t_chunk.data[1:-1] - G[0].t_chunk.data[0:-2] assert (t_chunks == t_chunks[0]).all() utpc = t_chunks[0] else: utpc = t_chunks[0] times = xds_from_ms(args.ms[0], columns=['TIME'])[0].get('TIME').data.compute() row_chunks, tbin_idx, tbin_counts = chunkify_rows(times, utimes_per_chunk=utpc, daskify_idx=True) f_chunks = G[0].f_chunk.data if len(f_chunks) > 1: f_chunks = G[0].f_chunk.data[1:-1] - G[0].f_chunk.data[0:-2] assert (f_chunks == f_chunks[0]).all() chan_chunks = f_chunks[0] else: if f_chunks[0]: chan_chunks = f_chunks[0] else: chan_chunks = -1 columns = ('DATA', 'FLAG', 'FLAG_ROW', 'ANTENNA1', 'ANTENNA2') if args.acol is not None: columns += (args.acol,) # open MS xds = xds_from_ms(args.ms[0], chunks={'row': row_chunks, 'chan': chan_chunks}, columns=columns, group_cols=('FIELD_ID', 'DATA_DESC_ID', 'SCAN_NUMBER')) # Current hack probably only works for single field and DDID try: assert len(xds) == len(G) except Exception as e: raise ValueError("Number of datasets in gains do not " "match those in MS") # assuming scans are aligned out_data = [] for g, ds in zip(G, xds): try: assert g.SCAN_NUMBER == ds.SCAN_NUMBER except Exception as e: raise ValueError("Scans not aligned") nrow = ds.dims['row'] nchan = ds.dims['chan'] ncorr = ds.dims['corr'] # need to swap axes for africanus jones = da.swapaxes(g.gains.data, 1, 2) flag = ds.FLAG.data frow = ds.FLAG_ROW.data ant1 = ds.ANTENNA1.data ant2 = ds.ANTENNA2.data frow = (frow | (ant1 == ant2)) flag = (flag[:, :, 0] | flag[:, :, -1]) flag = da.logical_or(flag, frow[:, None]) if args.acol is not None: acol = ds.get(args.acol).data.reshape(nrow, nchan, 1, ncorr) else: acol = da.ones((nrow, nchan, 1, ncorr), chunks=(row_chunks, chan_chunks, 1, -1), dtype=jones.dtype) cvis = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones, acol) # compare where unflagged if args.compareto is not None: flag = flag.compute() vis = ds.get(args.compareto).values[~flag] print("Max abs difference = ", np.abs(cvis.compute()[~flag] - vis).max()) quit() out_ds = ds.assign(**{args.mueller_column: (("row", "chan", "corr"), cvis)}) out_data.append(out_ds) writes = xds_to_table(out_data, args.ms[0], columns=[args.mueller_column]) dask.compute(writes)
def main(args): # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 all_freqs = [] for ims in args.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), columns=('UVW'), chunks={'row': args.row_chunks}) spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") spws = dask.compute(spws)[0] for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) spw = spws[ds.DATA_DESC_ID] tmp_freq = spw.CHAN_FREQ.data.squeeze() all_freqs.append(list(tmp_freq)) uv_max = u_max.compute() del uvw # get Nyquist cell size from africanus.constants import c as lightspeed all_freqs = dask.compute(all_freqs) freq = np.unique(all_freqs) cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed) if args.cell_size is not None: cell_rad = args.cell_size * np.pi / 60 / 60 / 180 print("Super resolution factor = ", cell_N / cell_rad) else: cell_rad = cell_N / args.super_resolution_factor args.cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % args.cell_size) if args.nx is None or args.ny is None: fov = args.fov * 3600 npix = int(fov / args.cell_size) if npix % 2: npix += 1 args.nx = npix args.ny = npix if args.nband is None: args.nband = freq.size print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny)) # init gridder R = Gridder(args.ms, args.nx, args.ny, args.cell_size, nband=args.nband, nthreads=args.nthreads, do_wstacking=args.do_wstacking, row_chunks=args.row_chunks, optimise_chunks=True, data_column=args.data_column, weight_column=args.weight_column, imaging_weight_column=args.imaging_weight_column, model_column=args.model_column, flag_column=args.flag_column) freq_out = R.freq_out radec = R.radec # get headers hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx, args.ny, radec, freq_out) hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx, args.ny, radec, np.mean(freq_out)) hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600, 2 * args.nx, 2 * args.ny, radec, freq_out) hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, 2 * args.nx, 2 * args.ny, radec, np.mean(freq_out)) # psf if args.psf is not None: compare_headers(hdr_psf, fits.getheader(args.psf)) psf_array = load_fits(args.psf) else: psf_array = R.make_psf() save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf) psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny), axis=1) wsum = np.sum(psf_max) counts = np.sum(psf_max > 0) psf_max_mean = wsum / counts psf_array /= psf_max_mean psf = PSF(psf_array, args.nthreads) psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny), axis=1) psf_max[psf_max < 1e-15] = 1e-15 if args.dirty is not None: compare_headers(hdr, fits.getheader(args.dirty)) dirty = load_fits(args.dirty) else: dirty = R.make_dirty() save_fits(args.outfile + '_dirty.fits', dirty, hdr) dirty /= psf_max_mean # mfs residual wsum = np.sum(psf_max) dirty_mfs = np.sum(dirty, axis=0) / wsum rmax = np.abs(dirty_mfs).max() rms = np.std(dirty_mfs) save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs) psf_mfs = np.sum(psf_array, axis=0) / wsum save_fits( args.outfile + '_psf_mfs.fits', psf_mfs[args.nx // 2:3 * args.nx // 2, args.ny // 2:3 * args.ny // 2], hdr_mfs) # mask if args.mask is not None: mask = load_fits(args.mask, dtype=np.int64) if mask.shape != (args.nx, args.ny): raise ValueError("Mask has incorrect shape") else: mask = np.ones((args.nx, args.ny), dtype=np.int64) if args.point_mask is not None: pmask = load_fits(args.point_mask, dtype=np.bool) if pmask.shape != (args.nx, args.ny): raise ValueError("Mask has incorrect shape") else: pmask = None # Reporting print("At iteration 0 peak of residual is %f and rms is %f" % (rmax, rms)) report_iters = list(np.arange(0, args.maxit, args.report_freq)) if report_iters[-1] != args.maxit - 1: report_iters.append(args.maxit - 1) # set up point sources phi = Dirac(args.nband, args.nx, args.ny, mask=pmask) dual = np.zeros((args.nband, args.nx, args.ny), dtype=np.float64) weights_21 = np.where(phi.mask, 1, np.inf) # preconditioning matrix def hess(beta): return phi.hdot(psf.convolve( phi.dot(beta))) + beta / args.sig_l2**2 # vague prior on beta # get new spectral norm L = power_method(hess, dirty.shape, tol=args.pmtol, maxit=args.pmmaxit) # deconvolve eps = 1.0 i = 0 residual = dirty.copy() model = np.zeros(dirty.shape, dtype=dirty.dtype) for i in range(1, args.maxit): # find point source candidates if args.do_clean: model_tmp = hogbom(mask[None] * residual / psf_max[:, None, None], psf_array / psf_max[:, None, None], gamma=args.cgamma, pf=args.peak_factor) phi.update_locs(np.any(model_tmp, axis=0)) # get new spectral norm L = power_method(hess, model.shape, tol=args.pmtol, maxit=args.pmmaxit) else: model_tmp = np.zeros_like(residual, dtype=residual.dtype) # solve for beta updates x = pcg(hess, phi.hdot(residual), phi.hdot(model_tmp), M=lambda x: x * args.sig_l2**2, tol=args.cgtol, maxit=args.cgmaxit, verbosity=args.cgverbose) modelp = model.copy() model += args.gamma * x # impose sparsity and positivity in point sources weights_21 = np.where(phi.mask, 1, 1e10) # 1e10 for effective infinity model, dual = primal_dual(hess, model, modelp, dual, args.sig_21, phi, weights_21, L, tol=args.pdtol, maxit=args.pdmaxit, axis=0, positivity=args.positivity, report_freq=100) # update Dirac dictionary (remove zero components) phi.trim_fat(model) # get residual residual = R.make_residual(model) / psf_max_mean # check stopping criteria residual_mfs = np.sum(residual, axis=0) / wsum rmax = np.abs(mask * residual_mfs).max() rms = np.std(mask * residual_mfs) eps = np.linalg.norm(model - modelp) / np.linalg.norm(model) if i in report_iters: # save current iteration save_fits(args.outfile + str(i) + '_model.fits', model, hdr) model_mfs = np.mean(model, axis=0) save_fits(args.outfile + str(i) + '_model_mfs.fits', model_mfs, hdr_mfs) save_fits(args.outfile + str(i) + '_residual.fits', residual / psf_max[:, None, None], hdr) save_fits(args.outfile + str(i) + '_residual_mfs.fits', residual_mfs, hdr_mfs) print( "At iteration %i peak of residual is %f, rms is %f, current eps is %f" % (i, rmax, rms, eps)) if eps < args.tol: print("We have convergence!") break # final iteration with only a positivity constraint on pixel locs tmp = phi.hdot(model) x = pcg(hess, phi.hdot(residual), np.zeros_like(tmp, dtype=tmp.dtype), M=lambda x: x * args.sig_l2**2, tol=args.cgtol, maxit=args.cgmaxit, verbosity=args.cgverbose) modelp = model.copy() model += args.gamma * x model, dual = primal_dual(hess, model, modelp, dual, 0.0, phi, weights_21, L, tol=args.pdtol, maxit=args.pdmaxit, axis=0, report_freq=100) # get residual residual = R.make_residual(model) / psf_max_mean # check stopping criteria residual_mfs = np.sum(residual, axis=0) / wsum rmax = np.abs(mask * residual_mfs).max() rms = np.std(mask * residual_mfs) print("At final iteration peak of residual is %f and rms is %f" % (rmax, rms)) save_fits(args.outfile + '_model.fits', model, hdr) model_mfs = np.mean(model, axis=0) save_fits(args.outfile + '_model_mfs.fits', model_mfs, hdr_mfs) save_fits(args.outfile + '_residual.fits', residual / psf_max[:, None, None], hdr) save_fits(args.outfile + '_residual_mfs.fits', residual_mfs, hdr_mfs) if args.interp_model: nband = args.nband order = args.spectral_poly_order phi.trim_fat(model) I = np.argwhere(phi.mask).squeeze() Ix = I[:, 0] Iy = I[:, 1] npix = I.shape[0] # get components beta = model[:, Ix, Iy] # fit integrated polynomial to model components # we are given frequencies at bin centers, convert to bin edges ref_freq = np.mean(freq_out) delta_freq = freq_out[1] - freq_out[0] wlow = (freq_out - delta_freq / 2.0) / ref_freq whigh = (freq_out + delta_freq / 2.0) / ref_freq wdiff = whigh - wlow # set design matrix for each component Xdesign = np.zeros([freq_out.size, args.spectral_poly_order]) for i in range(1, args.spectral_poly_order + 1): Xdesign[:, i - 1] = (whigh**i - wlow**i) / (i * wdiff) weights = psf_max[:, None] dirty_comps = Xdesign.T.dot(weights * beta) hess_comps = Xdesign.T.dot(weights * Xdesign) comps = np.linalg.solve(hess_comps, dirty_comps) np.savez(args.outfile + "spectral_comps", comps=comps, ref_freq=ref_freq, mask=np.any(model, axis=0)) if args.write_model: if args.interp_model: R.write_component_model(comps, ref_freq, phi.mask, args.row_chunks, args.chan_chunks) else: R.write_model(model)
def _predict(ms, stack, **kw): args = OmegaConf.create(kw) OmegaConf.set_struct(args, True) pyscilog.log_to_file(args.output_filename + '.log') pyscilog.enable_memory_logging(level=3) # number of threads per worker if args.nthreads is None: if args.host_address is not None: raise ValueError( "You have to specify nthreads when using a distributed scheduler" ) import multiprocessing nthreads = multiprocessing.cpu_count() args.nthreads = nthreads else: nthreads = args.nthreads if args.mem_limit is None: if args.host_address is not None: raise ValueError( "You have to specify mem-limit when using a distributed scheduler" ) import psutil mem_limit = int(psutil.virtual_memory()[0] / 1e9) # 100% of memory by default args.mem_limit = mem_limit else: mem_limit = args.mem_limit nband = args.nband if args.nworkers is None: nworkers = nband args.nworkers = nworkers else: nworkers = args.nworkers if args.nthreads_per_worker is None: nthreads_per_worker = 1 args.nthreads_per_worker = nthreads_per_worker else: nthreads_per_worker = args.nthreads_per_worker # the number of chunks being read in simultaneously is equal to # the number of dask threads nthreads_dask = nworkers * nthreads_per_worker if args.ngridder_threads is None: if args.host_address is not None: ngridder_threads = nthreads // nthreads_per_worker else: ngridder_threads = nthreads // nthreads_dask args.ngridder_threads = ngridder_threads else: ngridder_threads = args.ngridder_threads ms = list(ms) print('Input Options:', file=log) for key in kw.keys(): print(' %25s = %s' % (key, args[key]), file=log) # numpy imports have to happen after this step from pfb import set_client set_client(nthreads, mem_limit, nworkers, nthreads_per_worker, args.host_address, stack, log) import numpy as np from pfb.utils.misc import chan_to_band_mapping import dask from dask.distributed import performance_report from dask.graph_manipulation import clone from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table from daskms.utils import dataset_type mstype = dataset_type(ms[0]) if mstype == 'casa': from daskms import xds_to_table elif mstype == 'zarr': from daskms.experimental.zarr import xds_to_zarr as xds_to_table import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import model as im2vis from pfb.utils.fits import load_fits from pfb.utils.misc import restore_corrs, plan_row_chunk from astropy.io import fits # always returns 4D # gridder expects freq axis model = np.atleast_3d(load_fits(args.model).squeeze()) nband, nx, ny = model.shape hdr = fits.getheader(args.model) cell_d = np.abs(hdr['CDELT1']) cell_rad = np.deg2rad(cell_d) # chan <-> band mapping freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping( ms, nband=nband) # degridder memory budget max_chan_chunk = 0 for ims in ms: for spw in freqs[ims]: counts = freq_bin_counts[ims][spw].compute() max_chan_chunk = np.maximum(max_chan_chunk, counts.max()) # assumes number of correlations are the same across MS/SPW xds = xds_from_ms(ms[0]) ncorr = xds[0].dims['corr'] nrow = xds[0].dims['row'] if args.output_type is not None: output_type = np.dtype(args.output_type) else: output_type = np.result_type(np.dtype(args.real_type), np.complex64) data_bytes = output_type.itemsize bytes_per_row = max_chan_chunk * ncorr * data_bytes memory_per_row = bytes_per_row # model memory_per_row += 3 * 8 # uvw if mstype == 'zarr': if args.model_column in xds[0].keys(): model_chunks = getattr(xds[0], args.model_column).data.chunks else: model_chunks = xds[0].DATA.data.chunks print('Chunking model same as data') # get approx image size # this is not a conservative estimate when multiple SPW's map to a single # imaging band pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # full image on single node row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow, memory_per_row, nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row, nthreads_per_worker) if args.row_chunks is not None: row_chunk = int(args.row_chunks) if row_chunk == -1: row_chunk = nrow print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) chunks = {} for ims in ms: chunks[ims] = [] # xds_from_ms expects a list per ds for spw in freqs[ims]: chunks[ims].append({ 'row': row_chunk, 'chan': chan_chunks[ims][spw]['chan'] }) model = da.from_array(model.astype(args.real_type), chunks=(1, nx, ny), name=False) writes = [] radec = None # assumes we are only imaging field 0 of first MS for ims in ms: xds = xds_from_ms(ims, chunks=chunks[ims], columns=('UVW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw uvw = clone(ds.UVW.data) bands = band_mapping[ims][spw] model = model[list(bands), :, :] vis = im2vis(uvw, freqs[ims][spw], model, freq_bin_idx[ims][spw], freq_bin_counts[ims][spw], cell_rad, nthreads=ngridder_threads, epsilon=args.epsilon, do_wstacking=args.wstack) model_vis = restore_corrs(vis, ncorr) if mstype == 'zarr': model_vis = model_vis.rechunk(model_chunks) uvw = uvw.rechunk((model_chunks[0], 3)) out_ds = ds.assign( **{ args.model_column: (("row", "chan", "corr"), model_vis), 'UVW': (("row", "three"), uvw) }) # out_ds = ds.assign(**{args.model_column: (("row", "chan", "corr"), model_vis)}) out_data.append(out_ds) writes.append(xds_to_table(out_data, ims, columns=[args.model_column])) dask.visualize(*writes, filename=args.output_filename + '_predict_graph.pdf', optimize_graph=False, collapse_outputs=True) if not args.mock: with performance_report(filename=args.output_filename + '_predict_per.html'): dask.compute(writes, optimize_graph=False) print("All done here.", file=log)
# Create a dataset representing the entire antenna table ant_table = '::'.join((args.ms, 'ANTENNA')) for ant_ds in xds_from_table(ant_table): print(ant_ds) print( dask.compute(ant_ds.NAME.data, ant_ds.POSITION.data, ant_ds.DISH_DIAMETER.data)) # Create datasets representing each row of the spw table spw_table = '::'.join((args.ms, 'SPECTRAL_WINDOW')) for spw_ds in xds_from_table(spw_table, group_cols="__row__"): print(spw_ds) print(spw_ds.NUM_CHAN.values) print(spw_ds.CHAN_FREQ.values) # Create datasets from a partioning of the MS datasets = list(xds_from_ms(args.ms, chunks={'row': args.chunks})) for ds in datasets: print(ds) # Try write the STATE_ID column back write = xds_to_table(ds, args.ms, 'STATE_ID') with ProgressBar(), Profiler() as prof: write.compute() # Profile prof.visualize(file_path="chunked.html")
def make_psf(self): print("Making PSF", file=log) psfs = [] self.stokes_weights = {} self.uvws = {} for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] self.stokes_weights[ims] = {} self.uvws[ims] = {} for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue # this is not correct, need to use spw spw = ds.DATA_DESC_ID freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data flag = getattr(ds, self.flag_column).data weights = getattr(ds, self.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], flag.shape, chunks=flag.chunks) if self.imaging_weight_column is not None: imaging_weights = getattr(ds, self.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to( imaging_weights[:, None, :], flag.shape, chunks=flag.chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # for the PSF we need to scale the weights by the # Mueller amplitudes squared if self.mueller_column is not None: mueller = getattr(ds, self.mueller_column).data weightsxx *= da.absolute(mueller[:, :, 0])**2 weightsyy *= da.absolute(mueller[:, :, -1])**2 # weighted sum corr to Stokes I weights = weightsxx + weightsyy # only keep data where both corrs are unflagged flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] flag = ~(flagxx | flagyy) # ducc0 convention weights *= flag data = weights.astype(np.complex64) psf = vis2im(uvw, freq, data, freq_bin_idx, freq_bin_counts, self.nx_psf, self.ny_psf, self.cell, flag=flag.astype(np.uint8), nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking, double_accum=True) psfs.append(psf) # assumes that stokes weights and uvw fit into memory # self.stokes_weights[ims][spw] = dask.persist(weights.rechunk({0:-1}))[0] # self.uvws[ims][spw] = dask.persist(uvw.rechunk({0:-1}))[0] # for comparison with numpy implementation # self.stokes_weights[ims][spw] = dask.compute(weights)[0] # self.uvws[ims][spw] = dask.compute(uvw)[0] # import pdb # pdb.set_trace() psfs = dask.compute(psfs, scheduler='single-threaded')[0] return accumulate_dirty(psfs, self.nband, self.band_mapping).astype(self.real_type)