Beispiel #1
0
def solve_image(vis: Visibility,
                model: Image,
                components=None,
                predict=predict_2d,
                invert=invert_2d,
                **kwargs) -> (Visibility, Image, Image):
    """Solve for image using deconvolve_cube and specified predict, invert

    This is the same as a majorcycle/minorcycle algorithm. The components are removed prior to deconvolution.
    
    See also arguments for predict, invert, deconvolve_cube functions.2d

    :param vis:
    :param model: Model image
    :param predict: Predict function e.g. predict_2d, predict_wstack
    :param invert: Invert function e.g. invert_2d, invert_wstack
    :return: Visibility, model
    """
    nmajor = get_parameter(kwargs, 'nmajor', 5)
    log.info("solve_image: Performing %d major cycles" % nmajor)

    # The model is added to each major cycle and then the visibilities are
    # calculated from the full model
    vispred = copy_visibility(vis)
    visres = copy_visibility(vis)

    vispred = predict(vispred, model, **kwargs)

    if components is not None:
        vispred = predict_skycomponent_visibility(vispred, components)

    visres.data['vis'] = vis.data['vis'] - vispred.data['vis']
    dirty, sumwt = invert(visres, model, **kwargs)
    psf, sumwt = invert(visres, model, dopsf=True, **kwargs)

    thresh = get_parameter(kwargs, "threshold", 0.0)

    for i in range(nmajor):
        log.info("solve_image: Start of major cycle %d" % i)
        cc, res = deconvolve_cube(dirty, psf, **kwargs)
        res = None
        model.data += cc.data
        vispred = predict(vispred, model, **kwargs)
        visres.data['vis'] = vis.data['vis'] - vispred.data['vis']
        dirty, sumwt = invert(visres, model, **kwargs)
        if numpy.abs(dirty.data).max() < 1.1 * thresh:
            log.info("Reached stopping threshold %.6f Jy" % thresh)
            break
        log.info("solve_image: End of major cycle")

    log.info("solve_image: End of major cycles")
    return visres, model, dirty
Beispiel #2
0
 def _predict_base(self, context='2d', extra='', fluxthreshold=1.0):
     self.modelvis = copy_visibility(self.componentvis, zero=True)
     self.modelvis.data['vis'] *= 0.0
     self.modelvis = predict_function(self.modelvis,
                                      self.model,
                                      context=context,
                                      **self.params)
     self.residualvis = copy_visibility(self.componentvis, zero=True)
     self.residualvis.data['uvw'][:, 2] = 0.0
     self.residualvis.data[
         'vis'] = self.modelvis.data['vis'] - self.componentvis.data['vis']
     self._checkdirty(self.residualvis,
                      'predict_%s%s' % (context, extra),
                      fluxthreshold=fluxthreshold)
 def test_calibrate_function(self):
     self.actualSetup('stokesI', 'stokesI', f=[100.0])
     # Prepare the corrupted visibility data
     gt = create_gaintable_from_blockvisibility(self.vis)
     log.info("Created gain table: %s" % (gaintable_summary(gt)))
     gt = simulate_gaintable(gt,
                             phase_error=10.0,
                             amplitude_error=0.1,
                             timeslice='auto')
     bgt = simulate_gaintable(gt,
                              phase_error=0.1,
                              amplitude_error=0.01,
                              timeslice=1e5)
     original = copy_visibility(self.vis)
     self.vis = apply_gaintable(self.vis, bgt, timeslice=1e5)
     self.vis = apply_gaintable(self.vis, gt, timeslice='auto')
     # Now get the control dictionary and calibrate
     controls = create_calibration_controls()
     controls['T']['first_selfcal'] = 0
     controls['B']['first_selfcal'] = 0
     calibrated_vis, gaintables = calibrate_function(self.vis,
                                                     original,
                                                     context='TB',
                                                     controls=controls)
     residual = numpy.max(gaintables['T'].residual)
     assert residual < 3e-2, "Max T residual = %s" % (residual)
     residual = numpy.max(gaintables['B'].residual)
     assert residual < 6e-5, "Max B residual = %s" % (residual)
 def predict_ignore_none(vis, model):
     if vis is not None:
         predicted = copy_visibility(vis)
         predicted = predict_context(predicted, model, context=context, **kwargs)
         return predicted
     else:
         return None
