예제 #1
0
def test_plot_data():
    dp = DataPack(
        '/home/albert/data/gains_screen/data/L342938_DDS5_full_merged.h5',
        readonly=True)
    with dp:
        select = dict(pol=slice(0, 1, 1), ant=[0, 50], time=slice(0, 100, 1))
        dp.current_solset = 'sol000'
        dp.select(**select)
        tec_mean, axes = dp.tec
        tec_mean = tec_mean[0, ...]
        const_mean, axes = dp.const
        const_mean = const_mean[0, ...]
        tec_std, axes = dp.weights_tec
        tec_std = tec_std[0, ...]
        patch_names, directions = dp.get_directions(axes['dir'])
        antenna_labels, antennas = dp.get_antennas(axes['ant'])
        timestamps, times = dp.get_times(axes['time'])
    antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    ref_ant = antennas[0]

    for i, time in enumerate(times):
        frame = ENU(obstime=time, location=ref_ant.earth_location)
        antennas = antennas.transform_to(frame)
        directions = directions.transform_to(frame)
        x = antennas.cartesian.xyz.to(au.km).value.T
        k = directions.cartesian.xyz.value.T

        dtec = tec_mean[:, 1, i]
        ax = plot_vornoi_map(k[:, 0:2], dtec)
        ax.set_title(time)
        plt.show()
def itrs_to_enu_6D(X, ref_location=None):
    """
    Convert the given coordinates from ITRS to ENU

    :param X: float array [b0,...,bB,6]
        The coordinates are ordered [time, ra, dec, itrs.x, itrs.y, itrs.z]
    :param ref_location: float array [3]
        Point about which to rotate frame.
    :return: float array [b0,...,bB, 7]
        The transforms coordinates.
    """
    time = np.unique(X[..., 0])
    if time.size > 1:
        raise ValueError("Times should be the same.")

    shape = X.shape[:-1]
    X = X.reshape((-1, 6))
    if ref_location is None:
        ref_location = X[0, 3:]
    obstime = at.Time(time / 86400., format='mjd')
    location = ac.ITRS(x=ref_location[0] * dist_type,
                       y=ref_location[1] * dist_type,
                       z=ref_location[2] * dist_type)
    enu = ENU(location=location, obstime=obstime)
    antennas = ac.ITRS(x=X[:, 3] * dist_type,
                       y=X[:, 4] * dist_type,
                       z=X[:, 5] * dist_type,
                       obstime=obstime)
    antennas = antennas.transform_to(enu).cartesian.xyz.to(dist_type).value.T
    directions = ac.ICRS(ra=X[:, 1] * angle_type, dec=X[:, 2] * angle_type)
    directions = directions.transform_to(enu).cartesian.xyz.value.T
    return np.concatenate([X[:, 0:1], directions, antennas],
                          axis=1).reshape(shape + (7, ))
