def gather_image_iteration_results(results, template_model):
     result = create_empty_image_like(template_model)
     flat = create_empty_image_like(template_model)
     i = 0
     sumwt = numpy.zeros([template_model.nchan, template_model.npol])
     for dpatch in image_scatter_facets(result,
                                        facets=facets,
                                        overlap=overlap,
                                        taper=taper):
         assert i < len(
             results), "Too few results in gather_image_iteration_results"
         if results[i] is not None:
             assert len(results[i]) == 2, results[i]
             dpatch.data[...] += results[i][0].data[...]
             sumwt += results[i][1]
             i += 1
     flat = image_gather_facets(results,
                                flat,
                                facets=facets,
                                overlap=overlap,
                                taper=taper,
                                return_flat=True)
     result.data[flat.data > 0.5] /= flat.data[flat.data > 0.5]
     result.data[flat.data <= 0.5] = 0.0
     return result, sumwt
예제 #2
0
 def gather_image_iteration_results(results, template_model):
     result = create_empty_image_like(template_model)
     i = 0
     sumwt = numpy.zeros([template_model.nchan, template_model.npol])
     for dpatch in image_scatter_facets(result, facets=facets):
         assert i < len(results), "Too few results in gather_image_iteration_results"
         if results[i] is not None:
             assert len(results[i]) == 2, results[i]
             dpatch.data[...] = results[i][0].data[...]
             sumwt += results[i][1]
             i += 1
     return result, sumwt
예제 #3
0
def deconvolve_list_serial_workflow(dirty_list,
                                    psf_list,
                                    model_imagelist,
                                    prefix='',
                                    mask=None,
                                    **kwargs):
    """Create a graph for deconvolution, adding to the model

    :param dirty_list: list of dirty images
    :param psf_list: list of psfs
    :param model_imagelist: list of models
    :param prefix: Informative prefix to log messages
    :param mask: Mask for deconvolution
    :param kwargs: Parameters for functions
    :return: List of deconvolved images

    For example::

        dirty_imagelist = invert_list_serial_workflow(vis_list, model_imagelist, context='2d',
                                                          dopsf=False, normalize=True)
        psf_imagelist = invert_list_serial_workflow(vis_list, model_imagelist, context='2d',
                                                        dopsf=True, normalize=True)
        dec_imagelist = deconvolve_list_serial_workflow(dirty_imagelist, psf_imagelist,
                model_imagelist, niter=1000, fractional_threshold=0.01,
                scales=[0, 3, 10], algorithm='mmclean', nmoment=3, nchan=freqwin,
                threshold=0.1, gain=0.7)

    """
    nchan = len(dirty_list)
    nmoment = get_parameter(kwargs, "nmoment", 0)

    assert isinstance(dirty_list, list), dirty_list
    assert isinstance(psf_list, list), psf_list
    assert isinstance(model_imagelist, list), model_imagelist

    def deconvolve(dirty, psf, model, facet, gthreshold, msk=None):
        if prefix == '':
            lprefix = "subimage %d" % facet
        else:
            lprefix = "%s, subimage %d" % (prefix, facet)

        if nmoment > 0:
            moment0 = calculate_image_frequency_moments(dirty)
            this_peak = numpy.max(numpy.abs(
                moment0.data[0, ...])) / dirty.data.shape[0]
        else:
            ref_chan = dirty.data.shape[0] // 2
            this_peak = numpy.max(numpy.abs(dirty.data[ref_chan, ...]))

        if this_peak > 1.1 * gthreshold:
            kwargs['threshold'] = gthreshold
            result, _ = deconvolve_cube(dirty,
                                        psf,
                                        prefix=lprefix,
                                        mask=msk,
                                        **kwargs)

            if result.data.shape[0] == model.data.shape[0]:
                result.data += model.data
            return result
        else:

            return copy_image(model)

    deconvolve_facets = get_parameter(kwargs, 'deconvolve_facets', 1)
    deconvolve_overlap = get_parameter(kwargs, 'deconvolve_overlap', 0)
    deconvolve_taper = get_parameter(kwargs, 'deconvolve_taper', None)
    if deconvolve_facets > 1 and deconvolve_overlap > 0:
        deconvolve_number_facets = (deconvolve_facets - 2)**2
    else:
        deconvolve_number_facets = deconvolve_facets**2

    model_imagelist = image_gather_channels(model_imagelist)

    # Scatter the separate channel images into deconvolve facets and then gather channels for each facet.
    # This avoids constructing the entire spectral cube.
    dirty_list_trimmed = remove_sumwt(dirty_list)
    scattered_channels_facets_dirty_list = \
        [image_scatter_facets(d, facets=deconvolve_facets,
                              overlap=deconvolve_overlap,
                              taper=deconvolve_taper)
         for d in dirty_list_trimmed]

    # Now we do a transpose and gather
    scattered_facets_list = [
        image_gather_channels([
            scattered_channels_facets_dirty_list[chan][facet]
            for chan in range(nchan)
        ]) for facet in range(deconvolve_number_facets)
    ]

    psf_list_trimmed = remove_sumwt(psf_list)
    psf_list_trimmed = image_gather_channels(psf_list_trimmed)

    scattered_model_imagelist = \
        image_scatter_facets(model_imagelist,
                             facets=deconvolve_facets,
                             overlap=deconvolve_overlap)

    # Work out the threshold. Need to find global peak over all dirty_list images
    threshold = get_parameter(kwargs, "threshold", 0.0)
    fractional_threshold = get_parameter(kwargs, "fractional_threshold", 0.1)
    nmoment = get_parameter(kwargs, "nmoment", 0)
    use_moment0 = nmoment > 0

    # Find the global threshold. This uses the peak in the average on the frequency axis since we
    # want to use it in a stopping criterion in a moment clean
    global_threshold = threshold_list(scattered_facets_list,
                                      threshold,
                                      fractional_threshold,
                                      use_moment0=use_moment0,
                                      prefix=prefix)

    facet_list = numpy.arange(deconvolve_number_facets).astype('int')
    if mask is None:
        scattered_results_list = [
            deconvolve(d, psf_list_trimmed, m, facet, global_threshold)
            for d, m, facet in zip(scattered_facets_list,
                                   scattered_model_imagelist, facet_list)
        ]
    else:
        mask_list = \
            image_scatter_facets(mask,
                                 facets=deconvolve_facets,
                                 overlap=deconvolve_overlap)
        scattered_results_list = [
            deconvolve(d, psf_list_trimmed, m, facet, global_threshold,
                       msk) for d, m, facet, msk in zip(
                           scattered_facets_list, scattered_model_imagelist,
                           facet_list, mask_list)
        ]

    # Gather the results back into one image, correcting for overlaps as necessary. The taper function is is used to
    # feather the facets together
    gathered_results_list = image_gather_facets(scattered_results_list,
                                                model_imagelist,
                                                facets=deconvolve_facets,
                                                overlap=deconvolve_overlap,
                                                taper=deconvolve_taper)

    return image_scatter_channels(gathered_results_list, subimages=nchan)
