Exemple #1
0
def main(args):
    # get coord info
    hdr = fits.getheader(args.image)
    l_coord, ref_l = data_from_header(hdr, axis=1)
    l_coord -= ref_l
    m_coord, ref_m = data_from_header(hdr, axis=2)
    m_coord -= ref_m
    if hdr["CTYPE4"].lower() == 'freq':
        freq_axis = 4
    elif hdr["CTYPE3"].lower() == 'freq':
        freq_axis = 3
    else:
        raise ValueError("Freq axis must be 3rd or 4th")
    freqs, ref_freq = data_from_header(hdr, axis=freq_axis)
    
    xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')
    
    # interpolate primary beam to fits header and optionally average over time
    beam_image = interpolate_beam(xx, yy, freqs, args)


    # save power beam
    save_fits(args.output_filename, beam_image, hdr)
    print("Wrote interpolated beam cube to %s \n" % args.output_filename)


    return
Exemple #2
0
def sara(psf,
         model,
         residual,
         mask=None,
         beam_image=None,
         hessian=None,
         wsum=1,
         adapt_sig21=True,
         hdr=None,
         hdr_mfs=None,
         outfile=None,
         cpsf=None,
         nthreads=1,
         sig_21=1e-6,
         sigma_frac=100,
         maxit=10,
         tol=1e-3,
         gamma=0.99,
         psi_levels=2,
         psi_basis=None,
         alpha=None,
         pdtol=1e-6,
         pdmaxit=250,
         pdverbose=1,
         positivity=True,
         cgtol=1e-6,
         cgminit=25,
         cgmaxit=150,
         cgverbose=1,
         pmtol=1e-5,
         pmmaxit=50,
         pmverbose=1):

    if len(residual.shape) > 3:
        raise ValueError("Residual must have shape (nband, nx, ny)")

    nband, nx, ny = residual.shape

    if beam_image is None:

        def beam(x):
            return x

        def beaminv(x):
            return x
    else:
        try:
            assert beam.shape == (nband, nx, ny)

            def beam(x):
                return beam_image * x

            def beaminv(x):
                return np.where(beam_image > 0.01, x / beam_image, x)
        except BaseException:
            raise ValueError("Beam has incorrect shape")

    if mask is None:

        def mask(x):
            return x
    else:
        try:
            if mask.ndim == 2:
                assert mask.shape == (nx, ny)

                def mask(x):
                    return mask[None] * x
            elif mask.ndim == 3:
                assert mask.shape == (1, nx, ny)

                def mask(x):
                    return mask * x
            else:
                raise ValueError
        except BaseException:
            raise ValueError("Mask has incorrect shape")

    # PSF operator
    psfo = PSF(psf, residual.shape,
               nthreads=nthreads)  #, backward_undersize=1.2)

    if cpsf is None:
        raise ValueError
    else:
        cpsfo = PSF(cpsf, residual.shape, nthreads=nthreads)

    residual_mfs = np.sum(residual, axis=0)
    residual = mask(beam(residual))
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)

    # wavelet dictionary
    if psi_basis is None:
        psi = DaskPSI(imsize=residual.shape,
                      nlevels=psi_levels,
                      nthreads=nthreads)
    else:
        if not isinstance(psi_basis, list):
            psi_basis = [psi_basis]
        psi = DaskPSI(imsize=residual.shape,
                      nlevels=psi_levels,
                      nthreads=nthreads,
                      bases=psi_basis)

    # set alpha's and sig21's
    # this assumes that the model has been initialised using NNLS
    alpha = np.zeros(psi.nbasis)
    sigmas = np.zeros(psi.nbasis)
    resid_comps = psi.hdot(
        residual /
        np.amax(residual.reshape(-1, nx * ny), axis=1)[:, None, None])
    l2_norm = np.linalg.norm(psi.hdot(cpsfo.convolve(model)), axis=1)
    for m in range(psi.nbasis):
        alpha[m] = np.std(resid_comps[m])
        _, sigmas[m] = expon.fit(l2_norm[m], floc=0.0)
        print("Basis %i, alpha %f, sigma %f" % (m, alpha[m], sigmas[m]),
              file=log)

    # l21 weights and dual
    weights21 = np.ones((psi.nbasis, psi.nmax), dtype=residual.dtype)
    for m in range(psi.nbasis):
        weights21[m] *= sigmas[m] / sig_21
    dual = np.zeros((psi.nbasis, nband, psi.nmax), dtype=residual.dtype)

    # use PSF to approximate Hessian if not passed in
    if hessian is None:
        hessian = psfo.convolve
        wsum = 1.0

    #  preconditioning operator
    if model.any():
        varmap = np.maximum(rms, sigma_frac * cpsfo.convolve(model))
    else:
        varmap = np.ones(model.shape) * sigma_frac * rms

    def hessf(x):
        # return mask(beam(hessian(mask(beam(x)))))/wsum + x / varmap
        return mask(beam(psfo.convolve(mask(beam(x))))) + x / varmap

    def hessb(x):
        return mask(beam(psfo.convolve(mask(beam(x))))) + x / varmap

    beta, betavec = power_method(hessb,
                                 residual.shape,
                                 tol=pmtol,
                                 maxit=pmmaxit,
                                 verbosity=pmverbose)

    if model.any():
        dirty = residual + hessian(mask(beam(model))) / wsum
    else:
        dirty = residual

    # deconvolve
    for i in range(0, maxit):
        x = pcg(hessf,
                mask(beam(residual)),
                np.zeros_like(residual),
                M=lambda x: x * varmap,
                tol=cgtol,
                maxit=cgmaxit,
                minit=cgminit,
                verbosity=cgverbose)

        # update model
        modelp = model
        model = modelp + gamma * x

        model, dual = primal_dual(hessb,
                                  model,
                                  modelp,
                                  dual,
                                  sig_21,
                                  psi,
                                  weights21,
                                  beta,
                                  prox_21,
                                  tol=pdtol,
                                  maxit=pdmaxit,
                                  report_freq=50,
                                  mask=mask,
                                  verbosity=pdverbose,
                                  positivity=positivity)

        # get residual
        residual, residual_mfs = resid_func(model, dirty, hessian, mask, beam,
                                            wsum)
        model_mfs = np.mean(model, axis=0)
        x_mfs = np.mean(x, axis=0)

        # check stopping criteria
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        # update variance map (positivity constraint optional)
        varmap = np.maximum(rms, sigma_frac * cpsfo.convolve(model))

        # update spectral norm
        beta, betavec = power_method(hessb,
                                     residual.shape,
                                     b0=betavec,
                                     tol=pmtol,
                                     maxit=pmmaxit,
                                     verbosity=pmverbose)

        print("Iter %i: peak residual = %f, rms = %f, eps = %f" %
              (i + 1, rmax, rms, eps),
              file=log)

        # reweight
        l2_norm = np.linalg.norm(psi.hdot(model), axis=1)
        for m in range(psi.nbasis):
            if adapt_sig21:
                _, sigmas[m] = expon.fit(l2_norm[m], floc=0.0)
                print('basis %i, sigma %f' % sigmas[m], file=log)

            weights21[m] = alpha[m] / (alpha[m] +
                                       l2_norm[m]) * sigmas[m] / sig_21

        # save current iteration
        if outfile is not None:
            assert hdr is not None
            assert hdr_mfs is not None

            save_fits(outfile + str(i + 1) + '_model_mfs.fits', model_mfs,
                      hdr_mfs)

            save_fits(outfile + str(i + 1) + '_model.fits', model, hdr)

            save_fits(outfile + str(i + 1) + '_update.fits', x, hdr)

            save_fits(outfile + str(i + 1) + '_update_mfs.fits', x_mfs, hdr)

            save_fits(outfile + str(i + 1) + '_residual_mfs.fits',
                      residual_mfs, hdr_mfs)

            save_fits(outfile + str(i + 1) + '_residual.fits', residual * wsum,
                      hdr)

        if eps < tol:
            print("Success, convergence after %i iterations" % (i + 1),
                  file=log)
            break

    return model