예제 #3
0
def visualisation(h5parm, ant=None, time=None):
    with DataPack(h5parm, readonly=True) as dp:
        dp.current_solset = 'sol000'
        dp.select(ant=ant, time=time)
        dtec, axes = dp.tec
        dtec = dtec[0]
        patch_names, directions = dp.get_directions(axes['dir'])
        antenna_labels, antennas = dp.get_antennas(axes['ant'])
        timestamps, times = dp.get_times(axes['time'])

    frame = ENU(obstime=time, location=antennas[0].earth_location)
    directions = directions.transform_to(frame)
    t = times.mjd * 86400.
    t -= t[0]
    dt = np.diff(t).mean()
    x = antennas.cartesian.xyz.to(au.km).value.T[1:, :]
    # x[1,:] = x[0,:]
    # x[1,0] += 0.3
    k = directions.cartesian.xyz.value.T
    logger.info(f"Directions: {directions}")
    logger.info(f"Antennas: {x} {antenna_labels}")
    logger.info(f"Times: {t}")
    Na = x.shape[0]
    logger.info(f"Number of antenna to plot: {Na}")
    Nd = k.shape[0]
    Nt = t.shape[0]

    fig, axs = plt.subplots(Na,
                            Nt,
                            sharex=True,
                            sharey=True,
                            figsize=(2 * Nt, 2 * Na),
                            squeeze=False)

    for a in range(Na):
        for i in range(Nt):
            ax = axs[a][i]
            ax = plot_vornoi_map(k[:, 0:2],
                                 dtec[:, a, i],
                                 ax=ax,
                                 colorbar=False)
            if a == (Na - 1):
                ax.set_xlabel(r"$k_{\rm east}$")
            if i == 0:
                ax.set_ylabel(r"$k_{\rm north}$")
            if a == 0:
                ax.set_title(f"Time: {int(t[i])} sec")

    plt.show()
    def transform(X, ref_location=ref_location):
        """
        Convert the given coordinates from ITRS to ENU

        :param X: float array [Nd, Na,6]
            The coordinates are ordered [time, ra, dec, itrs.x, itrs.y, itrs.z]
        :return: float array [Nd, Na, 7(10(13))]
            The transforms coordinates.
        """
        time = np.unique(X[..., 0])
        if time.size > 1:
            raise ValueError("Times should be the same.")
        shape = X.shape[:-1]
        X = X.reshape((-1, 6))
        if ref_location is None:
            ref_location = X[0, 3:6]
        obstime = at.Time(time / 86400., format='mjd')
        location = ac.ITRS(x=ref_location[0] * dist_type,
                           y=ref_location[1] * dist_type,
                           z=ref_location[2] * dist_type)
        ref_ant = ac.ITRS(x=ref_antenna[0] * dist_type,
                          y=ref_antenna[1] * dist_type,
                          z=ref_antenna[2] * dist_type,
                          obstime=obstime)
        ref_dir = ac.ICRS(ra=ref_direction[0] * angle_type,
                          dec=ref_direction[1] * angle_type)
        enu = ENU(location=location, obstime=obstime)
        ref_ant = ref_ant.transform_to(enu).cartesian.xyz.to(dist_type).value.T
        ref_dir = ref_dir.transform_to(enu).cartesian.xyz.value.T
        antennas = ac.ITRS(x=X[:, 3] * dist_type,
                           y=X[:, 4] * dist_type,
                           z=X[:, 5] * dist_type,
                           obstime=obstime)
        antennas = antennas.transform_to(enu).cartesian.xyz.to(
            dist_type).value.T
        directions = ac.ICRS(ra=X[:, 1] * angle_type, dec=X[:, 2] * angle_type)
        directions = directions.transform_to(enu).cartesian.xyz.value.T
        result = np.concatenate([X[:, 0:1], directions, antennas], axis=1)  #
        if ref_antenna is not None:
            result = np.concatenate(
                [result, np.tile(ref_ant, (result.shape[0], 1))], axis=-1)
        if ref_direction is not None:
            result = np.concatenate(
                [result, np.tile(ref_dir, (result.shape[0], 1))], axis=-1)
        result = result.reshape(shape + result.shape[-1:])
        return result
예제 #5
0
def test_tomographic_kernel():
    dp = make_example_datapack(500, 24, 1, clobber=True)
    with dp:
        select = dict(pol=slice(0, 1, 1), ant=slice(0, None, 1))
        dp.current_solset = 'sol000'
        dp.select(**select)
        tec_mean, axes = dp.tec
        tec_mean = tec_mean[0, ...]
        patch_names, directions = dp.get_directions(axes['dir'])
        antenna_labels, antennas = dp.get_antennas(axes['ant'])
        timestamps, times = dp.get_times(axes['time'])
    antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    ref_ant = antennas[0]
    frame = ENU(obstime=times[0], location=ref_ant.earth_location)
    antennas = antennas.transform_to(frame)
    ref_ant = antennas[0]
    directions = directions.transform_to(frame)
    x = antennas.cartesian.xyz.to(au.km).value.T
    k = directions.cartesian.xyz.value.T
    X = make_coord_array(x[50:51, :], k)
    x0 = ref_ant.cartesian.xyz.to(au.km).value
    print(k.shape)

    kernel = TomographicKernel(x0, x0, RBF(), S_marg=25)
    K = jit(lambda X: kernel(
        X, X, bottom=200., width=50., fed_kernel_params=dict(l=7., sigma=1.)))(
            jnp.asarray(X))
    # K /= jnp.outer(jnp.sqrt(jnp.diag(K)), jnp.sqrt(jnp.diag(K)))
    plt.imshow(K)
    plt.colorbar()
    plt.show()
    L = jnp.linalg.cholesky(K + 1e-6 * jnp.eye(K.shape[0]))
    print(L)
    dtec = L @ random.normal(random.PRNGKey(24532), shape=(K.shape[0], ))
    print(jnp.std(dtec))
    ax = plot_vornoi_map(k[:, 0:2], dtec)
    ax.set_xlabel(r"$k_{\rm east}$")
    ax.set_ylabel(r"$k_{\rm north}$")
    ax.set_xlim(-0.1, 0.1)
    ax.set_ylim(-0.1, 0.1)
    plt.show()
