def test_plot_data(): dp = DataPack( '/home/albert/data/gains_screen/data/L342938_DDS5_full_merged.h5', readonly=True) with dp: select = dict(pol=slice(0, 1, 1), ant=[0, 50], time=slice(0, 100, 1)) dp.current_solset = 'sol000' dp.select(**select) tec_mean, axes = dp.tec tec_mean = tec_mean[0, ...] const_mean, axes = dp.const const_mean = const_mean[0, ...] tec_std, axes = dp.weights_tec tec_std = tec_std[0, ...] patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) ref_ant = antennas[0] for i, time in enumerate(times): frame = ENU(obstime=time, location=ref_ant.earth_location) antennas = antennas.transform_to(frame) directions = directions.transform_to(frame) x = antennas.cartesian.xyz.to(au.km).value.T k = directions.cartesian.xyz.value.T dtec = tec_mean[:, 1, i] ax = plot_vornoi_map(k[:, 0:2], dtec) ax.set_title(time) plt.show()
def main(data_dir, working_dir, obs_num): logger.info("Merging slow solutions into screen and smoothed.") smoothed_h5parm = os.path.join(data_dir, 'L{}_DDS5_full_merged.h5'.format(obs_num)) screen_h5parm = os.path.join(data_dir, 'L{}_DDS6_full_merged.h5'.format(obs_num)) slow_h5parm = os.path.join(data_dir, 'L{}_DDS7_full_slow_merged.h5'.format(obs_num)) merged_h5parm = os.path.join(working_dir, 'L{}_DDS8_full_merged.h5'.format(obs_num)) linked_merged_h5parm = os.path.join(data_dir, os.path.basename(merged_h5parm)) link_overwrite(merged_h5parm, linked_merged_h5parm) select = dict(pol=slice(0, 1, 1)) ### # get slow phase and amplitude datapack_slow = DataPack(slow_h5parm, readonly=True) logger.info("Getting slow000/phase000+amplitude000") datapack_slow.current_solset = 'sol000' datapack_slow.select(**select) axes = datapack_slow.axes_phase patch_names, directions = datapack_slow.get_directions(axes['dir']) directions_slow = np.stack([directions.ra.rad, directions.dec.rad], axis=1) timestamps, times = datapack_slow.get_times(axes['time']) time_slow = times.mjd * 86400. phase_slow, axes = datapack_slow.phase amplitude_slow, axes = datapack_slow.amplitude ### # get smoothed phase and amplitude datapack_smoothed = DataPack(smoothed_h5parm, readonly=True) logger.info("Getting directionally_referenced/const000") datapack_smoothed.current_solset = 'sol000' datapack_smoothed.select(**select) axes = datapack_smoothed.axes_phase patch_names, directions = datapack_smoothed.get_directions(axes['dir']) directions_smoothed = np.stack([directions.ra.rad, directions.dec.rad], axis=1) timestamps, times = datapack_smoothed.get_times(axes['time']) time_smoothed = times.mjd * 86400. phase_smoothed, axes = datapack_smoothed.phase amplitude_smoothed, axes = datapack_smoothed.amplitude Ncal = directions_smoothed.shape[0] ### # get screen phase and amplitude datapack_screen = DataPack(screen_h5parm, readonly=False) logger.info("Getting screen_posterior/phase000+amplitude000") datapack_screen.current_solset = 'sol000' datapack_screen.select(**select) axes = datapack_screen.axes_phase patch_names, directions = datapack_screen.get_directions(axes['dir']) directions_screen = np.stack([directions.ra.rad, directions.dec.rad], axis=1) timestamps, times = datapack_screen.get_times(axes['time']) time_screen = times.mjd * 86400. phase_screen, axes = datapack_screen.phase amplitude_screen, axes = datapack_screen.amplitude ### # Create and set screen_slow000 logger.info("Creating screen_slow/phase000+amplitude000") make_soltab(screen_h5parm, from_solset='sol000', to_solset='screen_slow', from_soltab='phase000', to_soltab=['phase000', 'amplitude000'], remake_solset=True, to_datapack=merged_h5parm) logger.info("Creating smoothed_slow/phase000+amplitude000") make_soltab(smoothed_h5parm, from_solset='sol000', to_solset='smoothed_slow', from_soltab='phase000', to_soltab=['phase000', 'amplitude000'], remake_solset=True, to_datapack=merged_h5parm) logger.info("Creating time mapping") time_map = np.asarray( [np.argmin(np.abs(time_slow - t)) for t in time_screen]) logger.info("Creating direction mapping") dir_map = np.asarray([ jnp.argmin( great_circle_sep(directions_slow[:, 0], directions_slow[:, 1], ra, dec)) for (ra, dec) in zip(directions_screen[:, 0], directions_screen[:, 1]) ]) #TODO: see if only applying slow on calibrators and screen elsewhere gets rid of artefacts. #TODO: see if only applying tec screen gets rid of artefacts (include slow if the above experiment doesn't work) #TODO: visibility flagging based on tec outliers (update imaging command) phase_smooth_slow = phase_slow[..., time_map] + phase_smoothed amplitude_smooth_slow = amplitude_slow[..., time_map] * amplitude_smoothed phase_screen_slow = phase_screen + phase_slow[..., time_map][:, dir_map, ...] amplitude_screen_slow = amplitude_screen * amplitude_slow[ ..., time_map][:, dir_map, ...] logger.info("Phase screen+slow contains {} nans.".format( np.isnan(phase_screen_slow).sum())) phase_screen_slow = np.where(np.isnan(phase_screen_slow), 0., phase_screen_slow) logger.info("Amplitude screen+slow contains {} nans.".format( np.isnan(amplitude_screen_slow).sum())) amplitude_screen_slow = np.where(np.isnan(amplitude_screen_slow), 1., amplitude_screen_slow) logger.info("Saving results to {}".format(merged_h5parm)) datapack = DataPack(merged_h5parm, readonly=False) datapack.current_solset = 'screen_slow' datapack.select(**select) datapack.phase = phase_screen_slow datapack.amplitude = amplitude_screen_slow datapack.current_solset = 'smoothed_slow' datapack.select(**select) datapack.phase = phase_smooth_slow datapack.amplitude = amplitude_smooth_slow
from jax.config import config config.update("jax_enable_x64", True) dp = DataPack( '/home/albert/data/gains_screen/data/L342938_DDS5_full_merged.h5', readonly=True) with dp: select = dict(pol=slice(0, 1, 1), ant=[50], time=slice(0, 9, 1)) dp.current_solset = 'sol000' dp.select(**select) tec_mean, axes = dp.tec dtec = jnp.asarray(tec_mean[0, :, :, :]) tec_std, axes = dp.weights_tec dtec_uncert = jnp.asarray(tec_std[0, :, :, :]) patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) # antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) # ref_ant = antennas[0] # frame = ENU(obstime=times[0], location=ref_ant.earth_location) # antennas = antennas.transform_to(frame) # ref_ant = antennas[0] # directions = directions.transform_to(frame) # x = antennas.cartesian.xyz.to(au.km).value.T[1:2, :] # k = directions.cartesian.xyz.value.T times = times.mjd times -= times[0] times *= 86400.
def train_neural_network(datapack: DataPack, batch_size, learning_rate, num_batches): with datapack: select = dict(pol=slice(0, 1, 1), ant=None, time=slice(0, 1, 1)) datapack.current_solset = 'sol000' datapack.select(**select) axes = datapack.axes_tec patch_names, directions = datapack.get_directions(axes['dir']) antenna_labels, antennas = datapack.get_antennas(axes['ant']) timestamps, times = datapack.get_times(axes['time']) antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) ref_ant = antennas[0] frame = ENU(obstime=times[0], location=ref_ant.earth_location) antennas = antennas.transform_to(frame) ref_ant = antennas[0] directions = directions.transform_to(frame) x = antennas.cartesian.xyz.to(au.km).value.T k = directions.cartesian.xyz.value.T t = times.mjd t -= t[len(t) // 2] t *= 86400. n_screen = 250 kstar = random.uniform(random.PRNGKey(29428942), (n_screen, 3), minval=jnp.min(k, axis=0), maxval=jnp.max(k, axis=0)) kstar /= jnp.linalg.norm(kstar, axis=-1, keepdims=True) X = jnp.asarray( make_coord_array(x, jnp.concatenate([k, kstar], axis=0), t[:, None])) x0 = jnp.asarray(antennas.cartesian.xyz.to(au.km).value.T[0, :]) ref_ant = x0 kernel = TomographicKernel(x0, ref_ant, RBF(), S_marg=100) neural_kernel = NeuralTomographicKernel(x0, ref_ant) def loss(params, key): keys = random.split(key, 5) indices = random.permutation(keys[0], jnp.arange(X.shape[0]))[:batch_size] X_batch = X[indices, :] wind_velocity = random.uniform(keys[1], shape=(3, ), minval=jnp.asarray([-200., -200., 0.]), maxval=jnp.asarray([200., 200., 0. ])) / 1000. bottom = random.uniform(keys[2], minval=50., maxval=500.) width = random.uniform(keys[3], minval=40., maxval=300.) l = random.uniform(keys[4], minval=1., maxval=30.) sigma = 1. K = kernel(X_batch, X_batch, bottom, width, l, sigma, wind_velocity=wind_velocity) neural_kernel.set_params(params) neural_K = neural_kernel(X_batch, X_batch, bottom, width, l, sigma, wind_velocity=wind_velocity) return jnp.mean((K - neural_K)**2) / width**2 init_params = neural_kernel.init_params(random.PRNGKey(42)) def train_one_batch(params, key): l, g = value_and_grad(lambda params: loss(params, key))(params) params = tree_multimap(lambda p, g: p - learning_rate * g, params, g) return params, l final_params, losses = jit(lambda key: scan( train_one_batch, init_params, random.split(key, num_batches)))( random.PRNGKey(42)) plt.plot(losses) plt.yscale('log') plt.show()