def create_predict_graph(vis_graph_list, model_graph: delayed, vis_slices=1, facets=1, context='2d', **kwargs):
    """Predict, iterating over both the scattered vis_graph_list and image

    :param facets: 
    :param context: 
    :param vis_graph_list:
    :param model_graph: Model used to determine image parameters
    :param vis_slices: Number of vis slices (w stack or timeslice)
    :param kwargs: Parameters for functions in graphs
    :return: List of vis_graphs
   """
    c = imaging_context(context)
    image_iter = c['image_iterator']
    vis_iter = c['vis_iterator']
    
    def predict_ignore_none(vis, model):
        if vis is not None:
            predicted = copy_visibility(vis)
            predicted = predict_context(predicted, model, context=context, **kwargs)
            return predicted
        else:
            return None
    
    def gather_vis(results, vis):
        i = 0
        for rows in vis_iter(vis, vis_slices=vis_slices, **kwargs):
            if rows is not None:
                vis.data['vis'][rows][...] = results[i].data['vis'][...]
                
                i += 1
        return vis
    
    def scatter_vis(vis):
        if isinstance(vis, BlockVisibility):
            avis = coalesce_visibility(vis, **kwargs)
        else:
            avis = vis
        return [create_visibility_from_rows(vis, rows) for rows in vis_iter(avis, vis_slices=vis_slices, **kwargs)]
    
    def scatter_image(im):
        return [subim for subim in image_iter(im, facets=facets, **kwargs)]
    
    model_graphs = delayed(scatter_image, nout=facets ** 2)(model_graph)
    
    results_vis_graph_list = list()
    for vis_graph in vis_graph_list:
        
        sub_vis_graphs = delayed(scatter_vis, nout=vis_slices)(vis_graph)
        
        vis_graphs = list()
        for sub_model_graph in model_graphs:
            sub_model_results = list()
            for sub_vis_graph in sub_vis_graphs:
                sub_model_results.append(delayed(predict_ignore_none, pure=True, nout=1)(sub_vis_graph,
                                                                                         sub_model_graph))
            vis_graphs.append(delayed(sum_predict_results)(sub_model_results))
        
        results_vis_graph_list.append(delayed(gather_vis, nout=1)(vis_graphs, vis_graph))
    return results_vis_graph_list
Esempio n. 2
0
def degrid_handle(reppre_ifft, telescope_data, context, vis_slices, facets,
                  nfrequency, **kwargs) -> BlockVisibility:
    parallelism = get_parameter(kwargs, "parallelism")
    c = imaging_context(context)
    if context == "2d":
        telescope_data = telescope_data.mapValues(
            lambda vis: coalesce_visibility(vis, **kwargs))
        return reppre_ifft.join(
            telescope_data, parallelism).map(lambda record: degrid_kernel(
                record, c["predict"], context=context, **kwargs))

    elif context == "facets":
        telescope_data = telescope_data.mapValues(
            lambda vis: coalesce_visibility(vis, **kwargs))
        print("scatter_image")
        return reppre_ifft.flatMap(lambda im: scatter_image_flatmap(im, facets=facets, image_iter=c["image_iterator"], **kwargs), True)\
            .join(telescope_data).map(lambda record: degrid_kernel(record, c["predict"], context=context, **kwargs ), True)\
            .reduceByKey(sum_predict_vis_reduce_kernel)

    elif context == "facets_slice" or context == "facets_timeslice" or context == "facets_wstack":
        telescope_data_origin = telescope_data.map(
            lambda vis: (vis[0], coalesce_visibility(vis[1], **kwargs)))
        telescope_data = telescope_data_origin.flatMap(
            lambda vis: scatter_vis_flatmap(vis,
                                            vis_slices=vis_slices,
                                            vis_iter=c["vis_iterator"],
                                            **kwargs), True)
        # TODO 此处可优化
        return reppre_ifft.flatMap(lambda im: scatter_image_flatmap(im, facets=facets, image_iter=c["image_iterator"], **kwargs), True) \
            .join(telescope_data).map(lambda record: degrid_kernel(record, c["predict"], context=context, **kwargs)) \
            .combineByKey(gather_vis_createCombiner_kernel, gather_vis_mergeValue_kernel, gather_vis_mergeCombiner_kernel)\
            .map(change_key).join(telescope_data_origin).mapValues(lambda data: gather_vis_kernel(data, vis_slices=vis_slices, vis_iter=c["vis_iterator"], **kwargs))\
            .reduceByKey(sum_predict_vis_reduce_kernel)

    elif context == "slice" or context == "timeslice" or context == "wstack":
        telescope_data_origin = telescope_data.map(
            lambda vis: (vis[0], coalesce_visibility(vis[1], **kwargs)))
        telescope_data = telescope_data_origin.flatMap(
            lambda vis: scatter_vis_flatmap(vis,
                                            vis_slices=vis_slices,
                                            vis_iter=c["vis_iterator"],
                                            **kwargs), True)
        return reppre_ifft.join(telescope_data).map(lambda record: degrid_kernel(record, c["predict"], context=context, **kwargs)) \
        .combineByKey(gather_vis_createCombiner_kernel, gather_vis_mergeValue_kernel, gather_vis_mergeCombiner_kernel)\
        .join(telescope_data_origin).mapValues(lambda data: gather_vis_kernel(data, vis_slices=vis_slices, vis_iter=c["vis_iterator"], **kwargs))