Exemple #3
0
def _psf(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    # from dask.distributed import performance_report
    from dask.graph_manipulation import clone
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    from daskms import Dataset
    from daskms.experimental.zarr import xds_to_zarr
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import dirty as vis2im
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    ms = args.ms
    nband = args.nband
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in args.ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(args.ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    # we still have to cater for complex valued data because we cast
    # the weights to complex but we not longer need to factor the
    # weight column into our memory budget
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # flags (uint8 or bool)
    memory_per_row += bytes_per_row / 8

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    # TIME
    memory_per_row += xds[0].TIME.data.itemsize

    # data column is not actually read into memory just used to infer
    # dtype and chunking
    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2', 'TIME')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in args.ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(args.psf_oversize * fov / cell_size)
        if npix % 2:
            npix += 1
        nx = npix
        ny = npix
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size,
                                   nrow, memory_per_row,
                                   args.nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                   memory_per_row, args.nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in args.ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    psfs = []
    radec = None  # assumes we are only imaging field 0 of first MS
    out_datasets = []
    for ims in args.ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data_type = getattr(ds, args.data_column).data.dtype
            data_shape = getattr(ds, args.data_column).data.shape
            data_chunks = getattr(ds, args.data_column).data.chunks

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data_shape,
                                          chunks=data_chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data_shape,
                                                      chunks=data_chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply mueller term
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                weightsxx *= da.absolute(mueller[:, :, 0])**2
                weightsyy *= da.absolute(mueller[:, :, -1])**2

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            psf = vis2im(uvw,
                         freqs[ims][spw],
                         weights.astype(data_type),
                         freq_bin_idx[ims][spw],
                         freq_bin_counts[ims][spw],
                         nx,
                         ny,
                         cell_rad,
                         flag=mask.astype(np.uint8),
                         nthreads=args.nvthreads,
                         epsilon=args.epsilon,
                         do_wstacking=args.wstack,
                         double_accum=args.double_accum)

            psfs.append(psf)

            data_vars = {
                'FIELD_ID': (('row', ),
                             da.full_like(ds.TIME.data,
                                          ds.FIELD_ID,
                                          chunks=args.row_out_chunk)),
                'DATA_DESC_ID': (('row', ),
                                 da.full_like(ds.TIME.data,
                                              ds.DATA_DESC_ID,
                                              chunks=args.row_out_chunk)),
                'WEIGHT':
                (('row', 'chan'), weights.rechunk({0: args.row_out_chunk
                                                   })),  # why no 'f4'?
                'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk}))
            }

            coords = {'chan': (('chan', ), freqs[ims][spw])}

            out_ds = Dataset(data_vars, coords)

            out_datasets.append(out_ds)

    writes = xds_to_zarr(out_datasets,
                         args.output_filename + '.zarr',
                         columns='ALL')

    # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False)
    # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False)

    if not args.mock:
        # psfs = dask.compute(psfs, writes, optimize_graph=False)[0]
        # with performance_report(filename=args.output_filename + '_psf_per.html'):
        psfs = dask.compute(psfs, writes, optimize_graph=False)[0]

        psf = stitch_images(psfs, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_psf.fits',
                  psf,
                  hdr,
                  dtype=args.output_type)

        psf_mfs = np.sum(psf, axis=0)
        wsum = psf_mfs.max()
        psf_mfs /= wsum

        hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                          np.mean(freq_out))
        save_fits(args.output_filename + '_psf_mfs.fits',
                  psf_mfs,
                  hdr_mfs,
                  dtype=args.output_type)

    print("All done here.", file=log)
Exemple #4
0
def _main(dest=sys.stdout):
    from pfb.parser import create_parser
    args = create_parser().parse_args()

    if not args.nthreads:
        import multiprocessing
        args.nthreads = multiprocessing.cpu_count()

    if not args.mem_limit:
        import psutil
        args.mem_limit = int(psutil.virtual_memory()[0] /
                             1e9)  # 100% of memory by default

    import numpy as np
    import numba
    import numexpr
    import dask
    import dask.array as da
    from daskms import xds_from_ms, xds_from_table
    from astropy.io import fits
    from pfb.utils.fits import (set_wcs, load_fits, save_fits, compare_headers,
                                data_from_header)
    from pfb.utils.restoration import fitcleanbeam
    from pfb.utils.misc import Gaussian2D
    from pfb.operators.gridder import Gridder
    from pfb.operators.psf import PSF
    from pfb.deconv.sara import sara
    from pfb.deconv.clean import clean
    from pfb.deconv.spotless import spotless
    from pfb.deconv.nnls import nnls
    from pfb.opt.pcg import pcg

    if not isinstance(args.ms, list):
        args.ms = [args.ms]

    pyscilog.log_to_file(args.outfile + '.log')
    pyscilog.enable_memory_logging(level=3)

    GD = vars(args)
    print('Input Options:', file=log)
    for key in GD.keys():
        print('     %25s = %s' % (key, GD[key]), file=log)

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list([tmp_freq]))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=dest)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size, file=dest)

    if args.nx is None or args.ny is None:
        from ducc0.fft import good_size
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = good_size(npix)
        args.ny = good_size(npix)

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny),
          file=dest)

    # mask
    if args.mask is not None:
        mask_array = load_fits(args.mask, dtype=args.real_type).squeeze()
        if mask_array.shape != (args.nx, args.ny):
            raise ValueError("Mask has incorrect shape.")
        # add freq axis
        mask_array = mask_array[None]

        def mask(x):
            return mask_array * x
    else:
        mask_array = None

        def mask(x):
            return x

    # init gridder
    R = Gridder(
        args.ms,
        args.nx,
        args.ny,
        args.cell_size,
        nband=args.nband,
        nthreads=args.nthreads,
        do_wstacking=args.do_wstacking,
        row_chunks=args.row_chunks,
        psf_oversize=args.psf_oversize,
        data_column=args.data_column,
        epsilon=args.epsilon,
        weight_column=args.weight_column,
        imaging_weight_column=args.imaging_weight_column,
        model_column=args.model_column,
        flag_column=args.flag_column,
        weighting=args.weighting,
        robust=args.robust,
        mem_limit=int(
            0.8 * args.mem_limit))  # assumes gridding accounts for 80% memory
    freq_out = R.freq_out
    radec = R.radec

    print("PSF size set to (%i, %i, %i)" % (args.nband, R.nx_psf, R.ny_psf),
          file=dest)

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600, R.nx_psf,
                      R.ny_psf, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          R.nx_psf, R.ny_psf, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        try:
            compare_headers(hdr_psf, fits.getheader(args.psf))
            psf = load_fits(args.psf, dtype=args.real_type).squeeze()
        except BaseException:
            raise
            psf = R.make_psf()
            save_fits(args.outfile + '_psf.fits', psf, hdr_psf)
    else:
        psf = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf, hdr_psf)

    # Normalising by wsum (so that the PSF always sums to 1) results in the
    # most intuitive sig_21 values and by far the least bookkeeping.
    # However, we won't save the cubes that way as it destroys information
    # about the noise in image space. Note only the MFS images will have the
    # usual units of Jy/beam.
    wsums = np.amax(psf.reshape(args.nband, R.nx_psf * R.ny_psf), axis=1)
    wsum = np.sum(wsums)
    psf /= wsum
    psf_mfs = np.sum(psf, axis=0)

    # fit restoring psf
    GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0)
    GaussPars = fitcleanbeam(psf, level=0.5, pixsize=1.0)

    cpsf_mfs = np.zeros(psf_mfs.shape, dtype=args.real_type)
    cpsf = np.zeros(psf.shape, dtype=args.real_type)

    lpsf = np.arange(-R.nx_psf / 2, R.nx_psf / 2)
    mpsf = np.arange(-R.ny_psf / 2, R.ny_psf / 2)
    xx, yy = np.meshgrid(lpsf, mpsf, indexing='ij')

    cpsf_mfs = Gaussian2D(xx, yy, GaussPar[0], normalise=False)

    for v in range(args.nband):
        cpsf[v] = Gaussian2D(xx, yy, GaussPars[v], normalise=False)

    from pfb.utils.fits import add_beampars
    GaussPar = list(GaussPar[0])
    GaussPar[0] *= args.cell_size / 3600
    GaussPar[1] *= args.cell_size / 3600
    GaussPar = tuple(GaussPar)
    hdr_psf_mfs = add_beampars(hdr_psf_mfs, GaussPar)

    save_fits(args.outfile + '_cpsf_mfs.fits', cpsf_mfs, hdr_psf_mfs)
    save_fits(args.outfile + '_psf_mfs.fits', psf_mfs, hdr_psf_mfs)

    GaussPars = list(GaussPars)
    for b in range(args.nband):
        GaussPars[b] = list(GaussPars[b])
        GaussPars[b][0] *= args.cell_size / 3600
        GaussPars[b][1] *= args.cell_size / 3600
        GaussPars[b] = tuple(GaussPars[b])
    GaussPars = tuple(GaussPars)
    hdr_psf = add_beampars(hdr_psf, GaussPar, GaussPars)

    save_fits(args.outfile + '_cpsf.fits', cpsf, hdr_psf)

    # dirty
    if args.dirty is not None:
        try:
            compare_headers(hdr, fits.getheader(args.dirty))
            dirty = load_fits(args.dirty).squeeze()
        except BaseException:
            raise
            dirty = R.make_dirty()
            save_fits(args.outfile + '_dirty.fits', dirty, hdr)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty /= wsum
    dirty_mfs = np.sum(dirty, axis=0)
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    quit()
    # initial model and residual
    if args.x0 is not None:
        try:
            compare_headers(hdr, fits.getheader(args.x0))
            model = load_fits(args.x0, dtype=args.real_type).squeeze()
            if args.first_residual is not None:
                try:
                    compare_headers(hdr, fits.getheader(args.first_residual))
                    residual = load_fits(args.first_residual,
                                         dtype=args.real_type).squeeze()
                except BaseException:
                    residual = R.make_residual(model)
                    save_fits(args.outfile + '_first_residual.fits', residual,
                              hdr)
            else:
                residual = R.make_residual(model)
                save_fits(args.outfile + '_first_residual.fits', residual, hdr)
            residual /= wsum
        except BaseException:
            model = np.zeros((args.nband, args.nx, args.ny))
            residual = dirty.copy()
    else:
        model = np.zeros((args.nband, args.nx, args.ny))
        residual = dirty.copy()

    residual_mfs = np.sum(residual, axis=0)
    save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs)

    # smooth beam
    if args.beam_model is not None:
        if args.beam_model[-5:] == '.fits':
            beam_image = load_fits(args.beam_model,
                                   dtype=args.real_type).squeeze()
            if beam_image.shape != (args.nband, args.nx, args.ny):
                raise ValueError("Beam has incorrect shape")

        elif args.beam_model == "JimBeam":
            from katbeam import JimBeam
            if args.band.lower() == 'l':
                beam = JimBeam('MKAT-AA-L-JIM-2020')
            else:
                beam = JimBeam('MKAT-AA-UHF-JIM-2020')
            beam_image = np.zeros((args.nband, args.nx, args.ny),
                                  dtype=args.real_type)

            l_coord, ref_l = data_from_header(hdr, axis=1)
            l_coord -= ref_l
            m_coord, ref_m = data_from_header(hdr, axis=2)
            m_coord -= ref_m
            xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')

            for v in range(args.nband):
                beam_image[v] = beam.I(xx, yy, freq_out[v])

        def beam(x):
            return beam_image * x
    else:
        beam_image = None

        def beam(x):
            return x

    if args.init_nnls:
        print("Initialising with NNLS", file=log)
        model = nnls(psf,
                     model,
                     residual,
                     mask=mask_array,
                     beam_image=beam_image,
                     hdr=hdr,
                     hdr_mfs=hdr_mfs,
                     outfile=args.outfile,
                     maxit=1,
                     nthreads=args.nthreads)

        residual = R.make_residual(beam(mask(model))) / wsum
        residual_mfs = np.sum(residual, axis=0)

    # deconvolve
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)
    redo_dirty = False
    print("Peak of initial residual is %f and rms is %f" % (rmax, rms),
          file=dest)
    for i in range(0, args.maxit):
        # run minor cycle of choice
        modelp = model.copy()
        if args.deconv_mode == 'sara':
            model = sara(psf,
                         model,
                         residual,
                         mask=mask_array,
                         beam_image=beam_image,
                         hessian=R.convolve,
                         wsum=wsum,
                         adapt_sig21=args.adapt_sig21,
                         hdr=hdr,
                         hdr_mfs=hdr_mfs,
                         outfile=args.outfile,
                         cpsf=cpsf,
                         nthreads=args.nthreads,
                         sig_21=args.sig_21,
                         sigma_frac=args.sigma_frac,
                         maxit=args.minormaxit,
                         tol=args.minortol,
                         gamma=args.gamma,
                         psi_levels=args.psi_levels,
                         psi_basis=args.psi_basis,
                         pdtol=args.pdtol,
                         pdmaxit=args.pdmaxit,
                         pdverbose=args.pdverbose,
                         positivity=args.positivity,
                         cgtol=args.cgtol,
                         cgminit=args.cgminit,
                         cgmaxit=args.cgmaxit,
                         cgverbose=args.cgverbose,
                         pmtol=args.pmtol,
                         pmmaxit=args.pmmaxit,
                         pmverbose=args.pmverbose)

        elif args.deconv_mode == 'clean':
            model = clean(psf,
                          model,
                          residual,
                          mask=mask_array,
                          beam=beam_image,
                          nthreads=args.nthreads,
                          maxit=args.minormaxit,
                          gamma=args.gamma,
                          peak_factor=args.peak_factor,
                          threshold=args.threshold,
                          hbgamma=args.hbgamma,
                          hbpf=args.hbpf,
                          hbmaxit=args.hbmaxit,
                          hbverbose=args.hbverbose)
        elif args.deconv_mode == 'spotless':
            model = spotless(psf,
                             model,
                             residual,
                             mask=mask_array,
                             beam_image=beam_image,
                             hessian=R.convolve,
                             wsum=wsum,
                             adapt_sig21=args.adapt_sig21,
                             cpsf=cpsf_mfs,
                             hdr=hdr,
                             hdr_mfs=hdr_mfs,
                             outfile=args.outfile,
                             sig_21=args.sig_21,
                             sigma_frac=args.sigma_frac,
                             nthreads=args.nthreads,
                             gamma=args.gamma,
                             peak_factor=args.peak_factor,
                             maxit=args.minormaxit,
                             tol=args.minortol,
                             threshold=args.threshold,
                             positivity=args.positivity,
                             hbgamma=args.hbgamma,
                             hbpf=args.hbpf,
                             hbmaxit=args.hbmaxit,
                             hbverbose=args.hbverbose,
                             pdtol=args.pdtol,
                             pdmaxit=args.pdmaxit,
                             pdverbose=args.pdverbose,
                             cgtol=args.cgtol,
                             cgminit=args.cgminit,
                             cgmaxit=args.cgmaxit,
                             cgverbose=args.cgverbose,
                             pmtol=args.pmtol,
                             pmmaxit=args.pmmaxit,
                             pmverbose=args.pmverbose)
        else:
            raise ValueError("Unknown deconvolution mode ", args.deconv_mode)

        # get residual
        if redo_dirty:
            # Need to do this if weights or Jones has changed
            # (eg. if we change robustness factor, reweight or calibrate)
            psf = R.make_psf()
            wsums = np.amax(psf.reshape(args.nband, R.nx_psf * R.ny_psf),
                            axis=1)
            wsum = np.sum(wsums)
            psf /= wsum
            dirty = R.make_dirty() / wsum

        # compute in image space
        # residual = dirty - R.convolve(beam(mask(model))) / wsum
        residual = R.make_residual(beam(mask(model))) / wsum

        residual_mfs = np.sum(residual, axis=0)

        # save current iteration
        model_mfs = np.mean(model, axis=0)
        save_fits(args.outfile + '_major' + str(i + 1) + '_model_mfs.fits',
                  model_mfs, hdr_mfs)

        save_fits(args.outfile + '_major' + str(i + 1) + '_model.fits', model,
                  hdr)

        save_fits(args.outfile + '_major' + str(i + 1) + '_residual_mfs.fits',
                  residual_mfs, hdr_mfs)

        save_fits(args.outfile + '_major' + str(i + 1) + '_residual.fits',
                  residual * wsum, hdr)

        # check stopping criteria
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        print("At iteration %i peak of residual is %f, rms is %f, current "
              "eps is %f" % (i + 1, rmax, rms, eps),
              file=dest)

        if eps < args.tol:
            break

    if args.mop_flux:
        print("Mopping flux", file=dest)

        # vague Gaussian prior on x
        def hess(x):
            return mask(beam(R.convolve(mask(beam(x))))) / wsum + 1e-6 * x

        def M(x):
            return x / 1e-6  # preconditioner

        x = pcg(hess,
                mask(beam(residual)),
                np.zeros(residual.shape, dtype=residual.dtype),
                M=M,
                tol=0.1 * args.cgtol,
                maxit=args.cgmaxit,
                minit=args.cgminit,
                verbosity=args.cgverbose)

        model += x
        # residual = dirty - R.convolve(beam(mask(model))) / wsum
        residual = R.make_residual(beam(mask(model))) / wsum

        save_fits(args.outfile + '_mopped_model.fits', model, hdr)
        save_fits(args.outfile + '_mopped_residual.fits', residual, hdr)
        model_mfs = np.mean(model, axis=0)
        save_fits(args.outfile + '_mopped_model_mfs.fits', model_mfs, hdr_mfs)
        residual_mfs = np.sum(residual, axis=0)
        save_fits(args.outfile + '_mopped_residual_mfs.fits', residual_mfs,
                  hdr_mfs)

        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)

        print("After mopping flux peak of residual is %f, rms is %f" %
              (rmax, rms),
              file=dest)

    # if args.interp_model:
    #     nband = args.nband
    #     order = args.spectral_poly_order
    #     phi.trim_fat(model)
    #     I = np.argwhere(phi.mask).squeeze()
    #     Ix = I[:, 0]
    #     Iy = I[:, 1]
    #     npix = I.shape[0]

    #     # get components
    #     beta = model[:, Ix, Iy]

    #     # fit integrated polynomial to model components
    #     # we are given frequencies at bin centers, convert to bin edges
    #     ref_freq = np.mean(freq_out)
    #     delta_freq = freq_out[1] - freq_out[0]
    #     wlow = (freq_out - delta_freq/2.0)/ref_freq
    #     whigh = (freq_out + delta_freq/2.0)/ref_freq
    #     wdiff = whigh - wlow

    #     # set design matrix for each component
    #     Xdesign = np.zeros([freq_out.size, args.spectral_poly_order])
    #     for i in range(1, args.spectral_poly_order+1):
    #         Xdesign[:, i-1] = (whigh**i - wlow**i)/(i*wdiff)

    #     weights = psf_max[:, None]
    #     dirty_comps = Xdesign.T.dot(weights*beta)

    #     hess_comps = Xdesign.T.dot(weights*Xdesign)

    #     comps = np.linalg.solve(hess_comps, dirty_comps)

    #     np.savez(args.outfile + "spectral_comps", comps=comps, ref_freq=ref_freq, mask=np.any(model, axis=0))

    if args.write_model:
        print("Writing model", file=dest)
        R.write_model(model)

    if args.make_restored:
        print("Making restored", file=dest)
        cpsfo = PSF(cpsf, residual.shape, nthreads=args.nthreads)
        restored = cpsfo.convolve(model)

        # residual needs to be in Jy/beam before adding to convolved model
        wsums = np.amax(psf.reshape(-1, R.nx_psf * R.ny_psf), axis=1)
        restored += residual / wsums[:, None, None]

        save_fits(args.outfile + '_restored.fits', restored, hdr)
        restored_mfs = np.mean(restored, axis=0)
        save_fits(args.outfile + '_restored_mfs.fits', restored_mfs, hdr_mfs)
        residual_mfs = np.sum(residual, axis=0)
