Example #1
0
def make_power_beam(args, lm_source, freqs, use_dask):
    print("Loading fits beam patterns from %s" % args.beam_model)
    from glob import glob
    paths = glob(args.beam_model + '**_**.fits')
    beam_hdr = None
    if args.corr_type == 'linear':
        corr1 = 'XX'
        corr2 = 'YY'
    elif args.corr_type == 'circular':
        corr1 = 'LL'
        corr2 = 'RR'
    else:
        raise KeyError("Unknown corr_type supplied. Only 'linear' or 'circular' supported")

    for path in paths:
        if corr1.lower() in path[-10::]:
            if 're' in path[-7::]:
                corr1_re = load_fits(path)
                if beam_hdr is None:
                    beam_hdr = fits.getheader(path)
            elif 'im' in path[-7::]:
                corr1_im = load_fits(path)
            else:
                raise NotImplementedError("Only re/im patterns supported")
        elif corr2.lower() in path[-10::]:
            if 're' in path[-7::]:
                corr2_re = load_fits(path)
            elif 'im' in path[-7::]:
                corr2_im = load_fits(path)
            else:
                raise NotImplementedError("Only re/im patterns supported")
    
    # get power beam
    beam_amp = (corr1_re**2 + corr1_im**2 + corr2_re**2 + corr2_im**2)/2.0

    # get cube in correct shape for interpolation code
    beam_amp = np.ascontiguousarray(np.transpose(beam_amp, (1, 2, 0))
                                    [:, :, :, None, None])
    # get cube info
    if beam_hdr['CUNIT1'].lower() != "deg":
        raise ValueError("Beam image units must be in degrees")
    npix_l = beam_hdr['NAXIS1']
    refpix_l = beam_hdr['CRPIX1']
    delta_l = beam_hdr['CDELT1']
    l_min = (1 - refpix_l)*delta_l
    l_max = (1 + npix_l - refpix_l)*delta_l

    if beam_hdr['CUNIT2'].lower() != "deg":
        raise ValueError("Beam image units must be in degrees")
    npix_m = beam_hdr['NAXIS2']
    refpix_m = beam_hdr['CRPIX2']
    delta_m = beam_hdr['CDELT2']
    m_min = (1 - refpix_m)*delta_m
    m_max = (1 + npix_m - refpix_m)*delta_m

    if (l_min > lm_source[:, 0].min() or m_min > lm_source[:, 1].min() or
            l_max < lm_source[:, 0].max() or m_max < lm_source[:, 1].max()):
        raise ValueError("The supplied beam is not large enough")

    beam_extents = np.array([[l_min, l_max], [m_min, m_max]])

    # get frequencies
    if beam_hdr["CTYPE3"].lower() != 'freq':
        raise ValueError(
            "Cubes are assumed to be in format [nchan, nx, ny]")
    nchan = beam_hdr['NAXIS3']
    refpix = beam_hdr['CRPIX3']
    delta = beam_hdr['CDELT3']  # assumes units are Hz
    freq0 = beam_hdr['CRVAL3']
    bfreqs = freq0 + np.arange(1 - refpix, 1 + nchan - refpix) * delta
    if bfreqs[0] > freqs[0] or bfreqs[-1] < freqs[-1]:
        warnings.warn("The supplied beam does not have sufficient "
                        "bandwidth. Beam frequencies:")
        with np.printoptions(precision=2):
            print(bfreqs)

    if use_dask:
        return (da.from_array(beam_amp, chunks=beam_amp.shape),
                da.from_array(beam_extents, chunks=beam_extents.shape), 
                da.from_array(bfreqs, bfreqs.shape))
    else:
        return beam_amp, beam_extents, bfreqs
