Exemple #1
0
def clean(**kw):
    '''
    Single-scale clean.

    If the optional weight-table argument points to a valid weight table
    (created by the psf worker) the algorithm will approximate gradients using
    the diagonal Mueller weights assumption (exact for Stokes I imaging) i.e.

    IR = ID - R.H W R x

    otherwise it is a pure image space algorithm i.e.

    IR = ID - PSF.convolve(x)

    The latter is exact in the absence of wide-field effects and is usually
    much faster.

    If a host address is provided the computation can be distributed
    over imaging band and row. When using a distributed scheduler both
    mem-limit and nthreads is per node and have to be specified.

    When using a local cluster, mem-limit and nthreads refer to the global
    memory and threads available, respectively. By default the gridder will
    use all available resources.

    Disclaimer - Memory budgeting is still very crude!

    On a local cluster, the default is to use:

        nworkers = nband
        nthreads-per-worker = 1

    They have to be specified in ~.config/dask/jobqueue.yaml in the
    distributed case.

    if LocalCluster:
        nvthreads = nthreads//(nworkers*nthreads_per_worker)
    else:
        nvthreads = nthreads//nthreads-per-worker
    '''
    args = OmegaConf.create(kw)
    pyscilog.log_to_file(args.output_filename + '.log')

    if args.nworkers is None:
        args.nworkers = args.nband

    OmegaConf.set_struct(args, True)

    with ExitStack() as stack:
        # numpy imports have to happen after this step
        from pfb import set_client
        set_client(args, stack, log)

        # TODO - prettier config printing
        print('Input Options:', file=log)
        for key in args.keys():
            print('     %25s = %s' % (key, args[key]), file=log)

        return _clean(**args)
Exemple #2
0
def jones2col(**kw):
    '''
    Write product of diagonal Jones matrices to 'Mueller' column
    '''
    args = OmegaConf.create(kw)
    pyscilog.log_to_file(args.output_filename + '.log')
    from glob import glob
    ms = glob(args.ms)
    try:
        assert len(ms) == 1
        args.ms = ms
    except:
        raise ValueError(f"There must be exactly one MS at {args.ms}")

    OmegaConf.set_struct(args, True)

    with ExitStack() as stack:
        from pfb import set_client
        args = set_client(args, stack, log)

        # TODO - prettier config printing
        print('Input Options:', file=log)
        for key in args.keys():
            print('     %25s = %s' % (key, args[key]), file=log)

        return _jones2col(**args)
Exemple #3
0
def dirty(**kw):
    '''
    Create a dirty image from a list of measurement sets.
    The dirty image cube is not normalised by wsum as this destroyes
    information. The MFS image is written out in units of Jy/beam.
    The normalisation factors can be obtained by making a psf image
    using the psf worker (see pfbworkers psf --help).

    If a host address is provided the computation can be distributed
    over imaging band and row. When using a distributed scheduler both
    mem-limit and nthreads is per node and have to be specified.

    When using a local cluster, mem-limit and nthreads refer to the global
    memory and threads available, respectively. By default the gridder will
    use all available resources.

    Disclaimer - Memory budgeting is still very crude!

    On a local cluster, the default is to use:

        nworkers = nband
        nthreads-per-worker = 1

    They have to be specified in ~.config/dask/jobqueue.yaml in the
    distributed case.

    if LocalCluster:
        ngridder-threads = nthreads//(nworkers*nthreads_per_worker)
    else:
        ngridder-threads = nthreads//nthreads-per-worker

    '''
    args = OmegaConf.create(kw)
    pyscilog.log_to_file(args.output_filename + '.log')
    from glob import glob
    ms = glob(args.ms)
    try:
        assert len(ms) > 0
        args.ms = ms
    except:
        raise ValueError(f"No MS at {args.ms}")

    if args.nworkers is None:
        args.nworkers = args.nband

    OmegaConf.set_struct(args, True)

    with ExitStack() as stack:
        from pfb import set_client
        args = set_client(args, stack, log)

        # TODO - prettier config printing
        print('Input Options:', file=log)
        for key in args.keys():
            print('     %25s = %s' % (key, args[key]), file=log)

        return _dirty(**args)
