Example #1
0
def test_rechunking(ms, tmp_path_factory, prechunking, postchunking):
    store = tmp_path_factory.mktemp("zarr_store")
    ref_datasets = xds_from_ms(ms)

    for i, ds in enumerate(ref_datasets):
        chunks = ds.chunks
        row = sum(chunks["row"])
        chan = sum(chunks["chan"])
        corr = sum(chunks["corr"])

        ref_datasets[i] = ds.assign_coords(
            row=np.arange(row),
            chan=np.arange(chan),
            corr=np.arange(corr),
            dummy=np.arange(10)  # Orphan coordinate.
        )

    chunked_datasets = [ds.chunk(prechunking) for ds in ref_datasets]
    dask.compute(xds_to_zarr(chunked_datasets, store))

    rechunked_datasets = [ds.chunk(postchunking)
                          for ds in xds_from_zarr(store)]
    dask.compute(xds_to_zarr(rechunked_datasets, store, rechunk=True))

    rechunked_datasets = xds_from_zarr(store)

    assert all([ds.equals(rds)
                for ds, rds in zip(rechunked_datasets, ref_datasets)])
Example #2
0
def test_add_datavars(ms, tmp_path_factory, prechunking, postchunking):
    store = tmp_path_factory.mktemp("zarr_store")
    ref_datasets = xds_from_ms(ms)

    for i, ds in enumerate(ref_datasets):
        chunks = ds.chunks
        row = sum(chunks["row"])
        chan = sum(chunks["chan"])
        corr = sum(chunks["corr"])

        ref_datasets[i] = ds.assign_coords(
            row=np.arange(row),
            chan=np.arange(chan),
            corr=np.arange(corr),
            dummy=np.arange(10)  # Orphan coordinate.
        )

    chunked_datasets = [ds.chunk(prechunking) for ds in ref_datasets]
    dask.compute(xds_to_zarr(chunked_datasets, store))

    rechunked_datasets = [ds.chunk(postchunking)
                          for ds in xds_from_zarr(store)]
    augmented_datasets = [ds.assign({"DUMMY": (("row", "chan", "corr"),
                                    da.zeros_like(ds.DATA.data))})
                          for ds in rechunked_datasets]
    dask.compute(xds_to_zarr(augmented_datasets, store, rechunk=True))

    augmented_datasets = xds_from_zarr(store)

    assert all([ds.DUMMY.chunks == cds.DATA.chunks
                for ds, cds in zip(augmented_datasets, chunked_datasets)])
Example #3
0
def test_xds_to_zarr_coords(tmp_path_factory):
    zarr_store = tmp_path_factory.mktemp("zarr_coords") / "test.zarr"

    data = da.ones((100, 16, 4), chunks=(10, 4, 1), dtype=np.complex64)
    rowid = da.arange(100, chunks=10)

    data_vars = {"DATA": (("row", "chan", "corr"), data)}
    coords = {
        "ROWID": (("row",), rowid),
        "chan": (("chan",), np.arange(16)),
        "foo": (("foo",), np.arange(4)),
    }

    ds = [Dataset(data_vars, coords=coords)]

    writes = xds_to_zarr(ds, zarr_store)
    dask.compute(writes)

    rds = xds_from_zarr(zarr_store)
    assert len(ds) == len(rds)

    for ods, nds in zip(ds, rds):
        for c, v in ods.data_vars.items():
            assert_array_equal(v.data, getattr(nds, c).data)

        for c, v in ods.coords.items():
            assert_array_equal(v.data, getattr(nds, c).data)
Example #4
0
def test_basic_roundtrip(tmp_path):

    path = tmp_path / "test.zarr"

    # We need >10 datasets to be sure roundtripping is consistent.
    xdsl = [Dataset({'x': (('y',), da.ones(i))}) for i in range(1, 12)]
    dask.compute(xds_to_zarr(xdsl, path))

    xdsl = xds_from_zarr(path)
    dask.compute(xds_to_zarr(xdsl, path))