Beispiel #5
0
 def predict_and_sum(vis, model, **kwargs):
     if vis is not None:
         predicted = copy_visibility(vis)
         predicted = predict(predicted, model, **kwargs)
         return predicted
     else:
         return None
Beispiel #6
0
def sum_visibility(vis: Visibility, direction: SkyCoord) -> numpy.array:
    """ Direct Fourier summation in a given direction

    :param vis: Visibility to be summed
    :param direction: Direction of summation
    :return: flux[nch,npol], weight[nch,pol]
    """
    # TODO: Convert to Visibility or remove?

    svis = copy_visibility(vis)

    l, m, n = skycoord_to_lmn(direction, svis.phasecentre)
    phasor = numpy.conjugate(simulate_point(svis.uvw, l, m))

    # Need to put correct mapping here
    _, frequency = get_frequency_map(svis, None)

    frequency = list(frequency)

    nchan = max(frequency) + 1
    npol = svis.polarisation_frame.npol

    flux = numpy.zeros([nchan, npol])
    weight = numpy.zeros([nchan, npol])

    coords = svis.vis, svis.weight, phasor, list(frequency)
    for v, wt, p, ic in zip(*coords):
        for pol in range(npol):
            flux[ic, pol] += numpy.real(wt[pol] * v[pol] * p)
            weight[ic, pol] += wt[pol]

    flux[weight > 0.0] = flux[weight > 0.0] / weight[weight > 0.0]
    flux[weight <= 0.0] = 0.0
    return flux, weight
Beispiel #7
0
def peel_skycomponent_blockvisibility(vis: BlockVisibility, sc: Union[Skycomponent, List[Skycomponent]], remove=True)\
        -> (BlockVisibility, List[GainTable]):
    """ Peel a collection of components.
    
    Sequentially solve the gain towards each Skycomponent and optionally remove the corrupted visibility from the
    observed visibility.

    :param params:
    :param vis: Visibility to be processed
    :param sc: Skycomponent or list of Skycomponents
    :return: subtracted visibility and list of GainTables
    """
    assert isinstance(
        vis, BlockVisibility), "vis is not a BlockVisibility: %r" % vis

    if not isinstance(sc, collections.Iterable):
        sc = [sc]

    gtlist = []
    for comp in sc:
        assert comp.shape == 'Point', "Cannot handle shape %s" % comp.shape

        modelvis = copy_visibility(vis, zero=True)
        modelvis = predict_skycomponent_blockvisibility(modelvis, comp)
        gt = solve_gaintable(vis, modelvis, phase_only=False)
        modelvis = apply_gaintable(modelvis, gt)
        if remove:
            vis.data['vis'] -= modelvis.data['vis']
        gtlist.append(gt)

    return vis, gtlist
Beispiel #8
0
 def core_solve(self,
                spf,
                dpf,
                phase_error=0.1,
                amplitude_error=0.0,
                leakage=0.0,
                phase_only=True,
                niter=200,
                crosspol=False,
                residual_tol=1e-6,
                f=[100.0, 50.0, -10.0, 40.0]):
     self.actualSetup(spf, dpf, f=f)
     gt = create_gaintable_from_blockvisibility(self.vis)
     log.info("Created gain table: %s" % (gaintable_summary(gt)))
     gt = simulate_gaintable(gt,
                             phase_error=phase_error,
                             amplitude_error=amplitude_error,
                             leakage=leakage)
     original = copy_visibility(self.vis)
     vis = apply_gaintable(self.vis, gt)
     gtsol = solve_gaintable(self.vis,
                             original,
                             phase_only=phase_only,
                             niter=niter,
                             crosspol=crosspol,
                             tol=1e-6)
     vis = apply_gaintable(vis, gtsol, inverse=True)
     residual = numpy.max(gtsol.residual)
     assert residual < residual_tol, "%s %s Max residual = %s" % (spf, dpf,
                                                                  residual)
     log.debug(qa_gaintable(gt))
     assert numpy.max(numpy.abs(gtsol.gain - 1.0)) > 0.1
def safe_predict_list(vis_list, model, predict=predict_2d, **kwargs):
    """ Predicts a list of visibilities to obtain a list of visibilities

    Can be used in bag.map()

    :param vis_list:
    :param model:
    :param predict:
    :param kwargs:
    :return: List of visibilities
    """
    assert isinstance(vis_list, collections.Iterable
                      ), "Visibility list is not Iterable: %s" % str(vis_list)

    assert isinstance(model, Image), "Model is not an image: %s" % model

    result = list()
    for v in vis_list:
        if v is not None:
            predicted = copy_visibility(v)
            result.append(predict(predicted, model, **kwargs))

    assert len(
        result
    ) > 0, "Visibility after concatenation is empty, input list is %s" % str(
        vis_list)

    return result