Esempio n. 3
0
def create_invert_graph(vis_graph_list,
                        template_model_graph,
                        dopsf=False,
                        normalize=True,
                        facets=1,
                        vis_slices=None,
                        context="2d",
                        **kwargs):
    c = imaging_context(context)
    results_vis_graph_list = invert_handle(template_model_graph,
                                           vis_graph_list,
                                           context=context,
                                           dopsf=dopsf,
                                           normalize=normalize,
                                           facets=facets,
                                           vis_slices=vis_slices,
                                           **kwargs)
    return results_vis_graph_list
Esempio n. 4
0
File: bags.py Progetto: Jxt1/arlo
def invert_bag(vis_bag, model_bag, dopsf=False, context='2d', **kwargs):
    """ Construct a bag to invert a bag of visibilities to a bag of (image, weight) tuples
    
    Call directly - don't use via bag.map
    
    :param vis_bag:
    :param model:
    :param context:
    :param kwargs:
    :return:
    """
    c = imaging_context(context)
    log.info('Imaging context is %s' % c)
    assert c['scatter'] is not None
    return vis_bag. \
        map(c['scatter'], **kwargs). \
        map(safe_invert_list, model_bag, c['invert'], dopsf=dopsf, **kwargs). \
        map(sum_invert_bag_results)
Esempio n. 5
0
def invert_handle(template_model_graph, vis_graph_list, context, dopsf,
                  normalize, facets, vis_slices, **kwargs):
    c = imaging_context(context)
    parallelism = get_parameter(kwargs, "parallelism")
    image_metadata = template_model_graph.mapValues(
        lambda im: (im.wcs, im.polarisation_frame, im.shape))
    if context == "2d":
        visibility = vis_graph_list.mapValues(
            lambda vis: coalesce_visibility(vis, **kwargs))
        return template_model_graph.join(
            visibility, parallelism).map(lambda record: invert_kernel(
                record, c["invert"], dopsf, normalize, context, **kwargs))
    elif context == "facets":
        visibility = vis_graph_list.mapValues(
            lambda vis: coalesce_visibility(vis, **kwargs))
        return template_model_graph.flatMap(lambda im: scatter_image_flatmap(im, facets=facets, image_iter=c["image_iterator"], **kwargs), True)\
        .join(visibility).map(lambda record: invert_kernel(record, c["invert"], dopsf, normalize, context, **kwargs), True)\
        .combineByKey(gather_img_createConbiner_kernel, gather_img_margeValue_kernel, gather_img_mergeCombiner)\
        .join(image_metadata).mapValues(lambda data: gather_image_kernel(data, facets=facets, image_iter=c["image_iterator"], **kwargs))

    elif context == "facets_slice" or context == "facets_wstack":
        visibility = vis_graph_list.mapValues(lambda vis: coalesce_visibility(vis, **kwargs)).\
            flatMap(lambda vis: scatter_vis_flatmap(vis, vis_slices=vis_slices, vis_iter=c["vis_iterator"], **kwargs), True)
        return template_model_graph.flatMap(lambda im: scatter_image_flatmap(im, facets=facets, image_iter=c["image_iterator"], **kwargs), True)\
        .join(visibility).map(lambda record: invert_kernel(record, c["invert"], dopsf, normalize, context, **kwargs))\
        .reduceByKey(sum_inver_image_reduce_kernel).map(change_key).mapValues(lambda im_sumwt: (normalize_sumwt(im_sumwt[0], im_sumwt[1]), im_sumwt[1]))\
        .combineByKey(gather_img_createConbiner_kernel, gather_img_margeValue_kernel, gather_img_mergeCombiner)\
        .join(image_metadata).mapValues(lambda data: gather_image_kernel(data, facets=facets, image_iter=c["image_iterator"], **kwargs))

    elif context == "facets_timeslice":
        visibility = vis_graph_list.mapValues(lambda vis: coalesce_visibility(vis, **kwargs)). \
            flatMap(lambda vis: scatter_vis_flatmap(vis, vis_slices=vis_slices, vis_iter=c["vis_iterator"], **kwargs), True)
        return template_model_graph.flatMap(lambda im: scatter_image_flatmap(im, facets=facets, image_iter=c["image_iterator"], **kwargs), True) \
        .join(visibility).map(lambda record: invert_kernel(record, c["invert"], dopsf, normalize, context, **kwargs)) \
        .combineByKey(gather_img_createConbiner_kernel, gather_img_margeValue_kernel, gather_img_mergeCombiner) \
        .join(image_metadata).mapValues(lambda data: gather_image_kernel(data, facets=facets, image_iter=c["image_iterator"], **kwargs))\
        .map(change_key).reduceByKey(sum_inver_image_reduce_kernel).mapValues(lambda im_sumwt: (normalize_sumwt(im_sumwt[0], im_sumwt[1]), im_sumwt[1]))

    elif context == "slice" or context == "timeslice" or context == "wstack":
        visibility = vis_graph_list.mapValues(lambda vis: coalesce_visibility(vis, **kwargs)). \
            flatMap(lambda vis: scatter_vis_flatmap(vis, vis_slices=vis_slices, vis_iter=c["vis_iterator"], **kwargs), True)
        return template_model_graph.join(visibility).map(lambda record: invert_kernel(record, c["invert"], dopsf, normalize, context, **kwargs)) \
        .reduceByKey(sum_inver_image_reduce_kernel).mapValues(lambda im_sumwt: (normalize_sumwt(im_sumwt[0], im_sumwt[1]), im_sumwt[1]))
