示例#1
0
def apply_filter_vector_dask_true(_filter, arr, chunk=5):
    out = arr.copy()
    tc, slc = pad_next_square_size(out)
    tc = da.from_array(tc, (-1, -1, chunk))
    _filter = da.from_array(_filter)
    temp = dff.ifft2(
        da.multiply(dff.ifftshift(1 - _filter[:, :, None]),
                    dff.fft2(tc, axes=(0, 1))),
        axes=(0, 1),
    ).real
    return reverse_padding(arr, temp, slc)
示例#2
0
def make_gridding_convolution_function(vis_dataset, global_dataset, gcf_parms, grid_parms, storage_parms):
    """
    Currently creates a gcf to correct for the primary beams of antennas and supports heterogenous arrays (antennas with different dish sizes).
    Only the airy disk and ALMA airy disk model is implemented.
    In the future support will be added for beam squint, pointing corrections, w projection, and including a prolate spheroidal term.
    
    Parameters
    ----------
    vis_dataset : xarray.core.dataset.Dataset
        Input visibility dataset.
    gcf_parms : dictionary
    gcf_parms['function'] : {'alma_airy'/'airy'}, default = 'alma_airy'
        The primary beam model used (a function of the dish diameter and blockage diameter).
    gcf_parms['list_dish_diameters']  : list of number, units = meter
        A list of unique antenna dish diameters.
    gcf_parms['list_blockage_diameters']  : list of number, units = meter
        A list of unique feed blockage diameters (must be the same length as gcf_parms['list_dish_diameters']).
    gcf_parms['unique_ant_indx']  : list of int
        A list that has indeces for the gcf_parms['list_dish_diameters'] and gcf_parms['list_blockage_diameters'] lists, for each antenna.
    gcf_parms['image_phase_center']  : list of number, length = 2, units = radians
        The mosaic image phase center.
    gcf_parms['a_chan_num_chunk']  : int, default = 3
        The number of chunks in the channel dimension of the gridding convolution function data variable.
    gcf_parms['oversampling']  : list of int, length = 2, default = [10,10]
        The oversampling of the gridding convolution function.
    gcf_parms['max_support']  : list of int, length = 2, default = [15,15]
        The maximum allowable support of the gridding convolution function.
    gcf_parms['support_cut_level']  : number, default = 0.025
        The antennuation at which to truncate the gridding convolution function.
    gcf_parms['chan_tolerance_factor']  : number, default = 0.005
        It is the fractional bandwidth at which the frequency dependence of the primary beam can be ignored and determines the number of frequencies for which to calculate a gridding convolution function. Number of channels equals the fractional bandwidth devided by gcf_parms['chan_tolerance_factor'].
    grid_parms : dictionary
    grid_parms['image_size'] : list of int, length = 2
        The image size (no padding).
    grid_parms['cell_size']  : list of number, length = 2, units = arcseconds
        The image cell size.
    storage_parms : dictionary
    storage_parms['to_disk'] : bool, default = False
        If true the dask graph is executed and saved to disk in the zarr format.
    storage_parms['append'] : bool, default = False
        If storage_parms['to_disk'] is True only the dask graph associated with the function is executed and the resulting data variables are saved to an existing zarr file on disk.
        Note that graphs on unrelated data to this function will not be executed or saved.
    storage_parms['outfile'] : str
        The zarr file to create or append to.
    storage_parms['chunks_on_disk'] : dict of int, default = {}
        The chunk size to use when writing to disk. This is ignored if storage_parms['append'] is True. The default will use the chunking of the input dataset.
    storage_parms['chunks_return'] : dict of int, default = {}
        The chunk size of the dataset that is returned. The default will use the chunking of the input dataset.
    storage_parms['graph_name'] : str
        The time to compute and save the data is stored in the attribute section of the dataset and storage_parms['graph_name'] is used in the label.
    storage_parms['compressor'] : numcodecs.blosc.Blosc,default=Blosc(cname='zstd', clevel=2, shuffle=0)
        The compression algorithm to use. Available compression algorithms can be found at https://numcodecs.readthedocs.io/en/stable/blosc.html.
    Returns
    -------
    gcf_dataset : xarray.core.dataset.Dataset
            
    """
    print('######################### Start make_gridding_convolution_function #########################')
    
    from ngcasa._ngcasa_utils._store import _store
    from ngcasa._ngcasa_utils._check_parms import _check_storage_parms
    from ._imaging_utils._check_imaging_parms import _check_pb_parms
    from ._imaging_utils._check_imaging_parms import _check_grid_parms, _check_gcf_parms
    from ._imaging_utils._gridding_convolutional_kernels import _create_prolate_spheroidal_kernel_2D, _create_prolate_spheroidal_image_2D
    from ._imaging_utils._remove_padding import _remove_padding
    import numpy as np
    import dask.array as da
    import copy, os
    import xarray as xr
    import itertools
    import dask
    import dask.array.fft as dafft
    
    import matplotlib.pylab as plt
    
    _gcf_parms = copy.deepcopy(gcf_parms)
    _grid_parms = copy.deepcopy(grid_parms)
    _storage_parms = copy.deepcopy(storage_parms)
    
    _gcf_parms['basline_ant'] = vis_dataset.antennas.values # n_baseline x 2 (ant pair)
    _gcf_parms['freq_chan'] = vis_dataset.chan.values
    _gcf_parms['pol'] = vis_dataset.pol.values
    _gcf_parms['vis_data_chunks'] = vis_dataset.DATA.chunks
    _gcf_parms['field_phase_dir'] = np.array(global_dataset.FIELD_PHASE_DIR.values[:,:,vis_dataset.attrs['ddi']])
    
    assert(_check_gcf_parms(_gcf_parms)), "######### ERROR: gcf_parms checking failed"
    assert(_check_grid_parms(_grid_parms)), "######### ERROR: grid_parms checking failed"
    assert(_check_storage_parms(_storage_parms,'dataset.gcf.zarr','make_gcf')), "######### ERROR: user_storage_parms checking failed"
    
    assert(not _storage_parms['append']), "######### ERROR: storage_parms['append'] = True is not available for make_gridding_convolution_function"
        
    if _gcf_parms['function'] == 'airy':
        from ._imaging_utils._make_pb_symmetric import _airy_disk_rorder
        pb_func = _airy_disk_rorder
    elif _gcf_parms['function'] == 'alma_airy':
        from ._imaging_utils._make_pb_symmetric import _alma_airy_disk_rorder
        pb_func = _alma_airy_disk_rorder
    else:
        assert(False), "######### ERROR: Only airy and alma_airy function has been implemented"
        
    #For now only a_term works
    _gcf_parms['a_term'] =  True
    _gcf_parms['ps_term'] =  False
        
    _gcf_parms['resize_conv_size'] = (_gcf_parms['max_support'] + 1)*_gcf_parms['oversampling']
    #resize_conv_size = _gcf_parms['resize_conv_size']
    
    if _gcf_parms['ps_term'] == True:
        '''
        ps_term = _create_prolate_spheroidal_kernel_2D(_gcf_parms['oversampling'],np.array([7,7])) #This is only used with a_term == False. Support is hardcoded to 7 until old ps code is replaced by a general function.
        center = _grid_parms['image_center']
        center_embed = np.array(ps_term.shape)//2
        ps_term_padded = np.zeros(_grid_parms['image_size'])
        ps_term_padded[center[0]-center_embed[0]:center[0]+center_embed[0],center[1]-center_embed[1] : center[1]+center_embed[1]] = ps_term
        ps_term_padded_ifft = dafft.fftshift(dafft.ifft2(dafft.ifftshift(da.from_array(ps_term_padded))))

        ps_image = da.from_array(_remove_padding(_create_prolate_spheroidal_image_2D(_grid_parms['image_size_padded']),_grid_parms['image_size']),chunks=_grid_parms['image_size'])
        
        #Effecively no mapping needed if ps_term == True and a_term == False
        cf_baseline_map = np.zeros((len(_gcf_parms['basline_ant']),),dtype=int)
        cf_chan_map = np.zeros((len(_gcf_parms['freq_chan']),),dtype=int)
        cf_pol_map = np.zeros((len(_gcf_parms['pol']),),dtype=int)
        '''
    
    if _gcf_parms['a_term'] == True:
        n_unique_ant = len(_gcf_parms['list_dish_diameters'])
        cf_baseline_map,pb_ant_pairs = create_cf_baseline_map(_gcf_parms['unique_ant_indx'],_gcf_parms['basline_ant'],n_unique_ant)
        
        cf_chan_map, pb_freq = create_cf_chan_map(_gcf_parms['freq_chan'],_gcf_parms['chan_tolerance_factor'])
        pb_freq = da.from_array(pb_freq,chunks=np.ceil(len(pb_freq)/_gcf_parms['a_chan_num_chunk'] ))
        
        cf_pol_map = np.zeros((len(_gcf_parms['pol']),),dtype=int) #create_cf_pol_map(), currently treating all pols the same
        pb_pol = da.from_array(np.array([0]),1)
        
        n_chunks_in_each_dim = [pb_freq.numblocks[0],pb_pol.numblocks[0]]
        iter_chunks_indx = itertools.product(np.arange(n_chunks_in_each_dim[0]), np.arange(n_chunks_in_each_dim[1]))
        chan_chunk_sizes = pb_freq.chunks
        pol_chunk_sizes = pb_pol.chunks
        
        #print(pb_freq, pb_pol,pol_chunk_sizes)
        list_baseline_pb = []
        list_weight_baseline_pb_sqrd = []
        for c_chan, c_pol in iter_chunks_indx:
                #print('chan, pol ',c_chan,c_pol)
                _gcf_parms['ipower'] = 1
                delayed_baseline_pb = dask.delayed(make_baseline_patterns)(pb_freq.partitions[c_chan],pb_pol.partitions[c_pol],dask.delayed(pb_ant_pairs),dask.delayed(pb_func),dask.delayed(_gcf_parms),dask.delayed(_grid_parms))
                
                list_baseline_pb.append(da.from_delayed(delayed_baseline_pb,(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_grid_parms['image_size_padded'][0],_grid_parms['image_size_padded'][1]),dtype=np.double))
                              
                _gcf_parms['ipower'] = 2
                delayed_weight_baseline_pb_sqrd = dask.delayed(make_baseline_patterns)(pb_freq.partitions[c_chan],pb_pol.partitions[c_pol],dask.delayed(pb_ant_pairs),dask.delayed(pb_func),dask.delayed(_gcf_parms),dask.delayed(_grid_parms))
                
                list_weight_baseline_pb_sqrd.append(da.from_delayed(delayed_weight_baseline_pb_sqrd,(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_grid_parms['image_size_padded'][0],_grid_parms['image_size_padded'][1]),dtype=np.double))
               
        
        baseline_pb = da.concatenate(list_baseline_pb,axis=1)
        weight_baseline_pb_sqrd = da.concatenate(list_weight_baseline_pb_sqrd,axis=1)
    
    #Combine patterns and fft to obtain the gridding convolutional kernel
    #print(weight_baseline_pb_sqrd)

    dataset_dict = {}
    list_xarray_data_variables = []
    if (_gcf_parms['a_term'] == True) and (_gcf_parms['ps_term'] == True):
        conv_kernel = da.real(dafft.fftshift(dafft.fft2(dafft.ifftshift(ps_term_padded_ifft*baseline_pb, axes=(3, 4)), axes=(3, 4)), axes=(3, 4)))
        conv_weight_kernel = da.real(dafft.fftshift(dafft.fft2(dafft.ifftshift(weight_baseline_pb_sqrd, axes=(3, 4)), axes=(3, 4)), axes=(3, 4)))
        
        
        list_conv_kernel = []
        list_weight_conv_kernel = []
        list_conv_support = []
        iter_chunks_indx = itertools.product(np.arange(n_chunks_in_each_dim[0]), np.arange(n_chunks_in_each_dim[1]))
        for c_chan, c_pol in iter_chunks_indx:
                delayed_kernels_and_support = dask.delayed(resize_and_calc_support)(conv_kernel.partitions[:,c_chan,c_pol,:,:],conv_weight_kernel.partitions[:,c_chan,c_pol,:,:],dask.delayed(_gcf_parms),dask.delayed(_grid_parms))
                list_conv_kernel.append(da.from_delayed(delayed_kernels_and_support[0],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]),dtype=np.double))
                list_weight_conv_kernel.append(da.from_delayed(delayed_kernels_and_support[1],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]),dtype=np.double))
                list_conv_support.append(da.from_delayed(delayed_kernels_and_support[2],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],2),dtype=np.int))
                
        
        conv_kernel = da.concatenate(list_conv_kernel,axis=1)
        weight_conv_kernel = da.concatenate(list_weight_conv_kernel,axis=1)
        conv_support = da.concatenate(list_conv_support,axis=1)
        
    
        dataset_dict['SUPPORT'] = xr.DataArray(conv_support, dims=['conv_baseline','conv_chan','conv_pol','xy'])
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(ps_image, dims=['l','m'])
        dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(weight_conv_kernel, dims=['conv_baseline','conv_chan','conv_pol','u','v'])
    elif (_gcf_parms['a_term'] == False) and (_gcf_parms['ps_term'] == True):
        support = np.array([7,7])
        dataset_dict['SUPPORT'] = xr.DataArray(support[None,None,None,:], dims=['conv_baseline','conv_chan','conv_pol','xy'])
        conv_kernel = np.zeros((1,1,1,_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]))
        center = _gcf_parms['resize_conv_size']//2
        center_embed = np.array(ps_term.shape)//2
        conv_kernel[0,0,0,center[0]-center_embed[0]:center[0]+center_embed[0],center[1]-center_embed[1] : center[1]+center_embed[1]] = ps_term
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(ps_image, dims=['l','m'])
        ##Enabled for test
        #dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(conv_kernel, dims=['conv_baseline','conv_chan','conv_pol','u','v'])
    elif (_gcf_parms['a_term'] == True) and (_gcf_parms['ps_term'] == False):
        conv_kernel = da.real(dafft.fftshift(dafft.fft2(dafft.ifftshift(baseline_pb, axes=(3, 4)), axes=(3, 4)), axes=(3, 4)))
        conv_weight_kernel = da.real(dafft.fftshift(dafft.fft2(dafft.ifftshift(weight_baseline_pb_sqrd, axes=(3, 4)), axes=(3, 4)), axes=(3, 4)))
        
        list_conv_kernel = []
        list_weight_conv_kernel = []
        list_conv_support = []
        iter_chunks_indx = itertools.product(np.arange(n_chunks_in_each_dim[0]), np.arange(n_chunks_in_each_dim[1]))
        for c_chan, c_pol in iter_chunks_indx:
                delayed_kernels_and_support = dask.delayed(resize_and_calc_support)(conv_kernel.partitions[:,c_chan,c_pol,:,:],conv_weight_kernel.partitions[:,c_chan,c_pol,:,:],dask.delayed(_gcf_parms),dask.delayed(_grid_parms))
                list_conv_kernel.append(da.from_delayed(delayed_kernels_and_support[0],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]),dtype=np.double))
                list_weight_conv_kernel.append(da.from_delayed(delayed_kernels_and_support[1],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]),dtype=np.double))
                list_conv_support.append(da.from_delayed(delayed_kernels_and_support[2],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],2),dtype=np.int))
                
        
        conv_kernel = da.concatenate(list_conv_kernel,axis=1)
        weight_conv_kernel = da.concatenate(list_weight_conv_kernel,axis=1)
        conv_support = da.concatenate(list_conv_support,axis=1)
        
    
        dataset_dict['SUPPORT'] = xr.DataArray(conv_support, dims=['conv_baseline','conv_chan','conv_pol','xy'])
        dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(weight_conv_kernel, dims=['conv_baseline','conv_chan','conv_pol','u','v'])
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(da.from_array(np.ones(_grid_parms['image_size']),chunks=_grid_parms['image_size']), dims=['l','m'])
    else:
        assert(False), "######### ERROR: At least 'a_term' or 'ps_term' must be true."
    
    ###########################################################
    #Make phase gradient (one for each field)
    field_phase_dir = _gcf_parms['field_phase_dir']
    field_phase_dir = da.from_array(field_phase_dir,chunks=(np.ceil(len(field_phase_dir)/_gcf_parms['a_chan_num_chunk']),2))
    
    phase_gradient = da.blockwise(make_phase_gradient, ("n_field","n_x","n_y"), field_phase_dir, ("n_field","2"), gcf_parms=_gcf_parms, grid_parms=_grid_parms, dtype=complex,  new_axes={"n_x": _gcf_parms['resize_conv_size'][0], "n_y": _gcf_parms['resize_conv_size'][1]})
    

    ###########################################################
    
    #coords = {'baseline': np.arange(n_unique_ant), 'chan': pb_freq, 'pol' : pb_pol, 'u': np.arange(resize_conv_size[0]), 'v': np.arange(resize_conv_size[1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_gridding_convolution_parms['imsize'][0]),'m':np.arange(_gridding_convolution_parms['imsize'][1])}
        
    #coords = { 'conv_chan': pb_freq, 'conv_pol' : pb_pol, 'u': np.arange(resize_conv_size[0]), 'v': np.arange(resize_conv_size[1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_gridding_convolution_parms['imsize'][0]),'m':np.arange(_gridding_convolution_parms['imsize'][1])}
    
    coords = { 'u': np.arange(_gcf_parms['resize_conv_size'][0]), 'v': np.arange(_gcf_parms['resize_conv_size'][1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_grid_parms['image_size'][0]),'m':np.arange(_grid_parms['image_size'][1])}
    
    dataset_dict['CF_BASELINE_MAP'] = xr.DataArray(cf_baseline_map, dims=('baseline')).chunk(_gcf_parms['vis_data_chunks'][1])
    dataset_dict['CF_CHAN_MAP'] = xr.DataArray(cf_chan_map, dims=('chan')).chunk(_gcf_parms['vis_data_chunks'][2])
    dataset_dict['CF_POL_MAP'] = xr.DataArray(cf_pol_map, dims=('pol')).chunk(_gcf_parms['vis_data_chunks'][3])
    
        
    dataset_dict['CONV_KERNEL'] = xr.DataArray(conv_kernel, dims=('conv_baseline','conv_chan','conv_pol','u','v'))
    dataset_dict['PHASE_GRADIENT'] = xr.DataArray(phase_gradient, dims=('field','u','v'))
    
    gcf_dataset = xr.Dataset(dataset_dict, coords=coords)
    gcf_dataset.attrs['cell_uv'] =1/(_grid_parms['image_size_padded']*_grid_parms['cell_size']*_gcf_parms['oversampling'])
    gcf_dataset.attrs['oversampling'] = _gcf_parms['oversampling']
    
    
    #list_xarray_data_variables = [gcf_dataset['A_TERM'],gcf_dataset['WEIGHT_A_TERM'],gcf_dataset['A_SUPPORT'],gcf_dataset['WEIGHT_A_SUPPORT'],gcf_dataset['PHASE_GRADIENT']]
    return _store(gcf_dataset,list_xarray_data_variables,_storage_parms)
def make_gridding_convolution_function(mxds, gcf_parms, grid_parms, sel_parms):
    """
    Currently creates a gcf to correct for the primary beams of antennas and supports heterogenous arrays (antennas with different dish sizes).
    Only the airy disk and ALMA airy disk model is implemented.
    In the future support will be added for beam squint, pointing corrections, w projection, and including a prolate spheroidal term.
    
    Parameters
    ----------
    vis_dataset : xarray.core.dataset.Dataset
        Input visibility dataset.
    gcf_parms : dictionary
    gcf_parms['function'] : {'casa_airy'/'airy'}, default = 'casa_airy'
        The primary beam model used (a function of the dish diameter and blockage diameter).
    gcf_parms['list_dish_diameters']  : list of number, units = meter
        A list of unique antenna dish diameters.
    gcf_parms['list_blockage_diameters']  : list of number, units = meter
        A list of unique feed blockage diameters (must be the same length as gcf_parms['list_dish_diameters']).
    gcf_parms['unique_ant_indx']  : list of int
        A list that has indeces for the gcf_parms['list_dish_diameters'] and gcf_parms['list_blockage_diameters'] lists, for each antenna.
    gcf_parms['image_phase_center']  : list of number, length = 2, units = radians
        The mosaic image phase center.
    gcf_parms['a_chan_num_chunk']  : int, default = 3
        The number of chunks in the channel dimension of the gridding convolution function data variable.
    gcf_parms['oversampling']  : list of int, length = 2, default = [10,10]
        The oversampling of the gridding convolution function.
    gcf_parms['max_support']  : list of int, length = 2, default = [15,15]
        The maximum allowable support of the gridding convolution function.
    gcf_parms['support_cut_level']  : number, default = 0.025
        The antennuation at which to truncate the gridding convolution function.
    gcf_parms['chan_tolerance_factor']  : number, default = 0.005
        It is the fractional bandwidth at which the frequency dependence of the primary beam can be ignored and determines the number of frequencies for which to calculate a gridding convolution function. Number of channels equals the fractional bandwidth devided by gcf_parms['chan_tolerance_factor'].
    grid_parms : dictionary
    grid_parms['image_size'] : list of int, length = 2
        The image size (no padding).
    grid_parms['cell_size']  : list of number, length = 2, units = arcseconds
        The image cell size.
    Returns
    -------
    gcf_dataset : xarray.core.dataset.Dataset
            
    """
    print(
        '######################### Start make_gridding_convolution_function #########################'
    )

    from ._imaging_utils._check_imaging_parms import _check_pb_parms
    from cngi._utils._check_parms import _check_sel_parms, _check_existence_sel_parms
    from ._imaging_utils._check_imaging_parms import _check_grid_parms, _check_gcf_parms
    from ._imaging_utils._gridding_convolutional_kernels import _create_prolate_spheroidal_kernel_2D, _create_prolate_spheroidal_image_2D
    from ._imaging_utils._remove_padding import _remove_padding
    import numpy as np
    import dask.array as da
    import copy, os
    import xarray as xr
    import itertools
    import dask
    import dask.array.fft as dafft
    import time

    import matplotlib.pylab as plt

    #Deep copy so that inputs are not modified
    _mxds = mxds.copy(deep=True)
    _gcf_parms = copy.deepcopy(gcf_parms)
    _grid_parms = copy.deepcopy(grid_parms)
    _sel_parms = copy.deepcopy(sel_parms)

    ##############Parameter Checking and Set Defaults##############
    assert (
        'xds' in _sel_parms
    ), "######### ERROR: xds must be specified in sel_parms"  #Can't have a default since xds names are not fixed.
    _vis_dataset = _mxds.attrs[sel_parms['xds']]

    assert (
        'xds' in _sel_parms
    ), "######### ERROR: xds must be specified in sel_parms"  #Can't have a default since xds names are not fixed.
    _vis_dataset = _mxds.attrs[sel_parms['xds']]

    _check_sel_parms(_vis_dataset, _sel_parms)

    #_gcf_parms['basline_ant'] = np.unique([_vis_dataset.ANTENNA1.max(axis=0), _vis_dataset.ANTENNA2.max(axis=0)], axis=0).T
    _gcf_parms['basline_ant'] = np.array(
        [_vis_dataset.ANTENNA1.values, _vis_dataset.ANTENNA2.values]).T

    _gcf_parms['freq_chan'] = _vis_dataset.chan.values
    _gcf_parms['pol'] = _vis_dataset.pol.values
    _gcf_parms['vis_data_chunks'] = _vis_dataset.DATA.chunks

    _gcf_parms['field_phase_dir'] = mxds.FIELD.PHASE_DIR[:,
                                                         0, :].data.compute()
    field_id = mxds.FIELD.field_id.data  #.compute()

    #print(_gcf_parms['field_phase_dir'])
    #_gcf_parms['field_phase_dir'] = np.array(global_dataset.FIELD_PHASE_DIR.values[:,:,vis_dataset.attrs['ddi']])

    assert (_check_gcf_parms(_gcf_parms)
            ), "######### ERROR: gcf_parms checking failed"
    assert (_check_grid_parms(_grid_parms)
            ), "######### ERROR: grid_parms checking failed"

    if _gcf_parms['function'] == 'airy':
        from ._imaging_utils._make_pb_symmetric import _airy_disk_rorder
        pb_func = _airy_disk_rorder
    elif _gcf_parms['function'] == 'casa_airy':
        from ._imaging_utils._make_pb_symmetric import _casa_airy_disk_rorder
        pb_func = _casa_airy_disk_rorder
    else:
        assert (
            False
        ), "######### ERROR: Only airy and casa_airy function has been implemented"

    #For now only a_term works
    _gcf_parms['a_term'] = True
    _gcf_parms['ps_term'] = False

    _gcf_parms['resize_conv_size'] = (_gcf_parms['max_support'] +
                                      1) * _gcf_parms['oversampling']
    #resize_conv_size = _gcf_parms['resize_conv_size']

    if _gcf_parms['ps_term'] == True:
        '''
        ps_term = _create_prolate_spheroidal_kernel_2D(_gcf_parms['oversampling'],np.array([7,7])) #This is only used with a_term == False. Support is hardcoded to 7 until old ps code is replaced by a general function.
        center = _grid_parms['image_center']
        center_embed = np.array(ps_term.shape)//2
        ps_term_padded = np.zeros(_grid_parms['image_size'])
        ps_term_padded[center[0]-center_embed[0]:center[0]+center_embed[0],center[1]-center_embed[1] : center[1]+center_embed[1]] = ps_term
        ps_term_padded_ifft = dafft.fftshift(dafft.ifft2(dafft.ifftshift(da.from_array(ps_term_padded))))

        ps_image = da.from_array(_remove_padding(_create_prolate_spheroidal_image_2D(_grid_parms['image_size_padded']),_grid_parms['image_size']),chunks=_grid_parms['image_size'])

        #Effecively no mapping needed if ps_term == True and a_term == False
        cf_baseline_map = np.zeros((len(_gcf_parms['basline_ant']),),dtype=int)
        cf_chan_map = np.zeros((len(_gcf_parms['freq_chan']),),dtype=int)
        cf_pol_map = np.zeros((len(_gcf_parms['pol']),),dtype=int)
        '''

    if _gcf_parms['a_term'] == True:
        n_unique_ant = len(_gcf_parms['list_dish_diameters'])

        cf_baseline_map, pb_ant_pairs = create_cf_baseline_map(
            _gcf_parms['unique_ant_indx'], _gcf_parms['basline_ant'],
            n_unique_ant)

        cf_chan_map, pb_freq = create_cf_chan_map(
            _gcf_parms['freq_chan'], _gcf_parms['chan_tolerance_factor'])
        #print('****',pb_freq)
        pb_freq = da.from_array(
            pb_freq,
            chunks=np.ceil(len(pb_freq) / _gcf_parms['a_chan_num_chunk']))

        cf_pol_map = np.zeros(
            (len(_gcf_parms['pol']), ), dtype=int
        )  #create_cf_pol_map(), currently treating all pols the same
        pb_pol = da.from_array(np.array([0]), 1)

        n_chunks_in_each_dim = [pb_freq.numblocks[0], pb_pol.numblocks[0]]
        iter_chunks_indx = itertools.product(
            np.arange(n_chunks_in_each_dim[0]),
            np.arange(n_chunks_in_each_dim[1]))
        chan_chunk_sizes = pb_freq.chunks
        pol_chunk_sizes = pb_pol.chunks

        #print(pb_freq, pb_pol,pol_chunk_sizes)
        list_baseline_pb = []
        list_weight_baseline_pb_sqrd = []
        for c_chan, c_pol in iter_chunks_indx:
            #print('chan, pol ',c_chan,c_pol)
            _gcf_parms['ipower'] = 1
            delayed_baseline_pb = dask.delayed(make_baseline_patterns)(
                pb_freq.partitions[c_chan], pb_pol.partitions[c_pol],
                dask.delayed(pb_ant_pairs), dask.delayed(pb_func),
                dask.delayed(_gcf_parms), dask.delayed(_grid_parms))

            list_baseline_pb.append(
                da.from_delayed(
                    delayed_baseline_pb,
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol],
                     _grid_parms['image_size_padded'][0],
                     _grid_parms['image_size_padded'][1]),
                    dtype=np.double))

            _gcf_parms['ipower'] = 2
            delayed_weight_baseline_pb_sqrd = dask.delayed(
                make_baseline_patterns)(pb_freq.partitions[c_chan],
                                        pb_pol.partitions[c_pol],
                                        dask.delayed(pb_ant_pairs),
                                        dask.delayed(pb_func),
                                        dask.delayed(_gcf_parms),
                                        dask.delayed(_grid_parms))

            list_weight_baseline_pb_sqrd.append(
                da.from_delayed(
                    delayed_weight_baseline_pb_sqrd,
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol],
                     _grid_parms['image_size_padded'][0],
                     _grid_parms['image_size_padded'][1]),
                    dtype=np.double))

        baseline_pb = da.concatenate(list_baseline_pb, axis=1)
        weight_baseline_pb_sqrd = da.concatenate(list_weight_baseline_pb_sqrd,
                                                 axis=1)