Beispiel #10
0
def calibrate_visibility(vt: Visibility,
                         model: Image = None,
                         components=None,
                         predict=predict_2d,
                         **kwargs) -> Visibility:
    """ calibrate Visibility with respect to model and optionally components

    :param vt: Visibility
    :param model: Model image
    :param components: Sky components
    :return: Calibrated visibility
    """
    assert model is not None or components is not None, "calibration requires a model or skycomponents"

    vtpred = copy_visibility(vt, zero=True)

    if model is not None:
        vtpred = predict(vtpred, model, **kwargs)
        if components is not None:
            vtpred = predict_skycomponent_visibility(vtpred, components)
    else:
        vtpred = predict_skycomponent_visibility(vtpred, components)

    bvt = decoalesce_visibility(vt)
    bvtpred = decoalesce_visibility(vtpred)
    gt = solve_gaintable(bvt, bvtpred, **kwargs)
    bvt = apply_gaintable(bvt,
                          gt,
                          inverse=get_parameter(kwargs, "inverse", False))
    return convert_blockvisibility_to_visibility(bvt)
Beispiel #11
0
def calibrate_blockvisibility(bvt: BlockVisibility,
                              model: Image = None,
                              components=None,
                              predict=predict_2d,
                              **kwargs) -> BlockVisibility:
    """ calibrate BlockVisibility with respect to model and optionally components

    :param bvt: BlockVisibility
    :param model: Model image
    :param components: Sky components
    :return: Calibrated BlockVisibility

    """
    assert model is not None or components is not None, "calibration requires a model or skycomponents"

    if model is not None:
        vtpred = convert_blockvisibility_to_visibility(bvt)
        vtpred = predict(vtpred, model, **kwargs)
        bvtpred = decoalesce_visibility(vtpred)
        if components is not None:
            bvtpred = predict_skycomponent_blockvisibility(bvtpred, components)
    else:
        bvtpred = copy_visibility(bvt, zero=True)
        bvtpred = predict_skycomponent_blockvisibility(bvtpred, components)

    gt = solve_gaintable(bvt, bvtpred, **kwargs)
    return apply_gaintable(bvt, gt, **kwargs)
Beispiel #12
0
def subtract_visibility(vis, model_vis, inplace=False):
    """ Subtract model_vis from vis, returning new visibility
    
    :param vis:
    :param model_vis:
    :return:
    """
    if isinstance(vis, Visibility):
        assert isinstance(model_vis, Visibility), model_vis
    elif isinstance(vis, BlockVisibility):
        assert isinstance(model_vis, BlockVisibility), model_vis
    else:
        raise RuntimeError("Types of vis and model visibility are invalid")

    assert vis.vis.shape == model_vis.vis.shape, "Observed %s and model visibilities %s have different shapes"\
        % (vis.vis.shape, model_vis.vis.shape)

    if inplace:
        vis.data['vis'] = vis.data['vis'] - model_vis.data['vis']
        return vis
    else:
        residual_vis = copy_visibility(vis)
        residual_vis.data[
            'vis'] = residual_vis.data['vis'] - model_vis.data['vis']
        return residual_vis
Beispiel #13
0
 def zerovis(vis):
     if vis is not None:
         zerovis = copy_visibility(vis)
         zerovis.data['vis'][...] = 0.0
         return zerovis
     else:
         return None
Beispiel #14
0
 def predict_facets_and_accumulate(vis, model, **kwargs):
     if vis is not None:
         predicted = copy_visibility(vis)
         predicted = predict(predicted, model, **kwargs)
         return predicted
     else:
         return None
Beispiel #15
0
 def subtract_vis(vis, model_vis):
     if vis is not None and model_vis is not None:
         assert vis.vis.shape == model_vis.vis.shape
         subvis = copy_visibility(vis)
         subvis.data['vis'][...] -= model_vis.data['vis'][...]
         return subvis
     else:
         return None
Beispiel #16
0
 def predict_facets_and_accumulate(vis, model, **kwargs):
     if vis is not None:
         predicted = copy_visibility(vis)
         predicted = predict(predicted, model, **kwargs)
         vis.data['vis'] += predicted.data['vis']
         return vis
     else:
         return None