Exemple #4
0
def forward(**kw):
    '''
    Extract flux at model locations.

    Will write out the result of solving

    x = (R.H W R + sigmainv**2 I)^{-1} ID

    assuming that R.H W R can be approximated as a convolution with the PSF.

    If a host address is provided the computation can be distributed
    over imaging band and row. When using a distributed scheduler both
    mem-limit and nthreads is per node and have to be specified.

    When using a local cluster, mem-limit and nthreads refer to the global
    memory and threads available, respectively. By default the gridder will
    use all available resources.

    Disclaimer - Memory budgeting is still very crude!

    On a local cluster, the default is to use:

        nworkers = nband
        nthreads-per-worker = 1

    They have to be specified in ~.config/dask/jobqueue.yaml in the
    distributed case.

    if LocalCluster:
        ngridder-threads = nthreads//(nworkers*nthreads_per_worker)
    else:
        ngridder-threads = nthreads//nthreads-per-worker
    '''
    args = OmegaConf.create(kw)
    pyscilog.log_to_file(args.output_filename + '.log')

    if args.nworkers is None:
        args.nworkers = args.nband

    OmegaConf.set_struct(args, True)

    with ExitStack() as stack:
        from pfb import set_client
        args = set_client(args, stack, log)

        # TODO - prettier config printing
        print('Input Options:', file=log)
        for key in args.keys():
            print('     %25s = %s' % (key, args[key]), file=log)

        return _forward(**args)
Exemple #5
0
def _predict(ms, stack, **kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)
    pyscilog.log_to_file(args.output_filename + '.log')
    pyscilog.enable_memory_logging(level=3)

    # number of threads per worker
    if args.nthreads is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify nthreads when using a distributed scheduler"
            )
        import multiprocessing
        nthreads = multiprocessing.cpu_count()
        args.nthreads = nthreads
    else:
        nthreads = args.nthreads

    if args.mem_limit is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify mem-limit when using a distributed scheduler"
            )
        import psutil
        mem_limit = int(psutil.virtual_memory()[0] /
                        1e9)  # 100% of memory by default
        args.mem_limit = mem_limit
    else:
        mem_limit = args.mem_limit

    nband = args.nband
    if args.nworkers is None:
        nworkers = nband
        args.nworkers = nworkers
    else:
        nworkers = args.nworkers

    if args.nthreads_per_worker is None:
        nthreads_per_worker = 1
        args.nthreads_per_worker = nthreads_per_worker
    else:
        nthreads_per_worker = args.nthreads_per_worker

    # the number of chunks being read in simultaneously is equal to
    # the number of dask threads
    nthreads_dask = nworkers * nthreads_per_worker

    if args.ngridder_threads is None:
        if args.host_address is not None:
            ngridder_threads = nthreads // nthreads_per_worker
        else:
            ngridder_threads = nthreads // nthreads_dask
        args.ngridder_threads = ngridder_threads
    else:
        ngridder_threads = args.ngridder_threads

    ms = list(ms)
    print('Input Options:', file=log)
    for key in kw.keys():
        print('     %25s = %s' % (key, args[key]), file=log)

    # numpy imports have to happen after this step
    from pfb import set_client
    set_client(nthreads, mem_limit, nworkers, nthreads_per_worker,
               args.host_address, stack, log)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    from dask.distributed import performance_report
    from dask.graph_manipulation import clone
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    from daskms.utils import dataset_type
    mstype = dataset_type(ms[0])
    if mstype == 'casa':
        from daskms import xds_to_table
    elif mstype == 'zarr':
        from daskms.experimental.zarr import xds_to_zarr as xds_to_table
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import model as im2vis
    from pfb.utils.fits import load_fits
    from pfb.utils.misc import restore_corrs, plan_row_chunk
    from astropy.io import fits

    # always returns 4D
    # gridder expects freq axis
    model = np.atleast_3d(load_fits(args.model).squeeze())
    nband, nx, ny = model.shape
    hdr = fits.getheader(args.model)
    cell_d = np.abs(hdr['CDELT1'])
    cell_rad = np.deg2rad(cell_d)

    # chan <-> band mapping
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # degridder memory budget
    max_chan_chunk = 0
    for ims in ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())

    # assumes number of correlations are the same across MS/SPW
    xds = xds_from_ms(ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    if args.output_type is not None:
        output_type = np.dtype(args.output_type)
    else:
        output_type = np.result_type(np.dtype(args.real_type), np.complex64)
    data_bytes = output_type.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row  # model
    memory_per_row += 3 * 8  # uvw

    if mstype == 'zarr':
        if args.model_column in xds[0].keys():
            model_chunks = getattr(xds[0], args.model_column).data.chunks
        else:
            model_chunks = xds[0].DATA.data.chunks
            print('Chunking model same as data')

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow,
                                   memory_per_row, nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row,
                                   nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    model = da.from_array(model.astype(args.real_type),
                          chunks=(1, nx, ny),
                          name=False)
    writes = []
    radec = None  # assumes we are only imaging field 0 of first MS
    for ims in ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=('UVW'))

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        out_data = []
        for ds in xds:
            field = fields[ds.FIELD_ID]
            radec = field.PHASE_DIR.data.squeeze()

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

            uvw = clone(ds.UVW.data)

            bands = band_mapping[ims][spw]
            model = model[list(bands), :, :]
            vis = im2vis(uvw,
                         freqs[ims][spw],
                         model,
                         freq_bin_idx[ims][spw],
                         freq_bin_counts[ims][spw],
                         cell_rad,
                         nthreads=ngridder_threads,
                         epsilon=args.epsilon,
                         do_wstacking=args.wstack)

            model_vis = restore_corrs(vis, ncorr)
            if mstype == 'zarr':
                model_vis = model_vis.rechunk(model_chunks)
                uvw = uvw.rechunk((model_chunks[0], 3))

            out_ds = ds.assign(
                **{
                    args.model_column: (("row", "chan", "corr"), model_vis),
                    'UVW': (("row", "three"), uvw)
                })
            # out_ds = ds.assign(**{args.model_column: (("row", "chan", "corr"), model_vis)})
            out_data.append(out_ds)

        writes.append(xds_to_table(out_data, ims, columns=[args.model_column]))

    dask.visualize(*writes,
                   filename=args.output_filename + '_predict_graph.pdf',
                   optimize_graph=False,
                   collapse_outputs=True)

    if not args.mock:
        with performance_report(filename=args.output_filename +
                                '_predict_per.html'):
            dask.compute(writes, optimize_graph=False)

    print("All done here.", file=log)
