Exemplo n.º 1
0
def deconvolve_channel_list_serial_workflow(dirty_list, psf_list,
                                            model_imagelist, subimages,
                                            **kwargs):
    """Create a graph for deconvolution by channels, adding to the model

    Does deconvolution channel by channel.
    :param subimages:
    :param dirty_list:
    :param psf_list: Must be the size of a facet
    :param model_imagelist: Current model
    :param kwargs: Parameters for functions in components
    :return:
    """
    def deconvolve_subimage(dirty, psf):
        assert isinstance(dirty, Image)
        assert isinstance(psf, Image)
        comp = deconvolve_cube(dirty, psf, **kwargs)
        return comp[0]

    def add_model(sum_model, model):
        assert isinstance(output, Image)
        assert isinstance(model, Image)
        sum_model.data += model.data
        return sum_model

    output = create_empty_image_like(model_imagelist)
    dirty_lists = image_scatter_channels(dirty_list[0], subimages=subimages)
    results = [
        deconvolve_subimage(dirty_list, psf_list[0])
        for dirty_list in dirty_lists
    ]
    result = image_gather_channels(results, output, subimages=subimages)
    return add_model(result, model_imagelist)
 def test_gather_channel(self):
     for nchan in [128, 16]:
         m31cube = create_test_image(polarisation_frame=PolarisationFrame('stokesI'),
                                     frequency=numpy.linspace(1e8, 1.1e8, nchan))
         image_list = image_scatter_channels(m31cube, subimages=nchan)
         m31cuberec = image_gather_channels(image_list, None, subimages=nchan)
         assert m31cube.shape == m31cuberec.shape
         diff = m31cube.data - m31cuberec.data
         assert numpy.max(numpy.abs(diff)) == 0.0, "Scatter gather failed for %d" % nchan
Exemplo n.º 3
0
def deconvolve_list_serial_workflow(dirty_list,
                                    psf_list,
                                    model_imagelist,
                                    prefix='',
                                    **kwargs):
    """Create a graph for deconvolution, adding to the model

    :param dirty_list:
    :param psf_list:
    :param model_imagelist:
    :param kwargs: Parameters for functions in components
    :return: (graph for the deconvolution, graph for the flat)
    """
    nchan = len(dirty_list)

    def deconvolve(dirty, psf, model, facet, gthreshold):
        import time
        starttime = time.time()
        if prefix == '':
            lprefix = "facet %d" % facet
        else:
            lprefix = "%s, facet %d" % (prefix, facet)

        nmoments = get_parameter(kwargs, "nmoments", 0)

        if nmoments > 0:
            moment0 = calculate_image_frequency_moments(dirty)
            this_peak = numpy.max(numpy.abs(
                moment0.data[0, ...])) / dirty.data.shape[0]
        else:
            this_peak = numpy.max(numpy.abs(dirty.data[0, ...]))

        if this_peak > 1.1 * gthreshold:
            log.info(
                "deconvolve_list_serial_workflow %s: cleaning - peak %.6f > 1.1 * threshold %.6f"
                % (lprefix, this_peak, gthreshold))
            kwargs['threshold'] = gthreshold
            result, _ = deconvolve_cube(dirty, psf, prefix=lprefix, **kwargs)

            if result.data.shape[0] == model.data.shape[0]:
                result.data += model.data
            else:
                log.warning(
                    "deconvolve_list_serial_workflow %s: Initial model %s and clean result %s do not have the same shape"
                    % (lprefix, str(
                        model.data.shape[0]), str(result.data.shape[0])))

            flux = numpy.sum(result.data[0, 0, ...])
            log.info(
                '### %s, %.6f, %.6f, True, %.3f # cycle, facet, peak, cleaned flux, clean, time?'
                % (lprefix, this_peak, flux, time.time() - starttime))

            return result
        else:
            log.info(
                "deconvolve_list_serial_workflow %s: Not cleaning - peak %.6f <= 1.1 * threshold %.6f"
                % (lprefix, this_peak, gthreshold))
            log.info(
                '### %s, %.6f, %.6f, False, %.3f # cycle, facet, peak, cleaned flux, clean, time?'
                % (lprefix, this_peak, 0.0, time.time() - starttime))

            return copy_image(model)

    deconvolve_facets = get_parameter(kwargs, 'deconvolve_facets', 1)
    deconvolve_overlap = get_parameter(kwargs, 'deconvolve_overlap', 0)
    deconvolve_taper = get_parameter(kwargs, 'deconvolve_taper', None)
    if deconvolve_overlap > 0:
        deconvolve_number_facets = (deconvolve_facets - 2)**2
    else:
        deconvolve_number_facets = deconvolve_facets**2

    model_imagelist = image_gather_channels(model_imagelist)

    # Scatter the separate channel images into deconvolve facets and then gather channels for each facet.
    # This avoids constructing the entire spectral cube.
    #    dirty_list = remove_sumwt, nout=nchan)(dirty_list)
    scattered_channels_facets_dirty_list = \
        [image_scatter_facets(d[0], facets=deconvolve_facets,
                              overlap=deconvolve_overlap,
                              taper=deconvolve_taper)
         for d in dirty_list]

    # Now we do a transpose and gather
    scattered_facets_list = [
        image_gather_channels([
            scattered_channels_facets_dirty_list[chan][facet]
            for chan in range(nchan)
        ]) for facet in range(deconvolve_number_facets)
    ]

    psf_list = remove_sumwt(psf_list)
    psf_list = image_gather_channels(psf_list)

    scattered_model_imagelist = \
        image_scatter_facets(model_imagelist,
                             facets=deconvolve_facets,
                             overlap=deconvolve_overlap)

    # Work out the threshold. Need to find global peak over all dirty_list images
    threshold = get_parameter(kwargs, "threshold", 0.0)
    fractional_threshold = get_parameter(kwargs, "fractional_threshold", 0.1)
    nmoments = get_parameter(kwargs, "nmoments", 0)
    use_moment0 = nmoments > 0

    # Find the global threshold. This uses the peak in the average on the frequency axis since we
    # want to use it in a stopping criterion in a moment clean
    global_threshold = threshold_list(scattered_facets_list,
                                      threshold,
                                      fractional_threshold,
                                      use_moment0=use_moment0,
                                      prefix=prefix)

    facet_list = numpy.arange(deconvolve_number_facets).astype('int')
    scattered_results_list = [
        deconvolve(d, psf_list, m, facet, global_threshold) for d, m, facet in
        zip(scattered_facets_list, scattered_model_imagelist, facet_list)
    ]

    # Gather the results back into one image, correcting for overlaps as necessary. The taper function is is used to
    # feather the facets together
    gathered_results_list = image_gather_facets(scattered_results_list,
                                                model_imagelist,
                                                facets=deconvolve_facets,
                                                overlap=deconvolve_overlap,
                                                taper=deconvolve_taper)
    flat_list = image_gather_facets(scattered_results_list,
                                    model_imagelist,
                                    facets=deconvolve_facets,
                                    overlap=deconvolve_overlap,
                                    taper=deconvolve_taper,
                                    return_flat=True)

    return image_scatter_channels(gathered_results_list,
                                  subimages=nchan), flat_list