def test_enu():
    lofar_array = np.random.normal(size=[10, 3])
    antennas = lofar_array[1]
    obstime = at.Time("2018-01-01T00:00:00.000", format='isot')
    location = ac.ITRS(x=antennas[0, 0] * dist_type,
                       y=antennas[0, 1] * dist_type,
                       z=antennas[0, 2] * dist_type)
    enu = ENU(obstime=obstime, location=location.earth_location)
    altaz = ac.AltAz(obstime=obstime, location=location.earth_location)
    lofar_antennas = ac.ITRS(x=antennas[:, 0] * dist_type,
                             y=antennas[:, 1] * dist_type,
                             z=antennas[:, 2] * dist_type,
                             obstime=obstime)
    assert np.all(
        np.linalg.norm(lofar_antennas.transform_to(enu).cartesian.xyz.to(
            dist_type).value,
                       axis=0) < 100.)
    assert np.all(
        np.isclose(
            lofar_antennas.transform_to(enu).cartesian.xyz.to(dist_type).value,
            lofar_antennas.transform_to(enu).transform_to(altaz).transform_to(
                enu).cartesian.xyz.to(dist_type).value))
    assert np.all(
        np.isclose(
            lofar_antennas.transform_to(altaz).cartesian.xyz.to(
                dist_type).value,
            lofar_antennas.transform_to(altaz).transform_to(enu).transform_to(
                altaz).cartesian.xyz.to(dist_type).value))
    north_enu = ac.SkyCoord(east=0., north=1., up=0., frame=enu)
    north_altaz = ac.SkyCoord(az=0 * au.deg,
                              alt=0 * au.deg,
                              distance=1.,
                              frame=altaz)
    assert np.all(
        np.isclose(
            north_enu.transform_to(altaz).cartesian.xyz.value,
            north_altaz.cartesian.xyz.value))
    assert np.all(
        np.isclose(north_enu.cartesian.xyz.value,
                   north_altaz.transform_to(enu).cartesian.xyz.value))
    east_enu = ac.SkyCoord(east=1., north=0., up=0., frame=enu)
    east_altaz = ac.SkyCoord(az=90 * au.deg,
                             alt=0 * au.deg,
                             distance=1.,
                             frame=altaz)
    assert np.all(
        np.isclose(
            east_enu.transform_to(altaz).cartesian.xyz.value,
            east_altaz.cartesian.xyz.value))
    assert np.all(
        np.isclose(east_enu.cartesian.xyz.value,
                   east_altaz.transform_to(enu).cartesian.xyz.value))
    up_enu = ac.SkyCoord(east=0., north=0., up=1., frame=enu)
    up_altaz = ac.SkyCoord(az=0 * au.deg,
                           alt=90 * au.deg,
                           distance=1.,
                           frame=altaz)
    assert np.all(
        np.isclose(
            up_enu.transform_to(altaz).cartesian.xyz.value,
            up_altaz.cartesian.xyz.value))
    assert np.all(
        np.isclose(up_enu.cartesian.xyz.value,
                   up_altaz.transform_to(enu).cartesian.xyz.value))
    ###
    # dimensionful
    north_enu = ac.SkyCoord(east=0. * dist_type,
                            north=1. * dist_type,
                            up=0. * dist_type,
                            frame=enu)
    north_altaz = ac.SkyCoord(az=0 * au.deg,
                              alt=0 * au.deg,
                              distance=1. * dist_type,
                              frame=altaz)
    assert np.all(
        np.isclose(
            north_enu.transform_to(altaz).cartesian.xyz.to(dist_type).value,
            north_altaz.cartesian.xyz.to(dist_type).value))
    assert np.all(
        np.isclose(
            north_enu.cartesian.xyz.to(dist_type).value,
            north_altaz.transform_to(enu).cartesian.xyz.to(dist_type).value))
    east_enu = ac.SkyCoord(east=1. * dist_type,
                           north=0. * dist_type,
                           up=0. * dist_type,
                           frame=enu)
    east_altaz = ac.SkyCoord(az=90 * au.deg,
                             alt=0 * au.deg,
                             distance=1. * dist_type,
                             frame=altaz)
    assert np.all(
        np.isclose(
            east_enu.transform_to(altaz).cartesian.xyz.to(dist_type).value,
            east_altaz.cartesian.xyz.to(dist_type).value))
    assert np.all(
        np.isclose(
            east_enu.cartesian.xyz.to(dist_type).value,
            east_altaz.transform_to(enu).cartesian.xyz.to(dist_type).value))
    up_enu = ac.SkyCoord(east=0. * dist_type,
                         north=0. * dist_type,
                         up=1. * dist_type,
                         frame=enu)
    up_altaz = ac.SkyCoord(az=0 * au.deg,
                           alt=90 * au.deg,
                           distance=1. * dist_type,
                           frame=altaz)
    assert np.all(
        np.isclose(
            up_enu.transform_to(altaz).cartesian.xyz.to(dist_type).value,
            up_altaz.cartesian.xyz.to(dist_type).value))
    assert np.all(
        np.isclose(
            up_enu.cartesian.xyz.to(dist_type).value,
            up_altaz.transform_to(enu).cartesian.xyz.to(dist_type).value))
