Пример #1
0
    def __init__(self, psf, imsize, nthreads=1, backward_undersize=None):
        self.nthreads = nthreads
        self.nband, nx_psf, ny_psf = psf.shape
        _, nx, ny = imsize
        npad_xl = (nx_psf - nx) // 2
        npad_xr = nx_psf - nx - npad_xl
        npad_yl = (ny_psf - ny) // 2
        npad_yr = ny_psf - ny - npad_yl
        self.padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr))
        self.ax = (1, 2)
        self.unpad_x = slice(npad_xl, -npad_xr)
        self.unpad_y = slice(npad_yl, -npad_yr)
        self.lastsize = ny + np.sum(self.padding[-1])
        self.psf = psf
        psf_pad = iFs(psf, axes=self.ax)
        self.psfhat = r2c(psf_pad,
                          axes=self.ax,
                          forward=True,
                          nthreads=nthreads,
                          inorm=0)

        # LB - failed experiment?
        # self.psfhatinv = 1/(self.psfhat + 1.0)

        if backward_undersize is not None:
            # set up for backward step
            nx_psfb = good_size(int(backward_undersize * nx))
            ny_psfb = good_size(int(backward_undersize * ny))
            npad_xlb = (nx_psfb - nx) // 2
            npad_xrb = nx_psfb - nx - npad_xlb
            npad_ylb = (ny_psfb - ny) // 2
            npad_yrb = ny_psfb - ny - npad_ylb
            self.paddingb = ((0, 0), (npad_xlb, npad_xrb), (npad_ylb,
                                                            npad_yrb))
            self.unpad_xb = slice(npad_xlb, -npad_xrb)
            self.unpad_yb = slice(npad_ylb, -npad_yrb)
            self.lastsizeb = ny + np.sum(self.paddingb[-1])

            xlb = (nx_psf - nx_psfb) // 2
            xrb = nx_psf - nx_psfb - xlb
            ylb = (ny_psf - ny_psfb) // 2
            yrb = ny_psf - ny_psfb - ylb
            psf_padb = iFs(psf[:, slice(xlb, -xrb),
                               slice(ylb, -yrb)],
                           axes=self.ax)
            self.psfhatb = r2c(psf_padb,
                               axes=self.ax,
                               forward=True,
                               nthreads=nthreads,
                               inorm=0)
        else:
            self.paddingb = self.padding
            self.unpad_xb = self.unpad_x
            self.unpad_yb = self.unpad_y
            self.lastsizeb = self.lastsize
            self.psfhatb = self.psfhat
Пример #2
0
def get_padding_info(nx, ny, pfrac):
    from ducc0.fft import good_size
    npad_x = int(pfrac * nx)
    nfft = good_size(nx + npad_x, True)
    npad_xl = (nfft - nx) // 2
    npad_xr = nfft - nx - npad_xl

    npad_y = int(pfrac * ny)
    nfft = good_size(ny + npad_y, True)
    npad_yl = (nfft - ny) // 2
    npad_yr = nfft - ny - npad_yl
    padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr))
    unpad_x = slice(npad_xl, -npad_xr)
    unpad_y = slice(npad_yl, -npad_yr)
    return padding, unpad_x, unpad_y
