Example #1
0
def test_plot_data():
    dp = DataPack(
        '/home/albert/data/gains_screen/data/L342938_DDS5_full_merged.h5',
        readonly=True)
    with dp:
        select = dict(pol=slice(0, 1, 1), ant=[0, 50], time=slice(0, 100, 1))
        dp.current_solset = 'sol000'
        dp.select(**select)
        tec_mean, axes = dp.tec
        tec_mean = tec_mean[0, ...]
        const_mean, axes = dp.const
        const_mean = const_mean[0, ...]
        tec_std, axes = dp.weights_tec
        tec_std = tec_std[0, ...]
        patch_names, directions = dp.get_directions(axes['dir'])
        antenna_labels, antennas = dp.get_antennas(axes['ant'])
        timestamps, times = dp.get_times(axes['time'])
    antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    ref_ant = antennas[0]

    for i, time in enumerate(times):
        frame = ENU(obstime=time, location=ref_ant.earth_location)
        antennas = antennas.transform_to(frame)
        directions = directions.transform_to(frame)
        x = antennas.cartesian.xyz.to(au.km).value.T
        k = directions.cartesian.xyz.value.T

        dtec = tec_mean[:, 1, i]
        ax = plot_vornoi_map(k[:, 0:2], dtec)
        ax.set_title(time)
        plt.show()
Example #2
0
def create_empty_datapack_spec(directions: ac.ICRS,
                               Nf,
                               Nt,
                               pols=None,
                               start_time=None,
                               time_resolution=30.,
                               min_freq=122.,
                               max_freq=166.,
                               array_file=None,
                               save_name='test_datapack.h5',
                               clobber=False) -> DataPack:
    """
    Creates an empty datapack with phase, amplitude and DTEC.

    Args:
        Nd: number of directions
        Nf: number of frequencies
        Nt: number of times
        pols: polarisations, ['XX', ...]
        array_file: array file else Lofar HBA is used
        phase_tracking: tuple (RA, DEC) in degrees in ICRS frame
        field_of_view_diameter: FoV diameter in degrees
        start_time: start time in modified Julian days (mjs/86400)
        time_resolution: time step in seconds.
        min_freq: minimum frequency in MHz
        max_freq: maximum frequency in MHz
        save_name: where to save the H5parm.
        clobber: Whether to overwrite.

    Returns:
        DataPack
    """

    logger.info("=== Creating empty datapack ===")
    Nd = len(directions)

    save_name = os.path.abspath(save_name)
    if os.path.isfile(save_name) and clobber:
        os.unlink(save_name)

    if array_file is None:
        array_file = DataPack.lofar_array_hba

    if start_time is None:
        start_time = at.Time("2019-01-01T00:00:00.000", format='isot').mjd

    if pols is None:
        pols = ['XX']
    assert isinstance(pols, (tuple, list))

    time0 = at.Time(start_time, format='mjd')

    datapack = DataPack(save_name, readonly=False)
    with datapack:
        datapack.add_solset('sol000', array_file=array_file)
        datapack.set_directions(
            None, np.stack([directions.ra.rad, directions.dec.rad], axis=1))

        patch_names, _ = datapack.directions
        antenna_labels, _ = datapack.antennas
        _, antennas = datapack.get_antennas(antenna_labels)
        antennas = antennas.cartesian.xyz.to(au.km).value.T
        Na = antennas.shape[0]

        times = at.Time(time0.mjd + (np.arange(Nt) * time_resolution) / 86400.,
                        format='mjd').mjd * 86400.  # mjs
        freqs = np.linspace(min_freq, max_freq, Nf) * 1e6

        Npol = len(pols)
        dtecs = np.zeros((Npol, Nd, Na, Nt))
        phase = np.zeros((Npol, Nd, Na, Nf, Nt))
        amp = np.ones_like(phase)

        datapack.add_soltab('phase000',
                            values=phase,
                            ant=antenna_labels,
                            dir=patch_names,
                            time=times,
                            freq=freqs,
                            pol=pols)
        datapack.add_soltab('amplitude000',
                            values=amp,
                            ant=antenna_labels,
                            dir=patch_names,
                            time=times,
                            freq=freqs,
                            pol=pols)
        datapack.add_soltab('tec000',
                            values=dtecs,
                            ant=antenna_labels,
                            dir=patch_names,
                            time=times,
                            pol=pols)
        return datapack
    config.update("jax_enable_x64", True)

    dp = DataPack(
        '/home/albert/data/gains_screen/data/L342938_DDS5_full_merged.h5',
        readonly=True)
    with dp:
        select = dict(pol=slice(0, 1, 1), ant=[50], time=slice(0, 9, 1))
        dp.current_solset = 'sol000'
        dp.select(**select)
        tec_mean, axes = dp.tec
        dtec = jnp.asarray(tec_mean[0, :, :, :])
        tec_std, axes = dp.weights_tec
        dtec_uncert = jnp.asarray(tec_std[0, :, :, :])
        patch_names, directions = dp.get_directions(axes['dir'])
        antenna_labels, antennas = dp.get_antennas(axes['ant'])
        timestamps, times = dp.get_times(axes['time'])

    # antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    # ref_ant = antennas[0]
    # frame = ENU(obstime=times[0], location=ref_ant.earth_location)
    # antennas = antennas.transform_to(frame)
    # ref_ant = antennas[0]
    # directions = directions.transform_to(frame)
    # x = antennas.cartesian.xyz.to(au.km).value.T[1:2, :]
    # k = directions.cartesian.xyz.value.T
    times = times.mjd
    times -= times[0]
    times *= 86400.

    directions = jnp.stack([directions.ra.deg, directions.dec.deg], axis=1)
