def test_scatter_gather_facet(self):

        m31original = create_test_image(
            polarisation_frame=PolarisationFrame('stokesI'))
        assert numpy.max(numpy.abs(m31original.data)), "Original is empty"

        for nraster in [1, 4, 8]:
            m31model = create_test_image(
                polarisation_frame=PolarisationFrame('stokesI'))
            image_list = image_scatter_facets(m31model, facets=nraster)
            for patch in image_list:
                assert patch.data.shape[3] == (m31model.data.shape[3] // nraster), \
                    "Number of pixels in each patch: %d not as expected: %d" % (patch.data.shape[3],
                                                                                (m31model.data.shape[3] // nraster))
                assert patch.data.shape[2] == (m31model.data.shape[2] // nraster), \
                    "Number of pixels in each patch: %d not as expected: %d" % (patch.data.shape[2],
                                                                                (m31model.data.shape[2] // nraster))
                patch.data[...] = 1.0
            m31reconstructed = create_empty_image_like(m31model)
            m31reconstructed = image_gather_facets(image_list,
                                                   m31reconstructed,
                                                   facets=nraster)
            flat = image_gather_facets(image_list,
                                       m31reconstructed,
                                       facets=nraster,
                                       return_flat=True)

            assert numpy.max(numpy.abs(
                flat.data)), "Flat is empty for %d" % nraster
            assert numpy.max(numpy.abs(
                m31reconstructed.data)), "Raster is empty for %d" % nraster
 def test_scatter_gather_facet_overlap_taper(self):
 
     m31original = create_test_image(polarisation_frame=PolarisationFrame('stokesI'))
     assert numpy.max(numpy.abs(m31original.data)), "Original is empty"
 
     for taper in ['linear', None]:
         for nraster, overlap in [(1, 0), (4, 8), (8, 8), (8, 16)]:
             m31model = create_test_image(polarisation_frame=PolarisationFrame('stokesI'))
             image_list = image_scatter_facets(m31model, facets=nraster, overlap=overlap, taper=taper)
             for patch in image_list:
                 assert patch.data.shape[3] == (2 * overlap + m31model.data.shape[3] // nraster), \
                     "Number of pixels in each patch: %d not as expected: %d" % (patch.data.shape[3],
                                                                                 (2 * overlap + m31model.data.shape[3] //
                                                                                  nraster))
                 assert patch.data.shape[2] == (2 * overlap + m31model.data.shape[2] // nraster), \
                     "Number of pixels in each patch: %d not as expected: %d" % (patch.data.shape[2],
                                                                                 (2 * overlap + m31model.data.shape[2] //
                                                                                  nraster))
             m31reconstructed = create_empty_image_like(m31model)
             m31reconstructed = image_gather_facets(image_list, m31reconstructed, facets=nraster, overlap=overlap,
                                                    taper=taper)
             flat = image_gather_facets(image_list, m31reconstructed, facets=nraster, overlap=overlap,
                                        taper=taper, return_flat=True)
             export_image_to_fits(m31reconstructed,
                                  "%s/test_image_gather_scatter_%dnraster_%doverlap_%s_reconstructed.fits" %
                                  (self.dir, nraster, overlap, taper))
             export_image_to_fits(flat,
                                  "%s/test_image_gather_scatter_%dnraster_%doverlap_%s_flat.fits" %
                                  (self.dir, nraster, overlap, taper))
 
             assert numpy.max(numpy.abs(flat.data)), "Flat is empty for %d" % nraster
             assert numpy.max(numpy.abs(m31reconstructed.data)), "Raster is empty for %d" % nraster
def invert_serial(vis, im: Image, dopsf=False, normalize=True, context='2d', vis_slices=1,
                  facets=1, overlap=0, taper=None, **kwargs):
    """ Invert using algorithm specified by context:

     * 2d: Two-dimensional transform
     * wstack: wstacking with either vis_slices or wstack (spacing between w planes) set
     * wprojection: w projection with wstep (spacing between w places) set, also kernel='wprojection'
     * timeslice: snapshot imaging with either vis_slices or timeslice set. timeslice='auto' does every time
     * facets: Faceted imaging with facets facets on each axis
     * facets_wprojection: facets AND wprojection
     * facets_wstack: facets AND wstacking
     * wprojection_wstack: wprojection and wstacking


    :param vis:
    :param im:
    :param dopsf: Make the psf instead of the dirty image (False)
    :param normalize: Normalize by the sum of weights (True)
    :param context: Imaging context e.g. '2d', 'timeslice', etc.
    :param kwargs:
    :return: Image, sum of weights
    """
    c = imaging_context(context)
    vis_iter = c['vis_iterator']
    invert = c['invert']

    if not isinstance(vis, Visibility):
        svis = convert_blockvisibility_to_visibility(vis)
    else:
        svis = vis
    
    resultimage = create_empty_image_like(im)
    
    totalwt = None
    for rows in vis_iter(svis, vis_slices=vis_slices):
        if numpy.sum(rows):
            visslice = create_visibility_from_rows(svis, rows)
            sumwt = 0.0
            workimage = create_empty_image_like(im)
            for dpatch in image_scatter_facets(workimage, facets=facets, overlap=overlap, taper=taper):
                result, sumwt = invert(visslice, dpatch, dopsf, normalize=False, facets=facets,
                                       vis_slices=vis_slices, **kwargs)
                # Ensure that we fill in the elements of dpatch instead of creating a new numpy arrray
                dpatch.data[...] = result.data[...]
            # Assume that sumwt is the same for all patches
            if totalwt is None:
                totalwt = sumwt
            else:
                totalwt += sumwt
            resultimage.data += workimage.data
    
    assert totalwt is not None, "No valid data found for imaging"
    if normalize:
        resultimage = normalize_sumwt(resultimage, totalwt)
    
    return resultimage, totalwt
Ejemplo n.º 4
0
 def gather_image_iteration_results(results, template_model):
     result = create_empty_image_like(template_model)
     i = 0
     sumwt = numpy.zeros([template_model.nchan, template_model.npol])
     for dpatch in image_scatter_facets(result, facets=facets):
         assert i < len(
             results), "Too few results in gather_image_iteration_results"
         if results[i] is not None:
             assert len(results[i]) == 2, results[i]
             dpatch.data[...] = results[i][0].data[...]
             sumwt += results[i][1]
             i += 1
     return result, sumwt
def predict_serial(vis, model: Image, context='2d', vis_slices=1, facets=1, overlap=0, taper=None,
                   **kwargs) -> Visibility:
    """Predict visibilities using algorithm specified by context
    
     * 2d: Two-dimensional transform
     * wstack: wstacking with either vis_slices or wstack (spacing between w planes) set
     * wprojection: w projection with wstep (spacing between w places) set, also kernel='wprojection'
     * timeslice: snapshot imaging with either vis_slices or timeslice set. timeslice='auto' does every time
     * facets: Faceted imaging with facets facets on each axis
     * facets_wprojection: facets AND wprojection
     * facets_wstack: facets AND wstacking
     * wprojection_wstack: wprojection and wstacking

    
    :param vis:
    :param model: Model image, used to determine image characteristics
    :param context: Imaging context e.g. '2d', 'timeslice', etc.
    :param inner: Inner loop 'vis'|'image'
    :param kwargs:
    :return:


    """
    c = imaging_context(context)
    vis_iter = c['vis_iterator']
    predict = c['predict']
    
    if not isinstance(vis, Visibility):
        svis = convert_blockvisibility_to_visibility(vis)
    else:
        svis = vis
    
    result = copy_visibility(vis, zero=True)
    
    for rows in vis_iter(svis, vis_slices=vis_slices):
        if numpy.sum(rows):
            visslice = create_visibility_from_rows(svis, rows)
            visslice.data['vis'][...] = 0.0
            for dpatch in image_scatter_facets(model, facets=facets, overlap=overlap, taper=taper):
                result.data['vis'][...] = 0.0
                result = predict(visslice, dpatch, **kwargs)
                svis.data['vis'][rows] += result.data['vis']

    if not isinstance(vis, Visibility):
        svis = convert_visibility_to_blockvisibility(svis)

    return svis
    assert facets * facets % size == 0

    # Create test image
    frequency = numpy.array([1e8])
    phasecentre = SkyCoord(ra=+15.0 * u.deg,
                           dec=-35.0 * u.deg,
                           frame='icrs',
                           equinox='J2000')
    model = create_test_image(frequency=frequency,
                              phasecentre=phasecentre,
                              cellsize=0.001,
                              polarisation_frame=PolarisationFrame('stokesI'))

    # Rank 0 scatters the test image
    if rank == 0:
        subimages = image_scatter_facets(model, facets=facets)
        subimages = numpy.array_split(subimages, size)
    else:
        subimages = list()

    sublist = comm.scatter(subimages, root=0)

    root_images = imagerooter(sublist)

    roots = comm.gather(root_images, root=0)

    if rank == 0:
        results = sum(roots, [])
        root_model = create_empty_image_like(model)
        result = image_gather_facets(results, root_model, facets=facets)
        numpy.testing.assert_array_almost_equal_nulp(result.data**2,
Ejemplo n.º 7
0
def invert_list_serial_workflow(vis_list,
                                template_model_imagelist,
                                dopsf=False,
                                normalize=True,
                                facets=1,
                                vis_slices=1,
                                context='2d',
                                **kwargs):
    """ Sum results from invert, iterating over the scattered image and vis_list

    :param vis_list:
    :param template_model_imagelist: Model used to determine image parameters
    :param dopsf: Make the PSF instead of the dirty image
    :param facets: Number of facets
    :param normalize: Normalize by sumwt
    :param vis_slices: Number of slices
    :param context: Imaging context
    :param kwargs: Parameters for functions in components
    :return: List of (image, sumwt) tuple
   """

    if not isinstance(template_model_imagelist, collections.Iterable):
        template_model_imagelist = [template_model_imagelist]

    c = imaging_context(context)
    vis_iter = c['vis_iterator']
    invert = c['invert']

    def gather_image_iteration_results(results, template_model):
        result = create_empty_image_like(template_model)
        i = 0
        sumwt = numpy.zeros([template_model.nchan, template_model.npol])
        for dpatch in image_scatter_facets(result, facets=facets):
            assert i < len(
                results), "Too few results in gather_image_iteration_results"
            if results[i] is not None:
                assert len(results[i]) == 2, results[i]
                dpatch.data[...] = results[i][0].data[...]
                sumwt += results[i][1]
                i += 1
        return result, sumwt

    def invert_ignore_none(vis, model):
        if vis is not None:
            return invert(vis,
                          model,
                          context=context,
                          dopsf=dopsf,
                          normalize=normalize,
                          facets=facets,
                          vis_slices=vis_slices,
                          **kwargs)
        else:
            return create_empty_image_like(model), 0.0

    # Loop over all vis_lists independently
    results_vislist = list()
    for freqwin, vis_list in enumerate(vis_list):
        # Create the graph to divide an image into facets. This is by reference.
        facet_lists = image_scatter_facets(template_model_imagelist[freqwin],
                                           facets=facets)
        # Create the graph to divide the visibility into slices. This is by copy.
        sub_vis_lists = visibility_scatter(vis_list,
                                           vis_iter,
                                           vis_slices=vis_slices)

        # Iterate within each vis_list
        vis_results = list()
        for sub_vis_list in sub_vis_lists:
            facet_vis_results = list()
            for facet_list in facet_lists:
                facet_vis_results.append(
                    invert_ignore_none(sub_vis_list, facet_list))
            vis_results.append(
                gather_image_iteration_results(
                    facet_vis_results, template_model_imagelist[freqwin]))
        results_vislist.append(sum_invert_results(vis_results))

    return results_vislist
Ejemplo n.º 8
0
def predict_list_serial_workflow(vis_list,
                                 model_imagelist,
                                 vis_slices=1,
                                 facets=1,
                                 context='2d',
                                 **kwargs):
    """Predict, iterating over both the scattered vis_list and image

    The visibility and image are scattered, the visibility is predicted on each part, and then the
    parts are assembled.

    :param vis_list:
    :param model_imagelist: Model used to determine image parameters
    :param vis_slices: Number of vis slices (w stack or timeslice)
    :param facets: Number of facets (per axis)
    :param context:
    :param kwargs: Parameters for functions in components
    :return: List of vis_lists
   """

    assert len(vis_list) == len(
        model_imagelist), "Model must be the same length as the vis_list"

    c = imaging_context(context)
    vis_iter = c['vis_iterator']
    predict = c['predict']

    def predict_ignore_none(vis, model):
        if vis is not None:
            return predict(vis,
                           model,
                           context=context,
                           facets=facets,
                           vis_slices=vis_slices,
                           **kwargs)
        else:
            return None

    image_results_list_list = list()
    # Loop over all frequency windows
    for freqwin, vis_list in enumerate(vis_list):
        # Create the graph to divide an image into facets. This is by reference.
        facet_lists = image_scatter_facets(model_imagelist[freqwin],
                                           facets=facets)
        # Create the graph to divide the visibility into slices. This is by copy.
        sub_vis_lists = visibility_scatter(vis_list, vis_iter, vis_slices)

        facet_vis_lists = list()
        # Loop over sub visibility
        for sub_vis_list in sub_vis_lists:
            facet_vis_results = list()
            # Loop over facets
            for facet_list in facet_lists:
                # Predict visibility for this subvisibility from this facet
                facet_vis_list = predict_ignore_none(sub_vis_list, facet_list)
                facet_vis_results.append(facet_vis_list)
            # Sum the current sub-visibility over all facets
            facet_vis_lists.append(sum_predict_results(facet_vis_results))
        # Sum all sub-visibilties
        image_results_list_list.append(
            visibility_gather(facet_vis_lists, vis_list, vis_iter))

    return image_results_list_list
Ejemplo n.º 9
0
def deconvolve_list_serial_workflow(dirty_list,
                                    psf_list,
                                    model_imagelist,
                                    prefix='',
                                    **kwargs):
    """Create a graph for deconvolution, adding to the model

    :param dirty_list:
    :param psf_list:
    :param model_imagelist:
    :param kwargs: Parameters for functions in components
    :return: (graph for the deconvolution, graph for the flat)
    """
    nchan = len(dirty_list)

    def deconvolve(dirty, psf, model, facet, gthreshold):
        import time
        starttime = time.time()
        if prefix == '':
            lprefix = "facet %d" % facet
        else:
            lprefix = "%s, facet %d" % (prefix, facet)

        nmoments = get_parameter(kwargs, "nmoments", 0)

        if nmoments > 0:
            moment0 = calculate_image_frequency_moments(dirty)
            this_peak = numpy.max(numpy.abs(
                moment0.data[0, ...])) / dirty.data.shape[0]
        else:
            this_peak = numpy.max(numpy.abs(dirty.data[0, ...]))

        if this_peak > 1.1 * gthreshold:
            log.info(
                "deconvolve_list_serial_workflow %s: cleaning - peak %.6f > 1.1 * threshold %.6f"
                % (lprefix, this_peak, gthreshold))
            kwargs['threshold'] = gthreshold
            result, _ = deconvolve_cube(dirty, psf, prefix=lprefix, **kwargs)

            if result.data.shape[0] == model.data.shape[0]:
                result.data += model.data
            else:
                log.warning(
                    "deconvolve_list_serial_workflow %s: Initial model %s and clean result %s do not have the same shape"
                    % (lprefix, str(
                        model.data.shape[0]), str(result.data.shape[0])))

            flux = numpy.sum(result.data[0, 0, ...])
            log.info(
                '### %s, %.6f, %.6f, True, %.3f # cycle, facet, peak, cleaned flux, clean, time?'
                % (lprefix, this_peak, flux, time.time() - starttime))

            return result
        else:
            log.info(
                "deconvolve_list_serial_workflow %s: Not cleaning - peak %.6f <= 1.1 * threshold %.6f"
                % (lprefix, this_peak, gthreshold))
            log.info(
                '### %s, %.6f, %.6f, False, %.3f # cycle, facet, peak, cleaned flux, clean, time?'
                % (lprefix, this_peak, 0.0, time.time() - starttime))

            return copy_image(model)

    deconvolve_facets = get_parameter(kwargs, 'deconvolve_facets', 1)
    deconvolve_overlap = get_parameter(kwargs, 'deconvolve_overlap', 0)
    deconvolve_taper = get_parameter(kwargs, 'deconvolve_taper', None)
    if deconvolve_overlap > 0:
        deconvolve_number_facets = (deconvolve_facets - 2)**2
    else:
        deconvolve_number_facets = deconvolve_facets**2

    model_imagelist = image_gather_channels(model_imagelist)

    # Scatter the separate channel images into deconvolve facets and then gather channels for each facet.
    # This avoids constructing the entire spectral cube.
    #    dirty_list = remove_sumwt, nout=nchan)(dirty_list)
    scattered_channels_facets_dirty_list = \
        [image_scatter_facets(d[0], facets=deconvolve_facets,
                              overlap=deconvolve_overlap,
                              taper=deconvolve_taper)
         for d in dirty_list]

    # Now we do a transpose and gather
    scattered_facets_list = [
        image_gather_channels([
            scattered_channels_facets_dirty_list[chan][facet]
            for chan in range(nchan)
        ]) for facet in range(deconvolve_number_facets)
    ]

    psf_list = remove_sumwt(psf_list)
    psf_list = image_gather_channels(psf_list)

    scattered_model_imagelist = \
        image_scatter_facets(model_imagelist,
                             facets=deconvolve_facets,
                             overlap=deconvolve_overlap)

    # Work out the threshold. Need to find global peak over all dirty_list images
    threshold = get_parameter(kwargs, "threshold", 0.0)
    fractional_threshold = get_parameter(kwargs, "fractional_threshold", 0.1)
    nmoments = get_parameter(kwargs, "nmoments", 0)
    use_moment0 = nmoments > 0

    # Find the global threshold. This uses the peak in the average on the frequency axis since we
    # want to use it in a stopping criterion in a moment clean
    global_threshold = threshold_list(scattered_facets_list,
                                      threshold,
                                      fractional_threshold,
                                      use_moment0=use_moment0,
                                      prefix=prefix)

    facet_list = numpy.arange(deconvolve_number_facets).astype('int')
    scattered_results_list = [
        deconvolve(d, psf_list, m, facet, global_threshold) for d, m, facet in
        zip(scattered_facets_list, scattered_model_imagelist, facet_list)
    ]

    # Gather the results back into one image, correcting for overlaps as necessary. The taper function is is used to
    # feather the facets together
    gathered_results_list = image_gather_facets(scattered_results_list,
                                                model_imagelist,
                                                facets=deconvolve_facets,
                                                overlap=deconvolve_overlap,
                                                taper=deconvolve_taper)
    flat_list = image_gather_facets(scattered_results_list,
                                    model_imagelist,
                                    facets=deconvolve_facets,
                                    overlap=deconvolve_overlap,
                                    taper=deconvolve_taper,
                                    return_flat=True)

    return image_scatter_channels(gathered_results_list,
                                  subimages=nchan), flat_list
Ejemplo n.º 10
0
            print(facets)
            print(vis_slices)
            return predict(vis,
                           model,
                           context=context,
                           facets=facets,
                           vis_slices=vis_slices)
        else:
            return None

    #pdb.set_trace()
    image_results_list_list = list()
    # Loop over all frequency windows
    for freqwin, vis_lst in enumerate(vis_list):
        # Create the graph to divide an image into facets. This is by reference.
        facet_lists = image_scatter_facets(model_imagelist[freqwin],
                                           facets=facets)
        # Create the graph to divide the visibility into slices. This is by copy.
        sub_vis_lists = visibility_scatter(vis_lst, vis_iter, vis_slices)
        facet_vis_lists = list()

        #pdb.set_trace()
        # Loop over sub visibility
        for sub_vis_list in sub_vis_lists:
            facet_vis_results = list()
            # Loop over facets
            for facet_list in facet_lists:
                # Predict visibility for this subvisibility from this facet
                facet_vis_list = predict_ignore_none(sub_vis_list, facet_list)
                facet_vis_results.append(facet_vis_list)
            # Sum the current sub-visibility over all facets
            facet_vis_lists.append(sum_predict_results(facet_vis_results))