예제 #4
0
def predict_list_serial_workflow(vis_list,
                                 model_imagelist,
                                 context,
                                 vis_slices=1,
                                 facets=1,
                                 gcfcf=None,
                                 **kwargs):
    """Predict, iterating over both the scattered vis_list and image

    The visibility and image are scattered, the visibility is predicted on each part, and then the
    parts are assembled.

    :param vis_list: list of vis
    :param model_imagelist: Model used to determine image parameters
    :param vis_slices: Number of vis slices (w stack or timeslice)
    :param facets: Number of facets (per axis)
    :param context: Type of processing e.g. 2d, wstack, timeslice or facets
    :param gcfcg: tuple containing grid correction and convolution function
    :param kwargs: Parameters for functions in components
    :return: List of vis_lists
   """

    assert len(vis_list) == len(
        model_imagelist), "Model must be the same length as the vis_list"

    # Predict_2d does not clear the vis so we have to do it here.
    vis_list = zero_list_serial_workflow(vis_list)

    c = imaging_context(context)
    vis_iter = c['vis_iterator']
    predict = c['predict']

    if facets % 2 == 0 or facets == 1:
        actual_number_facets = facets
    else:
        actual_number_facets = facets - 1

    def predict_ignore_none(vis, model, g):
        if vis is not None:
            assert isinstance(vis, Visibility) or isinstance(
                vis, BlockVisibility), vis
            assert isinstance(model, Image), model
            return predict(vis, model, context=context, gcfcf=g, **kwargs)
        else:
            return None

    if gcfcf is None:
        gcfcf = [create_pswf_convolutionfunction(m) for m in model_imagelist]

    # Loop over all frequency windows
    if facets == 1:
        image_results_list = list()
        for ivis, sub_vis_list in enumerate(vis_list):
            if len(gcfcf) > 1:
                g = gcfcf[ivis]
            else:
                g = gcfcf[0]
            # Loop over sub visibility
            vis_predicted = copy_visibility(sub_vis_list, zero=True)
            for rows in vis_iter(sub_vis_list, vis_slices):
                row_vis = create_visibility_from_rows(sub_vis_list, rows)
                row_vis_predicted = predict_ignore_none(
                    row_vis, model_imagelist[ivis], g)
                if row_vis_predicted is not None:
                    vis_predicted.data['vis'][
                        rows, ...] = row_vis_predicted.data['vis']
            image_results_list.append(vis_predicted)

        return image_results_list
    else:
        image_results_list = list()
        for ivis, sub_vis_list in enumerate(vis_list):
            # Create the graph to divide an image into facets. This is by reference.
            facet_lists = image_scatter_facets(model_imagelist[ivis],
                                               facets=facets)
            facet_vis_lists = list()
            sub_vis_lists = visibility_scatter(sub_vis_list, vis_iter,
                                               vis_slices)

            # Loop over sub visibility
            for sub_sub_vis_list in sub_vis_lists:
                facet_vis_results = list()
                # Loop over facets
                for facet_list in facet_lists:
                    # Predict visibility for this subvisibility from this facet
                    facet_vis_list = predict_ignore_none(
                        sub_sub_vis_list, facet_list, None)
                    facet_vis_results.append(facet_vis_list)
                # Sum the current sub-visibility over all facets
                facet_vis_lists.append(sum_predict_results(facet_vis_results))
            # Sum all sub-visibilties
            image_results_list.append(
                visibility_gather(facet_vis_lists, sub_vis_list, vis_iter))
        return image_results_list