def train_neural_network(datapack: DataPack, batch_size, learning_rate,
                         num_batches):

    with datapack:
        select = dict(pol=slice(0, 1, 1), ant=None, time=slice(0, 1, 1))
        datapack.current_solset = 'sol000'
        datapack.select(**select)
        axes = datapack.axes_tec
        patch_names, directions = datapack.get_directions(axes['dir'])
        antenna_labels, antennas = datapack.get_antennas(axes['ant'])
        timestamps, times = datapack.get_times(axes['time'])

    antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    ref_ant = antennas[0]
    frame = ENU(obstime=times[0], location=ref_ant.earth_location)
    antennas = antennas.transform_to(frame)
    ref_ant = antennas[0]
    directions = directions.transform_to(frame)
    x = antennas.cartesian.xyz.to(au.km).value.T
    k = directions.cartesian.xyz.value.T
    t = times.mjd
    t -= t[len(t) // 2]
    t *= 86400.
    n_screen = 250
    kstar = random.uniform(random.PRNGKey(29428942), (n_screen, 3),
                           minval=jnp.min(k, axis=0),
                           maxval=jnp.max(k, axis=0))
    kstar /= jnp.linalg.norm(kstar, axis=-1, keepdims=True)
    X = jnp.asarray(
        make_coord_array(x, jnp.concatenate([k, kstar], axis=0), t[:, None]))
    x0 = jnp.asarray(antennas.cartesian.xyz.to(au.km).value.T[0, :])
    ref_ant = x0

    kernel = TomographicKernel(x0, ref_ant, RBF(), S_marg=100)
    neural_kernel = NeuralTomographicKernel(x0, ref_ant)

    def loss(params, key):
        keys = random.split(key, 5)
        indices = random.permutation(keys[0],
                                     jnp.arange(X.shape[0]))[:batch_size]
        X_batch = X[indices, :]

        wind_velocity = random.uniform(keys[1],
                                       shape=(3, ),
                                       minval=jnp.asarray([-200., -200., 0.]),
                                       maxval=jnp.asarray([200., 200., 0.
                                                           ])) / 1000.
        bottom = random.uniform(keys[2], minval=50., maxval=500.)
        width = random.uniform(keys[3], minval=40., maxval=300.)
        l = random.uniform(keys[4], minval=1., maxval=30.)
        sigma = 1.
        K = kernel(X_batch,
                   X_batch,
                   bottom,
                   width,
                   l,
                   sigma,
                   wind_velocity=wind_velocity)
        neural_kernel.set_params(params)
        neural_K = neural_kernel(X_batch,
                                 X_batch,
                                 bottom,
                                 width,
                                 l,
                                 sigma,
                                 wind_velocity=wind_velocity)

        return jnp.mean((K - neural_K)**2) / width**2

    init_params = neural_kernel.init_params(random.PRNGKey(42))

    def train_one_batch(params, key):
        l, g = value_and_grad(lambda params: loss(params, key))(params)
        params = tree_multimap(lambda p, g: p - learning_rate * g, params, g)
        return params, l

    final_params, losses = jit(lambda key: scan(
        train_one_batch, init_params, random.split(key, num_batches)))(
            random.PRNGKey(42))

    plt.plot(losses)
    plt.yscale('log')
    plt.show()