Beispiel #17
0
 def test_copy_visibility(self):
     self.vis = create_visibility(self.lowcore, self.times, self.frequency,
                                  channel_bandwidth=self.channel_bandwidth, phasecentre=self.phasecentre, weight=1.0,
                                  polarisation_frame=PolarisationFrame("stokesIQUV"))
     vis = copy_visibility(self.vis)
     self.vis.data['vis'] = 0.0
     vis.data['vis'] = 1.0
     assert (vis.data['vis'][0, 0].real == 1.0)
     assert (self.vis.data['vis'][0, 0].real == 0.0)
 def test_apply_gaintable_null(self):
     for spf, dpf in[('stokesI', 'stokesI'), ('stokesIQUV', 'linear'), ('stokesIQUV', 'circular')]:
         self.actualSetup(spf, dpf)
         gt = create_gaintable_from_blockvisibility(self.vis, timeslice='auto')
         gt.data['gain']*=0.0
         original = copy_visibility(self.vis)
         vis = apply_gaintable(self.vis, gt, inverse=True, timeslice='auto')
         error = numpy.max(numpy.abs(vis.vis[:,0,1,...] - original.vis[:,0,1,...]))
         assert error < 1e-12, "Error = %s" % (error)
Beispiel #19
0
 def test_create_gaintable_from_visibility(self):
     for spf, dpf in [('stokesIQUV', 'linear'), ('stokesIQUV', 'circular')]:
         self.actualSetup(spf, dpf)
         gt = create_gaintable_from_blockvisibility(self.vis)
         log.info("Created gain table: %s" % (gaintable_summary(gt)))
         gt = simulate_gaintable(gt, phase_error=0.1)
         original = copy_visibility(self.vis)
         vis = apply_gaintable(self.vis, gt)
         assert numpy.max(numpy.abs(vis.vis - original.vis)) > 0.0
 def test_apply_gaintable_only(self):
     for spf, dpf in[('stokesI', 'stokesI'), ('stokesIQUV', 'linear'), ('stokesIQUV', 'circular')]:
         self.actualSetup(spf, dpf)
         gt = create_gaintable_from_blockvisibility(self.vis, timeslice='auto')
         log.info("Created gain table: %s" % (gaintable_summary(gt)))
         gt = simulate_gaintable(gt, phase_error=0.1, amplitude_error=0.01, timeslice='auto')
         original = copy_visibility(self.vis)
         vis = apply_gaintable(self.vis, gt, timeslice='auto')
         error = numpy.max(numpy.abs(vis.vis - original.vis))
         assert error > 10.0, "Error = %f" % (error)
Beispiel #21
0
 def test_apply_gaintable_and_inverse_both(self):
     for spf, dpf in [('stokesIQUV', 'linear'), ('stokesIQUV', 'circular')]:
         self.actualSetup(spf, dpf)
         gt = create_gaintable_from_blockvisibility(self.vis)
         log.info("Created gain table: %s" % (gaintable_summary(gt)))
         gt = simulate_gaintable(gt, phase_error=0.1, amplitude_error=0.1)
         original = copy_visibility(self.vis)
         vis = apply_gaintable(self.vis, gt)
         vis = apply_gaintable(self.vis, gt, inverse=True)
         error = numpy.max(numpy.abs(vis.vis - original.vis))
         assert error < 1e-12, "Error = %s" % (error)
Beispiel #22
0
 def test_solve_gaintable_scalar(self):
     self.actualSetup('stokesI', 'stokesI', f=[100.0])
     gt = create_gaintable_from_blockvisibility(self.vis)
     log.info("Created gain table: %s" % (gaintable_summary(gt)))
     gt = simulate_gaintable(gt, phase_error=10.0, amplitude_error=0.0)
     original = copy_visibility(self.vis)
     self.vis = apply_gaintable(self.vis, gt)
     gtsol = solve_gaintable(self.vis, original, phase_only=True, niter=200)
     residual = numpy.max(gtsol.residual)
     assert residual < 3e-8, "Max residual = %s" % (residual)
     assert numpy.max(numpy.abs(gtsol.gain - 1.0)) > 0.1
 def test_create_gaintable_from_visibility_interval(self):
     for timeslice in [10.0, 'auto', 1e5]:
         for spf, dpf in[('stokesIQUV', 'linear')]:
             self.actualSetup(spf, dpf)
             gt = create_gaintable_from_blockvisibility(self.vis, timeslice=timeslice)
             log.info("Created gain table: %s" % (gaintable_summary(gt)))
             gt = simulate_gaintable(gt, phase_error=1.0, timeslice=timeslice)
             original = copy_visibility(self.vis)
             vis = apply_gaintable(self.vis, gt, timeslice=timeslice)
             assert numpy.max(numpy.abs(original.vis)) > 0.0
             assert numpy.max(numpy.abs(vis.vis)) > 0.0
             assert numpy.max(numpy.abs(vis.vis - original.vis)) > 0.0