Example #2
0
def main(args):
    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list(tmp_freq))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        print("Super resolution factor = ", cell_N / cell_rad)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size)

    if args.nx is None or args.ny is None:
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = npix
        args.ny = npix

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny))

    # init gridder
    R = Gridder(args.ms,
                args.nx,
                args.ny,
                args.cell_size,
                nband=args.nband,
                nthreads=args.nthreads,
                do_wstacking=args.do_wstacking,
                row_chunks=args.row_chunks,
                optimise_chunks=True,
                data_column=args.data_column,
                weight_column=args.weight_column,
                imaging_weight_column=args.imaging_weight_column,
                model_column=args.model_column,
                flag_column=args.flag_column)
    freq_out = R.freq_out
    radec = R.radec

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                      2 * args.nx, 2 * args.ny, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          2 * args.nx, 2 * args.ny, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        compare_headers(hdr_psf, fits.getheader(args.psf))
        psf_array = load_fits(args.psf)
    else:
        psf_array = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)

    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    counts = np.sum(psf_max > 0)
    psf_max_mean = wsum / counts
    psf_array /= psf_max_mean
    psf = PSF(psf_array, args.nthreads)
    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    psf_max[psf_max < 1e-15] = 1e-15

    if args.dirty is not None:
        compare_headers(hdr, fits.getheader(args.dirty))
        dirty = load_fits(args.dirty)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty /= psf_max_mean

    # mfs residual
    wsum = np.sum(psf_max)
    dirty_mfs = np.sum(dirty, axis=0) / wsum
    rmax = np.abs(dirty_mfs).max()
    rms = np.std(dirty_mfs)
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    psf_mfs = np.sum(psf_array, axis=0) / wsum
    save_fits(
        args.outfile + '_psf_mfs.fits', psf_mfs[args.nx // 2:3 * args.nx // 2,
                                                args.ny // 2:3 * args.ny // 2],
        hdr_mfs)

    # mask
    if args.mask is not None:
        mask = load_fits(args.mask, dtype=np.int64)
        if mask.shape != (args.nx, args.ny):
            raise ValueError("Mask has incorrect shape")
    else:
        mask = np.ones((args.nx, args.ny), dtype=np.int64)

    if args.point_mask is not None:
        pmask = load_fits(args.point_mask, dtype=np.bool)
        if pmask.shape != (args.nx, args.ny):
            raise ValueError("Mask has incorrect shape")
    else:
        pmask = None

    # Reporting
    print("At iteration 0 peak of residual is %f and rms is %f" % (rmax, rms))
    report_iters = list(np.arange(0, args.maxit, args.report_freq))
    if report_iters[-1] != args.maxit - 1:
        report_iters.append(args.maxit - 1)

    # set up point sources
    phi = Dirac(args.nband, args.nx, args.ny, mask=pmask)
    dual = np.zeros((args.nband, args.nx, args.ny), dtype=np.float64)
    weights_21 = np.where(phi.mask, 1, np.inf)

    # preconditioning matrix
    def hess(beta):
        return phi.hdot(psf.convolve(
            phi.dot(beta))) + beta / args.sig_l2**2  # vague prior on beta

    # get new spectral norm
    L = power_method(hess, dirty.shape, tol=args.pmtol, maxit=args.pmmaxit)

    # deconvolve
    eps = 1.0
    i = 0
    residual = dirty.copy()
    model = np.zeros(dirty.shape, dtype=dirty.dtype)
    for i in range(1, args.maxit):
        # find point source candidates
        if args.do_clean:
            model_tmp = hogbom(mask[None] * residual / psf_max[:, None, None],
                               psf_array / psf_max[:, None, None],
                               gamma=args.cgamma,
                               pf=args.peak_factor)
            phi.update_locs(np.any(model_tmp, axis=0))
            # get new spectral norm
            L = power_method(hess,
                             model.shape,
                             tol=args.pmtol,
                             maxit=args.pmmaxit)
        else:
            model_tmp = np.zeros_like(residual, dtype=residual.dtype)

        # solve for beta updates
        x = pcg(hess,
                phi.hdot(residual),
                phi.hdot(model_tmp),
                M=lambda x: x * args.sig_l2**2,
                tol=args.cgtol,
                maxit=args.cgmaxit,
                verbosity=args.cgverbose)

        modelp = model.copy()
        model += args.gamma * x

        # impose sparsity and positivity in point sources
        weights_21 = np.where(phi.mask, 1, 1e10)  # 1e10 for effective infinity
        model, dual = primal_dual(hess,
                                  model,
                                  modelp,
                                  dual,
                                  args.sig_21,
                                  phi,
                                  weights_21,
                                  L,
                                  tol=args.pdtol,
                                  maxit=args.pdmaxit,
                                  axis=0,
                                  positivity=args.positivity,
                                  report_freq=100)

        # update Dirac dictionary (remove zero components)
        phi.trim_fat(model)

        # get residual
        residual = R.make_residual(model) / psf_max_mean

        # check stopping criteria
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(mask * residual_mfs).max()
        rms = np.std(mask * residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        if i in report_iters:
            # save current iteration
            save_fits(args.outfile + str(i) + '_model.fits', model, hdr)

            model_mfs = np.mean(model, axis=0)
            save_fits(args.outfile + str(i) + '_model_mfs.fits', model_mfs,
                      hdr_mfs)

            save_fits(args.outfile + str(i) + '_residual.fits',
                      residual / psf_max[:, None, None], hdr)

            save_fits(args.outfile + str(i) + '_residual_mfs.fits',
                      residual_mfs, hdr_mfs)

        print(
            "At iteration %i peak of residual is %f, rms is %f, current eps is %f"
            % (i, rmax, rms, eps))

        if eps < args.tol:
            print("We have convergence!")
            break

    # final iteration with only a positivity constraint on pixel locs
    tmp = phi.hdot(model)
    x = pcg(hess,
            phi.hdot(residual),
            np.zeros_like(tmp, dtype=tmp.dtype),
            M=lambda x: x * args.sig_l2**2,
            tol=args.cgtol,
            maxit=args.cgmaxit,
            verbosity=args.cgverbose)

    modelp = model.copy()
    model += args.gamma * x
    model, dual = primal_dual(hess,
                              model,
                              modelp,
                              dual,
                              0.0,
                              phi,
                              weights_21,
                              L,
                              tol=args.pdtol,
                              maxit=args.pdmaxit,
                              axis=0,
                              report_freq=100)

    # get residual
    residual = R.make_residual(model) / psf_max_mean

    # check stopping criteria
    residual_mfs = np.sum(residual, axis=0) / wsum
    rmax = np.abs(mask * residual_mfs).max()
    rms = np.std(mask * residual_mfs)
    print("At final iteration peak of residual is %f and rms is %f" %
          (rmax, rms))

    save_fits(args.outfile + '_model.fits', model, hdr)

    model_mfs = np.mean(model, axis=0)
    save_fits(args.outfile + '_model_mfs.fits', model_mfs, hdr_mfs)

    save_fits(args.outfile + '_residual.fits',
              residual / psf_max[:, None, None], hdr)

    save_fits(args.outfile + '_residual_mfs.fits', residual_mfs, hdr_mfs)

    if args.interp_model:
        nband = args.nband
        order = args.spectral_poly_order
        phi.trim_fat(model)
        I = np.argwhere(phi.mask).squeeze()
        Ix = I[:, 0]
        Iy = I[:, 1]
        npix = I.shape[0]

        # get components
        beta = model[:, Ix, Iy]

        # fit integrated polynomial to model components
        # we are given frequencies at bin centers, convert to bin edges
        ref_freq = np.mean(freq_out)
        delta_freq = freq_out[1] - freq_out[0]
        wlow = (freq_out - delta_freq / 2.0) / ref_freq
        whigh = (freq_out + delta_freq / 2.0) / ref_freq
        wdiff = whigh - wlow

        # set design matrix for each component
        Xdesign = np.zeros([freq_out.size, args.spectral_poly_order])
        for i in range(1, args.spectral_poly_order + 1):
            Xdesign[:, i - 1] = (whigh**i - wlow**i) / (i * wdiff)

        weights = psf_max[:, None]
        dirty_comps = Xdesign.T.dot(weights * beta)

        hess_comps = Xdesign.T.dot(weights * Xdesign)

        comps = np.linalg.solve(hess_comps, dirty_comps)

        np.savez(args.outfile + "spectral_comps",
                 comps=comps,
                 ref_freq=ref_freq,
                 mask=np.any(model, axis=0))

    if args.write_model:
        if args.interp_model:
            R.write_component_model(comps, ref_freq, phi.mask, args.row_chunks,
                                    args.chan_chunks)
        else:
            R.write_model(model)
Example #3
0
def main(args):
    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list(tmp_freq))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        print("Super resolution factor = ", cell_N / cell_rad)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size)

    if args.nx is None or args.ny is None:
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = npix
        args.ny = npix

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny))

    # init gridder
    R = Gridder(args.ms,
                args.nx,
                args.ny,
                args.cell_size,
                nband=args.nband,
                nthreads=args.nthreads,
                do_wstacking=args.do_wstacking,
                row_chunks=args.row_chunks,
                data_column=args.data_column,
                weight_column=args.weight_column,
                epsilon=args.epsilon,
                imaging_weight_column=args.imaging_weight_column,
                model_column=args.model_column,
                flag_column=args.flag_column)
    freq_out = R.freq_out
    radec = R.radec

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                      2 * args.nx, 2 * args.ny, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          2 * args.nx, 2 * args.ny, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        try:
            compare_headers(hdr_psf, fits.getheader(args.psf))
            psf_array = load_fits(args.psf)
        except:
            psf_array = R.make_psf()
            save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)
    else:
        psf_array = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)

    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    counts = np.sum(psf_max > 0)
    psf_max_mean = wsum / counts  # normalissation for more intuitive sig_21 values
    psf_array /= psf_max_mean
    psf = PSF(psf_array, args.nthreads)
    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    psf_max[psf_max < 1e-15] = 1e-15  # LB - is this the right thing to do?

    psf_mfs = np.sum(psf_array, axis=0) / wsum
    save_fits(
        args.outfile + '_psf_mfs.fits', psf_mfs[args.nx // 2:3 * args.nx // 2,
                                                args.ny // 2:3 * args.ny // 2],
        hdr_mfs)

    # dirty
    if args.dirty is not None:
        try:
            compare_headers(hdr, fits.getheader(args.dirty))
            dirty = load_fits(args.dirty)
        except:
            dirty = R.make_dirty()
            save_fits(args.outfile + '_dirty.fits', dirty, hdr)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty_mfs = np.sum(dirty / psf_max_mean, axis=0) / wsum
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    residual = dirty.copy()

    model = np.zeros((2, args.nband, args.nx, args.ny))
    recompute_residual = False
    if args.beta0 is not None:
        compare_headers(hdr, fits.getheader(args.beta0))
        model[0] = load_fits(args.beta0).squeeze()
        recompute_residual = True

    if args.alpha0 is not None:
        compare_headers(hdr, fits.getheader(args.alpha0))
        model[1] = load_fits(args.alpha0).squeeze()
        recompute_residual = True

    # normalise for more intuitive hypers
    residual /= psf_max_mean
    residual_mfs = np.sum(residual, axis=0) / wsum
    save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs)

    # mask
    if args.mask is not None:
        mask = load_fits(args.mask, dtype=np.int64)[None, :, :]
        if mask.shape != (1, args.nx, args.ny):
            raise ValueError("Mask has incorrect shape")
    else:
        mask = np.ones((1, args.nx, args.ny), dtype=np.int64)

    # point mask
    pmask = load_fits(args.point_mask, dtype=np.bool)[None, :, :]
    if pmask.shape != (1, args.nx, args.ny):
        raise ValueError("Mask has incorrect shape")

    # set up splitting operator
    phi = lambda x: x[0] * pmask + x[1] * mask
    phih = lambda x: np.concatenate(
        ((pmask * x)[None], (mask * x)[None]), axis=0)

    if recompute_residual:
        image = phi(model)
        residual = R.make_residual(image) / psf_max_mean
        residual_mfs = np.sum(residual, axis=0) / wsum

    # Gaussian "prior" used for preconditioning extended emission
    A = Gauss(args.sig_l2a, args.nband, args.nx, args.ny, args.nthreads)

    #  preconditioning matrix
    def hess(x):
        return phih(psf.convolve(phi(x))) + np.concatenate(
            (x[0:1] / args.sig_l2b**2, A.idot(x[1])[None]), axis=0)
        # return  phih(psf.convolve(phi(x))) + np.concatenate((x[0:1]/args.sig_l2b**2, x[1::]/args.sig_l2a**2), axis=0)

    # M_func = lambda x: np.concatenate((x[0:1] * args.sig_l2b**2, x[1::] * args.sig_l2a**2), axis=0)
    M_func = lambda x: np.concatenate(
        (x[0:1] * args.sig_l2b**2, A.convolve(x[1])[None]), axis=0)

    par_shape = phih(dirty).shape
    if args.beta is None:
        print("Getting spectral norm of update operator")
        beta = power_method(hess,
                            par_shape,
                            tol=args.pmtol,
                            maxit=args.pmmaxit)
    else:
        beta = args.beta
    print(" beta = %f " % beta)

    # set up wavelet basis
    theta = DaskTheta(args.nband, args.nx, args.ny, nthreads=args.nthreads)
    nbasis = theta.nbasis
    weights_21 = np.ones((theta.nbasis + 1, theta.nmax), dtype=np.float64)
    tmp = np.pad(pmask.ravel(), (0, theta.nmax - args.nx * args.ny),
                 mode='constant')
    weights_21[0] = np.where(tmp, args.sig_21b / args.sig_21a, 1e15)
    dual = np.zeros((theta.nbasis + 1, args.nband, theta.nmax),
                    dtype=np.float64)

    # Reporting
    report_iters = list(np.arange(0, args.maxit, args.report_freq))
    if report_iters[-1] != args.maxit - 1:
        report_iters.append(args.maxit - 1)

    # deconvolve
    eps = 1.0
    i = 0
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)
    print("Peak of initial residual is %f and rms is %f" % (rmax, rms))
    for i in range(1, args.maxit):
        x = pcg(hess,
                phih(residual),
                np.zeros(par_shape, dtype=np.float64),
                M=M_func,
                tol=args.cgtol,
                maxit=args.cgmaxit,
                verbosity=args.cgverbose)

        if i in report_iters:
            save_fits(args.outfile + str(i) + '_point_update.fits', x[0], hdr)
            save_fits(args.outfile + str(i) + '_fluff_update.fits', x[1], hdr)

        # update model
        modelp = model
        model = modelp + args.gamma * x
        model, dual = primal_dual(hess,
                                  model,
                                  modelp,
                                  dual,
                                  args.sig_21a,
                                  theta,
                                  weights_21,
                                  beta,
                                  tol=args.pdtol,
                                  maxit=args.pdmaxit,
                                  report_freq=100,
                                  mask=mask,
                                  positivity=args.positivity,
                                  gamma=args.gamma)

        # get residual
        image = phi(model)
        residual = R.make_residual(image) / psf_max_mean

        # check stopping criteria
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        if i in report_iters:
            # save current iteration
            save_fits(args.outfile + str(i) + '_model.fits', image, hdr)

            save_fits(args.outfile + str(i) + '_point.fits', model[0], hdr)
            save_fits(args.outfile + str(i) + '_fluff.fits', model[1], hdr)

            model_mfs = np.mean(image, axis=0)
            save_fits(args.outfile + str(i) + '_model_mfs.fits', model_mfs,
                      hdr_mfs)

            save_fits(args.outfile + str(i) + '_residual.fits', residual, hdr)

            save_fits(args.outfile + str(i) + '_residual_mfs.fits',
                      residual_mfs, hdr_mfs)

        print(
            "At iteration %i peak of residual is %f, rms is %f, current eps is %f"
            % (i, rmax, rms, eps))

        if eps < args.tol:
            break

    if args.interp_model:
        nband = args.nband
        order = args.spectral_poly_order
        mask = np.where(model_mfs > 1e-10, 1, 0)
        I = np.argwhere(mask).squeeze()
        Ix = I[:, 0]
        Iy = I[:, 1]
        npix = I.shape[0]

        # get components
        beta = image[:, Ix, Iy]

        # fit integrated polynomial to model components
        # we are given frequencies at bin centers, convert to bin edges
        ref_freq = np.mean(freq_out)
        delta_freq = freq_out[1] - freq_out[0]
        wlow = (freq_out - delta_freq / 2.0) / ref_freq
        whigh = (freq_out + delta_freq / 2.0) / ref_freq
        wdiff = whigh - wlow

        # set design matrix for each component
        Xdesign = np.zeros([freq_out.size, args.spectral_poly_order])
        for i in range(1, args.spectral_poly_order + 1):
            Xdesign[:, i - 1] = (whigh**i - wlow**i) / (i * wdiff)

        weights = psf_max[:, None]
        dirty_comps = Xdesign.T.dot(weights * beta)

        hess_comps = Xdesign.T.dot(weights * Xdesign)

        comps = np.linalg.solve(hess_comps, dirty_comps)

        np.savez(args.outfile + "spectral_comps",
                 comps=comps,
                 ref_freq=ref_freq,
                 mask=np.any(model, axis=0))

    if args.write_model:
        if args.interp_model:
            R.write_component_model(comps, ref_freq, mask, args.row_chunks,
                                    args.chan_chunks)
        else:
            R.write_model(model)