Exemple #6
0
def psf(**kw):
    '''
    Create a psf image from a list of measurement setsand write out the
    Mueller weights.

    The psf image cube is not normalised by wsum as this destroyes
    information. The MFS image is written out in units of Jy/beam
    and should have a peak of one otherwise something has gone wrong.

    The --field-of-view and --super-resolution-factor options
    (equivalently --cell-size, --nx and --ny) pertain to the size of
    the image (eg. dirty and model). The size of the PSF output image
    is controlled by the --psf-oversize option.

    The Stokes I weights required to apply the Hessian are also written out
    to a zarr data set called output-filename.zarr. This data set does not
    adhere to the MSv2 specs and is only meant to be used to apply the
    Hessian. In particular, the weights written out are a combination of
    imaging weights and the "Mueller" weights.

    If a host address is provided the computation can be distributed
    over imaging band and row. When using a distributed scheduler both
    mem-limit and nthreads is per node and have to be specified.

    When using a local cluster, mem-limit and nthreads refer to the global
    memory and threads available, respectively. By default the gridder will
    use all available resources.

    Disclaimer - Memory budgeting is still very crude!

    On a local cluster, the default is to use:

        nworkers = nband
        nthreads-per-worker = 1

    They have to be specified in ~.config/dask/jobqueue.yaml in the
    distributed case.

    if LocalCluster:
        ngridder-threads = nthreads//(nworkers*nthreads_per_worker)
    else:
        ngridder-threads = nthreads//nthreads-per-worker

    '''
    args = OmegaConf.create(kw)
    pyscilog.log_to_file(args.output_filename + '.log')
    from glob import glob
    ms = glob(args.ms)
    try:
        assert len(ms) > 0
        args.ms = ms
    except:
        raise ValueError(f"No MS at {args.ms}")

    if args.nworkers is None:
        args.nworkers = args.nband

    OmegaConf.set_struct(args, True)

    with ExitStack() as stack:
        from pfb import set_client
        args = set_client(args, stack, log)

        # TODO - prettier config printing
        print('Input Options:', file=log)
        for key in args.keys():
            print('     %25s = %s' % (key, args[key]), file=log)

        return _psf(**args)
