Exemplo n.º 1
0
    def __init__(self, psf, imsize, nthreads=1, backward_undersize=None):
        self.nthreads = nthreads
        self.nband, nx_psf, ny_psf = psf.shape
        _, nx, ny = imsize
        npad_xl = (nx_psf - nx) // 2
        npad_xr = nx_psf - nx - npad_xl
        npad_yl = (ny_psf - ny) // 2
        npad_yr = ny_psf - ny - npad_yl
        self.padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr))
        self.ax = (1, 2)
        self.unpad_x = slice(npad_xl, -npad_xr)
        self.unpad_y = slice(npad_yl, -npad_yr)
        self.lastsize = ny + np.sum(self.padding[-1])
        self.psf = psf
        psf_pad = iFs(psf, axes=self.ax)
        self.psfhat = r2c(psf_pad,
                          axes=self.ax,
                          forward=True,
                          nthreads=nthreads,
                          inorm=0)

        # LB - failed experiment?
        # self.psfhatinv = 1/(self.psfhat + 1.0)

        if backward_undersize is not None:
            # set up for backward step
            nx_psfb = good_size(int(backward_undersize * nx))
            ny_psfb = good_size(int(backward_undersize * ny))
            npad_xlb = (nx_psfb - nx) // 2
            npad_xrb = nx_psfb - nx - npad_xlb
            npad_ylb = (ny_psfb - ny) // 2
            npad_yrb = ny_psfb - ny - npad_ylb
            self.paddingb = ((0, 0), (npad_xlb, npad_xrb), (npad_ylb,
                                                            npad_yrb))
            self.unpad_xb = slice(npad_xlb, -npad_xrb)
            self.unpad_yb = slice(npad_ylb, -npad_yrb)
            self.lastsizeb = ny + np.sum(self.paddingb[-1])

            xlb = (nx_psf - nx_psfb) // 2
            xrb = nx_psf - nx_psfb - xlb
            ylb = (ny_psf - ny_psfb) // 2
            yrb = ny_psf - ny_psfb - ylb
            psf_padb = iFs(psf[:, slice(xlb, -xrb),
                               slice(ylb, -yrb)],
                           axes=self.ax)
            self.psfhatb = r2c(psf_padb,
                               axes=self.ax,
                               forward=True,
                               nthreads=nthreads,
                               inorm=0)
        else:
            self.paddingb = self.padding
            self.unpad_xb = self.unpad_x
            self.unpad_yb = self.unpad_y
            self.lastsizeb = self.lastsize
            self.psfhatb = self.psfhat
Exemplo n.º 2
0
def _hessian(x, psfhat, padding, nthreads, unpad_x, unpad_y, lastsize):
    xhat = iFs(np.pad(x, padding, mode='constant'), axes=(1, 2))
    xhat = r2c(xhat, axes=(1, 2), nthreads=nthreads, forward=True, inorm=0)
    xhat = c2r(xhat * psfhat,
               axes=(1, 2),
               forward=False,
               lastsize=lastsize,
               inorm=2,
               nthreads=nthreads)
    return Fs(xhat, axes=(1, 2))[:, unpad_x, unpad_y]
Exemplo n.º 3
0
def convolve2gaussres(image, xx, yy, gaussparf, nthreads, gausspari=None, pfrac=0.5, norm_kernel=False):
    """
    Convolves the image to a specified resolution.
    
    Parameters
    ----------
    Image - (nband, nx, ny) array to convolve
    xx/yy - coordinates on the grid in the same units as gaussparf.
    gaussparf - tuple containing Gaussian parameters of desired resolution (emaj, emin, pa).
    gausspari - initial resolution . By default it is assumed that the image is a clean component image with no associated resolution. 
                If beampari is specified, it must be a tuple containing gausspars for each imaging band in the same format.
    nthreads - number of threads to use for the FFT's.
    pfrac - padding used for the FFT based convolution. Will pad by pfrac/2 on both sides of image 
    """
    nband, nx, ny = image.shape
    padding, unpad_x, unpad_y = get_padding_info(nx, ny, pfrac)
    ax = (1, 2)  # axes over which to perform fft
    lastsize = ny + np.sum(padding[-1])

    gausskern = Gaussian2D(xx, yy, gaussparf, normalise=norm_kernel)
    gausskern = np.pad(gausskern[None], padding, mode='constant')
    gausskernhat = r2c(iFs(gausskern, axes=ax), axes=ax, forward=True, nthreads=nthreads, inorm=0)

    image = np.pad(image, padding, mode='constant')
    imhat = r2c(iFs(image, axes=ax), axes=ax, forward=True, nthreads=nthreads, inorm=0)

    # convolve to desired resolution
    if gausspari is None:
        imhat *= gausskernhat
    else:
        for i in range(nband):
            thiskern = Gaussian2D(xx, yy, gausspari[i], normalise=norm_kernel)
            thiskern = np.pad(thiskern[None], padding, mode='constant')
            thiskernhat = r2c(iFs(thiskern, axes=ax), axes=ax, forward=True, nthreads=nthreads, inorm=0)

            convkernhat = np.where(np.abs(thiskernhat)>0.0, gausskernhat/thiskernhat, 0.0)

            imhat[i] *= convkernhat[0]

    image = Fs(c2r(imhat, axes=ax, forward=False, lastsize=lastsize, inorm=2, nthreads=nthreads), axes=ax)[:, unpad_x, unpad_y]

    return image, gausskern[:, unpad_x, unpad_y]
