示例#1
0
def test_plot_data():
    dp = DataPack(
        '/home/albert/data/gains_screen/data/L342938_DDS5_full_merged.h5',
        readonly=True)
    with dp:
        select = dict(pol=slice(0, 1, 1), ant=[0, 50], time=slice(0, 100, 1))
        dp.current_solset = 'sol000'
        dp.select(**select)
        tec_mean, axes = dp.tec
        tec_mean = tec_mean[0, ...]
        const_mean, axes = dp.const
        const_mean = const_mean[0, ...]
        tec_std, axes = dp.weights_tec
        tec_std = tec_std[0, ...]
        patch_names, directions = dp.get_directions(axes['dir'])
        antenna_labels, antennas = dp.get_antennas(axes['ant'])
        timestamps, times = dp.get_times(axes['time'])
    antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    ref_ant = antennas[0]

    for i, time in enumerate(times):
        frame = ENU(obstime=time, location=ref_ant.earth_location)
        antennas = antennas.transform_to(frame)
        directions = directions.transform_to(frame)
        x = antennas.cartesian.xyz.to(au.km).value.T
        k = directions.cartesian.xyz.value.T

        dtec = tec_mean[:, 1, i]
        ax = plot_vornoi_map(k[:, 0:2], dtec)
        ax.set_title(time)
        plt.show()