Exemple #7
0
def _restore(stack, **kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)
    pyscilog.log_to_file(args.output_filename + '.log')
    pyscilog.enable_memory_logging(level=3)

    # number of threads per worker
    if args.nthreads is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify nthreads when using a distributed scheduler"
            )
        import multiprocessing
        nthreads = multiprocessing.cpu_count()
        args.nthreads = nthreads
    else:
        nthreads = args.nthreads

    # configure memory limit
    if args.mem_limit is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify mem-limit when using a distributed scheduler"
            )
        import psutil
        mem_limit = int(psutil.virtual_memory()[0] /
                        1e9)  # 100% of memory by default
        args.mem_limit = mem_limit
    else:
        mem_limit = args.mem_limit

    nband = args.nband
    if args.nworkers is None:
        nworkers = nband
        args.nworkers = nworkers
    else:
        nworkers = args.nworkers

    if args.nthreads_per_worker is None:
        nthreads_per_worker = 1
        args.nthreads_per_worker = nthreads_per_worker
    else:
        nthreads_per_worker = args.nthreads_per_worker

    # the number of chunks being read in simultaneously is equal to
    # the number of dask threads
    nthreads_dask = nworkers * nthreads_per_worker

    if args.ngridder_threads is None:
        if args.host_address is not None:
            ngridder_threads = nthreads // nthreads_per_worker
        else:
            ngridder_threads = nthreads // nthreads_dask
        args.ngridder_threads = ngridder_threads
    else:
        ngridder_threads = args.ngridder_threads

    ms = list(ms)
    print('Input Options:', file=log)
    for key in kw.keys():
        print('     %25s = %s' % (key, args[key]), file=log)

    # numpy imports have to happen after this step
    from pfb import set_client
    set_client(nthreads, mem_limit, nworkers, nthreads_per_worker,
               args.host_address, stack, log)

    import numpy as np
    from astropy.io import fits
    mhdr = fits.getheader(args.model)

    from pfb.utils.fits import load_fits
    model = load_fits(args.model).squeeze()  # drop Stokes axis

    # check images compatible
    rhdr = fits.getheader(args.residual)

    from pfb.utils.fits import compare_headers
    compare_headers(mhdr, rhdr)
    residual = load_fits(args.residual).squeeze()

    # fit restoring psf
    from pfb.utils.misc import fitcleanbeam
    psf = load_fits(args.psf, dtype=args.real_type).squeeze()

    nband, nx_psf, ny_psf = psf.shape
    wsums = np.amax(psf.reshape(args.nband, nx_psf, ny_psf), axis=1)
    wsum = np.sum(wsums)
    psf /= wsum
    psf_mfs = np.sum(psf, axis=0)

    # fit restoring psf
    GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0)
    GaussPars = fitcleanbeam(psf, level=0.5, pixsize=1.0)

    cpsf_mfs = np.zeros(psf_mfs.shape, dtype=args.real_type)
    cpsf = np.zeros(psf.shape, dtype=args.real_type)

    lpsf = np.arange(-R.nx_psf / 2, R.nx_psf / 2)
    mpsf = np.arange(-R.ny_psf / 2, R.ny_psf / 2)
    xx, yy = np.meshgrid(lpsf, mpsf, indexing='ij')

    cpsf_mfs = Gaussian2D(xx, yy, GaussPar[0], normalise=False)

    for v in range(args.nband):
        cpsf[v] = Gaussian2D(xx, yy, GaussPars[v], normalise=False)

    from pfb.utils.fits import add_beampars
    GaussPar = list(GaussPar[0])
    GaussPar[0] *= args.cell_size / 3600
    GaussPar[1] *= args.cell_size / 3600
    GaussPar = tuple(GaussPar)
    hdr_psf_mfs = add_beampars(hdr_psf_mfs, GaussPar)

    save_fits(args.outfile + '_cpsf_mfs.fits', cpsf_mfs, hdr_psf_mfs)
    save_fits(args.outfile + '_psf_mfs.fits', psf_mfs, hdr_psf_mfs)

    if args.beam is not None:
        bhdr = fits.getheader(args.beam)
        compare_headers(mhdr, bhdr)
        beam = load_fits(args.beam).squeeze()
        model = np.where(beam > args.pb_min, model / beam, 0.0)

    nband, nx, ny = model.shape
    guassparf = ()
    if nband > 1:
        for b in range(nband):
            guassparf += (rhdr['BMAJ' + str(b)], rhdr['BMIN' + str(b)],
                          rhdr['BPA' + str(b)])
    else:
        guassparf += (rhdr['BMAJ'], rhdr['BMIN'], rhdr['BPA'])

    # if args.convolve_residuals:

    cellx = np.abs(mhdr['CDELT1'])
    celly = np.abs(mhdr['CDELT2'])

    from pfb.utils.restoration import restore_image
