    def apply_strategies(self, flag_windows, vis_windows):
        original = flag_windows.copy()

        # Run flagger strategies
        for strategy in self.strategies:
                task = strategy['task']
            except KeyError:
                raise ValueError("strategy has no 'task': %s" % strategy)

            if task == "sum_threshold":
                new_flags = sum_threshold_flagger(vis_windows, flag_windows,
                # sum threshold builds upon any flags that came previous
                flag_windows = da.logical_or(new_flags, flag_windows)
            elif task == "uvcontsub_flagger":
                new_flags = uvcontsub_flagger(vis_windows, flag_windows,
                # this task discards previous flags by default during its
                # second iteration. The original flags from MS should be or'd
                # back in afterwards. Flags from steps prior to this one serves
                # only as a "initial guess"
                flag_windows = new_flags
            elif task == "flag_autos":
                new_flags = flag_autos(flag_windows, self.ubl)
                flag_windows = da.logical_or(new_flags, flag_windows)
            elif task == "combine_with_input_flags":
                # or's in original flags from the measurement set
                # (if -if option has not been specified,
                # in which case this option will do nothing)
                flag_windows = da.logical_or(flag_windows, original)
            elif task == "unflag":
                flag_windows = da.zeros_like(flag_windows)
            elif task == "flag_nans_zeros":
                flag_windows = flag_nans_and_zeros(vis_windows, flag_windows)
            elif task == "apply_static_mask":
                new_flags = apply_static_mask(flag_windows,
                # override option will override any flags computed previously
                # this may not be desirable so use with care or in combination
                # with combine_with_input_flags option!
                if strategy['kwargs']["accumulation_mode"].strip() == "or":
                    flag_windows = da.logical_or(new_flags, flag_windows)
                    flag_windows = new_flags

                raise ValueError("Task '%s' does not name a valid task", task)

        return flag_windows
def _true_color_dask(r, g, b, nodata, c, th):
    pixel_max = 255

    alpha = da.where(da.logical_or(da.isnan(r), r <= nodata), 0,

    red = (_normalize_data(r, pixel_max, c, th)).astype(np.uint8)
    green = (_normalize_data(g, pixel_max, c, th)).astype(np.uint8)
    blue = (_normalize_data(b, pixel_max, c, th)).astype(np.uint8)

    out = da.stack([red, green, blue, alpha], axis=-1)
    return out
def test_github_98():
    ms = "/home/sperkins/data/AF0236_spw01.ms/"

    if not os.path.exists(ms):
        pytest.skip("AF0236_spw01.ms on which this "
                    "test depends is not present")

    datasets = xds_from_ms(ms,
                           columns=['DATA', 'ANTENNA1', 'ANTENNA2'],
                           taql_where='ANTENNA1 == 5 || ANTENNA2 == 5')

    assert len(datasets) == 2
    assert datasets[0].DATA_DESC_ID == 0
    assert datasets[1].DATA_DESC_ID == 1

    for ds in datasets:
        expr = da.logical_or(ds.ANTENNA1.data == 5, ds.ANTENNA2.data == 5)
        expr, equal = dask.compute(expr, da.all(expr))
        assert equal.item() is True
        assert len(expr) > 0
def test_arithmetic():
    x = np.arange(5).astype('f4') + 2
    y = np.arange(5).astype('i8') + 2
    z = np.arange(5).astype('i4') + 2
    a = da.from_array(x, chunks=(2,))
    b = da.from_array(y, chunks=(2,))
    c = da.from_array(z, chunks=(2,))
    assert eq(a + b, x + y)
    assert eq(a * b, x * y)
    assert eq(a - b, x - y)
    assert eq(a / b, x / y)
    assert eq(b & b, y & y)
    assert eq(b | b, y | y)
    assert eq(b ^ b, y ^ y)
    assert eq(a // b, x // y)
    assert eq(a ** b, x ** y)
    assert eq(a % b, x % y)
    assert eq(a > b, x > y)
    assert eq(a < b, x < y)
    assert eq(a >= b, x >= y)
    assert eq(a <= b, x <= y)
    assert eq(a == b, x == y)
    assert eq(a != b, x != y)

    assert eq(a + 2, x + 2)
    assert eq(a * 2, x * 2)
    assert eq(a - 2, x - 2)
    assert eq(a / 2, x / 2)
    assert eq(b & True, y & True)
    assert eq(b | True, y | True)
    assert eq(b ^ True, y ^ True)
    assert eq(a // 2, x // 2)
    assert eq(a ** 2, x ** 2)
    assert eq(a % 2, x % 2)
    assert eq(a > 2, x > 2)
    assert eq(a < 2, x < 2)
    assert eq(a >= 2, x >= 2)
    assert eq(a <= 2, x <= 2)
    assert eq(a == 2, x == 2)
    assert eq(a != 2, x != 2)

    assert eq(2 + b, 2 + y)
    assert eq(2 * b, 2 * y)
    assert eq(2 - b, 2 - y)
    assert eq(2 / b, 2 / y)
    assert eq(True & b, True & y)
    assert eq(True | b, True | y)
    assert eq(True ^ b, True ^ y)
    assert eq(2 // b, 2 // y)
    assert eq(2 ** b, 2 ** y)
    assert eq(2 % b, 2 % y)
    assert eq(2 > b, 2 > y)
    assert eq(2 < b, 2 < y)
    assert eq(2 >= b, 2 >= y)
    assert eq(2 <= b, 2 <= y)
    assert eq(2 == b, 2 == y)
    assert eq(2 != b, 2 != y)

    assert eq(-a, -x)
    assert eq(abs(a), abs(x))
    assert eq(~(a == b), ~(x == y))
    assert eq(~(a == b), ~(x == y))

    assert eq(da.logaddexp(a, b), np.logaddexp(x, y))
    assert eq(da.logaddexp2(a, b), np.logaddexp2(x, y))
    assert eq(da.exp(b), np.exp(y))
    assert eq(da.log(a), np.log(x))
    assert eq(da.log10(a), np.log10(x))
    assert eq(da.log1p(a), np.log1p(x))
    assert eq(da.expm1(b), np.expm1(y))
    assert eq(da.sqrt(a), np.sqrt(x))
    assert eq(da.square(a), np.square(x))

    assert eq(da.sin(a), np.sin(x))
    assert eq(da.cos(b), np.cos(y))
    assert eq(da.tan(a), np.tan(x))
    assert eq(da.arcsin(b/10), np.arcsin(y/10))
    assert eq(da.arccos(b/10), np.arccos(y/10))
    assert eq(da.arctan(b/10), np.arctan(y/10))
    assert eq(da.arctan2(b*10, a), np.arctan2(y*10, x))
    assert eq(da.hypot(b, a), np.hypot(y, x))
    assert eq(da.sinh(a), np.sinh(x))
    assert eq(da.cosh(b), np.cosh(y))
    assert eq(da.tanh(a), np.tanh(x))
    assert eq(da.arcsinh(b*10), np.arcsinh(y*10))
    assert eq(da.arccosh(b*10), np.arccosh(y*10))
    assert eq(da.arctanh(b/10), np.arctanh(y/10))
    assert eq(da.deg2rad(a), np.deg2rad(x))
    assert eq(da.rad2deg(a), np.rad2deg(x))

    assert eq(da.logical_and(a < 1, b < 4), np.logical_and(x < 1, y < 4))
    assert eq(da.logical_or(a < 1, b < 4), np.logical_or(x < 1, y < 4))
    assert eq(da.logical_xor(a < 1, b < 4), np.logical_xor(x < 1, y < 4))
    assert eq(da.logical_not(a < 1), np.logical_not(x < 1))
    assert eq(da.maximum(a, 5 - a), np.maximum(a, 5 - a))
    assert eq(da.minimum(a, 5 - a), np.minimum(a, 5 - a))
    assert eq(da.fmax(a, 5 - a), np.fmax(a, 5 - a))
    assert eq(da.fmin(a, 5 - a), np.fmin(a, 5 - a))

    assert eq(da.isreal(a + 1j * b), np.isreal(x + 1j * y))
    assert eq(da.iscomplex(a + 1j * b), np.iscomplex(x + 1j * y))
    assert eq(da.isfinite(a), np.isfinite(x))
    assert eq(da.isinf(a), np.isinf(x))
    assert eq(da.isnan(a), np.isnan(x))
    assert eq(da.signbit(a - 3), np.signbit(x - 3))
    assert eq(da.copysign(a - 3, b), np.copysign(x - 3, y))
    assert eq(da.nextafter(a - 3, b), np.nextafter(x - 3, y))
    assert eq(da.ldexp(c, c), np.ldexp(z, z))
    assert eq(da.fmod(a * 12, b), np.fmod(x * 12, y))
    assert eq(da.floor(a * 0.5), np.floor(x * 0.5))
    assert eq(da.ceil(a), np.ceil(x))
    assert eq(da.trunc(a / 2), np.trunc(x / 2))

    assert eq(da.degrees(b), np.degrees(y))
    assert eq(da.radians(a), np.radians(x))

    assert eq(da.rint(a + 0.3), np.rint(x + 0.3))
    assert eq(da.fix(a - 2.5), np.fix(x - 2.5))

    assert eq(da.angle(a + 1j), np.angle(x + 1j))
    assert eq(da.real(a + 1j), np.real(x + 1j))
    assert eq((a + 1j).real, np.real(x + 1j))
    assert eq(da.imag(a + 1j), np.imag(x + 1j))
    assert eq((a + 1j).imag, np.imag(x + 1j))
    assert eq(da.conj(a + 1j * b), np.conj(x + 1j * y))
    assert eq((a + 1j * b).conj(), (x + 1j * y).conj())

    assert eq(da.clip(b, 1, 4), np.clip(y, 1, 4))
    assert eq(da.fabs(b), np.fabs(y))
    assert eq(da.sign(b - 2), np.sign(y - 2))

    l1, l2 = da.frexp(a)
    r1, r2 = np.frexp(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    l1, l2 = da.modf(a)
    r1, r2 = np.modf(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    assert eq(da.around(a, -1), np.around(x, -1))
 def _mask_invalid(self, data, header):
     """Mask invalid data"""
     invalid = da.logical_or(data == header['block5']["count_value_outside_scan_pixels"][0],
                             data == header['block5']["count_value_error_pixels"][0])
     return da.where(invalid, np.float32(np.nan), data)
def _jones2col(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from daskms.experimental.zarr import xds_from_zarr
    from daskms import xds_from_ms, xds_to_table
    import dask.array as da
    import dask
    from africanus.calibration.utils import chunkify_rows
    from africanus.calibration.utils.dask import corrupt_vis

    # get net gains
    G = xds_from_zarr(args.gain_table + '::G')

    # chunking info
    t_chunks = G[0].t_chunk.data
    if len(t_chunks) > 1:
        t_chunks = G[0].t_chunk.data[1:-1] - G[0].t_chunk.data[0:-2]
        assert (t_chunks == t_chunks[0]).all()
        utpc = t_chunks[0]
        utpc = t_chunks[0]
    times = xds_from_ms(args.ms[0], columns=['TIME'])[0].get('TIME').data.compute()
    row_chunks, tbin_idx, tbin_counts = chunkify_rows(times, utimes_per_chunk=utpc, daskify_idx=True)

    f_chunks = G[0].f_chunk.data
    if len(f_chunks) > 1:
        f_chunks = G[0].f_chunk.data[1:-1] - G[0].f_chunk.data[0:-2]
        assert (f_chunks == f_chunks[0]).all()
        chan_chunks = f_chunks[0]
        if f_chunks[0]:
            chan_chunks = f_chunks[0]
            chan_chunks = -1

    columns = ('DATA', 'FLAG', 'FLAG_ROW', 'ANTENNA1', 'ANTENNA2')
    if args.acol is not None:
        columns += (args.acol,)

    # open MS
    xds = xds_from_ms(args.ms[0], chunks={'row': row_chunks, 'chan': chan_chunks},
                      group_cols=('FIELD_ID', 'DATA_DESC_ID', 'SCAN_NUMBER'))

    # Current hack probably only works for single field and DDID
        assert len(xds) == len(G)
    except Exception as e:
        raise ValueError("Number of datasets in gains do not "
                            "match those in MS")

    # assuming scans are aligned
    out_data = []
    for g, ds in zip(G, xds):
            assert g.SCAN_NUMBER == ds.SCAN_NUMBER
        except Exception as e:
            raise ValueError("Scans not aligned")

        nrow = ds.dims['row']
        nchan = ds.dims['chan']
        ncorr = ds.dims['corr']

        # need to swap axes for africanus
        jones = da.swapaxes(g.gains.data, 1, 2)
        flag = ds.FLAG.data
        frow = ds.FLAG_ROW.data
        ant1 = ds.ANTENNA1.data
        ant2 = ds.ANTENNA2.data

        frow = (frow | (ant1 == ant2))
        flag = (flag[:, :, 0] | flag[:, :, -1])
        flag = da.logical_or(flag, frow[:, None])

        if args.acol is not None:
            acol = ds.get(args.acol).data.reshape(nrow, nchan, 1, ncorr)
            acol = da.ones((nrow, nchan, 1, ncorr),
                           chunks=(row_chunks, chan_chunks, 1, -1),

        cvis = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones, acol)

        # compare where unflagged
        if args.compareto is not None:
            flag = flag.compute()
            vis = ds.get(args.compareto).values[~flag]
            print("Max abs difference = ", np.abs(cvis.compute()[~flag] - vis).max())

        out_ds = ds.assign(**{args.mueller_column: (("row", "chan", "corr"), cvis)})

    writes = xds_to_table(out_data, args.ms[0], columns=[args.mueller_column])
文件: psf.py 项目: ratt-ru/pfb-clean
def _psf(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    # from dask.distributed import performance_report
    from dask.graph_manipulation import clone
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    from daskms import Dataset
    from daskms.experimental.zarr import xds_to_zarr
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import dirty as vis2im
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    ms = args.ms
    nband = args.nband
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in args.ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(args.ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    # we still have to cater for complex valued data because we cast
    # the weights to complex but we not longer need to factor the
    # weight column into our memory budget
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # flags (uint8 or bool)
    memory_per_row += bytes_per_row / 8

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    # TIME
    memory_per_row += xds[0].TIME.data.itemsize

    # data column is not actually read into memory just used to infer
    # dtype and chunking
    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2', 'TIME')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in args.ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(args.psf_oversize * fov / cell_size)
        if npix % 2:
            npix += 1
        nx = npix
        ny = npix
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size,
                                   nrow, memory_per_row,

        # single band per node
        row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                   memory_per_row, args.nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),

    chunks = {}
    for ims in args.ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']

    psfs = []
    radec = None  # assumes we are only imaging field 0 of first MS
    out_datasets = []
    for ims in args.ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data_type = getattr(ds, args.data_column).data.dtype
            data_shape = getattr(ds, args.data_column).data.shape
            data_chunks = getattr(ds, args.data_column).data.chunks

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply mueller term
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                weightsxx *= da.absolute(mueller[:, :, 0])**2
                weightsyy *= da.absolute(mueller[:, :, -1])**2

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            psf = vis2im(uvw,


            data_vars = {
                'FIELD_ID': (('row', ),
                'DATA_DESC_ID': (('row', ),
                (('row', 'chan'), weights.rechunk({0: args.row_out_chunk
                                                   })),  # why no 'f4'?
                'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk}))

            coords = {'chan': (('chan', ), freqs[ims][spw])}

            out_ds = Dataset(data_vars, coords)


    writes = xds_to_zarr(out_datasets,
                         args.output_filename + '.zarr',

    # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False)
    # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False)

    if not args.mock:
        # psfs = dask.compute(psfs, writes, optimize_graph=False)[0]
        # with performance_report(filename=args.output_filename + '_psf_per.html'):
        psfs = dask.compute(psfs, writes, optimize_graph=False)[0]

        psf = stitch_images(psfs, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
        save_fits(args.output_filename + '_psf.fits',

        psf_mfs = np.sum(psf, axis=0)
        wsum = psf_mfs.max()
        psf_mfs /= wsum

        hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
        save_fits(args.output_filename + '_psf_mfs.fits',

    print("All done here.", file=log)
def test_arithmetic():
    x = np.arange(5).astype('f4') + 2
    y = np.arange(5).astype('i8') + 2
    z = np.arange(5).astype('i4') + 2
    a = da.from_array(x, chunks=(2, ))
    b = da.from_array(y, chunks=(2, ))
    c = da.from_array(z, chunks=(2, ))
    assert eq(a + b, x + y)
    assert eq(a * b, x * y)
    assert eq(a - b, x - y)
    assert eq(a / b, x / y)
    assert eq(b & b, y & y)
    assert eq(b | b, y | y)
    assert eq(b ^ b, y ^ y)
    assert eq(a // b, x // y)
    assert eq(a**b, x**y)
    assert eq(a % b, x % y)
    assert eq(a > b, x > y)
    assert eq(a < b, x < y)
    assert eq(a >= b, x >= y)
    assert eq(a <= b, x <= y)
    assert eq(a == b, x == y)
    assert eq(a != b, x != y)

    assert eq(a + 2, x + 2)
    assert eq(a * 2, x * 2)
    assert eq(a - 2, x - 2)
    assert eq(a / 2, x / 2)
    assert eq(b & True, y & True)
    assert eq(b | True, y | True)
    assert eq(b ^ True, y ^ True)
    assert eq(a // 2, x // 2)
    assert eq(a**2, x**2)
    assert eq(a % 2, x % 2)
    assert eq(a > 2, x > 2)
    assert eq(a < 2, x < 2)
    assert eq(a >= 2, x >= 2)
    assert eq(a <= 2, x <= 2)
    assert eq(a == 2, x == 2)
    assert eq(a != 2, x != 2)

    assert eq(2 + b, 2 + y)
    assert eq(2 * b, 2 * y)
    assert eq(2 - b, 2 - y)
    assert eq(2 / b, 2 / y)
    assert eq(True & b, True & y)
    assert eq(True | b, True | y)
    assert eq(True ^ b, True ^ y)
    assert eq(2 // b, 2 // y)
    assert eq(2**b, 2**y)
    assert eq(2 % b, 2 % y)
    assert eq(2 > b, 2 > y)
    assert eq(2 < b, 2 < y)
    assert eq(2 >= b, 2 >= y)
    assert eq(2 <= b, 2 <= y)
    assert eq(2 == b, 2 == y)
    assert eq(2 != b, 2 != y)

    assert eq(-a, -x)
    assert eq(abs(a), abs(x))
    assert eq(~(a == b), ~(x == y))
    assert eq(~(a == b), ~(x == y))

    assert eq(da.logaddexp(a, b), np.logaddexp(x, y))
    assert eq(da.logaddexp2(a, b), np.logaddexp2(x, y))
    assert eq(da.exp(b), np.exp(y))
    assert eq(da.log(a), np.log(x))
    assert eq(da.log10(a), np.log10(x))
    assert eq(da.log1p(a), np.log1p(x))
    assert eq(da.expm1(b), np.expm1(y))
    assert eq(da.sqrt(a), np.sqrt(x))
    assert eq(da.square(a), np.square(x))

    assert eq(da.sin(a), np.sin(x))
    assert eq(da.cos(b), np.cos(y))
    assert eq(da.tan(a), np.tan(x))
    assert eq(da.arcsin(b / 10), np.arcsin(y / 10))
    assert eq(da.arccos(b / 10), np.arccos(y / 10))
    assert eq(da.arctan(b / 10), np.arctan(y / 10))
    assert eq(da.arctan2(b * 10, a), np.arctan2(y * 10, x))
    assert eq(da.hypot(b, a), np.hypot(y, x))
    assert eq(da.sinh(a), np.sinh(x))
    assert eq(da.cosh(b), np.cosh(y))
    assert eq(da.tanh(a), np.tanh(x))
    assert eq(da.arcsinh(b * 10), np.arcsinh(y * 10))
    assert eq(da.arccosh(b * 10), np.arccosh(y * 10))
    assert eq(da.arctanh(b / 10), np.arctanh(y / 10))
    assert eq(da.deg2rad(a), np.deg2rad(x))
    assert eq(da.rad2deg(a), np.rad2deg(x))

    assert eq(da.logical_and(a < 1, b < 4), np.logical_and(x < 1, y < 4))
    assert eq(da.logical_or(a < 1, b < 4), np.logical_or(x < 1, y < 4))
    assert eq(da.logical_xor(a < 1, b < 4), np.logical_xor(x < 1, y < 4))
    assert eq(da.logical_not(a < 1), np.logical_not(x < 1))
    assert eq(da.maximum(a, 5 - a), np.maximum(a, 5 - a))
    assert eq(da.minimum(a, 5 - a), np.minimum(a, 5 - a))
    assert eq(da.fmax(a, 5 - a), np.fmax(a, 5 - a))
    assert eq(da.fmin(a, 5 - a), np.fmin(a, 5 - a))

    assert eq(da.isreal(a + 1j * b), np.isreal(x + 1j * y))
    assert eq(da.iscomplex(a + 1j * b), np.iscomplex(x + 1j * y))
    assert eq(da.isfinite(a), np.isfinite(x))
    assert eq(da.isinf(a), np.isinf(x))
    assert eq(da.isnan(a), np.isnan(x))
    assert eq(da.signbit(a - 3), np.signbit(x - 3))
    assert eq(da.copysign(a - 3, b), np.copysign(x - 3, y))
    assert eq(da.nextafter(a - 3, b), np.nextafter(x - 3, y))
    assert eq(da.ldexp(c, c), np.ldexp(z, z))
    assert eq(da.fmod(a * 12, b), np.fmod(x * 12, y))
    assert eq(da.floor(a * 0.5), np.floor(x * 0.5))
    assert eq(da.ceil(a), np.ceil(x))
    assert eq(da.trunc(a / 2), np.trunc(x / 2))

    assert eq(da.degrees(b), np.degrees(y))
    assert eq(da.radians(a), np.radians(x))

    assert eq(da.rint(a + 0.3), np.rint(x + 0.3))
    assert eq(da.fix(a - 2.5), np.fix(x - 2.5))

    assert eq(da.angle(a + 1j), np.angle(x + 1j))
    assert eq(da.real(a + 1j), np.real(x + 1j))
    assert eq((a + 1j).real, np.real(x + 1j))
    assert eq(da.imag(a + 1j), np.imag(x + 1j))
    assert eq((a + 1j).imag, np.imag(x + 1j))
    assert eq(da.conj(a + 1j * b), np.conj(x + 1j * y))
    assert eq((a + 1j * b).conj(), (x + 1j * y).conj())

    assert eq(da.clip(b, 1, 4), np.clip(y, 1, 4))
    assert eq(da.fabs(b), np.fabs(y))
    assert eq(da.sign(b - 2), np.sign(y - 2))

    l1, l2 = da.frexp(a)
    r1, r2 = np.frexp(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    l1, l2 = da.modf(a)
    r1, r2 = np.modf(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    assert eq(da.around(a, -1), np.around(x, -1))
    def reflectance_from_tbs(self, sun_zenith, tb_near_ir, tb_thermal, **kwargs):
        The relfectance calculated is without units and should be between 0 and 1.


          sun_zenith: Sun zenith angle for every pixel - in degrees

          tb_near_ir: The 3.7 (or 3.9 or equivalent) IR Tb's at every pixel

          tb_thermal: The 10.8 (or 11 or 12 or equivalent) IR Tb's at every
                      pixel (Kelvin)

          tb_ir_co2: The 13.4 micron channel (or similar - co2 absorption band)
                     brightness temperatures at every pixel. If None, no CO2
                     absorption correction will be applied.

        # Check for dask arrays
        if hasattr(tb_near_ir, 'compute') or hasattr(tb_thermal, 'compute'):
            compute = False
            compute = True
        if hasattr(tb_near_ir, 'mask') or hasattr(tb_thermal, 'mask'):
            is_masked = True
            is_masked = False

        if np.isscalar(tb_near_ir):
            tb_nir = np.array([tb_near_ir, ])
            tb_nir = np.asanyarray(tb_near_ir)

        if np.isscalar(tb_thermal):
            tb_therm = np.array([tb_thermal, ])
            tb_therm = np.asanyarray(tb_thermal)

        if tb_therm.shape != tb_nir.shape:
            errmsg = 'Dimensions do not match! {0} and {1}'.format(
                str(tb_therm.shape), str(tb_nir.shape))
            raise ValueError(errmsg)

        tb_ir_co2 = kwargs.get('tb_ir_co2')
        lut = kwargs.get('lut', self.lut)

        if tb_ir_co2 is None:
            co2corr = False
            tbco2 = None
            co2corr = True
            if np.isscalar(tb_ir_co2):
                tbco2 = np.array([tb_ir_co2, ])
                tbco2 = np.asanyarray(tb_ir_co2)

        if not self.rsr:
            raise NotImplementedError("Reflectance calculations without "
                                      "rsr not yet supported!")

        # Assume rsr is in microns!!!
        # FIXME!
        self._rad3x_t11 = self.tb2radiance(tb_therm, lut=lut)['radiance']
        thermal_emiss_one = self._rad3x_t11 * self.rsr_integral

        l_nir = self.tb2radiance(tb_nir, lut=lut)['radiance'] * self.rsr_integral

        if thermal_emiss_one.ravel().shape[0] < 10:
            LOG.info('thermal_emiss_one = %s', str(thermal_emiss_one))
        if l_nir.ravel().shape[0] < 10:
            LOG.info('l_nir = %s', str(l_nir))

        sunzmask = (sun_zenith < 0.0) | (sun_zenith > 88.0)
        sunz = where(sunzmask, 88.0, sun_zenith)

        mu0 = np.cos(np.deg2rad(sunz))
        # mu0 = np.where(np.less(mu0, 0.1), 0.1, mu0)
        self._rad3x = l_nir
        self._solar_radiance = self.solar_flux * mu0 / np.pi

        # CO2 correction to the 3.9 radiance, only if tbs of a co2 band around
        # 13.4 micron is provided:
        if co2corr:
            self.derive_rad39_corr(tb_therm, tbco2)
            LOG.info("CO2 correction applied...")
            self._rad3x_correction = 1.0

        nomin = l_nir - thermal_emiss_one * self._rad3x_correction
        denom = self._solar_radiance - thermal_emiss_one * self._rad3x_correction
        data = nomin / denom
        mask = (self._solar_radiance - thermal_emiss_one *
                self._rad3x_correction) < EPSILON

        logical_or(sunzmask, mask, out=mask)
        logical_or(mask, np.isnan(tb_nir), out=mask)

        self._r3x = where(mask, np.nan, data)

        # Reflectances should be between 0 and 1, but values above 1 is
        # perfectly possible and okay! (Multiply by 100 to get reflectances
        # in percent)
        if hasattr(self._r3x, 'compute') and compute:
            res = self._r3x.compute()
            res = self._r3x
        if is_masked:
            res = np.ma.masked_array(res, mask=np.isnan(res))
        return res
def main(args):
    Flags outliers in data given a model and rescale weights so that whitened residuals have a
    mean amplitude of sqrt(2). 
    Flags and weights are computed per chunk of data
    radec_ref = None
    writes = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              "row": args.row_chunks,
                              "chan": args.chan_chunks
                          columns=('UVW', args.data_column, args.weight_column,
                                   args.model_column, args.flag_column,

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        out_data = []
        for ds in xds:
            field = fields[ds.FIELD_ID]
            radec = field.PHASE_DIR.data.squeeze()

            # check fields match
            if radec_ref is None:
                radec_ref = radec

            if not np.array_equal(radec, radec_ref):

            # load in data and compute whitened residuals
            data = getattr(ds, args.data_column).data
            model = getattr(ds, args.model_column).data
            flag = getattr(ds, args.flag_column).data
            flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None])
            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],

            if args.trim_channels:
                flag = trim_chans(flag, args.trim_channels)

            # Stokes I vis
            weights = (~flag) * weights
            resid_vis = (data - model) * weights
            wsums = (weights[:, :, 0] + weights[:, :, -1])
            resid_vis_I = da.where(
                wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums,

            # whiten and take abs
            white_resid = resid_vis_I * da.sqrt(wsums)
            abs_resid_vis_I = (white_resid).__abs__()

            # mean amp
            sum_amp = da.sum(abs_resid_vis_I)
            count = da.sum(wsums > 0)
            mean_amp = sum_amp / count

            flag_legacy = flag[:, :, 0] | flag[:, :, -1]
            flag_I = da.logical_or(abs_resid_vis_I > args.sigma_cut * mean_amp,

            # new flags
            updated_flag = da.broadcast_to(flag_I[:, :, None],

            # scale weights (whitened residuals should have mean amplitude of 1/sqrt(2))
            if args.scale_weights:
                # recompute mean amp with new flags
                weights = (~updated_flag) * weights
                resid_vis = (data - model) * weights
                wsums = (weights[:, :, 0] + weights[:, :, -1])
                resid_vis_I = da.where(
                    wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums,
                white_resid = resid_vis_I * da.sqrt(wsums)
                abs_resid_vis_I = (white_resid).__abs__()
                sum_amp = da.sum(abs_resid_vis_I)
                count = da.sum(wsums > 0)
                mean_amp = sum_amp / count
                updated_weight = 2**0.5 * weights / mean_amp**2
                updated_weight = weights

            ds = ds.assign(**{
                args.weight_out_column: (("row", "chan", "corr"),
            ds = ds.assign(**{
                args.flag_out_column: (("row", "chan", "corr"), updated_flag)

                columns=[args.flag_out_column, args.weight_out_column]))

    with ProgressBar():

    # report new mean amp
    if args.report_means:
        radec_ref = None
        mean_amps = []
        for ims in args.ms:
            xds = xds_from_ms(
                group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                    "row": args.row_chunks,
                    "chan": args.chan_chunks
                columns=('UVW', args.data_column, args.weight_out_column,
                         args.model_column, args.flag_out_column, 'FLAG_ROW'))

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()

                # check fields match
                if radec_ref is None:
                    radec_ref = radec

                if not np.array_equal(radec, radec_ref):

                # load in data and compute whitened residuals
                data = getattr(ds, args.data_column).data
                model = getattr(ds, args.model_column).data
                flag = getattr(ds, args.flag_out_column).data
                flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None])
                weights = getattr(ds, args.weight_out_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],

                # Stokes I vis
                weights = (~flag) * weights
                resid_vis = (data - model) * weights
                wsums = (weights[:, :, 0] + weights[:, :, -1])
                resid_vis_I = da.where(
                    wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums,

                # whiten and take abs
                white_resid = resid_vis_I * da.sqrt(wsums)
                abs_resid_vis_I = (white_resid).__abs__()

                # mean amp
                sum_amp = da.sum(abs_resid_vis_I)
                count = da.sum(wsums > 0)
                mean_amps.append(sum_amp / count)

        mean_amps = dask.compute(mean_amps)[0]

    def read_band(self, key, info):
        """Read the data."""
        tic = datetime.now()
        header = {}
        with open(self.filename, "rb") as fp_:

            header['block1'] = np.fromfile(
                fp_, dtype=_BASIC_INFO_TYPE, count=1)
            header["block2"] = np.fromfile(fp_, dtype=_DATA_INFO_TYPE, count=1)
            header["block3"] = np.fromfile(fp_, dtype=_PROJ_INFO_TYPE, count=1)
            header["block4"] = np.fromfile(fp_, dtype=_NAV_INFO_TYPE, count=1)
            header["block5"] = np.fromfile(fp_, dtype=_CAL_INFO_TYPE, count=1)
            logger.debug("Band number = " +
            logger.debug('Time_interval: %s - %s',
                         str(self.start_time), str(self.end_time))
            band_number = header["block5"]['band_number'][0]
            if band_number < 7:
                cal = np.fromfile(fp_, dtype=_VISCAL_INFO_TYPE, count=1)
                cal = np.fromfile(fp_, dtype=_IRCAL_INFO_TYPE, count=1)

            header['calibration'] = cal

            header["block6"] = np.fromfile(
                fp_, dtype=_INTER_CALIBRATION_INFO_TYPE, count=1)
            header["block7"] = np.fromfile(
                fp_, dtype=_SEGMENT_INFO_TYPE, count=1)
            header["block8"] = np.fromfile(
                fp_, dtype=_NAVIGATION_CORRECTION_INFO_TYPE, count=1)
            # 8 The navigation corrections:
            ncorrs = header["block8"]['numof_correction_info_data'][0]
            dtype = np.dtype([
                ("line_number_after_rotation", "<u2"),
                ("shift_amount_for_column_direction", "f4"),
                ("shift_amount_for_line_direction", "f4"),
            corrections = []
            for i in range(ncorrs):
                corrections.append(np.fromfile(fp_, dtype=dtype, count=1))
            fp_.seek(40, 1)
            header['navigation_corrections'] = corrections
            header["block9"] = np.fromfile(fp_,
            numobstimes = header["block9"]['number_of_observation_times'][0]

            dtype = np.dtype([
                ("line_number", "<u2"),
                ("observation_time", "f8"),
            lines_and_times = []
            for i in range(numobstimes):
            header['observation_time_information'] = lines_and_times
            fp_.seek(40, 1)

            header["block10"] = np.fromfile(fp_,
            dtype = np.dtype([
                ("line_number", "<u2"),
                ("numof_error_pixels_per_line", "<u2"),
            num_err_info_data = header["block10"][
            err_info_data = []
            for i in range(num_err_info_data):
                err_info_data.append(np.fromfile(fp_, dtype=dtype, count=1))
            header['error_information_data'] = err_info_data
            fp_.seek(40, 1)

            np.fromfile(fp_, dtype=_SPARE_TYPE, count=1)

            nlines = int(header["block2"]['number_of_lines'][0])
            ncols = int(header["block2"]['number_of_columns'][0])

            res = da.from_array(np.memmap(self.filename, offset=fp_.tell(),
                                          dtype='<u2',  shape=(nlines, ncols), mode='r'),

        invalid = da.logical_or(res == header['block5']["count_value_outside_scan_pixels"][0],
                                res == header['block5']["count_value_error_pixels"][0])
        res = da.where(invalid, np.float32(np.nan), res)
        self._header = header

        logger.debug("Reading time " + str(datetime.now() - tic))
        res = self.calibrate(res, key.calibration)
        new_info = dict(units=info['units'],
                        satellite_altitude=float(self.nav_info['distance_earth_center_to_satellite'] -
                                                 self.proj_info['earth_equatorial_radius']) * 1000)
        res = xr.DataArray(res, attrs=new_info, dims=['y', 'x'])
        res = res.where(self.geo_mask())
        return res
    def get_value(self, group, corr, extras, flag, flag_row, chanslice):
        coldata = self.get_column_data(group)
        # correlation may be pre-set by plot type, or may be passed to us
        corr = self.corr if self.corr is not None else corr
        # apply correlation reduction
        if coldata is not None and coldata.ndim == 3:
            assert corr is not None
            # the mapper can't have a specific axis set
            if self.mapper.axis is not None:
                raise TypeError(f"{self.name}: unexpected column with ndim=3")
            coldata = self.ms.corr_data_mappers[corr](coldata)
        # apply mapping function
        mapper = self.mapper
        # complex values with an identity mapper get an amp mapper assigned to them by default
        if np.iscomplexobj(coldata) and mapper is data_mappers["_"]:
            mapper = data_mappers["amp"]
        coldata = mapper.mapper(
            coldata, **{name: extras[name]
                        for name in self.mapper.extras})
        # for a constant axis, compute minmax on the fly
        if mapper.const and self._minmax_autorange:
            if np.isscalar(coldata):
                min1 = max1 = coldata
                min1, max1 = coldata.data.min(), coldata.data.max()
            self.minmax = min(self.minmax[0], min1) if self.minmax[0] is not None else min1, \
                          min(self.minmax[1], max1) if self.minmax[1] is not None else max1
        # scalar is just a scalar
        if np.isscalar(coldata):
            coldata = da.array(coldata)
            flag = None
            # apply channel slicing, if there's a channel axis in the array (and the array is a DataArray)
            if type(coldata) is xarray.DataArray and 'chan' in coldata.dims:
                coldata = coldata[dict(chan=chanslice)]
            # determine flags -- start with original flags
            if flag is not None:
                if coldata.ndim == 2:
                    flag = self.ms.corr_flag_mappers[corr](flag)
                elif coldata.ndim == 1:
                    if not self.mapper.axis:
                        flag = flag_row
                    elif self.mapper.axis == 1:
                        flag = None
                # shapes must now match
                if flag is not None and coldata.shape != flag.shape:
                    raise TypeError(f"{self.name}: unexpected column shape")
        # # discretize
        # if self.nlevels:

        if coldata.dtype is bool or np.issubdtype(coldata.dtype, np.integer):
            if self._is_discrete is False:
                raise TypeError(
                    f"{self.label}: column changed from continuous-valued to discrete. This is a bug, or a very weird MS."
            self._is_discrete = True
            # do we need to apply a remapping?
            if self.subset_remapper is not None:
                if type(
                ) is not dask.array.core.Array:  # could be xarray backed by dask array
                    coldata = coldata.data
                coldata = self.subset_remapper[coldata]
                bad_bins = da.greater_equal(coldata, len(self.subset_indices))
                if flag is None:
                    flag = bad_bins
                    flag = da.logical_or(flag.data, bad_bins)
            if self._is_discrete is True:
                raise TypeError(
                    f"{self.label}: column chnaged from discrete to continuous-valued. This is a bug, or a very weird MS."
            self._is_discrete = False

        # Ensure dask arrays for creating dask masked arrays
        if isinstance(coldata, xarray.DataArray):
            coldata = coldata.data

        if isinstance(flag, xarray.DataArray):
            flag = flag.data

        bad_data = da.logical_not(da.isfinite(coldata))
        if flag is not None:
            return dama.masked_array(coldata, da.logical_or(flag, bad_data))
            return dama.masked_array(coldata, bad_data)
def _residual(ms, stack, **kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)
    pyscilog.log_to_file(args.output_filename + '.log')

    # number of threads per worker
    if args.nthreads is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify nthreads when using a distributed scheduler"
        import multiprocessing
        nthreads = multiprocessing.cpu_count()
        args.nthreads = nthreads
        nthreads = args.nthreads

    # configure memory limit
    if args.mem_limit is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify mem-limit when using a distributed scheduler"
        import psutil
        mem_limit = int(psutil.virtual_memory()[0] /
                        1e9)  # 100% of memory by default
        args.mem_limit = mem_limit
        mem_limit = args.mem_limit

    nband = args.nband
    if args.nworkers is None:
        nworkers = nband
        args.nworkers = nworkers
        nworkers = args.nworkers

    if args.nthreads_per_worker is None:
        nthreads_per_worker = 1
        args.nthreads_per_worker = nthreads_per_worker
        nthreads_per_worker = args.nthreads_per_worker

    # the number of chunks being read in simultaneously is equal to
    # the number of dask threads
    nthreads_dask = nworkers * nthreads_per_worker

    if args.ngridder_threads is None:
        if args.host_address is not None:
            ngridder_threads = nthreads // nthreads_per_worker
            ngridder_threads = nthreads // nthreads_dask
        args.ngridder_threads = ngridder_threads
        ngridder_threads = args.ngridder_threads

    ms = list(ms)
    print('Input Options:', file=log)
    for key in kw.keys():
        print('     %25s = %s' % (key, args[key]), file=log)

    # numpy imports have to happen after this step
    from pfb import set_client
    set_client(nthreads, mem_limit, nworkers, nthreads_per_worker,
               args.host_address, stack, log)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    from dask.graph_manipulation import clone
    from dask.distributed import performance_report
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import residual as im2residim
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # real valued weights
    wdims = getattr(xds[0], args.weight_column).data.ndim
    if wdims == 2:  # WEIGHT
        memory_per_row += ncorr * data_bytes / 2
    else:  # WEIGHT_SPECTRUM
        memory_per_row += bytes_per_row / 2

    # flags (uint8 or bool)
    memory_per_row += np.dtype(np.uint8).itemsize * max_chan_chunk * ncorr

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(fov / cell_size)
        if npix % 2:
            npix += 1
        nx = good_size(npix)
        ny = good_size(npix)
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("Image size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow,
                                   memory_per_row, nthreads_per_worker)

        # single band per node
        row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row,

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),

    chunks = {}
    for ims in ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']

    dirties = []
    radec = None  # assumes we are only imaging field 0 of first MS
    for ims in ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data = getattr(ds, args.data_column).data
            dataxx = data[:, :, 0]
            datayy = data[:, :, -1]

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply adjoint of mueller term.
            # Phases modify data amplitudes modify weights.
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0]))
                datayy *= da.exp(-1j * da.angle(mueller[:, :, -1]))
                weightsxx *= da.absolute(mueller[:, :, 0])
                weightsyy *= da.absolute(mueller[:, :, -1])

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy
            data = (weightsxx * dataxx + weightsyy * datayy)
            # TODO - turn off this stupid warning
            data = da.where(weights, data / weights, 0.0j)

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            dirty = vis2im(uvw,


    # dask.visualize(dirties, filename=args.output_filename + '_graph.pdf', optimize_graph=False)

    if not args.mock:
        # result = dask.compute(dirties, wsum, optimize_graph=False)
        with performance_report(filename=args.output_filename + '_per.html'):
            result = dask.compute(dirties, optimize_graph=False)

        dirties = result[0]

        dirty = stitch_images(dirties, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
        save_fits(args.output_filename + '_dirty.fits',

    print("All done here.", file=log)