Beispiel #1
0
def plot_phase_vs_time_per_datapack(datapacks, output_folder, solsets='sol000',
                                    ant_sel=None, time_sel=None, dir_sel=None, freq_sel=None, pol_sel=None):
    if not isinstance(solsets, (list, tuple)):
        solsets = [solsets]

    if not isinstance(datapacks, (list, tuple)):
        datapacks = [datapacks]

    output_folder = os.path.abspath(output_folder)
    os.makedirs(output_folder, exist_ok=True)
    phases = []
    stds = []
    for solset, datapack in zip(solsets, datapacks):
        with DataPack(datapack, readonly=True) as datapack:
            datapack.current_solset = solset
            datapack.select(ant=ant_sel, time=time_sel, dir=dir_sel, freq=freq_sel, pol=pol_sel)
            weights, axes = datapack.weights_phase
            freq_ind = len(axes['freq']) >> 1
            freq = axes['freq'][freq_ind]
            ant = axes['ant'][0]
            phase, _ = datapack.phase
            std = np.sqrt(np.abs(weights))
            timestamps, times = datapack.get_times(axes['time'])
            phases.append(phase)
            stds.append(std)
    for phase in phases:
        for s, S in zip(phase.shape, phases[0].shape):
            assert s == S
    Npol, Nd, Na, Nf, Nt = phases[0].shape
    fig, ax = plt.subplots()
    for p in range(Npol):
        for d in range(Nd):
            for a in range(Na):
                for f in range(Nf):
                    ax.cla()
                    for i, solset in enumerate(solsets):
                        phase = phases[i]
                        std = stds[i]
                        label = "{} {} {} {:.1f}MHz {}:{}".format(os.path.basename(datapacks[i]), solset,
                                                                  axes['pol'][p], axes['freq'][f] / 1e6,
                                                                  axes['ant'][a], axes['dir'][d])
                        # ax.fill_between(times.mjd, phase[p, d, a, f, :] - 2 * std[p, d, a, f, :],
                        #                 phase[p, d, a, f, :] + 2 * std[p, d, a, f, :], alpha=0.5,
                        #                 label=r'$\pm2\hat{\sigma}_\phi$')  # ,color='blue')
                        ax.scatter(times.mjd, phase[p, d, a, f, :], marker='+', alpha=0.3,
                                   label=label)

                    ax.set_xlabel('Time [mjd]')
                    ax.set_ylabel('Phase deviation [rad.]')
                    ax.legend()
                    filename = "{}_{}_{}_{}MHz.png".format(axes['ant'][a], axes['dir'][d], axes['pol'][p],
                                                           axes['freq'][f] / 1e6)

                    plt.savefig(os.path.join(output_folder, filename))
        plt.close('all')
Beispiel #2
0
def visualisation(h5parm, ant=None, time=None):
    with DataPack(h5parm, readonly=True) as dp:
        dp.current_solset = 'sol000'
        dp.select(ant=ant, time=time)
        dtec, axes = dp.tec
        dtec = dtec[0]
        patch_names, directions = dp.get_directions(axes['dir'])
        antenna_labels, antennas = dp.get_antennas(axes['ant'])
        timestamps, times = dp.get_times(axes['time'])

    frame = ENU(obstime=time, location=antennas[0].earth_location)
    directions = directions.transform_to(frame)
    t = times.mjd * 86400.
    t -= t[0]
    dt = np.diff(t).mean()
    x = antennas.cartesian.xyz.to(au.km).value.T[1:, :]
    # x[1,:] = x[0,:]
    # x[1,0] += 0.3
    k = directions.cartesian.xyz.value.T
    logger.info(f"Directions: {directions}")
    logger.info(f"Antennas: {x} {antenna_labels}")
    logger.info(f"Times: {t}")
    Na = x.shape[0]
    logger.info(f"Number of antenna to plot: {Na}")
    Nd = k.shape[0]
    Nt = t.shape[0]

    fig, axs = plt.subplots(Na,
                            Nt,
                            sharex=True,
                            sharey=True,
                            figsize=(2 * Nt, 2 * Na),
                            squeeze=False)

    for a in range(Na):
        for i in range(Nt):
            ax = axs[a][i]
            ax = plot_vornoi_map(k[:, 0:2],
                                 dtec[:, a, i],
                                 ax=ax,
                                 colorbar=False)
            if a == (Na - 1):
                ax.set_xlabel(r"$k_{\rm east}$")
            if i == 0:
                ax.set_ylabel(r"$k_{\rm north}$")
            if a == 0:
                ax.set_title(f"Time: {int(t[i])} sec")

    plt.show()
