Exemple #1
def make_imaging_weight(vis_dataset, imaging_weights_parms,storage_parms):
    Creates the imaging weight data variable that has dimensions time x baseline x chan x pol (matches the visibility data variable).
    The weight density can be averaged over channels or calculated independently for each channel using imaging_weights_parms['chan_mode'].
    The following imaging weighting schemes are supported 'natural', 'uniform', 'briggs', 'briggs_abs'.
    The imaging_weights_parms['imsize'] and imaging_weights_parms['cell'] should usually be the same values that will be used for subsequent synthesis blocks (for example making the psf).
    To achieve something similar to 'superuniform' weighting in CASA tclean imaging_weights_parms['imsize'] and imaging_weights_parms['cell'] can be varied relative to the values used in subsequent synthesis blocks.
    vis_dataset : xarray.core.dataset.Dataset
        Input visibility dataset.
    imaging_weights_parms : dictionary
    imaging_weights_parms['weighting'] : {'natural', 'uniform', 'briggs', 'briggs_abs'}, default = natural
        Weighting scheme used for creating the imaging weights.
    imaging_weights_parms['imsize'] : list of int, length = 2
        The size of the grid for gridding the imaging weights. Used when imaging_weights_parms['weighting'] is not 'natural'.
    imaging_weights_parms['cell']  : list of number, length = 2, units = arcseconds
        The size of the pixels in the fft of the grid (the image domain pixel size). Used when imaging_weights_parms['weighting'] is not 'natural'.
    imaging_weights_parms['robust'] : number, acceptable range [-2,2], default = 0.5
        Robustness parameter for Briggs weighting.
        robust = -2.0 maps to uniform weighting.
        robust = +2.0 maps to natural weighting.
    imaging_weights_parms['briggs_abs_noise'] : number, default=1.0
        Noise parameter for imaging_weights_parms['weighting']='briggs_abs' mode weighting.
    imaging_weights_parms['chan_mode'] : {'continuum'/'cube'}, default = 'continuum'
        When 'cube' the weights are calculated independently for each channel (perchanweightdensity=True in CASA tclean) and when 'continuum' a common weight density is calculated for all channels.
    imaging_weights_parms['uvw_name'] : str, default ='UVW'
        The name of uvw data variable that will be used to grid the weights. Used when imaging_weights_parms['weighting'] is not 'natural'.
    imaging_weights_parms['data_name'] : str, default = 'DATA'
        The name of the visibility data variable whose dimensions will be used to construct the imaging weight data variable.
    imaging_weights_parms['imaging_weight_name'] : str, default ='IMAGING_WEIGHT'
        The name of that will be used for the imaging weight data variable.
    storage_parms : dictionary
    storage_parms['to_disk'] : bool, default = False
        If true the dask graph is executed and saved to disk in the zarr format.
    storage_parms['append'] : bool, default = False
        If storage_parms['to_disk'] is True only the dask graph associated with the function is executed and the resulting data variables are saved to an existing zarr file on disk.
        Note that graphs on unrelated data to this function will not be executed or saved.
    storage_parms['outfile'] : str
        The zarr file to create or append to.
    storage_parms['chunks_on_disk'] : dict of int, default = {}
        The chunk size to use when writing to disk. This is ignored if storage_parms['append'] is True. The default will use the chunking of the input dataset.
    storage_parms['chunks_return'] : dict of int, default = {}
        The chunk size of the dataset that is returned. The default will use the chunking of the input dataset.
    storage_parms['graph_name'] : str
        The time to compute and save the data is stored in the attribute section of the dataset and storage_parms['graph_name'] is used in the label.
    storage_parms['compressor'] : numcodecs.blosc.Blosc,default=Blosc(cname='zstd', clevel=2, shuffle=0)
        The compression algorithm to use. Available compression algorithms can be found at https://numcodecs.readthedocs.io/en/stable/blosc.html.
    vis_dataset : xarray.core.dataset.Dataset
        The vis_dataset will contain a new data variable for the imaging weights the name is defined by the input parameter imaging_weights_parms['imaging_weight_name'].
    print('######################### Start make_imaging_weights #########################')
    import time
    import math
    import xarray as xr
    import dask.array as da
    import matplotlib.pylab as plt
    import dask.array.fft as dafft
    import dask
    import copy, os
    from numcodecs import Blosc
    from itertools import cycle
    import zarr
    from ngcasa._ngcasa_utils._store import _store
    from ngcasa._ngcasa_utils._check_parms import _check_storage_parms
    from ._imaging_utils._check_imaging_parms import _check_imaging_weights_parms
    from cngi.dio import write_zarr, append_zarr
    _imaging_weights_parms =  copy.deepcopy(imaging_weights_parms)
    _storage_parms =  copy.deepcopy(storage_parms)
    assert(_check_imaging_weights_parms(vis_dataset,_imaging_weights_parms)), "######### ERROR: imaging_weights_parms checking failed"
    assert(_check_storage_parms(_storage_parms,'dataset.vis.zarr','make_imaging_weights')), "######### ERROR: storage_parms checking failed"
    #Check if weight or weight spectrum present
    #If both default to weight spectrum
    #If none create new
    weight_present = 'WEIGHT' in vis_dataset.data_vars
    weight_spectrum_present = 'WEIGHT_SPECTRUM' in vis_dataset.data_vars
    all_dims_dict = vis_dataset.dims
    vis_data_dims = vis_dataset[_imaging_weights_parms['data_name']].dims
    vis_data_chunksize = vis_dataset[_imaging_weights_parms['data_name']].data.chunksize
    if weight_present and weight_spectrum_present:
        print('Both WEIGHT and WEIGHT_SPECTRUM data variables found, will use WEIGHT_SPECTRUM to calculate', _imaging_weights_parms['imaging_weight_name'])
        imaging_weight = _match_array_shape(vis_dataset.WEIGHT_SPECTRUM,vis_dataset[_imaging_weights_parms['data_name']])
    elif weight_present:
        print('WEIGHT data variable found, will use WEIGHT to calculate ', _imaging_weights_parms['imaging_weight_name'])
        imaging_weight = _match_array_shape(vis_dataset.WEIGHT,vis_dataset[_imaging_weights_parms['data_name']])
    elif weight_spectrum_present:
        print('WEIGHT_SPECTRUM  data variable found, will use WEIGHT_SPECTRUM to calculate ', _imaging_weights_parms['imaging_weight_name'])
        imaging_weight = _match_array_shape(vis_dataset.WEIGHT_SPECTRUM,vis_dataset[_imaging_weights_parms['data_name']])
        print('No WEIGHT or WEIGHT_SPECTRUM data variable found,  will assume all weights are unity to calculate ', _imaging_weights_parms['imaging_weight_name'])
        imaging_weight = da.ones(vis_dataset[_imaging_weights_parms['data_name']].shape,chunks=vis_data_chunksize)
    vis_dataset[_imaging_weights_parms['imaging_weight_name']] =  xr.DataArray(imaging_weight, dims=vis_dataset[_imaging_weights_parms['data_name']].dims)
    if _imaging_weights_parms['weighting'] != 'natural':
    list_xarray_data_variables = [vis_dataset[_imaging_weights_parms['imaging_weight_name']]]
    return _store(vis_dataset,list_xarray_data_variables,_storage_parms)