Exemple #5
0
def spotless(
        psf,
        model,
        residual,
        mask=None,
        beam_image=None,
        hessian=None,
        wsum=1,
        adapt_sig21=False,
        cpsf=None,
        hdr=None,
        hdr_mfs=None,
        outfile=None,
        nthreads=1,
        sig_21=1e-3,
        sigma_frac=100,
        maxit=10,
        tol=1e-4,
        peak_factor=0.01,
        threshold=0.0,
        positivity=True,
        gamma=0.9999,
        hbgamma=0.1,
        hbpf=0.1,
        hbmaxit=5000,
        hbverbose=1,
        pdtol=1e-4,
        pdmaxit=250,
        pdverbose=1,  # primal dual options
        cgtol=1e-4,
        cgminit=15,
        cgmaxit=150,
        cgverbose=1,  # pcg options
        pmtol=1e-4,
        pmmaxit=50,
        pmverbose=1):  # power method options
    """
    Modified clean algorithm:

    psf      - PSF image i.e. R.H W where W contains the weights.
               Shape must be >= residual.shape
    model    - current intrinsic model
    residual - apparent residual image i.e. R.H W (V - R A x)

    Note that peak finding happens in apparent residual because that
    is where it is easiest to accommodate convolution by the PSF.
    However, the beam and the mask have to be applied to the residual
    before we solve for the pre-conditioned updates.

    """

    if len(residual.shape) > 3:
        raise ValueError("Residual must have shape (nband, nx, ny)")

    nband, nx, ny = residual.shape

    if beam_image is None:

        def beam(x):
            return x

        def beaminv(x):
            return x
    else:
        try:
            assert beam.shape == (nband, nx, ny)

            def beam(x):
                return beam_image * x

            def beaminv(x):
                return np.where(beam_image > 0.01, x / beam_image, x)
        except BaseException:
            raise ValueError("Beam has incorrect shape")

    if mask is None:

        def mask(x):
            return x
    else:
        try:
            if mask.ndim == 2:
                assert mask.shape == (nx, ny)

                def mask(x):
                    return mask[None] * x
            elif mask.ndim == 3:
                assert mask.shape == (1, nx, ny)

                def mask(x):
                    return mask * x
            else:
                raise ValueError
        except BaseException:
            raise ValueError("Mask has incorrect shape")

    # PSF operator
    psfo = PSF(psf, residual.shape, nthreads=nthreads, backward_undersize=1.2)

    # set up point sources
    phi = Dirac(nband, nx, ny, mask=np.any(model, axis=0))
    dual = np.zeros((nband, nx, ny), dtype=np.float64)

    # clean beam
    if cpsf is not None:
        try:
            assert cpsf.shape == (1, ) + psf.shape[1::]
        except Exception as e:
            cpsf = cpsf[None, :, :]
        cpsfo = PSF(cpsf, residual.shape, nthreads=nthreads)

    residual_mfs = np.sum(residual, axis=0)
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)

    #  preconditioning operator
    varmap = np.ones(model.shape) * (sigma_frac * rmax)

    def hessb(x):
        return phi.hdot(mask(beam(psfo.convolveb(mask(beam(phi.dot(x))))))) +\
            x / varmap

    def hessf(x):
        return phi.hdot(mask(beam(psfo.convolve(mask(beam(phi.dot(x))))))) +\
                    x / varmap

    beta, betavec = power_method(hessb,
                                 residual.shape,
                                 tol=pmtol,
                                 maxit=pmmaxit,
                                 verbosity=pmverbose)

    if hessian is None:
        hessian = psf.convolve
        wsum = 1.0

    if model.any():
        dirty = residual + hessian(mask(beam(model))) / wsum
    else:
        dirty = residual

    # deconvolve
    threshold = np.maximum(peak_factor * rmax, threshold)
    alpha = sig_21
    for i in range(0, maxit):
        # find point source candidates
        modelu = hogbom(mask(residual),
                        psf,
                        gamma=hbgamma,
                        pf=hbpf,
                        maxit=hbmaxit,
                        verbosity=hbverbose)

        phi.update_locs(modelu)

        # solve for beta updates
        x = pcg(hessf,
                phi.hdot(mask(beam(residual))),
                phi.hdot(beaminv(modelu)),
                M=lambda x: x * (sigma_frac * rmax),
                tol=cgtol,
                maxit=cgmaxit,
                minit=cgminit,
                verbosity=cgverbose)

        modelp = model.copy()
        model += gamma * x

        weights_21 = np.where(phi.mask, alpha /
                              (alpha + np.abs(np.mean(modelp, axis=0))),
                              1e10)  # 1e10 for effective infinity
        beta, betavec = power_method(hessb,
                                     model.shape,
                                     b0=betavec,
                                     tol=pmtol,
                                     maxit=pmmaxit,
                                     verbosity=pmverbose)

        model, dual = primal_dual(hessb,
                                  model,
                                  modelp,
                                  dual,
                                  sig_21,
                                  phi,
                                  weights_21,
                                  beta,
                                  prox_21m,
                                  tol=pdtol,
                                  maxit=pdmaxit,
                                  axis=0,
                                  positivity=positivity,
                                  report_freq=50,
                                  verbosity=pdverbose)

        # update Dirac dictionary (remove zero components)
        phi.trim_fat(model)
        residual, residual_mfs = resid_func(model, dirty, hessian, mask, beam,
                                            wsum)

        model_mfs = np.mean(model, axis=0)

        # check stopping criteria
        rmax = np.abs(mask(residual_mfs)).max()
        rms = np.std(mask(residual_mfs))
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        # update variance map (positivity constraint optional)
        varmap = np.maximum(rmax * sigma_frac, sigma_frac * (rmax + model))

        print("Iter %i: peak residual = %f, rms = %f, eps = %f" %
              (i + 1, rmax, rms, eps),
              file=log)

        # save current iteration
        if outfile is not None:
            assert hdr is not None
            assert hdr_mfs is not None

            save_fits(outfile + str(i + 1) + '_model_mfs.fits', model_mfs,
                      hdr_mfs)

            save_fits(outfile + str(i + 1) + '_model.fits', model, hdr)

            save_fits(outfile + str(i + 1) + '_update.fits', x, hdr)

            save_fits(outfile + str(i + 1) + '_residual_mfs.fits',
                      residual_mfs, hdr_mfs)

            save_fits(outfile + str(i + 1) + '_residual.fits', residual * wsum,
                      hdr)

        if rmax < threshold or eps < tol:
            print("Success, convergence after %i iterations", file=log)
            break

        if adapt_sig21:
            # sig_21 should be set to the std of the image noise
            from scipy.stats import skew, kurtosis
            alpha = rms
            tmp = residual_mfs
            z = tmp / alpha
            k = 0
            while (np.abs(skew(z.ravel(), nan_policy='omit')) > 0.05 or
                   np.abs(kurtosis(z.ravel(), fisher=True,
                                   nan_policy='omit')) > 0.5) and k < 10:
                # eliminate outliers
                tmp = np.where(np.abs(z) < 3, residual_mfs, np.nan)
                alpha = np.nanstd(tmp)
                z = tmp / alpha
                print(alpha, skew(z.ravel(), nan_policy='omit'),
                      kurtosis(z.ravel(), fisher=True, nan_policy='omit'))
                k += 1

            sig_21 = alpha
            print("alpha set to %f" % (alpha), file=log)

    return model
