def run_solve(flags):

    output_folder = os.path.abspath(flags.output_dir)
    os.makedirs(output_folder, exist_ok=True)
    filename = os.path.join(output_folder, 
            "datapack_{:.1e},_{:.1e}_{:.1e}_{:.1e}.hdf5".format(flags.sim_tec_scale, flags.sim_tec_noise, flags.sim_time_corr, flags.sim_dir_corr))
    datapack = make_example_datapack(flags.sim_Nd,flags.sim_Nf,flags.sim_Nt,pols=['XX'], time_corr=flags.sim_time_corr,dir_corr=flags.sim_dir_corr*np.pi/180.,tec_scale=flags.sim_tec_scale,tec_noise=flags.sim_tec_noise,name=filename,clobber=flags.sim_clobber)
    logging.info(datapack)
    solver = LMCPhaseOnlySolver(output_folder, datapack)
    solver.run(**vars(flags))
예제 #2
0
 def _freq_sel(s):
     logging.info("Parsing {}".format(s))
     if s.lower() == 'none':
         return None
     elif '/' in s:#slice
         s = s.split("/")
         assert len(s) == 3, "Proper slice notations is 'start/stop/step'"
         return slice(int(s[0]) if s[0].lower() != 'none' else None, 
                 int(s[1]) if s[1].lower() != 'none' else None, 
                 int(s[2])if s[2].lower() != 'none' else None)
     else:
         return s
예제 #3
0
def import_data(ndppp_dd_sols, out_datapack, clobber, ant_sel, time_sel,
                freq_sel, pol_sel, dir_sel):
    """Create a datapack from the direction dependent NDPPP solutions.
    
    """
    if os.path.exists(out_datapack):
        logging.info("{} exists".format(out_datapack))
        if clobber:
            logging.info("Deleting old datapack")
            os.unlink(out_datapack)
        else:
            raise ValueError(
                "{} already exists and non clobber".format(out_datapack))

    with DataPack(ndppp_dd_sols, readonly=True) as f_dd:
        f_dd.select(ant=ant_sel,
                    time=time_sel,
                    freq=freq_sel,
                    dir=dir_sel,
                    pol=pol_sel)
        freqs = np.array([125., 135., 145., 155., 165.]) * 1e6

        with DataPack(out_datapack) as out:
            patch_names, directions = f_dd.sources
            antenna_labels, antennas = f_dd.antennas
            out.add_antennas()  #default is lofar
            out.add_sources(directions, patch_names=patch_names)

            tec, axes = f_dd.tec  #(npol), nt, na, nd,1
            scalarphase, axes = f_dd.scalarphase  #(npol), nt, na, nd,1

            if 'pol' in axes.keys():  #(1,3595,62,1,42,1)
                tec = tec[..., 0].transpose((0, 3, 2, 1))  #npol,nd,na,nt
                scalarphase = scalarphase[..., 0].transpose(
                    (0, 3, 2, 1))  #npol,nd,na,nt
                phase = tec_conversion * tec[:, :, :, None, :] / freqs[
                    None, None, None, :, None] + scalarphase[:, :, :, None, :]
            else:
                tec = tec[..., 0].transpose((2, 1, 0))  #nd,na,nt
                scalarphase = scalarphase[..., 0].transpose(
                    (2, 1, 0))  #nd,na,nt
                phase = tec_conversion * tec[None, :, :, None, :] / freqs[
                    None, None, None, :, None] + scalarphase[None, :, :,
                                                             None, :]
                axes['pol'] = ['XX']

            out.add_freq_dep_tab('phase',
                                 axes['time'],
                                 freqs,
                                 pols=axes['pol'],
                                 ants=axes['ant'],
                                 dirs=axes['dir'],
                                 vals=_wrap(phase))
    logging.info("Done importing data")
예제 #4
0
def test_new_solver():

    #    opt = {'initial_learning_rate': 0.0469346965745387, 'learning_rate_steps': 2.3379450095649053, 'learning_rate_decay': 2.3096977604598385, 'minibatch_size': 257, 'dof_ratio': 15.32485312998133, 'gamma_start': 1.749795137201838e-05, 'gamma_add': 0.00014740343452076625, 'gamma_mul': 1.0555893705407017, 'gamma_max': 0.1063958902418518, 'gamma_fallback': 0.15444066000616663}

    opt = {
        'initial_learning_rate': 0.030035792298837113,
        'learning_rate_steps': 2.3915384159241064,
        'learning_rate_decay': 2.6685242978751798,
        'minibatch_size': 128,
        'dof_ratio': 10.,
        'gamma_start': 6.876944103773131e-05,
        'gamma_add': 1e-4,
        'gamma_mul': 1.04,
        'gamma_max': 0.14,
        'gamma_fallback': 0.1,
        'priors': {
            'kern_time_ls': 50.,
            'kern_dir_ls': 0.80
        }
    }

    datapack = '/net/lofar1/data1/albert/git/bayes_tec/scripts/data/killms_datapack_2.hdf5'
    run_dir = 'run_dir_killms_kern_opt'
    output_solset = "posterior_sol_kern_opt"

    time_sel = slice(50, 150, 1)
    ant_sel = "RS210HBA"

    import itertools
    res = []
    for s in itertools.product(['product', 'sum'], ['rbf', 'm32', 'm52'],
                               ['rbf', 'm32', 'm52']):
        name = "_".join(s)
        logging.info("Running {}".format(name))
        solver = PhaseOnlySolver(run_dir, datapack)
        solver._build_kernel = create_kern(name)

        lik = solver.solve(output_solset=output_solset,
                           solset='sol000',
                           jitter=1e-6,
                           tec_scale=0.005,
                           screen_res=30,
                           remake_posterior_solsets=False,
                           iterations=500,
                           intra_op_threads=0,
                           inter_op_threads=0,
                           ant_sel=ant_sel,
                           time_sel=time_sel,
                           pol_sel=slice(0, 1, 1),
                           debug=False,
                           W_diag=True,
                           freq_sel=slice(0, 48, 1),
                           plot_level=-1,
                           return_likelihood=True,
                           num_likelihood_samples=100,
                           **opt)
        res.append([name, -lik[0] / 1e6, lik[1] / 1e6])
        logging.info("{} results {}".format(name, res))
        with open("kern_opt_res.csv", 'a') as f:
            f.write("{}\n".format(
                str(res[-1]).replace('[', '').replace(']', '')))
예제 #5
0
    optional.add_argument("--plot_screen", type="bool", default=False,
                      help="Whether to plot screen. Expects properly shaped array.")


    optional.add_argument("--num_processes", type=int, default=1,
                      help="Number of parallel plots")
    optional.add_argument("--tec_eval_freq", type=float, default=None,
                      help="Freq to eval tec at.")
    optional.add_argument("--output_folder", type=str, default="./figs",
                       help="""The output folder.""")
    optional.add_argument("--observable", type=str, default="phase",
                       help="""The soltab to plot""")
    optional.add_argument("--phase_wrap", type="bool", default=True,
                       help="""Whether to wrap the observable""")
    optional.add_argument("--solset", type=str, default="sol000",
                       help="""The solset to plot""")

    optional.add_argument("--vmin", type=float, default=None,
                       help="""The min value if phase_wrap is False""")
    optional.add_argument("--vmax", type=float, default=None,
                       help="""The max value if phase_wrap is False""")


if __name__=='__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    add_args(parser)
    flags, unparsed = parser.parse_known_args()
    logging.info(vars(flags))
    run_plot(**vars(flags))