Exemple #8
0
def spifit(**kw):
    """
    Spectral index fitter

    """
    args = OmegaConf.create(kw)
    pyscilog.log_to_file(args.output_filename + '.log')
    from glob import glob
    from omegaconf import ListConfig
    # image is either a string or a list of strings that we want to glob on
    if isinstance(args.image, str):
        image = sorted(glob(args.image))
    elif isinstance(args.image, list) or isinstance(args.image, ListConfig):
        image = []
        for i in len(args.image):
            image.append(sorted(glob(args.image[i])))

    # make sure it's not empty
    try:
        assert len(image) > 0
        args.image = image
    except:
        raise ValueError(f"No image at {args.image}")

    # same goes for the residual except that it may also be None
    if isinstance(args.residual, str):
        residual = sorted(glob(args.residual))
    elif isinstance(args.residual, list) or isinstance(args.residual, ListConfig):
        residual = []
        for i in len(args.residual):
            residual.append(sorted(glob(args.residual[i])))

    if args.residual is not None:
        try:
            assert len(residual) > 0
            args.residual = residual
        except:
            raise ValueError(f"No residual at {args.residual}")
        # we also need the same number of residuals as images
        try:
            assert len(args.image) == len(args.residual)
        except:
            raise ValueError(f"Number of images and residuals need to "
                                "match")
    else:
        print("No residual passed in!", file=log)

    # and finally the beam model
    if isinstance(args.beam_model, str):
        beam_model  = sorted(glob(args.beam_model))
    elif isinstance(args.beam_model, list) or isinstance(args.beam_model, ListConfig):
        beam_model = []
        for i in len(args.beam_model):
            beam_model.append(sorted(glob(args.beam_model[i])))

    if args.beam_model is not None:
        try:
            assert len(beam_model) > 0
            args.beam_model = beam_model
        except:
            raise ValueError(f"No beam model at {args.beam_model}")

        try:
            assert len(args.image) == len(args.beam_model)
        except:
            raise ValueError(f"Number of images and beam models need to "
                                "match")
    else:
        print("Not doing any form of primary beam correction", file=log)

    # LB - TODO: can we sort them along freq at this point already?

    OmegaConf.set_struct(args, True)

    with ExitStack() as stack:
        from pfb import set_client
        args = set_client(args, stack, log)

        # TODO - prettier config printing
        print('Input Options:', file=log)
        for key in args.keys():
            print('     %25s = %s' % (key, args[key]), file=log)

        return _spifit(**args)
Exemple #9
0
def binterp(**kw):
    """
    Beam interpolator

    Interpolate beams and stack cubes one MS and one spectral window at a time.

    """
    args = OmegaConf.create(kw)
    from glob import glob
    image = sorted(glob(args.image))
    try:
        assert len(image) > 0
        args.image = image
    except:
        raise ValueError(f"No image at {args.image}")

    if args.output_dir is None:
        args.output_dir = os.path.dirname(args.image[0])

    pyscilog.log_to_file(args.output_dir + args.postfix.strip('fits') + 'log')

    if args.ms is not None:
        ms = glob(args.ms)
        try:
            assert len(ms) == 1
            args.ms = ms[0]
        except:
            raise ValueError(
                f"There must be exactly one MS matching {args.ms} if provided")

    if not isinstance(args.beam_model, str):
        raise ValueError("Only string beam patterns allowed")
    else:
        # we are either using JimBeam or globbing for beam patterns
        if args.beam_model.lower() == 'jimbeam':
            args.beam_model = args.beam_model.lower()
            band = args.band.lower()
            if band != 'l' and band != 'uhf':
                raise ValueError("Only l or uhf band supported with "
                                 "JimBeam")
            else:
                print("Using %s band beam model" % args.band, file=log)
        elif args.beam_model.lower().endswith('.fits'):
            beam_model = glob(args.beam_model)
            try:
                assert len(beam_model) > 0
            except:
                raise ValueError(f"No beam model at {args.beam_model}")
        else:
            raise ValueError("Unknown beam model provided. "
                             "Either use JimBeam or pass in the fits beam "
                             "patterns")

    OmegaConf.set_struct(args, True)
    with ExitStack() as stack:
        from pfb import set_client
        args = set_client(args, stack, log)

        # TODO - prettier config printing
        print('Input Options:', file=log)
        for key in args.keys():
            print('     %25s = %s' % (key, args[key]), file=log)

        return _binterp(**args)