Exemple #6
0
def _forward(**kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)

    import numpy as np
    import numexpr as ne
    import dask
    import dask.array as da
    from dask.distributed import performance_report
    from pfb.utils.fits import load_fits, set_wcs, save_fits, data_from_header
    from pfb.opt.hogbom import hogbom
    from astropy.io import fits

    print("Loading residual", file=log)
    residual = load_fits(args.residual, dtype=args.output_type).squeeze()
    nband, nx, ny = residual.shape
    hdr = fits.getheader(args.residual)

    print("Loading psf", file=log)
    psf = load_fits(args.psf, dtype=args.output_type).squeeze()
    _, nx_psf, ny_psf = psf.shape
    hdr_psf = fits.getheader(args.psf)

    wsums = np.amax(psf.reshape(-1, nx_psf*ny_psf), axis=1)
    wsum = np.sum(wsums)

    psf /= wsum
    psf_mfs = np.sum(psf, axis=0)

    assert (psf_mfs.max() - 1.0) < 1e-4

    residual /= wsum
    residual_mfs = np.sum(residual, axis=0)

    # get info required to set WCS
    ra = np.deg2rad(hdr['CRVAL1'])
    dec = np.deg2rad(hdr['CRVAL2'])
    radec = [ra, dec]

    cell_deg = np.abs(hdr['CDELT1'])
    if cell_deg != np.abs(hdr['CDELT2']):
        raise NotImplementedError('cell sizes have to be equal')
    cell_rad = np.deg2rad(cell_deg)

    l_coord, ref_l = data_from_header(hdr, axis=1)
    l_coord -= ref_l
    m_coord, ref_m = data_from_header(hdr, axis=2)
    m_coord -= ref_m
    freq_out, ref_freq = data_from_header(hdr, axis=3)

    hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq)

    save_fits(args.output_filename + '_residual_mfs.fits', residual_mfs, hdr_mfs,
              dtype=args.output_type)

    rms = np.std(residual_mfs)
    rmax = np.abs(residual_mfs).max()

    print("Initial peak residual = %f, rms = %f" % (rmax, rms), file=log)

    # load beam
    if args.beam_model is not None:
        if args.beam_model.endswith('.fits'):  # beam already interpolated
            bhdr = fits.getheader(args.beam_model)
            l_coord_beam, ref_lb = data_from_header(bhdr, axis=1)
            l_coord_beam -= ref_lb
            if not np.array_equal(l_coord_beam, l_coord):
                raise ValueError("l coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.")

            m_coord_beam, ref_mb = data_from_header(bhdr, axis=2)
            m_coord_beam -= ref_mb
            if not np.array_equal(m_coord_beam, m_coord):
                raise ValueError("m coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.")

            freq_beam, _ = data_from_header(bhdr, axis=freq_axis)
            if not np.array_equal(freq_out, freq_beam):
                raise ValueError("Freqs of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.")

            beam_image = load_fits(args.beam_model, dtype=args.output_type).squeeze()
        elif args.beam_model.lower() == "jimbeam":
            from katbeam import JimBeam
            if args.band.lower() == 'l':
                beam = JimBeam('MKAT-AA-L-JIM-2020')
            elif args.band.lower() == 'uhf':
                beam = JimBeam('MKAT-AA-UHF-JIM-2020')
            else:
                raise ValueError("Unkown band %s"%args.band[i])

            xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')
            beam_image = np.zeros(residual.shape, dtype=args.output_type)
            for v in range(freq_out.size):
                # freq must be in MHz
                beam_image[v] = beam.I(xx, yy, freq_out[v]/1e6).astype(args.output_type)
    else:
        beam_image = np.ones((nband, nx, ny), dtype=args.output_type)

    if args.mask is not None:
        mask = load_fits(args.mask).squeeze()
        assert mask.shape == (nx, ny)
        beam_image *= mask[None, :, :]

    beam_image = da.from_array(beam_image, chunks=(1, -1, -1))

    # if weight table is provided we use the vis space Hessian approximation
    if args.weight_table is not None:
        print("Solving for update using vis space approximation", file=log)
        normfact = wsum
        from pfb.utils.misc import plan_row_chunk
        from daskms.experimental.zarr import xds_from_zarr

        xds = xds_from_zarr(args.weight_table)[0]
        nrow = xds.row.size
        freq = xds.chan.data
        nchan = freq.size

        # bin edges
        fmin = freq.min()
        fmax = freq.max()
        fbins = np.linspace(fmin, fmax, nband + 1)

        # chan <-> band mapping
        band_mapping = {}
        chan_chunks = {}
        freq_bin_idx = {}
        freq_bin_counts = {}
        band_map = np.zeros(freq.size, dtype=np.int32)
        for band in range(nband):
            indl = freq >= fbins[band]
            indu = freq < fbins[band + 1] + 1e-6
            band_map = np.where(indl & indu, band, band_map)

        # to dask arrays
        bands, bin_counts = np.unique(band_map, return_counts=True)
        band_mapping = tuple(bands)
        chan_chunks = {'chan': tuple(bin_counts)}
        freq = da.from_array(freq, chunks=tuple(bin_counts))
        bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1]
        freq_bin_idx = da.from_array(bin_idx, chunks=1)
        freq_bin_counts = da.from_array(bin_counts, chunks=1)

        max_chan_chunk = bin_counts.max()
        bin_counts = tuple(bin_counts)
        # the first factor of 3 accounts for the intermediate visibilities
        # produced in Hessian (i.e. complex data + real weights)
        memory_per_row = (3 * max_chan_chunk * xds.WEIGHT.data.itemsize +
                          3 * xds.UVW.data.itemsize)

        # get approx image size
        pixel_bytes = np.dtype(args.output_type).itemsize
        band_size = nx * ny * pixel_bytes

        if args.host_address is None:
            # nworker bands on single node
            row_chunk = plan_row_chunk(args.mem_limit/args.nworkers, band_size, nrow,
                                       memory_per_row, args.nthreads_per_worker)
        else:
            # single band per node
            row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                       memory_per_row, args.nthreads_per_worker)

        print("nrows = %i, row chunks set to %i for a total of %i chunks per node" %
              (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log)

        residual = da.from_array(residual, chunks=(1, -1, -1))
        x0 = da.zeros((nband, nx, ny), chunks=(1, -1, -1), dtype=residual.dtype)

        xds = xds_from_zarr(args.weight_table, chunks={'row': -1, #row_chunk,
                            'chan': bin_counts})[0]

        from pfb.opt.pcg import pcg_wgt

        model = pcg_wgt(xds.UVW.data,
                        xds.WEIGHT.data.astype(args.output_type),
                        residual,
                        x0,
                        beam_image,
                        freq,
                        freq_bin_idx,
                        freq_bin_counts,
                        cell_rad,
                        args.wstack,
                        args.epsilon,
                        args.double_accum,
                        args.nvthreads,
                        args.sigmainv,
                        wsum,
                        args.cg_tol,
                        args.cg_maxit,
                        args.cg_minit,
                        args.cg_verbose,
                        args.cg_report_freq,
                        args.backtrack).compute()

    else:  # we use the image space approximation
        print("Solving for update using image space approximation", file=log)
        normfact = 1.0
        from pfb.operators.psf import hessian
        from ducc0.fft import r2c
        iFs = np.fft.ifftshift

        npad_xl = (nx_psf - nx)//2
        npad_xr = nx_psf - nx - npad_xl
        npad_yl = (ny_psf - ny)//2
        npad_yr = ny_psf - ny - npad_yl
        padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr))
        unpad_x = slice(npad_xl, -npad_xr)
        unpad_y = slice(npad_yl, -npad_yr)
        lastsize = ny + np.sum(padding[-1])
        psf_pad = iFs(psf, axes=(1, 2))
        psfhat = r2c(psf_pad, axes=(1, 2), forward=True,
                     nthreads=nthreads, inorm=0)

        psfhat = da.from_array(psfhat, chunks=(1, -1, -1))
        residual = da.from_array(residual, chunks=(1, -1, -1))
        x0 = da.zeros((nband, nx, ny), chunks=(1, -1, -1))


        from pfb.opt.pcg import pcg_psf

        model = pcg_psf(psfhat,
                        residual,
                        x0,
                        beam_image,
                        args.sigmainv,
                        args.nvthreads,
                        padding,
                        unpad_x,
                        unpad_y,
                        lastsize,
                        args.cg_tol,
                        args.cg_maxit,
                        args.cg_minit,
                        args.cg_verbose,
                        args.cg_report_freq,
                        args.backtrack).compute()


    print("Saving results", file=log)
    save_fits(args.output_filename + '_update.fits', model, hdr)
    model_mfs = np.mean(model, axis=0)
    save_fits(args.output_filename + '_update_mfs.fits', model_mfs, hdr_mfs)

    print("All done here.", file=log)
