Ejemplo n.º 1
0
 def _rmatvec(self, x):
     if self.reshape:
         x = da.reshape(x, self.dims_fft)
     if self.chunks[1] is not None:
         x = x.rechunk(self.chunks[1])
     if not self.reshape:
         if self.real:
             y = sqrt(self.nfft) * da.fft.irfft(x, n=self.nfft, axis=-1)
             y = da.real(y)
         else:
             y = sqrt(self.nfft) * da.fft.ifft(x, n=self.nfft, axis=-1)
         if self.nfft != self.dims[self.dir]:
             y = y[:self.dims[self.dir]]
         if self.fftshift:
             y = da.fft.fftshift(y)
     else:
         if self.real:
             y = sqrt(self.nfft) * da.fft.irfft(
                 x, n=self.nfft, axis=self.dir)
             y = da.real(y)
         else:
             y = sqrt(self.nfft) * da.fft.ifft(
                 x, n=self.nfft, axis=self.dir)
         if self.nfft != self.dims[self.dir]:
             y = da.take(y,
                         np.arange(0, self.dims[self.dir]),
                         axis=self.dir)
         if self.fftshift:
             y = da.fft.fftshift(y, axes=self.dir)
         y = y.ravel()
     y = y.astype(self.dtype)
     return y
Ejemplo n.º 2
0
def check_dmd_dask(D, mu, Phi, show_warning=True):
    """
        Checks how close the approximation using DMD is to the original data.

        Returns:
            None if the difference is within the tolerance
            Displays a warning otherwise.
    """
    X = D[:, 0:-1]
    Y = D[:, 1:]
    #Y_est = da.dot(da.dot(da.dot(Phi, da.diag(mu)), pinv_SVD(Phi)), X)
    Phi_inv = pinv_SVD(Phi)
    PhiMu = da.dot(Phi, da.diag(mu))
    #Y_est = da.dot(da.dot(PhiMu, Phi_inv), X)
    Y_est = da.dot(PhiMu, da.dot(Phi_inv, X))
    diff = da.real(Y - Y_est)
    res = da.fabs(diff)
    rtol = 1.e-8
    atol = 1.e-5

    if da.all(res < atol + rtol * da.fabs(da.real(Y_est))).compute():
        return (None)
    else:
        #if not b and show_warning:
        warn('dmd result does not satisfy Y=AX')
Ejemplo n.º 3
0
 def _matvec(self, x):
     if self.reshape:
         x = da.reshape(x, self.dims)
     if self.chunks[0] is not None:
         x = x.rechunk(self.chunks[0])
     if not self.reshape:
         if self.fftshift:
             x = da.fft.ifftshift(x)
         if self.real:
             y = sqrt(1. / self.nfft) * da.fft.rfft(
                 da.real(x), n=self.nfft, axis=-1)
         else:
             y = sqrt(1. / self.nfft) * da.fft.fft(x, n=self.nfft, axis=-1)
     else:
         if self.fftshift:
             x = da.fft.ifftshift(x, axes=self.dir)
         if self.real:
             y = sqrt(1. / self.nfft) * da.fft.rfft(
                 da.real(x), n=self.nfft, axis=self.dir)
         else:
             y = sqrt(1. / self.nfft) * da.fft.fft(
                 x, n=self.nfft, axis=self.dir)
         y = y.ravel()
     y = y.astype(self.cdtype)
     return y
Ejemplo n.º 4
0
    def _rmatvec(self, x):
        # apply forward fft
        x = da.reshape(x, self.dimsd)
        y = sqrt(1. / self.nt) * da.fft.rfft(x, n=self.nt, axis=0)
        y = y.astype(self.cdtype)
        y = y[:self.nfmax]

        # apply batched matrix mult
        y = y.rechunk((self.G.chunks[0], self.nr, self.nv))
        if self.saveGt:
            if self.conj:
                y = y.conj()
            y = da.matmul(self.GT, y)
            if self.conj:
                y = y.conj()
        else:
            if self.conj:
                y = da.matmul(y.transpose(0, 2, 1), self.G).transpose(0, 2, 1)
            else:
                y = da.matmul(y.transpose(0, 2, 1).conj(),
                              self.G).transpose(0, 2, 1).conj()
        if not self.prescaled:
            y *= self.dr * self.dt * np.sqrt(self.nt)

        # apply inverse fft
        y = da.pad(y, ((0, self.nfft - self.nfmax), (0, 0), (0, 0)),
                   mode='constant')
        y = y.rechunk(self.dimsdf)
        y = sqrt(self.nt) * da.fft.irfft(y, n=self.nt, axis=0)
        if self.twosided:
            y = da.fft.fftshift(y, axes=0)
        y = y.astype(self.dtype)
        y = da.real(y)
        return y.ravel()
Ejemplo n.º 5
0
def nearestPD(A, threads=1):
    """
    Find the nearest positive-definite matrix to input

    A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1], which
    credits [2] from Ahmed Fasih

    [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd

    [2] N.J. Higham, "Computing a nearest symmetric positive semidefinite
    matrix" (1988): https://doi.org/10.1016/0024-3795(88)90223-6
    """
    isPD = lambda x: da.all(np.linalg.eigvals(x) > 0).compute()
    B = (A + A.T) / 2
    _, s, V = da.linalg.svd(B)
    H = da.dot(V.T, da.dot(da.diag(s), V))
    A2 = (B + H) / 2
    A3 = (A2 + A2.T) / 2
    if isPD(A3):
        return A3
    spacing = da.spacing(da.linalg.norm(A))
    # The above is different from [1]. It appears that MATLAB's `chol` Cholesky
    # decomposition will accept matrixes with exactly 0-eigenvalue, whereas
    # Numpy's will not. So where [1] uses `eps(mineig)` (where `eps` is Matlab
    # for `np.spacing`), we use the above definition. CAVEAT: our `spacing`
    # will be much larger than [1]'s `eps(mineig)`, since `mineig` is usually on
    # the order of 1e-16, and `eps(1e-16)` is on the order of 1e-34, whereas
    # `spacing` will, for Gaussian random matrixes of small dimension, be on
    # othe order of 1e-16. In practice, both ways converge, as the unit test
    # below suggests.
    eye_chunk = estimate_chunks((A.shape[0], A.shape[0]), threads=threads)[0]
    I = da.eye(A.shape[0], chunks=eye_chunk)
    k = 1
    while not isPD(A3):
        mineig = da.min(da.real(np.linalg.eigvals(A3)))
        A3 += I * (-mineig * k**2 + spacing)
        k += 1
    return A3
