def test_plot_data(): dp = DataPack( '/home/albert/data/gains_screen/data/L342938_DDS5_full_merged.h5', readonly=True) with dp: select = dict(pol=slice(0, 1, 1), ant=[0, 50], time=slice(0, 100, 1)) dp.current_solset = 'sol000' dp.select(**select) tec_mean, axes = dp.tec tec_mean = tec_mean[0, ...] const_mean, axes = dp.const const_mean = const_mean[0, ...] tec_std, axes = dp.weights_tec tec_std = tec_std[0, ...] patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) ref_ant = antennas[0] for i, time in enumerate(times): frame = ENU(obstime=time, location=ref_ant.earth_location) antennas = antennas.transform_to(frame) directions = directions.transform_to(frame) x = antennas.cartesian.xyz.to(au.km).value.T k = directions.cartesian.xyz.value.T dtec = tec_mean[:, 1, i] ax = plot_vornoi_map(k[:, 0:2], dtec) ax.set_title(time) plt.show()
def create_empty_datapack_spec(directions: ac.ICRS, Nf, Nt, pols=None, start_time=None, time_resolution=30., min_freq=122., max_freq=166., array_file=None, save_name='test_datapack.h5', clobber=False) -> DataPack: """ Creates an empty datapack with phase, amplitude and DTEC. Args: Nd: number of directions Nf: number of frequencies Nt: number of times pols: polarisations, ['XX', ...] array_file: array file else Lofar HBA is used phase_tracking: tuple (RA, DEC) in degrees in ICRS frame field_of_view_diameter: FoV diameter in degrees start_time: start time in modified Julian days (mjs/86400) time_resolution: time step in seconds. min_freq: minimum frequency in MHz max_freq: maximum frequency in MHz save_name: where to save the H5parm. clobber: Whether to overwrite. Returns: DataPack """ logger.info("=== Creating empty datapack ===") Nd = len(directions) save_name = os.path.abspath(save_name) if os.path.isfile(save_name) and clobber: os.unlink(save_name) if array_file is None: array_file = DataPack.lofar_array_hba if start_time is None: start_time = at.Time("2019-01-01T00:00:00.000", format='isot').mjd if pols is None: pols = ['XX'] assert isinstance(pols, (tuple, list)) time0 = at.Time(start_time, format='mjd') datapack = DataPack(save_name, readonly=False) with datapack: datapack.add_solset('sol000', array_file=array_file) datapack.set_directions( None, np.stack([directions.ra.rad, directions.dec.rad], axis=1)) patch_names, _ = datapack.directions antenna_labels, _ = datapack.antennas _, antennas = datapack.get_antennas(antenna_labels) antennas = antennas.cartesian.xyz.to(au.km).value.T Na = antennas.shape[0] times = at.Time(time0.mjd + (np.arange(Nt) * time_resolution) / 86400., format='mjd').mjd * 86400. # mjs freqs = np.linspace(min_freq, max_freq, Nf) * 1e6 Npol = len(pols) dtecs = np.zeros((Npol, Nd, Na, Nt)) phase = np.zeros((Npol, Nd, Na, Nf, Nt)) amp = np.ones_like(phase) datapack.add_soltab('phase000', values=phase, ant=antenna_labels, dir=patch_names, time=times, freq=freqs, pol=pols) datapack.add_soltab('amplitude000', values=amp, ant=antenna_labels, dir=patch_names, time=times, freq=freqs, pol=pols) datapack.add_soltab('tec000', values=dtecs, ant=antenna_labels, dir=patch_names, time=times, pol=pols) return datapack
config.update("jax_enable_x64", True) dp = DataPack( '/home/albert/data/gains_screen/data/L342938_DDS5_full_merged.h5', readonly=True) with dp: select = dict(pol=slice(0, 1, 1), ant=[50], time=slice(0, 9, 1)) dp.current_solset = 'sol000' dp.select(**select) tec_mean, axes = dp.tec dtec = jnp.asarray(tec_mean[0, :, :, :]) tec_std, axes = dp.weights_tec dtec_uncert = jnp.asarray(tec_std[0, :, :, :]) patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) # antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) # ref_ant = antennas[0] # frame = ENU(obstime=times[0], location=ref_ant.earth_location) # antennas = antennas.transform_to(frame) # ref_ant = antennas[0] # directions = directions.transform_to(frame) # x = antennas.cartesian.xyz.to(au.km).value.T[1:2, :] # k = directions.cartesian.xyz.value.T times = times.mjd times -= times[0] times *= 86400. directions = jnp.stack([directions.ra.deg, directions.dec.deg], axis=1)
def train_neural_network(datapack: DataPack, batch_size, learning_rate, num_batches): with datapack: select = dict(pol=slice(0, 1, 1), ant=None, time=slice(0, 1, 1)) datapack.current_solset = 'sol000' datapack.select(**select) axes = datapack.axes_tec patch_names, directions = datapack.get_directions(axes['dir']) antenna_labels, antennas = datapack.get_antennas(axes['ant']) timestamps, times = datapack.get_times(axes['time']) antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) ref_ant = antennas[0] frame = ENU(obstime=times[0], location=ref_ant.earth_location) antennas = antennas.transform_to(frame) ref_ant = antennas[0] directions = directions.transform_to(frame) x = antennas.cartesian.xyz.to(au.km).value.T k = directions.cartesian.xyz.value.T t = times.mjd t -= t[len(t) // 2] t *= 86400. n_screen = 250 kstar = random.uniform(random.PRNGKey(29428942), (n_screen, 3), minval=jnp.min(k, axis=0), maxval=jnp.max(k, axis=0)) kstar /= jnp.linalg.norm(kstar, axis=-1, keepdims=True) X = jnp.asarray( make_coord_array(x, jnp.concatenate([k, kstar], axis=0), t[:, None])) x0 = jnp.asarray(antennas.cartesian.xyz.to(au.km).value.T[0, :]) ref_ant = x0 kernel = TomographicKernel(x0, ref_ant, RBF(), S_marg=100) neural_kernel = NeuralTomographicKernel(x0, ref_ant) def loss(params, key): keys = random.split(key, 5) indices = random.permutation(keys[0], jnp.arange(X.shape[0]))[:batch_size] X_batch = X[indices, :] wind_velocity = random.uniform(keys[1], shape=(3, ), minval=jnp.asarray([-200., -200., 0.]), maxval=jnp.asarray([200., 200., 0. ])) / 1000. bottom = random.uniform(keys[2], minval=50., maxval=500.) width = random.uniform(keys[3], minval=40., maxval=300.) l = random.uniform(keys[4], minval=1., maxval=30.) sigma = 1. K = kernel(X_batch, X_batch, bottom, width, l, sigma, wind_velocity=wind_velocity) neural_kernel.set_params(params) neural_K = neural_kernel(X_batch, X_batch, bottom, width, l, sigma, wind_velocity=wind_velocity) return jnp.mean((K - neural_K)**2) / width**2 init_params = neural_kernel.init_params(random.PRNGKey(42)) def train_one_batch(params, key): l, g = value_and_grad(lambda params: loss(params, key))(params) params = tree_multimap(lambda p, g: p - learning_rate * g, params, g) return params, l final_params, losses = jit(lambda key: scan( train_one_batch, init_params, random.split(key, num_batches)))( random.PRNGKey(42)) plt.plot(losses) plt.yscale('log') plt.show()