Пример #3
0
def bench_nd(ndim,
             nmax,
             nthr,
             ntry,
             tp,
             funcs,
             nrepeat,
             ttl="",
             filename="",
             nice_sizes=True):
    print("{}D, type {}, max extent is {}:".format(ndim, tp, nmax))
    results = [[] for i in range(len(funcs))]
    for n in range(ntry):
        shp = rng.integers(nmax // 3, nmax + 1, ndim)
        if nice_sizes:
            shp = np.array([duccfft.good_size(sz) for sz in shp])
        print("  {0:4d}/{1}: shape={2} ...".format(n, ntry, shp),
              end=" ",
              flush=True)
        a = (rng.random(shp) - 0.5 + 1j * (rng.random(shp) - 0.5)).astype(tp)
        output = []
        for func, res in zip(funcs, results):
            tmp = func(a, nrepeat, nthr)
            res.append(tmp[0])
            output.append(tmp[1])
        print("{0:5.2e}/{1:5.2e} = {2:5.2f}  L2 error={3}".format(
            results[0][n], results[1][n], results[0][n] / results[1][n],
            _l2error(output[0], output[1])))
    results = np.array(results)
    plt.title("{}: {}D, {}, max_extent={}".format(ttl, ndim, str(tp), nmax))
    plt.xlabel("time ratio")
    plt.ylabel("counts")
    plt.hist(results[0, :] / results[1, :], bins="auto")
    if filename != "":
        plt.savefig(filename)
    plt.show()
Пример #4
0
def _main(dest=sys.stdout):
    from pfb.parser import create_parser
    args = create_parser().parse_args()

    if not args.nthreads:
        import multiprocessing
        args.nthreads = multiprocessing.cpu_count()

    if not args.mem_limit:
        import psutil
        args.mem_limit = int(psutil.virtual_memory()[0] /
                             1e9)  # 100% of memory by default

    import numpy as np
    import numba
    import numexpr
    import dask
    import dask.array as da
    from daskms import xds_from_ms, xds_from_table
    from astropy.io import fits
    from pfb.utils.fits import (set_wcs, load_fits, save_fits, compare_headers,
                                data_from_header)
    from pfb.utils.restoration import fitcleanbeam
    from pfb.utils.misc import Gaussian2D
    from pfb.operators.gridder import Gridder
    from pfb.operators.psf import PSF
    from pfb.deconv.sara import sara
    from pfb.deconv.clean import clean
    from pfb.deconv.spotless import spotless
    from pfb.deconv.nnls import nnls
    from pfb.opt.pcg import pcg

    if not isinstance(args.ms, list):
        args.ms = [args.ms]

    pyscilog.log_to_file(args.outfile + '.log')
    pyscilog.enable_memory_logging(level=3)

    GD = vars(args)
    print('Input Options:', file=log)
    for key in GD.keys():
        print('     %25s = %s' % (key, GD[key]), file=log)

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list([tmp_freq]))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=dest)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size, file=dest)

    if args.nx is None or args.ny is None:
        from ducc0.fft import good_size
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = good_size(npix)
        args.ny = good_size(npix)

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny),
          file=dest)

    # mask
    if args.mask is not None:
        mask_array = load_fits(args.mask, dtype=args.real_type).squeeze()
        if mask_array.shape != (args.nx, args.ny):
            raise ValueError("Mask has incorrect shape.")
        # add freq axis
        mask_array = mask_array[None]

        def mask(x):
            return mask_array * x
    else:
        mask_array = None

        def mask(x):
            return x

    # init gridder
    R = Gridder(
        args.ms,
        args.nx,
        args.ny,
        args.cell_size,
        nband=args.nband,
        nthreads=args.nthreads,
        do_wstacking=args.do_wstacking,
        row_chunks=args.row_chunks,
        psf_oversize=args.psf_oversize,
        data_column=args.data_column,
        epsilon=args.epsilon,
        weight_column=args.weight_column,
        imaging_weight_column=args.imaging_weight_column,
        model_column=args.model_column,
        flag_column=args.flag_column,
        weighting=args.weighting,
        robust=args.robust,
        mem_limit=int(
            0.8 * args.mem_limit))  # assumes gridding accounts for 80% memory
    freq_out = R.freq_out
    radec = R.radec

    print("PSF size set to (%i, %i, %i)" % (args.nband, R.nx_psf, R.ny_psf),
          file=dest)

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600, R.nx_psf,
                      R.ny_psf, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          R.nx_psf, R.ny_psf, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        try:
            compare_headers(hdr_psf, fits.getheader(args.psf))
            psf = load_fits(args.psf, dtype=args.real_type).squeeze()
        except BaseException:
            raise
            psf = R.make_psf()
            save_fits(args.outfile + '_psf.fits', psf, hdr_psf)
    else:
        psf = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf, hdr_psf)

    # Normalising by wsum (so that the PSF always sums to 1) results in the
    # most intuitive sig_21 values and by far the least bookkeeping.
    # However, we won't save the cubes that way as it destroys information
    # about the noise in image space. Note only the MFS images will have the
    # usual units of Jy/beam.
    wsums = np.amax(psf.reshape(args.nband, R.nx_psf * R.ny_psf), axis=1)
    wsum = np.sum(wsums)
    psf /= wsum
    psf_mfs = np.sum(psf, axis=0)

    # fit restoring psf
    GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0)
    GaussPars = fitcleanbeam(psf, level=0.5, pixsize=1.0)

    cpsf_mfs = np.zeros(psf_mfs.shape, dtype=args.real_type)
    cpsf = np.zeros(psf.shape, dtype=args.real_type)

    lpsf = np.arange(-R.nx_psf / 2, R.nx_psf / 2)
    mpsf = np.arange(-R.ny_psf / 2, R.ny_psf / 2)
    xx, yy = np.meshgrid(lpsf, mpsf, indexing='ij')

    cpsf_mfs = Gaussian2D(xx, yy, GaussPar[0], normalise=False)

    for v in range(args.nband):
        cpsf[v] = Gaussian2D(xx, yy, GaussPars[v], normalise=False)

    from pfb.utils.fits import add_beampars
    GaussPar = list(GaussPar[0])
    GaussPar[0] *= args.cell_size / 3600
    GaussPar[1] *= args.cell_size / 3600
    GaussPar = tuple(GaussPar)
    hdr_psf_mfs = add_beampars(hdr_psf_mfs, GaussPar)

    save_fits(args.outfile + '_cpsf_mfs.fits', cpsf_mfs, hdr_psf_mfs)
    save_fits(args.outfile + '_psf_mfs.fits', psf_mfs, hdr_psf_mfs)

    GaussPars = list(GaussPars)
    for b in range(args.nband):
        GaussPars[b] = list(GaussPars[b])
        GaussPars[b][0] *= args.cell_size / 3600
        GaussPars[b][1] *= args.cell_size / 3600
        GaussPars[b] = tuple(GaussPars[b])
    GaussPars = tuple(GaussPars)
    hdr_psf = add_beampars(hdr_psf, GaussPar, GaussPars)

    save_fits(args.outfile + '_cpsf.fits', cpsf, hdr_psf)

    # dirty
    if args.dirty is not None:
        try:
            compare_headers(hdr, fits.getheader(args.dirty))
            dirty = load_fits(args.dirty).squeeze()
        except BaseException:
            raise
            dirty = R.make_dirty()
            save_fits(args.outfile + '_dirty.fits', dirty, hdr)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty /= wsum
    dirty_mfs = np.sum(dirty, axis=0)
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    quit()
    # initial model and residual
    if args.x0 is not None:
        try:
            compare_headers(hdr, fits.getheader(args.x0))
            model = load_fits(args.x0, dtype=args.real_type).squeeze()
            if args.first_residual is not None:
                try:
                    compare_headers(hdr, fits.getheader(args.first_residual))
                    residual = load_fits(args.first_residual,
                                         dtype=args.real_type).squeeze()
                except BaseException:
                    residual = R.make_residual(model)
                    save_fits(args.outfile + '_first_residual.fits', residual,
                              hdr)
            else:
                residual = R.make_residual(model)
                save_fits(args.outfile + '_first_residual.fits', residual, hdr)
            residual /= wsum
        except BaseException:
            model = np.zeros((args.nband, args.nx, args.ny))
            residual = dirty.copy()
    else:
        model = np.zeros((args.nband, args.nx, args.ny))
        residual = dirty.copy()

    residual_mfs = np.sum(residual, axis=0)
    save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs)

    # smooth beam
    if args.beam_model is not None:
        if args.beam_model[-5:] == '.fits':
            beam_image = load_fits(args.beam_model,
                                   dtype=args.real_type).squeeze()
            if beam_image.shape != (args.nband, args.nx, args.ny):
                raise ValueError("Beam has incorrect shape")

        elif args.beam_model == "JimBeam":
            from katbeam import JimBeam
            if args.band.lower() == 'l':
                beam = JimBeam('MKAT-AA-L-JIM-2020')
            else:
                beam = JimBeam('MKAT-AA-UHF-JIM-2020')
            beam_image = np.zeros((args.nband, args.nx, args.ny),
                                  dtype=args.real_type)

            l_coord, ref_l = data_from_header(hdr, axis=1)
            l_coord -= ref_l
            m_coord, ref_m = data_from_header(hdr, axis=2)
            m_coord -= ref_m
            xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')

            for v in range(args.nband):
                beam_image[v] = beam.I(xx, yy, freq_out[v])

        def beam(x):
            return beam_image * x
    else:
        beam_image = None

        def beam(x):
            return x

    if args.init_nnls:
        print("Initialising with NNLS", file=log)
        model = nnls(psf,
                     model,
                     residual,
                     mask=mask_array,
                     beam_image=beam_image,
                     hdr=hdr,
                     hdr_mfs=hdr_mfs,
                     outfile=args.outfile,
                     maxit=1,
                     nthreads=args.nthreads)

        residual = R.make_residual(beam(mask(model))) / wsum
        residual_mfs = np.sum(residual, axis=0)

    # deconvolve
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)
    redo_dirty = False
    print("Peak of initial residual is %f and rms is %f" % (rmax, rms),
          file=dest)
    for i in range(0, args.maxit):
        # run minor cycle of choice
        modelp = model.copy()
        if args.deconv_mode == 'sara':
            model = sara(psf,
                         model,
                         residual,
                         mask=mask_array,
                         beam_image=beam_image,
                         hessian=R.convolve,
                         wsum=wsum,
                         adapt_sig21=args.adapt_sig21,
                         hdr=hdr,
                         hdr_mfs=hdr_mfs,
                         outfile=args.outfile,
                         cpsf=cpsf,
                         nthreads=args.nthreads,
                         sig_21=args.sig_21,
                         sigma_frac=args.sigma_frac,
                         maxit=args.minormaxit,
                         tol=args.minortol,
                         gamma=args.gamma,
                         psi_levels=args.psi_levels,
                         psi_basis=args.psi_basis,
                         pdtol=args.pdtol,
                         pdmaxit=args.pdmaxit,
                         pdverbose=args.pdverbose,
                         positivity=args.positivity,
                         cgtol=args.cgtol,
                         cgminit=args.cgminit,
                         cgmaxit=args.cgmaxit,
                         cgverbose=args.cgverbose,
                         pmtol=args.pmtol,
                         pmmaxit=args.pmmaxit,
                         pmverbose=args.pmverbose)

        elif args.deconv_mode == 'clean':
            model = clean(psf,
                          model,
                          residual,
                          mask=mask_array,
                          beam=beam_image,
                          nthreads=args.nthreads,
                          maxit=args.minormaxit,
                          gamma=args.gamma,
                          peak_factor=args.peak_factor,
                          threshold=args.threshold,
                          hbgamma=args.hbgamma,
                          hbpf=args.hbpf,
                          hbmaxit=args.hbmaxit,
                          hbverbose=args.hbverbose)
        elif args.deconv_mode == 'spotless':
            model = spotless(psf,
                             model,
                             residual,
                             mask=mask_array,
                             beam_image=beam_image,
                             hessian=R.convolve,
                             wsum=wsum,
                             adapt_sig21=args.adapt_sig21,
                             cpsf=cpsf_mfs,
                             hdr=hdr,
                             hdr_mfs=hdr_mfs,
                             outfile=args.outfile,
                             sig_21=args.sig_21,
                             sigma_frac=args.sigma_frac,
                             nthreads=args.nthreads,
                             gamma=args.gamma,
                             peak_factor=args.peak_factor,
                             maxit=args.minormaxit,
                             tol=args.minortol,
                             threshold=args.threshold,
                             positivity=args.positivity,
                             hbgamma=args.hbgamma,
                             hbpf=args.hbpf,
                             hbmaxit=args.hbmaxit,
                             hbverbose=args.hbverbose,
                             pdtol=args.pdtol,
                             pdmaxit=args.pdmaxit,
                             pdverbose=args.pdverbose,
                             cgtol=args.cgtol,
                             cgminit=args.cgminit,
                             cgmaxit=args.cgmaxit,
                             cgverbose=args.cgverbose,
                             pmtol=args.pmtol,
                             pmmaxit=args.pmmaxit,
                             pmverbose=args.pmverbose)
        else:
            raise ValueError("Unknown deconvolution mode ", args.deconv_mode)

        # get residual
        if redo_dirty:
            # Need to do this if weights or Jones has changed
            # (eg. if we change robustness factor, reweight or calibrate)
            psf = R.make_psf()
            wsums = np.amax(psf.reshape(args.nband, R.nx_psf * R.ny_psf),
                            axis=1)
            wsum = np.sum(wsums)
            psf /= wsum
            dirty = R.make_dirty() / wsum

        # compute in image space
        # residual = dirty - R.convolve(beam(mask(model))) / wsum
        residual = R.make_residual(beam(mask(model))) / wsum

        residual_mfs = np.sum(residual, axis=0)

        # save current iteration
        model_mfs = np.mean(model, axis=0)
        save_fits(args.outfile + '_major' + str(i + 1) + '_model_mfs.fits',
                  model_mfs, hdr_mfs)

        save_fits(args.outfile + '_major' + str(i + 1) + '_model.fits', model,
                  hdr)

        save_fits(args.outfile + '_major' + str(i + 1) + '_residual_mfs.fits',
                  residual_mfs, hdr_mfs)

        save_fits(args.outfile + '_major' + str(i + 1) + '_residual.fits',
                  residual * wsum, hdr)

        # check stopping criteria
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        print("At iteration %i peak of residual is %f, rms is %f, current "
              "eps is %f" % (i + 1, rmax, rms, eps),
              file=dest)

        if eps < args.tol:
            break

    if args.mop_flux:
        print("Mopping flux", file=dest)

        # vague Gaussian prior on x
        def hess(x):
            return mask(beam(R.convolve(mask(beam(x))))) / wsum + 1e-6 * x

        def M(x):
            return x / 1e-6  # preconditioner

        x = pcg(hess,
                mask(beam(residual)),
                np.zeros(residual.shape, dtype=residual.dtype),
                M=M,
                tol=0.1 * args.cgtol,
                maxit=args.cgmaxit,
                minit=args.cgminit,
                verbosity=args.cgverbose)

        model += x
        # residual = dirty - R.convolve(beam(mask(model))) / wsum
        residual = R.make_residual(beam(mask(model))) / wsum

        save_fits(args.outfile + '_mopped_model.fits', model, hdr)
        save_fits(args.outfile + '_mopped_residual.fits', residual, hdr)
        model_mfs = np.mean(model, axis=0)
        save_fits(args.outfile + '_mopped_model_mfs.fits', model_mfs, hdr_mfs)
        residual_mfs = np.sum(residual, axis=0)
        save_fits(args.outfile + '_mopped_residual_mfs.fits', residual_mfs,
                  hdr_mfs)

        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)

        print("After mopping flux peak of residual is %f, rms is %f" %
              (rmax, rms),
              file=dest)

    # if args.interp_model:
    #     nband = args.nband
    #     order = args.spectral_poly_order
    #     phi.trim_fat(model)
    #     I = np.argwhere(phi.mask).squeeze()
    #     Ix = I[:, 0]
    #     Iy = I[:, 1]
    #     npix = I.shape[0]

    #     # get components
    #     beta = model[:, Ix, Iy]

    #     # fit integrated polynomial to model components
    #     # we are given frequencies at bin centers, convert to bin edges
    #     ref_freq = np.mean(freq_out)
    #     delta_freq = freq_out[1] - freq_out[0]
    #     wlow = (freq_out - delta_freq/2.0)/ref_freq
    #     whigh = (freq_out + delta_freq/2.0)/ref_freq
    #     wdiff = whigh - wlow

    #     # set design matrix for each component
    #     Xdesign = np.zeros([freq_out.size, args.spectral_poly_order])
    #     for i in range(1, args.spectral_poly_order+1):
    #         Xdesign[:, i-1] = (whigh**i - wlow**i)/(i*wdiff)

    #     weights = psf_max[:, None]
    #     dirty_comps = Xdesign.T.dot(weights*beta)

    #     hess_comps = Xdesign.T.dot(weights*Xdesign)

    #     comps = np.linalg.solve(hess_comps, dirty_comps)

    #     np.savez(args.outfile + "spectral_comps", comps=comps, ref_freq=ref_freq, mask=np.any(model, axis=0))

    if args.write_model:
        print("Writing model", file=dest)
        R.write_model(model)

    if args.make_restored:
        print("Making restored", file=dest)
        cpsfo = PSF(cpsf, residual.shape, nthreads=args.nthreads)
        restored = cpsfo.convolve(model)

        # residual needs to be in Jy/beam before adding to convolved model
        wsums = np.amax(psf.reshape(-1, R.nx_psf * R.ny_psf), axis=1)
        restored += residual / wsums[:, None, None]

        save_fits(args.outfile + '_restored.fits', restored, hdr)
        restored_mfs = np.mean(restored, axis=0)
        save_fits(args.outfile + '_restored_mfs.fits', restored_mfs, hdr_mfs)
        residual_mfs = np.sum(residual, axis=0)