示例#2
0
def main(data_dir, working_dir, obs_num):
    logger.info("Merging slow solutions into screen and smoothed.")

    smoothed_h5parm = os.path.join(data_dir,
                                   'L{}_DDS5_full_merged.h5'.format(obs_num))

    screen_h5parm = os.path.join(data_dir,
                                 'L{}_DDS6_full_merged.h5'.format(obs_num))

    slow_h5parm = os.path.join(data_dir,
                               'L{}_DDS7_full_slow_merged.h5'.format(obs_num))

    merged_h5parm = os.path.join(working_dir,
                                 'L{}_DDS8_full_merged.h5'.format(obs_num))
    linked_merged_h5parm = os.path.join(data_dir,
                                        os.path.basename(merged_h5parm))
    link_overwrite(merged_h5parm, linked_merged_h5parm)

    select = dict(pol=slice(0, 1, 1))

    ###
    # get slow phase and amplitude

    datapack_slow = DataPack(slow_h5parm, readonly=True)
    logger.info("Getting slow000/phase000+amplitude000")
    datapack_slow.current_solset = 'sol000'
    datapack_slow.select(**select)
    axes = datapack_slow.axes_phase
    patch_names, directions = datapack_slow.get_directions(axes['dir'])
    directions_slow = np.stack([directions.ra.rad, directions.dec.rad], axis=1)
    timestamps, times = datapack_slow.get_times(axes['time'])
    time_slow = times.mjd * 86400.
    phase_slow, axes = datapack_slow.phase
    amplitude_slow, axes = datapack_slow.amplitude

    ###
    # get smoothed phase and amplitude

    datapack_smoothed = DataPack(smoothed_h5parm, readonly=True)
    logger.info("Getting directionally_referenced/const000")
    datapack_smoothed.current_solset = 'sol000'
    datapack_smoothed.select(**select)
    axes = datapack_smoothed.axes_phase
    patch_names, directions = datapack_smoothed.get_directions(axes['dir'])
    directions_smoothed = np.stack([directions.ra.rad, directions.dec.rad],
                                   axis=1)
    timestamps, times = datapack_smoothed.get_times(axes['time'])
    time_smoothed = times.mjd * 86400.
    phase_smoothed, axes = datapack_smoothed.phase
    amplitude_smoothed, axes = datapack_smoothed.amplitude
    Ncal = directions_smoothed.shape[0]

    ###
    # get screen phase and amplitude

    datapack_screen = DataPack(screen_h5parm, readonly=False)
    logger.info("Getting screen_posterior/phase000+amplitude000")
    datapack_screen.current_solset = 'sol000'
    datapack_screen.select(**select)
    axes = datapack_screen.axes_phase
    patch_names, directions = datapack_screen.get_directions(axes['dir'])
    directions_screen = np.stack([directions.ra.rad, directions.dec.rad],
                                 axis=1)
    timestamps, times = datapack_screen.get_times(axes['time'])
    time_screen = times.mjd * 86400.
    phase_screen, axes = datapack_screen.phase
    amplitude_screen, axes = datapack_screen.amplitude

    ###
    # Create and set screen_slow000

    logger.info("Creating screen_slow/phase000+amplitude000")
    make_soltab(screen_h5parm,
                from_solset='sol000',
                to_solset='screen_slow',
                from_soltab='phase000',
                to_soltab=['phase000', 'amplitude000'],
                remake_solset=True,
                to_datapack=merged_h5parm)
    logger.info("Creating smoothed_slow/phase000+amplitude000")
    make_soltab(smoothed_h5parm,
                from_solset='sol000',
                to_solset='smoothed_slow',
                from_soltab='phase000',
                to_soltab=['phase000', 'amplitude000'],
                remake_solset=True,
                to_datapack=merged_h5parm)

    logger.info("Creating time mapping")
    time_map = np.asarray(
        [np.argmin(np.abs(time_slow - t)) for t in time_screen])
    logger.info("Creating direction mapping")
    dir_map = np.asarray([
        jnp.argmin(
            great_circle_sep(directions_slow[:, 0], directions_slow[:, 1], ra,
                             dec))
        for (ra, dec) in zip(directions_screen[:, 0], directions_screen[:, 1])
    ])
    #TODO: see if only applying slow on calibrators and screen elsewhere gets rid of artefacts.
    #TODO: see if only applying tec screen gets rid of artefacts (include slow if the above experiment doesn't work)
    #TODO: visibility flagging based on tec outliers (update imaging command)
    phase_smooth_slow = phase_slow[..., time_map] + phase_smoothed
    amplitude_smooth_slow = amplitude_slow[..., time_map] * amplitude_smoothed

    phase_screen_slow = phase_screen + phase_slow[..., time_map][:, dir_map,
                                                                 ...]
    amplitude_screen_slow = amplitude_screen * amplitude_slow[
        ..., time_map][:, dir_map, ...]

    logger.info("Phase screen+slow contains {} nans.".format(
        np.isnan(phase_screen_slow).sum()))
    phase_screen_slow = np.where(np.isnan(phase_screen_slow), 0.,
                                 phase_screen_slow)
    logger.info("Amplitude screen+slow contains {} nans.".format(
        np.isnan(amplitude_screen_slow).sum()))
    amplitude_screen_slow = np.where(np.isnan(amplitude_screen_slow), 1.,
                                     amplitude_screen_slow)

    logger.info("Saving results to {}".format(merged_h5parm))
    datapack = DataPack(merged_h5parm, readonly=False)
    datapack.current_solset = 'screen_slow'
    datapack.select(**select)
    datapack.phase = phase_screen_slow
    datapack.amplitude = amplitude_screen_slow

    datapack.current_solset = 'smoothed_slow'
    datapack.select(**select)
    datapack.phase = phase_smooth_slow
    datapack.amplitude = amplitude_smooth_slow
    from jax.config import config

    config.update("jax_enable_x64", True)

    dp = DataPack(
        '/home/albert/data/gains_screen/data/L342938_DDS5_full_merged.h5',
        readonly=True)
    with dp:
        select = dict(pol=slice(0, 1, 1), ant=[50], time=slice(0, 9, 1))
        dp.current_solset = 'sol000'
        dp.select(**select)
        tec_mean, axes = dp.tec
        dtec = jnp.asarray(tec_mean[0, :, :, :])
        tec_std, axes = dp.weights_tec
        dtec_uncert = jnp.asarray(tec_std[0, :, :, :])
        patch_names, directions = dp.get_directions(axes['dir'])
        antenna_labels, antennas = dp.get_antennas(axes['ant'])
        timestamps, times = dp.get_times(axes['time'])

    # antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    # ref_ant = antennas[0]
    # frame = ENU(obstime=times[0], location=ref_ant.earth_location)
    # antennas = antennas.transform_to(frame)
    # ref_ant = antennas[0]
    # directions = directions.transform_to(frame)
    # x = antennas.cartesian.xyz.to(au.km).value.T[1:2, :]
    # k = directions.cartesian.xyz.value.T
    times = times.mjd
    times -= times[0]
    times *= 86400.