Example #5
0
def test_xds_to_zarr(ms, spw_table, ant_table, tmp_path_factory):
    zarr_store = tmp_path_factory.mktemp("zarr_store") / "test.zarr"
    spw_store = zarr_store.parent / f"{zarr_store.name}::SPECTRAL_WINDOW"
    ant_store = zarr_store.parent / f"{zarr_store.name}::ANTENNA"

    ms_datasets = xds_from_ms(ms)
    spw_datasets = xds_from_table(spw_table, group_cols="__row__")
    ant_datasets = xds_from_table(ant_table)

    for i, ds in enumerate(ms_datasets):
        dims = ds.dims
        row, chan, corr = (dims[d] for d in ("row", "chan", "corr"))

        ms_datasets[i] = ds.assign_coords(
            **{
                "chan": (("chan", ), np.arange(chan)),
                "corr": (("corr", ), np.arange(corr)),
            })

    main_zarr_writes = xds_to_zarr(ms_datasets, zarr_store)
    assert len(ms_datasets) == len(main_zarr_writes)

    for ms_ds, zw_ds in zip(ms_datasets, main_zarr_writes):
        for k, _ in ms_ds.attrs[DASKMS_PARTITION_KEY]:
            assert getattr(ms_ds, k) == getattr(zw_ds, k)

    writes = [main_zarr_writes]
    writes.extend(xds_to_zarr(spw_datasets, spw_store))
    writes.extend(xds_to_zarr(ant_datasets, ant_store))
    dask.compute(writes)

    zarr_datasets = xds_from_zarr(zarr_store, chunks={"row": 1})

    for ms_ds, zarr_ds in zip(ms_datasets, zarr_datasets):
        # Check data variables
        assert ms_ds.data_vars, "MS Dataset has no variables"

        for name, var in ms_ds.data_vars.items():
            zdata = getattr(zarr_ds, name).data
            assert type(zdata) is type(var.data)  # noqa
            assert_array_equal(var.data, zdata)

        # Check coordinates
        assert ms_ds.coords, "MS Datset has no coordinates"

        for name, var in ms_ds.coords.items():
            zdata = getattr(zarr_ds, name).data
            assert type(zdata) is type(var.data)  # noqa
            assert_array_equal(var.data, zdata)

        # Check dataset attributes
        for k, v in ms_ds.attrs.items():
            zattr = getattr(zarr_ds, k)
            assert_array_equal(zattr, v)
Example #6
0
def test_storage_zarr(ms, tmp_path_factory):

    zarr_store = tmp_path_factory.mktemp("zarr") / "test.zarr"

    oxdsl = xds_from_ms(ms)

    writes = xds_to_zarr(oxdsl, zarr_store)

    dask.compute(writes)

    oxdsl = xds_from_zarr(zarr_store)

    writes = xds_to_storage_table(oxdsl, zarr_store)

    oxdsl = dask.compute(oxdsl)[0]

    dask.compute(writes)

    xdsl = dask.compute(xds_from_zarr(zarr_store))[0]

    assert all([xds.equals(oxds) for xds, oxds in zip(xdsl, oxdsl)])
Example #7
0
def xds_from_storage_table(store, **kwargs):
    from daskms.utils import dataset_type
    typ = dataset_type(store)

    if typ == "casa":
        return xds_from_table(store, **kwargs)
    elif typ == "zarr":
        from daskms.experimental.zarr import xds_from_zarr
        return xds_from_zarr(store, **kwargs)
    elif typ == "parquet":
        from daskms.experimental.arrow import xds_from_parquet
        return xds_from_parquet(store, **kwargs)
    else:
        raise TypeError(f"Unknown dataset {typ}")