Example #4
0
def main(args):
    if args.psf_pars is None:
        print("Attempting to take psf_pars from residual fits header")
        try:
            rhdr = fits.getheader(args.residual)
        except KeyError:
            raise RuntimeError("Either provide a residual with beam "
                               "information or pass them in using --psf_pars "
                               "argument")
        if 'BMAJ1' in rhdr.keys():
            emaj = rhdr['BMAJ1']
            emin = rhdr['BMIN1']
            pa = rhdr['BPA1']
            gaussparf = (emaj, emin, pa)
        elif 'BMAJ' in rhdr.keys():
            emaj = rhdr['BMAJ']
            emin = rhdr['BMIN']
            pa = rhdr['BPA']
            gaussparf = (emaj, emin, pa)
    else:
        gaussparf = tuple(args.psf_pars)

    if args.circ_psf:
        e = (gaussparf[0] + gaussparf[1]) / 2.0
        gaussparf[0] = e
        gaussparf[1] = e

    print("Using emaj = %3.2e, emin = %3.2e, PA = %3.2e \n" % gaussparf)

    # load model image
    model = load_fits(args.model, dtype=args.out_dtype)
    model = model.squeeze()
    orig_shape = model.shape
    mhdr = fits.getheader(args.model)

    l_coord, ref_l = data_from_header(mhdr, axis=1)
    l_coord -= ref_l
    m_coord, ref_m = data_from_header(mhdr, axis=2)
    m_coord -= ref_m
    if mhdr["CTYPE4"].lower() == 'freq':
        freq_axis = 4
    elif mhdr["CTYPE3"].lower() == 'freq':
        freq_axis = 3
    else:
        raise ValueError("Freq axis must be 3rd or 4th")

    mfs_shape = list(orig_shape)
    mfs_shape[0] = 1
    mfs_shape = tuple(mfs_shape)
    freqs, ref_freq = data_from_header(mhdr, axis=freq_axis)

    nband = freqs.size
    if nband < 2:
        raise ValueError("Can't produce alpha map from a single band image")
    npix_l = l_coord.size
    npix_m = m_coord.size

    # update cube psf-pars
    for i in range(1, nband + 1):
        mhdr['BMAJ' + str(i)] = gaussparf[0]
        mhdr['BMIN' + str(i)] = gaussparf[1]
        mhdr['BPA' + str(i)] = gaussparf[2]

    if args.ref_freq is not None and args.ref_freq != ref_freq:
        ref_freq = args.ref_freq
        print(
            'Provided reference frequency does not match that of fits file. Will overwrite.'
        )

    print("Cube frequencies:")
    with np.printoptions(precision=2):
        print(freqs)
    print("Reference frequency is %3.2e Hz \n" % ref_freq)

    # LB - new header for cubes if ref_freqs differ
    new_hdr = set_header_info(mhdr, ref_freq, freq_axis, args, gaussparf)

    # save next to model if no outfile is provided
    if args.output_filename is None:
        # strip .fits from model filename
        tmp = args.model[::-1]
        idx = tmp.find('.')
        outfile = args.model[0:-(idx + 1)]
    else:
        outfile = args.output_filename

    xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')

    # load beam
    if args.beam_model is not None:
        bhdr = fits.getheader(args.beam_model)
        l_coord_beam, ref_lb = data_from_header(bhdr, axis=1)
        l_coord_beam -= ref_lb
        if not np.array_equal(l_coord_beam, l_coord):
            raise ValueError(
                "l coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header."
            )

        m_coord_beam, ref_mb = data_from_header(bhdr, axis=2)
        m_coord_beam -= ref_mb
        if not np.array_equal(m_coord_beam, m_coord):
            raise ValueError(
                "m coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header."
            )

        freqs_beam, _ = data_from_header(bhdr, axis=freq_axis)
        if not np.array_equal(freqs, freqs_beam):
            raise ValueError(
                "Freqs of beam model do not match those of image. Use power_beam_maker to interpolate to fits header."
            )

        beam_image = load_fits(args.beam_model,
                               dtype=args.out_dtype).reshape(model.shape)
    else:
        beam_image = np.ones(model.shape, dtype=args.out_dtype)

    # do beam correction LB - TODO: use forward model instead
    beammin = np.amin(beam_image, axis=0)[None, :, :]
    model = np.where(beammin >= args.pb_min, model / beam_image, 0.0)

    if not args.dont_convolve:
        print("Computing clean beam")
        # convolve model to desired resolution
        model, gausskern = convolve2gaussres(model, xx, yy, gaussparf,
                                             args.ncpu, None,
                                             args.padding_frac)

        # save clean beam
        if 'c' in args.products:
            name = outfile + '.clean_psf.fits'
            save_fits(name,
                      gausskern.reshape(mfs_shape),
                      new_hdr,
                      dtype=args.out_dtype)
            print("Wrote clean psf to %s \n" % name)

        # save convolved model
        if 'm' in args.products:
            name = outfile + '.convolved_model.fits'
            save_fits(name,
                      model.reshape(orig_shape),
                      new_hdr,
                      dtype=args.out_dtype)
            print("Wrote convolved model to %s \n" % name)

    # add in residuals and set threshold
    if args.residual is not None:
        resid = load_fits(args.residual, dtype=args.out_dtype).squeeze()
        rhdr = fits.getheader(args.residual)
        l_res, ref_lb = data_from_header(rhdr, axis=1)
        l_res -= ref_lb
        if not np.array_equal(l_res, l_coord):
            raise ValueError(
                "l coordinates of residual do not match those of model")

        m_res, ref_mb = data_from_header(rhdr, axis=2)
        m_res -= ref_mb
        if not np.array_equal(m_res, m_coord):
            raise ValueError(
                "m coordinates of residual do not match those of model")

        freqs_res, _ = data_from_header(rhdr, axis=freq_axis)
        if not np.array_equal(freqs, freqs_res):
            raise ValueError("Freqs of residual do not match those of model")

        # convolve residual to same resolution as model
        gausspari = ()
        for i in range(1, nband + 1):
            key = 'BMAJ' + str(i)
            if key in rhdr.keys():
                emaj = rhdr[key]
                emin = rhdr['BMIN' + str(i)]
                pa = rhdr['BPA' + str(i)]
                gausspari += ((emaj, emin, pa), )
            else:
                print(
                    "Can't find Gausspars in residual header, unable to add residuals back in"
                )
                gausspari = None
                break

        if gausspari is not None and args.add_convolved_residuals:
            resid, _ = convolve2gaussres(resid,
                                         xx,
                                         yy,
                                         gaussparf,
                                         args.ncpu,
                                         gausspari,
                                         args.padding_frac,
                                         norm_kernel=True)
            model += resid
            print("Convolved residuals added to convolved model")

            if 'c' in args.products:
                name = outfile + '.convolved_residual.fits'
                save_fits(name, resid.reshape(orig_shape), rhdr)
                print("Wrote convolved residuals to %s" % name)

        counts = np.sum(resid != 0)
        rms = np.sqrt(np.sum(resid**2) / counts)
        rms_cube = np.std(resid.reshape(nband, npix_l * npix_m),
                          axis=1).ravel()
        threshold = args.threshold * rms
        print("Setting cutoff threshold as %i times the rms "
              "of the residual " % args.threshold)
        del resid
    else:
        print("No residual provided. Setting  threshold i.t.o dynamic range. "
              "Max dynamic range is %i " % args.maxDR)
        threshold = model.max() / args.maxDR
        rms_cube = None

    print("Threshold set to %f Jy. \n" % threshold)

    # get pixels above threshold
    minimage = np.amin(model, axis=0)
    maskindices = np.argwhere(minimage > threshold)
    if not maskindices.size:
        raise ValueError("No components found above threshold. "
                         "Try lowering your threshold."
                         "Max of convolved model is %3.2e" % model.max())
    fitcube = model[:, maskindices[:, 0], maskindices[:, 1]].T

    # set weights for fit
    if rms_cube is not None:
        print("Using RMS in each imaging band to determine weights. \n")
        weights = np.where(rms_cube > 0, 1.0 / rms_cube**2, 0.0)
        # normalise
        weights /= weights.max()
    else:
        if args.channel_weights is not None:
            weights = np.array(args.channel_weights)
            print("Using provided channel weights \n")
        else:
            print(
                "No residual or channel weights provided. Using equal weights. \n"
            )
            weights = np.ones(nband, dtype=np.float64)

    ncomps, _ = fitcube.shape
    fitcube = da.from_array(fitcube.astype(np.float64),
                            chunks=(ncomps // args.ncpu, nband))
    weights = da.from_array(weights.astype(np.float64), chunks=(nband))
    freqsdask = da.from_array(freqs.astype(np.float64), chunks=(nband))

    print("Fitting %i components" % ncomps)
    alpha, alpha_err, Iref, i0_err = fit_spi_components(
        fitcube, weights, freqsdask, np.float64(ref_freq)).compute()
    print("Done. Writing output. \n")

    alphamap = np.zeros(model[0].shape, dtype=model.dtype)
    alpha_err_map = np.zeros(model[0].shape, dtype=model.dtype)
    i0map = np.zeros(model[0].shape, dtype=model.dtype)
    i0_err_map = np.zeros(model[0].shape, dtype=model.dtype)
    alphamap[maskindices[:, 0], maskindices[:, 1]] = alpha
    alpha_err_map[maskindices[:, 0], maskindices[:, 1]] = alpha_err
    i0map[maskindices[:, 0], maskindices[:, 1]] = Iref
    i0_err_map[maskindices[:, 0], maskindices[:, 1]] = i0_err

    if 'I' in args.products:
        # get the reconstructed cube
        Irec_cube = i0map[None, :, :] * \
            (freqs[:, None, None]/ref_freq)**alphamap[None, :, :]
        name = outfile + '.Irec_cube.fits'
        save_fits(name,
                  Irec_cube.reshape(orig_shape),
                  mhdr,
                  dtype=args.out_dtype)
        print("Wrote reconstructed cube to %s" % name)

    # save alpha map
    if 'a' in args.products:
        name = outfile + '.alpha.fits'
        save_fits(name,
                  alphamap.reshape(mfs_shape),
                  mhdr,
                  dtype=args.out_dtype)
        print("Wrote alpha map to %s" % name)

    # save alpha error map
    if 'e' in args.products:
        name = outfile + '.alpha_err.fits'
        save_fits(name,
                  alpha_err_map.reshape(mfs_shape),
                  mhdr,
                  dtype=args.out_dtype)
        print("Wrote alpha error map to %s" % name)

    # save I0 map
    if 'i' in args.products:
        name = outfile + '.I0.fits'
        save_fits(name, i0map.reshape(mfs_shape), mhdr, dtype=args.out_dtype)
        print("Wrote I0 map to %s" % name)

    # save I0 error map
    if 'k' in args.products:
        name = outfile + '.I0_err.fits'
        save_fits(name,
                  i0_err_map.reshape(mfs_shape),
                  mhdr,
                  dtype=args.out_dtype)
        print("Wrote I0 error map to %s" % name)

    print(' \n ')

    print("All done here")
Example #5
0
def main(args):
    # read coords from fits file
    hdr = fits.getheader(args.image)
    l_coord, ref_l = data_from_header(hdr, axis=1)
    l_coord -= ref_l
    m_coord, ref_m = data_from_header(hdr, axis=2)
    m_coord -= ref_m
    if hdr["CTYPE4"].lower() == 'freq':
        freq_axis = 4
    elif hdr["CTYPE3"].lower() == 'freq':
        freq_axis = 3
    else:
        raise ValueError("Freq axis must be 3rd or 4th")
    freqs, ref_freq = data_from_header(hdr, axis=freq_axis)

    nchan = freqs.size
    gausspari = ()
    if freqs.size > 1:
        for i in range(1, nchan + 1):
            key = 'BMAJ' + str(i)
            if key in hdr.keys():
                emaj = hdr[key]
                emin = hdr['BMIN' + str(i)]
                pa = hdr['BPA' + str(i)]
                gausspari += ((emaj, emin, pa), )
    else:
        if 'BMAJ' in hdr.keys():
            emaj = hdr['BMAJ']
            emin = hdr['BMIN']
            pa = hdr['BPA']
            # using key of 1 for consistency with fits standard
            gausspari = ((emaj, emin, pa), )

    if len(gausspari) == 0 and args.psf_pars is None:
        raise ValueError("No psf parameters in fits file and none passed in.")

    if len(gausspari) == 0:
        print(
            "No psf parameters in fits file. Convolving model to resolution specified by psf-pars."
        )
        gaussparf = tuple(args.psf_pars)
    else:
        if args.psf_pars is None:
            gaussparf = gausspari[0]
        else:
            gaussparf = tuple(args.psf_pars)

    if args.circ_psf:
        e = (gaussparf[0] + gaussparf[1]) / 2.0
        gaussparf[0] = e
        gaussparf[1] = e

    print("Using emaj = %3.2e, emin = %3.2e, PA = %3.2e \n" % gaussparf)

    # update header
    if freqs.size > 1:
        for i in range(1, nchan + 1):
            hdr['BMAJ' + str(i)] = gaussparf[0]
            hdr['BMIN' + str(i)] = gaussparf[1]
            hdr['BPA' + str(i)] = gaussparf[2]
    else:
        hdr['BMAJ'] = gaussparf[0]
        hdr['BMIN'] = gaussparf[1]
        hdr['BPA'] = gaussparf[2]

    # coodinate grid
    xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')

    # convolve image
    imagei = load_fits(args.image, dtype=np.float32).squeeze()
    print(imagei.shape)
    image, gausskern = convolve2gaussres(imagei, xx, yy, gaussparf, args.ncpu,
                                         gausspari, args.padding_frac)

    # load beam and correct
    if args.beam_model is not None:
        bhdr = fits.getheader(args.beam_model)
        l_coord_beam, ref_lb = data_from_header(bhdr, axis=1)
        l_coord_beam -= ref_lb
        if not np.array_equal(l_coord_beam, l_coord):
            raise ValueError(
                "l coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header."
            )

        m_coord_beam, ref_mb = data_from_header(bhdr, axis=2)
        m_coord_beam -= ref_mb
        if not np.array_equal(m_coord_beam, m_coord):
            raise ValueError(
                "m coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header."
            )

        freqs_beam, _ = data_from_header(bhdr, axis=freq_axis)
        if not np.array_equal(freqs, freqs_beam):
            raise ValueError(
                "Freqs of beam model do not match those of image. Use power_beam_maker to interpolate to fits header."
            )

        beam_image = load_fits(args.beam_model, dtype=np.float32).squeeze()

        image = np.where(beam_image >= args.pb_min, image / beam_image, 0.0)

    # save next to model if no outfile is provided
    if args.output_filename is None:
        # strip .fits from model filename
        tmp = args.model[::-1]
        idx = tmp.find('.')
        outfile = args.model[0:-idx]
    else:
        outfile = args.output_filename

    # save images
    name = outfile + '.clean_psf.fits'
    save_fits(name, gausskern, hdr)
    print("Wrote clean psf to %s \n" % name)

    name = outfile + '.convolved.fits'
    save_fits(name, image, hdr)
    print("Wrote convolved model to %s \n" % name)

    print("All done here")
Example #6
0
def main(args):
    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list([tmp_freq]))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        print("Super resolution factor = ", cell_N / cell_rad)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size)

    if args.nx is None or args.ny is None:
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = npix
        args.ny = npix

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny))

    # init gridder
    R = Gridder(args.ms,
                args.nx,
                args.ny,
                args.cell_size,
                nband=args.nband,
                nthreads=args.nthreads,
                do_wstacking=args.do_wstacking,
                row_chunks=args.row_chunks,
                data_column=args.data_column,
                weight_column=args.weight_column,
                epsilon=args.epsilon,
                imaging_weight_column=args.imaging_weight_column,
                model_column=args.model_column,
                flag_column=args.flag_column)
    freq_out = R.freq_out
    radec = R.radec

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                      2 * args.nx, 2 * args.ny, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          2 * args.nx, 2 * args.ny, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        try:
            compare_headers(hdr_psf, fits.getheader(args.psf))
            psf_array = load_fits(args.psf)
        except:
            psf_array = R.make_psf()
            save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)
    else:
        psf_array = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)

    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    counts = np.sum(psf_max > 0)
    psf_max_mean = wsum / counts  # normalissation for more intuitive sig_21 values
    psf_array /= psf_max_mean
    psf = PSF(psf_array, args.nthreads)
    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    psf_max[psf_max < 1e-15] = 1e-15  # LB - is this the right thing to do?

    psf_mfs = np.sum(psf_array, axis=0) / wsum
    save_fits(
        args.outfile + '_psf_mfs.fits', psf_mfs[args.nx // 2:3 * args.nx // 2,
                                                args.ny // 2:3 * args.ny // 2],
        hdr_mfs)

    # dirty
    if args.dirty is not None:
        try:
            compare_headers(hdr, fits.getheader(args.dirty))
            dirty = load_fits(args.dirty)
        except:
            dirty = R.make_dirty()
            save_fits(args.outfile + '_dirty.fits', dirty, hdr)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty_mfs = np.sum(dirty / psf_max_mean, axis=0) / wsum
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    if args.x0 is not None:
        try:
            compare_headers(hdr, fits.getheader(args.x0))
            model = load_fits(args.x0, dtype=np.float64)
            if args.first_residual is not None:
                try:
                    compare_headers(hdr, fits.getheader(args.first_residual))
                    residual = load_fits(args.first_residual, dtype=np.float64)
                except:
                    residual = R.make_residual(model)
                    save_fits(args.outfile + '_first_residual.fits', residual,
                              hdr)
            else:
                residual = R.make_residual(model)
                save_fits(args.outfile + '_first_residual.fits', residual, hdr)
        except:
            model = np.zeros((args.nband, args.nx, args.ny))
            residual = dirty.copy()
    else:
        model = np.zeros((args.nband, args.nx, args.ny))
        residual = dirty.copy()

    # normalise for more intuitive hypers
    residual /= psf_max_mean
    residual_mfs = np.sum(residual, axis=0) / wsum
    save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs)

    # mask
    if args.mask is not None:
        mask = load_fits(args.mask, dtype=np.int64)[None, :, :]
        if mask.shape != (1, args.nx, args.ny):
            raise ValueError("Mask has incorrect shape")
    else:
        mask = np.ones((1, args.nx, args.ny), dtype=np.int64)

    #  preconditioning matrix
    def hess(x):
        return mask * psf.convolve(mask * x) + x / args.sig_l2**2

    if args.beta is None:
        print("Getting spectral norm of update operator")
        beta = power_method(hess,
                            dirty.shape,
                            tol=args.pmtol,
                            maxit=args.pmmaxit)
    else:
        beta = args.beta
    print(" beta = %f " % beta)

    # set up wavelet basis
    if args.psi_basis is None:
        print("Using Dirac + db1-4 dictionary")
        psi = DaskPSI(args.nband,
                      args.nx,
                      args.ny,
                      nlevels=args.psi_levels,
                      nthreads=args.nthreads)
        # psi = PSI(args.nband, args.nx, args.ny, nlevels=args.psi_levels)
    else:
        if not isinstance(args.psi_basis, list):
            args.psi_basis = list(args.psi_basis)
        print("Using ", args.psi_basis, " dictionary")
        psi = DaskPSI(args.nband,
                      args.nx,
                      args.ny,
                      nlevels=args.psi_levels,
                      nthreads=args.nthreads,
                      bases=args.psi_basis)
        # psi = PSI(args.nband, args.nx, args.ny, nlevels=args.psi_levels, bases=args.psi_basis)
    nbasis = psi.nbasis
    weights_21 = np.ones((psi.nbasis, psi.nmax), dtype=np.float64)
    dual = np.zeros((psi.nbasis, args.nband, psi.nmax), dtype=np.float64)

    # Reweighting
    if args.reweight_iters is not None:
        if not isinstance(args.reweight_iters, list):
            reweight_iters = [args.reweight_iters]
        else:
            reweight_iters = list(args.reweight_iters)
    else:
        reweight_iters = list(
            np.arange(args.reweight_start, args.reweight_end,
                      args.reweight_freq))
        reweight_iters.append(args.reweight_end)

    # Reporting
    report_iters = list(np.arange(0, args.maxit, args.report_freq))
    if report_iters[-1] != args.maxit - 1:
        report_iters.append(args.maxit - 1)

    # deconvolve
    eps = 1.0
    i = 0
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)
    M = lambda x: x * args.sig_l2**2  # preconditioner
    print("Peak of initial residual is %f and rms is %f" % (rmax, rms))
    for i in range(1, args.maxit):
        x = pcg(hess,
                mask * residual,
                np.zeros(dirty.shape, dtype=np.float64),
                M=M,
                tol=args.cgtol,
                maxit=args.cgmaxit,
                minit=args.cgminit,
                verbosity=args.cgverbose)

        if i in report_iters:
            save_fits(args.outfile + str(i) + '_update.fits', x, hdr)

        # update model
        modelp = model
        model = modelp + args.gamma * x
        model, dual = primal_dual(hess,
                                  model,
                                  modelp,
                                  dual,
                                  args.sig_21,
                                  psi,
                                  weights_21,
                                  beta,
                                  tol=args.pdtol,
                                  maxit=args.pdmaxit,
                                  report_freq=100,
                                  mask=mask,
                                  positivity=args.positivity)

        # reweighting
        if i in reweight_iters:
            v = psi.hdot(model)
            l2_norm = norm(v, axis=1)
            l2_norm = np.where(l2_norm < args.sig_21 * weights_21, 0.0,
                               l2_norm)
            for m in range(psi.nbasis):
                indnz = l2_norm[m].nonzero()
                alpha = np.percentile(l2_norm[m, indnz].flatten(),
                                      args.reweight_alpha_percent)
                alpha = np.maximum(alpha, args.reweight_alpha_min)
                print("Reweighting - ", m, alpha)
                weights_21[m] = alpha / (l2_norm[m] + alpha)
            args.reweight_alpha_percent *= args.reweight_alpha_ff
            # print(" reweight alpha percent = ", args.reweight_alpha_percent)

        # get residual
        residual = R.make_residual(model) / psf_max_mean

        # check stopping criteria
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        if i in report_iters:
            # save current iteration
            save_fits(args.outfile + str(i) + '_model.fits', model, hdr)

            model_mfs = np.mean(model, axis=0)
            save_fits(args.outfile + str(i) + '_model_mfs.fits', model_mfs,
                      hdr_mfs)

            save_fits(args.outfile + str(i) + '_residual.fits', residual, hdr)

            save_fits(args.outfile + str(i) + '_residual_mfs.fits',
                      residual_mfs, hdr_mfs)

        print(
            "At iteration %i peak of residual is %f, rms is %f, current eps is %f"
            % (i, rmax, rms, eps))

    if args.write_model:
        R.write_model(model)

    if args.make_restored:
        x = pcg(hess,
                residual,
                np.zeros(dirty.shape, dtype=np.float64),
                M=M,
                tol=args.cgtol,
                maxit=args.cgmaxit)
        restored = model + x

        # get residual
        residual = R.make_residual(restored) / psf_max_mean
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)

        print("After restoring peak of residual is %f and rms is %f" %
              (rmax, rms))

        # save current iteration
        save_fits(args.outfile + '_restored.fits', restored, hdr)

        restored_mfs = np.mean(restored, axis=0)
        save_fits(args.outfile + '_restored_mfs.fits', restored_mfs, hdr_mfs)

        save_fits(args.outfile + '_restored_residual.fits', residual, hdr)

        save_fits(args.outfile + '_restored_residual_mfs.fits', residual_mfs,
                  hdr_mfs)