Beispiel #3
0
def get_data(solution_file):
    logger.info("Getting DDS4 data.")
    with DataPack(solution_file, readonly=True) as h:
        select = dict(pol=slice(0, 1, 1),
                      ant=slice(50, 51),
                      dir=slice(0, None, 1),
                      time=slice(0, 100, 1))
        h.select(**select)
        phase, axes = h.phase
        phase = phase[0, ...]
        gain_outliers, _ = h.weights_phase
        gain_outliers = gain_outliers[0, ...] == 1
        amp, axes = h.amplitude
        amp = amp[0, ...]
        _, freqs = h.get_freqs(axes['freq'])
        freqs = freqs.to(au.Hz).value
        _, times = h.get_times(axes['time'])
        times = jnp.asarray(times.mjd, dtype=jnp.float64) * 86400.
        times = times - times[0]
        logger.info("Shape: {}".format(phase.shape))

        (Nd, Na, Nf, Nt) = phase.shape

        @jit
        def smooth(amp, outliers):
            '''
            Smooth amplitudes
            Args:
                amp: [Nt, Nf]
                outliers: [Nt, Nf]
            '''
            weights = jnp.where(outliers, 0., 1.)
            log_amp = jnp.log(amp)
            log_amp = vmap(lambda log_amp, weights: poly_smooth(
                times, log_amp, deg=3, weights=weights))(log_amp.T,
                                                         weights.T).T
            log_amp = vmap(lambda log_amp, weights: poly_smooth(
                freqs, log_amp, deg=3, weights=weights))(log_amp, weights)
            amp = jnp.exp(log_amp)
            return amp

        logger.info("Smoothing amplitudes")
        amp = chunked_pmap(smooth,
                           amp.reshape((Nd * Na, Nf, Nt)).transpose((0, 2, 1)),
                           gain_outliers.reshape((Nd * Na, Nf, Nt)).transpose(
                               (0, 2, 1)))  # Nd*Na,Nt,Nf
        amp = amp.transpose((0, 2, 1)).reshape((Nd, Na, Nf, Nt))
    return gain_outliers, phase, amp, times, freqs
Beispiel #4
0
    def visualise_grid(self):
        for d in range(0, 5):
            with DataPack(self._input_datapack, readonly=True) as dp:
                dp.current_solset = 'sol000'
                dp.select(pol=slice(0, 1, 1), dir=d, time=0)
                tec_grid, axes = dp.tec
                tec_grid = tec_grid[0]
                patch_names, directions_grid = dp.get_directions(axes['dir'])
                antenna_labels, antennas_grid = dp.get_antennas(axes['ant'])
                ref_ant = antennas_grid[0]
                timestamps, times_grid = dp.get_times(axes['time'])
                frame = ENU(location=ref_ant.earth_location,
                            obstime=times_grid[0])
                antennas_grid = ac.ITRS(
                    *antennas_grid.cartesian.xyz,
                    obstime=times_grid[0]).transform_to(frame)

            ant_pos = antennas_grid.cartesian.xyz.to(au.km).value.T
            plt.scatter(ant_pos[:, 0],
                        ant_pos[:, 1],
                        c=tec_grid[0, :, 0],
                        cmap=plt.cm.PuOr)
            plt.xlabel('East [km]')
            plt.ylabel("North [km]")
            plt.title(f"Direction {repr(directions_grid)}")
            plt.show()

        ant_scatter_args = (ant_pos[:, 0], ant_pos[:, 1], tec_grid[0, :, 0])

        for a in [0, 50, 150, 200, 250]:
            with DataPack(self._input_datapack, readonly=True) as dp:
                dp.current_solset = 'sol000'
                dp.select(pol=slice(0, 1, 1), ant=a, time=0)
                tec_grid, axes = dp.tec
                tec_grid = tec_grid[0]
                patch_names, directions_grid = dp.get_directions(axes['dir'])
                antenna_labels, antennas_grid = dp.get_antennas(axes['ant'])
                timestamps, times_grid = dp.get_times(axes['time'])
                frame = ENU(location=ref_ant.earth_location,
                            obstime=times_grid[0])
                antennas_grid = ac.ITRS(
                    *antennas_grid.cartesian.xyz,
                    obstime=times_grid[0]).transform_to(frame)

            _ant_pos = antennas_grid.cartesian.xyz.to(au.km).value.T[0]

            fig, axs = plt.subplots(2, 1, figsize=(4, 8))
            axs[0].scatter(*ant_scatter_args[0:2],
                           c=ant_scatter_args[2],
                           cmap=plt.cm.PuOr,
                           alpha=0.5)
            axs[0].scatter(*_ant_pos[0:2], marker='x', c='red')
            axs[0].set_xlabel('East [km]')
            axs[0].set_ylabel('North [km]')
            pos = 180 / np.pi * np.stack(
                [wrap(directions_grid.ra.rad),
                 wrap(directions_grid.dec.rad)],
                axis=-1)
            plot_vornoi_map(pos, tec_grid[:, 0, 0], fov_circle=True, ax=axs[1])
            axs[1].set_xlabel('RA(2000) [ded]')
            axs[1].set_ylabel('DEC(2000) [ded]')
            plt.show()