Example #8
0
def xds_from_storage_ms(store, **kwargs):
    if not isinstance(store, DaskMSStore):
        store = DaskMSStore(store, **kwargs.pop("storage_options", {}))

    typ = store.type()

    if typ == "casa":
        return xds_from_ms(store, **kwargs)
    elif typ == "zarr":
        from daskms.experimental.zarr import xds_from_zarr
        return xds_from_zarr(store, **kwargs)
    elif typ == "parquet":
        from daskms.experimental.arrow import xds_from_parquet
        return xds_from_parquet(store, **kwargs)
    else:
        raise TypeError(f"Unknown dataset {typ}")
Example #9
0
def test_zarr_string_array(tmp_path_factory):
    zarr_store = tmp_path_factory.mktemp("string-arrays") / "test.zarr"

    data = ["hello", "this", "strange new world",
            "full of", "interesting", "stuff"]
    data = np.array(data, dtype=object).reshape(3, 2)
    data = da.from_array(data, chunks=((2, 1), (1, 1)))

    datasets = [Dataset({"DATA": (("x", "y"), data)})]
    writes = xds_to_zarr(datasets, zarr_store)
    dask.compute(writes)

    new_datasets = xds_from_zarr(zarr_store)

    assert len(new_datasets) == len(datasets)

    for nds, ds in zip(new_datasets, datasets):
        assert_array_equal(nds.DATA.data, ds.DATA.data)
Example #10
0
        def convolver(x):
            model = da.from_array(x, chunks=(1, nx, ny), name=False)

            xds = xds_from_zarr(args.weight_table,
                                chunks={
                                    'row': row_chunk,
                                    'chan': bin_counts
                                })[0]

            convolvedim = hessian(xds.UVW.data,
                                  freqs,
                                  model,
                                  freq_bin_idx,
                                  freq_bin_counts,
                                  cell_rad,
                                  weights=xds.WEIGHT.data.astype(
                                      args.output_type),
                                  nthreads=args.nvthreads,
                                  epsilon=args.epsilon,
                                  do_wstacking=args.wstack,
                                  double_accum=args.double_accum)
            return convolvedim
Example #11
0
def test_multiprocess_create(ms, tmp_path_factory):
    zarr_store = tmp_path_factory.mktemp("zarr_store") / "test.zarr"

    ms_datasets = xds_from_ms(ms)

    for i, ds in enumerate(ms_datasets):
        ms_datasets[i] = ds.chunk({"row": 1})

    writes = xds_to_zarr(ms_datasets, zarr_store)

    dask.compute(writes, scheduler="processes")

    zds = xds_from_zarr(zarr_store)

    for zds, msds in zip(zds, ms_datasets):
        for k, v in msds.data_vars.items():
            assert_array_equal(v, getattr(zds, k))

        for k, v in msds.coords.items():
            assert_array_equal(v, getattr(zds, k))

        for k, v in msds.attrs.items():
            assert_array_equal(v, getattr(zds, k))