Ejemplo n.º 6
0
def test_arithmetic():
    x = np.arange(5).astype('f4') + 2
    y = np.arange(5).astype('i8') + 2
    z = np.arange(5).astype('i4') + 2
    a = da.from_array(x, chunks=(2,))
    b = da.from_array(y, chunks=(2,))
    c = da.from_array(z, chunks=(2,))
    assert eq(a + b, x + y)
    assert eq(a * b, x * y)
    assert eq(a - b, x - y)
    assert eq(a / b, x / y)
    assert eq(b & b, y & y)
    assert eq(b | b, y | y)
    assert eq(b ^ b, y ^ y)
    assert eq(a // b, x // y)
    assert eq(a ** b, x ** y)
    assert eq(a % b, x % y)
    assert eq(a > b, x > y)
    assert eq(a < b, x < y)
    assert eq(a >= b, x >= y)
    assert eq(a <= b, x <= y)
    assert eq(a == b, x == y)
    assert eq(a != b, x != y)

    assert eq(a + 2, x + 2)
    assert eq(a * 2, x * 2)
    assert eq(a - 2, x - 2)
    assert eq(a / 2, x / 2)
    assert eq(b & True, y & True)
    assert eq(b | True, y | True)
    assert eq(b ^ True, y ^ True)
    assert eq(a // 2, x // 2)
    assert eq(a ** 2, x ** 2)
    assert eq(a % 2, x % 2)
    assert eq(a > 2, x > 2)
    assert eq(a < 2, x < 2)
    assert eq(a >= 2, x >= 2)
    assert eq(a <= 2, x <= 2)
    assert eq(a == 2, x == 2)
    assert eq(a != 2, x != 2)

    assert eq(2 + b, 2 + y)
    assert eq(2 * b, 2 * y)
    assert eq(2 - b, 2 - y)
    assert eq(2 / b, 2 / y)
    assert eq(True & b, True & y)
    assert eq(True | b, True | y)
    assert eq(True ^ b, True ^ y)
    assert eq(2 // b, 2 // y)
    assert eq(2 ** b, 2 ** y)
    assert eq(2 % b, 2 % y)
    assert eq(2 > b, 2 > y)
    assert eq(2 < b, 2 < y)
    assert eq(2 >= b, 2 >= y)
    assert eq(2 <= b, 2 <= y)
    assert eq(2 == b, 2 == y)
    assert eq(2 != b, 2 != y)

    assert eq(-a, -x)
    assert eq(abs(a), abs(x))
    assert eq(~(a == b), ~(x == y))
    assert eq(~(a == b), ~(x == y))

    assert eq(da.logaddexp(a, b), np.logaddexp(x, y))
    assert eq(da.logaddexp2(a, b), np.logaddexp2(x, y))
    assert eq(da.exp(b), np.exp(y))
    assert eq(da.log(a), np.log(x))
    assert eq(da.log10(a), np.log10(x))
    assert eq(da.log1p(a), np.log1p(x))
    assert eq(da.expm1(b), np.expm1(y))
    assert eq(da.sqrt(a), np.sqrt(x))
    assert eq(da.square(a), np.square(x))

    assert eq(da.sin(a), np.sin(x))
    assert eq(da.cos(b), np.cos(y))
    assert eq(da.tan(a), np.tan(x))
    assert eq(da.arcsin(b/10), np.arcsin(y/10))
    assert eq(da.arccos(b/10), np.arccos(y/10))
    assert eq(da.arctan(b/10), np.arctan(y/10))
    assert eq(da.arctan2(b*10, a), np.arctan2(y*10, x))
    assert eq(da.hypot(b, a), np.hypot(y, x))
    assert eq(da.sinh(a), np.sinh(x))
    assert eq(da.cosh(b), np.cosh(y))
    assert eq(da.tanh(a), np.tanh(x))
    assert eq(da.arcsinh(b*10), np.arcsinh(y*10))
    assert eq(da.arccosh(b*10), np.arccosh(y*10))
    assert eq(da.arctanh(b/10), np.arctanh(y/10))
    assert eq(da.deg2rad(a), np.deg2rad(x))
    assert eq(da.rad2deg(a), np.rad2deg(x))

    assert eq(da.logical_and(a < 1, b < 4), np.logical_and(x < 1, y < 4))
    assert eq(da.logical_or(a < 1, b < 4), np.logical_or(x < 1, y < 4))
    assert eq(da.logical_xor(a < 1, b < 4), np.logical_xor(x < 1, y < 4))
    assert eq(da.logical_not(a < 1), np.logical_not(x < 1))
    assert eq(da.maximum(a, 5 - a), np.maximum(a, 5 - a))
    assert eq(da.minimum(a, 5 - a), np.minimum(a, 5 - a))
    assert eq(da.fmax(a, 5 - a), np.fmax(a, 5 - a))
    assert eq(da.fmin(a, 5 - a), np.fmin(a, 5 - a))

    assert eq(da.isreal(a + 1j * b), np.isreal(x + 1j * y))
    assert eq(da.iscomplex(a + 1j * b), np.iscomplex(x + 1j * y))
    assert eq(da.isfinite(a), np.isfinite(x))
    assert eq(da.isinf(a), np.isinf(x))
    assert eq(da.isnan(a), np.isnan(x))
    assert eq(da.signbit(a - 3), np.signbit(x - 3))
    assert eq(da.copysign(a - 3, b), np.copysign(x - 3, y))
    assert eq(da.nextafter(a - 3, b), np.nextafter(x - 3, y))
    assert eq(da.ldexp(c, c), np.ldexp(z, z))
    assert eq(da.fmod(a * 12, b), np.fmod(x * 12, y))
    assert eq(da.floor(a * 0.5), np.floor(x * 0.5))
    assert eq(da.ceil(a), np.ceil(x))
    assert eq(da.trunc(a / 2), np.trunc(x / 2))

    assert eq(da.degrees(b), np.degrees(y))
    assert eq(da.radians(a), np.radians(x))

    assert eq(da.rint(a + 0.3), np.rint(x + 0.3))
    assert eq(da.fix(a - 2.5), np.fix(x - 2.5))

    assert eq(da.angle(a + 1j), np.angle(x + 1j))
    assert eq(da.real(a + 1j), np.real(x + 1j))
    assert eq((a + 1j).real, np.real(x + 1j))
    assert eq(da.imag(a + 1j), np.imag(x + 1j))
    assert eq((a + 1j).imag, np.imag(x + 1j))
    assert eq(da.conj(a + 1j * b), np.conj(x + 1j * y))
    assert eq((a + 1j * b).conj(), (x + 1j * y).conj())

    assert eq(da.clip(b, 1, 4), np.clip(y, 1, 4))
    assert eq(da.fabs(b), np.fabs(y))
    assert eq(da.sign(b - 2), np.sign(y - 2))

    l1, l2 = da.frexp(a)
    r1, r2 = np.frexp(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    l1, l2 = da.modf(a)
    r1, r2 = np.modf(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    assert eq(da.around(a, -1), np.around(x, -1))
Ejemplo n.º 7
0
def test_arithmetic():
    x = np.arange(5).astype('f4') + 2
    y = np.arange(5).astype('i8') + 2
    z = np.arange(5).astype('i4') + 2
    a = da.from_array(x, chunks=(2, ))
    b = da.from_array(y, chunks=(2, ))
    c = da.from_array(z, chunks=(2, ))
    assert eq(a + b, x + y)
    assert eq(a * b, x * y)
    assert eq(a - b, x - y)
    assert eq(a / b, x / y)
    assert eq(b & b, y & y)
    assert eq(b | b, y | y)
    assert eq(b ^ b, y ^ y)
    assert eq(a // b, x // y)
    assert eq(a**b, x**y)
    assert eq(a % b, x % y)
    assert eq(a > b, x > y)
    assert eq(a < b, x < y)
    assert eq(a >= b, x >= y)
    assert eq(a <= b, x <= y)
    assert eq(a == b, x == y)
    assert eq(a != b, x != y)

    assert eq(a + 2, x + 2)
    assert eq(a * 2, x * 2)
    assert eq(a - 2, x - 2)
    assert eq(a / 2, x / 2)
    assert eq(b & True, y & True)
    assert eq(b | True, y | True)
    assert eq(b ^ True, y ^ True)
    assert eq(a // 2, x // 2)
    assert eq(a**2, x**2)
    assert eq(a % 2, x % 2)
    assert eq(a > 2, x > 2)
    assert eq(a < 2, x < 2)
    assert eq(a >= 2, x >= 2)
    assert eq(a <= 2, x <= 2)
    assert eq(a == 2, x == 2)
    assert eq(a != 2, x != 2)

    assert eq(2 + b, 2 + y)
    assert eq(2 * b, 2 * y)
    assert eq(2 - b, 2 - y)
    assert eq(2 / b, 2 / y)
    assert eq(True & b, True & y)
    assert eq(True | b, True | y)
    assert eq(True ^ b, True ^ y)
    assert eq(2 // b, 2 // y)
    assert eq(2**b, 2**y)
    assert eq(2 % b, 2 % y)
    assert eq(2 > b, 2 > y)
    assert eq(2 < b, 2 < y)
    assert eq(2 >= b, 2 >= y)
    assert eq(2 <= b, 2 <= y)
    assert eq(2 == b, 2 == y)
    assert eq(2 != b, 2 != y)

    assert eq(-a, -x)
    assert eq(abs(a), abs(x))
    assert eq(~(a == b), ~(x == y))
    assert eq(~(a == b), ~(x == y))

    assert eq(da.logaddexp(a, b), np.logaddexp(x, y))
    assert eq(da.logaddexp2(a, b), np.logaddexp2(x, y))
    assert eq(da.exp(b), np.exp(y))
    assert eq(da.log(a), np.log(x))
    assert eq(da.log10(a), np.log10(x))
    assert eq(da.log1p(a), np.log1p(x))
    assert eq(da.expm1(b), np.expm1(y))
    assert eq(da.sqrt(a), np.sqrt(x))
    assert eq(da.square(a), np.square(x))

    assert eq(da.sin(a), np.sin(x))
    assert eq(da.cos(b), np.cos(y))
    assert eq(da.tan(a), np.tan(x))
    assert eq(da.arcsin(b / 10), np.arcsin(y / 10))
    assert eq(da.arccos(b / 10), np.arccos(y / 10))
    assert eq(da.arctan(b / 10), np.arctan(y / 10))
    assert eq(da.arctan2(b * 10, a), np.arctan2(y * 10, x))
    assert eq(da.hypot(b, a), np.hypot(y, x))
    assert eq(da.sinh(a), np.sinh(x))
    assert eq(da.cosh(b), np.cosh(y))
    assert eq(da.tanh(a), np.tanh(x))
    assert eq(da.arcsinh(b * 10), np.arcsinh(y * 10))
    assert eq(da.arccosh(b * 10), np.arccosh(y * 10))
    assert eq(da.arctanh(b / 10), np.arctanh(y / 10))
    assert eq(da.deg2rad(a), np.deg2rad(x))
    assert eq(da.rad2deg(a), np.rad2deg(x))

    assert eq(da.logical_and(a < 1, b < 4), np.logical_and(x < 1, y < 4))
    assert eq(da.logical_or(a < 1, b < 4), np.logical_or(x < 1, y < 4))
    assert eq(da.logical_xor(a < 1, b < 4), np.logical_xor(x < 1, y < 4))
    assert eq(da.logical_not(a < 1), np.logical_not(x < 1))
    assert eq(da.maximum(a, 5 - a), np.maximum(a, 5 - a))
    assert eq(da.minimum(a, 5 - a), np.minimum(a, 5 - a))
    assert eq(da.fmax(a, 5 - a), np.fmax(a, 5 - a))
    assert eq(da.fmin(a, 5 - a), np.fmin(a, 5 - a))

    assert eq(da.isreal(a + 1j * b), np.isreal(x + 1j * y))
    assert eq(da.iscomplex(a + 1j * b), np.iscomplex(x + 1j * y))
    assert eq(da.isfinite(a), np.isfinite(x))
    assert eq(da.isinf(a), np.isinf(x))
    assert eq(da.isnan(a), np.isnan(x))
    assert eq(da.signbit(a - 3), np.signbit(x - 3))
    assert eq(da.copysign(a - 3, b), np.copysign(x - 3, y))
    assert eq(da.nextafter(a - 3, b), np.nextafter(x - 3, y))
    assert eq(da.ldexp(c, c), np.ldexp(z, z))
    assert eq(da.fmod(a * 12, b), np.fmod(x * 12, y))
    assert eq(da.floor(a * 0.5), np.floor(x * 0.5))
    assert eq(da.ceil(a), np.ceil(x))
    assert eq(da.trunc(a / 2), np.trunc(x / 2))

    assert eq(da.degrees(b), np.degrees(y))
    assert eq(da.radians(a), np.radians(x))

    assert eq(da.rint(a + 0.3), np.rint(x + 0.3))
    assert eq(da.fix(a - 2.5), np.fix(x - 2.5))

    assert eq(da.angle(a + 1j), np.angle(x + 1j))
    assert eq(da.real(a + 1j), np.real(x + 1j))
    assert eq((a + 1j).real, np.real(x + 1j))
    assert eq(da.imag(a + 1j), np.imag(x + 1j))
    assert eq((a + 1j).imag, np.imag(x + 1j))
    assert eq(da.conj(a + 1j * b), np.conj(x + 1j * y))
    assert eq((a + 1j * b).conj(), (x + 1j * y).conj())

    assert eq(da.clip(b, 1, 4), np.clip(y, 1, 4))
    assert eq(da.fabs(b), np.fabs(y))
    assert eq(da.sign(b - 2), np.sign(y - 2))

    l1, l2 = da.frexp(a)
    r1, r2 = np.frexp(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    l1, l2 = da.modf(a)
    r1, r2 = np.modf(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    assert eq(da.around(a, -1), np.around(x, -1))
Ejemplo n.º 8
0
def make_gridding_convolution_function(vis_dataset, global_dataset, gcf_parms, grid_parms, storage_parms):
    """
    Currently creates a gcf to correct for the primary beams of antennas and supports heterogenous arrays (antennas with different dish sizes).
    Only the airy disk and ALMA airy disk model is implemented.
    In the future support will be added for beam squint, pointing corrections, w projection, and including a prolate spheroidal term.
    
    Parameters
    ----------
    vis_dataset : xarray.core.dataset.Dataset
        Input visibility dataset.
    gcf_parms : dictionary
    gcf_parms['function'] : {'alma_airy'/'airy'}, default = 'alma_airy'
        The primary beam model used (a function of the dish diameter and blockage diameter).
    gcf_parms['list_dish_diameters']  : list of number, units = meter
        A list of unique antenna dish diameters.
    gcf_parms['list_blockage_diameters']  : list of number, units = meter
        A list of unique feed blockage diameters (must be the same length as gcf_parms['list_dish_diameters']).
    gcf_parms['unique_ant_indx']  : list of int
        A list that has indeces for the gcf_parms['list_dish_diameters'] and gcf_parms['list_blockage_diameters'] lists, for each antenna.
    gcf_parms['image_phase_center']  : list of number, length = 2, units = radians
        The mosaic image phase center.
    gcf_parms['a_chan_num_chunk']  : int, default = 3
        The number of chunks in the channel dimension of the gridding convolution function data variable.
    gcf_parms['oversampling']  : list of int, length = 2, default = [10,10]
        The oversampling of the gridding convolution function.
    gcf_parms['max_support']  : list of int, length = 2, default = [15,15]
        The maximum allowable support of the gridding convolution function.
    gcf_parms['support_cut_level']  : number, default = 0.025
        The antennuation at which to truncate the gridding convolution function.
    gcf_parms['chan_tolerance_factor']  : number, default = 0.005
        It is the fractional bandwidth at which the frequency dependence of the primary beam can be ignored and determines the number of frequencies for which to calculate a gridding convolution function. Number of channels equals the fractional bandwidth devided by gcf_parms['chan_tolerance_factor'].
    grid_parms : dictionary
    grid_parms['image_size'] : list of int, length = 2
        The image size (no padding).
    grid_parms['cell_size']  : list of number, length = 2, units = arcseconds
        The image cell size.
    storage_parms : dictionary
    storage_parms['to_disk'] : bool, default = False
        If true the dask graph is executed and saved to disk in the zarr format.
    storage_parms['append'] : bool, default = False
        If storage_parms['to_disk'] is True only the dask graph associated with the function is executed and the resulting data variables are saved to an existing zarr file on disk.
        Note that graphs on unrelated data to this function will not be executed or saved.
    storage_parms['outfile'] : str
        The zarr file to create or append to.
    storage_parms['chunks_on_disk'] : dict of int, default = {}
        The chunk size to use when writing to disk. This is ignored if storage_parms['append'] is True. The default will use the chunking of the input dataset.
    storage_parms['chunks_return'] : dict of int, default = {}
        The chunk size of the dataset that is returned. The default will use the chunking of the input dataset.
    storage_parms['graph_name'] : str
        The time to compute and save the data is stored in the attribute section of the dataset and storage_parms['graph_name'] is used in the label.
    storage_parms['compressor'] : numcodecs.blosc.Blosc,default=Blosc(cname='zstd', clevel=2, shuffle=0)
        The compression algorithm to use. Available compression algorithms can be found at https://numcodecs.readthedocs.io/en/stable/blosc.html.
    Returns
    -------
    gcf_dataset : xarray.core.dataset.Dataset
            
    """
    print('######################### Start make_gridding_convolution_function #########################')
    
    from ngcasa._ngcasa_utils._store import _store
    from ngcasa._ngcasa_utils._check_parms import _check_storage_parms
    from ._imaging_utils._check_imaging_parms import _check_pb_parms
    from ._imaging_utils._check_imaging_parms import _check_grid_parms, _check_gcf_parms
    from ._imaging_utils._gridding_convolutional_kernels import _create_prolate_spheroidal_kernel_2D, _create_prolate_spheroidal_image_2D
    from ._imaging_utils._remove_padding import _remove_padding
    import numpy as np
    import dask.array as da
    import copy, os
    import xarray as xr
    import itertools
    import dask
    import dask.array.fft as dafft
    
    import matplotlib.pylab as plt
    
    _gcf_parms = copy.deepcopy(gcf_parms)
    _grid_parms = copy.deepcopy(grid_parms)
    _storage_parms = copy.deepcopy(storage_parms)
    
    _gcf_parms['basline_ant'] = vis_dataset.antennas.values # n_baseline x 2 (ant pair)
    _gcf_parms['freq_chan'] = vis_dataset.chan.values
    _gcf_parms['pol'] = vis_dataset.pol.values
    _gcf_parms['vis_data_chunks'] = vis_dataset.DATA.chunks
    _gcf_parms['field_phase_dir'] = np.array(global_dataset.FIELD_PHASE_DIR.values[:,:,vis_dataset.attrs['ddi']])
    
    assert(_check_gcf_parms(_gcf_parms)), "######### ERROR: gcf_parms checking failed"
    assert(_check_grid_parms(_grid_parms)), "######### ERROR: grid_parms checking failed"
    assert(_check_storage_parms(_storage_parms,'dataset.gcf.zarr','make_gcf')), "######### ERROR: user_storage_parms checking failed"
    
    assert(not _storage_parms['append']), "######### ERROR: storage_parms['append'] = True is not available for make_gridding_convolution_function"
        
    if _gcf_parms['function'] == 'airy':
        from ._imaging_utils._make_pb_symmetric import _airy_disk_rorder
        pb_func = _airy_disk_rorder
    elif _gcf_parms['function'] == 'alma_airy':
        from ._imaging_utils._make_pb_symmetric import _alma_airy_disk_rorder
        pb_func = _alma_airy_disk_rorder
    else:
        assert(False), "######### ERROR: Only airy and alma_airy function has been implemented"
        
    #For now only a_term works
    _gcf_parms['a_term'] =  True
    _gcf_parms['ps_term'] =  False
        
    _gcf_parms['resize_conv_size'] = (_gcf_parms['max_support'] + 1)*_gcf_parms['oversampling']
    #resize_conv_size = _gcf_parms['resize_conv_size']
    
    if _gcf_parms['ps_term'] == True:
        '''
        ps_term = _create_prolate_spheroidal_kernel_2D(_gcf_parms['oversampling'],np.array([7,7])) #This is only used with a_term == False. Support is hardcoded to 7 until old ps code is replaced by a general function.
        center = _grid_parms['image_center']
        center_embed = np.array(ps_term.shape)//2
        ps_term_padded = np.zeros(_grid_parms['image_size'])
        ps_term_padded[center[0]-center_embed[0]:center[0]+center_embed[0],center[1]-center_embed[1] : center[1]+center_embed[1]] = ps_term
        ps_term_padded_ifft = dafft.fftshift(dafft.ifft2(dafft.ifftshift(da.from_array(ps_term_padded))))

        ps_image = da.from_array(_remove_padding(_create_prolate_spheroidal_image_2D(_grid_parms['image_size_padded']),_grid_parms['image_size']),chunks=_grid_parms['image_size'])
        
        #Effecively no mapping needed if ps_term == True and a_term == False
        cf_baseline_map = np.zeros((len(_gcf_parms['basline_ant']),),dtype=int)
        cf_chan_map = np.zeros((len(_gcf_parms['freq_chan']),),dtype=int)
        cf_pol_map = np.zeros((len(_gcf_parms['pol']),),dtype=int)
        '''
    
    if _gcf_parms['a_term'] == True:
        n_unique_ant = len(_gcf_parms['list_dish_diameters'])
        cf_baseline_map,pb_ant_pairs = create_cf_baseline_map(_gcf_parms['unique_ant_indx'],_gcf_parms['basline_ant'],n_unique_ant)
        
        cf_chan_map, pb_freq = create_cf_chan_map(_gcf_parms['freq_chan'],_gcf_parms['chan_tolerance_factor'])
        pb_freq = da.from_array(pb_freq,chunks=np.ceil(len(pb_freq)/_gcf_parms['a_chan_num_chunk'] ))
        
        cf_pol_map = np.zeros((len(_gcf_parms['pol']),),dtype=int) #create_cf_pol_map(), currently treating all pols the same
        pb_pol = da.from_array(np.array([0]),1)
        
        n_chunks_in_each_dim = [pb_freq.numblocks[0],pb_pol.numblocks[0]]
        iter_chunks_indx = itertools.product(np.arange(n_chunks_in_each_dim[0]), np.arange(n_chunks_in_each_dim[1]))
        chan_chunk_sizes = pb_freq.chunks
        pol_chunk_sizes = pb_pol.chunks
        
        #print(pb_freq, pb_pol,pol_chunk_sizes)
        list_baseline_pb = []
        list_weight_baseline_pb_sqrd = []
        for c_chan, c_pol in iter_chunks_indx:
                #print('chan, pol ',c_chan,c_pol)
                _gcf_parms['ipower'] = 1
                delayed_baseline_pb = dask.delayed(make_baseline_patterns)(pb_freq.partitions[c_chan],pb_pol.partitions[c_pol],dask.delayed(pb_ant_pairs),dask.delayed(pb_func),dask.delayed(_gcf_parms),dask.delayed(_grid_parms))
                
                list_baseline_pb.append(da.from_delayed(delayed_baseline_pb,(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_grid_parms['image_size_padded'][0],_grid_parms['image_size_padded'][1]),dtype=np.double))
                              
                _gcf_parms['ipower'] = 2
                delayed_weight_baseline_pb_sqrd = dask.delayed(make_baseline_patterns)(pb_freq.partitions[c_chan],pb_pol.partitions[c_pol],dask.delayed(pb_ant_pairs),dask.delayed(pb_func),dask.delayed(_gcf_parms),dask.delayed(_grid_parms))
                
                list_weight_baseline_pb_sqrd.append(da.from_delayed(delayed_weight_baseline_pb_sqrd,(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_grid_parms['image_size_padded'][0],_grid_parms['image_size_padded'][1]),dtype=np.double))
               
        
        baseline_pb = da.concatenate(list_baseline_pb,axis=1)
        weight_baseline_pb_sqrd = da.concatenate(list_weight_baseline_pb_sqrd,axis=1)
    
    #Combine patterns and fft to obtain the gridding convolutional kernel
    #print(weight_baseline_pb_sqrd)

    dataset_dict = {}
    list_xarray_data_variables = []
    if (_gcf_parms['a_term'] == True) and (_gcf_parms['ps_term'] == True):
        conv_kernel = da.real(dafft.fftshift(dafft.fft2(dafft.ifftshift(ps_term_padded_ifft*baseline_pb, axes=(3, 4)), axes=(3, 4)), axes=(3, 4)))
        conv_weight_kernel = da.real(dafft.fftshift(dafft.fft2(dafft.ifftshift(weight_baseline_pb_sqrd, axes=(3, 4)), axes=(3, 4)), axes=(3, 4)))
        
        
        list_conv_kernel = []
        list_weight_conv_kernel = []
        list_conv_support = []
        iter_chunks_indx = itertools.product(np.arange(n_chunks_in_each_dim[0]), np.arange(n_chunks_in_each_dim[1]))
        for c_chan, c_pol in iter_chunks_indx:
                delayed_kernels_and_support = dask.delayed(resize_and_calc_support)(conv_kernel.partitions[:,c_chan,c_pol,:,:],conv_weight_kernel.partitions[:,c_chan,c_pol,:,:],dask.delayed(_gcf_parms),dask.delayed(_grid_parms))
                list_conv_kernel.append(da.from_delayed(delayed_kernels_and_support[0],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]),dtype=np.double))
                list_weight_conv_kernel.append(da.from_delayed(delayed_kernels_and_support[1],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]),dtype=np.double))
                list_conv_support.append(da.from_delayed(delayed_kernels_and_support[2],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],2),dtype=np.int))
                
        
        conv_kernel = da.concatenate(list_conv_kernel,axis=1)
        weight_conv_kernel = da.concatenate(list_weight_conv_kernel,axis=1)
        conv_support = da.concatenate(list_conv_support,axis=1)
        
    
        dataset_dict['SUPPORT'] = xr.DataArray(conv_support, dims=['conv_baseline','conv_chan','conv_pol','xy'])
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(ps_image, dims=['l','m'])
        dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(weight_conv_kernel, dims=['conv_baseline','conv_chan','conv_pol','u','v'])
    elif (_gcf_parms['a_term'] == False) and (_gcf_parms['ps_term'] == True):
        support = np.array([7,7])
        dataset_dict['SUPPORT'] = xr.DataArray(support[None,None,None,:], dims=['conv_baseline','conv_chan','conv_pol','xy'])
        conv_kernel = np.zeros((1,1,1,_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]))
        center = _gcf_parms['resize_conv_size']//2
        center_embed = np.array(ps_term.shape)//2
        conv_kernel[0,0,0,center[0]-center_embed[0]:center[0]+center_embed[0],center[1]-center_embed[1] : center[1]+center_embed[1]] = ps_term
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(ps_image, dims=['l','m'])
        ##Enabled for test
        #dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(conv_kernel, dims=['conv_baseline','conv_chan','conv_pol','u','v'])
    elif (_gcf_parms['a_term'] == True) and (_gcf_parms['ps_term'] == False):
        conv_kernel = da.real(dafft.fftshift(dafft.fft2(dafft.ifftshift(baseline_pb, axes=(3, 4)), axes=(3, 4)), axes=(3, 4)))
        conv_weight_kernel = da.real(dafft.fftshift(dafft.fft2(dafft.ifftshift(weight_baseline_pb_sqrd, axes=(3, 4)), axes=(3, 4)), axes=(3, 4)))
        
        list_conv_kernel = []
        list_weight_conv_kernel = []
        list_conv_support = []
        iter_chunks_indx = itertools.product(np.arange(n_chunks_in_each_dim[0]), np.arange(n_chunks_in_each_dim[1]))
        for c_chan, c_pol in iter_chunks_indx:
                delayed_kernels_and_support = dask.delayed(resize_and_calc_support)(conv_kernel.partitions[:,c_chan,c_pol,:,:],conv_weight_kernel.partitions[:,c_chan,c_pol,:,:],dask.delayed(_gcf_parms),dask.delayed(_grid_parms))
                list_conv_kernel.append(da.from_delayed(delayed_kernels_and_support[0],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]),dtype=np.double))
                list_weight_conv_kernel.append(da.from_delayed(delayed_kernels_and_support[1],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],_gcf_parms['resize_conv_size'][0],_gcf_parms['resize_conv_size'][1]),dtype=np.double))
                list_conv_support.append(da.from_delayed(delayed_kernels_and_support[2],(len(pb_ant_pairs),chan_chunk_sizes[0][c_chan], pol_chunk_sizes[0][c_pol],2),dtype=np.int))
                
        
        conv_kernel = da.concatenate(list_conv_kernel,axis=1)
        weight_conv_kernel = da.concatenate(list_weight_conv_kernel,axis=1)
        conv_support = da.concatenate(list_conv_support,axis=1)
        
    
        dataset_dict['SUPPORT'] = xr.DataArray(conv_support, dims=['conv_baseline','conv_chan','conv_pol','xy'])
        dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(weight_conv_kernel, dims=['conv_baseline','conv_chan','conv_pol','u','v'])
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(da.from_array(np.ones(_grid_parms['image_size']),chunks=_grid_parms['image_size']), dims=['l','m'])
    else:
        assert(False), "######### ERROR: At least 'a_term' or 'ps_term' must be true."
    
    ###########################################################
    #Make phase gradient (one for each field)
    field_phase_dir = _gcf_parms['field_phase_dir']
    field_phase_dir = da.from_array(field_phase_dir,chunks=(np.ceil(len(field_phase_dir)/_gcf_parms['a_chan_num_chunk']),2))
    
    phase_gradient = da.blockwise(make_phase_gradient, ("n_field","n_x","n_y"), field_phase_dir, ("n_field","2"), gcf_parms=_gcf_parms, grid_parms=_grid_parms, dtype=complex,  new_axes={"n_x": _gcf_parms['resize_conv_size'][0], "n_y": _gcf_parms['resize_conv_size'][1]})
    

    ###########################################################
    
    #coords = {'baseline': np.arange(n_unique_ant), 'chan': pb_freq, 'pol' : pb_pol, 'u': np.arange(resize_conv_size[0]), 'v': np.arange(resize_conv_size[1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_gridding_convolution_parms['imsize'][0]),'m':np.arange(_gridding_convolution_parms['imsize'][1])}
        
    #coords = { 'conv_chan': pb_freq, 'conv_pol' : pb_pol, 'u': np.arange(resize_conv_size[0]), 'v': np.arange(resize_conv_size[1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_gridding_convolution_parms['imsize'][0]),'m':np.arange(_gridding_convolution_parms['imsize'][1])}
    
    coords = { 'u': np.arange(_gcf_parms['resize_conv_size'][0]), 'v': np.arange(_gcf_parms['resize_conv_size'][1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_grid_parms['image_size'][0]),'m':np.arange(_grid_parms['image_size'][1])}
    
    dataset_dict['CF_BASELINE_MAP'] = xr.DataArray(cf_baseline_map, dims=('baseline')).chunk(_gcf_parms['vis_data_chunks'][1])
    dataset_dict['CF_CHAN_MAP'] = xr.DataArray(cf_chan_map, dims=('chan')).chunk(_gcf_parms['vis_data_chunks'][2])
    dataset_dict['CF_POL_MAP'] = xr.DataArray(cf_pol_map, dims=('pol')).chunk(_gcf_parms['vis_data_chunks'][3])
    
        
    dataset_dict['CONV_KERNEL'] = xr.DataArray(conv_kernel, dims=('conv_baseline','conv_chan','conv_pol','u','v'))
    dataset_dict['PHASE_GRADIENT'] = xr.DataArray(phase_gradient, dims=('field','u','v'))
    
    gcf_dataset = xr.Dataset(dataset_dict, coords=coords)
    gcf_dataset.attrs['cell_uv'] =1/(_grid_parms['image_size_padded']*_grid_parms['cell_size']*_gcf_parms['oversampling'])
    gcf_dataset.attrs['oversampling'] = _gcf_parms['oversampling']
    
    
    #list_xarray_data_variables = [gcf_dataset['A_TERM'],gcf_dataset['WEIGHT_A_TERM'],gcf_dataset['A_SUPPORT'],gcf_dataset['WEIGHT_A_SUPPORT'],gcf_dataset['PHASE_GRADIENT']]
    return _store(gcf_dataset,list_xarray_data_variables,_storage_parms)
Ejemplo n.º 9
0
    def image_tikhonov(self,
                       vis_arr,
                       sphere,
                       alpha,
                       scale=True,
                       usedask=False):
        n_s = sphere.pixels.shape[0]
        n_v = self.u_arr.shape[0]

        lambduh = alpha / np.sqrt(n_s)
        if not usedask:
            gamma = self.make_gamma(sphere)
            logger.info("augmented: {}".format(gamma.shape))

            vis_aux = vis_to_real(vis_arr)
            logger.info("vis mean: {} shape: {}".format(
                np.mean(vis_aux), vis_aux.shape))

            tol = min(alpha / 1e4, 1e-10)
            logger.info("Solving tol={} ...".format(tol))

            # reg = linear_model.ElasticNet(alpha=alpha/np.sqrt(n_s),
            # tol=1e-6,
            # l1_ratio = 0.01,
            # max_iter=100000,
            # positive=True)
            if False:
                (
                    sky,
                    lstop,
                    itn,
                    r1norm,
                    r2norm,
                    anorm,
                    acond,
                    arnorm,
                    xnorm,
                    var,
                ) = scipy.sparse.linalg.lsqr(gamma,
                                             vis_aux,
                                             damp=alpha,
                                             show=True)
                logger.info(
                    "Alpha: {}: Iterations: {}: rnorm: {}: xnorm: {}".format(
                        alpha, itn, r2norm, xnorm))
            else:
                reg = linear_model.Ridge(alpha=alpha,
                                         tol=tol,
                                         solver="lsqr",
                                         max_iter=100000)

                reg.fit(gamma, vis_aux)
                logger.info("    Solve Complete, iter={}".format(reg.n_iter_))

                sky = da.from_array(reg.coef_)

                residual = vis_aux - gamma @ sky

                sky, residual_norm, solution_norm = da.compute(
                    sky,
                    np.linalg.norm(residual)**2,
                    np.linalg.norm(sky)**2)

                score = reg.score(gamma, vis_aux)
                logger.info("Alpha: {}: Loss: {}: rnorm: {}: snorm: {}".format(
                    alpha, score, residual_norm, solution_norm))

        else:
            from dask_ml.linear_model import LinearRegression
            import dask_glm
            from dask.distributed import Client, LocalCluster
            from dask.diagnostics import ProgressBar
            import dask

            logger.info("Starting Dask Client")

            if True:
                cluster = LocalCluster(dashboard_address=":8231",
                                       processes=False)
                client = Client(cluster)
            else:
                client = Client("tcp://localhost:8786")

            logger.info("Client = {}".format(client))

            harmonic_list = []
            p2j = 2 * np.pi * 1.0j

            dl = sphere.l
            dm = sphere.m
            dn = sphere.n

            n_arr_minus_1 = dn - 1

            du = self.u_arr
            dv = self.v_arr
            dw = self.w_arr

            for u, v, w in zip(du, dv, dw):
                harmonic = da.from_array(
                    np.exp(p2j * (u * dl + v * dm + w * n_arr_minus_1)) /
                    np.sqrt(sphere.npix),
                    chunks=(n_s, ),
                )
                harminc = client.persist(harmonic)
                harmonic_list.append(harmonic)

            gamma = da.stack(harmonic_list)
            logger.info("Gamma Shape: {}".format(gamma.shape))
            # gamma = gamma.reshape((n_v, n_s))
            gamma = gamma.conj()
            gamma = client.persist(gamma)

            logger.info("Gamma Shape: {}".format(gamma.shape))

            logger.info("Building Augmented Operator...")
            proj_operator_real = da.real(gamma)
            proj_operator_imag = da.imag(gamma)
            proj_operator = da.block([[proj_operator_real],
                                      [proj_operator_imag]])

            proj_operator = client.persist(proj_operator)

            logger.info("Proj Operator shape {}".format(proj_operator.shape))
            vis_aux = da.from_array(
                np.array(
                    np.concatenate((np.real(vis_arr), np.imag(vis_arr))),
                    dtype=np.float32,
                ))

            # logger.info("Solving...")

            en = dask_glm.regularizers.ElasticNet(weight=0.01)
            en = dask_glm.regularizers.L2()
            # dT = da.from_array(proj_operator, chunks=(-1, 'auto'))
            ##dT = da.from_array(proj_operator, chunks=(-1, 'auto'))
            # dv = da.from_array(vis_aux)

            dask.config.set({"array.chunk-size": "1024MiB"})
            A = da.rechunk(proj_operator, chunks=("auto", n_s))
            A = client.persist(A)
            y = vis_aux  # da.rechunk(vis_aux, chunks=('auto', n_s))
            y = client.persist(y)
            # sky = dask_glm.algorithms.proximal_grad(A, y, regularizer=en, lambduh=alpha, max_iter=10000)

            logger.info("Rechunking completed.. A= {}.".format(A.shape))
            reg = LinearRegression(
                penalty=en,
                C=1.0 / lambduh,
                fit_intercept=False,
                solver="lbfgs",
                max_iter=1000,
                tol=1e-8,
            )
            sky = reg.fit(A, y)
            sky = reg.coef_
            score = reg.score(proj_operator, vis_aux)
            logger.info("Loss function: {}".format(score.compute()))

        logger.info("Solving Complete: sky = {}".format(sky.shape))

        sphere.set_visible_pixels(sky, scale=False)
        return sky.reshape(-1, 1)
Ejemplo n.º 10
0
        assert mapper is not None
        self.fullname, self.unit, self.mapper, self.column, self.extras = fullname, unit, mapper, column, extras
        self.conjugate = conjugate
        self.axis = axis


_identity = lambda x: x

# this dict maps short axis names into full DataMapper objects
data_mappers = OrderedDict(
    _=DataMapper("", "", _identity),
    amp=DataMapper("Amplitude", "", abs),
    logamp=DataMapper("Log-amplitude", "", lambda x: da.log10(abs(x))),
    phase=DataMapper(
        "Phase", "deg",
        lambda x: da.arctan2(da.imag(x), da.real(x)) * 180 / math.pi),
    real=DataMapper("Real", "", da.real),
    imag=DataMapper("Imag", "", da.imag),
    TIME=DataMapper("Time", "s", axis=0, column="TIME", mapper=_identity),
    ROW=DataMapper("Row number",
                   "",
                   column=False,
                   axis=0,
                   extras=["rows"],
                   mapper=lambda x, rows: rows),
    BASELINE=DataMapper("Baseline",
                        "",
                        column=False,
                        axis=0,
                        extras=["baselines"],
                        mapper=lambda x, baselines: baselines),
Ejemplo n.º 11
0
    def image_tikhonov(self, vis_arr, sphere, alpha, scale=True, usedask=False):
        n_s = sphere.pixels.shape[0]
        n_v = self.u_arr.shape[0]
        
        lambduh = alpha/np.sqrt(n_s)
        if not usedask:
            gamma = self.make_gamma(sphere)
            logger.info("Building Augmented Operator...")
            proj_operator_real = np.real(gamma).astype(np.float32)
            proj_operator_imag = np.imag(gamma).astype(np.float32)
            gamma = None
            proj_operator = np.block([[proj_operator_real], [proj_operator_imag]])
            proj_operator_real = None
            proj_operator_imag = None 
            logger.info('augmented: {}'.format(proj_operator.shape))
            
            vis_aux = np.array(np.concatenate((np.real(vis_arr), np.imag(vis_arr))), dtype=np.float32)
            logger.info('vis mean: {} shape: {}'.format(np.mean(vis_aux), vis_aux.shape))

            logger.info("Solving...")
            reg = linear_model.ElasticNet(alpha=lambduh, l1_ratio=0.05, max_iter=10000, positive=True)
            reg.fit(proj_operator, vis_aux)
            sky = reg.coef_
            
            score = reg.score(proj_operator, vis_aux)
            logger.info('Loss function: {}'.format(score))
            
        else:
            from dask_ml.linear_model import LinearRegression
            import dask_glm
            import dask.array as da
            from dask.distributed import Client, LocalCluster
            from dask.diagnostics import ProgressBar
            import dask
            
            logger.info('Starting Dask Client')
            
            if True:
                cluster = LocalCluster(dashboard_address=':8231', processes=False)
                client = Client(cluster)
            else:
                client = Client('tcp://localhost:8786')
                
            logger.info("Client = {}".format(client))
            
            harmonic_list = []
            p2j = 2*np.pi*1.0j
            
            dl = sphere.l
            dm = sphere.m
            dn = sphere.n
        
            n_arr_minus_1 = dn - 1

            du = self.u_arr
            dv = self.v_arr
            dw = self.w_arr
        
            for u, v, w in zip(du, dv, dw):
                harmonic = da.from_array(np.exp(p2j*(u*dl + v*dm + w*n_arr_minus_1)) / np.sqrt(sphere.npix), chunks=(n_s,))
                harminc = client.persist(harmonic)
                harmonic_list.append(harmonic)

            gamma = da.stack(harmonic_list)
            logger.info('Gamma Shape: {}'.format(gamma.shape))
            #gamma = gamma.reshape((n_v, n_s))
            gamma = gamma.conj()
            gamma = client.persist(gamma)
            
            logger.info('Gamma Shape: {}'.format(gamma.shape))
            
            logger.info("Building Augmented Operator...")
            proj_operator_real = da.real(gamma)
            proj_operator_imag = da.imag(gamma)
            proj_operator = da.block([[proj_operator_real], [proj_operator_imag]])
            
            proj_operator = client.persist(proj_operator)
            
            logger.info("Proj Operator shape {}".format(proj_operator.shape))
            vis_aux = da.from_array(np.array(np.concatenate((np.real(vis_arr), np.imag(vis_arr))), dtype=np.float32))
            
            #logger.info("Solving...")

            
            en = dask_glm.regularizers.ElasticNet(weight=0.01)
            en =  dask_glm.regularizers.L2()
            #dT = da.from_array(proj_operator, chunks=(-1, 'auto'))
            ##dT = da.from_array(proj_operator, chunks=(-1, 'auto'))
            #dv = da.from_array(vis_aux)
            

            dask.config.set({'array.chunk-size': '1024MiB'})
            A = da.rechunk(proj_operator, chunks=('auto', n_s))
            A = client.persist(A)
            y = vis_aux # da.rechunk(vis_aux, chunks=('auto', n_s))
            y = client.persist(y)
            #sky = dask_glm.algorithms.proximal_grad(A, y, regularizer=en, lambduh=alpha, max_iter=10000)

            logger.info("Rechunking completed.. A= {}.".format(A.shape))
            reg =  LinearRegression(penalty=en, C=1.0/lambduh,  
                                    fit_intercept=False, 
                                    solver='lbfgs', 
                                    max_iter=1000, tol=1e-8 )
            sky = reg.fit(A, y)
            sky = reg.coef_
            score = reg.score(proj_operator, vis_aux)
            logger.info('Loss function: {}'.format(score.compute()))

        logger.info("Solving Complete: sky = {}".format(sky.shape))

        sphere.set_visible_pixels(sky, scale=True)
        return sky.reshape(-1,1)
def make_gridding_convolution_function(mxds, gcf_parms, grid_parms, sel_parms):
    """
    Currently creates a gcf to correct for the primary beams of antennas and supports heterogenous arrays (antennas with different dish sizes).
    Only the airy disk and ALMA airy disk model is implemented.
    In the future support will be added for beam squint, pointing corrections, w projection, and including a prolate spheroidal term.
    
    Parameters
    ----------
    vis_dataset : xarray.core.dataset.Dataset
        Input visibility dataset.
    gcf_parms : dictionary
    gcf_parms['function'] : {'casa_airy'/'airy'}, default = 'casa_airy'
        The primary beam model used (a function of the dish diameter and blockage diameter).
    gcf_parms['list_dish_diameters']  : list of number, units = meter
        A list of unique antenna dish diameters.
    gcf_parms['list_blockage_diameters']  : list of number, units = meter
        A list of unique feed blockage diameters (must be the same length as gcf_parms['list_dish_diameters']).
    gcf_parms['unique_ant_indx']  : list of int
        A list that has indeces for the gcf_parms['list_dish_diameters'] and gcf_parms['list_blockage_diameters'] lists, for each antenna.
    gcf_parms['image_phase_center']  : list of number, length = 2, units = radians
        The mosaic image phase center.
    gcf_parms['a_chan_num_chunk']  : int, default = 3
        The number of chunks in the channel dimension of the gridding convolution function data variable.
    gcf_parms['oversampling']  : list of int, length = 2, default = [10,10]
        The oversampling of the gridding convolution function.
    gcf_parms['max_support']  : list of int, length = 2, default = [15,15]
        The maximum allowable support of the gridding convolution function.
    gcf_parms['support_cut_level']  : number, default = 0.025
        The antennuation at which to truncate the gridding convolution function.
    gcf_parms['chan_tolerance_factor']  : number, default = 0.005
        It is the fractional bandwidth at which the frequency dependence of the primary beam can be ignored and determines the number of frequencies for which to calculate a gridding convolution function. Number of channels equals the fractional bandwidth devided by gcf_parms['chan_tolerance_factor'].
    grid_parms : dictionary
    grid_parms['image_size'] : list of int, length = 2
        The image size (no padding).
    grid_parms['cell_size']  : list of number, length = 2, units = arcseconds
        The image cell size.
    Returns
    -------
    gcf_dataset : xarray.core.dataset.Dataset
            
    """
    print(
        '######################### Start make_gridding_convolution_function #########################'
    )

    from ._imaging_utils._check_imaging_parms import _check_pb_parms
    from cngi._utils._check_parms import _check_sel_parms, _check_existence_sel_parms
    from ._imaging_utils._check_imaging_parms import _check_grid_parms, _check_gcf_parms
    from ._imaging_utils._gridding_convolutional_kernels import _create_prolate_spheroidal_kernel_2D, _create_prolate_spheroidal_image_2D
    from ._imaging_utils._remove_padding import _remove_padding
    import numpy as np
    import dask.array as da
    import copy, os
    import xarray as xr
    import itertools
    import dask
    import dask.array.fft as dafft
    import time

    import matplotlib.pylab as plt

    #Deep copy so that inputs are not modified
    _mxds = mxds.copy(deep=True)
    _gcf_parms = copy.deepcopy(gcf_parms)
    _grid_parms = copy.deepcopy(grid_parms)
    _sel_parms = copy.deepcopy(sel_parms)

    ##############Parameter Checking and Set Defaults##############
    assert (
        'xds' in _sel_parms
    ), "######### ERROR: xds must be specified in sel_parms"  #Can't have a default since xds names are not fixed.
    _vis_dataset = _mxds.attrs[sel_parms['xds']]

    assert (
        'xds' in _sel_parms
    ), "######### ERROR: xds must be specified in sel_parms"  #Can't have a default since xds names are not fixed.
    _vis_dataset = _mxds.attrs[sel_parms['xds']]

    _check_sel_parms(_vis_dataset, _sel_parms)

    #_gcf_parms['basline_ant'] = np.unique([_vis_dataset.ANTENNA1.max(axis=0), _vis_dataset.ANTENNA2.max(axis=0)], axis=0).T
    _gcf_parms['basline_ant'] = np.array(
        [_vis_dataset.ANTENNA1.values, _vis_dataset.ANTENNA2.values]).T

    _gcf_parms['freq_chan'] = _vis_dataset.chan.values
    _gcf_parms['pol'] = _vis_dataset.pol.values
    _gcf_parms['vis_data_chunks'] = _vis_dataset.DATA.chunks

    _gcf_parms['field_phase_dir'] = mxds.FIELD.PHASE_DIR[:,
                                                         0, :].data.compute()
    field_id = mxds.FIELD.field_id.data  #.compute()

    #print(_gcf_parms['field_phase_dir'])
    #_gcf_parms['field_phase_dir'] = np.array(global_dataset.FIELD_PHASE_DIR.values[:,:,vis_dataset.attrs['ddi']])

    assert (_check_gcf_parms(_gcf_parms)
            ), "######### ERROR: gcf_parms checking failed"
    assert (_check_grid_parms(_grid_parms)
            ), "######### ERROR: grid_parms checking failed"

    if _gcf_parms['function'] == 'airy':
        from ._imaging_utils._make_pb_symmetric import _airy_disk_rorder
        pb_func = _airy_disk_rorder
    elif _gcf_parms['function'] == 'casa_airy':
        from ._imaging_utils._make_pb_symmetric import _casa_airy_disk_rorder
        pb_func = _casa_airy_disk_rorder
    else:
        assert (
            False
        ), "######### ERROR: Only airy and casa_airy function has been implemented"

    #For now only a_term works
    _gcf_parms['a_term'] = True
    _gcf_parms['ps_term'] = False

    _gcf_parms['resize_conv_size'] = (_gcf_parms['max_support'] +
                                      1) * _gcf_parms['oversampling']
    #resize_conv_size = _gcf_parms['resize_conv_size']

    if _gcf_parms['ps_term'] == True:
        '''
        ps_term = _create_prolate_spheroidal_kernel_2D(_gcf_parms['oversampling'],np.array([7,7])) #This is only used with a_term == False. Support is hardcoded to 7 until old ps code is replaced by a general function.
        center = _grid_parms['image_center']
        center_embed = np.array(ps_term.shape)//2
        ps_term_padded = np.zeros(_grid_parms['image_size'])
        ps_term_padded[center[0]-center_embed[0]:center[0]+center_embed[0],center[1]-center_embed[1] : center[1]+center_embed[1]] = ps_term
        ps_term_padded_ifft = dafft.fftshift(dafft.ifft2(dafft.ifftshift(da.from_array(ps_term_padded))))

        ps_image = da.from_array(_remove_padding(_create_prolate_spheroidal_image_2D(_grid_parms['image_size_padded']),_grid_parms['image_size']),chunks=_grid_parms['image_size'])

        #Effecively no mapping needed if ps_term == True and a_term == False
        cf_baseline_map = np.zeros((len(_gcf_parms['basline_ant']),),dtype=int)
        cf_chan_map = np.zeros((len(_gcf_parms['freq_chan']),),dtype=int)
        cf_pol_map = np.zeros((len(_gcf_parms['pol']),),dtype=int)
        '''

    if _gcf_parms['a_term'] == True:
        n_unique_ant = len(_gcf_parms['list_dish_diameters'])

        cf_baseline_map, pb_ant_pairs = create_cf_baseline_map(
            _gcf_parms['unique_ant_indx'], _gcf_parms['basline_ant'],
            n_unique_ant)

        cf_chan_map, pb_freq = create_cf_chan_map(
            _gcf_parms['freq_chan'], _gcf_parms['chan_tolerance_factor'])
        #print('****',pb_freq)
        pb_freq = da.from_array(
            pb_freq,
            chunks=np.ceil(len(pb_freq) / _gcf_parms['a_chan_num_chunk']))

        cf_pol_map = np.zeros(
            (len(_gcf_parms['pol']), ), dtype=int
        )  #create_cf_pol_map(), currently treating all pols the same
        pb_pol = da.from_array(np.array([0]), 1)

        n_chunks_in_each_dim = [pb_freq.numblocks[0], pb_pol.numblocks[0]]
        iter_chunks_indx = itertools.product(
            np.arange(n_chunks_in_each_dim[0]),
            np.arange(n_chunks_in_each_dim[1]))
        chan_chunk_sizes = pb_freq.chunks
        pol_chunk_sizes = pb_pol.chunks

        #print(pb_freq, pb_pol,pol_chunk_sizes)
        list_baseline_pb = []
        list_weight_baseline_pb_sqrd = []
        for c_chan, c_pol in iter_chunks_indx:
            #print('chan, pol ',c_chan,c_pol)
            _gcf_parms['ipower'] = 1
            delayed_baseline_pb = dask.delayed(make_baseline_patterns)(
                pb_freq.partitions[c_chan], pb_pol.partitions[c_pol],
                dask.delayed(pb_ant_pairs), dask.delayed(pb_func),
                dask.delayed(_gcf_parms), dask.delayed(_grid_parms))

            list_baseline_pb.append(
                da.from_delayed(
                    delayed_baseline_pb,
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol],
                     _grid_parms['image_size_padded'][0],
                     _grid_parms['image_size_padded'][1]),
                    dtype=np.double))

            _gcf_parms['ipower'] = 2
            delayed_weight_baseline_pb_sqrd = dask.delayed(
                make_baseline_patterns)(pb_freq.partitions[c_chan],
                                        pb_pol.partitions[c_pol],
                                        dask.delayed(pb_ant_pairs),
                                        dask.delayed(pb_func),
                                        dask.delayed(_gcf_parms),
                                        dask.delayed(_grid_parms))

            list_weight_baseline_pb_sqrd.append(
                da.from_delayed(
                    delayed_weight_baseline_pb_sqrd,
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol],
                     _grid_parms['image_size_padded'][0],
                     _grid_parms['image_size_padded'][1]),
                    dtype=np.double))

        baseline_pb = da.concatenate(list_baseline_pb, axis=1)
        weight_baseline_pb_sqrd = da.concatenate(list_weight_baseline_pb_sqrd,
                                                 axis=1)