Exemplo n.º 4
0
    def __init__(self, sigma0, nband, nx, ny, nthreads=8):
        self.nthreads = nthreads
        self.nx = nx
        self.ny = ny
        nx_psf = 2 * self.nx
        npad_x = (nx_psf - nx) // 2
        ny_psf = 2 * self.ny
        npad_y = (ny_psf - ny) // 2
        self.padding = ((0, 0), (npad_x, npad_x), (npad_y, npad_y))
        self.ax = (1, 2)

        self.unpad_x = slice(npad_x, -npad_x)
        self.unpad_y = slice(npad_y, -npad_y)
        self.lastsize = ny + np.sum(self.padding[-1])

        # set length scales
        length_scale = 0.5

        K = make_kernel(nx_psf, ny_psf, sigma0, length_scale)

        self.K = K
        K_pad = iFs(self.K, axes=self.ax)
        self.Khat = r2c(K_pad,
                        axes=self.ax,
                        forward=True,
                        nthreads=nthreads,
                        inorm=0)
        self.Khatinv = np.where(self.Khat.real > 1e-14, 1.0 / self.Khat, 1e-14)

        # get covariance in each dimension
        # pixel coordinates
        self.Kv = mock_array(nband)  # np.eye(nband) * sigma0**2
        self.Kvinv = mock_array(nband)  # np.eye(nband) / sigma0**2
        if nx == ny:
            l_coord = m_coord = np.arange(-(nx // 2), nx // 2)
            self.Kl = self.Km = expsq(l_coord, l_coord, 1.0, length_scale)
            self.Klinv = self.Kminv = np.linalg.pinv(self.Kl,
                                                     hermitian=True,
                                                     rcond=1e-12)
            self.Kl *= sigma0**2
            self.Klinv /= sigma0**2
        else:
            l_coord = np.arange(-(nx // 2), nx // 2)
            m_coord = np.arange(-(ny // 2), ny // 2)

            self.Kl = expsq(l_coord, l_coord, sigma0, length_scale)
            self.Km = expsq(m_coord, m_coord, 1.0, length_scale)
            self.Klinv = np.linalg.pinv(self.Kl, hermitian=True, rcond=1e-12)
            self.Kminv = np.linalg.pinv(self.Km, hermitian=True, rcond=1e-12)

        # Kronecker matrices for "fast" matrix vector products
        self.Kkron = (self.Kv, self.Kl, self.Km)
        self.Kinvkron = (self.Kvinv, self.Klinv, self.Kminv)
Exemplo n.º 5
0
 def convolveb(self, x):
     xhat = iFs(np.pad(x, self.paddingb, mode='constant'), axes=self.ax)
     xhat = r2c(xhat,
                axes=self.ax,
                nthreads=self.nthreads,
                forward=True,
                inorm=0)
     xhat = c2r(xhat * self.psfhatb,
                axes=self.ax,
                forward=False,
                lastsize=self.lastsizeb,
                inorm=2,
                nthreads=self.nthreads)
     return Fs(xhat, axes=self.ax)[:, self.unpad_xb, self.unpad_yb]
Exemplo n.º 6
0
 def __init__(self, psf, nthreads, sigma0=1.0):
     self.nthreads = nthreads
     self.nband, nx_psf, ny_psf = psf.shape
     nx = nx_psf // 2
     ny = ny_psf // 2
     npad_x = (nx_psf - nx) // 2
     npad_y = (ny_psf - ny) // 2
     self.padding = ((0, 0), (npad_x, npad_x), (npad_y, npad_y))
     self.ax = (1, 2)
     self.unpad_x = slice(npad_x, -npad_x)
     self.unpad_y = slice(npad_y, -npad_y)
     self.lastsize = ny + np.sum(self.padding[-1])
     self.psf = psf
     psf_pad = iFs(psf, axes=self.ax)
     self.psfhat = r2c(psf_pad,
                       axes=self.ax,
                       forward=True,
                       nthreads=nthreads,
                       inorm=0)
Exemplo n.º 7
0
def _forward(**kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)

    import numpy as np
    import numexpr as ne
    import dask
    import dask.array as da
    from dask.distributed import performance_report
    from pfb.utils.fits import load_fits, set_wcs, save_fits, data_from_header
    from pfb.opt.hogbom import hogbom
    from astropy.io import fits

    print("Loading residual", file=log)
    residual = load_fits(args.residual, dtype=args.output_type).squeeze()
    nband, nx, ny = residual.shape
    hdr = fits.getheader(args.residual)

    print("Loading psf", file=log)
    psf = load_fits(args.psf, dtype=args.output_type).squeeze()
    _, nx_psf, ny_psf = psf.shape
    hdr_psf = fits.getheader(args.psf)

    wsums = np.amax(psf.reshape(-1, nx_psf*ny_psf), axis=1)
    wsum = np.sum(wsums)

    psf /= wsum
    psf_mfs = np.sum(psf, axis=0)

    assert (psf_mfs.max() - 1.0) < 1e-4

    residual /= wsum
    residual_mfs = np.sum(residual, axis=0)

    # get info required to set WCS
    ra = np.deg2rad(hdr['CRVAL1'])
    dec = np.deg2rad(hdr['CRVAL2'])
    radec = [ra, dec]

    cell_deg = np.abs(hdr['CDELT1'])
    if cell_deg != np.abs(hdr['CDELT2']):
        raise NotImplementedError('cell sizes have to be equal')
    cell_rad = np.deg2rad(cell_deg)

    l_coord, ref_l = data_from_header(hdr, axis=1)
    l_coord -= ref_l
    m_coord, ref_m = data_from_header(hdr, axis=2)
    m_coord -= ref_m
    freq_out, ref_freq = data_from_header(hdr, axis=3)

    hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq)

    save_fits(args.output_filename + '_residual_mfs.fits', residual_mfs, hdr_mfs,
              dtype=args.output_type)

    rms = np.std(residual_mfs)
    rmax = np.abs(residual_mfs).max()

    print("Initial peak residual = %f, rms = %f" % (rmax, rms), file=log)

    # load beam
    if args.beam_model is not None:
        if args.beam_model.endswith('.fits'):  # beam already interpolated
            bhdr = fits.getheader(args.beam_model)
            l_coord_beam, ref_lb = data_from_header(bhdr, axis=1)
            l_coord_beam -= ref_lb
            if not np.array_equal(l_coord_beam, l_coord):
                raise ValueError("l coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.")

            m_coord_beam, ref_mb = data_from_header(bhdr, axis=2)
            m_coord_beam -= ref_mb
            if not np.array_equal(m_coord_beam, m_coord):
                raise ValueError("m coordinates of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.")

            freq_beam, _ = data_from_header(bhdr, axis=freq_axis)
            if not np.array_equal(freq_out, freq_beam):
                raise ValueError("Freqs of beam model do not match those of image. Use power_beam_maker to interpolate to fits header.")

            beam_image = load_fits(args.beam_model, dtype=args.output_type).squeeze()
        elif args.beam_model.lower() == "jimbeam":
            from katbeam import JimBeam
            if args.band.lower() == 'l':
                beam = JimBeam('MKAT-AA-L-JIM-2020')
            elif args.band.lower() == 'uhf':
                beam = JimBeam('MKAT-AA-UHF-JIM-2020')
            else:
                raise ValueError("Unkown band %s"%args.band[i])

            xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')
            beam_image = np.zeros(residual.shape, dtype=args.output_type)
            for v in range(freq_out.size):
                # freq must be in MHz
                beam_image[v] = beam.I(xx, yy, freq_out[v]/1e6).astype(args.output_type)
    else:
        beam_image = np.ones((nband, nx, ny), dtype=args.output_type)

    if args.mask is not None:
        mask = load_fits(args.mask).squeeze()
        assert mask.shape == (nx, ny)
        beam_image *= mask[None, :, :]

    beam_image = da.from_array(beam_image, chunks=(1, -1, -1))

    # if weight table is provided we use the vis space Hessian approximation
    if args.weight_table is not None:
        print("Solving for update using vis space approximation", file=log)
        normfact = wsum
        from pfb.utils.misc import plan_row_chunk
        from daskms.experimental.zarr import xds_from_zarr

        xds = xds_from_zarr(args.weight_table)[0]
        nrow = xds.row.size
        freq = xds.chan.data
        nchan = freq.size

        # bin edges
        fmin = freq.min()
        fmax = freq.max()
        fbins = np.linspace(fmin, fmax, nband + 1)

        # chan <-> band mapping
        band_mapping = {}
        chan_chunks = {}
        freq_bin_idx = {}
        freq_bin_counts = {}
        band_map = np.zeros(freq.size, dtype=np.int32)
        for band in range(nband):
            indl = freq >= fbins[band]
            indu = freq < fbins[band + 1] + 1e-6
            band_map = np.where(indl & indu, band, band_map)

        # to dask arrays
        bands, bin_counts = np.unique(band_map, return_counts=True)
        band_mapping = tuple(bands)
        chan_chunks = {'chan': tuple(bin_counts)}
        freq = da.from_array(freq, chunks=tuple(bin_counts))
        bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1]
        freq_bin_idx = da.from_array(bin_idx, chunks=1)
        freq_bin_counts = da.from_array(bin_counts, chunks=1)

        max_chan_chunk = bin_counts.max()
        bin_counts = tuple(bin_counts)
        # the first factor of 3 accounts for the intermediate visibilities
        # produced in Hessian (i.e. complex data + real weights)
        memory_per_row = (3 * max_chan_chunk * xds.WEIGHT.data.itemsize +
                          3 * xds.UVW.data.itemsize)

        # get approx image size
        pixel_bytes = np.dtype(args.output_type).itemsize
        band_size = nx * ny * pixel_bytes

        if args.host_address is None:
            # nworker bands on single node
            row_chunk = plan_row_chunk(args.mem_limit/args.nworkers, band_size, nrow,
                                       memory_per_row, args.nthreads_per_worker)
        else:
            # single band per node
            row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                       memory_per_row, args.nthreads_per_worker)

        print("nrows = %i, row chunks set to %i for a total of %i chunks per node" %
              (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log)

        residual = da.from_array(residual, chunks=(1, -1, -1))
        x0 = da.zeros((nband, nx, ny), chunks=(1, -1, -1), dtype=residual.dtype)

        xds = xds_from_zarr(args.weight_table, chunks={'row': -1, #row_chunk,
                            'chan': bin_counts})[0]

        from pfb.opt.pcg import pcg_wgt

        model = pcg_wgt(xds.UVW.data,
                        xds.WEIGHT.data.astype(args.output_type),
                        residual,
                        x0,
                        beam_image,
                        freq,
                        freq_bin_idx,
                        freq_bin_counts,
                        cell_rad,
                        args.wstack,
                        args.epsilon,
                        args.double_accum,
                        args.nvthreads,
                        args.sigmainv,
                        wsum,
                        args.cg_tol,
                        args.cg_maxit,
                        args.cg_minit,
                        args.cg_verbose,
                        args.cg_report_freq,
                        args.backtrack).compute()

    else:  # we use the image space approximation
        print("Solving for update using image space approximation", file=log)
        normfact = 1.0
        from pfb.operators.psf import hessian
        from ducc0.fft import r2c
        iFs = np.fft.ifftshift

        npad_xl = (nx_psf - nx)//2
        npad_xr = nx_psf - nx - npad_xl
        npad_yl = (ny_psf - ny)//2
        npad_yr = ny_psf - ny - npad_yl
        padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr))
        unpad_x = slice(npad_xl, -npad_xr)
        unpad_y = slice(npad_yl, -npad_yr)
        lastsize = ny + np.sum(padding[-1])
        psf_pad = iFs(psf, axes=(1, 2))
        psfhat = r2c(psf_pad, axes=(1, 2), forward=True,
                     nthreads=nthreads, inorm=0)

        psfhat = da.from_array(psfhat, chunks=(1, -1, -1))
        residual = da.from_array(residual, chunks=(1, -1, -1))
        x0 = da.zeros((nband, nx, ny), chunks=(1, -1, -1))


        from pfb.opt.pcg import pcg_psf

        model = pcg_psf(psfhat,
                        residual,
                        x0,
                        beam_image,
                        args.sigmainv,
                        args.nvthreads,
                        padding,
                        unpad_x,
                        unpad_y,
                        lastsize,
                        args.cg_tol,
                        args.cg_maxit,
                        args.cg_minit,
                        args.cg_verbose,
                        args.cg_report_freq,
                        args.backtrack).compute()


    print("Saving results", file=log)
    save_fits(args.output_filename + '_update.fits', model, hdr)
    model_mfs = np.mean(model, axis=0)
    save_fits(args.output_filename + '_update_mfs.fits', model_mfs, hdr_mfs)

    print("All done here.", file=log)
Exemplo n.º 8
0
def rfftn(a, axes=None, inorm=0, nthreads=1):
    return fft.r2c(a, axes=axes, forward=True, inorm=inorm, nthreads=nthreads)
Exemplo n.º 9
0
def _clean(**kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)

    import numpy as np
    import numexpr as ne
    import dask
    import dask.array as da
    from dask.distributed import performance_report
    from pfb.utils.fits import load_fits, set_wcs, save_fits, data_from_header
    from pfb.opt.hogbom import hogbom
    from astropy.io import fits

    print("Loading dirty", file=log)
    dirty = load_fits(args.dirty, dtype=args.output_type).squeeze()
    nband, nx, ny = dirty.shape
    hdr = fits.getheader(args.dirty)

    print("Loading psf", file=log)
    psf = load_fits(args.psf, dtype=args.output_type).squeeze()
    _, nx_psf, ny_psf = psf.shape
    hdr_psf = fits.getheader(args.psf)

    wsums = np.amax(psf.reshape(-1, nx_psf * ny_psf), axis=1)
    wsum = np.sum(wsums)

    psf /= wsum
    psf_mfs = np.sum(psf, axis=0)

    assert (psf_mfs.max() - 1.0) < 1e-4

    dirty /= wsum
    dirty_mfs = np.sum(dirty, axis=0)

    # get info required to set WCS
    ra = np.deg2rad(hdr['CRVAL1'])
    dec = np.deg2rad(hdr['CRVAL2'])
    radec = [ra, dec]

    cell_deg = np.abs(hdr['CDELT1'])
    if cell_deg != np.abs(hdr['CDELT2']):
        raise NotImplementedError('cell sizes have to be equal')
    cell_rad = np.deg2rad(cell_deg)

    freq_out, ref_freq = data_from_header(hdr, axis=3)

    hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq)

    save_fits(args.output_filename + '_dirty_mfs.fits',
              dirty_mfs,
              hdr_mfs,
              dtype=args.output_type)

    # set up Hessian approximation
    if args.weight_table is not None:
        normfact = wsum
        from africanus.gridding.wgridder.dask import hessian
        from pfb.utils.misc import plan_row_chunk
        from daskms.experimental.zarr import xds_from_zarr

        xds = xds_from_zarr(args.weight_table)[0]
        nrow = xds.row.size
        freqs = xds.chan.data
        nchan = freqs.size

        # bin edges
        fmin = freqs.min()
        fmax = freqs.max()
        fbins = np.linspace(fmin, fmax, nband + 1)

        # chan <-> band mapping
        band_mapping = {}
        chan_chunks = {}
        freq_bin_idx = {}
        freq_bin_counts = {}
        band_map = np.zeros(freqs.size, dtype=np.int32)
        for band in range(nband):
            indl = freqs >= fbins[band]
            indu = freqs < fbins[band + 1] + 1e-6
            band_map = np.where(indl & indu, band, band_map)

        # to dask arrays
        bands, bin_counts = np.unique(band_map, return_counts=True)
        band_mapping = tuple(bands)
        chan_chunks = {'chan': tuple(bin_counts)}
        freqs = da.from_array(freqs, chunks=tuple(bin_counts))
        bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1]
        freq_bin_idx = da.from_array(bin_idx, chunks=1)
        freq_bin_counts = da.from_array(bin_counts, chunks=1)

        max_chan_chunk = bin_counts.max()
        bin_counts = tuple(bin_counts)
        # the first factor of 3 accounts for the intermediate visibilities
        # produced in Hessian (i.e. complex data + real weights)
        memory_per_row = (3 * max_chan_chunk * xds.WEIGHT.data.itemsize +
                          3 * xds.UVW.data.itemsize)

        # get approx image size
        pixel_bytes = np.dtype(args.output_type).itemsize
        band_size = nx * ny * pixel_bytes

        if args.host_address is None:
            # nworker bands on single node
            row_chunk = plan_row_chunk(args.mem_limit / args.nworkers,
                                       band_size, nrow, memory_per_row,
                                       args.nthreads_per_worker)
        else:
            # single band per node
            row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                       memory_per_row,
                                       args.nthreads_per_worker)

        print(
            "nrows = %i, row chunks set to %i for a total of %i chunks per node"
            % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
            file=log)

        def convolver(x):
            model = da.from_array(x, chunks=(1, nx, ny), name=False)

            xds = xds_from_zarr(args.weight_table,
                                chunks={
                                    'row': row_chunk,
                                    'chan': bin_counts
                                })[0]

            convolvedim = hessian(xds.UVW.data,
                                  freqs,
                                  model,
                                  freq_bin_idx,
                                  freq_bin_counts,
                                  cell_rad,
                                  weights=xds.WEIGHT.data.astype(
                                      args.output_type),
                                  nthreads=args.nvthreads,
                                  epsilon=args.epsilon,
                                  do_wstacking=args.wstack,
                                  double_accum=args.double_accum)
            return convolvedim
    else:
        normfact = 1.0
        from pfb.operators.psf import hessian
        from ducc0.fft import r2c
        iFs = np.fft.ifftshift

        npad_xl = (nx_psf - nx) // 2
        npad_xr = nx_psf - nx - npad_xl
        npad_yl = (ny_psf - ny) // 2
        npad_yr = ny_psf - ny - npad_yl
        padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr))
        unpad_x = slice(npad_xl, -npad_xr)
        unpad_y = slice(npad_yl, -npad_yr)
        lastsize = ny + np.sum(padding[-1])
        psf_pad = iFs(psf, axes=(1, 2))
        psfhat = r2c(psf_pad,
                     axes=(1, 2),
                     forward=True,
                     nthreads=nthreads,
                     inorm=0)

        psfhat = da.from_array(psfhat, chunks=(1, -1, -1))

        def convolver(x):
            model = da.from_array(x, chunks=(1, nx, ny), name=False)

            convolvedim = hessian(model, psfhat, padding, nvthreads, unpad_x,
                                  unpad_y, lastsize)
            return convolvedim

        # psfo = PSF(psf, dirty.shape, nthreads=args.nthreads)
        # def convolver(x): return psfo.convolve(x)

    rms = np.std(dirty_mfs)
    rmax = np.abs(dirty_mfs).max()

    print("Iter %i: peak residual = %f, rms = %f" % (0, rmax, rms), file=log)

    residual = dirty.copy()
    residual_mfs = dirty_mfs.copy()
    model = np.zeros_like(residual)
    for k in range(args.nmiter):
        print("Running Hogbom", file=log)
        x = hogbom(residual,
                   psf,
                   gamma=args.hb_gamma,
                   pf=args.hb_peak_factor,
                   maxit=args.hb_maxit,
                   verbosity=args.hb_verbose,
                   report_freq=args.hb_report_freq)

        model += x
        print("Getting residual", file=log)

        convimage = convolver(model)
        dask.visualize(convimage,
                       filename=args.output_filename + '_hessian' + str(k) +
                       '_graph.pdf',
                       optimize_graph=False)
        with performance_report(filename=args.output_filename + '_hessian' +
                                str(k) + '_per.html'):
            convimage = dask.compute(convimage, optimize_graph=False)[0]
        ne.evaluate('dirty - convimage/normfact',
                    out=residual,
                    casting='same_kind')
        ne.evaluate('sum(residual, axis=0)',
                    out=residual_mfs,
                    casting='same_kind')

        rms = np.std(residual_mfs)
        rmax = np.abs(residual_mfs).max()

        print("Iter %i: peak residual = %f, rms = %f" % (k + 1, rmax, rms),
              file=log)

    print("Saving results", file=log)
    save_fits(args.output_filename + '_model.fits', model, hdr)
    model_mfs = np.mean(model, axis=0)
    save_fits(args.output_filename + '_model_mfs.fits', model_mfs, hdr_mfs)
    save_fits(args.output_filename + '_residual.fits',
              residual * wsums[:, None, None], hdr)
    save_fits(args.output_filename + '_residual.fits', residual_mfs, hdr_mfs)

    print("All done here.", file=log)