Beispiel #5
0
def main(data_dir, working_dir, obs_num, ncpu, plot_results):
    os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}"
    logger.info("Performing data smoothing via tec+const+clock inference.")
    dds4_h5parm = os.path.join(data_dir,
                               'L{}_DDS4_full_merged.h5'.format(obs_num))
    dds5_h5parm = os.path.join(working_dir,
                               'L{}_DDS5_full_merged.h5'.format(obs_num))
    linked_dds5_h5parm = os.path.join(
        data_dir, 'L{}_DDS5_full_merged.h5'.format(obs_num))
    logger.info("Looking for {}".format(dds4_h5parm))
    link_overwrite(dds5_h5parm, linked_dds5_h5parm)
    prepare_soltabs(dds4_h5parm, dds5_h5parm)
    gain_outliers, phase_obs, amp, times, freqs = get_data(
        solution_file=dds4_h5parm)
    phase_mean, phase_uncert, tec_mean, tec_std, tec_outliers, const_mean, const_std = \
        solve_and_smooth(gain_outliers, phase_obs, times, freqs)
    # exit(0)
    logger.info("Storing smoothed phase, amplitudes, tec, const, and clock")
    with DataPack(dds5_h5parm, readonly=False) as h:
        h.current_solset = 'sol000'
        # h.select(pol=slice(0, 1, 1), ant=slice(50, 51), dir=slice(0, None, 1), time=slice(0, 100, 1))
        h.select(pol=slice(0, 1, 1))
        h.phase = np.asarray(phase_mean)[None, ...]
        h.weights_phase = np.asarray(phase_uncert)[None, ...]
        h.amplitude = np.asarray(amp)[None, ...]
        h.tec = np.asarray(tec_mean)[None, ...]
        h.tec_outliers = np.asarray(tec_outliers)[None, ...]
        h.weights_tec = np.asarray(tec_std)[None, ...]
        h.const = np.asarray(const_mean)[None, ...]
        axes = h.axes_phase
        patch_names, _ = h.get_directions(axes['dir'])
        antenna_labels, _ = h.get_antennas(axes['ant'])

    if plot_results:

        diagnostic_data_dir = os.path.join(working_dir, 'diagnostic')
        os.makedirs(diagnostic_data_dir, exist_ok=True)

        logger.info("Plotting results.")
        data_plot_dir = os.path.join(working_dir, 'data_plots')
        os.makedirs(data_plot_dir, exist_ok=True)
        Nd, Na, Nf, Nt = phase_mean.shape
        for ia in range(Na):
            for id in range(Nd):
                fig, axs = plt.subplots(3, 1, sharex=True)
                axs[0].plot(times, tec_mean[id, ia, :], c='black', label='tec')
                ylim = axs[0].get_ylim()
                axs[0].vlines(times[tec_outliers[id, ia, :]],
                              *ylim,
                              colors='red',
                              label='outliers',
                              alpha=0.5)
                axs[0].set_ylim(*ylim)

                axs[1].plot(times,
                            const_mean[id, ia, :],
                            c='black',
                            label='const')
                axs[1].fill_between(
                    times,
                    const_mean[id, ia, :] - const_std[id, ia, :],
                    const_mean[id, ia, :] + const_std[id, ia, :],
                    color='black',
                    alpha=0.2)
                ylim = axs[1].get_ylim()
                axs[1].vlines(times[tec_outliers[id, ia, :]],
                              *ylim,
                              colors='red',
                              label='outliers',
                              alpha=0.5)
                axs[1].set_ylim(*ylim)

                axs[2].plot(times,
                            tec_std[id, ia, :],
                            c='black',
                            label='tec_std')
                ylim = axs[2].get_ylim()
                axs[2].vlines(times[tec_outliers[id, ia, :]],
                              *ylim,
                              colors='red',
                              label='outliers',
                              alpha=0.5)
                axs[2].set_ylim(*ylim)

                axs[0].legend()
                axs[1].legend()
                axs[2].legend()

                axs[0].set_ylabel("DTEC [mTECU]")
                axs[1].set_ylabel("const [rad]")
                axs[2].set_ylabel("DTEC uncert [mTECU]")
                axs[2].set_xlabel("time [s]")

                fig.savefig(
                    os.path.join(
                        data_plot_dir,
                        'solutions_ant{:02d}_dir{:02d}.png'.format(ia, id)))
                plt.close("all")

                fig, axs = plt.subplots(4, 1, sharex=True, sharey=True)
                # phase data with input outliers
                # phase posterior with tec outliers
                # dphase with no outliers
                # phase uncertainty

                axs[0].imshow(phase_obs[id, ia, :, :],
                              vmin=-jnp.pi,
                              vmax=jnp.pi,
                              cmap='twilight',
                              aspect='auto',
                              origin='lower',
                              interpolation='nearest')
                axs[0].imshow(jnp.where(gain_outliers[id, ia, :, :], 1.,
                                        jnp.nan),
                              vmin=0.,
                              vmax=1.,
                              cmap='bone',
                              aspect='auto',
                              origin='lower',
                              interpolation='nearest')
                add_colorbar_to_axes(axs[0],
                                     "twilight",
                                     vmin=-jnp.pi,
                                     vmax=jnp.pi)

                axs[1].imshow(phase_mean[id, ia, :, :],
                              vmin=-jnp.pi,
                              vmax=jnp.pi,
                              cmap='twilight',
                              aspect='auto',
                              origin='lower',
                              interpolation='nearest')
                axs[1].imshow(jnp.where(jnp.isinf(phase_uncert[id, ia, :, :]),
                                        1., jnp.nan),
                              vmin=0.,
                              vmax=1.,
                              cmap='bone',
                              aspect='auto',
                              origin='lower',
                              interpolation='nearest')
                add_colorbar_to_axes(axs[1],
                                     "twilight",
                                     vmin=-jnp.pi,
                                     vmax=jnp.pi)

                dphase = wrap(wrap(phase_mean) - phase_obs)
                vmin = -0.5
                vmax = 0.5

                axs[2].imshow(dphase[id, ia, :, :],
                              vmin=vmin,
                              vmax=vmax,
                              cmap='PuOr',
                              aspect='auto',
                              origin='lower',
                              interpolation='nearest')
                add_colorbar_to_axes(axs[2], "PuOr", vmin=vmin, vmax=vmax)

                vmin = 0.
                vmax = 0.8

                axs[3].imshow(phase_uncert[id, ia, :, :],
                              vmin=vmin,
                              vmax=vmax,
                              cmap='PuOr',
                              aspect='auto',
                              origin='lower',
                              interpolation='nearest')
                add_colorbar_to_axes(axs[3], "PuOr", vmin=vmin, vmax=vmax)

                axs[0].set_ylabel("freq [MHz]")
                axs[1].set_ylabel("freq [MHz]")
                axs[2].set_ylabel("freq [MHz]")
                axs[3].set_ylabel("freq [MHz]")
                axs[3].set_xlabel("time [s]")

                axs[0].set_title("phase data [rad]")
                axs[1].set_title("phase model [rad]")
                axs[2].set_title("phase diff. [rad]")
                axs[3].set_title("phase uncert [rad]")

                fig.savefig(
                    os.path.join(
                        data_plot_dir,
                        'data_comparison_ant{:02d}_dir{:02d}.png'.format(
                            ia, id)))
                plt.close("all")
        # exit(0)

        d = os.path.join(working_dir, 'tec_plots')
        animate_datapack(dds5_h5parm,
                         d,
                         num_processes=(ncpu * 2) // 3,
                         vmin=-60,
                         vmax=60.,
                         observable='tec',
                         phase_wrap=False,
                         plot_crosses=False,
                         plot_facet_idx=True,
                         labels_in_radec=True,
                         per_timestep_scale=True,
                         solset='sol000',
                         cmap=plt.cm.PuOr)
        # os.makedirs(d, exist_ok=True)
        # DatapackPlotter(dds5_h5parm).plot(
        #     fignames=[os.path.join(d, "fig-{:04d}.png".format(j)) for j in range(Nt)],
        #     vmin=-60,
        #     vmax=60., observable='tec', phase_wrap=False, plot_crosses=False,
        #     plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True,
        #     solset='sol000', cmap=plt.cm.PuOr)
        # make_animation(d, prefix='fig', fps=4)

        d = os.path.join(working_dir, 'const_plots')
        animate_datapack(dds5_h5parm,
                         d,
                         num_processes=(ncpu * 2) // 3,
                         vmin=-np.pi,
                         vmax=np.pi,
                         observable='const',
                         phase_wrap=False,
                         plot_crosses=False,
                         plot_facet_idx=True,
                         labels_in_radec=True,
                         per_timestep_scale=True,
                         solset='sol000',
                         cmap=plt.cm.PuOr)

        # os.makedirs(d, exist_ok=True)
        # DatapackPlotter(dds5_h5parm).plot(
        #     fignames=[os.path.join(d, "fig-{:04d}.png".format(j)) for j in range(Nt)],
        #     vmin=-np.pi,
        #     vmax=np.pi, observable='const', phase_wrap=False, plot_crosses=False,
        #     plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True,
        #     solset='sol000', cmap=plt.cm.PuOr)
        # make_animation(d, prefix='fig', fps=4)

        d = os.path.join(working_dir, 'clock_plots')
        animate_datapack(dds5_h5parm,
                         d,
                         num_processes=(ncpu * 2) // 3,
                         vmin=None,
                         vmax=None,
                         observable='clock',
                         phase_wrap=False,
                         plot_crosses=False,
                         plot_facet_idx=True,
                         labels_in_radec=True,
                         per_timestep_scale=True,
                         solset='sol000',
                         cmap=plt.cm.PuOr)

        # os.makedirs(d, exist_ok=True)
        # DatapackPlotter(dds5_h5parm).plot(
        #     fignames=[os.path.join(d, "fig-{:04d}.png".format(j)) for j in range(Nt)],
        #     vmin=None,
        #     vmax=None,
        #     observable='clock', phase_wrap=False, plot_crosses=False,
        #     plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True,
        #     solset='sol000', cmap=plt.cm.PuOr)
        # make_animation(d, prefix='fig', fps=4)

        d = os.path.join(working_dir, 'amplitude_plots')
        animate_datapack(dds5_h5parm,
                         d,
                         num_processes=(ncpu * 2) // 3,
                         log_scale=True,
                         observable='amplitude',
                         phase_wrap=False,
                         plot_crosses=False,
                         plot_facet_idx=True,
                         labels_in_radec=True,
                         per_timestep_scale=True,
                         solset='sol000',
                         cmap=plt.cm.PuOr)