예제 #7
0
def test_compare_with_forward_model():
    dp = make_example_datapack(5, 24, 1, clobber=True)
    with dp:
        select = dict(pol=slice(0, 1, 1), ant=slice(0, None, 1))
        dp.current_solset = 'sol000'
        dp.select(**select)
        tec_mean, axes = dp.tec
        tec_mean = tec_mean[0, ...]
        patch_names, directions = dp.get_directions(axes['dir'])
        antenna_labels, antennas = dp.get_antennas(axes['ant'])
        timestamps, times = dp.get_times(axes['time'])

    antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    ref_ant = antennas[0]
    earth_centre = ac.ITRS(x=0 * au.m,
                           y=0 * au.m,
                           z=0. * au.m,
                           obstime=times[0])
    frame = ENU(obstime=times[0], location=ref_ant.earth_location)
    antennas = antennas.transform_to(frame)
    earth_centre = earth_centre.transform_to(frame)
    ref_ant = antennas[0]
    directions = directions.transform_to(frame)
    x = antennas.cartesian.xyz.to(au.km).value.T[20:21, :]
    k = directions.cartesian.xyz.value.T
    X = make_coord_array(x, k)
    x0 = ref_ant.cartesian.xyz.to(au.km).value
    earth_centre = earth_centre.cartesian.xyz.to(au.km).value
    bottom = 200.
    width = 50.
    l = 10.
    sigma = 1.
    fed_kernel_params = dict(l=l, sigma=sigma)
    S_marg = 1000
    fed_kernel = M12()

    def get_points_on_rays(X):
        x = X[0:3]
        k = X[3:6]
        smin = (bottom - (x[2] - x0[2])) / k[2]
        smax = (bottom + width - (x[2] - x0[2])) / k[2]
        ds = (smax - smin)
        _x = x + k * smin
        _k = k * ds
        t = jnp.linspace(0., 1., S_marg + 1)
        return _x + _k * t[:, None], ds / S_marg

    points, ds = vmap(get_points_on_rays)(X)
    points = points.reshape((-1, 3))
    plt.scatter(points[:, 1], points[:, 2], marker='.')
    plt.show()

    K = fed_kernel(points, points, l, sigma)
    plt.imshow(K)
    plt.show()
    L = jnp.linalg.cholesky(K + 1e-6 * jnp.eye(K.shape[0]))
    Z = L @ random.normal(random.PRNGKey(1245214), (L.shape[0], 3000))
    Z = Z.reshape((directions.shape[0], -1, Z.shape[1]))
    Y = jnp.sum(Z * ds[:, None, None], axis=1)
    K = jnp.mean(Y[:, None, :] * Y[None, :, :], axis=2)
    # print("Directly Computed TEC Covariance",K)
    plt.imshow(K)
    plt.colorbar()
    plt.title("Directly Computed TEC Covariance")
    plt.show()

    # kernel = TomographicKernel(x0, x0,fed_kernel, S_marg=200, compute_tec=False)
    # K = kernel(X, X, bottom, width, fed_kernel_params)
    # plt.imshow(K)
    # plt.colorbar()
    # plt.title("Analytic Weighted TEC Covariance")
    # plt.show()
    #
    # print("Analytic Weighted TEC Covariance",K)
    # print(x0, earth_centre, fed_kernel)
    kernel = TomographicKernel(x0,
                               earth_centre,
                               fed_kernel,
                               S_marg=200,
                               compute_tec=True)
    # print(X)
    X1 = GeodesicTuple(x=X[:, 0:3],
                       k=X[:, 3:6],
                       t=jnp.zeros_like(X[:, :1]),
                       ref_x=x0)
    print(X1)
    K = kernel(X1, X1, bottom, width, fed_kernel_params)
    plt.imshow(K)
    plt.colorbar()
    plt.title("Analytic TEC Covariance")
    plt.show()
    print("Analytic TEC Covariance", K)
