def test_scatter_gather_facet(self): m31original = create_test_image( polarisation_frame=PolarisationFrame('stokesI')) assert numpy.max(numpy.abs(m31original.data)), "Original is empty" for nraster in [1, 4, 8]: m31model = create_test_image( polarisation_frame=PolarisationFrame('stokesI')) image_list = image_scatter_facets(m31model, facets=nraster) for patch in image_list: assert patch.data.shape[3] == (m31model.data.shape[3] // nraster), \ "Number of pixels in each patch: %d not as expected: %d" % (patch.data.shape[3], (m31model.data.shape[3] // nraster)) assert patch.data.shape[2] == (m31model.data.shape[2] // nraster), \ "Number of pixels in each patch: %d not as expected: %d" % (patch.data.shape[2], (m31model.data.shape[2] // nraster)) patch.data[...] = 1.0 m31reconstructed = create_empty_image_like(m31model) m31reconstructed = image_gather_facets(image_list, m31reconstructed, facets=nraster) flat = image_gather_facets(image_list, m31reconstructed, facets=nraster, return_flat=True) assert numpy.max(numpy.abs( flat.data)), "Flat is empty for %d" % nraster assert numpy.max(numpy.abs( m31reconstructed.data)), "Raster is empty for %d" % nraster
def test_scatter_gather_facet_overlap_taper(self): m31original = create_test_image(polarisation_frame=PolarisationFrame('stokesI')) assert numpy.max(numpy.abs(m31original.data)), "Original is empty" for taper in ['linear', None]: for nraster, overlap in [(1, 0), (4, 8), (8, 8), (8, 16)]: m31model = create_test_image(polarisation_frame=PolarisationFrame('stokesI')) image_list = image_scatter_facets(m31model, facets=nraster, overlap=overlap, taper=taper) for patch in image_list: assert patch.data.shape[3] == (2 * overlap + m31model.data.shape[3] // nraster), \ "Number of pixels in each patch: %d not as expected: %d" % (patch.data.shape[3], (2 * overlap + m31model.data.shape[3] // nraster)) assert patch.data.shape[2] == (2 * overlap + m31model.data.shape[2] // nraster), \ "Number of pixels in each patch: %d not as expected: %d" % (patch.data.shape[2], (2 * overlap + m31model.data.shape[2] // nraster)) m31reconstructed = create_empty_image_like(m31model) m31reconstructed = image_gather_facets(image_list, m31reconstructed, facets=nraster, overlap=overlap, taper=taper) flat = image_gather_facets(image_list, m31reconstructed, facets=nraster, overlap=overlap, taper=taper, return_flat=True) export_image_to_fits(m31reconstructed, "%s/test_image_gather_scatter_%dnraster_%doverlap_%s_reconstructed.fits" % (self.dir, nraster, overlap, taper)) export_image_to_fits(flat, "%s/test_image_gather_scatter_%dnraster_%doverlap_%s_flat.fits" % (self.dir, nraster, overlap, taper)) assert numpy.max(numpy.abs(flat.data)), "Flat is empty for %d" % nraster assert numpy.max(numpy.abs(m31reconstructed.data)), "Raster is empty for %d" % nraster
def invert_serial(vis, im: Image, dopsf=False, normalize=True, context='2d', vis_slices=1, facets=1, overlap=0, taper=None, **kwargs): """ Invert using algorithm specified by context: * 2d: Two-dimensional transform * wstack: wstacking with either vis_slices or wstack (spacing between w planes) set * wprojection: w projection with wstep (spacing between w places) set, also kernel='wprojection' * timeslice: snapshot imaging with either vis_slices or timeslice set. timeslice='auto' does every time * facets: Faceted imaging with facets facets on each axis * facets_wprojection: facets AND wprojection * facets_wstack: facets AND wstacking * wprojection_wstack: wprojection and wstacking :param vis: :param im: :param dopsf: Make the psf instead of the dirty image (False) :param normalize: Normalize by the sum of weights (True) :param context: Imaging context e.g. '2d', 'timeslice', etc. :param kwargs: :return: Image, sum of weights """ c = imaging_context(context) vis_iter = c['vis_iterator'] invert = c['invert'] if not isinstance(vis, Visibility): svis = convert_blockvisibility_to_visibility(vis) else: svis = vis resultimage = create_empty_image_like(im) totalwt = None for rows in vis_iter(svis, vis_slices=vis_slices): if numpy.sum(rows): visslice = create_visibility_from_rows(svis, rows) sumwt = 0.0 workimage = create_empty_image_like(im) for dpatch in image_scatter_facets(workimage, facets=facets, overlap=overlap, taper=taper): result, sumwt = invert(visslice, dpatch, dopsf, normalize=False, facets=facets, vis_slices=vis_slices, **kwargs) # Ensure that we fill in the elements of dpatch instead of creating a new numpy arrray dpatch.data[...] = result.data[...] # Assume that sumwt is the same for all patches if totalwt is None: totalwt = sumwt else: totalwt += sumwt resultimage.data += workimage.data assert totalwt is not None, "No valid data found for imaging" if normalize: resultimage = normalize_sumwt(resultimage, totalwt) return resultimage, totalwt
def gather_image_iteration_results(results, template_model): result = create_empty_image_like(template_model) i = 0 sumwt = numpy.zeros([template_model.nchan, template_model.npol]) for dpatch in image_scatter_facets(result, facets=facets): assert i < len( results), "Too few results in gather_image_iteration_results" if results[i] is not None: assert len(results[i]) == 2, results[i] dpatch.data[...] = results[i][0].data[...] sumwt += results[i][1] i += 1 return result, sumwt
def predict_serial(vis, model: Image, context='2d', vis_slices=1, facets=1, overlap=0, taper=None, **kwargs) -> Visibility: """Predict visibilities using algorithm specified by context * 2d: Two-dimensional transform * wstack: wstacking with either vis_slices or wstack (spacing between w planes) set * wprojection: w projection with wstep (spacing between w places) set, also kernel='wprojection' * timeslice: snapshot imaging with either vis_slices or timeslice set. timeslice='auto' does every time * facets: Faceted imaging with facets facets on each axis * facets_wprojection: facets AND wprojection * facets_wstack: facets AND wstacking * wprojection_wstack: wprojection and wstacking :param vis: :param model: Model image, used to determine image characteristics :param context: Imaging context e.g. '2d', 'timeslice', etc. :param inner: Inner loop 'vis'|'image' :param kwargs: :return: """ c = imaging_context(context) vis_iter = c['vis_iterator'] predict = c['predict'] if not isinstance(vis, Visibility): svis = convert_blockvisibility_to_visibility(vis) else: svis = vis result = copy_visibility(vis, zero=True) for rows in vis_iter(svis, vis_slices=vis_slices): if numpy.sum(rows): visslice = create_visibility_from_rows(svis, rows) visslice.data['vis'][...] = 0.0 for dpatch in image_scatter_facets(model, facets=facets, overlap=overlap, taper=taper): result.data['vis'][...] = 0.0 result = predict(visslice, dpatch, **kwargs) svis.data['vis'][rows] += result.data['vis'] if not isinstance(vis, Visibility): svis = convert_visibility_to_blockvisibility(svis) return svis
assert facets * facets % size == 0 # Create test image frequency = numpy.array([1e8]) phasecentre = SkyCoord(ra=+15.0 * u.deg, dec=-35.0 * u.deg, frame='icrs', equinox='J2000') model = create_test_image(frequency=frequency, phasecentre=phasecentre, cellsize=0.001, polarisation_frame=PolarisationFrame('stokesI')) # Rank 0 scatters the test image if rank == 0: subimages = image_scatter_facets(model, facets=facets) subimages = numpy.array_split(subimages, size) else: subimages = list() sublist = comm.scatter(subimages, root=0) root_images = imagerooter(sublist) roots = comm.gather(root_images, root=0) if rank == 0: results = sum(roots, []) root_model = create_empty_image_like(model) result = image_gather_facets(results, root_model, facets=facets) numpy.testing.assert_array_almost_equal_nulp(result.data**2,
def invert_list_serial_workflow(vis_list, template_model_imagelist, dopsf=False, normalize=True, facets=1, vis_slices=1, context='2d', **kwargs): """ Sum results from invert, iterating over the scattered image and vis_list :param vis_list: :param template_model_imagelist: Model used to determine image parameters :param dopsf: Make the PSF instead of the dirty image :param facets: Number of facets :param normalize: Normalize by sumwt :param vis_slices: Number of slices :param context: Imaging context :param kwargs: Parameters for functions in components :return: List of (image, sumwt) tuple """ if not isinstance(template_model_imagelist, collections.Iterable): template_model_imagelist = [template_model_imagelist] c = imaging_context(context) vis_iter = c['vis_iterator'] invert = c['invert'] def gather_image_iteration_results(results, template_model): result = create_empty_image_like(template_model) i = 0 sumwt = numpy.zeros([template_model.nchan, template_model.npol]) for dpatch in image_scatter_facets(result, facets=facets): assert i < len( results), "Too few results in gather_image_iteration_results" if results[i] is not None: assert len(results[i]) == 2, results[i] dpatch.data[...] = results[i][0].data[...] sumwt += results[i][1] i += 1 return result, sumwt def invert_ignore_none(vis, model): if vis is not None: return invert(vis, model, context=context, dopsf=dopsf, normalize=normalize, facets=facets, vis_slices=vis_slices, **kwargs) else: return create_empty_image_like(model), 0.0 # Loop over all vis_lists independently results_vislist = list() for freqwin, vis_list in enumerate(vis_list): # Create the graph to divide an image into facets. This is by reference. facet_lists = image_scatter_facets(template_model_imagelist[freqwin], facets=facets) # Create the graph to divide the visibility into slices. This is by copy. sub_vis_lists = visibility_scatter(vis_list, vis_iter, vis_slices=vis_slices) # Iterate within each vis_list vis_results = list() for sub_vis_list in sub_vis_lists: facet_vis_results = list() for facet_list in facet_lists: facet_vis_results.append( invert_ignore_none(sub_vis_list, facet_list)) vis_results.append( gather_image_iteration_results( facet_vis_results, template_model_imagelist[freqwin])) results_vislist.append(sum_invert_results(vis_results)) return results_vislist
def predict_list_serial_workflow(vis_list, model_imagelist, vis_slices=1, facets=1, context='2d', **kwargs): """Predict, iterating over both the scattered vis_list and image The visibility and image are scattered, the visibility is predicted on each part, and then the parts are assembled. :param vis_list: :param model_imagelist: Model used to determine image parameters :param vis_slices: Number of vis slices (w stack or timeslice) :param facets: Number of facets (per axis) :param context: :param kwargs: Parameters for functions in components :return: List of vis_lists """ assert len(vis_list) == len( model_imagelist), "Model must be the same length as the vis_list" c = imaging_context(context) vis_iter = c['vis_iterator'] predict = c['predict'] def predict_ignore_none(vis, model): if vis is not None: return predict(vis, model, context=context, facets=facets, vis_slices=vis_slices, **kwargs) else: return None image_results_list_list = list() # Loop over all frequency windows for freqwin, vis_list in enumerate(vis_list): # Create the graph to divide an image into facets. This is by reference. facet_lists = image_scatter_facets(model_imagelist[freqwin], facets=facets) # Create the graph to divide the visibility into slices. This is by copy. sub_vis_lists = visibility_scatter(vis_list, vis_iter, vis_slices) facet_vis_lists = list() # Loop over sub visibility for sub_vis_list in sub_vis_lists: facet_vis_results = list() # Loop over facets for facet_list in facet_lists: # Predict visibility for this subvisibility from this facet facet_vis_list = predict_ignore_none(sub_vis_list, facet_list) facet_vis_results.append(facet_vis_list) # Sum the current sub-visibility over all facets facet_vis_lists.append(sum_predict_results(facet_vis_results)) # Sum all sub-visibilties image_results_list_list.append( visibility_gather(facet_vis_lists, vis_list, vis_iter)) return image_results_list_list
def deconvolve_list_serial_workflow(dirty_list, psf_list, model_imagelist, prefix='', **kwargs): """Create a graph for deconvolution, adding to the model :param dirty_list: :param psf_list: :param model_imagelist: :param kwargs: Parameters for functions in components :return: (graph for the deconvolution, graph for the flat) """ nchan = len(dirty_list) def deconvolve(dirty, psf, model, facet, gthreshold): import time starttime = time.time() if prefix == '': lprefix = "facet %d" % facet else: lprefix = "%s, facet %d" % (prefix, facet) nmoments = get_parameter(kwargs, "nmoments", 0) if nmoments > 0: moment0 = calculate_image_frequency_moments(dirty) this_peak = numpy.max(numpy.abs( moment0.data[0, ...])) / dirty.data.shape[0] else: this_peak = numpy.max(numpy.abs(dirty.data[0, ...])) if this_peak > 1.1 * gthreshold: log.info( "deconvolve_list_serial_workflow %s: cleaning - peak %.6f > 1.1 * threshold %.6f" % (lprefix, this_peak, gthreshold)) kwargs['threshold'] = gthreshold result, _ = deconvolve_cube(dirty, psf, prefix=lprefix, **kwargs) if result.data.shape[0] == model.data.shape[0]: result.data += model.data else: log.warning( "deconvolve_list_serial_workflow %s: Initial model %s and clean result %s do not have the same shape" % (lprefix, str( model.data.shape[0]), str(result.data.shape[0]))) flux = numpy.sum(result.data[0, 0, ...]) log.info( '### %s, %.6f, %.6f, True, %.3f # cycle, facet, peak, cleaned flux, clean, time?' % (lprefix, this_peak, flux, time.time() - starttime)) return result else: log.info( "deconvolve_list_serial_workflow %s: Not cleaning - peak %.6f <= 1.1 * threshold %.6f" % (lprefix, this_peak, gthreshold)) log.info( '### %s, %.6f, %.6f, False, %.3f # cycle, facet, peak, cleaned flux, clean, time?' % (lprefix, this_peak, 0.0, time.time() - starttime)) return copy_image(model) deconvolve_facets = get_parameter(kwargs, 'deconvolve_facets', 1) deconvolve_overlap = get_parameter(kwargs, 'deconvolve_overlap', 0) deconvolve_taper = get_parameter(kwargs, 'deconvolve_taper', None) if deconvolve_overlap > 0: deconvolve_number_facets = (deconvolve_facets - 2)**2 else: deconvolve_number_facets = deconvolve_facets**2 model_imagelist = image_gather_channels(model_imagelist) # Scatter the separate channel images into deconvolve facets and then gather channels for each facet. # This avoids constructing the entire spectral cube. # dirty_list = remove_sumwt, nout=nchan)(dirty_list) scattered_channels_facets_dirty_list = \ [image_scatter_facets(d[0], facets=deconvolve_facets, overlap=deconvolve_overlap, taper=deconvolve_taper) for d in dirty_list] # Now we do a transpose and gather scattered_facets_list = [ image_gather_channels([ scattered_channels_facets_dirty_list[chan][facet] for chan in range(nchan) ]) for facet in range(deconvolve_number_facets) ] psf_list = remove_sumwt(psf_list) psf_list = image_gather_channels(psf_list) scattered_model_imagelist = \ image_scatter_facets(model_imagelist, facets=deconvolve_facets, overlap=deconvolve_overlap) # Work out the threshold. Need to find global peak over all dirty_list images threshold = get_parameter(kwargs, "threshold", 0.0) fractional_threshold = get_parameter(kwargs, "fractional_threshold", 0.1) nmoments = get_parameter(kwargs, "nmoments", 0) use_moment0 = nmoments > 0 # Find the global threshold. This uses the peak in the average on the frequency axis since we # want to use it in a stopping criterion in a moment clean global_threshold = threshold_list(scattered_facets_list, threshold, fractional_threshold, use_moment0=use_moment0, prefix=prefix) facet_list = numpy.arange(deconvolve_number_facets).astype('int') scattered_results_list = [ deconvolve(d, psf_list, m, facet, global_threshold) for d, m, facet in zip(scattered_facets_list, scattered_model_imagelist, facet_list) ] # Gather the results back into one image, correcting for overlaps as necessary. The taper function is is used to # feather the facets together gathered_results_list = image_gather_facets(scattered_results_list, model_imagelist, facets=deconvolve_facets, overlap=deconvolve_overlap, taper=deconvolve_taper) flat_list = image_gather_facets(scattered_results_list, model_imagelist, facets=deconvolve_facets, overlap=deconvolve_overlap, taper=deconvolve_taper, return_flat=True) return image_scatter_channels(gathered_results_list, subimages=nchan), flat_list
print(facets) print(vis_slices) return predict(vis, model, context=context, facets=facets, vis_slices=vis_slices) else: return None #pdb.set_trace() image_results_list_list = list() # Loop over all frequency windows for freqwin, vis_lst in enumerate(vis_list): # Create the graph to divide an image into facets. This is by reference. facet_lists = image_scatter_facets(model_imagelist[freqwin], facets=facets) # Create the graph to divide the visibility into slices. This is by copy. sub_vis_lists = visibility_scatter(vis_lst, vis_iter, vis_slices) facet_vis_lists = list() #pdb.set_trace() # Loop over sub visibility for sub_vis_list in sub_vis_lists: facet_vis_results = list() # Loop over facets for facet_list in facet_lists: # Predict visibility for this subvisibility from this facet facet_vis_list = predict_ignore_none(sub_vis_list, facet_list) facet_vis_results.append(facet_vis_list) # Sum the current sub-visibility over all facets facet_vis_lists.append(sum_predict_results(facet_vis_results))