#    x = baseline_pb.compute()
#    print("&*&*&*&",x.shape)
#    plt.figure()
#    plt.imshow(x[0,0,0,240:260,240:260])
#    plt.show()

#Combine patterns and fft to obtain the gridding convolutional kernel
#print(weight_baseline_pb_sqrd)

    dataset_dict = {}
    list_xarray_data_variables = []
    if (_gcf_parms['a_term'] == True) and (_gcf_parms['ps_term'] == True):
        conv_kernel = da.real(
            dafft.fftshift(dafft.fft2(dafft.ifftshift(ps_term_padded_ifft *
                                                      baseline_pb,
                                                      axes=(3, 4)),
                                      axes=(3, 4)),
                           axes=(3, 4)))
        conv_weight_kernel = da.real(
            dafft.fftshift(dafft.fft2(dafft.ifftshift(weight_baseline_pb_sqrd,
                                                      axes=(3, 4)),
                                      axes=(3, 4)),
                           axes=(3, 4)))

        list_conv_kernel = []
        list_weight_conv_kernel = []
        list_conv_support = []
        iter_chunks_indx = itertools.product(
            np.arange(n_chunks_in_each_dim[0]),
            np.arange(n_chunks_in_each_dim[1]))
        for c_chan, c_pol in iter_chunks_indx:
            delayed_kernels_and_support = dask.delayed(
                resize_and_calc_support)(
                    conv_kernel.partitions[:, c_chan, c_pol, :, :],
                    conv_weight_kernel.partitions[:, c_chan, c_pol, :, :],
                    dask.delayed(_gcf_parms), dask.delayed(_grid_parms))
            list_conv_kernel.append(
                da.from_delayed(
                    delayed_kernels_and_support[0],
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol],
                     _gcf_parms['resize_conv_size'][0],
                     _gcf_parms['resize_conv_size'][1]),
                    dtype=np.double))
            list_weight_conv_kernel.append(
                da.from_delayed(
                    delayed_kernels_and_support[1],
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol],
                     _gcf_parms['resize_conv_size'][0],
                     _gcf_parms['resize_conv_size'][1]),
                    dtype=np.double))
            list_conv_support.append(
                da.from_delayed(
                    delayed_kernels_and_support[2],
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol], 2),
                    dtype=np.int))

        conv_kernel = da.concatenate(list_conv_kernel, axis=1)
        weight_conv_kernel = da.concatenate(list_weight_conv_kernel, axis=1)
        conv_support = da.concatenate(list_conv_support, axis=1)

        dataset_dict['SUPPORT'] = xr.DataArray(
            conv_support,
            dims=['conv_baseline', 'conv_chan', 'conv_pol', 'xy'])
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(ps_image, dims=['l', 'm'])
        dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(
            weight_conv_kernel,
            dims=['conv_baseline', 'conv_chan', 'conv_pol', 'u', 'v'])
    elif (_gcf_parms['a_term'] == False) and (_gcf_parms['ps_term'] == True):
        support = np.array([7, 7])
        dataset_dict['SUPPORT'] = xr.DataArray(
            support[None, None, None, :],
            dims=['conv_baseline', 'conv_chan', 'conv_pol', 'xy'])
        conv_kernel = np.zeros((1, 1, 1, _gcf_parms['resize_conv_size'][0],
                                _gcf_parms['resize_conv_size'][1]))
        center = _gcf_parms['resize_conv_size'] // 2
        center_embed = np.array(ps_term.shape) // 2
        conv_kernel[0, 0, 0,
                    center[0] - center_embed[0]:center[0] + center_embed[0],
                    center[1] - center_embed[1]:center[1] +
                    center_embed[1]] = ps_term
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(ps_image, dims=['l', 'm'])
        ##Enabled for test
        #dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(conv_kernel, dims=['conv_baseline','conv_chan','conv_pol','u','v'])
    elif (_gcf_parms['a_term'] == True) and (_gcf_parms['ps_term'] == False):
        conv_kernel = da.real(
            dafft.fftshift(dafft.fft2(dafft.ifftshift(baseline_pb,
                                                      axes=(3, 4)),
                                      axes=(3, 4)),
                           axes=(3, 4)))
        conv_weight_kernel = da.real(
            dafft.fftshift(dafft.fft2(dafft.ifftshift(weight_baseline_pb_sqrd,
                                                      axes=(3, 4)),
                                      axes=(3, 4)),
                           axes=(3, 4)))

        #        x = conv_weight_kernel.compute()
        #        print("&*&*&*&",x.shape)
        #        plt.figure()
        #        #plt.imshow(x[0,0,0,240:260,240:260])
        #        plt.imshow(x[0,0,0,:,:])
        #        plt.show()

        list_conv_kernel = []
        list_weight_conv_kernel = []
        list_conv_support = []
        iter_chunks_indx = itertools.product(
            np.arange(n_chunks_in_each_dim[0]),
            np.arange(n_chunks_in_each_dim[1]))
        for c_chan, c_pol in iter_chunks_indx:
            delayed_kernels_and_support = dask.delayed(
                resize_and_calc_support)(
                    conv_kernel.partitions[:, c_chan, c_pol, :, :],
                    conv_weight_kernel.partitions[:, c_chan, c_pol, :, :],
                    dask.delayed(_gcf_parms), dask.delayed(_grid_parms))
            list_conv_kernel.append(
                da.from_delayed(
                    delayed_kernels_and_support[0],
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol],
                     _gcf_parms['resize_conv_size'][0],
                     _gcf_parms['resize_conv_size'][1]),
                    dtype=np.double))
            list_weight_conv_kernel.append(
                da.from_delayed(
                    delayed_kernels_and_support[1],
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol],
                     _gcf_parms['resize_conv_size'][0],
                     _gcf_parms['resize_conv_size'][1]),
                    dtype=np.double))
            list_conv_support.append(
                da.from_delayed(
                    delayed_kernels_and_support[2],
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol], 2),
                    dtype=np.int))

        conv_kernel = da.concatenate(list_conv_kernel, axis=1)
        weight_conv_kernel = da.concatenate(list_weight_conv_kernel, axis=1)
        conv_support = da.concatenate(list_conv_support, axis=1)

        #        x = weight_conv_kernel.compute()
        #        print("&*&*&*&",x.shape)
        #        plt.figure()
        #        #plt.imshow(x[0,0,0,240:260,240:260])
        #        plt.imshow(x[0,0,0,:,:])
        #        plt.show()

        dataset_dict['SUPPORT'] = xr.DataArray(
            conv_support,
            dims=['conv_baseline', 'conv_chan', 'conv_pol', 'xy'])
        dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(
            weight_conv_kernel,
            dims=['conv_baseline', 'conv_chan', 'conv_pol', 'u', 'v'])
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(da.from_array(
            np.ones(_grid_parms['image_size']),
            chunks=_grid_parms['image_size']),
                                                     dims=['l', 'm'])
    else:
        assert (
            False
        ), "######### ERROR: At least 'a_term' or 'ps_term' must be true."

    ###########################################################
    #Make phase gradient (one for each field)
    field_phase_dir = _gcf_parms['field_phase_dir']
    field_phase_dir = da.from_array(
        field_phase_dir,
        chunks=(np.ceil(len(field_phase_dir) / _gcf_parms['a_chan_num_chunk']),
                2))

    phase_gradient = da.blockwise(make_phase_gradient,
                                  ("n_field", "n_x", "n_y"),
                                  field_phase_dir, ("n_field", "2"),
                                  gcf_parms=_gcf_parms,
                                  grid_parms=_grid_parms,
                                  dtype=complex,
                                  new_axes={
                                      "n_x": _gcf_parms['resize_conv_size'][0],
                                      "n_y": _gcf_parms['resize_conv_size'][1]
                                  })

    ###########################################################

    #coords = {'baseline': np.arange(n_unique_ant), 'chan': pb_freq, 'pol' : pb_pol, 'u': np.arange(resize_conv_size[0]), 'v': np.arange(resize_conv_size[1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_gridding_convolution_parms['imsize'][0]),'m':np.arange(_gridding_convolution_parms['imsize'][1])}

    #coords = { 'conv_chan': pb_freq, 'conv_pol' : pb_pol, 'u': np.arange(resize_conv_size[0]), 'v': np.arange(resize_conv_size[1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_gridding_convolution_parms['imsize'][0]),'m':np.arange(_gridding_convolution_parms['imsize'][1])}

    coords = {
        'u': np.arange(_gcf_parms['resize_conv_size'][0]),
        'v': np.arange(_gcf_parms['resize_conv_size'][1]),
        'xy': np.arange(2),
        'field_id': field_id,
        'l': np.arange(_grid_parms['image_size'][0]),
        'm': np.arange(_grid_parms['image_size'][1])
    }

    dataset_dict['CF_BASELINE_MAP'] = xr.DataArray(
        cf_baseline_map,
        dims=('baseline')).chunk(_gcf_parms['vis_data_chunks'][1])
    dataset_dict['CF_CHAN_MAP'] = xr.DataArray(
        cf_chan_map, dims=('chan')).chunk(_gcf_parms['vis_data_chunks'][2])
    dataset_dict['CF_POL_MAP'] = xr.DataArray(cf_pol_map, dims=('pol')).chunk(
        _gcf_parms['vis_data_chunks'][3])

    dataset_dict['CONV_KERNEL'] = xr.DataArray(conv_kernel,
                                               dims=('conv_baseline',
                                                     'conv_chan', 'conv_pol',
                                                     'u', 'v'))
    dataset_dict['PHASE_GRADIENT'] = xr.DataArray(phase_gradient,
                                                  dims=('field_id', 'u', 'v'))

    #print(field_id)
    gcf_dataset = xr.Dataset(dataset_dict, coords=coords)
    gcf_dataset.attrs['cell_uv'] = 1 / (_grid_parms['image_size_padded'] *
                                        _grid_parms['cell_size'] *
                                        _gcf_parms['oversampling'])
    gcf_dataset.attrs['oversampling'] = _gcf_parms['oversampling']

    #list_xarray_data_variables = [gcf_dataset['A_TERM'],gcf_dataset['WEIGHT_A_TERM'],gcf_dataset['A_SUPPORT'],gcf_dataset['WEIGHT_A_SUPPORT'],gcf_dataset['PHASE_GRADIENT']]
    #return _store(gcf_dataset,list_xarray_data_variables,_storage_parms)

    print(
        '#########################  Created graph for make_gridding_convolution_function #########################'
    )

    return gcf_dataset