#    x = baseline_pb.compute()
#    print("&*&*&*&",x.shape)
#    plt.figure()
#    plt.imshow(x[0,0,0,240:260,240:260])
#    plt.show()

#Combine patterns and fft to obtain the gridding convolutional kernel
#print(weight_baseline_pb_sqrd)

    dataset_dict = {}
    list_xarray_data_variables = []
    if (_gcf_parms['a_term'] == True) and (_gcf_parms['ps_term'] == True):
        conv_kernel = da.real(
            dafft.fftshift(dafft.fft2(dafft.ifftshift(ps_term_padded_ifft *
                                                      baseline_pb,
                                                      axes=(3, 4)),
                                      axes=(3, 4)),
                           axes=(3, 4)))
        conv_weight_kernel = da.real(
            dafft.fftshift(dafft.fft2(dafft.ifftshift(weight_baseline_pb_sqrd,
                                                      axes=(3, 4)),
                                      axes=(3, 4)),
                           axes=(3, 4)))

        list_conv_kernel = []
        list_weight_conv_kernel = []
        list_conv_support = []
        iter_chunks_indx = itertools.product(
            np.arange(n_chunks_in_each_dim[0]),
            np.arange(n_chunks_in_each_dim[1]))
        for c_chan, c_pol in iter_chunks_indx:
            delayed_kernels_and_support = dask.delayed(
                resize_and_calc_support)(
                    conv_kernel.partitions[:, c_chan, c_pol, :, :],
                    conv_weight_kernel.partitions[:, c_chan, c_pol, :, :],
                    dask.delayed(_gcf_parms), dask.delayed(_grid_parms))
            list_conv_kernel.append(
                da.from_delayed(
                    delayed_kernels_and_support[0],
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol],
                     _gcf_parms['resize_conv_size'][0],
                     _gcf_parms['resize_conv_size'][1]),
                    dtype=np.double))
            list_weight_conv_kernel.append(
                da.from_delayed(
                    delayed_kernels_and_support[1],
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol],
                     _gcf_parms['resize_conv_size'][0],
                     _gcf_parms['resize_conv_size'][1]),
                    dtype=np.double))
            list_conv_support.append(
                da.from_delayed(
                    delayed_kernels_and_support[2],
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol], 2),
                    dtype=np.int))

        conv_kernel = da.concatenate(list_conv_kernel, axis=1)
        weight_conv_kernel = da.concatenate(list_weight_conv_kernel, axis=1)
        conv_support = da.concatenate(list_conv_support, axis=1)

        dataset_dict['SUPPORT'] = xr.DataArray(
            conv_support,
            dims=['conv_baseline', 'conv_chan', 'conv_pol', 'xy'])
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(ps_image, dims=['l', 'm'])
        dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(
            weight_conv_kernel,
            dims=['conv_baseline', 'conv_chan', 'conv_pol', 'u', 'v'])
    elif (_gcf_parms['a_term'] == False) and (_gcf_parms['ps_term'] == True):
        support = np.array([7, 7])
        dataset_dict['SUPPORT'] = xr.DataArray(
            support[None, None, None, :],
            dims=['conv_baseline', 'conv_chan', 'conv_pol', 'xy'])
        conv_kernel = np.zeros((1, 1, 1, _gcf_parms['resize_conv_size'][0],
                                _gcf_parms['resize_conv_size'][1]))
        center = _gcf_parms['resize_conv_size'] // 2
        center_embed = np.array(ps_term.shape) // 2
        conv_kernel[0, 0, 0,
                    center[0] - center_embed[0]:center[0] + center_embed[0],
                    center[1] - center_embed[1]:center[1] +
                    center_embed[1]] = ps_term
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(ps_image, dims=['l', 'm'])
        ##Enabled for test
        #dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(conv_kernel, dims=['conv_baseline','conv_chan','conv_pol','u','v'])
    elif (_gcf_parms['a_term'] == True) and (_gcf_parms['ps_term'] == False):
        conv_kernel = da.real(
            dafft.fftshift(dafft.fft2(dafft.ifftshift(baseline_pb,
                                                      axes=(3, 4)),
                                      axes=(3, 4)),
                           axes=(3, 4)))
        conv_weight_kernel = da.real(
            dafft.fftshift(dafft.fft2(dafft.ifftshift(weight_baseline_pb_sqrd,
                                                      axes=(3, 4)),
                                      axes=(3, 4)),
                           axes=(3, 4)))

        #        x = conv_weight_kernel.compute()
        #        print("&*&*&*&",x.shape)
        #        plt.figure()
        #        #plt.imshow(x[0,0,0,240:260,240:260])
        #        plt.imshow(x[0,0,0,:,:])
        #        plt.show()

        list_conv_kernel = []
        list_weight_conv_kernel = []
        list_conv_support = []
        iter_chunks_indx = itertools.product(
            np.arange(n_chunks_in_each_dim[0]),
            np.arange(n_chunks_in_each_dim[1]))
        for c_chan, c_pol in iter_chunks_indx:
            delayed_kernels_and_support = dask.delayed(
                resize_and_calc_support)(
                    conv_kernel.partitions[:, c_chan, c_pol, :, :],
                    conv_weight_kernel.partitions[:, c_chan, c_pol, :, :],
                    dask.delayed(_gcf_parms), dask.delayed(_grid_parms))
            list_conv_kernel.append(
                da.from_delayed(
                    delayed_kernels_and_support[0],
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol],
                     _gcf_parms['resize_conv_size'][0],
                     _gcf_parms['resize_conv_size'][1]),
                    dtype=np.double))
            list_weight_conv_kernel.append(
                da.from_delayed(
                    delayed_kernels_and_support[1],
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol],
                     _gcf_parms['resize_conv_size'][0],
                     _gcf_parms['resize_conv_size'][1]),
                    dtype=np.double))
            list_conv_support.append(
                da.from_delayed(
                    delayed_kernels_and_support[2],
                    (len(pb_ant_pairs), chan_chunk_sizes[0][c_chan],
                     pol_chunk_sizes[0][c_pol], 2),
                    dtype=np.int))

        conv_kernel = da.concatenate(list_conv_kernel, axis=1)
        weight_conv_kernel = da.concatenate(list_weight_conv_kernel, axis=1)
        conv_support = da.concatenate(list_conv_support, axis=1)

        #        x = weight_conv_kernel.compute()
        #        print("&*&*&*&",x.shape)
        #        plt.figure()
        #        #plt.imshow(x[0,0,0,240:260,240:260])
        #        plt.imshow(x[0,0,0,:,:])
        #        plt.show()

        dataset_dict['SUPPORT'] = xr.DataArray(
            conv_support,
            dims=['conv_baseline', 'conv_chan', 'conv_pol', 'xy'])
        dataset_dict['WEIGHT_CONV_KERNEL'] = xr.DataArray(
            weight_conv_kernel,
            dims=['conv_baseline', 'conv_chan', 'conv_pol', 'u', 'v'])
        dataset_dict['PS_CORR_IMAGE'] = xr.DataArray(da.from_array(
            np.ones(_grid_parms['image_size']),
            chunks=_grid_parms['image_size']),
                                                     dims=['l', 'm'])
    else:
        assert (
            False
        ), "######### ERROR: At least 'a_term' or 'ps_term' must be true."

    ###########################################################
    #Make phase gradient (one for each field)
    field_phase_dir = _gcf_parms['field_phase_dir']
    field_phase_dir = da.from_array(
        field_phase_dir,
        chunks=(np.ceil(len(field_phase_dir) / _gcf_parms['a_chan_num_chunk']),
                2))

    phase_gradient = da.blockwise(make_phase_gradient,
                                  ("n_field", "n_x", "n_y"),
                                  field_phase_dir, ("n_field", "2"),
                                  gcf_parms=_gcf_parms,
                                  grid_parms=_grid_parms,
                                  dtype=complex,
                                  new_axes={
                                      "n_x": _gcf_parms['resize_conv_size'][0],
                                      "n_y": _gcf_parms['resize_conv_size'][1]
                                  })

    ###########################################################

    #coords = {'baseline': np.arange(n_unique_ant), 'chan': pb_freq, 'pol' : pb_pol, 'u': np.arange(resize_conv_size[0]), 'v': np.arange(resize_conv_size[1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_gridding_convolution_parms['imsize'][0]),'m':np.arange(_gridding_convolution_parms['imsize'][1])}

    #coords = { 'conv_chan': pb_freq, 'conv_pol' : pb_pol, 'u': np.arange(resize_conv_size[0]), 'v': np.arange(resize_conv_size[1]), 'xy':np.arange(2), 'field':np.arange(field_phase_dir.shape[0]),'l':np.arange(_gridding_convolution_parms['imsize'][0]),'m':np.arange(_gridding_convolution_parms['imsize'][1])}

    coords = {
        'u': np.arange(_gcf_parms['resize_conv_size'][0]),
        'v': np.arange(_gcf_parms['resize_conv_size'][1]),
        'xy': np.arange(2),
        'field_id': field_id,
        'l': np.arange(_grid_parms['image_size'][0]),
        'm': np.arange(_grid_parms['image_size'][1])
    }

    dataset_dict['CF_BASELINE_MAP'] = xr.DataArray(
        cf_baseline_map,
        dims=('baseline')).chunk(_gcf_parms['vis_data_chunks'][1])
    dataset_dict['CF_CHAN_MAP'] = xr.DataArray(
        cf_chan_map, dims=('chan')).chunk(_gcf_parms['vis_data_chunks'][2])
    dataset_dict['CF_POL_MAP'] = xr.DataArray(cf_pol_map, dims=('pol')).chunk(
        _gcf_parms['vis_data_chunks'][3])

    dataset_dict['CONV_KERNEL'] = xr.DataArray(conv_kernel,
                                               dims=('conv_baseline',
                                                     'conv_chan', 'conv_pol',
                                                     'u', 'v'))
    dataset_dict['PHASE_GRADIENT'] = xr.DataArray(phase_gradient,
                                                  dims=('field_id', 'u', 'v'))

    #print(field_id)
    gcf_dataset = xr.Dataset(dataset_dict, coords=coords)
    gcf_dataset.attrs['cell_uv'] = 1 / (_grid_parms['image_size_padded'] *
                                        _grid_parms['cell_size'] *
                                        _gcf_parms['oversampling'])
    gcf_dataset.attrs['oversampling'] = _gcf_parms['oversampling']

    #list_xarray_data_variables = [gcf_dataset['A_TERM'],gcf_dataset['WEIGHT_A_TERM'],gcf_dataset['A_SUPPORT'],gcf_dataset['WEIGHT_A_SUPPORT'],gcf_dataset['PHASE_GRADIENT']]
    #return _store(gcf_dataset,list_xarray_data_variables,_storage_parms)

    print(
        '#########################  Created graph for make_gridding_convolution_function #########################'
    )

    return gcf_dataset