Esempio n. 6
0
File: bags.py Progetto: Jxt1/arlo
def predict_bag(vis_bag, model_bag, context='2d', **kwargs):
    """Construct a bag to predict a bag of visibilities.
    
    The vis_bag is scatter appropriately, the predict is applied, and the data then
    concatenated. The sort order of the data is not necessarily preserved.

    Call directly - don't use via bag.map
    
    :param vis_bag:
    :param model:
    :param context:
    :param kwargs:
    :return:
    """
    c = imaging_context(context)
    assert c['scatter'] is not None
    
    return vis_bag. \
        map(copy_visibility, zero=True). \
        map(c['scatter'], **kwargs). \
        map(safe_predict_list, model_bag, c['predict'], **kwargs). \
        map(concatenate_visibility)
def create_invert_graph(vis_graph_list, template_model_graph: delayed, dopsf=False, normalize=True,
                        facets=1, vis_slices=1, context='2d', **kwargs) -> delayed:
    """ Sum results from invert, iterating over the scattered image and vis_graph_list

    :param vis_graph_list:
    :param template_model_graph: Model used to determine image parameters
    :param dopsf: Make the PSF instead of the dirty image
    :param facets: Number of facets
    :param normalize: Normalize by sumwt
    :param vis_slices: Number of slices
    :param context: Imaging context
    :param kwargs: Parameters for functions in graphs
    :return: delayed for invert
   """
    c = imaging_context(context)
    image_iter = c['image_iterator']
    vis_iter = c['vis_iterator']
    inner = c['inner']
    
    def scatter_vis(vis):
        if isinstance(vis, BlockVisibility):
            avis = coalesce_visibility(vis, **kwargs)
        else:
            avis = vis
        return [create_visibility_from_rows(vis, rows) for rows in vis_iter(avis, vis_slices=vis_slices, **kwargs)]
    
    def scatter_image_iteration(im):
        return [subim for subim in image_iter(im, facets=facets, **kwargs)]
    
    def gather_image_iteration_results(results, template_model):
        result = create_empty_image_like(template_model)
        i = 0
        for dpatch in image_iter(result, facets=facets, **kwargs):
            if results[i] is not None:
                dpatch.data[...] = results[i][0].data[...]
                i += 1
        return result, results[0][1]
    
    def invert_ignore_none(vis, model):
        if vis is not None:
            return invert_context(vis, model, context=context, dopsf=dopsf, normalize=normalize,
                                  **kwargs)
        else:
            return create_empty_image_like(model), numpy.zeros([model.nchan, model.npol])
    
    # Scatter the model in e.g. facets
    model_graphs = delayed(scatter_image_iteration, nout=facets ** 2)(template_model_graph)
    # Loop over all vis_graphs independently
    
    results_vis_graph_list = list()
    for vis_graph in vis_graph_list:
        sub_vis_graphs = delayed(scatter_vis, nout=vis_slices)(vis_graph)
        # Iterate within each vis_graph
        if inner == 'vis':
            model_results = list()
            for model_graph in model_graphs:
                model_vis_results = list()
                for sub_vis_graph in sub_vis_graphs:
                    model_vis_results.append(delayed(invert_ignore_none, pure=True)(sub_vis_graph, model_graph))
                model_results.append(delayed(sum_invert_results)(model_vis_results))
            results_vis_graph_list.append(delayed(gather_image_iteration_results)(model_results, template_model_graph))
        else:
            vis_results = list()
            for sub_vis_graph in sub_vis_graphs:
                model_vis_results = list()
                for model_graph in model_graphs:
                    model_vis_results.append(delayed(invert_ignore_none, pure=True)(sub_vis_graph, model_graph))
                vis_results.append(delayed(gather_image_iteration_results)(model_vis_results,
                                                                           template_model_graph))
            results_vis_graph_list.append(delayed(sum_invert_results)(vis_results))
    
    return results_vis_graph_list