Exemple #7
0
def _spifit(**kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)

    import dask.array as da
    import numpy as np
    from astropy.io import fits
    from africanus.model.spi.dask import fit_spi_components
    from pfb.utils.fits import load_fits, save_fits, data_from_header, set_wcs
    from pfb.utils.misc import convolve2gaussres

    # get max gausspars
    gaussparf = None
    if args.psf_pars is None:
        if args.residual is None:
            ppsource = args.image
        else:
            ppsource = args.residual

        for image in ppsource:
            try:
                pphdr = fits.getheader(image)
            except Exception as e:
                raise e

            if 'BMAJ0' in pphdr.keys():
                emaj = pphdr['BMAJ0']
                emin = pphdr['BMIN0']
                pa = pphdr['BPA0']
                gausspars = [emaj, emin, pa]
                freq_idx0 = 0
            elif 'BMAJ1' in pphdr.keys():
                emaj = pphdr['BMAJ1']
                emin = pphdr['BMIN1']
                pa = pphdr['BPA1']
                gausspars = [emaj, emin, pa]
                freq_idx0 = 1
            elif 'BMAJ' in pphdr.keys():
                emaj = pphdr['BMAJ']
                emin = pphdr['BMIN']
                pa = pphdr['BPA']
                gausspars = [emaj, emin, pa]
                freq_idx0 = 0
            else:
                raise ValueError("No beam parameters found in residual."
                                "You will have to provide them manually.")

            if gaussparf is None:
                gaussparf = gausspars
            else:
                # we need to take the max in both directions
                gaussparf[0] = np.maximum(gaussparf[0], gausspars[0])
                gaussparf[1] = np.maximum(gaussparf[1], gausspars[1])
    else:
        freq_idx0 = 0  # assumption
        gaussparf = list(args.psf_pars)

    if args.circ_psf:
        e = np.maximum(gaussparf[0], gaussparf[1])
        gaussparf[0] = e
        gaussparf[1] = e
        gaussparf[2] = 0.0

    gaussparf = tuple(gaussparf)
    print("Using emaj = %3.2e, emin = %3.2e, PA = %3.2e \n" % gaussparf, file=log)

    # get required data products
    image_dict = {}
    for i in range(len(args.image)):
        image_dict[i] = {}

        # load model image
        model = load_fits(args.image[i], dtype=args.out_dtype).squeeze()
        mhdr = fits.getheader(args.image[i])

        if model.ndim < 3:
            model = model[None, :, :]

        l_coord, ref_l = data_from_header(mhdr, axis=1)
        l_coord -= ref_l
        m_coord, ref_m = data_from_header(mhdr, axis=2)
        m_coord -= ref_m
        if mhdr["CTYPE4"].lower() == 'freq':
            freq_axis = 4
            stokes_axis = 3
        elif mhdr["CTYPE3"].lower() == 'freq':
            freq_axis = 3
            stokes_axis = 4
        else:
            raise ValueError("Freq axis must be 3rd or 4th")

        freqs, ref_freq = data_from_header(mhdr, axis=freq_axis)

        image_dict[i]['freqs'] = freqs

        nband = freqs.size
        npix_l = l_coord.size
        npix_m = m_coord.size

        xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')

        # load beam
        if args.beam_model is not None:
            bhdr = fits.getheader(args.beam_model[i])
            l_coord_beam, ref_lb = data_from_header(bhdr, axis=1)
            l_coord_beam -= ref_lb
            if not np.array_equal(l_coord_beam, l_coord):
                raise ValueError("l coordinates of beam model do not match "
                                 "those of image. Use binterp to make "
                                 "compatible beam images")

            m_coord_beam, ref_mb = data_from_header(bhdr, axis=2)
            m_coord_beam -= ref_mb
            if not np.array_equal(m_coord_beam, m_coord):
                raise ValueError("m coordinates of beam model do not match "
                                 "those of image. Use binterp to make "
                                 "compatible beam images")
            freqs_beam, _ = data_from_header(bhdr, axis=freq_axis)
            if not np.array_equal(freqs, freqs_beam):
                raise ValueError("Freq coordinates of beam model do not match "
                                 "those of image. Use binterp to make "
                                 "compatible beam images")
            beam_image = load_fits(args.beam_model[i],
                                   dtype=args.out_dtype).squeeze()

            if beam_image.ndim < 3:
                beam_image = beam_image[None, :, :]

        else:
            beam_image = np.ones(model.shape, dtype=args.out_dtype)

        image_dict[i]['beam'] = beam_image

        if not args.dont_convolve:
            print("Convolving model %i"%i, file=log)
            # convolve model to desired resolution
            model, gausskern = convolve2gaussres(model, xx, yy, gaussparf,
                                                 args.nthreads, None,
                                                 args.padding_frac)

        image_dict[i]['model'] = model

        # add in residuals and set threshold
        if args.residual is not None:
            msg = "of residual do not match those of model"
            rhdr = fits.getheader(args.residual[i])
            l_res, ref_lb = data_from_header(rhdr, axis=1)
            l_res -= ref_lb
            if not np.array_equal(l_res, l_coord):
                raise ValueError("l coordinates " + msg)

            m_res, ref_mb = data_from_header(rhdr, axis=2)
            m_res -= ref_mb
            if not np.array_equal(m_res, m_coord):
                raise ValueError("m coordinates " + msg)

            freqs_res, _ = data_from_header(rhdr, axis=freq_axis)
            if not np.array_equal(freqs, freqs_res):
                raise ValueError("Freqs " + msg)

            resid = load_fits(args.residual[i],
                              dtype=args.out_dtype).squeeze()
            if resid.ndim < 3:
                resid = resid[None, :, :]

            # convolve residual to same resolution as model
            gausspari = ()
            for b in range(nband):
                key = 'BMAJ' + str(b + freq_idx0)
                if key in rhdr.keys():
                    emaj = rhdr[key]
                    emin = rhdr[key]
                    pa = rhdr[key]
                    gausspari += ((emaj, emin, pa),)
                elif 'BMAJ' in rhdr.keys():
                    emaj = rhdr['BMAJ']
                    emin = rhdr['BMIN']
                    pa = rhdr['BPA']
                    gausspari += ((emaj, emin, pa),)
                else:
                    print("Can't find Gausspars in residual header, "
                          "unable to add residuals back in", file=log)
                    gausspari = None
                    break

            if gausspari is not None and args.add_convolved_residuals:
                print("Convolving residuals %i"%i, file=log)
                resid, _ = convolve2gaussres(resid, xx, yy, gaussparf,
                                             args.nthreads, gausspari,
                                             args.padding_frac,
                                             norm_kernel=False)
                model += resid
                print("Convolved residuals added to convolved model %i"%i,
                      file=log)


            image_dict[i]['resid'] = resid

        else:
            image_dict[i]['resid'] = None

    # concatenate images along frequency here
    freqs = []
    model = []
    beam_image = []
    resid = []
    for i in image_dict.keys():
        freqs.append(image_dict[i]['freqs'])
        model.append(image_dict[i]['model'])
        beam_image.append(image_dict[i]['beam'])
        resid.append(image_dict[i]['resid'])
    freqs = np.concatenate(freqs, axis=0)
    Isort = np.argsort(freqs)
    freqs = freqs[Isort]

    model = np.concatenate(model, axis=0)
    model = model[Isort]

    # create header
    cell_deg = mhdr['CDELT1']
    ra = np.deg2rad(mhdr['CRVAL1'])
    dec = np.deg2rad(mhdr['CRVAL2'])
    radec = [ra, dec]
    nband, nx, ny = model.shape
    hdr = set_wcs(cell_deg, cell_deg, nx, ny, radec, freqs)
    for i in range(1, nband+1):
        hdr['BMAJ' + str(i)] = gaussparf[0]
        hdr['BMIN' + str(i)] = gaussparf[1]
        hdr['BPA' + str(i)] = gaussparf[2]
    if args.ref_freq is None:
        ref_freq = np.mean(freqs)
    else:
        ref_freq = args.ref_freq
    hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq)
    hdr_mfs['BMAJ'] = gaussparf[0]
    hdr_mfs['BMIN'] = gaussparf[1]
    hdr_mfs['BPA'] = gaussparf[2]

    # save convolved model
    if 'm' in args.products:
        name = args.output_filename + '.convolved_model.fits'
        save_fits(name, model, hdr, dtype=args.out_dtype)
        print("Wrote convolved model to %s" % name, file=log)

    beam_image = np.concatenate(beam_image, axis=0)
    beam_image = beam_image[Isort]

    if 'b' in args.products:
        name = args.output_filename + '.power_beam.fits'
        save_fits(name, beam_image, hdr, dtype=args.out_dtype)
        print("Wrote average power beam to %s" % name, file=log)

    if resid[0] is not None:
        resid = np.concatenate(resid, axis=0)
        resid = resid[Isort]

        if 'r' in args.products:
            name = args.output_filename + '.convolved_residual.fits'
            save_fits(name, resid, hdr, dtype=args.out_dtype)
            print("Wrote convolved residuals to %s" % name, file=log)

        # get threshold
        counts = np.sum(resid != 0)
        rms = np.sqrt(np.sum(resid**2)/counts)
        rms_cube = np.std(resid.reshape(nband, npix_l*npix_m), axis=1).ravel()
        threshold = args.threshold * rms
    else:
        print("No residual provided. Setting  threshold i.t.o dynamic range. "
              "Max dynamic range is %i " % args.maxdr, file=log)
        threshold = model.max()/args.maxdr
        rms_cube = None

    print("Threshold set to %f Jy. \n" % threshold, file=log)

    # beam cut off
    beam_min = np.amin(beam_image, axis=0)
    model = np.where(beam_min[None] > args.pb_min, model, 0.0)

    # get pixels above threshold
    minimage = np.amin(model, axis=0)
    maskindices = np.argwhere(minimage > threshold)
    nanindices = np.argwhere(minimage <= threshold)
    if not maskindices.size:
        raise ValueError("No components found above threshold. "
                        "Try lowering your threshold."
                        "Max of convolved model is %3.2e" % model.max())
    fitcube = model[:, maskindices[:, 0], maskindices[:, 1]].T
    beam_comps = beam_image[:, maskindices[:, 0], maskindices[:, 1]].T

    # set weights for fit
    if rms_cube is not None:
        print("Using RMS in each imaging band to determine weights.", file=log)
        weights = np.where(rms_cube > 0, 1.0/rms_cube**2, 0.0)
        # normalise
        weights /= weights.max()
    else:
        if args.band_weights is not None:
            weights = np.array(args.band_weights)
            try:
                assert weights.size == nband
            except Exception as e:
                raise ValueError("Inconsistent weighst provided.")
            print("Using provided channel weights.", file=log)
        else:
            print("No residual or channel weights provided. Using equal weights.", file=log)
            weights = np.ones(nband, dtype=np.float64)

    ncomps, _ = fitcube.shape
    fitcube = da.from_array(fitcube.astype(np.float64),
                            chunks=(ncomps//args.nthreads, nband))
    beam_comps = da.from_array(beam_comps.astype(np.float64),
                               chunks=(ncomps//args.nthreads, nband))
    weights = da.from_array(weights.astype(np.float64), chunks=(nband))
    freqsdask = da.from_array(freqs.astype(np.float64), chunks=(nband))

    print("Fitting %i components" % ncomps, file=log)
    alpha, alpha_err, Iref, i0_err = fit_spi_components(fitcube, weights, freqsdask,
                                        np.float64(ref_freq), beam=beam_comps).compute()
    print("Done. Writing output.", file=log)

    alphamap = np.zeros(model[0].shape, dtype=model.dtype)
    alphamap[...] = np.nan
    alpha_err_map = np.zeros(model[0].shape, dtype=model.dtype)
    alpha_err_map[...] = np.nan
    i0map = np.zeros(model[0].shape, dtype=model.dtype)
    i0map[...] = np.nan
    i0_err_map = np.zeros(model[0].shape, dtype=model.dtype)
    i0_err_map[...] = np.nan
    alphamap[maskindices[:, 0], maskindices[:, 1]] = alpha
    alpha_err_map[maskindices[:, 0], maskindices[:, 1]] = alpha_err
    i0map[maskindices[:, 0], maskindices[:, 1]] = Iref
    i0_err_map[maskindices[:, 0], maskindices[:, 1]] = i0_err

    if 'I' in args.products:
        # get the reconstructed cube
        Irec_cube = i0map[None, :, :] * \
            (freqs[:, None, None]/ref_freq)**alphamap[None, :, :]
        name = args.output_filename + '.Irec_cube.fits'
        save_fits(name, Irec_cube, hdr, dtype=args.out_dtype)
        print("Wrote reconstructed cube to %s" % name, file=log)

    # save alpha map
    if 'a' in args.products:
        name = args.output_filename + '.alpha.fits'
        save_fits(name, alphamap, hdr_mfs, dtype=args.out_dtype)
        print("Wrote alpha map to %s" % name, file=log)

    # save alpha error map
    if 'e' in args.products:
        name = args.output_filename + '.alpha_err.fits'
        save_fits(name, alpha_err_map, mhdr, dtype=args.out_dtype)
        print("Wrote alpha error map to %s" % name, file=log)

    # save I0 map
    if 'i' in args.products:
        name = args.output_filename + '.I0.fits'
        save_fits(name, i0map, mhdr, dtype=args.out_dtype)
        print("Wrote I0 map to %s" % name, file=log)

    # save I0 error map
    if 'k' in args.products:
        name = args.output_filename + '.I0_err.fits'
        save_fits(name, i0_err_map, mhdr, dtype=args.out_dtype)
        print("Wrote I0 error map to %s" % name, file=log)

    print("All done here", file=log)
Exemple #8
0
def nnls(**kw):
    '''
    Minor cycle implementing non-negative least squares
    '''
    args = OmegaConf.create(kw)
    pyscilog.log_to_file(args.output_filename + '.log')
    pyscilog.enable_memory_logging(level=3)

    print('Input Options:', file=log)
    for key in kw.keys():
        print('     %25s = %s' % (key, kw[key]), file=log)

    from pfb.utils.fits import load_fits
    from astropy.io import fits
    import numpy as np

    def resid_func(x, dirty, psfo):
        """
        Returns the unattenuated residual
        """
        residual = dirty - psfo.convolve(x)
        residual_mfs = np.sum(residual, axis=0)
        return residual, residual_mfs

    def value_and_grad(x, dirty, psfo):
        model_conv = psfo.convolve(x)
        return np.vdot(x, model_conv - 2 * dirty), 2 * (model_conv - dirty)

    def prox(x):
        x[x < args.min_value] = 0.0
        return x

    dirty = load_fits(args.dirty).squeeze()
    nband, nx, ny = dirty.shape
    hdr = fits.getheader(args.dirty)

    psf = load_fits(args.psf).squeeze()
    _, nx_psf, ny_psf = psf.shape
    hdr_psf = fits.getheader(args.psf)

    wsums = np.amax(psf.reshape(-1, nx_psf * ny_psf), axis=1)
    wsum = np.sum(wsums)

    psf /= wsum
    psf_mfs = np.sum(psf, axis=0)

    assert (psf_mfs.max() - 1.0) < 1e-4

    dirty /= wsum
    dirty_mfs = np.sum(dirty, axis=0)

    from pfb.operators.psf import PSF
    psfo = PSF(psf, dirty.shape, nthreads=args.nthreads)

    from pfb.opt.power_method import power_method

    beta, betavec = power_method(psfo.convolve,
                                 dirty.shape,
                                 tol=args.pm_tol,
                                 maxit=args.pm_maxit,
                                 verbosity=args.pm_verbose,
                                 report_freq=args.pm_report_freq)

    fprime = partial(value_and_grad, dirty=dirty, psfo=psfo)

    from pfb.opt.fista import fista

    if args.x0 is None:
        x0 = np.zeros_like(dirty)
    else:
        x0 = load_fits(args.x0, dtype=dirty.dtype).squeeze()

    model = fista(x0,
                  beta,
                  fprime,
                  prox,
                  tol=args.fista_tol,
                  maxit=args.fista_maxit,
                  verbosity=args.fista_verbose,
                  report_freq=args.fista_report_freq)

    residual, residual_mfs = resid_func(model, dirty, psfo)

    from pfb.utils.fits import save_fits

    save_fits(args.output_filename + '_model.fits', model, hdr)
    save_fits(args.output_filename + '_residual.fits', residual, hdr)
Exemple #9
0
def nnls(psf,
         model,
         residual,
         mask=None,
         beam_image=None,
         hessian=None,
         wsum=None,
         gamma=0.95,
         hdr=None,
         hdr_mfs=None,
         outfile=None,
         nthreads=1,
         maxit=1,
         tol=1e-3,
         pmtol=1e-5,
         pmmaxit=50,
         pmverbose=1,
         ftol=1e-5,
         fmaxit=250,
         fverbose=3):

    if len(residual.shape) > 3:
        raise ValueError("Residual must have shape (nband, nx, ny)")

    nband, nx, ny = residual.shape

    if beam_image is None:

        def beam(x):
            return x
    else:
        try:
            assert beam.shape == (nband, nx, ny)

            def beam(x):
                return beam_image * x
        except BaseException:
            raise ValueError("Beam has incorrect shape")

    if mask is None:

        def mask(x):
            return x
    else:
        try:
            if mask.ndim == 2:
                assert mask.shape == (nx, ny)

                def mask(x):
                    return mask[None] * x
            elif mask.ndim == 3:
                assert mask.shape == (1, nx, ny)

                def mask(x):
                    return mask * x
            else:
                raise ValueError
        except BaseException:
            raise ValueError("Mask has incorrect shape")

    # PSF operator
    psfo = PSF(psf, residual.shape, nthreads=nthreads)

    residual_mfs = np.sum(residual, axis=0)
    residual = mask(beam(residual))
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)

    if hessian is None:
        hessian = psfo.convolve
        wsum = 1

    def hess(x):
        return mask(beam(psfo.convolve(mask(beam(x)))))

    beta, betavec = power_method(hess,
                                 residual.shape,
                                 tol=pmtol,
                                 maxit=pmmaxit,
                                 verbosity=pmverbose)

    if model.any():
        dirty = residual + hessian(mask(beam(model))) / wsum
    else:
        dirty = residual

    for i in range(maxit):
        fprime = partial(value_and_grad,
                         dirty=residual,
                         psfo=psfo,
                         mask=mask,
                         beam=beam)

        x = fista(np.zeros_like(model),
                  beta,
                  fprime,
                  prox,
                  tol=ftol,
                  maxit=fmaxit,
                  verbosity=fverbose)

        modelp = model.copy()
        model += gamma * x

        residual, residual_mfs = resid_func(model, dirty, hessian, mask, beam,
                                            wsum)
        model_mfs = np.mean(model, axis=0)

        # check stopping criteria
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        print("Iter %i: peak residual = %f, rms = %f, eps = %f" %
              (i + 1, rmax, rms, eps),
              file=log)

        # save current iteration
        if outfile is not None:
            assert hdr is not None
            assert hdr_mfs is not None

            save_fits(outfile + str(i + 1) + '_NNLS_model_mfs.fits', model_mfs,
                      hdr_mfs)

            save_fits(outfile + str(i + 1) + '_NNLS_model.fits', model, hdr)

            save_fits(outfile + str(i + 1) + '_NNLS_residual_mfs.fits',
                      residual_mfs, hdr_mfs)

        if eps < tol:
            print("Success, convergence after %i iterations" % (i + 1),
                  file=log)
            break

    return model
Exemple #10
0
def _binterp(**kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)

    from pfb.utils.fits import save_fits
    import dask
    import dask.array as da
    import numpy as np
    from numba import jit
    from astropy.io import fits
    import warnings
    from africanus.rime import parallactic_angles
    from pfb.utils.fits import load_fits, save_fits, data_from_header
    from daskms import xds_from_ms, xds_from_table

    if args.ms is None:
        if args.beam_model.lower() == 'jimbeam':
            for image in args.image:
                mhdr = fits.getheader(image)
                l_coord, ref_l = data_from_header(mhdr, axis=1)
                l_coord -= ref_l
                m_coord, ref_m = data_from_header(mhdr, axis=2)
                m_coord -= ref_m
                if mhdr["CTYPE4"].lower() == 'freq':
                    freq_axis = 4
                    stokes_axis = 3
                elif mhdr["CTYPE3"].lower() == 'freq':
                    freq_axis = 3
                    stokes_axis = 4
                else:
                    raise ValueError("Freq axis must be 3rd or 4th")

                freq, ref_freq = data_from_header(mhdr, axis=freq_axis)

                from katbeam import JimBeam
                if args.band.lower() == 'l':
                    beam = JimBeam('MKAT-AA-L-JIM-2020')
                elif args.band.lower() == 'uhf':
                    beam = JimBeam('MKAT-AA-UHF-JIM-2020')
                else:
                    raise ValueError("Unkown band %s" % args.band[i])

                xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')
                beam_image = np.zeros((freq.size, l_coord.size, m_coord.size),
                                      dtype=args.out_dtype)
                for v in range(freq.size):
                    # freq must be in MHz
                    beam_image[v] = beam.I(xx, yy, freq[v] / 1e6).astype(
                        args.out_dtype)

                if args.output_dir in image:
                    idx = len(args.output_dir)
                    iname = image[idx::]
                    outname = iname + '.' + args.postfix
                else:
                    outname = image + '.' + args.postfix

                beam_image = np.expand_dims(beam_image,
                                            axis=3 - stokes_axis + 1)
                save_fits(args.output_dir + outname,
                          beam_image,
                          mhdr,
                          dtype=args.out_dtype)

        else:
            raise NotImplementedError("Not there yet, sorry")

    print("All done here.", file=log)


# @jit(nopython=True, nogil=True, cache=True)
# def _unflagged_counts(flags, time_idx, out):
#     for i in range(time_idx.size):
#         ilow = time_idx[i]
#         ihigh = time_idx[i+1]
#         out[i] = np.sum(~flags[ilow:ihigh])
#     return out

# def extract_dde_info(args, freqs):
#     """
#     Computes paralactic angles, antenna scaling and pointing information
#     required for beam interpolation.
#     """
#     # get ms info required to compute paralactic angles and weighted sum
#     nband = freqs.size
#     if args.ms is not None:
#         utimes = []
#         unflag_counts = []
#         ant_pos = None
#         phase_dir = None
#         for ms_name in args.ms:
#             # get antenna positions
#             ant = xds_from_table(ms_name + '::ANTENNA')[0].compute()
#             if ant_pos is None:
#                 ant_pos = ant['POSITION'].data
#             else:  # check all are the same
#                 tmp = ant['POSITION']
#                 if not np.array_equal(ant_pos, tmp):
#                     raise ValueError(
#                         "Antenna positions not the same across measurement sets")

#             # get phase center for field
#             field = xds_from_table(ms_name + '::FIELD')[0].compute()
#             if phase_dir is None:
#                 phase_dir = field['PHASE_DIR'][args.field].data.squeeze()
#             else:
#                 tmp = field['PHASE_DIR'][args.field].data.squeeze()
#                 if not np.array_equal(phase_dir, tmp):
#                     raise ValueError(
#                         'Phase direction not the same across measurement sets')

#             # get unique times and count flags
#             xds = xds_from_ms(ms_name, columns=["TIME", "FLAG_ROW"], group_cols=[
#                               "FIELD_ID"])[args.field]
#             utime, time_idx = np.unique(
#                 xds.TIME.data.compute(), return_index=True)
#             ntime = utime.size
#             # extract subset of times
#             if args.sparsify_time > 1:
#                 I = np.arange(0, ntime, args.sparsify_time)
#                 utime = utime[I]
#                 time_idx = time_idx[I]
#                 ntime = utime.size

#             utimes.append(utime)

#             flags = xds.FLAG_ROW.data.compute()
#             unflag_count = _unflagged_counts(flags.astype(
#                 np.int32), time_idx, np.zeros(ntime, dtype=np.int32))
#             unflag_counts.append(unflag_count)

#         utimes = np.concatenate(utimes)
#         unflag_counts = np.concatenate(unflag_counts)
#         ntimes = utimes.size

#         # compute paralactic angles
#         parangles = parallactic_angles(utimes, ant_pos, phase_dir)

#         # mean over antanna nant -> 1
#         parangles = np.mean(parangles, axis=1, keepdims=True)
#         nant = 1

#         # beam_cube_dde requirements
#         ant_scale = np.ones((nant, nband, 2), dtype=np.float64)
#         point_errs = np.zeros((ntimes, nant, nband, 2), dtype=np.float64)

#         return (parangles,
#                 da.from_array(ant_scale, chunks=ant_scale.shape),
#                 point_errs,
#                 unflag_counts,
#                 True)
#     else:
#         ntimes = 1
#         nant = 1
#         parangles = np.zeros((ntimes, nant,), dtype=np.float64)
#         ant_scale = np.ones((nant, nband, 2), dtype=np.float64)
#         point_errs = np.zeros((ntimes, nant, nband, 2), dtype=np.float64)
#         unflag_counts = np.array([1])

#         return (parangles, ant_scale, point_errs, unflag_counts, False)

# def make_power_beam(args, lm_source, freqs, use_dask):
#     print("Loading fits beam patterns from %s" % args.beam_model)
#     from glob import glob
#     paths = glob(args.beam_model + '**_**.fits')
#     beam_hdr = None
#     if args.corr_type == 'linear':
#         corr1 = 'XX'
#         corr2 = 'YY'
#     elif args.corr_type == 'circular':
#         corr1 = 'LL'
#         corr2 = 'RR'
#     else:
#         raise KeyError(
#             "Unknown corr_type supplied. Only 'linear' or 'circular' supported")

#     for path in paths:
#         if corr1.lower() in path[-10::]:
#             if 're' in path[-7::]:
#                 corr1_re = load_fits(path)
#                 if beam_hdr is None:
#                     beam_hdr = fits.getheader(path)
#             elif 'im' in path[-7::]:
#                 corr1_im = load_fits(path)
#             else:
#                 raise NotImplementedError("Only re/im patterns supported")
#         elif corr2.lower() in path[-10::]:
#             if 're' in path[-7::]:
#                 corr2_re = load_fits(path)
#             elif 'im' in path[-7::]:
#                 corr2_im = load_fits(path)
#             else:
#                 raise NotImplementedError("Only re/im patterns supported")

#     # get power beam
#     beam_amp = (corr1_re**2 + corr1_im**2 + corr2_re**2 + corr2_im**2)/2.0

#     # get cube in correct shape for interpolation code
#     beam_amp = np.ascontiguousarray(np.transpose(beam_amp, (1, 2, 0))
#                                     [:, :, :, None, None])
#     # get cube info
#     if beam_hdr['CUNIT1'].lower() != "deg":
#         raise ValueError("Beam image units must be in degrees")
#     npix_l = beam_hdr['NAXIS1']
#     refpix_l = beam_hdr['CRPIX1']
#     delta_l = beam_hdr['CDELT1']
#     l_min = (1 - refpix_l)*delta_l
#     l_max = (1 + npix_l - refpix_l)*delta_l

#     if beam_hdr['CUNIT2'].lower() != "deg":
#         raise ValueError("Beam image units must be in degrees")
#     npix_m = beam_hdr['NAXIS2']
#     refpix_m = beam_hdr['CRPIX2']
#     delta_m = beam_hdr['CDELT2']
#     m_min = (1 - refpix_m)*delta_m
#     m_max = (1 + npix_m - refpix_m)*delta_m

#     if (l_min > lm_source[:, 0].min() or m_min > lm_source[:, 1].min() or
#             l_max < lm_source[:, 0].max() or m_max < lm_source[:, 1].max()):
#         raise ValueError("The supplied beam is not large enough")

#     beam_extents = np.array([[l_min, l_max], [m_min, m_max]])

#     # get frequencies
#     if beam_hdr["CTYPE3"].lower() != 'freq':
#         raise ValueError(
#             "Cubes are assumed to be in format [nchan, nx, ny]")
#     nchan = beam_hdr['NAXIS3']
#     refpix = beam_hdr['CRPIX3']
#     delta = beam_hdr['CDELT3']  # assumes units are Hz
#     freq0 = beam_hdr['CRVAL3']
#     bfreqs = freq0 + np.arange(1 - refpix, 1 + nchan - refpix) * delta
#     if bfreqs[0] > freqs[0] or bfreqs[-1] < freqs[-1]:
#         warnings.warn("The supplied beam does not have sufficient "
#                       "bandwidth. Beam frequencies:")
#         with np.printoptions(precision=2):
#             print(bfreqs)

#     if use_dask:
#         return (da.from_array(beam_amp, chunks=beam_amp.shape),
#                 da.from_array(beam_extents, chunks=beam_extents.shape),
#                 da.from_array(bfreqs, bfreqs.shape))
#     else:
#         return beam_amp, beam_extents, bfreqs

# def interpolate_beam(ll, mm, freqs, args):
#     """
#     Interpolate beam to image coordinates and optionally compute average
#     over time if MS is provoded
#     """
#     nband = freqs.size
#     print("Interpolating beam")
#     parangles, ant_scale, point_errs, unflag_counts, use_dask = extract_dde_info(
#         args, freqs)

#     lm_source = np.vstack((ll.ravel(), mm.ravel())).T
#     beam_amp, beam_extents, bfreqs = make_power_beam(
#         args, lm_source, freqs, use_dask)

#     # interpolate beam
#     if use_dask:
#         from africanus.rime.dask import beam_cube_dde
#         lm_source = da.from_array(lm_source, chunks=lm_source.shape)
#         freqs = da.from_array(freqs, chunks=freqs.shape)
#         # compute ncpu images at a time to avoid memory errors
#         ntimes = parangles.shape[0]
#         I = np.arange(0, ntimes, args.ncpu)
#         nchunks = I.size
#         I = np.append(I, ntimes)
#         beam_image = np.zeros((ll.size, 1, nband), dtype=beam_amp.dtype)
#         for i in range(nchunks):
#             ilow = I[i]
#             ihigh = I[i+1]
#             part_parangles = da.from_array(
#                 parangles[ilow:ihigh], chunks=(1, 1))
#             part_point_errs = da.from_array(
#                 point_errs[ilow:ihigh], chunks=(1, 1, freqs.size, 2))
#             # interpolate and remove redundant axes
#             part_beam_image = beam_cube_dde(beam_amp, beam_extents, bfreqs,
#                                             lm_source, part_parangles, part_point_errs,
#                                             ant_scale, freqs).compute()[:, :, 0, :, 0, 0]
#             # weighted sum over time
#             beam_image += np.sum(part_beam_image *
#                                  unflag_counts[None, ilow:ihigh, None], axis=1, keepdims=True)
#         # normalise by sum of weights
#         beam_image /= np.sum(unflag_counts)
#         # remove time axis
#         beam_image = beam_image[:, 0, :]
#     else:
#         from africanus.rime.fast_beam_cubes import beam_cube_dde
#         beam_image = beam_cube_dde(beam_amp, beam_extents, bfreqs,
#                                    lm_source, parangles, point_errs,
#                                    ant_scale, freqs).squeeze()

#     # swap source and freq axes and reshape to image shape
#     beam_source = np.transpose(beam_image, axes=(1, 0))
#     return beam_source.squeeze().reshape((freqs.size, *ll.shape))

# def main(args):
#     # get coord info
#     hdr = fits.getheader(args.image)
#     l_coord, ref_l = data_from_header(hdr, axis=1)
#     l_coord -= ref_l
#     m_coord, ref_m = data_from_header(hdr, axis=2)
#     m_coord -= ref_m
#     if hdr["CTYPE4"].lower() == 'freq':
#         freq_axis = 4
#     elif hdr["CTYPE3"].lower() == 'freq':
#         freq_axis = 3
#     else:
#         raise ValueError("Freq axis must be 3rd or 4th")
#     freqs, ref_freq = data_from_header(hdr, axis=freq_axis)

#     xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')

#     # interpolate primary beam to fits header and optionally average over time
#     beam_image = interpolate_beam(xx, yy, freqs, args)

#     # save power beam
#     save_fits(args.output_filename, beam_image, hdr)
#     print("Wrote interpolated beam cube to %s \n" % args.output_filename)

#     return
Exemple #11
0
def test_forwardmodel(do_beam, do_gains, tmp_path_factory):
    test_dir = tmp_path_factory.mktemp("test_pfb")

    packratt.get('/test/ms/2021-06-24/elwood/test_ascii_1h60.0s.MS.tar',
                 str(test_dir))

    import numpy as np
    np.random.seed(420)
    from numpy.testing import assert_allclose
    from pyrap.tables import table

    ms = table(str(test_dir / 'test_ascii_1h60.0s.MS'), readonly=False)
    spw = table(str(test_dir / 'test_ascii_1h60.0s.MS::SPECTRAL_WINDOW'))

    utime = np.unique(ms.getcol('TIME'))

    freq = spw.getcol('CHAN_FREQ').squeeze()
    freq0 = np.mean(freq)

    ntime = utime.size
    nchan = freq.size
    nant = np.maximum(
        ms.getcol('ANTENNA1').max(),
        ms.getcol('ANTENNA2').max()) + 1

    ncorr = ms.getcol('FLAG').shape[-1]

    uvw = ms.getcol('UVW')
    nrow = uvw.shape[0]
    u_max = abs(uvw[:, 0]).max()
    v_max = abs(uvw[:, 1]).max()
    uv_max = np.maximum(u_max, v_max)

    # image size
    from africanus.constants import c as lightspeed
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    srf = 2.0
    cell_rad = cell_N / srf
    cell_size = cell_rad * 180 / np.pi
    print("Cell size set to %5.5e arcseconds" % cell_size)

    fov = 2
    npix = int(fov / cell_size)
    if npix % 2:
        npix += 1

    nx = npix
    ny = npix

    print("Image size set to (%i, %i, %i)" % (nchan, nx, ny))

    # model
    model = np.zeros((nchan, nx, ny), dtype=np.float64)
    nsource = 10
    Ix = np.random.randint(0, npix, nsource)
    Iy = np.random.randint(0, npix, nsource)
    alpha = -0.7 + 0.1 * np.random.randn(nsource)
    I0 = 1.0 + np.abs(np.random.randn(nsource))
    for i in range(nsource):
        model[:, Ix[i], Iy[i]] = I0[i] * (freq / freq0)**alpha[i]

    if do_beam:
        # primary beam
        from katbeam import JimBeam
        beam = JimBeam('MKAT-AA-L-JIM-2020')
        l_coord = -np.arange(-(nx // 2), nx // 2) * cell_size
        m_coord = np.arange(-(ny // 2), ny // 2) * cell_size
        xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')
        pbeam = np.zeros((nchan, nx, ny), dtype=np.float64)
        for i in range(nchan):
            pbeam[i] = beam.I(xx, yy, freq[i] / 1e6)  # freq in MHz
        model_att = pbeam * model
        bm = 'JimBeam'
    else:
        model_att = model
        bm = None

    # model vis
    from ducc0.wgridder import dirty2ms
    model_vis = np.zeros((nrow, nchan, ncorr), dtype=np.complex128)
    for c in range(nchan):
        model_vis[:, c:c + 1, 0] = dirty2ms(uvw,
                                            freq[c:c + 1],
                                            model_att[c],
                                            pixsize_x=cell_rad,
                                            pixsize_y=cell_rad,
                                            epsilon=1e-8,
                                            do_wstacking=True,
                                            nthreads=8)
        model_vis[:, c, -1] = model_vis[:, c, 0]

    ms.putcol('MODEL_DATA', model_vis.astype(np.complex64))

    if do_gains:
        t = (utime - utime.min()) / (utime.max() - utime.min())
        nu = 2.5 * (freq / freq0 - 1.0)

        from africanus.gps.utils import abs_diff
        tt = abs_diff(t, t)
        lt = 0.25
        Kt = 0.1 * np.exp(-tt**2 / (2 * lt**2))
        Lt = np.linalg.cholesky(Kt + 1e-10 * np.eye(ntime))
        vv = abs_diff(nu, nu)
        lv = 0.1
        Kv = 0.1 * np.exp(-vv**2 / (2 * lv**2))
        Lv = np.linalg.cholesky(Kv + 1e-10 * np.eye(nchan))
        L = (Lt, Lv)

        from pfb.utils.misc import kron_matvec

        jones = np.zeros((ntime, nant, nchan, 1, ncorr), dtype=np.complex128)
        for p in range(nant):
            for c in [0, -1]:  # for now only diagonal
                xi_amp = np.random.randn(ntime, nchan)
                amp = np.exp(-nu[None, :]**2 +
                             kron_matvec(L, xi_amp).reshape(ntime, nchan))
                xi_phase = np.random.randn(ntime, nchan)
                phase = kron_matvec(L, xi_phase).reshape(ntime, nchan)
                jones[:, p, :, 0, c] = amp * np.exp(1.0j * phase)

        # corrupted vis
        model_vis = model_vis.reshape(nrow, nchan, 1, 2, 2)
        from africanus.calibration.utils import chunkify_rows
        time = ms.getcol('TIME')
        row_chunks, tbin_idx, tbin_counts = chunkify_rows(time, ntime)
        ant1 = ms.getcol('ANTENNA1')
        ant2 = ms.getcol('ANTENNA2')

        from africanus.calibration.utils import corrupt_vis
        vis = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones,
                          model_vis).reshape(nrow, nchan, ncorr)

        model_vis[:, :, 0, 0, 0] = 1.0 + 0j
        model_vis[:, :, 0, -1, -1] = 1.0 + 0j
        muellercol = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones,
                                 model_vis).reshape(nrow, nchan, ncorr)

        ms.putcol('DATA', vis.astype(np.complex64))
        ms.putcol('CORRECTED_DATA', muellercol.astype(np.complex64))
        ms.close()
        mcol = 'CORRECTED_DATA'
    else:
        ms.putcol('DATA', model_vis.astype(np.complex64))
        mcol = None

    from pfb.workers.grid.dirty import _dirty
    _dirty(ms=str(test_dir / 'test_ascii_1h60.0s.MS'),
           data_column="DATA",
           weight_column='WEIGHT',
           imaging_weight_column=None,
           flag_column='FLAG',
           mueller_column=mcol,
           row_chunks=None,
           epsilon=1e-5,
           wstack=True,
           mock=False,
           double_accum=True,
           output_filename=str(test_dir / 'test'),
           nband=nchan,
           field_of_view=fov,
           super_resolution_factor=srf,
           cell_size=None,
           nx=None,
           ny=None,
           output_type='f4',
           nworkers=1,
           nthreads_per_worker=1,
           nvthreads=8,
           mem_limit=8,
           nthreads=8,
           host_address=None)

    from pfb.workers.grid.psf import _psf
    _psf(ms=str(test_dir / 'test_ascii_1h60.0s.MS'),
         data_column="DATA",
         weight_column='WEIGHT',
         imaging_weight_column=None,
         flag_column='FLAG',
         mueller_column=mcol,
         row_out_chunk=-1,
         row_chunks=None,
         epsilon=1e-5,
         wstack=True,
         mock=False,
         psf_oversize=2,
         double_accum=True,
         output_filename=str(test_dir / 'test'),
         nband=nchan,
         field_of_view=fov,
         super_resolution_factor=srf,
         cell_size=None,
         nx=None,
         ny=None,
         output_type='f4',
         nworkers=1,
         nthreads_per_worker=1,
         nvthreads=8,
         mem_limit=8,
         nthreads=8,
         host_address=None)

    # solve for model using pcg and mask
    mask = np.any(model, axis=0)
    from astropy.io import fits
    from pfb.utils.fits import save_fits
    hdr = fits.getheader(str(test_dir / 'test_dirty.fits'))
    save_fits(str(test_dir / 'test_model.fits'), model, hdr)
    save_fits(str(test_dir / 'test_mask.fits'), mask, hdr)

    from pfb.workers.deconv.forward import _forward
    _forward(residual=str(test_dir / 'test_dirty.fits'),
             psf=str(test_dir / 'test_psf.fits'),
             mask=str(test_dir / 'test_mask.fits'),
             beam_model=bm,
             band='L',
             weight_table=str(test_dir / 'test.zarr'),
             output_filename=str(test_dir / 'test'),
             nband=nchan,
             output_type='f4',
             epsilon=1e-5,
             sigmainv=0.0,
             wstack=True,
             double_accum=True,
             cg_tol=1e-6,
             cg_minit=10,
             cg_maxit=100,
             cg_verbose=0,
             cg_report_freq=10,
             backtrack=False,
             nworkers=1,
             nthreads_per_worker=1,
             nvthreads=1,
             mem_limit=8,
             nthreads=1,
             host_address=None)

    # get inferred model
    from pfb.utils.fits import load_fits
    model_inferred = load_fits(str(test_dir / 'test_update.fits')).squeeze()

    for i in range(nsource):
        if do_beam:
            beam = pbeam[:, Ix[i], Iy[i]]
            assert_allclose(
                0.0,
                beam *
                (model_inferred[:, Ix[i], Iy[i]] - model[:, Ix[i], Iy[i]]),
                atol=1e-4)
        else:
            assert_allclose(0.0,
                            model_inferred[:, Ix[i], Iy[i]] -
                            model[:, Ix[i], Iy[i]],
                            atol=1e-4)
Exemple #12
0
def _residual(ms, stack, **kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)
    pyscilog.log_to_file(args.output_filename + '.log')
    pyscilog.enable_memory_logging(level=3)

    # number of threads per worker
    if args.nthreads is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify nthreads when using a distributed scheduler"
            )
        import multiprocessing
        nthreads = multiprocessing.cpu_count()
        args.nthreads = nthreads
    else:
        nthreads = args.nthreads

    # configure memory limit
    if args.mem_limit is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify mem-limit when using a distributed scheduler"
            )
        import psutil
        mem_limit = int(psutil.virtual_memory()[0] /
                        1e9)  # 100% of memory by default
        args.mem_limit = mem_limit
    else:
        mem_limit = args.mem_limit

    nband = args.nband
    if args.nworkers is None:
        nworkers = nband
        args.nworkers = nworkers
    else:
        nworkers = args.nworkers

    if args.nthreads_per_worker is None:
        nthreads_per_worker = 1
        args.nthreads_per_worker = nthreads_per_worker
    else:
        nthreads_per_worker = args.nthreads_per_worker

    # the number of chunks being read in simultaneously is equal to
    # the number of dask threads
    nthreads_dask = nworkers * nthreads_per_worker

    if args.ngridder_threads is None:
        if args.host_address is not None:
            ngridder_threads = nthreads // nthreads_per_worker
        else:
            ngridder_threads = nthreads // nthreads_dask
        args.ngridder_threads = ngridder_threads
    else:
        ngridder_threads = args.ngridder_threads

    ms = list(ms)
    print('Input Options:', file=log)
    for key in kw.keys():
        print('     %25s = %s' % (key, args[key]), file=log)

    # numpy imports have to happen after this step
    from pfb import set_client
    set_client(nthreads, mem_limit, nworkers, nthreads_per_worker,
               args.host_address, stack, log)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    from dask.graph_manipulation import clone
    from dask.distributed import performance_report
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import residual as im2residim
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # real valued weights
    wdims = getattr(xds[0], args.weight_column).data.ndim
    if wdims == 2:  # WEIGHT
        memory_per_row += ncorr * data_bytes / 2
    else:  # WEIGHT_SPECTRUM
        memory_per_row += bytes_per_row / 2

    # flags (uint8 or bool)
    memory_per_row += np.dtype(np.uint8).itemsize * max_chan_chunk * ncorr

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(fov / cell_size)
        if npix % 2:
            npix += 1
        nx = good_size(npix)
        ny = good_size(npix)
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("Image size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow,
                                   memory_per_row, nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row,
                                   nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    dirties = []
    radec = None  # assumes we are only imaging field 0 of first MS
    for ims in ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data = getattr(ds, args.data_column).data
            dataxx = data[:, :, 0]
            datayy = data[:, :, -1]

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data.shape,
                                          chunks=data.chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data.shape,
                                                      chunks=data.chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply adjoint of mueller term.
            # Phases modify data amplitudes modify weights.
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0]))
                datayy *= da.exp(-1j * da.angle(mueller[:, :, -1]))
                weightsxx *= da.absolute(mueller[:, :, 0])
                weightsyy *= da.absolute(mueller[:, :, -1])

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy
            data = (weightsxx * dataxx + weightsyy * datayy)
            # TODO - turn off this stupid warning
            data = da.where(weights, data / weights, 0.0j)

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            dirty = vis2im(uvw,
                           freqs[ims][spw],
                           data,
                           freq_bin_idx[ims][spw],
                           freq_bin_counts[ims][spw],
                           nx,
                           ny,
                           cell_rad,
                           weights=weights,
                           flag=mask.astype(np.uint8),
                           nthreads=ngridder_threads,
                           epsilon=args.epsilon,
                           do_wstacking=args.wstack,
                           double_accum=args.double_accum)

            dirties.append(dirty)

    # dask.visualize(dirties, filename=args.output_filename + '_graph.pdf', optimize_graph=False)

    if not args.mock:
        # result = dask.compute(dirties, wsum, optimize_graph=False)
        with performance_report(filename=args.output_filename + '_per.html'):
            result = dask.compute(dirties, optimize_graph=False)

        dirties = result[0]

        dirty = stitch_images(dirties, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_dirty.fits',
                  dirty,
                  hdr,
                  dtype=args.output_type)

    print("All done here.", file=log)
Exemple #13
0
def _clean(**kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)

    import numpy as np
    import numexpr as ne
    import dask
    import dask.array as da
    from dask.distributed import performance_report
    from pfb.utils.fits import load_fits, set_wcs, save_fits, data_from_header
    from pfb.opt.hogbom import hogbom
    from astropy.io import fits

    print("Loading dirty", file=log)
    dirty = load_fits(args.dirty, dtype=args.output_type).squeeze()
    nband, nx, ny = dirty.shape
    hdr = fits.getheader(args.dirty)

    print("Loading psf", file=log)
    psf = load_fits(args.psf, dtype=args.output_type).squeeze()
    _, nx_psf, ny_psf = psf.shape
    hdr_psf = fits.getheader(args.psf)

    wsums = np.amax(psf.reshape(-1, nx_psf * ny_psf), axis=1)
    wsum = np.sum(wsums)

    psf /= wsum
    psf_mfs = np.sum(psf, axis=0)

    assert (psf_mfs.max() - 1.0) < 1e-4

    dirty /= wsum
    dirty_mfs = np.sum(dirty, axis=0)

    # get info required to set WCS
    ra = np.deg2rad(hdr['CRVAL1'])
    dec = np.deg2rad(hdr['CRVAL2'])
    radec = [ra, dec]

    cell_deg = np.abs(hdr['CDELT1'])
    if cell_deg != np.abs(hdr['CDELT2']):
        raise NotImplementedError('cell sizes have to be equal')
    cell_rad = np.deg2rad(cell_deg)

    freq_out, ref_freq = data_from_header(hdr, axis=3)

    hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq)

    save_fits(args.output_filename + '_dirty_mfs.fits',
              dirty_mfs,
              hdr_mfs,
              dtype=args.output_type)

    # set up Hessian approximation
    if args.weight_table is not None:
        normfact = wsum
        from africanus.gridding.wgridder.dask import hessian
        from pfb.utils.misc import plan_row_chunk
        from daskms.experimental.zarr import xds_from_zarr

        xds = xds_from_zarr(args.weight_table)[0]
        nrow = xds.row.size
        freqs = xds.chan.data
        nchan = freqs.size

        # bin edges
        fmin = freqs.min()
        fmax = freqs.max()
        fbins = np.linspace(fmin, fmax, nband + 1)

        # chan <-> band mapping
        band_mapping = {}
        chan_chunks = {}
        freq_bin_idx = {}
        freq_bin_counts = {}
        band_map = np.zeros(freqs.size, dtype=np.int32)
        for band in range(nband):
            indl = freqs >= fbins[band]
            indu = freqs < fbins[band + 1] + 1e-6
            band_map = np.where(indl & indu, band, band_map)

        # to dask arrays
        bands, bin_counts = np.unique(band_map, return_counts=True)
        band_mapping = tuple(bands)
        chan_chunks = {'chan': tuple(bin_counts)}
        freqs = da.from_array(freqs, chunks=tuple(bin_counts))
        bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1]
        freq_bin_idx = da.from_array(bin_idx, chunks=1)
        freq_bin_counts = da.from_array(bin_counts, chunks=1)

        max_chan_chunk = bin_counts.max()
        bin_counts = tuple(bin_counts)
        # the first factor of 3 accounts for the intermediate visibilities
        # produced in Hessian (i.e. complex data + real weights)
        memory_per_row = (3 * max_chan_chunk * xds.WEIGHT.data.itemsize +
                          3 * xds.UVW.data.itemsize)

        # get approx image size
        pixel_bytes = np.dtype(args.output_type).itemsize
        band_size = nx * ny * pixel_bytes

        if args.host_address is None:
            # nworker bands on single node
            row_chunk = plan_row_chunk(args.mem_limit / args.nworkers,
                                       band_size, nrow, memory_per_row,
                                       args.nthreads_per_worker)
        else:
            # single band per node
            row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                       memory_per_row,
                                       args.nthreads_per_worker)

        print(
            "nrows = %i, row chunks set to %i for a total of %i chunks per node"
            % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
            file=log)

        def convolver(x):
            model = da.from_array(x, chunks=(1, nx, ny), name=False)

            xds = xds_from_zarr(args.weight_table,
                                chunks={
                                    'row': row_chunk,
                                    'chan': bin_counts
                                })[0]

            convolvedim = hessian(xds.UVW.data,
                                  freqs,
                                  model,
                                  freq_bin_idx,
                                  freq_bin_counts,
                                  cell_rad,
                                  weights=xds.WEIGHT.data.astype(
                                      args.output_type),
                                  nthreads=args.nvthreads,
                                  epsilon=args.epsilon,
                                  do_wstacking=args.wstack,
                                  double_accum=args.double_accum)
            return convolvedim
    else:
        normfact = 1.0
        from pfb.operators.psf import hessian
        from ducc0.fft import r2c
        iFs = np.fft.ifftshift

        npad_xl = (nx_psf - nx) // 2
        npad_xr = nx_psf - nx - npad_xl
        npad_yl = (ny_psf - ny) // 2
        npad_yr = ny_psf - ny - npad_yl
        padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr))
        unpad_x = slice(npad_xl, -npad_xr)
        unpad_y = slice(npad_yl, -npad_yr)
        lastsize = ny + np.sum(padding[-1])
        psf_pad = iFs(psf, axes=(1, 2))
        psfhat = r2c(psf_pad,
                     axes=(1, 2),
                     forward=True,
                     nthreads=nthreads,
                     inorm=0)

        psfhat = da.from_array(psfhat, chunks=(1, -1, -1))

        def convolver(x):
            model = da.from_array(x, chunks=(1, nx, ny), name=False)

            convolvedim = hessian(model, psfhat, padding, nvthreads, unpad_x,
                                  unpad_y, lastsize)
            return convolvedim

        # psfo = PSF(psf, dirty.shape, nthreads=args.nthreads)
        # def convolver(x): return psfo.convolve(x)

    rms = np.std(dirty_mfs)
    rmax = np.abs(dirty_mfs).max()

    print("Iter %i: peak residual = %f, rms = %f" % (0, rmax, rms), file=log)

    residual = dirty.copy()
    residual_mfs = dirty_mfs.copy()
    model = np.zeros_like(residual)
    for k in range(args.nmiter):
        print("Running Hogbom", file=log)
        x = hogbom(residual,
                   psf,
                   gamma=args.hb_gamma,
                   pf=args.hb_peak_factor,
                   maxit=args.hb_maxit,
                   verbosity=args.hb_verbose,
                   report_freq=args.hb_report_freq)

        model += x
        print("Getting residual", file=log)

        convimage = convolver(model)
        dask.visualize(convimage,
                       filename=args.output_filename + '_hessian' + str(k) +
                       '_graph.pdf',
                       optimize_graph=False)
        with performance_report(filename=args.output_filename + '_hessian' +
                                str(k) + '_per.html'):
            convimage = dask.compute(convimage, optimize_graph=False)[0]
        ne.evaluate('dirty - convimage/normfact',
                    out=residual,
                    casting='same_kind')
        ne.evaluate('sum(residual, axis=0)',
                    out=residual_mfs,
                    casting='same_kind')

        rms = np.std(residual_mfs)
        rmax = np.abs(residual_mfs).max()

        print("Iter %i: peak residual = %f, rms = %f" % (k + 1, rmax, rms),
              file=log)

    print("Saving results", file=log)
    save_fits(args.output_filename + '_model.fits', model, hdr)
    model_mfs = np.mean(model, axis=0)
    save_fits(args.output_filename + '_model_mfs.fits', model_mfs, hdr_mfs)
    save_fits(args.output_filename + '_residual.fits',
              residual * wsums[:, None, None], hdr)
    save_fits(args.output_filename + '_residual.fits', residual_mfs, hdr_mfs)

    print("All done here.", file=log)