def plot_phase_vs_time_per_datapack(datapacks, output_folder, solsets='sol000', ant_sel=None, time_sel=None, dir_sel=None, freq_sel=None, pol_sel=None): if not isinstance(solsets, (list, tuple)): solsets = [solsets] if not isinstance(datapacks, (list, tuple)): datapacks = [datapacks] output_folder = os.path.abspath(output_folder) os.makedirs(output_folder, exist_ok=True) phases = [] stds = [] for solset, datapack in zip(solsets, datapacks): with DataPack(datapack, readonly=True) as datapack: datapack.current_solset = solset datapack.select(ant=ant_sel, time=time_sel, dir=dir_sel, freq=freq_sel, pol=pol_sel) weights, axes = datapack.weights_phase freq_ind = len(axes['freq']) >> 1 freq = axes['freq'][freq_ind] ant = axes['ant'][0] phase, _ = datapack.phase std = np.sqrt(np.abs(weights)) timestamps, times = datapack.get_times(axes['time']) phases.append(phase) stds.append(std) for phase in phases: for s, S in zip(phase.shape, phases[0].shape): assert s == S Npol, Nd, Na, Nf, Nt = phases[0].shape fig, ax = plt.subplots() for p in range(Npol): for d in range(Nd): for a in range(Na): for f in range(Nf): ax.cla() for i, solset in enumerate(solsets): phase = phases[i] std = stds[i] label = "{} {} {} {:.1f}MHz {}:{}".format(os.path.basename(datapacks[i]), solset, axes['pol'][p], axes['freq'][f] / 1e6, axes['ant'][a], axes['dir'][d]) # ax.fill_between(times.mjd, phase[p, d, a, f, :] - 2 * std[p, d, a, f, :], # phase[p, d, a, f, :] + 2 * std[p, d, a, f, :], alpha=0.5, # label=r'$\pm2\hat{\sigma}_\phi$') # ,color='blue') ax.scatter(times.mjd, phase[p, d, a, f, :], marker='+', alpha=0.3, label=label) ax.set_xlabel('Time [mjd]') ax.set_ylabel('Phase deviation [rad.]') ax.legend() filename = "{}_{}_{}_{}MHz.png".format(axes['ant'][a], axes['dir'][d], axes['pol'][p], axes['freq'][f] / 1e6) plt.savefig(os.path.join(output_folder, filename)) plt.close('all')
def visualisation(h5parm, ant=None, time=None): with DataPack(h5parm, readonly=True) as dp: dp.current_solset = 'sol000' dp.select(ant=ant, time=time) dtec, axes = dp.tec dtec = dtec[0] patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) frame = ENU(obstime=time, location=antennas[0].earth_location) directions = directions.transform_to(frame) t = times.mjd * 86400. t -= t[0] dt = np.diff(t).mean() x = antennas.cartesian.xyz.to(au.km).value.T[1:, :] # x[1,:] = x[0,:] # x[1,0] += 0.3 k = directions.cartesian.xyz.value.T logger.info(f"Directions: {directions}") logger.info(f"Antennas: {x} {antenna_labels}") logger.info(f"Times: {t}") Na = x.shape[0] logger.info(f"Number of antenna to plot: {Na}") Nd = k.shape[0] Nt = t.shape[0] fig, axs = plt.subplots(Na, Nt, sharex=True, sharey=True, figsize=(2 * Nt, 2 * Na), squeeze=False) for a in range(Na): for i in range(Nt): ax = axs[a][i] ax = plot_vornoi_map(k[:, 0:2], dtec[:, a, i], ax=ax, colorbar=False) if a == (Na - 1): ax.set_xlabel(r"$k_{\rm east}$") if i == 0: ax.set_ylabel(r"$k_{\rm north}$") if a == 0: ax.set_title(f"Time: {int(t[i])} sec") plt.show()
def get_data(solution_file): logger.info("Getting DDS4 data.") with DataPack(solution_file, readonly=True) as h: select = dict(pol=slice(0, 1, 1), ant=slice(50, 51), dir=slice(0, None, 1), time=slice(0, 100, 1)) h.select(**select) phase, axes = h.phase phase = phase[0, ...] gain_outliers, _ = h.weights_phase gain_outliers = gain_outliers[0, ...] == 1 amp, axes = h.amplitude amp = amp[0, ...] _, freqs = h.get_freqs(axes['freq']) freqs = freqs.to(au.Hz).value _, times = h.get_times(axes['time']) times = jnp.asarray(times.mjd, dtype=jnp.float64) * 86400. times = times - times[0] logger.info("Shape: {}".format(phase.shape)) (Nd, Na, Nf, Nt) = phase.shape @jit def smooth(amp, outliers): ''' Smooth amplitudes Args: amp: [Nt, Nf] outliers: [Nt, Nf] ''' weights = jnp.where(outliers, 0., 1.) log_amp = jnp.log(amp) log_amp = vmap(lambda log_amp, weights: poly_smooth( times, log_amp, deg=3, weights=weights))(log_amp.T, weights.T).T log_amp = vmap(lambda log_amp, weights: poly_smooth( freqs, log_amp, deg=3, weights=weights))(log_amp, weights) amp = jnp.exp(log_amp) return amp logger.info("Smoothing amplitudes") amp = chunked_pmap(smooth, amp.reshape((Nd * Na, Nf, Nt)).transpose((0, 2, 1)), gain_outliers.reshape((Nd * Na, Nf, Nt)).transpose( (0, 2, 1))) # Nd*Na,Nt,Nf amp = amp.transpose((0, 2, 1)).reshape((Nd, Na, Nf, Nt)) return gain_outliers, phase, amp, times, freqs
def visualise_grid(self): for d in range(0, 5): with DataPack(self._input_datapack, readonly=True) as dp: dp.current_solset = 'sol000' dp.select(pol=slice(0, 1, 1), dir=d, time=0) tec_grid, axes = dp.tec tec_grid = tec_grid[0] patch_names, directions_grid = dp.get_directions(axes['dir']) antenna_labels, antennas_grid = dp.get_antennas(axes['ant']) ref_ant = antennas_grid[0] timestamps, times_grid = dp.get_times(axes['time']) frame = ENU(location=ref_ant.earth_location, obstime=times_grid[0]) antennas_grid = ac.ITRS( *antennas_grid.cartesian.xyz, obstime=times_grid[0]).transform_to(frame) ant_pos = antennas_grid.cartesian.xyz.to(au.km).value.T plt.scatter(ant_pos[:, 0], ant_pos[:, 1], c=tec_grid[0, :, 0], cmap=plt.cm.PuOr) plt.xlabel('East [km]') plt.ylabel("North [km]") plt.title(f"Direction {repr(directions_grid)}") plt.show() ant_scatter_args = (ant_pos[:, 0], ant_pos[:, 1], tec_grid[0, :, 0]) for a in [0, 50, 150, 200, 250]: with DataPack(self._input_datapack, readonly=True) as dp: dp.current_solset = 'sol000' dp.select(pol=slice(0, 1, 1), ant=a, time=0) tec_grid, axes = dp.tec tec_grid = tec_grid[0] patch_names, directions_grid = dp.get_directions(axes['dir']) antenna_labels, antennas_grid = dp.get_antennas(axes['ant']) timestamps, times_grid = dp.get_times(axes['time']) frame = ENU(location=ref_ant.earth_location, obstime=times_grid[0]) antennas_grid = ac.ITRS( *antennas_grid.cartesian.xyz, obstime=times_grid[0]).transform_to(frame) _ant_pos = antennas_grid.cartesian.xyz.to(au.km).value.T[0] fig, axs = plt.subplots(2, 1, figsize=(4, 8)) axs[0].scatter(*ant_scatter_args[0:2], c=ant_scatter_args[2], cmap=plt.cm.PuOr, alpha=0.5) axs[0].scatter(*_ant_pos[0:2], marker='x', c='red') axs[0].set_xlabel('East [km]') axs[0].set_ylabel('North [km]') pos = 180 / np.pi * np.stack( [wrap(directions_grid.ra.rad), wrap(directions_grid.dec.rad)], axis=-1) plot_vornoi_map(pos, tec_grid[:, 0, 0], fov_circle=True, ax=axs[1]) axs[1].set_xlabel('RA(2000) [ded]') axs[1].set_ylabel('DEC(2000) [ded]') plt.show()
def main(data_dir, working_dir, obs_num, ncpu, plot_results): os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}" logger.info("Performing data smoothing via tec+const+clock inference.") dds4_h5parm = os.path.join(data_dir, 'L{}_DDS4_full_merged.h5'.format(obs_num)) dds5_h5parm = os.path.join(working_dir, 'L{}_DDS5_full_merged.h5'.format(obs_num)) linked_dds5_h5parm = os.path.join( data_dir, 'L{}_DDS5_full_merged.h5'.format(obs_num)) logger.info("Looking for {}".format(dds4_h5parm)) link_overwrite(dds5_h5parm, linked_dds5_h5parm) prepare_soltabs(dds4_h5parm, dds5_h5parm) gain_outliers, phase_obs, amp, times, freqs = get_data( solution_file=dds4_h5parm) phase_mean, phase_uncert, tec_mean, tec_std, tec_outliers, const_mean, const_std = \ solve_and_smooth(gain_outliers, phase_obs, times, freqs) # exit(0) logger.info("Storing smoothed phase, amplitudes, tec, const, and clock") with DataPack(dds5_h5parm, readonly=False) as h: h.current_solset = 'sol000' # h.select(pol=slice(0, 1, 1), ant=slice(50, 51), dir=slice(0, None, 1), time=slice(0, 100, 1)) h.select(pol=slice(0, 1, 1)) h.phase = np.asarray(phase_mean)[None, ...] h.weights_phase = np.asarray(phase_uncert)[None, ...] h.amplitude = np.asarray(amp)[None, ...] h.tec = np.asarray(tec_mean)[None, ...] h.tec_outliers = np.asarray(tec_outliers)[None, ...] h.weights_tec = np.asarray(tec_std)[None, ...] h.const = np.asarray(const_mean)[None, ...] axes = h.axes_phase patch_names, _ = h.get_directions(axes['dir']) antenna_labels, _ = h.get_antennas(axes['ant']) if plot_results: diagnostic_data_dir = os.path.join(working_dir, 'diagnostic') os.makedirs(diagnostic_data_dir, exist_ok=True) logger.info("Plotting results.") data_plot_dir = os.path.join(working_dir, 'data_plots') os.makedirs(data_plot_dir, exist_ok=True) Nd, Na, Nf, Nt = phase_mean.shape for ia in range(Na): for id in range(Nd): fig, axs = plt.subplots(3, 1, sharex=True) axs[0].plot(times, tec_mean[id, ia, :], c='black', label='tec') ylim = axs[0].get_ylim() axs[0].vlines(times[tec_outliers[id, ia, :]], *ylim, colors='red', label='outliers', alpha=0.5) axs[0].set_ylim(*ylim) axs[1].plot(times, const_mean[id, ia, :], c='black', label='const') axs[1].fill_between( times, const_mean[id, ia, :] - const_std[id, ia, :], const_mean[id, ia, :] + const_std[id, ia, :], color='black', alpha=0.2) ylim = axs[1].get_ylim() axs[1].vlines(times[tec_outliers[id, ia, :]], *ylim, colors='red', label='outliers', alpha=0.5) axs[1].set_ylim(*ylim) axs[2].plot(times, tec_std[id, ia, :], c='black', label='tec_std') ylim = axs[2].get_ylim() axs[2].vlines(times[tec_outliers[id, ia, :]], *ylim, colors='red', label='outliers', alpha=0.5) axs[2].set_ylim(*ylim) axs[0].legend() axs[1].legend() axs[2].legend() axs[0].set_ylabel("DTEC [mTECU]") axs[1].set_ylabel("const [rad]") axs[2].set_ylabel("DTEC uncert [mTECU]") axs[2].set_xlabel("time [s]") fig.savefig( os.path.join( data_plot_dir, 'solutions_ant{:02d}_dir{:02d}.png'.format(ia, id))) plt.close("all") fig, axs = plt.subplots(4, 1, sharex=True, sharey=True) # phase data with input outliers # phase posterior with tec outliers # dphase with no outliers # phase uncertainty axs[0].imshow(phase_obs[id, ia, :, :], vmin=-jnp.pi, vmax=jnp.pi, cmap='twilight', aspect='auto', origin='lower', interpolation='nearest') axs[0].imshow(jnp.where(gain_outliers[id, ia, :, :], 1., jnp.nan), vmin=0., vmax=1., cmap='bone', aspect='auto', origin='lower', interpolation='nearest') add_colorbar_to_axes(axs[0], "twilight", vmin=-jnp.pi, vmax=jnp.pi) axs[1].imshow(phase_mean[id, ia, :, :], vmin=-jnp.pi, vmax=jnp.pi, cmap='twilight', aspect='auto', origin='lower', interpolation='nearest') axs[1].imshow(jnp.where(jnp.isinf(phase_uncert[id, ia, :, :]), 1., jnp.nan), vmin=0., vmax=1., cmap='bone', aspect='auto', origin='lower', interpolation='nearest') add_colorbar_to_axes(axs[1], "twilight", vmin=-jnp.pi, vmax=jnp.pi) dphase = wrap(wrap(phase_mean) - phase_obs) vmin = -0.5 vmax = 0.5 axs[2].imshow(dphase[id, ia, :, :], vmin=vmin, vmax=vmax, cmap='PuOr', aspect='auto', origin='lower', interpolation='nearest') add_colorbar_to_axes(axs[2], "PuOr", vmin=vmin, vmax=vmax) vmin = 0. vmax = 0.8 axs[3].imshow(phase_uncert[id, ia, :, :], vmin=vmin, vmax=vmax, cmap='PuOr', aspect='auto', origin='lower', interpolation='nearest') add_colorbar_to_axes(axs[3], "PuOr", vmin=vmin, vmax=vmax) axs[0].set_ylabel("freq [MHz]") axs[1].set_ylabel("freq [MHz]") axs[2].set_ylabel("freq [MHz]") axs[3].set_ylabel("freq [MHz]") axs[3].set_xlabel("time [s]") axs[0].set_title("phase data [rad]") axs[1].set_title("phase model [rad]") axs[2].set_title("phase diff. [rad]") axs[3].set_title("phase uncert [rad]") fig.savefig( os.path.join( data_plot_dir, 'data_comparison_ant{:02d}_dir{:02d}.png'.format( ia, id))) plt.close("all") # exit(0) d = os.path.join(working_dir, 'tec_plots') animate_datapack(dds5_h5parm, d, num_processes=(ncpu * 2) // 3, vmin=-60, vmax=60., observable='tec', phase_wrap=False, plot_crosses=False, plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, solset='sol000', cmap=plt.cm.PuOr) # os.makedirs(d, exist_ok=True) # DatapackPlotter(dds5_h5parm).plot( # fignames=[os.path.join(d, "fig-{:04d}.png".format(j)) for j in range(Nt)], # vmin=-60, # vmax=60., observable='tec', phase_wrap=False, plot_crosses=False, # plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, # solset='sol000', cmap=plt.cm.PuOr) # make_animation(d, prefix='fig', fps=4) d = os.path.join(working_dir, 'const_plots') animate_datapack(dds5_h5parm, d, num_processes=(ncpu * 2) // 3, vmin=-np.pi, vmax=np.pi, observable='const', phase_wrap=False, plot_crosses=False, plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, solset='sol000', cmap=plt.cm.PuOr) # os.makedirs(d, exist_ok=True) # DatapackPlotter(dds5_h5parm).plot( # fignames=[os.path.join(d, "fig-{:04d}.png".format(j)) for j in range(Nt)], # vmin=-np.pi, # vmax=np.pi, observable='const', phase_wrap=False, plot_crosses=False, # plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, # solset='sol000', cmap=plt.cm.PuOr) # make_animation(d, prefix='fig', fps=4) d = os.path.join(working_dir, 'clock_plots') animate_datapack(dds5_h5parm, d, num_processes=(ncpu * 2) // 3, vmin=None, vmax=None, observable='clock', phase_wrap=False, plot_crosses=False, plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, solset='sol000', cmap=plt.cm.PuOr) # os.makedirs(d, exist_ok=True) # DatapackPlotter(dds5_h5parm).plot( # fignames=[os.path.join(d, "fig-{:04d}.png".format(j)) for j in range(Nt)], # vmin=None, # vmax=None, # observable='clock', phase_wrap=False, plot_crosses=False, # plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, # solset='sol000', cmap=plt.cm.PuOr) # make_animation(d, prefix='fig', fps=4) d = os.path.join(working_dir, 'amplitude_plots') animate_datapack(dds5_h5parm, d, num_processes=(ncpu * 2) // 3, log_scale=True, observable='amplitude', phase_wrap=False, plot_crosses=False, plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, solset='sol000', cmap=plt.cm.PuOr)
def train_neural_network(datapack: DataPack, batch_size, learning_rate, num_batches): with datapack: select = dict(pol=slice(0, 1, 1), ant=None, time=slice(0, 1, 1)) datapack.current_solset = 'sol000' datapack.select(**select) axes = datapack.axes_tec patch_names, directions = datapack.get_directions(axes['dir']) antenna_labels, antennas = datapack.get_antennas(axes['ant']) timestamps, times = datapack.get_times(axes['time']) antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) ref_ant = antennas[0] frame = ENU(obstime=times[0], location=ref_ant.earth_location) antennas = antennas.transform_to(frame) ref_ant = antennas[0] directions = directions.transform_to(frame) x = antennas.cartesian.xyz.to(au.km).value.T k = directions.cartesian.xyz.value.T t = times.mjd t -= t[len(t) // 2] t *= 86400. n_screen = 250 kstar = random.uniform(random.PRNGKey(29428942), (n_screen, 3), minval=jnp.min(k, axis=0), maxval=jnp.max(k, axis=0)) kstar /= jnp.linalg.norm(kstar, axis=-1, keepdims=True) X = jnp.asarray( make_coord_array(x, jnp.concatenate([k, kstar], axis=0), t[:, None])) x0 = jnp.asarray(antennas.cartesian.xyz.to(au.km).value.T[0, :]) ref_ant = x0 kernel = TomographicKernel(x0, ref_ant, RBF(), S_marg=100) neural_kernel = NeuralTomographicKernel(x0, ref_ant) def loss(params, key): keys = random.split(key, 5) indices = random.permutation(keys[0], jnp.arange(X.shape[0]))[:batch_size] X_batch = X[indices, :] wind_velocity = random.uniform(keys[1], shape=(3, ), minval=jnp.asarray([-200., -200., 0.]), maxval=jnp.asarray([200., 200., 0. ])) / 1000. bottom = random.uniform(keys[2], minval=50., maxval=500.) width = random.uniform(keys[3], minval=40., maxval=300.) l = random.uniform(keys[4], minval=1., maxval=30.) sigma = 1. K = kernel(X_batch, X_batch, bottom, width, l, sigma, wind_velocity=wind_velocity) neural_kernel.set_params(params) neural_K = neural_kernel(X_batch, X_batch, bottom, width, l, sigma, wind_velocity=wind_velocity) return jnp.mean((K - neural_K)**2) / width**2 init_params = neural_kernel.init_params(random.PRNGKey(42)) def train_one_batch(params, key): l, g = value_and_grad(lambda params: loss(params, key))(params) params = tree_multimap(lambda p, g: p - learning_rate * g, params, g) return params, l final_params, losses = jit(lambda key: scan( train_one_batch, init_params, random.split(key, num_batches)))( random.PRNGKey(42)) plt.plot(losses) plt.yscale('log') plt.show()
def plot_solution_residuals(datapack, output_folder, data_solset='sol000', solution_solset='posterior_sol', ant_sel=None, time_sel=None, dir_sel=None, freq_sel=None, pol_sel=None): def _wrap(phi): return np.angle(np.exp(1j * phi)) if not isinstance(datapack, str): datapack = datapack.filename output_folder = os.path.abspath(output_folder) os.makedirs(output_folder, exist_ok=True) solsets = [data_solset, solution_solset] with DataPack(datapack, readonly=True) as datapack: datapack.switch_solset(data_solset) datapack.select(ant=ant_sel, time=time_sel, dir=dir_sel, freq=freq_sel, pol=pol_sel) phase, axes = datapack.phase timestamps, times = datapack.get_times(axes['time']) antenna_labels, antennas = datapack.get_antennas(axes['ant']) patch_names, directions = datapack.get_sources(axes['dir']) _, freqs = datapack.get_freqs(axes['freq']) pols, _ = datapack.get_pols(axes['pol']) Npol, Nd, Na, Nf, Nt = phase.shape datapack.switch_solset(solution_solset) datapack.select(ant=ant_sel, time=time_sel, dir=dir_sel, freq=freq_sel, pol=pol_sel) tec, _ = datapack.tec phase_pred = -8.448e9 * tec[..., None, :] / freqs[:, None] res = _wrap(_wrap(phase) - _wrap(phase_pred)) cbar = None for p in range(Npol): for a in range(Na): M = int(np.ceil(np.sqrt(Nd))) fig, axs = plt.subplots(nrows=2 * M, ncols=M, sharex=True, figsize=(M * 4, 1 * M * 4), gridspec_kw={'height_ratios': [1.5, 1] * M}) fig.subplots_adjust(wspace=0., hspace=0.) fig.subplots_adjust(right=0.85) cbar_ax = fig.add_axes([0.875, 0.15, 0.025, 0.7]) vmin = -1. vmax = 1. norm = plt.Normalize(vmin, vmax) for row in range(0, 2 * M, 2): for col in range(M): ax1 = axs[row][col] ax2 = axs[row + 1][col] d = col + row // 2 * M if d >= Nd: continue img = ax1.imshow(res[p, d, a, :, :], origin='lower', aspect='auto', extent=(times[0].mjd * 86400., times[-1].mjd * 86400., freqs[0], freqs[-1]), cmap=plt.cm.jet, norm=norm) ax1.text(0.05, 0.95, axes['dir'][d], horizontalalignment='left', verticalalignment='top', transform=ax1.transAxes, backgroundcolor=(1., 1., 1., 0.5)) ax1.set_ylabel('frequency [Hz]') ax1.legend() mean = res[p, d, a, :, :].mean(0) t = np.arange(len(times)) ax2.plot(times.mjd * 86400, mean, label=r'$\mathbb{E}_\nu[\delta\phi]$') std = res[p, d, a, :, :].std(0) ax2.fill_between(times.mjd * 86400, mean - std, mean + std, alpha=0.5, label=r'$\pm\sigma_{\delta\phi}$') ax2.set_xlabel('Time [mjs]') ax2.set_xlim(times[0].mjd * 86400., times[-1].mjd * 86400.) ax2.set_ylim(-np.pi, np.pi) # ax2.legend() fig.colorbar(img, cax=cbar_ax, orientation='vertical', label='phase dev. [rad]') filename = "{}_v_{}_{}_{}.png".format(data_solset, solution_solset, axes['ant'][a], axes['pol'][p]) plt.savefig(os.path.join(output_folder, filename)) plt.close('all')
def plot_freq_vs_time(datapack, output_folder, solset='sol000', soltab='phase', phase_wrap=True, log_scale=False, ant_sel=None, time_sel=None, dir_sel=None, freq_sel=None, pol_sel=None, vmin=None, vmax=None): if isinstance(datapack, DataPack): datapack = datapack.filename with DataPack(datapack, readonly=True) as datapack: datapack.switch_solset(solset) datapack.select(ant=ant_sel, time=time_sel, dir=dir_sel, freq=freq_sel, pol=pol_sel) obs, axes = datapack.__getattr__(soltab) if soltab.startswith('weights_'): obs = np.sqrt(np.abs(1. / obs)) # uncert from weights = 1/var phase_wrap = False if 'pol' in axes.keys(): # plot only first pol selected obs = obs[0, ...] # obs is dir, ant, freq, time antenna_labels, antennas = datapack.get_antennas(axes['ant']) patch_names, directions = datapack.get_sources(axes['dir']) timestamps, times = datapack.get_times(axes['time']) freq_labels, freqs = datapack.get_freqs(axes['freq']) if phase_wrap: obs = np.angle(np.exp(1j * obs)) vmin = -np.pi vmax = np.pi cmap = phase_cmap else: vmin = vmin if vmin is not None else np.percentile(obs.flatten(), 1) vmax = vmax if vmax is not None else np.percentile(obs.flatten(), 99) cmap = plt.cm.bone if log_scale: obs = np.log10(obs) Na = len(antennas) Nt = len(times) Nd = len(directions) Nf = len(freqs) M = int(np.ceil(np.sqrt(Na))) output_folder = os.path.abspath(output_folder) os.makedirs(output_folder, exist_ok=True) for k in range(Nd): filename = os.path.join(os.path.abspath(output_folder), "{}_{}_dir_{}.png".format(solset, soltab, k)) logger.info("Plotting {}".format(filename)) fig, axs = plt.subplots(nrows=M, ncols=M, figsize=(4 * M, 4 * M), sharex=True, sharey=True) for i in range(M): for j in range(M): l = j + M * i if l >= Na: continue im = axs[i][j].imshow(obs[k, l, :, :], origin='lower', cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax, extent=(times[0].mjd * 86400., times[-1].mjd * 86400., freqs[0], freqs[1])) plt.tight_layout() fig.subplots_adjust(right=0.8) cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) fig.colorbar(im, cax=cbar_ax) plt.savefig(filename) plt.close('all')
def plot_data_vs_solution(datapack, output_folder, data_solset='sol000', solution_solset='posterior_sol', show_prior_uncert=False, ant_sel=None, time_sel=None, dir_sel=None, freq_sel=None, pol_sel=None): def _wrap(phi): return np.angle(np.exp(1j * phi)) if isinstance(datapack, DataPack): datapack = datapack.filename output_folder = os.path.abspath(output_folder) os.makedirs(output_folder, exist_ok=True) solsets = [data_solset, solution_solset] with DataPack(datapack, readonly=True) as datapack: phases = [] stds = [] datapack.switch_solset(data_solset) datapack.select(ant=ant_sel, time=time_sel, dir=dir_sel, freq=freq_sel, pol=pol_sel) weights, axes = datapack.weights_phase _, freqs = datapack.get_freqs(axes['freq']) phase, _ = datapack.phase std = np.sqrt(np.abs(1. / weights)) timestamps, times = datapack.get_times(axes['time']) phases.append(_wrap(phase)) stds.append(std) tec_conversion = -8.4480e9 / freqs[None, None, None, :, None] datapack.switch_solset(solution_solset) datapack.select(ant=ant_sel, time=time_sel, dir=dir_sel, freq=freq_sel, pol=pol_sel) weights, _ = datapack.weights_tec tec, _ = datapack.tec std = np.sqrt(np.abs(1. / weights))[:, :, :, None, :] * np.abs(tec_conversion) phases.append(_wrap(tec[:, :, :, None, :] * tec_conversion)) stds.append(std) for phase in phases: for s, S in zip(phase.shape, phases[0].shape): assert s == S Npol, Nd, Na, Nf, Nt = phases[0].shape fig, ax = plt.subplots() for p in range(Npol): for d in range(Nd): for a in range(Na): for f in range(Nf): ax.cla() ### # Data phase = phases[0] std = stds[0] label = "{} {} {:.1f}MHz {}:{}".format(data_solset, axes['pol'][p], axes['freq'][f] / 1e6, axes['ant'][a], axes['dir'][d]) if show_prior_uncert: ax.fill_between(times.mjd, phase[p, d, a, f, :] - std[p, d, a, f, :], phase[p, d, a, f, :] + std[p, d, a, f, :], alpha=0.5, label=r'$\pm2\hat{\sigma}_\phi$') # ,color='blue') ax.scatter(times.mjd, phase[p, d, a, f, :], marker='+', alpha=0.3, color='black', label=label) ### # Solution phase = phases[1] std = stds[1] label = "Solution: {}".format(solution_solset) ax.fill_between(times.mjd, phase[p, d, a, f, :] - std[p, d, a, f, :], phase[p, d, a, f, :] + std[p, d, a, f, :], alpha=0.5, label=r'$\pm\hat{\sigma}_\phi$') # ,color='blue') ax.scatter(times.mjd, phase[p, d, a, f, :], label=label, marker='.', s=5.) ax.set_xlabel('Time [mjd]') ax.set_ylabel('Phase deviation [rad.]') ax.legend() filename = "{}_v_{}_{}_{}_{}_{}MHz.png".format(data_solset, solution_solset, axes['ant'][a], axes['dir'][d], axes['pol'][p], axes['freq'][f] / 1e6) ax.set_ylim(-np.pi, np.pi) plt.savefig(os.path.join(output_folder, filename)) plt.close('all')
def __init__(self, datapack): if isinstance(datapack, str): datapack = DataPack(filename=datapack, readonly=True) self.datapack = datapack