Beispiel #24
0
    def test_predict_2d(self):
        # Test if the 2D prediction works
        #
        # Set w=0 so that the two-dimensional transform should agree exactly with the component transform.
        # Good check on the grid correction in the image->vis direction
        # Set all w to zero
        self.actualSetUp()
        self.componentvis.data['uvw'][:, 2] = 0.0
        # Predict the visibility using direct evaluation
        self.componentvis.data['vis'][...] = 0.0
        self.componentvis = predict_skycomponent_visibility(
            self.componentvis, self.components)

        self.modelvis = copy_visibility(self.componentvis, zero=True)
        self.modelvis.data['uvw'][:, 2] = 0.0
        self.modelvis = predict_2d(self.modelvis, self.model, **self.params)
        self.residualvis = copy_visibility(self.componentvis, zero=True)
        self.residualvis.data['uvw'][:, 2] = 0.0
        self.residualvis.data[
            'vis'] = self.modelvis.data['vis'] - self.componentvis.data['vis']

        self._checkdirty(self.residualvis, 'predict_2d')
def sum_predict_results(results):
    """ Sum a set of predict results

    :param results: List of visibilities to be summed
    :return: summed visibility
    """
    sum_results = None
    for result in results:
        if result is not None:
            if sum_results is None:
                sum_results = copy_visibility(result)
            else:
                sum_results.data['vis'] += result.data['vis']
    
    return sum_results
def residual_image(vis: Visibility, model: Image, invert_residual=invert_2d, predict_residual=predict_2d,
                   **kwargs) -> Image:
    """Calculate residual image and visibility

    :param vis: Visibility to be inverted
    :param im: image template (not changed)
    :param invert: invert to be used (default invert_2d)
    :param predict: predict to be used (default predict_2d)
    :return: residual visibility, residual image, sum of weights
    """
    visres = copy_visibility(vis, zero=True)
    visres = predict_residual(visres, model, **kwargs)
    visres.data['vis'] = vis.data['vis'] - visres.data['vis']
    dirty, sumwt = invert_residual(visres, model, dopsf=False, **kwargs)
    return visres, dirty, sumwt
Beispiel #27
0
def subtract_visibility(vis, model_vis, inplace=False):
    """ Subtract model_vis from vis, returning new visibility
    
    :param vis:
    :param model_vis:
    :return:
    """
    assert isinstance(vis, Visibility) or isinstance(vis, BlockVisibility), vis
    
    if inplace:
        vis.data['vis'] = vis.data['vis'] - model_vis.data['vis']
        return vis
    else:
        residual_vis = copy_visibility(vis)
        residual_vis.data['vis'] = residual_vis.data['vis'] - model_vis.data['vis']
        return residual_vis
Beispiel #28
0
def spectral_line_imaging(vis: Visibility,
                          model: Image,
                          continuum_model: Image = None,
                          continuum_components=None,
                          predict=predict_2d,
                          invert=invert_2d,
                          deconvolve_spectral=False,
                          **kwargs) -> (Image, Image, Image):
    """Spectral line imaging from calibrated (DIE) data
    
    A continuum model can be subtracted, and deconvolution is optional.
    
    If deconvolve_spectral is True then the solve_image is used to deconvolve.
    If deconvolve_spectral is False then the residual image after continuum subtraction is calculated
    
    :param vis: Visibility
    :param continuum_model: model continuum image to be subtracted
    :param continuum_components: mode components to be subtracted
    :param spectral_model: model spectral image
    :param predict: Predict fumction e.g. predict_2d
    :param invert: Invert function e.g. invert_wprojection
    :return: Residual visibility, spectral model image, spectral residual image
    """

    vis_no_continuum = copy_visibility(vis)
    if continuum_model is not None:
        vis_no_continuum = predict(vis_no_continuum, continuum_model, **kwargs)
    if continuum_components is not None:
        vis_no_continuum = predict_skycomponent_visibility(
            vis_no_continuum, continuum_components)
    vis_no_continuum.data[
        'vis'] = vis.data['vis'] - vis_no_continuum.data['vis']

    if deconvolve_spectral:
        log.info(
            "spectral_line_imaging: Deconvolving continuum subtracted visibility"
        )
        vis_no_continuum, spectral_model, spectral_residual = solve_image(
            vis_no_continuum, model, **kwargs)
    else:
        log.info(
            "spectral_line_imaging: Making dirty image from continuum subtracted visibility"
        )
        spectral_model, spectral_residual = \
            invert(vis_no_continuum, model, **kwargs)

    return vis_no_continuum, spectral_model, spectral_residual