Example #7
0
def main(args):
    # load dirty and psf
    dirty = load_fits(args.dirty)
    real_type = dirty.dtype
    hdr = fits.getheader(args.dirty)
    freq = data_from_header(hdr, axis=3)
    l_coord = data_from_header(hdr, axis=1)
    m_coord = data_from_header(hdr, axis=2)

    nband, nx, ny = dirty.shape
    psf_array = load_fits(args.psf)
    hdr_psf = fits.getheader(args.psf)
    try:
        assert np.array_equal(freq, data_from_header(hdr_psf, axis=3))
    except:
        raise ValueError("Fits frequency axes dont match")

    print("Image size is (%i, %i, %i)" % (nband, nx, ny))

    psf_max = np.amax(psf_array.reshape(nband, 4 * nx * ny), axis=1)
    wsum = np.sum(psf_max)
    psf_max[psf_max < 1e-15] = 1e-15

    dirty_mfs = np.sum(dirty, axis=0) / wsum
    rmax = np.abs(dirty_mfs).max()
    rms = np.std(dirty_mfs)
    print("Peak of dirty is %f and rms is %f" % (rmax, rms))

    psf_mfs = np.sum(psf_array, axis=0) / wsum

    # set operators
    psf = PSF(psf_array, args.ncpu, sigma0=args.sig_l2)
    K = Prior(args.sig_l2, nband, nx, ny, nthreads=args.ncpu)

    # def hess(x):
    #     return psf.convolve(x) + K.iconvolve(x)

    # get Lipschitz constant
    if args.beta is None:
        from pfb.opt import power_method
        beta = power_method(psf.hess,
                            dirty.shape,
                            tol=args.pmtol,
                            maxit=args.pmmaxit)
    else:
        beta = args.beta
    print("beta = ", beta)

    # Reweighting
    if args.reweight_iters is not None:
        reweight_iters = args.reweight_iters
    else:
        reweight_iters = list(
            np.arange(args.reweight_start, args.reweight_end,
                      args.reweight_freq))
        reweight_iters.append(args.reweight_end)

    # Reporting
    report_iters = list(np.arange(0, args.maxit, args.report_freq))
    if report_iters[-1] != args.maxit - 1:
        report_iters.append(args.maxit - 1)

    # set up wavelet basis
    if args.use_psi:
        nband, nx, ny = dirty.shape
        psi = PSI(nband, nx, ny, nlevels=args.psi_levels)
        nbasis = psi.nbasis
        weights_21 = np.ones((psi.nbasis, psi.nmax), dtype=real_type)
    else:
        psi = None
        weights_21 = np.ones(nx * ny, dtype=real_type)

    # initalise model
    if args.x0 is None:
        model = np.zeros(dirty.shape, dtype=real_type)
        dual = np.zeros((psi.nbasis, nband, psi.nmax), dtype=real_type)
        residual = dirty
    else:
        compare_headers(hdr, fits.getheader(args.x0))
        model = load_fits(args.x0).astype(real_type)
        dual = np.zeros((psi.nbasis, nband, psi.nmax), dtype=real_type)
        residual = dirty - psf.convolve(model)

    residual_mfs = np.sum(residual, axis=0) / wsum
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)
    print("At iteration 0 peak of residual is %f and rms is %f" % (rmax, rms))

    # deconvolve
    for k in range(args.maxit):
        x = pcg(psf.hess,
                residual,
                np.zeros(dirty.shape, dtype=real_type),
                M=K.dot,
                tol=args.cgtol,
                maxit=args.cgmaxit,
                verbosity=args.cgverbose)

        modelp = model.copy()
        model = modelp + args.gamma * x

        if args.use_psi:
            model, dual = simple_pd(psf.hess,
                                    model,
                                    modelp,
                                    dual,
                                    args.sig_21,
                                    psi,
                                    weights_21,
                                    beta,
                                    tol=args.pdtol,
                                    maxit=args.pdmaxit,
                                    report_freq=10)
        else:
            model = prox_21(model,
                            args.sig_21,
                            weights_21,
                            psi=psi,
                            positivity=True)

        # convergence check
        normx = norm(model)
        if np.isnan(normx) or normx == 0.0:
            normx = 1.0

        eps = norm(model - modelp) / normx
        if eps < args.tol:
            break

        # reweighting
        if k in reweight_iters:
            if psi is None:
                l2norm = norm(model.reshape(nband, npix), axis=0)
                weights_21 = 1.0 / (l2norm + alpha)
            else:
                v = psi.hdot(model)
                l2_norm = norm(v, axis=1)
                for m in range(psi.nbasis):
                    indnz = l2_norm[m].nonzero()
                    alpha = np.percentile(l2_norm[m, indnz].flatten(),
                                          args.reweight_alpha_percent)
                    alpha = np.maximum(alpha, args.reweight_alpha_min)
                    # alpha = args.reweight_alpha_min
                    print("Reweighting - ", m, alpha)
                    weights_21[m] = 1.0 / (l2_norm[m] + alpha)
                args.reweight_alpha_percent *= args.reweight_alpha_ff
                print(" reweight alpha percent = ",
                      args.reweight_alpha_percent)

        # get residual
        residual = dirty - psf.convolve(model)

        # check stopping criteria
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)

        # reporting
        if k in report_iters:
            save_fits(args.outfile + str(k + 1) + '_model.fits',
                      model,
                      hdr,
                      dtype=real_type)

            model_mfs = np.mean(model, axis=0)
            save_fits(args.outfile + str(k + 1) + '_model_mfs.fits', model_mfs,
                      hdr)

            save_fits(args.outfile + str(k + 1) + '_update.fits', x, hdr)

            save_fits(args.outfile + str(k + 1) + '_residual.fits',
                      residual,
                      hdr,
                      dtype=real_type)

            save_fits(args.outfile + str(k + 1) + '_residual_mfs.fits',
                      residual_mfs, hdr)

        print(
            "At iteration %i peak of residual is %f, rms is %f, current eps is %f"
            % (k + 1, rmax, rms, eps))

    # save final results
    save_fits(args.outfile + '_model.fits', model, hdr, dtype=real_type)

    residual = dirty - psf.convolve(model)

    save_fits(args.outfile + '_residual.fits',
              residual / psf_max[:, None, None], hdr)