Exemple #2
def make_gridding_convolution_function(vis_dataset, global_dataset, gcf_parms, grid_parms, storage_parms):
    Currently creates a gcf to correct for the primary beams of antennas and supports heterogenous arrays (antennas with different dish sizes).
    Only the airy disk and ALMA airy disk model is implemented.
    In the future support will be added for beam squint, pointing corrections, w projection, and including a prolate spheroidal term.
    vis_dataset : xarray.core.dataset.Dataset
        Input visibility dataset.
    gcf_parms : dictionary
    gcf_parms['function'] : {'alma_airy'/'airy'}, default = 'alma_airy'
        The primary beam model used (a function of the dish diameter and blockage diameter).
    gcf_parms['list_dish_diameters']  : list of number, units = meter
        A list of unique antenna dish diameters.
    gcf_parms['list_blockage_diameters']  : list of number, units = meter
        A list of unique feed blockage diameters (must be the same length as gcf_parms['list_dish_diameters']).
    gcf_parms['unique_ant_indx']  : list of int
        A list that has indeces for the gcf_parms['list_dish_diameters'] and gcf_parms['list_blockage_diameters'] lists, for each antenna.
    gcf_parms['image_phase_center']  : list of number, length = 2, units = radians
        The mosaic image phase center.
    gcf_parms['a_chan_num_chunk']  : int, default = 3
        The number of chunks in the channel dimension of the gridding convolution function data variable.
    gcf_parms['oversampling']  : list of int, length = 2, default = [10,10]
        The oversampling of the gridding convolution function.
    gcf_parms['max_support']  : list of int, length = 2, default = [15,15]
        The maximum allowable support of the gridding convolution function.
    gcf_parms['support_cut_level']  : number, default = 0.025
        The antennuation at which to truncate the gridding convolution function.
    gcf_parms['chan_tolerance_factor']  : number, default = 0.005
        It is the fractional bandwidth at which the frequency dependence of the primary beam can be ignored and determines the number of frequencies for which to calculate a gridding convolution function. Number of channels equals the fractional bandwidth devided by gcf_parms['chan_tolerance_factor'].
    grid_parms : dictionary
    grid_parms['image_size'] : list of int, length = 2
        The image size (no padding).
    grid_parms['cell_size']  : list of number, length = 2, units = arcseconds
        The image cell size.
    storage_parms : dictionary
    storage_parms['to_disk'] : bool, default = False
        If true the dask graph is executed and saved to disk in the zarr format.
    storage_parms['append'] : bool, default = False
        If storage_parms['to_disk'] is True only the dask graph associated with the function is executed and the resulting data variables are saved to an existing zarr file on disk.
        Note that graphs on unrelated data to this function will not be executed or saved.
    storage_parms['outfile'] : str
        The zarr file to create or append to.
    storage_parms['chunks_on_disk'] : dict of int, default = {}
        The chunk size to use when writing to disk. This is ignored if storage_parms['append'] is True. The default will use the chunking of the input dataset.
    storage_parms['chunks_return'] : dict of int, default = {}
        The chunk size of the dataset that is returned. The default will use the chunking of the input dataset.
    storage_parms['graph_name'] : str
        The time to compute and save the data is stored in the attribute section of the dataset and storage_parms['graph_name'] is used in the label.
    storage_parms['compressor'] : numcodecs.blosc.Blosc,default=Blosc(cname='zstd', clevel=2, shuffle=0)
        The compression algorithm to use. Available compression algorithms can be found at https://numcodecs.readthedocs.io/en/stable/blosc.html.
    gcf_dataset : xarray.core.dataset.Dataset
    print('######################### Start make_gridding_convolution_function #########################')
    from ngcasa._ngcasa_utils._store import _store
    from ngcasa._ngcasa_utils._check_parms import _check_storage_parms
    from ._imaging_utils._check_imaging_parms import _check_pb_parms
    from ._imaging_utils._check_imaging_parms import _check_grid_parms, _check_gcf_parms
    from ._imaging_utils._gridding_convolutional_kernels import _create_prolate_spheroidal_kernel_2D, _create_prolate_spheroidal_image_2D
    from ._imaging_utils._remove_padding import _remove_padding
    import numpy as np
    import dask.array as da
    import copy, os
    import xarray as xr
    import itertools
    import dask
    import dask.array.fft as dafft
    import matplotlib.pylab as plt
    _gcf_parms = copy.deepcopy(gcf_parms)
    _grid_parms = copy.deepcopy(grid_parms)
    _storage_parms = copy.deepcopy(storage_parms)
    _gcf_parms['basline_ant'] = vis_dataset.antennas.values # n_baseline x 2 (ant pair)
    _gcf_parms['freq_chan'] = vis_dataset.chan.values
    _gcf_parms['pol'] = vis_dataset.pol.values
    _gcf_parms['vis_data_chunks'] = vis_dataset.DATA.chunks
    _gcf_parms['field_phase_dir'] = np.array(global_dataset.FIELD_PHASE_DIR.values[:,:,vis_dataset.attrs['ddi']])
    assert(_check_gcf_parms(_gcf_parms)), "######### ERROR: gcf_parms checking failed"
    assert(_check_grid_parms(_grid_parms)), "######### ERROR: grid_parms checking failed"
    assert(_check_storage_parms(_storage_parms,'dataset.gcf.zarr','make_gcf')), "######### ERROR: user_storage_parms checking failed"
    assert(not _storage_parms['append']), "######### ERROR: storage_parms['append'] = True is not available for make_gridding_convolution_function"
    if _gcf_parms['function'] == 'airy':
        from ._imaging_utils._make_pb_symmetric import _airy_disk_rorder
        pb_func = _airy_disk_rorder
    elif _gcf_parms['function'] == 'alma_airy':
        from ._imaging_utils._make_pb_symmetric import _alma_airy_disk_rorder
        pb_func = _alma_airy_disk_rorder
        assert(False), "######### ERROR: Only airy and alma_airy function has been implemented"
    #For now only a_term works
    _gcf_parms['a_term'] =  True
    _gcf_parms['ps_term'] =  False
    _gcf_parms['resize_conv_size'] = (_gcf_parms['max_support'] + 1)*_gcf_parms['oversampling']
    #resize_conv_size = _gcf_parms['resize_conv_size']
    if _gcf_parms['ps_term'] == True:
        ps_term = _create_prolate_spheroidal_kernel_2D(_gcf_parms['oversampling'],np.array([7,7])) #This is only used with a_term == False. Support is hardcoded to 7 until old ps code is replaced by a general function.
        center = _grid_parms['image_center']
        center_embed = np.array(ps_term.shape)//2
        ps_term_padded = np.zeros(_grid_parms['image_size'])
        ps_term_padded[center[0]-center_embed[0]:center[0]+center_embed[0],center[1]-center_embed[1] : center[1]+center_embed[1]] = ps_term
        ps_term_padded_ifft = dafft.fftshift(dafft.ifft2(dafft.ifftshift(da.from_array(ps_term_padded))))

        ps_image = da.from_array(_remove_padding(_create_prolate_spheroidal_image_2D(_grid_parms['image_size_padded']),_grid_parms['image_size']),chunks=_grid_parms['image_size'])
        #Effecively no mapping needed if ps_term == True and a_term == False
        cf_baseline_map = np.zeros((len(_gcf_parms['basline_ant']),),dtype=int)
        cf_chan_map = np.zeros((len(_gcf_parms['freq_chan']),),dtype=int)
        cf_pol_map = np.zeros((len(_gcf_parms['pol']),),dtype=int)
    if _gcf_parms['a_term'] == True:
        n_unique_ant = len(_gcf_parms['list_dish_diameters'])
        cf_baseline_map,pb_ant_pairs = create_cf_baseline_map(_gcf_parms['unique_ant_indx'],_gcf_parms['basline_ant'],n_unique_ant)
        cf_chan_map, pb_freq = create_cf_chan_map(_gcf_parms['freq_chan'],_gcf_parms['chan_tolerance_factor'])
        pb_freq = da.from_array(pb_freq,chunks=np.ceil(len(pb_freq)/_gcf_parms['a_chan_num_chunk'] ))
        cf_pol_map = np.zeros((len(_gcf_parms['pol']),),dtype=int) #create_cf_pol_map(), currently treating all pols the same
        pb_pol = da.from_array(np.array([0]),1)
        n_chunks_in_each_dim = [pb_freq.numblocks[0],pb_pol.numblocks[0]]
        iter_chunks_indx = itertools.product(np.arange(n_chunks_in_each_dim[0]), np.arange(n_chunks_in_each_dim[1]))
        chan_chunk_sizes = pb_freq.chunks
        pol_chunk_sizes = pb_pol.chunks
        #print(pb_freq, pb_pol,pol_chunk_sizes)
        list_baseline_pb = []
        list_weight_baseline_pb_sqrd = []
        for c_chan, c_pol in iter_chunks_indx:
                #print('chan, pol ',c_chan,c_pol)
                _gcf_parms['ipower'] = 1
                delayed_baseline_pb = dask.delayed(make_baseline_patterns)(pb_freq.partitions[c_chan],pb_pol.partitions[c_pol],dask.delayed(pb_ant_pairs),dask.delayed(pb_func),dask.delayed(_gcf_parms),dask.delayed(_grid_parms))
                list_baseline_pb.append(da.from_delayed(delayed_baseline_pb,(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_grid_parms['image_size_padded'][0],_grid_parms['image_size_padded'][1]),dtype=np.double))
                _gcf_parms['ipower'] = 2
                delayed_weight_baseline_pb_sqrd = dask.delayed(make_baseline_patterns)(pb_freq.partitions[c_chan],pb_pol.partitions[c_pol],dask.delayed(pb_ant_pairs),dask.delayed(pb_func),dask.delayed(_gcf_parms),dask.delayed(_grid_parms))
                list_weight_baseline_pb_sqrd.append(da.from_delayed(delayed_weight_baseline_pb_sqrd,(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_grid_parms['image_size_padded'][0],_grid_parms['image_size_padded'][1]),dtype=np.double))
        baseline_pb = da.concatenate(list_baseline_pb,axis=1)
        weight_baseline_pb_sqrd = da.concatenate(list_weight_baseline_pb_sqrd,axis=1)
    #Combine patterns and fft to obtain the gridding convolutional kernel

    dataset_dict = {}
    list_xarray_data_variables = []
    if (_gcf_parms['a_term'] == True) and (_gcf_parms['ps_term'] == True):
        conv_kernel = da.real(dafft.fftshift(dafft.fft2(dafft.ifftshift(ps_term_padded_ifft*baseline_pb, axes=(3, 4)), axes=(3, 4)), axes=(3, 4)))
        conv_weight_kernel = da.real(dafft.fftshift(dafft.fft2(dafft.ifftshift(weight_baseline_pb_sqrd, axes=(3, 4)), axes=(3, 4)), axes=(3, 4)))
        list_conv_kernel = []
        list_weight_conv_kernel = []
        list_conv_support = []
        iter_chunks_indx = itertools.product(np.arange(n_chunks_in_each_dim[0]), np.arange(n_chunks_in_each_dim[1]))
        for c_chan, c_pol in iter_chunks_indx:
                delayed_kernels_and_support = dask.delayed(resize_and_calc_support)(conv_kernel.partitions[:,c_chan,c_pol,:,:],conv_weight_kernel.partitions[:,c_chan,c_pol,:,:],dask.delayed(_gcf_parms),dask.delayed(_grid_parms))
                list_conv_kernel.append(da.from_delayed(delayed_kernels_and_support[0],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]),dtype=np.double))
                list_weight_conv_kernel.append(da.from_delayed(delayed_kernels_and_support[1],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]),dtype=np.double))
                list_conv_support.append(da.from_delayed(delayed_kernels_and_support[2],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],2),dtype=np.int))
        conv_kernel = da.concatenate(list_conv_kernel,axis=1)
        weight_conv_kernel = da.concatenate(list_weight_conv_kernel,axis=1)
        conv_support = da.concatenate(list_conv_support,axis=1)
        dataset_dict['SUPPORT'] = xr.DataArray(conv_support, dims=['conv_baseline','conv_chan','conv_pol','xy'])
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(ps_image, dims=['l','m'])
        dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(weight_conv_kernel, dims=['conv_baseline','conv_chan','conv_pol','u','v'])
    elif (_gcf_parms['a_term'] == False) and (_gcf_parms['ps_term'] == True):
        support = np.array([7,7])
        dataset_dict['SUPPORT'] = xr.DataArray(support[None,None,None,:], dims=['conv_baseline','conv_chan','conv_pol','xy'])
        conv_kernel = np.zeros((1,1,1,_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]))
        center = _gcf_parms['resize_conv_size']//2
        center_embed = np.array(ps_term.shape)//2
        conv_kernel[0,0,0,center[0]-center_embed[0]:center[0]+center_embed[0],center[1]-center_embed[1] : center[1]+center_embed[1]] = ps_term
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(ps_image, dims=['l','m'])
        ##Enabled for test
        #dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(conv_kernel, dims=['conv_baseline','conv_chan','conv_pol','u','v'])
    elif (_gcf_parms['a_term'] == True) and (_gcf_parms['ps_term'] == False):
        conv_kernel = da.real(dafft.fftshift(dafft.fft2(dafft.ifftshift(baseline_pb, axes=(3, 4)), axes=(3, 4)), axes=(3, 4)))
        conv_weight_kernel = da.real(dafft.fftshift(dafft.fft2(dafft.ifftshift(weight_baseline_pb_sqrd, axes=(3, 4)), axes=(3, 4)), axes=(3, 4)))
        list_conv_kernel = []
        list_weight_conv_kernel = []
        list_conv_support = []
        iter_chunks_indx = itertools.product(np.arange(n_chunks_in_each_dim[0]), np.arange(n_chunks_in_each_dim[1]))
        for c_chan, c_pol in iter_chunks_indx:
                delayed_kernels_and_support = dask.delayed(resize_and_calc_support)(conv_kernel.partitions[:,c_chan,c_pol,:,:],conv_weight_kernel.partitions[:,c_chan,c_pol,:,:],dask.delayed(_gcf_parms),dask.delayed(_grid_parms))
                list_conv_kernel.append(da.from_delayed(delayed_kernels_and_support[0],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]),dtype=np.double))
                list_weight_conv_kernel.append(da.from_delayed(delayed_kernels_and_support[1],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]),dtype=np.double))
                list_conv_support.append(da.from_delayed(delayed_kernels_and_support[2],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],2),dtype=np.int))
        conv_kernel = da.concatenate(list_conv_kernel,axis=1)
        weight_conv_kernel = da.concatenate(list_weight_conv_kernel,axis=1)
        conv_support = da.concatenate(list_conv_support,axis=1)
        dataset_dict['SUPPORT'] = xr.DataArray(conv_support, dims=['conv_baseline','conv_chan','conv_pol','xy'])
        dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(weight_conv_kernel, dims=['conv_baseline','conv_chan','conv_pol','u','v'])
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(da.from_array(np.ones(_grid_parms['image_size']),chunks=_grid_parms['image_size']), dims=['l','m'])
        assert(False), "######### ERROR: At least 'a_term' or 'ps_term' must be true."
    #Make phase gradient (one for each field)
    field_phase_dir = _gcf_parms['field_phase_dir']
    field_phase_dir = da.from_array(field_phase_dir,chunks=(np.ceil(len(field_phase_dir)/_gcf_parms['a_chan_num_chunk']),2))
    phase_gradient = da.blockwise(make_phase_gradient, ("n_field","n_x","n_y"), field_phase_dir, ("n_field","2"), gcf_parms=_gcf_parms, grid_parms=_grid_parms, dtype=complex,  new_axes={"n_x": _gcf_parms['resize_conv_size'][0], "n_y": _gcf_parms['resize_conv_size'][1]})

    #coords = {'baseline': np.arange(n_unique_ant), 'chan': pb_freq, 'pol' : pb_pol, 'u': np.arange(resize_conv_size[0]), 'v': np.arange(resize_conv_size[1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_gridding_convolution_parms['imsize'][0]),'m':np.arange(_gridding_convolution_parms['imsize'][1])}
    #coords = { 'conv_chan': pb_freq, 'conv_pol' : pb_pol, 'u': np.arange(resize_conv_size[0]), 'v': np.arange(resize_conv_size[1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_gridding_convolution_parms['imsize'][0]),'m':np.arange(_gridding_convolution_parms['imsize'][1])}
    coords = { 'u': np.arange(_gcf_parms['resize_conv_size'][0]), 'v': np.arange(_gcf_parms['resize_conv_size'][1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_grid_parms['image_size'][0]),'m':np.arange(_grid_parms['image_size'][1])}
    dataset_dict['CF_BASELINE_MAP'] = xr.DataArray(cf_baseline_map, dims=('baseline')).chunk(_gcf_parms['vis_data_chunks'][1])
    dataset_dict['CF_CHAN_MAP'] = xr.DataArray(cf_chan_map, dims=('chan')).chunk(_gcf_parms['vis_data_chunks'][2])
    dataset_dict['CF_POL_MAP'] = xr.DataArray(cf_pol_map, dims=('pol')).chunk(_gcf_parms['vis_data_chunks'][3])
    dataset_dict['CONV_KERNEL'] = xr.DataArray(conv_kernel, dims=('conv_baseline','conv_chan','conv_pol','u','v'))
    dataset_dict['PHASE_GRADIENT'] = xr.DataArray(phase_gradient, dims=('field','u','v'))
    gcf_dataset = xr.Dataset(dataset_dict, coords=coords)
    gcf_dataset.attrs['cell_uv'] =1/(_grid_parms['image_size_padded']*_grid_parms['cell_size']*_gcf_parms['oversampling'])
    gcf_dataset.attrs['oversampling'] = _gcf_parms['oversampling']
    #list_xarray_data_variables = [gcf_dataset['A_TERM'],gcf_dataset['WEIGHT_A_TERM'],gcf_dataset['A_SUPPORT'],gcf_dataset['WEIGHT_A_SUPPORT'],gcf_dataset['PHASE_GRADIENT']]
    return _store(gcf_dataset,list_xarray_data_variables,_storage_parms)
Exemple #3
def phase_rotate_sgraph(vis_dataset, global_dataset, rotation_parms, sel_parms, storage_parms):
    Rotate uvw with faceting style rephasing for multifield mosaic.
    The specified phasecenter and field phase centers are assumed to be in the same frame.
    This does not support east-west arrays, emphemeris objects or objects within the nearfield.
    (no refocus).
    vis_dataset : xarray.core.dataset.Dataset
        input Visibility Dataset
    psf_dataset : xarray.core.dataset.Dataset
    #based on UVWMachine and FTMachine
    #Important: Can not applyflags before calling rotate (uvw coordinates are also flagged). This will destroy the rotation transform.
    #Performance improvements apply_rotation_matrix (jit code)
    #print('1. numpy',vis_dataset.DATA[:,0,0,0].values)
    from ngcasa._ngcasa_utils._store import _store
    from scipy.spatial.transform import Rotation as R
    import numpy as np
    import copy
    import dask.array as da
    import xarray as xr
    from ngcasa._ngcasa_utils._check_parms import _check_storage_parms, _check_sel_parms, _check_existence_sel_parms
    from ._imaging_utils._check_imaging_parms import _check_rotation_parms
    import time
    import numba
    from numba import double
    import dask
    import itertools
    _sel_parms = copy.deepcopy(sel_parms)
    _rotation_parms = copy.deepcopy(rotation_parms)
    _storage_parms = copy.deepcopy(storage_parms)
    assert(_check_sel_parms(_sel_parms,{'uvw_in':'UVW','uvw_out':'UVW_ROT','data_in':'DATA','data_out':'DATA_ROT'})), "######### ERROR: sel_parms checking failed"
    assert(_check_existence_sel_parms(vis_dataset,{'uvw_in':_sel_parms['uvw_in'],'data_in':_sel_parms['data_in']})), "######### ERROR: sel_parms checking failed"
    assert(_check_rotation_parms(_rotation_parms)), "######### ERROR: rotation_parms checking failed"
    assert(_check_storage_parms(_storage_parms,'dataset.vis.zarr','phase_rotate')), "######### ERROR: storage_parms checking failed"
    assert(_sel_parms['uvw_out'] != _sel_parms['uvw_in']), "######### ERROR: sel_parms checking failed sel_parms['uvw_out'] can not be the same as sel_parms['uvw_in']."
    assert(_sel_parms['data_out'] != _sel_parms['data_in']), "######### ERROR: sel_parms checking failed sel_parms['data_out'] can not be the same as sel_parms['data_in']."
    #Phase center
    ra_image = _rotation_parms['image_phase_center'][0]
    dec_image = _rotation_parms['image_phase_center'][1]
    rotmat_image_phase_center = R.from_euler('XZ',[[np.pi/2 - dec_image, - ra_image + np.pi/2]]).as_matrix()[0]
    image_phase_center_cosine = _directional_cosine([ra_image,dec_image])
    n_fields = global_dataset.dims['field']
    field_names = global_dataset.field
    uvw_rotmat = np.zeros((n_fields,3,3),np.double)
    phase_rotation = np.zeros((n_fields,3),np.double)
    fields_phase_center = global_dataset.FIELD_PHASE_DIR.values[:,:,vis_dataset.attrs['ddi']]
    #Create a rotation matrix for each field
    for i_field in range(n_fields):
        #Not sure if last dimention in FIELD_PHASE_DIR is the ddi number
        field_phase_center = fields_phase_center[i_field,:]
        # Define rotation to a coordinate system with pole towards in-direction
        # and X-axis W; by rotating around z-axis over -(90-long); and around
        # x-axis (lat-90).
        rotmat_field_phase_center = R.from_euler('ZX',[[-np.pi/2 + field_phase_center[0],field_phase_center[1] - np.pi/2]]).as_matrix()[0]
        uvw_rotmat[i_field,:,:] = np.matmul(rotmat_image_phase_center,rotmat_field_phase_center).T
        if _rotation_parms['common_tangent_reprojection'] == True:
            uvw_rotmat[i_field,2,0:2] = 0.0 # (Common tangent rotation needed for joint mosaics, see last part of FTMachine::girarUVW in CASA)
        field_phase_center_cosine = _directional_cosine(field_phase_center)
        phase_rotation[i_field,:] = np.matmul(rotmat_image_phase_center,(image_phase_center_cosine - field_phase_center_cosine))
    chunk_sizes = vis_dataset[sel_parms["data_in"]].chunks
    freq_chan = da.from_array(vis_dataset.coords['chan'].values, chunks=(chunk_sizes[2][0]))
    n_chunks_in_each_dim = vis_dataset[_sel_parms['data_in']].data.numblocks
    iter_chunks_indx = itertools.product(np.arange(n_chunks_in_each_dim[0]), np.arange(n_chunks_in_each_dim[1]),
                                         np.arange(n_chunks_in_each_dim[2]), np.arange(n_chunks_in_each_dim[3]))
    list_of_vis_data = ndim_list(n_chunks_in_each_dim)
    list_of_uvw = ndim_list(n_chunks_in_each_dim[0:2]+(1,))
    for c_time, c_baseline, c_chan, c_pol in iter_chunks_indx:
        vis_data_and_uvw = dask.delayed(apply_phasor)(
        vis_dataset[sel_parms["data_in"]].data.partitions[c_time, c_baseline, c_chan, c_pol],
        vis_dataset[sel_parms["uvw_in"]].data.partitions[c_time, c_baseline, 0],
        dask.delayed(phase_rotation), dask.delayed(_rotation_parms['common_tangent_reprojection']))
        list_of_vis_data[c_time][c_baseline][c_chan][c_pol] = da.from_delayed(vis_data_and_uvw[0], (chunk_sizes[0][c_time], chunk_sizes[1][c_baseline], chunk_sizes[2][c_chan], chunk_sizes[3][c_pol]),dtype=np.complex128)
        list_of_uvw[c_time][c_baseline][0]  = da.from_delayed(vis_data_and_uvw[1],(chunk_sizes[0][c_time], chunk_sizes[1][c_baseline], 3),dtype=np.float64)
    vis_dataset[_sel_parms['data_out']] =  xr.DataArray(da.block(list_of_vis_data), dims=vis_dataset[_sel_parms['data_in']].dims)
    vis_dataset[_sel_parms['uvw_out']] =  xr.DataArray(da.block(list_of_uvw), dims=vis_dataset[_sel_parms['uvw_in']].dims)
    list_xarray_data_variables = [vis_dataset[_sel_parms['uvw_out']],vis_dataset[_sel_parms['data_out']]]
    return _store(vis_dataset,list_xarray_data_variables,_storage_parms)
Exemple #4
def make_image_with_gcf(vis_dataset, gcf_dataset, img_dataset, grid_parms,
                        norm_parms, sel_parms, storage_parms):
    Creates a cube or continuum dirty image from the user specified visibility, uvw and imaging weight data. A gridding convolution function (gcf_dataset), primary beam image (img_dataset) and a primary beam weight image (img_dataset) must be supplied.
    vis_dataset : xarray.core.dataset.Dataset
        Input visibility dataset.
    gcf_dataset : xarray.core.dataset.Dataset
         Input gridding convolution dataset.
    img_dataset : xarray.core.dataset.Dataset
         Input image dataset.
    grid_parms : dictionary
    grid_parms['image_size'] : list of int, length = 2
        The image size (no padding).
    grid_parms['cell_size']  : list of number, length = 2, units = arcseconds
        The image cell size.
    grid_parms['chan_mode'] : {'continuum'/'cube'}, default = 'continuum'
        Create a continuum or cube image.
    grid_parms['fft_padding'] : number, acceptable range [1,100], default = 1.2
        The factor that determines how much the gridded visibilities are padded before the fft is done.
    norm_parms : dictionary
    norm_parms['norm_type'] : {'none'/'flat_noise'/'flat_sky'}, default = 'flat_sky'
         Gridded (and FT'd) images represent the PB-weighted sky image.
         Qualitatively it can be approximated as two instances of the PB
         applied to the sky image (one naturally present in the data
         and one introduced during gridding via the convolution functions).
         normtype='flat_noise' : Divide the raw image by sqrt(sel_parms['weight_pb']) so that
                                             the input to the minor cycle represents the
                                             product of the sky and PB. The noise is 'flat'
                                             across the region covered by each PB.
        normtype='flat_sky' : Divide the raw image by sel_parms['weight_pb'] so that the input
                                         to the minor cycle represents only the sky.
                                         The noise is higher in the outer regions of the
                                         primary beam where the sensitivity is low.
        normtype='none' : No normalization after gridding and FFT.
    sel_parms : dictionary
    sel_parms['uvw'] : str, default ='UVW'
        The name of uvw data variable that will be used to grid the visibilities.
    sel_parms['data'] : str, default = 'DATA'
        The name of the visibility data to be gridded.
    sel_parms['imaging_weight'] : str, default ='IMAGING_WEIGHT'
        The name of the imaging weights to be used.
    sel_parms['image'] : str, default ='IMAGE'
        The created image name.
    sel_parms['sum_weight'] : str, default ='SUM_WEIGHT'
        The created sum of weights name.
    sel_parms['pb'] : str, default ='PB'
         The primary beam image to use for normalization.
    sel_parms['weight_pb'] : str, default ='WEIGHT_PB'
         The primary beam weight image to use for normalization.
    storage_parms : dictionary
    storage_parms['to_disk'] : bool, default = False
        If true the dask graph is executed and saved to disk in the zarr format.
    storage_parms['append'] : bool, default = False
        If storage_parms['to_disk'] is True only the dask graph associated with the function is executed and the resulting data variables are saved to an existing zarr file on disk.
        Note that graphs on unrelated data to this function will not be executed or saved.
    storage_parms['outfile'] : str
        The zarr file to create or append to.
    storage_parms['chunks_on_disk'] : dict of int, default = {}
        The chunk size to use when writing to disk. This is ignored if storage_parms['append'] is True. The default will use the chunking of the input dataset.
    storage_parms['chunks_return'] : dict of int, default = {}
        The chunk size of the dataset that is returned. The default will use the chunking of the input dataset.
    storage_parms['graph_name'] : str
        The time to compute and save the data is stored in the attribute section of the dataset and storage_parms['graph_name'] is used in the label.
    storage_parms['compressor'] : numcodecs.blosc.Blosc,default=Blosc(cname='zstd', clevel=2, shuffle=0)
        The compression algorithm to use. Available compression algorithms can be found at https://numcodecs.readthedocs.io/en/stable/blosc.html.
    image_dataset : xarray.core.dataset.Dataset
        The image_dataset will contain the image created and the sum of weights.
        '######################### Start make_image_with_gcf #########################'
    import numpy as np
    from numba import jit
    import time
    import math
    import dask.array.fft as dafft
    import xarray as xr
    import dask.array as da
    import matplotlib.pylab as plt
    import dask
    import copy, os
    from numcodecs import Blosc
    from itertools import cycle

    from ngcasa._ngcasa_utils._store import _store
    from ngcasa._ngcasa_utils._check_parms import _check_storage_parms, _check_sel_parms, _check_existence_sel_parms
    from ._imaging_utils._check_imaging_parms import _check_grid_parms, _check_norm_parms
    #from ._imaging_utils._gridding_convolutional_kernels import _create_prolate_spheroidal_kernel, _create_prolate_spheroidal_kernel_1D
    from ._imaging_utils._standard_grid import _graph_standard_grid
    from ._imaging_utils._remove_padding import _remove_padding
    from ._imaging_utils._aperture_grid import _graph_aperture_grid
    from ._imaging_utils._normalize import _normalize

    _grid_parms = copy.deepcopy(grid_parms)
    _storage_parms = copy.deepcopy(storage_parms)
    _sel_parms = copy.deepcopy(sel_parms)
    _norm_parms = copy.deepcopy(norm_parms)

    assert (_check_sel_parms(
        _sel_parms, {
            'uvw': 'UVW',
            'data': 'DATA',
            'imaging_weight': 'IMAGING_WEIGHT',
            'sum_weight': 'SUM_WEIGHT',
            'image': 'IMAGE',
            'pb': 'PB',
            'weight_pb': 'WEIGHT_PB'
        })), "######### ERROR: sel_parms checking failed"
    assert (_check_existence_sel_parms(
        vis_dataset, {
            'uvw': _sel_parms['uvw'],
            'data': _sel_parms['data'],
            'imaging_weight': _sel_parms['imaging_weight']
        })), "######### ERROR: sel_parms checking failed"
    assert (_check_existence_sel_parms(img_dataset, {
        'pb': _sel_parms['pb'],
        'weight_pb': _sel_parms['weight_pb']
    })), "######### ERROR: sel_parms checking failed"
    assert (_check_grid_parms(_grid_parms)
            ), "######### ERROR: grid_parms checking failed"
    assert (_check_norm_parms(_norm_parms)
            ), "######### ERROR: norm_parms checking failed"
    assert (_check_storage_parms(
        _storage_parms, 'dirty_image.img.zarr',
        'make_image')), "######### ERROR: storage_parms checking failed"

    # Creating gridding kernel
    #cgk, correcting_cgk_image = _create_prolate_spheroidal_kernel(_grid_parms['oversampling'], _grid_parms['support'], _grid_parms['imsize_padded'])
    #cgk_1D = _create_prolate_spheroidal_kernel_1D(_grid_parms['oversampling'], _grid_parms['support'])

    #Standard Gridd add switch
    #cgk, correcting_cgk_image = _create_prolate_spheroidal_kernel(100, 7, _grid_parms['imsize_padded'])
    #cgk_1D = _create_prolate_spheroidal_kernel_1D(100, 7)
    #grids_and_sum_weights = _graph_standard_grid(vis_dataset, cgk_1D, _grid_parms)

    _grid_parms['grid_weights'] = False
    _grid_parms['do_psf'] = False
    _grid_parms['oversampling'] = np.array(gcf_dataset.oversampling)

    grids_and_sum_weights = _graph_aperture_grid(vis_dataset, gcf_dataset,
                                                 _grid_parms, _sel_parms)
    uncorrected_dirty_image = dafft.fftshift(dafft.ifft2(dafft.ifftshift(
        grids_and_sum_weights[0], axes=(0, 1)),
                                                         axes=(0, 1)),
                                             axes=(0, 1))

    #Remove Padding
    print('grid sizes', _grid_parms['image_size_padded'][0],
    uncorrected_dirty_image = _remove_padding(
        uncorrected_dirty_image, _grid_parms['image_size']).real * (
            _grid_parms['image_size_padded'][0] *
    normalized_image = _normalize(uncorrected_dirty_image,
                                  grids_and_sum_weights[1], img_dataset,
                                  gcf_dataset, 'forward', _norm_parms,

    if _grid_parms['chan_mode'] == 'continuum':
        freq_coords = [da.mean(vis_dataset.coords['chan'].values)]
        chan_width = da.from_array([da.mean(vis_dataset['chan_width'].data)],
                                   chunks=(1, ))
        imag_chan_chunk_size = 1
    elif _grid_parms['chan_mode'] == 'cube':
        freq_coords = vis_dataset.coords['chan'].values
        chan_width = vis_dataset['chan_width'].data
        imag_chan_chunk_size = vis_dataset.DATA.chunks[2][0]

    ###Create Image Dataset
    chunks = vis_dataset.DATA.chunks
    n_imag_pol = chunks[3][0]

    coords = {
        'd0': np.arange(_grid_parms['image_size'][0]),
        'd1': np.arange(_grid_parms['image_size'][1]),
        'chan': freq_coords,
        'pol': np.arange(n_imag_pol),
        'chan_width': ('chan', chan_width)
    img_dataset = img_dataset.assign_coords(coords)
    img_dataset[_sel_parms['sum_weight']] = xr.DataArray(
        grids_and_sum_weights[1], dims=['chan', 'pol'])
    img_dataset[_sel_parms['image']] = xr.DataArray(
        normalized_image, dims=['d0', 'd1', 'chan', 'pol'])

    list_xarray_data_variables = [
        img_dataset[_sel_parms['image']], img_dataset[_sel_parms['sum_weight']]
    return _store(img_dataset, list_xarray_data_variables, _storage_parms)
Exemple #5
def make_image(vis_dataset, img_dataset, grid_parms, sel_parms, storage_parms):
    Creates a cube or continuum dirty image from the user specified visibility, uvw and imaging weight data. Only the prolate spheroidal convolutional gridding function is supported. See make_image_with_gcf function for creating an image with A-projection.
    vis_dataset : xarray.core.dataset.Dataset
        Input visibility dataset.
    grid_parms : dictionary
    grid_parms['image_size'] : list of int, length = 2
        The image size (no padding).
    grid_parms['cell_size']  : list of number, length = 2, units = arcseconds
        The image cell size.
    grid_parms['chan_mode'] : {'continuum'/'cube'}, default = 'continuum'
        Create a continuum or cube image.
    grid_parms['fft_padding'] : number, acceptable range [1,100], default = 1.2
        The factor that determines how much the gridded visibilities are padded before the fft is done.
    sel_parms : dictionary
    sel_parms['uvw'] : str, default ='UVW'
        The name of uvw data variable that will be used to grid the visibilities.
    sel_parms['data'] : str, default = 'DATA'
        The name of the visibility data to be gridded.
    sel_parms['imaging_weight'] : str, default ='IMAGING_WEIGHT'
        The name of the imaging weights to be used.
    sel_parms['image'] : str, default ='DIRTY_IMAGE'
        The created image name.
    sel_parms['sum_weight'] : str, default ='SUM_WEIGHT'
        The created sum of weights name.
    storage_parms : dictionary
    storage_parms['to_disk'] : bool, default = False
        If true the dask graph is executed and saved to disk in the zarr format.
    storage_parms['append'] : bool, default = False
        If storage_parms['to_disk'] is True only the dask graph associated with the function is executed and the resulting data variables are saved to an existing zarr file on disk.
        Note that graphs on unrelated data to this function will not be executed or saved.
    storage_parms['outfile'] : str
        The zarr file to create or append to.
    storage_parms['chunks_on_disk'] : dict of int, default = {}
        The chunk size to use when writing to disk. This is ignored if storage_parms['append'] is True. The default will use the chunking of the input dataset.
    storage_parms['chunks_return'] : dict of int, default = {}
        The chunk size of the dataset that is returned. The default will use the chunking of the input dataset.
    storage_parms['graph_name'] : str
        The time to compute and save the data is stored in the attribute section of the dataset and storage_parms['graph_name'] is used in the label.
    storage_parms['compressor'] : numcodecs.blosc.Blosc,default=Blosc(cname='zstd', clevel=2, shuffle=0)
        The compression algorithm to use. Available compression algorithms can be found at https://numcodecs.readthedocs.io/en/stable/blosc.html.
    image_dataset : xarray.core.dataset.Dataset
        The image_dataset will contain the image created and the sum of weights.
    print('######################### Start make_image #########################')
    import numpy as np
    from numba import jit
    import time
    import math
    import dask.array.fft as dafft
    import xarray as xr
    import dask.array as da
    import matplotlib.pylab as plt
    import dask
    import copy, os
    from numcodecs import Blosc
    from itertools import cycle
    from ngcasa._ngcasa_utils._store import _store
    from ngcasa._ngcasa_utils._check_parms import _check_storage_parms, _check_sel_parms, _check_existence_sel_parms
    from ._imaging_utils._check_imaging_parms import _check_grid_parms
    from ._imaging_utils._gridding_convolutional_kernels import _create_prolate_spheroidal_kernel, _create_prolate_spheroidal_kernel_1D
    from ._imaging_utils._standard_grid import _graph_standard_grid
    from ._imaging_utils._remove_padding import _remove_padding
    from ._imaging_utils._aperture_grid import _graph_aperture_grid
    _grid_parms = copy.deepcopy(grid_parms)
    _storage_parms = copy.deepcopy(storage_parms)
    _sel_parms = copy.deepcopy(sel_parms)
    assert(_check_sel_parms(_sel_parms,{'uvw':'UVW','data':'DATA','imaging_weight':'IMAGING_WEIGHT','sum_weight':'SUM_WEIGHT','image':'IMAGE','pb':'PB','weight_pb':'WEIGHT_PB'})), "######### ERROR: sel_parms checking failed"
    assert(_check_existence_sel_parms(vis_dataset,{'uvw':_sel_parms['uvw'],'data':_sel_parms['data'],'imaging_weight':_sel_parms['imaging_weight']})), "######### ERROR: sel_parms checking failed"
    assert(_check_grid_parms(_grid_parms)), "######### ERROR: grid_parms checking failed"
    assert(_check_storage_parms(_storage_parms,'dirty_image.img.zarr','make_image')), "######### ERROR: storage_parms checking failed"
    # Creating gridding kernel
    _grid_parms['oversampling'] = 100
    _grid_parms['support'] = 7
    cgk, correcting_cgk_image = _create_prolate_spheroidal_kernel(_grid_parms['oversampling'], _grid_parms['support'], _grid_parms['image_size_padded'])
    cgk_1D = _create_prolate_spheroidal_kernel_1D(_grid_parms['oversampling'], _grid_parms['support'])
    _grid_parms['complex_grid'] = True
    _grid_parms['do_psf'] = False
    grids_and_sum_weights = _graph_standard_grid(vis_dataset, cgk_1D, _grid_parms, _sel_parms)
    uncorrected_dirty_image = dafft.fftshift(dafft.ifft2(dafft.ifftshift(grids_and_sum_weights[0], axes=(0, 1)), axes=(0, 1)), axes=(0, 1))
    #Remove Padding
    correcting_cgk_image = _remove_padding(correcting_cgk_image,_grid_parms['image_size'])
    uncorrected_dirty_image = _remove_padding(uncorrected_dirty_image,_grid_parms['image_size']).real * (_grid_parms['image_size_padded'][0] * _grid_parms['image_size_padded'][1])
    def correct_image(uncorrected_dirty_image, sum_weights, correcting_cgk):
        sum_weights_copy = copy.deepcopy(sum_weights) ##Don't mutate inputs, therefore do deep copy (https://docs.dask.org/en/latest/delayed-best-practices.html).
        sum_weights_copy[sum_weights_copy == 0] = 1
        # corrected_image = (uncorrected_dirty_image/sum_weights[:,:,None,None])/correcting_cgk[None,None,:,:]
        corrected_image = (uncorrected_dirty_image / sum_weights_copy) / correcting_cgk
        return corrected_image

    corrected_dirty_image = da.map_blocks(correct_image, uncorrected_dirty_image, grids_and_sum_weights[1][None, None, :, :],correcting_cgk_image[:, :, None, None])

    if _grid_parms['chan_mode'] == 'continuum':
        freq_coords = [da.mean(vis_dataset.coords['chan'].values)]
        chan_width = da.from_array([da.mean(vis_dataset['chan_width'].data)],chunks=(1,))
        imag_chan_chunk_size = 1
    elif _grid_parms['chan_mode'] == 'cube':
        freq_coords = vis_dataset.coords['chan'].values
        chan_width = vis_dataset['chan_width'].data
        imag_chan_chunk_size = vis_dataset.DATA.chunks[2][0]
    ###Create Image Dataset
    chunks = vis_dataset.DATA.chunks
    n_imag_pol = chunks[3][0]
    coords = {'d0': np.arange(_grid_parms['image_size'][0]), 'd1': np.arange(_grid_parms['image_size'][1]),
              'chan': freq_coords, 'pol': np.arange(n_imag_pol), 'chan_width' : ('chan',chan_width)}
    img_dataset = img_dataset.assign_coords(coords)
    img_dataset[_sel_parms['sum_weight']] = xr.DataArray(grids_and_sum_weights[1], dims=['chan','pol'])
    img_dataset[_sel_parms['image']] = xr.DataArray(corrected_dirty_image, dims=['d0', 'd1', 'chan', 'pol'])
    list_xarray_data_variables = [img_dataset[_sel_parms['image']],img_dataset[_sel_parms['sum_weight']]]
    return _store(img_dataset,list_xarray_data_variables,_storage_parms)
def mosaic_rotate_uvw(vis_dataset, global_dataset, user_rotation_parms,
    ********* Experimental Function *************
    Rotate uvw with faceting style rephasing for multifield mosaic.
    The specified phasecenter and field phase centers are assumed to be in the same frame.
    This does not support east-west arrays, emphemeris objects or objects within the nearfield.
    (no refocus).
    vis_dataset : xarray.core.dataset.Dataset
        input Visibility Dataset
    psf_dataset : xarray.core.dataset.Dataset
    #based on UVWMachine and FTMachine

    from ngcasa._ngcasa_utils._store import _store
    from scipy.spatial.transform import Rotation as R
    import numpy as np
    import copy
    import dask.array as da

    rotation_parms = copy.deepcopy(user_rotation_parms)

    # if append true than rotation_parms['uvw_out_name'] != rotation_parms['uvw_in_name']

    #I think this should be included in vis_dataset. There should also be a beter pythonic way to do the loop inside apply_rotation_matrix.
    def apply_rotation_matrix(vis_data_field_names, field_names):
        field_indx = np.zeros(vis_data_field_names.shape, np.int)
        for i_field, field_name in enumerate(field_names):
            field_indx[vis_data_field_names == field_name] = i_field

        return field_indx

    field_indx = da.map_blocks(apply_rotation_matrix,

    #Create Rotation Matrices
    ra_image = rotation_parms['image_phase_center'].ra.radian
    dec_image = rotation_parms['image_phase_center'].dec.radian
    rotmat_image_phase_center = R.from_euler(
        'XZ', [[np.pi / 2 - dec_image, -ra_image + np.pi / 2]]).as_matrix()[0]
    image_phase_center_cosine = _directional_cosine([ra_image, dec_image])

    n_fields = global_dataset.dims['field']
    uvw_rotmat = np.zeros((n_fields, 3, 3), np.double)
    phase_rotation = np.zeros((n_fields, 3), np.double)

    for i_field in range(n_fields):
        #Not sure if last dimention in FIELD_PHASE_DIR is the ddi number
        field_phase_center = global_dataset.FIELD_PHASE_DIR.values[
            i_field, :, vis_dataset.attrs['ddi']]
        # Define rotation to a coordinate system with pole towards in-direction
        # and X-axis W; by rotating around z-axis over -(90-long); and around
        # x-axis (lat-90).
        rotmat_field_phase_center = R.from_euler('ZX', [[
            -np.pi / 2 + field_phase_center[0],
            field_phase_center[1] - np.pi / 2
        uvw_rotmat[i_field, :, :] = np.matmul(rotmat_image_phase_center,
            i_field, 2, 0:
            2] = 0.0  #Not sure why this should be done (see last part of FTMachine::girarUVW in CASA)

        field_phase_center_cosine = _directional_cosine(field_phase_center)
        phase_rotation[i_field, :] = np.matmul(
            (image_phase_center_cosine - field_phase_center_cosine))

    def apply_rotation_matrix(uvw, field_indx, uvw_rotmat):
        print(uvw.shape, field_indx.shape, uvw_rotmat.shape)
        for i_time in range(uvw.shape[0]):
            #uvw[i_time,:,0:2] = -uvw[i_time,:,0:2] this gives the same result as casa (in the ftmachines uvw(negateUV(vb)) is used). In ngcasa we don't do this since the uvw definition in the gridder and vis.zarr are the same.
            uvw[i_time, :, :] = np.matmul(uvw[i_time, :, :],
                                          uvw_rotmat[field_indx[i_time], :, :])
        return uvw

    uvw = da.map_blocks(apply_rotation_matrix,

    def calc_uv_phase_direction(uvw, field_indx, phase_rotation):
        phase_direction = np.zeros(uvw.shape[0:2] + (1, ), np.double)
        for i_time in range(uvw.shape[0]):
            phase_direction[i_time, :, 0] = np.matmul(
                uvw[i_time, :, 0:2], phase_rotation[field_indx[i_time], 0:2])

        return phase_direction

    phase_direction = da.map_blocks(
        chunks=vis_dataset['UVW'].data.chunksize[0:2] + (1, ))[:, :, 0]

    vis_dataset[rotation_parms['uvw_out_name']] = xr.DataArray(
        uvw, dims=vis_dataset[rotation_parms['uvw_in_name']].dims)

    list_xarray_data_variables = [
    return _store(vis_dataset, list_xarray_data_variables, storage_parms)
def make_psf(vis_dataset, grid_parms, storage_parms):
    Creates a cube or continuum point spread function (psf) image from the user specified uvw and imaging weight data. Only the prolate spheroidal convolutional gridding function is supported (this will change in a future releases.)
    vis_dataset : xarray.core.dataset.Dataset
        Input visibility dataset.
    grid_parms : dictionary
    grid_parms['imsize'] : list of int, length = 2
        The image size (no padding).
    grid_parms['cell']  : list of number, length = 2, units = arcseconds
        The image cell size.
    grid_parms['chan_mode'] : {'continuum'/'cube'}, default = 'continuum'
        Create a continuum or cube image.
    grid_parms['oversampling'] : int, default = 100
        The oversampling used for the convolutional gridding kernel. This will be removed in a later release and incorporated in the function that creates gridding convolutional kernels.
    grid_parms['support'] : int, default = 7
        The full support used for convolutional gridding kernel. This will be removed in a later release and incormporrated in the function that creates gridding convolutional kernels.
    grid_parms['fft_padding'] : number, acceptable range [1,100], default = 1.2
        The factor that determines how much the gridded weights are padded before the fft is done.
    grid_parms['uvw_name'] : str, default ='UVW'
        The name of uvw data variable that will be used to grid the imaging weights.
    grid_parms['imaging_weight_name'] : str, default ='IMAGING_WEIGHT'
        The name of the imaging weights to be gridded.
    grid_parms['image_name'] : str, default ='PSF'
        The created image name.
    grid_parms['sum_weight_name'] : str, default ='PSF_SUM_WEIGHT'
        The created sum of weights name.
    storage_parms : dictionary
    storage_parms['to_disk'] : bool, default = False
        If true the dask graph is executed and saved to disk in the zarr format.
    storage_parms['append'] : bool, default = False
        If storage_parms['to_disk'] is True only the dask graph associated with the function is executed and the resulting data variables are saved to an existing zarr file on disk.
        Note that graphs on unrelated data to this function will not be executed or saved.
    storage_parms['outfile'] : str
        The zarr file to create or append to.
    storage_parms['chunks_on_disk'] : dict of int, default = {}
        The chunk size to use when writing to disk. This is ignored if storage_parms['append'] is True. The default will use the chunking of the input dataset.
    storage_parms['chunks_return'] : dict of int, default = {}
        The chunk size of the dataset that is returned. The default will use the chunking of the input dataset.
    storage_parms['graph_name'] : str
        The time to compute and save the data is stored in the attribute section of the dataset and storage_parms['graph_name'] is used in the label.
    storage_parms['compressor'] : numcodecs.blosc.Blosc,default=Blosc(cname='zstd', clevel=2, shuffle=0)
        The compression algorithm to use. Available compression algorithms can be found at https://numcodecs.readthedocs.io/en/stable/blosc.html.
    image_dataset : xarray.core.dataset.Dataset
        The image_dataset will contain the image created and the sum of weights.

    print('######################### Start make_psf #########################')
    import numpy as np
    from numba import jit
    import time
    import math
    import dask.array.fft as dafft
    import xarray as xr
    import dask.array as da
    import matplotlib.pylab as plt
    import dask.array.fft as dafft
    import dask
    import copy, os
    from numcodecs import Blosc
    from itertools import cycle

    from ngcasa._ngcasa_utils._store import _store
    from ngcasa._ngcasa_utils._check_parms import _check_storage_parms
    from ._imaging_utils._check_imaging_parms import _check_grid_params
    from ._imaging_utils._gridding_convolutional_kernels import _create_prolate_spheroidal_kernel, _create_prolate_spheroidal_kernel_1D
    from ._imaging_utils._standard_grid import _graph_standard_grid
    from ._imaging_utils._remove_padding import _remove_padding

    _grid_parms = copy.deepcopy(grid_parms)
    _storage_parms = copy.deepcopy(storage_parms)

    _grid_parms['do_psf'] = True

    assert (_check_grid_params(vis_dataset,
            ), "######### ERROR: grid_parms checking failed"
    assert (_check_storage_parms(
        _storage_parms, 'psf.img.zarr',
        'make_psf')), "######### ERROR: storage_parms checking failed"

    # Creating gridding kernel
    cgk, correcting_cgk_image = _create_prolate_spheroidal_kernel(
        _grid_parms['oversampling'], _grid_parms['support'],
    cgk_1D = _create_prolate_spheroidal_kernel_1D(_grid_parms['oversampling'],

    grids_and_sum_weights = _graph_standard_grid(vis_dataset, cgk_1D,
    uncorrected_psf_image = dafft.fftshift(dafft.ifft2(dafft.ifftshift(
        grids_and_sum_weights[0], axes=(0, 1)),
                                                       axes=(0, 1)),
                                           axes=(0, 1))

    #Remove Padding
    correcting_cgk_image = _remove_padding(correcting_cgk_image,
    uncorrected_psf_image = _remove_padding(
        uncorrected_psf_image, _grid_parms['imsize']).real * (
            _grid_parms['imsize_padded'][0] * _grid_parms['imsize_padded'][1])

    #############Move this to Normalizer#############
    def correct_image(uncorrected_psf_image, sum_weights, correcting_cgk):
        sum_weights[sum_weights == 0] = 1
        corrected_image = (
            uncorrected_psf_image /
            sum_weights[None, None, :, :]) / correcting_cgk[:, :, None, None]
        return corrected_image

    corrected_psf_image = da.map_blocks(
        correct_image, uncorrected_psf_image, grids_and_sum_weights[1],
        correcting_cgk_image)  # ? has to be .data to paralize correctly

    if _grid_parms['chan_mode'] == 'continuum':
        freq_coords = [da.mean(vis_dataset.coords['chan'].values)]
        chan_width = da.from_array([da.mean(vis_dataset['chan'].data)],
                                   chunks=(1, ))
        imag_chan_chunk_size = 1
    elif _grid_parms['chan_mode'] == 'cube':
        freq_coords = vis_dataset.coords['chan'].values
        chan_width = vis_dataset['chan'].data
        imag_chan_chunk_size = vis_dataset.DATA.chunks[2][0]

    ###Create PSF Image Dataset
    chunks = vis_dataset.DATA.chunks
    n_imag_pol = chunks[3][0]
    image_dict = {}
    coords = {
        'd0': np.arange(_grid_parms['imsize'][0]),
        'd1': np.arange(_grid_parms['imsize'][1]),
        'chan': freq_coords,
        'pol': np.arange(n_imag_pol),
        'chan_width': ('chan', chan_width)
    image_dict[_grid_parms['sum_weight_name']] = xr.DataArray(
        grids_and_sum_weights[1], dims=['chan', 'pol'])
    image_dict[_grid_parms['image_name']] = xr.DataArray(
        corrected_psf_image, dims=['d0', 'd1', 'chan', 'pol'])
    image_dataset = xr.Dataset(image_dict, coords=coords)

    list_xarray_data_variables = [
    return _store(image_dataset, list_xarray_data_variables, _storage_parms)
def make_pb(img_dataset,pb_parms, storage_parms):
    The make_pb function currently supports rotationally symmetric airy disk primary beams. Primary beams can be generated for any number of dishes.
    The make_pb_parms['list_dish_diameters'] and make_pb_parms['list_blockage_diameters'] must be specified for each dish.
    img_dataset : xarray.core.dataset.Dataset
        Input image dataset.
    make_pb_parms : dictionary
    make_pb_parms['function'] : {'airy'}, default='airy'
        Only the airy disk function is currently supported.
    grid_parms['imsize'] : list of int, length = 2
        The image size (no padding).
    grid_parms['cell']  : list of number, length = 2, units = arcseconds
        The image cell size.
    make_pb_parms['list_dish_diameters'] : list of number
        The list of dish diameters.
    make_pb_parms['list_blockage_diameters'] = list of number
        The list of blockage diameters for each dish.
    make_pb_parms['pb_name'] = 'PB'
        The created PB name.
    storage_parms : dictionary
    storage_parms['to_disk'] : bool, default = False
        If true the dask graph is executed and saved to disk in the zarr format.
    storage_parms['append'] : bool, default = False
        If storage_parms['to_disk'] is True only the dask graph associated with the function is executed and the resulting data variables are saved to an existing zarr file on disk.
        Note that graphs on unrelated data to this function will not be executed or saved.
    storage_parms['outfile'] : str
        The zarr file to create or append to.
    storage_parms['chunks_on_disk'] : dict of int, default = {}
        The chunk size to use when writing to disk. This is ignored if storage_parms['append'] is True. The default will use the chunking of the input dataset.
    storage_parms['chunks_return'] : dict of int, default = {}
        The chunk size of the dataset that is returned. The default will use the chunking of the input dataset.
    storage_parms['graph_name'] : str
        The time to compute and save the data is stored in the attribute section of the dataset and storage_parms['graph_name'] is used in the label.
    storage_parms['compressor'] : numcodecs.blosc.Blosc,default=Blosc(cname='zstd', clevel=2, shuffle=0)
        The compression algorithm to use. Available compression algorithms can be found at https://numcodecs.readthedocs.io/en/stable/blosc.html.
    img_xds : xarray.core.dataset.Dataset
    from ngcasa._ngcasa_utils._store import _store
    from ngcasa._ngcasa_utils._check_parms import _check_storage_parms
    from ._imaging_utils._check_imaging_parms import _check_pb_parms
    import numpy as np
    import dask.array as da
    import copy, os
    import xarray as xr
    import matplotlib.pylab as plt
    _pb_parms =  copy.deepcopy(pb_parms)
    _storage_parms = copy.deepcopy(storage_parms)
    assert(_check_pb_parms(img_dataset,_pb_parms)), "######### ERROR: user_imaging_weights_parms checking failed"
    assert(_check_storage_parms(_storage_parms,'dataset.img.zarr','make_pb')), "######### ERROR: user_storage_parms checking failed"
    #parameter check
    #cube continuum check
    if _pb_parms['function'] == 'airy':
        from ._imaging_utils._make_pb_1d import _airy_disk
        pb_func = _airy_disk
        print('Only the airy function has been implemented')
    _pb_parms['ipower'] = 2
    _pb_parms['center_indx'] = []

    chan_chunk_size = img_dataset.chan_width.chunks[0][0]
    freq_coords = da.from_array(img_dataset.coords['chan'].values, chunks=(chan_chunk_size))
    pol = img_dataset.pol.values #don't want chunking here

    chunksize = (_pb_parms['imsize'][0],_pb_parms['imsize'][1]) + freq_coords.chunksize + (len(pol),) + (len(_pb_parms['list_dish_diameters']),)
    pb = da.map_blocks(pb_func, freq_coords, pol, _pb_parms, chunks=chunksize ,new_axis=[0,1,3,4], dtype=np.double)
    ## Add PB to img_dataset
    coords = {'d0': np.arange(pb_parms['imsize'][0]), 'd1': np.arange(_pb_parms['imsize'][1]),
              'chan': freq_coords.compute(), 'pol': pol,'dish_type': np.arange(len(_pb_parms['list_dish_diameters']))}
    img_dataset[_pb_parms['pb_name']] = xr.DataArray(pb, dims=['d0', 'd1', 'chan', 'pol','dish_type'])
    img_dataset = img_dataset.assign_coords({'dish_type': np.arange(len(_pb_parms['list_dish_diameters']))})
    list_xarray_data_variables = [img_dataset[_pb_parms['pb_name']]]
    return _store(img_dataset,list_xarray_data_variables,_storage_parms)
def phase_rotate(vis_dataset, global_dataset, rotation_parms, sel_parms,
    Rotate uvw coordinates and phase rotate visibilities. For a joint mosaics rotation_parms['common_tangent_reprojection'] must be true.
    The specified phasecenter and field phase centers are assumed to be in the same frame.
    East-west arrays, emphemeris objects or objects within the nearfield are not supported.
    vis_dataset : xarray.core.dataset.Dataset
        Input visibility dataset.
    global_dataset : xarray.core.dataset.Dataset
        Input global dataset.
    rotation_parms : dictionary
    rotation_parms['image_phase_center'] : list of number, length = 2, units = radians
       The phase center to rotate to (right ascension and declination).
    rotation_parms['common_tangent_reprojection']  : bool, default = True
       If true common tangent reprojection is used (should be true if a joint mosaic image is being created).
    rotation_parms['single_precision'] : bool, default = True
       If rotation_parms['single_precision'] is true then the output visibilities are cast from 128 bit complex to 64 bit complex. Mathematical operations are always done in double precision.
    sel_parms : dictionary
    sel_parms['uvw_in'] : str, default = 'UVW'
        The uvw data variable to rotate.
    sel_parms['uvw_out'] : str, default = 'UVW_ROT'
        The output uvw data variable (must not be the same as sel_parms['uvw_in']).
    sel_parms['data_in'] : str, default = 'DATA'
        The visbility data variable to phase rotate.
    sel_parms['data_out'] : str, default = 'DATA_ROT'
        The output visibility data variable (must not be the same as sel_parms['data_in']).
    storage_parms : dictionary
    storage_parms['to_disk'] : bool, default = False
       If true the dask graph is executed and saved to disk in the zarr format.
    storage_parms['append'] : bool, default = False
       If storage_parms['to_disk'] is True only the dask graph associated with the function is executed and the resulting data variables are saved to an existing zarr file on disk.
       Note that graphs on unrelated data to this function will not be executed or saved.
    storage_parms['outfile'] : str
       The zarr file to create or append to.
    storage_parms['chunks_on_disk'] : dict of int, default = {}
       The chunk size to use when writing to disk. This is ignored if storage_parms['append'] is True. The default will use the chunking of the input dataset.
    storage_parms['chunks_return'] : dict of int, default = {}
       The chunk size of the dataset that is returned. The default will use the chunking of the input dataset.
    storage_parms['graph_name'] : str
       The time to compute and save the data is stored in the attribute section of the dataset and storage_parms['graph_name'] is used in the label.
    storage_parms['compressor'] : numcodecs.blosc.Blosc,default=Blosc(cname='zstd', clevel=2, shuffle=0)
       The compression algorithm to use. Available compression algorithms can be found at https://numcodecs.readthedocs.io/en/stable/blosc.html.
    psf_dataset : xarray.core.dataset.Dataset
    #based on UVWMachine and FTMachine

        '######################### Start phase_rotate #########################'

    from ngcasa._ngcasa_utils._store import _store
    from scipy.spatial.transform import Rotation as R
    import numpy as np
    import copy
    import dask.array as da
    import xarray as xr
    from ngcasa._ngcasa_utils._check_parms import _check_storage_parms, _check_sel_parms, _check_existence_sel_parms
    from ._imaging_utils._check_imaging_parms import _check_rotation_parms
    import time
    import numba
    from numba import double
    import dask

    _sel_parms = copy.deepcopy(sel_parms)
    _rotation_parms = copy.deepcopy(rotation_parms)
    _storage_parms = copy.deepcopy(storage_parms)

    assert (_check_sel_parms(
        _sel_parms, {
            'uvw_in': 'UVW',
            'uvw_out': 'UVW_ROT',
            'data_in': 'DATA',
            'data_out': 'DATA_ROT'
        })), "######### ERROR: sel_parms checking failed"
    assert (_check_existence_sel_parms(vis_dataset, {
        'uvw_in': _sel_parms['uvw_in'],
        'data_in': _sel_parms['data_in']
    })), "######### ERROR: sel_parms checking failed"
    assert (_check_rotation_parms(_rotation_parms)
            ), "######### ERROR: rotation_parms checking failed"
    assert (_check_storage_parms(
        _storage_parms, 'dataset.vis.zarr',
        'phase_rotate')), "######### ERROR: storage_parms checking failed"

    assert (
        _sel_parms['uvw_out'] != _sel_parms['uvw_in']
    ), "######### ERROR: sel_parms checking failed sel_parms['uvw_out'] can not be the same as sel_parms['uvw_in']."
    assert (
        _sel_parms['data_out'] != _sel_parms['data_in']
    ), "######### ERROR: sel_parms checking failed sel_parms['data_out'] can not be the same as sel_parms['data_in']."

    #Phase center
    ra_image = _rotation_parms['image_phase_center'][0]
    dec_image = _rotation_parms['image_phase_center'][1]

    rotmat_image_phase_center = R.from_euler(
        'XZ', [[np.pi / 2 - dec_image, -ra_image + np.pi / 2]]).as_matrix()[0]
    image_phase_center_cosine = _directional_cosine([ra_image, dec_image])

    n_fields = global_dataset.dims['field']
    field_names = global_dataset.field
    uvw_rotmat = np.zeros((n_fields, 3, 3), np.double)
    phase_rotation = np.zeros((n_fields, 3), np.double)

    fields_phase_center = global_dataset.FIELD_PHASE_DIR.values[:, :,

    #Create a rotation matrix for each field
    for i_field in range(n_fields):
        #Not sure if last dimention in FIELD_PHASE_DIR is the ddi number
        field_phase_center = fields_phase_center[i_field, :]
        # Define rotation to a coordinate system with pole towards in-direction
        # and X-axis W; by rotating around z-axis over -(90-long); and around
        # x-axis (lat-90).
        rotmat_field_phase_center = R.from_euler('ZX', [[
            -np.pi / 2 + field_phase_center[0],
            field_phase_center[1] - np.pi / 2
        uvw_rotmat[i_field, :, :] = np.matmul(rotmat_image_phase_center,

        if _rotation_parms['common_tangent_reprojection'] == True:
                i_field, 2, 0:
                2] = 0.0  # (Common tangent rotation needed for joint mosaics, see last part of FTMachine::girarUVW in CASA)

        field_phase_center_cosine = _directional_cosine(field_phase_center)
        phase_rotation[i_field, :] = np.matmul(
            (image_phase_center_cosine - field_phase_center_cosine))

    #Apply rotation matrix to uvw
    def apply_rotation_matrix(uvw, field_id, uvw_rotmat):
        for i_time in range(uvw.shape[0]):
            #print('The index is',np.where(field_names==field[i_time])[1][0],field_id[i_time,0,0])
            #uvw[i_time,:,0:2] = -uvw[i_time,:,0:2] this gives the same result as casa (in the ftmachines uvw(negateUV(vb)) is used). In ngcasa we don't do this since the uvw definition in the gridder and vis.zarr are the same.
            #uvw[i_time,:,:] = np.matmul(uvw[i_time,:,:],uvw_rotmat[field_id[i_time],:,:])
            uvw[i_time, :, :] = uvw[i_time, :, :] @ uvw_rotmat[field_id[
                i_time, 0, 0], :, :]
        return uvw

    uvw = da.map_blocks(apply_rotation_matrix,
                        vis_dataset.field_id.data[:, None, None],

    chan_chunk_size = vis_dataset[_sel_parms['data_in']].chunks[2][0]
    freq_chan = da.from_array(vis_dataset.coords['chan'].values,

    vis_rot = da.map_blocks(apply_phasor,
                            uvw[:, :, :, None],
                            vis_dataset.field_id.data[:, None, None, None],
                            freq_chan[None, None, :, None],

    vis_dataset[_sel_parms['uvw_out']] = xr.DataArray(
        uvw, dims=vis_dataset[_sel_parms['uvw_in']].dims)
    vis_dataset[_sel_parms['data_out']] = xr.DataArray(
        vis_rot, dims=vis_dataset[_sel_parms['data_in']].dims)

    list_xarray_data_variables = [
        vis_dataset[_sel_parms['uvw_out']], vis_dataset[_sel_parms['data_out']]
    return _store(vis_dataset, list_xarray_data_variables, _storage_parms)
def phase_rotate_numba(vis_dataset, global_dataset, rotation_parms, sel_parms,
    Rotate uvw with faceting style rephasing for multifield mosaic.
    The specified phasecenter and field phase centers are assumed to be in the same frame.
    This does not support east-west arrays, emphemeris objects or objects within the nearfield.
    (no refocus).
    vis_dataset : xarray.core.dataset.Dataset
        input Visibility Dataset
    psf_dataset : xarray.core.dataset.Dataset
    #based on UVWMachine and FTMachine

    #Important: Can not applyflags before calling rotate (uvw coordinates are also flagged). This will destroy the rotation transform.
    #Performance improvements apply_rotation_matrix (jit code)

    #print('1. numba',vis_dataset.DATA[:,0,0,0].values)

    from ngcasa._ngcasa_utils._store import _store
    from scipy.spatial.transform import Rotation as R
    import scipy
    import numpy as np
    import copy
    import dask.array as da
    import xarray as xr
    from ngcasa._ngcasa_utils._check_parms import _check_storage_parms, _check_sel_parms, _check_existence_sel_parms
    from ._imaging_utils._check_imaging_parms import _check_rotation_parms
    import time
    import numba
    from numba import double
    import dask

    _sel_parms = copy.deepcopy(sel_parms)
    _rotation_parms = copy.deepcopy(rotation_parms)
    _storage_parms = copy.deepcopy(storage_parms)

    assert (_check_sel_parms(
        _sel_parms, {
            'uvw_in': 'UVW',
            'uvw_out': 'UVW_ROT',
            'data_in': 'DATA',
            'data_out': 'DATA_ROT'
        })), "######### ERROR: sel_parms checking failed"
    assert (_check_existence_sel_parms(vis_dataset, {
        'uvw_in': _sel_parms['uvw_in'],
        'data_in': _sel_parms['data_in']
    })), "######### ERROR: sel_parms checking failed"
    assert (_check_rotation_parms(_rotation_parms)
            ), "######### ERROR: rotation_parms checking failed"
    assert (_check_storage_parms(
        _storage_parms, 'dataset.vis.zarr',
        'phase_rotate')), "######### ERROR: storage_parms checking failed"

    assert (
        _sel_parms['uvw_out'] != _sel_parms['uvw_in']
    ), "######### ERROR: sel_parms checking failed sel_parms['uvw_out'] can not be the same as sel_parms['uvw_in']."
    assert (
        _sel_parms['data_out'] != _sel_parms['data_in']
    ), "######### ERROR: sel_parms checking failed sel_parms['data_out'] can not be the same as sel_parms['data_in']."

    #Phase center
    ra_image = _rotation_parms['image_phase_center'][0]
    dec_image = _rotation_parms['image_phase_center'][1]

    rotmat_image_phase_center = R.from_euler(
        'XZ', [[np.pi / 2 - dec_image, -ra_image + np.pi / 2]]).as_matrix()[0]
    image_phase_center_cosine = _directional_cosine([ra_image, dec_image])

    n_fields = global_dataset.dims['field']
    field_names = global_dataset.field
    uvw_rotmat = np.zeros((n_fields, 3, 3), np.double)
    phase_rotation = np.zeros((n_fields, 3), np.double)

    fields_phase_center = global_dataset.FIELD_PHASE_DIR.values[:, :,


    #Create a rotation matrix for each field
    for i_field in range(n_fields):
        #Not sure if last dimention in FIELD_PHASE_DIR is the ddi number
        field_phase_center = fields_phase_center[i_field, :]
        # Define rotation to a coordinate system with pole towards in-direction
        # and X-axis W; by rotating around z-axis over -(90-long); and around
        # x-axis (lat-90).
        rotmat_field_phase_center = R.from_euler('ZX', [[
            -np.pi / 2 + field_phase_center[0],
            field_phase_center[1] - np.pi / 2
        uvw_rotmat[i_field, :, :] = np.matmul(rotmat_image_phase_center,

        if _rotation_parms['common_tangent_reprojection'] == True:
                i_field, 2, 0:
                2] = 0.0  # (Common tangent rotation needed for joint mosaics, see last part of FTMachine::girarUVW in CASA)

        field_phase_center_cosine = _directional_cosine(field_phase_center)
        phase_rotation[i_field, :] = np.matmul(
            (image_phase_center_cosine - field_phase_center_cosine))

    #Apply rotation matrix to uvw
    #@jit(nopython=True, cache=True, nogil=True)
    def apply_rotation_matrix(uvw, field_id, uvw_rotmat):
        for i_time in range(uvw.shape[0]):
            #print('The index is',np.where(field_names==field[i_time])[1][0],field_id[i_time,0,0])
            #uvw[i_time,:,0:2] = -uvw[i_time,:,0:2] this gives the same result as casa (in the ftmachines uvw(negateUV(vb)) is used). In ngcasa we don't do this since the uvw definition in the gridder and vis.zarr are the same.
            #uvw[i_time,:,:] = np.matmul(uvw[i_time,:,:],uvw_rotmat[field_id[i_time],:,:])
            uvw[i_time, :, :] = uvw[i_time, :, :] @ uvw_rotmat[field_id[
                i_time, 0, 0], :, :]
        return uvw

    uvw = da.map_blocks(apply_rotation_matrix,
                        vis_dataset.field_id.data[:, None, None],


    chan_chunk_size = vis_dataset[_sel_parms['data_in']].chunks[2][0]
    freq_chan = da.from_array(vis_dataset.coords['chan'].values,

    #print('2. numba',vis_dataset[_sel_parms['data_in']][:,0,0,0].values)
    vis_rot = da.map_blocks(apply_phasor,
                            uvw[:, :, :, None],
                            vis_dataset.field_id.data[:, None, None, None],
                            freq_chan[None, None, :, None],


    vis_dataset[_sel_parms['uvw_out']] = xr.DataArray(
        uvw, dims=vis_dataset[_sel_parms['uvw_in']].dims)
    vis_dataset[_sel_parms['data_out']] = xr.DataArray(
        vis_rot, dims=vis_dataset[_sel_parms['data_in']].dims)

    #    dask.visualize(vis_dataset[_sel_parms['uvw_out']],filename='uvw_rot_dataset')
    #    dask.visualize(vis_dataset[_sel_parms['data_out']],filename='vis_rot_dataset')
    #    dask.visualize(vis_dataset,filename='vis_dataset_before_append')

    list_xarray_data_variables = [
        vis_dataset[_sel_parms['uvw_out']], vis_dataset[_sel_parms['data_out']]
    return _store(vis_dataset, list_xarray_data_variables, _storage_parms)