def create_predict_graph(vis_graph_list,
                         model_graph: delayed,
                         vis_slices=1,
                         facets=1,
                         context='2d',
                         **kwargs):
    """Predict, iterating over both the scattered vis_graph_list and image

    :param facets: 
    :param context: 
    :param vis_graph_list:
    :param model_graph: Model used to determine image parameters
    :param vis_slices: Number of vis slices (w stack or timeslice)
    :param kwargs: Parameters for functions in graphs
    :return: List of vis_graphs
   """
    c = imaging_context(context)
    image_iter = c['image_iterator']
    vis_iter = c['vis_iterator']
    predict = c['predict']
    inner = c['inner']

    def predict_ignore_none(vis, model):
        if vis is not None:
            predicted = copy_visibility(vis)
            predicted = predict(predicted, model, context=context, **kwargs)
            return predicted
        else:
            return None

    def gather_vis(results, vis):
        # Gather across the visibility iteration axis
        assert vis is not None
        if isinstance(vis, BlockVisibility):
            avis = coalesce_visibility(vis, **kwargs)
        else:
            avis = vis
        for i, rows in enumerate(
                vis_iter(avis, vis_slices=vis_slices, **kwargs)):
            assert i < len(results), "Insufficient results for the gather"
            if rows is not None and results[i] is not None:
                avis.data['vis'][rows] = results[i].data['vis']

        if isinstance(vis, BlockVisibility):
            return decoalesce_visibility(avis, **kwargs)
        else:
            return avis

    def scatter_vis(vis):
        # Scatter along the visibility iteration axis
        if isinstance(vis, BlockVisibility):
            avis = coalesce_visibility(vis, **kwargs)
        else:
            avis = vis
        result = [
            create_visibility_from_rows(avis, rows)
            for rows in vis_iter(avis, vis_slices=vis_slices, **kwargs)
        ]
        return result

    def scatter_image(im):
        # Scatter across image iteration
        return [subim for subim in image_iter(im, facets=facets, **kwargs)]

    results_vis_graph_list = list()
    for freqwin, vis_graph in enumerate(vis_graph_list):
        sub_model_graphs = delayed(scatter_image,
                                   nout=facets**2)(model_graph[freqwin])
        sub_vis_graphs = delayed(scatter_vis, nout=vis_slices)(vis_graph)
        vis_graphs = list()
        for sub_model_graph in sub_model_graphs:
            sub_model_results = list()
            for sub_vis_graph in sub_vis_graphs:
                model_vis_graph = delayed(predict_ignore_none,
                                          pure=True,
                                          nout=1)(sub_vis_graph,
                                                  sub_model_graph)
                sub_model_results.append(model_vis_graph)
            vis_graphs.append(
                delayed(gather_vis, nout=1)(sub_model_results, vis_graph))
        results_vis_graph_list.append(delayed(sum_predict_results)(vis_graphs))

    return results_vis_graph_list