def train_neural_network(datapack: DataPack, batch_size, learning_rate,
                         num_batches):

    with datapack:
        select = dict(pol=slice(0, 1, 1), ant=None, time=slice(0, 1, 1))
        datapack.current_solset = 'sol000'
        datapack.select(**select)
        axes = datapack.axes_tec
        patch_names, directions = datapack.get_directions(axes['dir'])
        antenna_labels, antennas = datapack.get_antennas(axes['ant'])
        timestamps, times = datapack.get_times(axes['time'])

    antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    ref_ant = antennas[0]
    frame = ENU(obstime=times[0], location=ref_ant.earth_location)
    antennas = antennas.transform_to(frame)
    ref_ant = antennas[0]
    directions = directions.transform_to(frame)
    x = antennas.cartesian.xyz.to(au.km).value.T
    k = directions.cartesian.xyz.value.T
    t = times.mjd
    t -= t[len(t) // 2]
    t *= 86400.
    n_screen = 250
    kstar = random.uniform(random.PRNGKey(29428942), (n_screen, 3),
                           minval=jnp.min(k, axis=0),
                           maxval=jnp.max(k, axis=0))
    kstar /= jnp.linalg.norm(kstar, axis=-1, keepdims=True)
    X = jnp.asarray(
        make_coord_array(x, jnp.concatenate([k, kstar], axis=0), t[:, None]))
    x0 = jnp.asarray(antennas.cartesian.xyz.to(au.km).value.T[0, :])
    ref_ant = x0

    kernel = TomographicKernel(x0, ref_ant, RBF(), S_marg=100)
    neural_kernel = NeuralTomographicKernel(x0, ref_ant)

    def loss(params, key):
        keys = random.split(key, 5)
        indices = random.permutation(keys[0],
                                     jnp.arange(X.shape[0]))[:batch_size]
        X_batch = X[indices, :]

        wind_velocity = random.uniform(keys[1],
                                       shape=(3, ),
                                       minval=jnp.asarray([-200., -200., 0.]),
                                       maxval=jnp.asarray([200., 200., 0.
                                                           ])) / 1000.
        bottom = random.uniform(keys[2], minval=50., maxval=500.)
        width = random.uniform(keys[3], minval=40., maxval=300.)
        l = random.uniform(keys[4], minval=1., maxval=30.)
        sigma = 1.
        K = kernel(X_batch,
                   X_batch,
                   bottom,
                   width,
                   l,
                   sigma,
                   wind_velocity=wind_velocity)
        neural_kernel.set_params(params)
        neural_K = neural_kernel(X_batch,
                                 X_batch,
                                 bottom,
                                 width,
                                 l,
                                 sigma,
                                 wind_velocity=wind_velocity)

        return jnp.mean((K - neural_K)**2) / width**2

    init_params = neural_kernel.init_params(random.PRNGKey(42))

    def train_one_batch(params, key):
        l, g = value_and_grad(lambda params: loss(params, key))(params)
        params = tree_multimap(lambda p, g: p - learning_rate * g, params, g)
        return params, l

    final_params, losses = jit(lambda key: scan(
        train_one_batch, init_params, random.split(key, num_batches)))(
            random.PRNGKey(42))

    plt.plot(losses)
    plt.yscale('log')
    plt.show()