Example #12
0
def _jones2col(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from daskms.experimental.zarr import xds_from_zarr
    from daskms import xds_from_ms, xds_to_table
    import dask.array as da
    import dask
    from africanus.calibration.utils import chunkify_rows
    from africanus.calibration.utils.dask import corrupt_vis

    # get net gains
    G = xds_from_zarr(args.gain_table + '::G')

    # chunking info
    t_chunks = G[0].t_chunk.data
    if len(t_chunks) > 1:
        t_chunks = G[0].t_chunk.data[1:-1] - G[0].t_chunk.data[0:-2]
        assert (t_chunks == t_chunks[0]).all()
        utpc = t_chunks[0]
    else:
        utpc = t_chunks[0]
    times = xds_from_ms(args.ms[0], columns=['TIME'])[0].get('TIME').data.compute()
    row_chunks, tbin_idx, tbin_counts = chunkify_rows(times, utimes_per_chunk=utpc, daskify_idx=True)

    f_chunks = G[0].f_chunk.data
    if len(f_chunks) > 1:
        f_chunks = G[0].f_chunk.data[1:-1] - G[0].f_chunk.data[0:-2]
        assert (f_chunks == f_chunks[0]).all()
        chan_chunks = f_chunks[0]
    else:
        if f_chunks[0]:
            chan_chunks = f_chunks[0]
        else:
            chan_chunks = -1

    columns = ('DATA', 'FLAG', 'FLAG_ROW', 'ANTENNA1', 'ANTENNA2')
    if args.acol is not None:
        columns += (args.acol,)

    # open MS
    xds = xds_from_ms(args.ms[0], chunks={'row': row_chunks, 'chan': chan_chunks},
                      columns=columns,
                      group_cols=('FIELD_ID', 'DATA_DESC_ID', 'SCAN_NUMBER'))

    # Current hack probably only works for single field and DDID
    try:
        assert len(xds) == len(G)
    except Exception as e:
        raise ValueError("Number of datasets in gains do not "
                            "match those in MS")

    # assuming scans are aligned
    out_data = []
    for g, ds in zip(G, xds):
        try:
            assert g.SCAN_NUMBER == ds.SCAN_NUMBER
        except Exception as e:
            raise ValueError("Scans not aligned")

        nrow = ds.dims['row']
        nchan = ds.dims['chan']
        ncorr = ds.dims['corr']

        # need to swap axes for africanus
        jones = da.swapaxes(g.gains.data, 1, 2)
        flag = ds.FLAG.data
        frow = ds.FLAG_ROW.data
        ant1 = ds.ANTENNA1.data
        ant2 = ds.ANTENNA2.data

        frow = (frow | (ant1 == ant2))
        flag = (flag[:, :, 0] | flag[:, :, -1])
        flag = da.logical_or(flag, frow[:, None])

        if args.acol is not None:
            acol = ds.get(args.acol).data.reshape(nrow, nchan, 1, ncorr)
        else:
            acol = da.ones((nrow, nchan, 1, ncorr),
                           chunks=(row_chunks, chan_chunks, 1, -1),
                           dtype=jones.dtype)

        cvis = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones, acol)

        # compare where unflagged
        if args.compareto is not None:
            flag = flag.compute()
            vis = ds.get(args.compareto).values[~flag]
            print("Max abs difference = ", np.abs(cvis.compute()[~flag] - vis).max())
            quit()

        out_ds = ds.assign(**{args.mueller_column: (("row", "chan", "corr"), cvis)})
        out_data.append(out_ds)

    writes = xds_to_table(out_data, args.ms[0], columns=[args.mueller_column])
    dask.compute(writes)