예제 #8
0
    def visualise_grid(self):
        for d in range(0, 5):
            with DataPack(self._input_datapack, readonly=True) as dp:
                dp.current_solset = 'sol000'
                dp.select(pol=slice(0, 1, 1), dir=d, time=0)
                tec_grid, axes = dp.tec
                tec_grid = tec_grid[0]
                patch_names, directions_grid = dp.get_directions(axes['dir'])
                antenna_labels, antennas_grid = dp.get_antennas(axes['ant'])
                ref_ant = antennas_grid[0]
                timestamps, times_grid = dp.get_times(axes['time'])
                frame = ENU(location=ref_ant.earth_location,
                            obstime=times_grid[0])
                antennas_grid = ac.ITRS(
                    *antennas_grid.cartesian.xyz,
                    obstime=times_grid[0]).transform_to(frame)

            ant_pos = antennas_grid.cartesian.xyz.to(au.km).value.T
            plt.scatter(ant_pos[:, 0],
                        ant_pos[:, 1],
                        c=tec_grid[0, :, 0],
                        cmap=plt.cm.PuOr)
            plt.xlabel('East [km]')
            plt.ylabel("North [km]")
            plt.title(f"Direction {repr(directions_grid)}")
            plt.show()

        ant_scatter_args = (ant_pos[:, 0], ant_pos[:, 1], tec_grid[0, :, 0])

        for a in [0, 50, 150, 200, 250]:
            with DataPack(self._input_datapack, readonly=True) as dp:
                dp.current_solset = 'sol000'
                dp.select(pol=slice(0, 1, 1), ant=a, time=0)
                tec_grid, axes = dp.tec
                tec_grid = tec_grid[0]
                patch_names, directions_grid = dp.get_directions(axes['dir'])
                antenna_labels, antennas_grid = dp.get_antennas(axes['ant'])
                timestamps, times_grid = dp.get_times(axes['time'])
                frame = ENU(location=ref_ant.earth_location,
                            obstime=times_grid[0])
                antennas_grid = ac.ITRS(
                    *antennas_grid.cartesian.xyz,
                    obstime=times_grid[0]).transform_to(frame)

            _ant_pos = antennas_grid.cartesian.xyz.to(au.km).value.T[0]

            fig, axs = plt.subplots(2, 1, figsize=(4, 8))
            axs[0].scatter(*ant_scatter_args[0:2],
                           c=ant_scatter_args[2],
                           cmap=plt.cm.PuOr,
                           alpha=0.5)
            axs[0].scatter(*_ant_pos[0:2], marker='x', c='red')
            axs[0].set_xlabel('East [km]')
            axs[0].set_ylabel('North [km]')
            pos = 180 / np.pi * np.stack(
                [wrap(directions_grid.ra.rad),
                 wrap(directions_grid.dec.rad)],
                axis=-1)
            plot_vornoi_map(pos, tec_grid[:, 0, 0], fov_circle=True, ax=axs[1])
            axs[1].set_xlabel('RA(2000) [ded]')
            axs[1].set_ylabel('DEC(2000) [ded]')
            plt.show()