def train_neural_network(datapack: DataPack, batch_size, learning_rate,
                         num_batches):

    with datapack:
        select = dict(pol=slice(0, 1, 1), ant=None, time=slice(0, 1, 1))
        datapack.current_solset = 'sol000'
        datapack.select(**select)
        axes = datapack.axes_tec
        patch_names, directions = datapack.get_directions(axes['dir'])
        antenna_labels, antennas = datapack.get_antennas(axes['ant'])
        timestamps, times = datapack.get_times(axes['time'])

    antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    ref_ant = antennas[0]
    frame = ENU(obstime=times[0], location=ref_ant.earth_location)
    antennas = antennas.transform_to(frame)
    ref_ant = antennas[0]
    directions = directions.transform_to(frame)
    x = antennas.cartesian.xyz.to(au.km).value.T
    k = directions.cartesian.xyz.value.T
    t = times.mjd
    t -= t[len(t) // 2]
    t *= 86400.
    n_screen = 250
    kstar = random.uniform(random.PRNGKey(29428942), (n_screen, 3),
                           minval=jnp.min(k, axis=0),
                           maxval=jnp.max(k, axis=0))
    kstar /= jnp.linalg.norm(kstar, axis=-1, keepdims=True)
    X = jnp.asarray(
        make_coord_array(x, jnp.concatenate([k, kstar], axis=0), t[:, None]))
    x0 = jnp.asarray(antennas.cartesian.xyz.to(au.km).value.T[0, :])
    ref_ant = x0

    kernel = TomographicKernel(x0, ref_ant, RBF(), S_marg=100)
    neural_kernel = NeuralTomographicKernel(x0, ref_ant)

    def loss(params, key):
        keys = random.split(key, 5)
        indices = random.permutation(keys[0],
                                     jnp.arange(X.shape[0]))[:batch_size]
        X_batch = X[indices, :]

        wind_velocity = random.uniform(keys[1],
                                       shape=(3, ),
                                       minval=jnp.asarray([-200., -200., 0.]),
                                       maxval=jnp.asarray([200., 200., 0.
                                                           ])) / 1000.
        bottom = random.uniform(keys[2], minval=50., maxval=500.)
        width = random.uniform(keys[3], minval=40., maxval=300.)
        l = random.uniform(keys[4], minval=1., maxval=30.)
        sigma = 1.
        K = kernel(X_batch,
                   X_batch,
                   bottom,
                   width,
                   l,
                   sigma,
                   wind_velocity=wind_velocity)
        neural_kernel.set_params(params)
        neural_K = neural_kernel(X_batch,
                                 X_batch,
                                 bottom,
                                 width,
                                 l,
                                 sigma,
                                 wind_velocity=wind_velocity)

        return jnp.mean((K - neural_K)**2) / width**2

    init_params = neural_kernel.init_params(random.PRNGKey(42))

    def train_one_batch(params, key):
        l, g = value_and_grad(lambda params: loss(params, key))(params)
        params = tree_multimap(lambda p, g: p - learning_rate * g, params, g)
        return params, l

    final_params, losses = jit(lambda key: scan(
        train_one_batch, init_params, random.split(key, num_batches)))(
            random.PRNGKey(42))

    plt.plot(losses)
    plt.yscale('log')
    plt.show()
Beispiel #7
0
def plot_solution_residuals(datapack, output_folder, data_solset='sol000', solution_solset='posterior_sol',
                            ant_sel=None, time_sel=None, dir_sel=None, freq_sel=None, pol_sel=None):
    def _wrap(phi):
        return np.angle(np.exp(1j * phi))

    if not isinstance(datapack, str):
        datapack = datapack.filename

    output_folder = os.path.abspath(output_folder)
    os.makedirs(output_folder, exist_ok=True)

    solsets = [data_solset, solution_solset]
    with DataPack(datapack, readonly=True) as datapack:
        datapack.switch_solset(data_solset)
        datapack.select(ant=ant_sel, time=time_sel, dir=dir_sel, freq=freq_sel, pol=pol_sel)

        phase, axes = datapack.phase
        timestamps, times = datapack.get_times(axes['time'])
        antenna_labels, antennas = datapack.get_antennas(axes['ant'])
        patch_names, directions = datapack.get_sources(axes['dir'])
        _, freqs = datapack.get_freqs(axes['freq'])
        pols, _ = datapack.get_pols(axes['pol'])
        Npol, Nd, Na, Nf, Nt = phase.shape

        datapack.switch_solset(solution_solset)
        datapack.select(ant=ant_sel, time=time_sel, dir=dir_sel, freq=freq_sel, pol=pol_sel)
        tec, _ = datapack.tec
        phase_pred = -8.448e9 * tec[..., None, :] / freqs[:, None]

        res = _wrap(_wrap(phase) - _wrap(phase_pred))
        cbar = None

        for p in range(Npol):
            for a in range(Na):

                M = int(np.ceil(np.sqrt(Nd)))
                fig, axs = plt.subplots(nrows=2 * M, ncols=M, sharex=True, figsize=(M * 4, 1 * M * 4),
                                        gridspec_kw={'height_ratios': [1.5, 1] * M})
                fig.subplots_adjust(wspace=0., hspace=0.)
                fig.subplots_adjust(right=0.85)
                cbar_ax = fig.add_axes([0.875, 0.15, 0.025, 0.7])

                vmin = -1.
                vmax = 1.
                norm = plt.Normalize(vmin, vmax)

                for row in range(0, 2 * M, 2):
                    for col in range(M):
                        ax1 = axs[row][col]
                        ax2 = axs[row + 1][col]

                        d = col + row // 2 * M
                        if d >= Nd:
                            continue

                        img = ax1.imshow(res[p, d, a, :, :], origin='lower', aspect='auto',
                                         extent=(times[0].mjd * 86400., times[-1].mjd * 86400., freqs[0], freqs[-1]),
                                         cmap=plt.cm.jet, norm=norm)
                        ax1.text(0.05, 0.95, axes['dir'][d], horizontalalignment='left', verticalalignment='top',
                                 transform=ax1.transAxes, backgroundcolor=(1., 1., 1., 0.5))

                        ax1.set_ylabel('frequency [Hz]')
                        ax1.legend()

                        mean = res[p, d, a, :, :].mean(0)
                        t = np.arange(len(times))
                        ax2.plot(times.mjd * 86400, mean, label=r'$\mathbb{E}_\nu[\delta\phi]$')
                        std = res[p, d, a, :, :].std(0)
                        ax2.fill_between(times.mjd * 86400, mean - std, mean + std, alpha=0.5,
                                         label=r'$\pm\sigma_{\delta\phi}$')
                        ax2.set_xlabel('Time [mjs]')
                        ax2.set_xlim(times[0].mjd * 86400., times[-1].mjd * 86400.)
                        ax2.set_ylim(-np.pi, np.pi)
                #                         ax2.legend()

                fig.colorbar(img, cax=cbar_ax, orientation='vertical', label='phase dev. [rad]')
                filename = "{}_v_{}_{}_{}.png".format(data_solset, solution_solset, axes['ant'][a], axes['pol'][p])
                plt.savefig(os.path.join(output_folder, filename))
                plt.close('all')
Beispiel #8
0
def plot_freq_vs_time(datapack, output_folder, solset='sol000', soltab='phase', phase_wrap=True, log_scale=False,
                      ant_sel=None, time_sel=None, dir_sel=None, freq_sel=None, pol_sel=None, vmin=None, vmax=None):
    if isinstance(datapack, DataPack):
        datapack = datapack.filename

    with DataPack(datapack, readonly=True) as datapack:
        datapack.switch_solset(solset)
        datapack.select(ant=ant_sel, time=time_sel, dir=dir_sel, freq=freq_sel, pol=pol_sel)
        obs, axes = datapack.__getattr__(soltab)
        if soltab.startswith('weights_'):
            obs = np.sqrt(np.abs(1. / obs))  # uncert from weights = 1/var
            phase_wrap = False
        if 'pol' in axes.keys():
            # plot only first pol selected
            obs = obs[0, ...]

        # obs is dir, ant, freq, time
        antenna_labels, antennas = datapack.get_antennas(axes['ant'])
        patch_names, directions = datapack.get_sources(axes['dir'])
        timestamps, times = datapack.get_times(axes['time'])
        freq_labels, freqs = datapack.get_freqs(axes['freq'])

        if phase_wrap:
            obs = np.angle(np.exp(1j * obs))
            vmin = -np.pi
            vmax = np.pi
            cmap = phase_cmap
        else:
            vmin = vmin if vmin is not None else np.percentile(obs.flatten(), 1)
            vmax = vmax if vmax is not None else np.percentile(obs.flatten(), 99)
            cmap = plt.cm.bone
        if log_scale:
            obs = np.log10(obs)

        Na = len(antennas)
        Nt = len(times)
        Nd = len(directions)
        Nf = len(freqs)

        M = int(np.ceil(np.sqrt(Na)))

        output_folder = os.path.abspath(output_folder)
        os.makedirs(output_folder, exist_ok=True)
        for k in range(Nd):
            filename = os.path.join(os.path.abspath(output_folder), "{}_{}_dir_{}.png".format(solset, soltab, k))
            logger.info("Plotting {}".format(filename))
            fig, axs = plt.subplots(nrows=M, ncols=M, figsize=(4 * M, 4 * M), sharex=True, sharey=True)
            for i in range(M):

                for j in range(M):
                    l = j + M * i
                    if l >= Na:
                        continue
                    im = axs[i][j].imshow(obs[k, l, :, :], origin='lower', cmap=cmap, aspect='auto', vmin=vmin,
                                          vmax=vmax,
                                          extent=(times[0].mjd * 86400., times[-1].mjd * 86400., freqs[0], freqs[1]))
            plt.tight_layout()
            fig.subplots_adjust(right=0.8)
            cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
            fig.colorbar(im, cax=cbar_ax)
            plt.savefig(filename)
        plt.close('all')
Beispiel #9
0
def plot_data_vs_solution(datapack, output_folder, data_solset='sol000', solution_solset='posterior_sol',
                          show_prior_uncert=False,
                          ant_sel=None, time_sel=None, dir_sel=None, freq_sel=None, pol_sel=None):
    def _wrap(phi):
        return np.angle(np.exp(1j * phi))

    if isinstance(datapack, DataPack):
        datapack = datapack.filename

    output_folder = os.path.abspath(output_folder)
    os.makedirs(output_folder, exist_ok=True)

    solsets = [data_solset, solution_solset]
    with DataPack(datapack, readonly=True) as datapack:
        phases = []
        stds = []
        datapack.switch_solset(data_solset)
        datapack.select(ant=ant_sel, time=time_sel, dir=dir_sel, freq=freq_sel, pol=pol_sel)
        weights, axes = datapack.weights_phase
        _, freqs = datapack.get_freqs(axes['freq'])
        phase, _ = datapack.phase
        std = np.sqrt(np.abs(1. / weights))
        timestamps, times = datapack.get_times(axes['time'])
        phases.append(_wrap(phase))
        stds.append(std)

        tec_conversion = -8.4480e9 / freqs[None, None, None, :, None]

        datapack.switch_solset(solution_solset)
        datapack.select(ant=ant_sel, time=time_sel, dir=dir_sel, freq=freq_sel, pol=pol_sel)
        weights, _ = datapack.weights_tec
        tec, _ = datapack.tec
        std = np.sqrt(np.abs(1. / weights))[:, :, :, None, :] * np.abs(tec_conversion)
        phases.append(_wrap(tec[:, :, :, None, :] * tec_conversion))
        stds.append(std)

        for phase in phases:
            for s, S in zip(phase.shape, phases[0].shape):
                assert s == S
        Npol, Nd, Na, Nf, Nt = phases[0].shape
        fig, ax = plt.subplots()
        for p in range(Npol):
            for d in range(Nd):
                for a in range(Na):
                    for f in range(Nf):
                        ax.cla()
                        ###
                        # Data
                        phase = phases[0]
                        std = stds[0]
                        label = "{} {} {:.1f}MHz {}:{}".format(data_solset, axes['pol'][p], axes['freq'][f] / 1e6,
                                                               axes['ant'][a], axes['dir'][d])
                        if show_prior_uncert:
                            ax.fill_between(times.mjd, phase[p, d, a, f, :] - std[p, d, a, f, :],
                                            phase[p, d, a, f, :] + std[p, d, a, f, :], alpha=0.5,
                                            label=r'$\pm2\hat{\sigma}_\phi$')  # ,color='blue')
                        ax.scatter(times.mjd, phase[p, d, a, f, :], marker='+', alpha=0.3, color='black', label=label)

                        ###
                        # Solution
                        phase = phases[1]
                        std = stds[1]
                        label = "Solution: {}".format(solution_solset)
                        ax.fill_between(times.mjd, phase[p, d, a, f, :] - std[p, d, a, f, :],
                                        phase[p, d, a, f, :] + std[p, d, a, f, :], alpha=0.5,
                                        label=r'$\pm\hat{\sigma}_\phi$')  # ,color='blue')
                        ax.scatter(times.mjd, phase[p, d, a, f, :], label=label, marker='.', s=5.)

                        ax.set_xlabel('Time [mjd]')
                        ax.set_ylabel('Phase deviation [rad.]')
                        ax.legend()
                        filename = "{}_v_{}_{}_{}_{}_{}MHz.png".format(data_solset, solution_solset, axes['ant'][a],
                                                                       axes['dir'][d], axes['pol'][p],
                                                                       axes['freq'][f] / 1e6)
                        ax.set_ylim(-np.pi, np.pi)
                        plt.savefig(os.path.join(output_folder, filename))
        plt.close('all')
Beispiel #10
0
 def __init__(self, datapack):
     if isinstance(datapack, str):
         datapack = DataPack(filename=datapack, readonly=True)
     self.datapack = datapack