Example #13
0
def _forward(**kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)

    import numpy as np
    import numexpr as ne
    import dask
    import dask.array as da
    from dask.distributed import performance_report
    from pfb.utils.fits import load_fits, set_wcs, save_fits, data_from_header
    from pfb.opt.hogbom import hogbom
    from astropy.io import fits

    print("Loading residual", file=log)
    residual = load_fits(args.residual, dtype=args.output_type).squeeze()
    nband, nx, ny = residual.shape
    hdr = fits.getheader(args.residual)

    print("Loading psf", file=log)
    psf = load_fits(args.psf, dtype=args.output_type).squeeze()
    _, nx_psf, ny_psf = psf.shape
    hdr_psf = fits.getheader(args.psf)

    wsums = np.amax(psf.reshape(-1, nx_psf*ny_psf), axis=1)
    wsum = np.sum(wsums)

    psf /= wsum
    psf_mfs = np.sum(psf, axis=0)

    assert (psf_mfs.max() - 1.0) < 1e-4

    residual /= wsum
    residual_mfs = np.sum(residual, axis=0)

    # get info required to set WCS
    ra = np.deg2rad(hdr['CRVAL1'])
    dec = np.deg2rad(hdr['CRVAL2'])
    radec = [ra, dec]

    cell_deg = np.abs(hdr['CDELT1'])
    if cell_deg != np.abs(hdr['CDELT2']):
        raise NotImplementedError('cell sizes have to be equal')
    cell_rad = np.deg2rad(cell_deg)

    l_coord, ref_l = data_from_header(hdr, axis=1)
    l_coord -= ref_l
    m_coord, ref_m = data_from_header(hdr, axis=2)
    m_coord -= ref_m
    freq_out, ref_freq = data_from_header(hdr, axis=3)

    hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq)

    save_fits(args.output_filename + '_residual_mfs.fits', residual_mfs, hdr_mfs,
              dtype=args.output_type)

    rms = np.std(residual_mfs)
    rmax = np.abs(residual_mfs).max()

    print("Initial peak residual = %f, rms = %f" % (rmax, rms), file=log)

    # load beam
    if args.beam_model is not None:
        if args.beam_model.endswith('.fits'):  # beam already interpolated
            bhdr = fits.getheader(args.beam_model)
            l_coord_beam, ref_lb = data_from_header(bhdr, axis=1)
            l_coord_beam -= ref_lb
            if not np.array_equal(l_coord_beam, l_coord):
                raise ValueError("l coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.")

            m_coord_beam, ref_mb = data_from_header(bhdr, axis=2)
            m_coord_beam -= ref_mb
            if not np.array_equal(m_coord_beam, m_coord):
                raise ValueError("m coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.")

            freq_beam, _ = data_from_header(bhdr, axis=freq_axis)
            if not np.array_equal(freq_out, freq_beam):
                raise ValueError("Freqs of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.")

            beam_image = load_fits(args.beam_model, dtype=args.output_type).squeeze()
        elif args.beam_model.lower() == "jimbeam":
            from katbeam import JimBeam
            if args.band.lower() == 'l':
                beam = JimBeam('MKAT-AA-L-JIM-2020')
            elif args.band.lower() == 'uhf':
                beam = JimBeam('MKAT-AA-UHF-JIM-2020')
            else:
                raise ValueError("Unkown band %s"%args.band[i])

            xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')
            beam_image = np.zeros(residual.shape, dtype=args.output_type)
            for v in range(freq_out.size):
                # freq must be in MHz
                beam_image[v] = beam.I(xx, yy, freq_out[v]/1e6).astype(args.output_type)
    else:
        beam_image = np.ones((nband, nx, ny), dtype=args.output_type)

    if args.mask is not None:
        mask = load_fits(args.mask).squeeze()
        assert mask.shape == (nx, ny)
        beam_image *= mask[None, :, :]

    beam_image = da.from_array(beam_image, chunks=(1, -1, -1))

    # if weight table is provided we use the vis space Hessian approximation
    if args.weight_table is not None:
        print("Solving for update using vis space approximation", file=log)
        normfact = wsum
        from pfb.utils.misc import plan_row_chunk
        from daskms.experimental.zarr import xds_from_zarr

        xds = xds_from_zarr(args.weight_table)[0]
        nrow = xds.row.size
        freq = xds.chan.data
        nchan = freq.size

        # bin edges
        fmin = freq.min()
        fmax = freq.max()
        fbins = np.linspace(fmin, fmax, nband + 1)

        # chan <-> band mapping
        band_mapping = {}
        chan_chunks = {}
        freq_bin_idx = {}
        freq_bin_counts = {}
        band_map = np.zeros(freq.size, dtype=np.int32)
        for band in range(nband):
            indl = freq >= fbins[band]
            indu = freq < fbins[band + 1] + 1e-6
            band_map = np.where(indl & indu, band, band_map)

        # to dask arrays
        bands, bin_counts = np.unique(band_map, return_counts=True)
        band_mapping = tuple(bands)
        chan_chunks = {'chan': tuple(bin_counts)}
        freq = da.from_array(freq, chunks=tuple(bin_counts))
        bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1]
        freq_bin_idx = da.from_array(bin_idx, chunks=1)
        freq_bin_counts = da.from_array(bin_counts, chunks=1)

        max_chan_chunk = bin_counts.max()
        bin_counts = tuple(bin_counts)
        # the first factor of 3 accounts for the intermediate visibilities
        # produced in Hessian (i.e. complex data + real weights)
        memory_per_row = (3 * max_chan_chunk * xds.WEIGHT.data.itemsize +
                          3 * xds.UVW.data.itemsize)

        # get approx image size
        pixel_bytes = np.dtype(args.output_type).itemsize
        band_size = nx * ny * pixel_bytes

        if args.host_address is None:
            # nworker bands on single node
            row_chunk = plan_row_chunk(args.mem_limit/args.nworkers, band_size, nrow,
                                       memory_per_row, args.nthreads_per_worker)
        else:
            # single band per node
            row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                       memory_per_row, args.nthreads_per_worker)

        print("nrows = %i, row chunks set to %i for a total of %i chunks per node" %
              (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log)

        residual = da.from_array(residual, chunks=(1, -1, -1))
        x0 = da.zeros((nband, nx, ny), chunks=(1, -1, -1), dtype=residual.dtype)

        xds = xds_from_zarr(args.weight_table, chunks={'row': -1, #row_chunk,
                            'chan': bin_counts})[0]

        from pfb.opt.pcg import pcg_wgt

        model = pcg_wgt(xds.UVW.data,
                        xds.WEIGHT.data.astype(args.output_type),
                        residual,
                        x0,
                        beam_image,
                        freq,
                        freq_bin_idx,
                        freq_bin_counts,
                        cell_rad,
                        args.wstack,
                        args.epsilon,
                        args.double_accum,
                        args.nvthreads,
                        args.sigmainv,
                        wsum,
                        args.cg_tol,
                        args.cg_maxit,
                        args.cg_minit,
                        args.cg_verbose,
                        args.cg_report_freq,
                        args.backtrack).compute()

    else:  # we use the image space approximation
        print("Solving for update using image space approximation", file=log)
        normfact = 1.0
        from pfb.operators.psf import hessian
        from ducc0.fft import r2c
        iFs = np.fft.ifftshift

        npad_xl = (nx_psf - nx)//2
        npad_xr = nx_psf - nx - npad_xl
        npad_yl = (ny_psf - ny)//2
        npad_yr = ny_psf - ny - npad_yl
        padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr))
        unpad_x = slice(npad_xl, -npad_xr)
        unpad_y = slice(npad_yl, -npad_yr)
        lastsize = ny + np.sum(padding[-1])
        psf_pad = iFs(psf, axes=(1, 2))
        psfhat = r2c(psf_pad, axes=(1, 2), forward=True,
                     nthreads=nthreads, inorm=0)

        psfhat = da.from_array(psfhat, chunks=(1, -1, -1))
        residual = da.from_array(residual, chunks=(1, -1, -1))
        x0 = da.zeros((nband, nx, ny), chunks=(1, -1, -1))


        from pfb.opt.pcg import pcg_psf

        model = pcg_psf(psfhat,
                        residual,
                        x0,
                        beam_image,
                        args.sigmainv,
                        args.nvthreads,
                        padding,
                        unpad_x,
                        unpad_y,
                        lastsize,
                        args.cg_tol,
                        args.cg_maxit,
                        args.cg_minit,
                        args.cg_verbose,
                        args.cg_report_freq,
                        args.backtrack).compute()


    print("Saving results", file=log)
    save_fits(args.output_filename + '_update.fits', model, hdr)
    model_mfs = np.mean(model, axis=0)
    save_fits(args.output_filename + '_update_mfs.fits', model_mfs, hdr_mfs)

    print("All done here.", file=log)