Пример #5
0
def _residual(ms, stack, **kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)
    pyscilog.log_to_file(args.output_filename + '.log')
    pyscilog.enable_memory_logging(level=3)

    # number of threads per worker
    if args.nthreads is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify nthreads when using a distributed scheduler"
            )
        import multiprocessing
        nthreads = multiprocessing.cpu_count()
        args.nthreads = nthreads
    else:
        nthreads = args.nthreads

    # configure memory limit
    if args.mem_limit is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify mem-limit when using a distributed scheduler"
            )
        import psutil
        mem_limit = int(psutil.virtual_memory()[0] /
                        1e9)  # 100% of memory by default
        args.mem_limit = mem_limit
    else:
        mem_limit = args.mem_limit

    nband = args.nband
    if args.nworkers is None:
        nworkers = nband
        args.nworkers = nworkers
    else:
        nworkers = args.nworkers

    if args.nthreads_per_worker is None:
        nthreads_per_worker = 1
        args.nthreads_per_worker = nthreads_per_worker
    else:
        nthreads_per_worker = args.nthreads_per_worker

    # the number of chunks being read in simultaneously is equal to
    # the number of dask threads
    nthreads_dask = nworkers * nthreads_per_worker

    if args.ngridder_threads is None:
        if args.host_address is not None:
            ngridder_threads = nthreads // nthreads_per_worker
        else:
            ngridder_threads = nthreads // nthreads_dask
        args.ngridder_threads = ngridder_threads
    else:
        ngridder_threads = args.ngridder_threads

    ms = list(ms)
    print('Input Options:', file=log)
    for key in kw.keys():
        print('     %25s = %s' % (key, args[key]), file=log)

    # numpy imports have to happen after this step
    from pfb import set_client
    set_client(nthreads, mem_limit, nworkers, nthreads_per_worker,
               args.host_address, stack, log)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    from dask.graph_manipulation import clone
    from dask.distributed import performance_report
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import residual as im2residim
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # real valued weights
    wdims = getattr(xds[0], args.weight_column).data.ndim
    if wdims == 2:  # WEIGHT
        memory_per_row += ncorr * data_bytes / 2
    else:  # WEIGHT_SPECTRUM
        memory_per_row += bytes_per_row / 2

    # flags (uint8 or bool)
    memory_per_row += np.dtype(np.uint8).itemsize * max_chan_chunk * ncorr

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(fov / cell_size)
        if npix % 2:
            npix += 1
        nx = good_size(npix)
        ny = good_size(npix)
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("Image size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow,
                                   memory_per_row, nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row,
                                   nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    dirties = []
    radec = None  # assumes we are only imaging field 0 of first MS
    for ims in ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data = getattr(ds, args.data_column).data
            dataxx = data[:, :, 0]
            datayy = data[:, :, -1]

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data.shape,
                                          chunks=data.chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data.shape,
                                                      chunks=data.chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply adjoint of mueller term.
            # Phases modify data amplitudes modify weights.
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0]))
                datayy *= da.exp(-1j * da.angle(mueller[:, :, -1]))
                weightsxx *= da.absolute(mueller[:, :, 0])
                weightsyy *= da.absolute(mueller[:, :, -1])

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy
            data = (weightsxx * dataxx + weightsyy * datayy)
            # TODO - turn off this stupid warning
            data = da.where(weights, data / weights, 0.0j)

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            dirty = vis2im(uvw,
                           freqs[ims][spw],
                           data,
                           freq_bin_idx[ims][spw],
                           freq_bin_counts[ims][spw],
                           nx,
                           ny,
                           cell_rad,
                           weights=weights,
                           flag=mask.astype(np.uint8),
                           nthreads=ngridder_threads,
                           epsilon=args.epsilon,
                           do_wstacking=args.wstack,
                           double_accum=args.double_accum)

            dirties.append(dirty)

    # dask.visualize(dirties, filename=args.output_filename + '_graph.pdf', optimize_graph=False)

    if not args.mock:
        # result = dask.compute(dirties, wsum, optimize_graph=False)
        with performance_report(filename=args.output_filename + '_per.html'):
            result = dask.compute(dirties, optimize_graph=False)

        dirties = result[0]

        dirty = stitch_images(dirties, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_dirty.fits',
                  dirty,
                  hdr,
                  dtype=args.output_type)

    print("All done here.", file=log)