예제 #5
0
def invert_list_serial_workflow(vis_list,
                                template_model_imagelist,
                                dopsf=False,
                                normalize=True,
                                facets=1,
                                vis_slices=1,
                                context='2d',
                                gcfcf=None,
                                **kwargs):
    """ Sum results from invert, iterating over the scattered image and vis_list

    :param vis_list: list of vis
    :param template_model_imagelist: list of template models
    :param dopsf: Make the PSF instead of the dirty image
    :param facets: Number of facets
    :param normalize: Normalize by sumwt
    :param vis_slices: Number of slices
    :param context: Imaging context
    :param gcfcg: tuple containing grid correction and convolution function
    :param kwargs: Parameters for functions in components
    :return: List of (image, sumwt) tuples, one per vis in vis_list

    For example::

        model_list = [create_image_from_visibility
            (v, npixel=npixel, cellsize=cellsize, polarisation_frame=pol_frame)
            for v in vis_list]

        dirty_list = invert_list_serial_workflow(vis_list, template_model_imagelist=model_list, context='wstack',
                                                    vis_slices=51)
        dirty, sumwt = dirty_list[centre]

   """

    if not isinstance(template_model_imagelist, collections.abc.Iterable):
        template_model_imagelist = [template_model_imagelist]

    c = imaging_context(context)
    vis_iter = c['vis_iterator']
    invert = c['invert']

    if facets % 2 == 0 or facets == 1:
        actual_number_facets = facets
    else:
        actual_number_facets = max(1, (facets - 1))

    def gather_image_iteration_results(results, template_model):
        result = create_empty_image_like(template_model)
        i = 0
        sumwt = numpy.zeros([template_model.nchan, template_model.npol])
        for dpatch in image_scatter_facets(result, facets=facets):
            assert i < len(
                results), "Too few results in gather_image_iteration_results"
            if results[i] is not None:
                assert len(results[i]) == 2, results[i]
                dpatch.data[...] = results[i][0].data[...]
                sumwt += results[i][1]
                i += 1
        return result, sumwt

    def invert_ignore_none(vis, model, gg):
        if vis is not None:

            return invert(vis,
                          model,
                          context=context,
                          dopsf=dopsf,
                          normalize=normalize,
                          gcfcf=gg,
                          **kwargs)
        else:
            return create_empty_image_like(model), numpy.zeros(
                [model.nchan, model.npol])

    # If we are doing facets, we need to create the gcf for each image
    if gcfcf is None and facets == 1:
        gcfcf = [create_pswf_convolutionfunction(template_model_imagelist[0])]

    # Loop over all vis_lists independently
    results_vislist = list()
    if facets == 1:
        for ivis, sub_vis_list in enumerate(vis_list):
            if len(gcfcf) > 1:
                g = gcfcf[ivis]
            else:
                g = gcfcf[0]
            # Iterate within each vis_list
            result_image = create_empty_image_like(
                template_model_imagelist[ivis])
            result_sumwt = numpy.zeros([
                template_model_imagelist[ivis].nchan,
                template_model_imagelist[ivis].npol
            ])
            for rows in vis_iter(sub_vis_list, vis_slices):
                row_vis = create_visibility_from_rows(sub_vis_list, rows)
                result = invert_ignore_none(row_vis,
                                            template_model_imagelist[ivis], g)
                if result is not None:
                    result_image.data += result[1][:, :, numpy.newaxis, numpy.
                                                   newaxis] * result[0].data
                    result_sumwt += result[1]
            result_image = normalize_sumwt(result_image, result_sumwt)
            results_vislist.append((result_image, result_sumwt))
    else:
        for ivis, sub_vis_list in enumerate(vis_list):
            # Create the graph to divide an image into facets. This is by reference.
            facet_lists = image_scatter_facets(template_model_imagelist[ivis],
                                               facets=facets)
            # Create the graph to divide the visibility into slices. This is by copy.
            sub_sub_vis_lists = visibility_scatter(sub_vis_list,
                                                   vis_iter,
                                                   vis_slices=vis_slices)

            # Iterate within each vis_list
            vis_results = list()
            for sub_sub_vis_list in sub_sub_vis_lists:
                facet_vis_results = list()
                for facet_list in facet_lists:
                    facet_vis_results.append(
                        invert_ignore_none(sub_sub_vis_list, facet_list, None))
                vis_results.append(
                    gather_image_iteration_results(
                        facet_vis_results, template_model_imagelist[ivis]))
            results_vislist.append(sum_invert_results(vis_results))

    return results_vislist