Example #14
0
def _clean(**kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)

    import numpy as np
    import numexpr as ne
    import dask
    import dask.array as da
    from dask.distributed import performance_report
    from pfb.utils.fits import load_fits, set_wcs, save_fits, data_from_header
    from pfb.opt.hogbom import hogbom
    from astropy.io import fits

    print("Loading dirty", file=log)
    dirty = load_fits(args.dirty, dtype=args.output_type).squeeze()
    nband, nx, ny = dirty.shape
    hdr = fits.getheader(args.dirty)

    print("Loading psf", file=log)
    psf = load_fits(args.psf, dtype=args.output_type).squeeze()
    _, nx_psf, ny_psf = psf.shape
    hdr_psf = fits.getheader(args.psf)

    wsums = np.amax(psf.reshape(-1, nx_psf * ny_psf), axis=1)
    wsum = np.sum(wsums)

    psf /= wsum
    psf_mfs = np.sum(psf, axis=0)

    assert (psf_mfs.max() - 1.0) < 1e-4

    dirty /= wsum
    dirty_mfs = np.sum(dirty, axis=0)

    # get info required to set WCS
    ra = np.deg2rad(hdr['CRVAL1'])
    dec = np.deg2rad(hdr['CRVAL2'])
    radec = [ra, dec]

    cell_deg = np.abs(hdr['CDELT1'])
    if cell_deg != np.abs(hdr['CDELT2']):
        raise NotImplementedError('cell sizes have to be equal')
    cell_rad = np.deg2rad(cell_deg)

    freq_out, ref_freq = data_from_header(hdr, axis=3)

    hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq)

    save_fits(args.output_filename + '_dirty_mfs.fits',
              dirty_mfs,
              hdr_mfs,
              dtype=args.output_type)

    # set up Hessian approximation
    if args.weight_table is not None:
        normfact = wsum
        from africanus.gridding.wgridder.dask import hessian
        from pfb.utils.misc import plan_row_chunk
        from daskms.experimental.zarr import xds_from_zarr

        xds = xds_from_zarr(args.weight_table)[0]
        nrow = xds.row.size
        freqs = xds.chan.data
        nchan = freqs.size

        # bin edges
        fmin = freqs.min()
        fmax = freqs.max()
        fbins = np.linspace(fmin, fmax, nband + 1)

        # chan <-> band mapping
        band_mapping = {}
        chan_chunks = {}
        freq_bin_idx = {}
        freq_bin_counts = {}
        band_map = np.zeros(freqs.size, dtype=np.int32)
        for band in range(nband):
            indl = freqs >= fbins[band]
            indu = freqs < fbins[band + 1] + 1e-6
            band_map = np.where(indl & indu, band, band_map)

        # to dask arrays
        bands, bin_counts = np.unique(band_map, return_counts=True)
        band_mapping = tuple(bands)
        chan_chunks = {'chan': tuple(bin_counts)}
        freqs = da.from_array(freqs, chunks=tuple(bin_counts))
        bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1]
        freq_bin_idx = da.from_array(bin_idx, chunks=1)
        freq_bin_counts = da.from_array(bin_counts, chunks=1)

        max_chan_chunk = bin_counts.max()
        bin_counts = tuple(bin_counts)
        # the first factor of 3 accounts for the intermediate visibilities
        # produced in Hessian (i.e. complex data + real weights)
        memory_per_row = (3 * max_chan_chunk * xds.WEIGHT.data.itemsize +
                          3 * xds.UVW.data.itemsize)

        # get approx image size
        pixel_bytes = np.dtype(args.output_type).itemsize
        band_size = nx * ny * pixel_bytes

        if args.host_address is None:
            # nworker bands on single node
            row_chunk = plan_row_chunk(args.mem_limit / args.nworkers,
                                       band_size, nrow, memory_per_row,
                                       args.nthreads_per_worker)
        else:
            # single band per node
            row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                       memory_per_row,
                                       args.nthreads_per_worker)

        print(
            "nrows = %i, row chunks set to %i for a total of %i chunks per node"
            % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
            file=log)

        def convolver(x):
            model = da.from_array(x, chunks=(1, nx, ny), name=False)

            xds = xds_from_zarr(args.weight_table,
                                chunks={
                                    'row': row_chunk,
                                    'chan': bin_counts
                                })[0]

            convolvedim = hessian(xds.UVW.data,
                                  freqs,
                                  model,
                                  freq_bin_idx,
                                  freq_bin_counts,
                                  cell_rad,
                                  weights=xds.WEIGHT.data.astype(
                                      args.output_type),
                                  nthreads=args.nvthreads,
                                  epsilon=args.epsilon,
                                  do_wstacking=args.wstack,
                                  double_accum=args.double_accum)
            return convolvedim
    else:
        normfact = 1.0
        from pfb.operators.psf import hessian
        from ducc0.fft import r2c
        iFs = np.fft.ifftshift

        npad_xl = (nx_psf - nx) // 2
        npad_xr = nx_psf - nx - npad_xl
        npad_yl = (ny_psf - ny) // 2
        npad_yr = ny_psf - ny - npad_yl
        padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr))
        unpad_x = slice(npad_xl, -npad_xr)
        unpad_y = slice(npad_yl, -npad_yr)
        lastsize = ny + np.sum(padding[-1])
        psf_pad = iFs(psf, axes=(1, 2))
        psfhat = r2c(psf_pad,
                     axes=(1, 2),
                     forward=True,
                     nthreads=nthreads,
                     inorm=0)

        psfhat = da.from_array(psfhat, chunks=(1, -1, -1))

        def convolver(x):
            model = da.from_array(x, chunks=(1, nx, ny), name=False)

            convolvedim = hessian(model, psfhat, padding, nvthreads, unpad_x,
                                  unpad_y, lastsize)
            return convolvedim

        # psfo = PSF(psf, dirty.shape, nthreads=args.nthreads)
        # def convolver(x): return psfo.convolve(x)

    rms = np.std(dirty_mfs)
    rmax = np.abs(dirty_mfs).max()

    print("Iter %i: peak residual = %f, rms = %f" % (0, rmax, rms), file=log)

    residual = dirty.copy()
    residual_mfs = dirty_mfs.copy()
    model = np.zeros_like(residual)
    for k in range(args.nmiter):
        print("Running Hogbom", file=log)
        x = hogbom(residual,
                   psf,
                   gamma=args.hb_gamma,
                   pf=args.hb_peak_factor,
                   maxit=args.hb_maxit,
                   verbosity=args.hb_verbose,
                   report_freq=args.hb_report_freq)

        model += x
        print("Getting residual", file=log)

        convimage = convolver(model)
        dask.visualize(convimage,
                       filename=args.output_filename + '_hessian' + str(k) +
                       '_graph.pdf',
                       optimize_graph=False)
        with performance_report(filename=args.output_filename + '_hessian' +
                                str(k) + '_per.html'):
            convimage = dask.compute(convimage, optimize_graph=False)[0]
        ne.evaluate('dirty - convimage/normfact',
                    out=residual,
                    casting='same_kind')
        ne.evaluate('sum(residual, axis=0)',
                    out=residual_mfs,
                    casting='same_kind')

        rms = np.std(residual_mfs)
        rmax = np.abs(residual_mfs).max()

        print("Iter %i: peak residual = %f, rms = %f" % (k + 1, rmax, rms),
              file=log)

    print("Saving results", file=log)
    save_fits(args.output_filename + '_model.fits', model, hdr)
    model_mfs = np.mean(model, axis=0)
    save_fits(args.output_filename + '_model_mfs.fits', model_mfs, hdr_mfs)
    save_fits(args.output_filename + '_residual.fits',
              residual * wsums[:, None, None], hdr)
    save_fits(args.output_filename + '_residual.fits', residual_mfs, hdr_mfs)

    print("All done here.", file=log)