Beispiel #29
0
def decoalesce_visibility(vis: Visibility, overwrite=False) -> BlockVisibility:
    """ Decoalesce the visibilities to the original values (opposite of coalesce_visibility)
    
    This relies upon the block vis and the index being part of the vis.
    
    'uv': Needs the original image used in coalesce_visibility
    'tb': Needs the index generated by coalesce_visibility

    :param vis: (Coalesced visibility)
    :return: BlockVisibility with vis and weight columns overwritten
    """
    # 去合并,元数据不变,vis按照cindex填入
    assert type(vis) is Visibility, "vis is not a Visibility: %r" % vis
    assert type(vis.blockvis) is BlockVisibility, "No blockvisibility in vis"
    assert vis.cindex is not None, "No reverse index in Visibility"

    if overwrite:
        log.debug(
            'decoalesce_visibility: Created new Visibility for decoalesced data'
        )
        decomp_vis = copy_visibility(vis.blockvis)
    else:
        log.debug(
            'decoalesce_visibility: Filled decoalesced data into template')
        decomp_vis = vis.blockvis

    vshape = decomp_vis.data['vis'].shape

    npol = vshape[-1]
    dvis = numpy.zeros(vshape, dtype='complex')
    assert numpy.max(vis.cindex) < dvis.size
    # print(vis.cindex)
    # TODO 感觉此处逻辑有误
    # for i in range(dvis.size // npol):
    #     decomp_vis.data['vis'].flat[i:i + npol] = vis.data['vis'][vis.cindex[i]]

    # 修改后
    for i in range(dvis.size // npol):
        decomp_vis.data['vis'].flat[i * npol:i * npol +
                                    npol] = vis.data['vis'][vis.cindex[i]]

    log.debug('decoalesce_visibility: Coalesced %s, decoalesced %s' %
              (vis_summary(vis), vis_summary(decomp_vis)))

    return decomp_vis
def predict_wstack_single(vis, model, remove=True, **kwargs) -> Visibility:
    """ Predict using a single w slices.
    
    This processes a single w plane, rotating out the w beam for the average w

    :param vis: Visibility to be predicted
    :param model: model image
    :return: resulting visibility (in place works)
    """

    if not isinstance(vis, Visibility):
        log.debug("predict_wstack_single: Coalescing")
        avis = coalesce_visibility(vis, **kwargs)
    else:
        avis = vis

    log.debug("predict_wstack_single: predicting using single w slice")

    avis.data['vis'] *= 0.0
    # We might want to do wprojection so we remove the average w
    w_average = numpy.average(avis.w)
    avis.data['uvw'][..., 2] -= w_average
    tempvis = copy_visibility(avis)

    # Calculate w beam and apply to the model. The imaginary part is not needed
    workimage = copy_image(model)
    w_beam = create_w_term_like(model, w_average, vis.phasecentre)

    # Do the real part
    workimage.data = w_beam.data.real * model.data
    avis = predict_2d_base(avis, workimage, **kwargs)

    # and now the imaginary part
    workimage.data = w_beam.data.imag * model.data
    tempvis = predict_2d_base(tempvis, workimage, **kwargs)
    avis.data['vis'] -= 1j * tempvis.data['vis']

    if not remove:
        avis.data['uvw'][..., 2] += w_average

    if isinstance(vis, BlockVisibility) and isinstance(avis, Visibility):
        log.debug("imaging.predict decoalescing post prediction")
        return decoalesce_visibility(avis)
    else:
        return avis