def restore_list_mpi_workflow(model_imagelist, psf_imagelist, residual_imagelist, comm=MPI.COMM_WORLD, **kwargs): """ Create a graph to calculate the restored image :param model_imagelist: Model list (rank0) :param psf_imagelist: PSF list (rank0) :param residual_imagelist: Residual list (rank0) :param kwargs: Parameters for functions in components :return: """ from workflows.serial.imaging.imaging_serial import restore_list_serial_workflow_nosumwt rank = comm.Get_rank() size = comm.Get_size() #TODO Parallelize! and check the dask version, it removes sumwt component # to reduce communication if residual_imagelist is None: residual_imagelist = [] if rank == 0: psf_list = remove_sumwt(psf_imagelist) if len(residual_imagelist) > 0: residual_list = remove_sumwt(residual_imagelist) else: residual_list = residual_imagelist else: psf_list = list() residual_list = list() sub_model_imagelist = numpy.array_split(model_imagelist, size) sub_model_imagelist = comm.scatter(sub_model_imagelist, root=0) sub_psf_list = numpy.array_split(psf_list, size) sub_psf_list = comm.scatter(sub_psf_list, root=0) sub_residual_list = numpy.array_split(residual_list, size) sub_residual_list = comm.scatter(sub_residual_list, root=0) sub_result_list = restore_list_serial_workflow_nosumwt( sub_model_imagelist, sub_psf_list, sub_residual_list) #sub_result_list=[restore_cube(sub_model_imagelist[i], sub_psf_list[i], # sub_residual_list[i], **kwargs) # for i, _ in enumerate(sub_model_imagelist)] result_list = comm.gather(sub_result_list, root=0) if rank == 0: # this is a list of tuples too, we may need to call my function result_list = numpy.concatenate(result_list) else: result_list = list() return result_list
def deconvolve_list_serial_workflow(dirty_list, psf_list, model_imagelist, prefix='', mask=None, **kwargs): """Create a graph for deconvolution, adding to the model :param dirty_list: :param psf_list: :param model_imagelist: :param prefix: Informative prefix to log messages :param mask: Mask for deconvolution :param kwargs: Parameters for functions in components :return: (graph for the deconvolution, graph for the flat) """ nchan = len(dirty_list) nmoment = get_parameter(kwargs, "nmoment", 0) assert isinstance(dirty_list, list), dirty_list assert isinstance(psf_list, list), psf_list assert isinstance(model_imagelist, list), model_imagelist def deconvolve(dirty, psf, model, facet, gthreshold, msk=None): if prefix == '': lprefix = "facet %d" % facet else: lprefix = "%s, facet %d" % (prefix, facet) if nmoment > 0: moment0 = calculate_image_frequency_moments(dirty) this_peak = numpy.max(numpy.abs(moment0.data[0, ...])) / dirty.data.shape[0] else: ref_chan = dirty.data.shape[0] // 2 this_peak = numpy.max(numpy.abs(dirty.data[ref_chan, ...])) if this_peak > 1.1 * gthreshold: kwargs['threshold'] = gthreshold result, _ = deconvolve_cube(dirty, psf, prefix=lprefix, mask=msk, **kwargs) if result.data.shape[0] == model.data.shape[0]: result.data += model.data return result else: return copy_image(model) deconvolve_facets = get_parameter(kwargs, 'deconvolve_facets', 1) deconvolve_overlap = get_parameter(kwargs, 'deconvolve_overlap', 0) deconvolve_taper = get_parameter(kwargs, 'deconvolve_taper', None) if deconvolve_facets > 1 and deconvolve_overlap > 0: deconvolve_number_facets = (deconvolve_facets - 2) ** 2 else: deconvolve_number_facets = deconvolve_facets ** 2 model_imagelist = image_gather_channels(model_imagelist) # Scatter the separate channel images into deconvolve facets and then gather channels for each facet. # This avoids constructing the entire spectral cube. dirty_list_trimmed = remove_sumwt(dirty_list) scattered_channels_facets_dirty_list = \ [image_scatter_facets(d, facets=deconvolve_facets, overlap=deconvolve_overlap, taper=deconvolve_taper) for d in dirty_list_trimmed] # Now we do a transpose and gather scattered_facets_list = [ image_gather_channels([scattered_channels_facets_dirty_list[chan][facet] for chan in range(nchan)]) for facet in range(deconvolve_number_facets)] psf_list_trimmed = remove_sumwt(psf_list) psf_list_trimmed = image_gather_channels(psf_list_trimmed) scattered_model_imagelist = \ image_scatter_facets(model_imagelist, facets=deconvolve_facets, overlap=deconvolve_overlap) # Work out the threshold. Need to find global peak over all dirty_list images threshold = get_parameter(kwargs, "threshold", 0.0) fractional_threshold = get_parameter(kwargs, "fractional_threshold", 0.1) nmoment = get_parameter(kwargs, "nmoment", 0) use_moment0 = nmoment > 0 # Find the global threshold. This uses the peak in the average on the frequency axis since we # want to use it in a stopping criterion in a moment clean global_threshold = threshold_list(scattered_facets_list, threshold, fractional_threshold, use_moment0=use_moment0, prefix=prefix) facet_list = numpy.arange(deconvolve_number_facets).astype('int') if mask is None: scattered_results_list = [ deconvolve(d, psf_list_trimmed, m, facet, global_threshold) for d, m, facet in zip(scattered_facets_list, scattered_model_imagelist, facet_list)] else: mask_list = \ image_scatter_facets(mask, facets=deconvolve_facets, overlap=deconvolve_overlap) scattered_results_list = [ deconvolve(d, psf_list_trimmed, m, facet, global_threshold, msk) for d, m, facet, msk in zip(scattered_facets_list, scattered_model_imagelist, facet_list, mask_list)] # Gather the results back into one image, correcting for overlaps as necessary. The taper function is is used to # feather the facets together gathered_results_list = image_gather_facets(scattered_results_list, model_imagelist, facets=deconvolve_facets, overlap=deconvolve_overlap, taper=deconvolve_taper) flat_list = image_gather_facets(scattered_results_list, model_imagelist, facets=deconvolve_facets, overlap=deconvolve_overlap, taper=deconvolve_taper, return_flat=True) return image_scatter_channels(gathered_results_list, subimages=nchan), flat_list
def deconvolve_list_serial_workflow(dirty_list, psf_list, model_imagelist, prefix='', **kwargs): """Create a graph for deconvolution, adding to the model :param dirty_list: :param psf_list: :param model_imagelist: :param kwargs: Parameters for functions in components :return: (graph for the deconvolution, graph for the flat) """ nchan = len(dirty_list) def deconvolve(dirty, psf, model, facet, gthreshold): import time starttime = time.time() if prefix == '': lprefix = "facet %d" % facet else: lprefix = "%s, facet %d" % (prefix, facet) nmoments = get_parameter(kwargs, "nmoments", 0) if nmoments > 0: moment0 = calculate_image_frequency_moments(dirty) this_peak = numpy.max(numpy.abs( moment0.data[0, ...])) / dirty.data.shape[0] else: this_peak = numpy.max(numpy.abs(dirty.data[0, ...])) if this_peak > 1.1 * gthreshold: log.info( "deconvolve_list_serial_workflow %s: cleaning - peak %.6f > 1.1 * threshold %.6f" % (lprefix, this_peak, gthreshold)) kwargs['threshold'] = gthreshold result, _ = deconvolve_cube(dirty, psf, prefix=lprefix, **kwargs) if result.data.shape[0] == model.data.shape[0]: result.data += model.data else: log.warning( "deconvolve_list_serial_workflow %s: Initial model %s and clean result %s do not have the same shape" % (lprefix, str( model.data.shape[0]), str(result.data.shape[0]))) flux = numpy.sum(result.data[0, 0, ...]) log.info( '### %s, %.6f, %.6f, True, %.3f # cycle, facet, peak, cleaned flux, clean, time?' % (lprefix, this_peak, flux, time.time() - starttime)) return result else: log.info( "deconvolve_list_serial_workflow %s: Not cleaning - peak %.6f <= 1.1 * threshold %.6f" % (lprefix, this_peak, gthreshold)) log.info( '### %s, %.6f, %.6f, False, %.3f # cycle, facet, peak, cleaned flux, clean, time?' % (lprefix, this_peak, 0.0, time.time() - starttime)) return copy_image(model) deconvolve_facets = get_parameter(kwargs, 'deconvolve_facets', 1) deconvolve_overlap = get_parameter(kwargs, 'deconvolve_overlap', 0) deconvolve_taper = get_parameter(kwargs, 'deconvolve_taper', None) if deconvolve_overlap > 0: deconvolve_number_facets = (deconvolve_facets - 2)**2 else: deconvolve_number_facets = deconvolve_facets**2 model_imagelist = image_gather_channels(model_imagelist) # Scatter the separate channel images into deconvolve facets and then gather channels for each facet. # This avoids constructing the entire spectral cube. # dirty_list = remove_sumwt, nout=nchan)(dirty_list) scattered_channels_facets_dirty_list = \ [image_scatter_facets(d[0], facets=deconvolve_facets, overlap=deconvolve_overlap, taper=deconvolve_taper) for d in dirty_list] # Now we do a transpose and gather scattered_facets_list = [ image_gather_channels([ scattered_channels_facets_dirty_list[chan][facet] for chan in range(nchan) ]) for facet in range(deconvolve_number_facets) ] psf_list = remove_sumwt(psf_list) psf_list = image_gather_channels(psf_list) scattered_model_imagelist = \ image_scatter_facets(model_imagelist, facets=deconvolve_facets, overlap=deconvolve_overlap) # Work out the threshold. Need to find global peak over all dirty_list images threshold = get_parameter(kwargs, "threshold", 0.0) fractional_threshold = get_parameter(kwargs, "fractional_threshold", 0.1) nmoments = get_parameter(kwargs, "nmoments", 0) use_moment0 = nmoments > 0 # Find the global threshold. This uses the peak in the average on the frequency axis since we # want to use it in a stopping criterion in a moment clean global_threshold = threshold_list(scattered_facets_list, threshold, fractional_threshold, use_moment0=use_moment0, prefix=prefix) facet_list = numpy.arange(deconvolve_number_facets).astype('int') scattered_results_list = [ deconvolve(d, psf_list, m, facet, global_threshold) for d, m, facet in zip(scattered_facets_list, scattered_model_imagelist, facet_list) ] # Gather the results back into one image, correcting for overlaps as necessary. The taper function is is used to # feather the facets together gathered_results_list = image_gather_facets(scattered_results_list, model_imagelist, facets=deconvolve_facets, overlap=deconvolve_overlap, taper=deconvolve_taper) flat_list = image_gather_facets(scattered_results_list, model_imagelist, facets=deconvolve_facets, overlap=deconvolve_overlap, taper=deconvolve_taper, return_flat=True) return image_scatter_channels(gathered_results_list, subimages=nchan), flat_list