def train_neural_network(datapack: DataPack, batch_size, learning_rate,
                         num_batches):

    with datapack:
        select = dict(pol=slice(0, 1, 1), ant=None, time=slice(0, 1, 1))
        datapack.current_solset = 'sol000'
        datapack.select(**select)
        axes = datapack.axes_tec
        patch_names, directions = datapack.get_directions(axes['dir'])
        antenna_labels, antennas = datapack.get_antennas(axes['ant'])
        timestamps, times = datapack.get_times(axes['time'])

    antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    ref_ant = antennas[0]
    frame = ENU(obstime=times[0], location=ref_ant.earth_location)
    antennas = antennas.transform_to(frame)
    ref_ant = antennas[0]
    directions = directions.transform_to(frame)
    x = antennas.cartesian.xyz.to(au.km).value.T
    k = directions.cartesian.xyz.value.T
    t = times.mjd
    t -= t[len(t) // 2]
    t *= 86400.
    n_screen = 250
    kstar = random.uniform(random.PRNGKey(29428942), (n_screen, 3),
                           minval=jnp.min(k, axis=0),
                           maxval=jnp.max(k, axis=0))
    kstar /= jnp.linalg.norm(kstar, axis=-1, keepdims=True)
    X = jnp.asarray(
        make_coord_array(x, jnp.concatenate([k, kstar], axis=0), t[:, None]))
    x0 = jnp.asarray(antennas.cartesian.xyz.to(au.km).value.T[0, :])
    ref_ant = x0

    kernel = TomographicKernel(x0, ref_ant, RBF(), S_marg=100)
    neural_kernel = NeuralTomographicKernel(x0, ref_ant)

    def loss(params, key):
        keys = random.split(key, 5)
        indices = random.permutation(keys[0],
                                     jnp.arange(X.shape[0]))[:batch_size]
        X_batch = X[indices, :]

        wind_velocity = random.uniform(keys[1],
                                       shape=(3, ),
                                       minval=jnp.asarray([-200., -200., 0.]),
                                       maxval=jnp.asarray([200., 200., 0.
                                                           ])) / 1000.
        bottom = random.uniform(keys[2], minval=50., maxval=500.)
        width = random.uniform(keys[3], minval=40., maxval=300.)
        l = random.uniform(keys[4], minval=1., maxval=30.)
        sigma = 1.
        K = kernel(X_batch,
                   X_batch,
                   bottom,
                   width,
                   l,
                   sigma,
                   wind_velocity=wind_velocity)
        neural_kernel.set_params(params)
        neural_K = neural_kernel(X_batch,
                                 X_batch,
                                 bottom,
                                 width,
                                 l,
                                 sigma,
                                 wind_velocity=wind_velocity)

        return jnp.mean((K - neural_K)**2) / width**2

    init_params = neural_kernel.init_params(random.PRNGKey(42))

    def train_one_batch(params, key):
        l, g = value_and_grad(lambda params: loss(params, key))(params)
        params = tree_multimap(lambda p, g: p - learning_rate * g, params, g)
        return params, l

    final_params, losses = jit(lambda key: scan(
        train_one_batch, init_params, random.split(key, num_batches)))(
            random.PRNGKey(42))

    plt.plot(losses)
    plt.yscale('log')
    plt.show()
예제 #10
0
    def run(self, output_h5parm, ncpu, avg_direction_spacing,
            field_of_view_diameter, duration, time_resolution, start_time,
            array_name, phase_tracking):

        Nd = get_num_directions(
            avg_direction_spacing,
            field_of_view_diameter,
        )
        Nf = 2  # 8000
        Nt = int(duration / time_resolution) + 1
        min_freq = 700.
        max_freq = 2000.
        dp = create_empty_datapack(
            Nd,
            Nf,
            Nt,
            pols=None,
            field_of_view_diameter=field_of_view_diameter,
            start_time=start_time,
            time_resolution=time_resolution,
            min_freq=min_freq,
            max_freq=max_freq,
            array_file=ARRAYS[array_name],
            phase_tracking=(phase_tracking.ra.deg, phase_tracking.dec.deg),
            save_name=output_h5parm,
            clobber=True)

        with dp:
            dp.current_solset = 'sol000'
            dp.select(pol=slice(0, 1, 1))
            axes = dp.axes_tec
            patch_names, directions = dp.get_directions(axes['dir'])
            antenna_labels, antennas = dp.get_antennas(axes['ant'])
            timestamps, times = dp.get_times(axes['time'])
            ref_ant = antennas[0]
            ref_time = times[0]

        Na = len(antennas)
        Nd = len(directions)
        Nt = len(times)

        logger.info(f"Number of directions: {Nd}")
        logger.info(f"Number of antennas: {Na}")
        logger.info(f"Number of times: {Nt}")
        logger.info(f"Reference Ant: {ref_ant}")
        logger.info(f"Reference Time: {ref_time.isot}")

        # Plot Antenna Layout in East North Up frame
        ref_frame = ENU(obstime=ref_time, location=ref_ant.earth_location)

        _antennas = ac.ITRS(*antennas.cartesian.xyz,
                            obstime=ref_time).transform_to(ref_frame)
        # plt.scatter(_antennas.east, _antennas.north, marker='+')
        # plt.xlabel(f"East (m)")
        # plt.ylabel(f"North (m)")
        # plt.show()

        x0 = ac.ITRS(
            *antennas[0].cartesian.xyz,
            obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to(
                au.km).value
        earth_centre_x = ac.ITRS(
            x=0 * au.m, y=0 * au.m, z=0. * au.m,
            obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to(
                au.km).value
        self._kernel = TomographicKernel(x0,
                                         earth_centre_x,
                                         M32(),
                                         S_marg=20,
                                         compute_tec=False)

        k = directions.transform_to(ref_frame).cartesian.xyz.value.T

        t = times.mjd * 86400.
        t -= t[0]

        X1 = GeodesicTuple(x=[], k=[], t=[], ref_x=[])

        logger.info("Computing coordinates in frame ...")

        for i, time in tqdm(enumerate(times)):
            x = ac.ITRS(*antennas.cartesian.xyz,
                        obstime=time).transform_to(ref_frame).cartesian.xyz.to(
                            au.km).value.T
            ref_ant_x = ac.ITRS(
                *ref_ant.cartesian.xyz,
                obstime=time).transform_to(ref_frame).cartesian.xyz.to(
                    au.km).value

            X = make_coord_array(x,
                                 k,
                                 t[i:i + 1, None],
                                 ref_ant_x[None, :],
                                 flat=True)

            X1.x.append(X[:, 0:3])
            X1.k.append(X[:, 3:6])
            X1.t.append(X[:, 6:7])
            X1.ref_x.append(X[:, 7:8])

        X1 = X1._replace(
            x=jnp.concatenate(X1.x, axis=0),
            k=jnp.concatenate(X1.k, axis=0),
            t=jnp.concatenate(X1.t, axis=0),
            ref_x=jnp.concatenate(X1.ref_x, axis=0),
        )

        logger.info(f"Total number of coordinates: {X1.x.shape[0]}")

        def compute_covariance_row(X1: GeodesicTuple, X2: GeodesicTuple):
            K = self._kernel(X1,
                             X2,
                             self._bottom,
                             self._width,
                             self._fed_sigma,
                             self._fed_kernel_params,
                             wind_velocity=self._wind_vector)  # 1, N
            return K[0, :]

        covariance_row = lambda X: compute_covariance_row(
            tree_map(lambda x: x.reshape((1, -1)), X), X1)

        mean = jit(lambda X1: self._kernel.mean_function(X1,
                                                         self._bottom,
                                                         self._width,
                                                         self._fed_mu,
                                                         wind_velocity=self.
                                                         _wind_vector))(X1)

        cov = chunked_pmap(covariance_row,
                           X1,
                           batch_size=X1.x.shape[0],
                           chunksize=ncpu)

        plt.imshow(cov)
        plt.show()

        Z = random.normal(random.PRNGKey(42), (cov.shape[0], 1),
                          dtype=cov.dtype)

        t0 = default_timer()
        jitter = 1e-6
        logger.info(f"Computing Cholesky with jitter: {jitter}")
        L = jnp.linalg.cholesky(cov + jitter * jnp.eye(cov.shape[0]))
        if np.any(np.isnan(L)):
            logger.info("Numerically instable. Using SVD.")
            L = msqrt(cov)

        logger.info(f"Cholesky took {default_timer() - t0} seconds.")

        dtec = (L @ Z + mean[:, None])[:, 0].reshape((Na, Nd, Nt)).transpose(
            (1, 0, 2))

        logger.info(f"Saving result to {output_h5parm}")
        with dp:
            dp.current_solset = 'sol000'
            dp.select(pol=slice(0, 1, 1))
            dp.tec = np.asarray(dtec[None])