Exemple #10
0
def _residual(ms, stack, **kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)
    pyscilog.log_to_file(args.output_filename + '.log')
    pyscilog.enable_memory_logging(level=3)

    # number of threads per worker
    if args.nthreads is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify nthreads when using a distributed scheduler"
            )
        import multiprocessing
        nthreads = multiprocessing.cpu_count()
        args.nthreads = nthreads
    else:
        nthreads = args.nthreads

    # configure memory limit
    if args.mem_limit is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify mem-limit when using a distributed scheduler"
            )
        import psutil
        mem_limit = int(psutil.virtual_memory()[0] /
                        1e9)  # 100% of memory by default
        args.mem_limit = mem_limit
    else:
        mem_limit = args.mem_limit

    nband = args.nband
    if args.nworkers is None:
        nworkers = nband
        args.nworkers = nworkers
    else:
        nworkers = args.nworkers

    if args.nthreads_per_worker is None:
        nthreads_per_worker = 1
        args.nthreads_per_worker = nthreads_per_worker
    else:
        nthreads_per_worker = args.nthreads_per_worker

    # the number of chunks being read in simultaneously is equal to
    # the number of dask threads
    nthreads_dask = nworkers * nthreads_per_worker

    if args.ngridder_threads is None:
        if args.host_address is not None:
            ngridder_threads = nthreads // nthreads_per_worker
        else:
            ngridder_threads = nthreads // nthreads_dask
        args.ngridder_threads = ngridder_threads
    else:
        ngridder_threads = args.ngridder_threads

    ms = list(ms)
    print('Input Options:', file=log)
    for key in kw.keys():
        print('     %25s = %s' % (key, args[key]), file=log)

    # numpy imports have to happen after this step
    from pfb import set_client
    set_client(nthreads, mem_limit, nworkers, nthreads_per_worker,
               args.host_address, stack, log)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    from dask.graph_manipulation import clone
    from dask.distributed import performance_report
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import residual as im2residim
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # real valued weights
    wdims = getattr(xds[0], args.weight_column).data.ndim
    if wdims == 2:  # WEIGHT
        memory_per_row += ncorr * data_bytes / 2
    else:  # WEIGHT_SPECTRUM
        memory_per_row += bytes_per_row / 2

    # flags (uint8 or bool)
    memory_per_row += np.dtype(np.uint8).itemsize * max_chan_chunk * ncorr

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(fov / cell_size)
        if npix % 2:
            npix += 1
        nx = good_size(npix)
        ny = good_size(npix)
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("Image size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow,
                                   memory_per_row, nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row,
                                   nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    dirties = []
    radec = None  # assumes we are only imaging field 0 of first MS
    for ims in ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data = getattr(ds, args.data_column).data
            dataxx = data[:, :, 0]
            datayy = data[:, :, -1]

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data.shape,
                                          chunks=data.chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data.shape,
                                                      chunks=data.chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply adjoint of mueller term.
            # Phases modify data amplitudes modify weights.
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0]))
                datayy *= da.exp(-1j * da.angle(mueller[:, :, -1]))
                weightsxx *= da.absolute(mueller[:, :, 0])
                weightsyy *= da.absolute(mueller[:, :, -1])

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy
            data = (weightsxx * dataxx + weightsyy * datayy)
            # TODO - turn off this stupid warning
            data = da.where(weights, data / weights, 0.0j)

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            dirty = vis2im(uvw,
                           freqs[ims][spw],
                           data,
                           freq_bin_idx[ims][spw],
                           freq_bin_counts[ims][spw],
                           nx,
                           ny,
                           cell_rad,
                           weights=weights,
                           flag=mask.astype(np.uint8),
                           nthreads=ngridder_threads,
                           epsilon=args.epsilon,
                           do_wstacking=args.wstack,
                           double_accum=args.double_accum)

            dirties.append(dirty)

    # dask.visualize(dirties, filename=args.output_filename + '_graph.pdf', optimize_graph=False)

    if not args.mock:
        # result = dask.compute(dirties, wsum, optimize_graph=False)
        with performance_report(filename=args.output_filename + '_per.html'):
            result = dask.compute(dirties, optimize_graph=False)

        dirties = result[0]

        dirty = stitch_images(dirties, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_dirty.fits',
                  dirty,
                  hdr,
                  dtype=args.output_type)

    print("All done here.", file=log)