Ejemplo n.º 1
0
 def save_data(
         self,
         data_dir: Union[str, Path],
         skip_keys: list[str] = None,
         use_hdf5: bool = True,
         save_dataset: bool = True,
         # compression: str = 'gzip',
 ):
     """Save `self.data` entries to individual files in `output_dir`."""
     success = 0
     data_dir = Path(data_dir)
     io.check_else_make_dir(str(data_dir))
     if use_hdf5:
         # ------------ Save using h5py -------------------
         hfile = data_dir.joinpath(f'data_rank{RANK}.hdf5')
         self.to_h5pyfile(hfile, skip_keys=skip_keys)
         success = 1
     if save_dataset:
         # ----- Save to `netCDF4` file using hierarchical grouping ------
         self.save_dataset(data_dir)
         success = 1
     if success == 0:
         # -------------------------------------------------
         # Save each `{key: val}` entry in `self.data` as a 
         # compressed file, named `f'{key}.z'` .
         # -------------------------------------------------
         skip_keys = [] if skip_keys is None else skip_keys
         logger.info(f'Saving individual entries to {data_dir}.')
         for key, val in self.data.items():
             if key in skip_keys:
                 continue
             out_file = os.path.join(data_dir, f'{key}.z')
             io.savez(np.array(val), out_file)
Ejemplo n.º 2
0
 def save_networks(self, log_dir):
     """Save networks to disk."""
     models_dir = os.path.join(log_dir, 'training', 'models')
     io.check_else_make_dir(models_dir)
     eps_file = os.path.join(models_dir, 'eps.z')
     io.savez(self.eps.numpy(), eps_file, name='eps')
     if self.config.separate_networks:
         xnet_paths = [
             os.path.join(models_dir, f'dynamics_xnet{i}')
             for i in range(self.config.num_steps)
         ]
         vnet_paths = [
             os.path.join(models_dir, f'dynamics_vnet{i}')
             for i in range(self.config.num_steps)
         ]
         for idx, (xf, vf) in enumerate(zip(xnet_paths, vnet_paths)):
             xnet = self.xnet[idx]  # type: tf.keras.models.Model
             vnet = self.vnet[idx]  # type: tf.keras.models.Model
             io.log(f'Saving `xnet{idx}` to {xf}.')
             io.log(f'Saving `vnet{idx}` to {vf}.')
             xnet.save(xf)
             vnet.save(vf)
     else:
         xnet_paths = os.path.join(models_dir, 'dynamics_xnet')
         vnet_paths = os.path.join(models_dir, 'dynamics_vnet')
         io.log(f'Saving `xnet` to {xnet_paths}.')
         io.log(f'Saving `vnet` to {vnet_paths}.')
         self.xnet.save(xnet_paths)
         self.vnet.save(vnet_paths)
Ejemplo n.º 3
0
    def save_data(self, data_dir, rank=0):
        """Save `self.data` entries to individual files in `output_dir`."""
        if rank != 0:
            return

        io.check_else_make_dir(data_dir)
        for key, val in self.data.items():
            out_file = os.path.join(data_dir, f'{key}.z')
            io.savez(np.array(val), out_file)
Ejemplo n.º 4
0
def _deal_with_new_data(tint_file: str, save: bool = False):
    loaded = io.loadz(tint_file)
    if 'tau_int_data' in tint_file:
        head = os.path.dirname(os.path.split(tint_file)[0])
    else:
        head, _ = os.path.split(str(tint_file))
    tint_arr = loaded.get('tint', loaded.get('tau_int', None))
    narr = loaded.get('narr', loaded.get('draws', None))
    if tint_arr is None or narr is None:
        raise ValueError(f'Unable to load from {tint_file}.')

    try:
        params = _get_important_params(head)
    except (ValueError, FileNotFoundError) as err:
        logger.info(f'Error loading params from {head}, skipping!')
        raise err

    lf = params.get('lf', None)
    eps = params.get('eps', None)
    traj_len = params.get('traj_len', None)
    beta = params.get('beta', None)
    data = {
        'tint': tint_arr,
        'narr': narr,
        'run_dir': head,
        'lf': lf,
        'eps': eps,
        'traj_len': traj_len,
        'beta': beta,
        'run_params': params.get('run_params', None),
    }
    if save:
        outfile = os.path.join(head, 'tint_data.z')
        logger.info(f'Saving tint data to: {outfile}.')
        io.savez(data, outfile, 'tint_data')

    return data
Ejemplo n.º 5
0
def save_layer_weights(net, out_file):
    """Save all layer weights from `net` to `out_file`."""
    weights_dict = get_layer_weights(net)
    io.savez(weights_dict, out_file, name=net.name)
Ejemplo n.º 6
0
def calc_tau_int_from_dir(
        input_path: str,
        hmc: bool = False,
        px_cutoff: float = None,
        therm_frac: float = 0.2,
        num_pts: int = 50,
        nstart: int = 100,
        make_plot: bool = True,
        save_data: bool = True,
        keep_charges: bool = False,
):
    """
    NOTE: `load_charges_from_dir` returns `output_arr`:
      `output_arr` is a list of dicts, each of which look like:
            output = {
                'beta': beta,
                'lf': lf,
                'eps': eps,
                'traj_len': lf * eps,
                'qarr': charges,
                'run_params': params,
                'run_dir': run_dir,
            }
    """
    output_arr = load_charges_from_dir(input_path, hmc=hmc)
    if output_arr is None:
        logger.info(', '.join([
            'WARNING: Skipping entry!',
            f'\t unable to load charge data from {input_path}',
        ]))

        return None

    tint_data = {}
    for output in output_arr:
        run_dir = output['run_dir']
        beta = output['beta']
        lf = output['lf']
        #  eps = output['eps']
        #  traj_len = output['traj_len']
        run_params = output['run_params']

        if hmc:
            data_dir = os.path.join(input_path, 'run_data')
            plot_dir = os.path.join(input_path, 'plots', 'tau_int_plots')
        else:
            run_dir = output['run_params']['run_dir']
            data_dir = os.path.join(run_dir, 'run_data')
            plot_dir = os.path.join(run_dir, 'plots')

        outfile = os.path.join(data_dir, 'tau_int_data.z')
        outfile1 = os.path.join(data_dir, 'tint_data.z')
        fdraws = os.path.join(plot_dir, 'tau_int_vs_draws.pdf')
        ftlen = os.path.join(plot_dir, 'tau_int_vs_traj_len.pdf')
        c1 = os.path.isfile(outfile)
        c11 = os.path.isfile(outfile1)
        c2 = os.path.isfile(fdraws)
        c3 = os.path.isfile(ftlen)
        if c1 or c11 or c2 or c3:
            loaded = io.loadz(outfile)
            output.update(loaded)
            logger.info(', '.join([
                'WARNING: Loading existing data'
                f'\t Found existing data at: {outfile}.',
            ]))
            loaded = io.loadz(outfile)
            n = loaded.get('draws', loaded.get('narr', None))
            tint = loaded.get('tau_int', loaded.get('tint', None))
            output.update(loaded)

        xeps_check = 'xeps' in output['run_params'].keys()
        veps_check = 'veps' in output['run_params'].keys()
        if xeps_check and veps_check:
            xeps = tf.reduce_mean(output['run_params']['xeps'])
            veps = tf.reduce_mean(output['run_params']['veps'])
            eps = tf.reduce_mean([xeps, veps]).numpy()
        else:
            eps = output['eps']
            if isinstance(eps, list):
                eps = tf.reduce_mean(eps)
            elif tf.is_tensor(eps):
                try:
                    eps = eps.numpy()
                except AttributeError:
                    eps = tf.reduce_mean(eps)

        traj_len = lf * eps

        qarr, _ = therm_arr(output['qarr'], therm_frac=therm_frac)

        n, tint = calc_autocorr(qarr.T, num_pts=num_pts, nstart=nstart)
        #  output.update({
        #      'draws': n,
        #      'tau_int': tint,
        #      'qarr.shape': qarr.shape,
        #  })
        tint_data[run_dir] = {
            'run_dir': run_dir,
            'run_params': run_params,
            'lf': lf,
            'eps': eps,
            'traj_len': traj_len,
            'narr': n,
            'tint': tint,
            'qarr.shape': qarr.shape,
        }
        if save_data:
            io.savez(tint_data, outfile, 'tint_data')
            #  io.savez(output, outfile, name='tint_data')

        if make_plot:
            #  fbeta = os.path.join(plot_dir, 'tau_int_vs_beta.pdf')
            io.check_else_make_dir(plot_dir)
            prefix = 'HMC' if hmc else 'L2HMC'
            xlabel = 'draws'
            ylabel = r'$\tau_{\mathrm{int}}$ (estimate)'
            title = (f'{prefix}, '
                     + r'$\beta=$' + f'{beta}, '
                     + r'$N_{\mathrm{lf}}=$' + f'{lf}, '
                     + r'$\varepsilon=$' + f'{eps:.2g}, '
                     + r'$\lambda=$' + f'{traj_len:.2g}')

            _, ax = plt.subplots(constrained_layout=True)
            best = []
            for t in tint.T:
                _ = ax.plot(n, t, marker='.', color='k')
                best.append(t[-1])

            _ = ax.set_ylabel(ylabel)
            _ = ax.set_xlabel(xlabel)
            _ = ax.set_title(title)

            _ = ax.set_xscale('log')
            _ = ax.set_yscale('log')
            _ = ax.grid(alpha=0.4)
            logger.info(f'Saving figure to: {fdraws}')
            _ = plt.savefig(fdraws, dpi=400, bbox_inches='tight')
            plt.close('all')

            _, ax = plt.subplots()
            for b in best:
                _ = ax.plot(traj_len, b, marker='.', color='k')
            _ = ax.set_ylabel(ylabel)
            _ = ax.set_xlabel(r'trajectory length, $\lambda$')
            _ = ax.set_title(title)
            _ = ax.set_yscale('log')
            _ = ax.grid(True, alpha=0.4)
            logger.info(f'Saving figure to: {ftlen}')
            _ = plt.savefig(ftlen, dpi=400, bbox_inches='tight')
            plt.close('all')

    return tint_data
Ejemplo n.º 7
0
def deal_with_new_data(path: str, save: bool = False):
    tint_data = {}
    tint_files = [
        x for x in Path(path).rglob('*tint_data.z*') if x.is_file()
    ]
    tau_int_files = [
        x for x in Path(path).rglob(f'*tau_int_data.z*') if x.is_file()
    ]
    bad_dirs = {}
    for tint_file in tint_files:
        try:
            tint = _deal_with_new_data(tint_file, save=save)
        except (ValueError, FileNotFoundError) as err:
            logger.info(f'Unable to get tint data from: {tint_file}, skipping!')
            continue

        beta = tint['beta']
        traj_len = tint['traj_len']
        run_dir = tint['run_dir']
        if beta not in tint_data:
            tint_data[beta] = {
                tint['traj_len']: tint
            }
        else:
            if traj_len not in tint_data[beta]:
                tint_data[beta][traj_len] = tint
            else:
                # If traj_len is already in tint_data[beta], jitter it by some
                # small amount to create new entry in `tint_data[beta]`
                #  traj_len_modified = traj_len + 1e-6 * np.random.randn()
                traj_len += 1e-6 * np.random.randn()
                tint_data[beta][traj_len] = tint

        if save:
            outfile = os.path.join(run_dir, 'tint_data.z')
            logger.info(f'Saving tint_data to: {outfile}.')
            io.savez(tint, outfile, 'tint_data')

    for tau_int_file in tau_int_files:
        try:
            tint = _deal_with_new_data(tau_int_file, save=save)
        except (ValueError, FileNotFoundError) as err:
            logger.info(f'Unable to get tint data from: {tau_int_file}, skipping!')
            continue
        beta = tint['beta']
        traj_len = tint['traj_len']
        run_dir = tint['run_dir']
        if beta not in tint_data:
            tint_data[beta] = {
                tint['traj_len']: tint
            }
        else:
            if traj_len not in tint_data[beta]:
                tint_data[beta][traj_len] = tint
            else:
                traj_len += 1e-6 * np.random.randn()
                tint_data[beta][traj_len] = tint
        if save:
            outfile = os.path.join(run_dir, 'tint_data.z')
            logger.info(f'Saving tint data to: {outfile}.')
            io.savez(tint, outfile, 'tint_data')

    return tint_data
Ejemplo n.º 8
0
def run(
        dynamics: GaugeDynamics,
        configs: dict[str, Any],
        x: tf.Tensor = None,
        beta: float = None,
        runs_dir: str = None,
        make_plots: bool = True,
        therm_frac: float = 0.33,
        num_chains: int = 16,
        save_x: bool = False,
        md_steps: int = 50,
        skip_existing: bool = False,
        save_dataset: bool = True,
        use_hdf5: bool = False,
        skip_keys: list[str] = None,
        run_steps: int = None,
) -> InferenceResults:
    """Run inference. (Note: Higher-level than `run_dynamics`)."""
    if not IS_CHIEF:
        return InferenceResults(None, None, None, None, None)

    if num_chains > 16:
        logger.warning(f'Reducing `num_chains` from: {num_chains} to {16}.')
        num_chains = 16

    if run_steps is None:
        run_steps = configs.get('run_steps', 50000)

    if beta is None:
        beta = configs.get('beta_final', configs.get('beta', None))

    assert beta is not None

    logdir = configs.get('log_dir', configs.get('logdir', None))
    if runs_dir is None:
        rs = 'inference_hmc' if dynamics.config.hmc else 'inference'
        runs_dir = os.path.join(logdir, rs)

    io.check_else_make_dir(runs_dir)
    run_dir = io.make_run_dir(configs=configs, base_dir=runs_dir,
                              beta=beta, skip_existing=skip_existing)
    logger.info(f'run_dir: {run_dir}')
    data_dir = os.path.join(run_dir, 'run_data')
    summary_dir = os.path.join(run_dir, 'summaries')
    log_file = os.path.join(run_dir, 'run_log.txt')
    io.check_else_make_dir([run_dir, data_dir, summary_dir])
    writer = tf.summary.create_file_writer(summary_dir)
    writer.set_as_default()

    configs['logging_steps'] = 1
    if x is None:
        x = convert_to_angle(tf.random.uniform(shape=dynamics.x_shape,
                                               minval=-PI, maxval=PI))

    # == RUN DYNAMICS =======================================================
    nw = dynamics.net_weights
    inf_type = 'HMC' if dynamics.config.hmc else 'inference'
    logger.info(', '.join([f'Running {inf_type}', f'beta={beta}', f'nw={nw}']))
    t0 = time.time()
    results = run_dynamics(dynamics, flags=configs, beta=beta, x=x,
                           writer=writer, save_x=save_x, md_steps=md_steps)
    logger.info(f'Done running {inf_type}. took: {time.time() - t0:.4f} s')
    # =======================================================================

    run_data = results.run_data
    run_data.update_dirs({'log_dir': logdir, 'run_dir': run_dir})
    run_params = {
        'hmc': dynamics.config.hmc,
        'run_dir': run_dir,
        'beta': beta,
        'run_steps': run_steps,
        'plaq_weight': dynamics.plaq_weight,
        'charge_weight': dynamics.charge_weight,
        'x_shape': dynamics.x_shape,
        'num_steps': dynamics.config.num_steps,
        'net_weights': dynamics.net_weights,
        'input_shape': dynamics.x_shape,
    }

    inf_log_fpath = os.path.join(run_dir, 'inference_log.txt')
    with open(inf_log_fpath, 'a') as f:
        f.writelines('\n'.join(results.data_strs))

    traj_len = dynamics.config.num_steps * tf.reduce_mean(dynamics.xeps)
    if hasattr(dynamics, 'xeps') and hasattr(dynamics, 'veps'):
        xeps_avg = tf.reduce_mean(dynamics.xeps)
        veps_avg = tf.reduce_mean(dynamics.veps)
        traj_len = tf.reduce_sum(dynamics.xeps)
        run_params.update({
            'xeps': dynamics.xeps,
            'traj_len': traj_len,
            'veps': dynamics.veps,
            'xeps_avg': xeps_avg,
            'veps_avg': veps_avg,
            'eps_avg': (xeps_avg + veps_avg) / 2.,
        })

    elif hasattr(dynamics, 'eps'):
        run_params.update({
            'eps': dynamics.eps,
        })

    io.save_params(run_params, run_dir, name='run_params')

    save_times = {}
    plot_times = {}
    if make_plots:
        output = plot_data(data_container=run_data,
                           configs=configs,
                           params=run_params,
                           out_dir=run_dir,
                           hmc=dynamics.config.hmc,
                           therm_frac=therm_frac,
                           num_chains=num_chains,
                           profile=True,
                           cmap='crest',
                           logging_steps=1)

        save_times = io.SortedDict(**output['save_times'])
        plot_times = io.SortedDict(**output['plot_times'])
        dt1 = io.SortedDict(**plot_times['data_container.plot_dataset'])
        plot_times['data_container.plot_dataset'] = dt1

        tint_data = {
            'beta': beta,
            'run_dir': run_dir,
            'traj_len': traj_len,
            'run_params': run_params,
            'eps': run_params['eps_avg'],
            'lf': run_params['num_steps'],
            'narr': output['tint_dict']['narr'],
            'tint': output['tint_dict']['tint'],
        }

        t0 = time.time()
        tint_file = os.path.join(run_dir, 'tint_data.z')
        io.savez(tint_data, tint_file, 'tint_data')
        save_times['savez_tint_data'] = time.time() - t0

    t0 = time.time()

    logfile = os.path.join(run_dir, 'inference.log')
    run_data.save_and_flush(data_dir=data_dir,
                            log_file=logfile,
                            use_hdf5=use_hdf5,
                            skip_keys=skip_keys,
                            save_dataset=save_dataset)

    save_times['run_data.flush_data_strs'] = time.time() - t0

    t0 = time.time()
    try:
        run_data.write_to_csv(logdir, run_dir, hmc=dynamics.config.hmc)
    except TypeError:
        logger.warning(f'Unable to write to csv. Continuing...')

    save_times['run_data.write_to_csv'] = time.time() - t0

    # t0 = time.time()
    # io.save_inference(run_dir, run_data)
    # save_times['io.save_inference'] = time.time() - t0
    profdir = os.path.join(run_dir, 'profile_info')
    io.check_else_make_dir(profdir)
    io.save_dict(plot_times, profdir, name='plot_times')
    io.save_dict(save_times, profdir, name='save_times')

    return InferenceResults(dynamics=results.dynamics,
                            x=results.x, x_arr=results.x_arr,
                            run_data=results.run_data,
                            data_strs=results.data_strs)
Ejemplo n.º 9
0
def plot_data(
        data_container: "DataContainer", #  type: "DataContainer",  # noqa:F821
        configs: dict = None,
        out_dir: str = None,
        therm_frac: float = 0,
        params: Union[dict, AttrDict] = None,
        hmc: bool = None,
        num_chains: int = 32,
        profile: bool = False,
        cmap: str = 'crest',
        verbose: bool = False,
        logging_steps: int = 1,
) -> dict:
    """Plot data from `data_container.data`."""
    if verbose:
        keep_strs = list(data_container.data.keys())

    else:
        keep_strs = [
            'charges', 'plaqs', 'accept_prob',
            'Hf_start', 'Hf_mid', 'Hf_end',
            'Hb_start', 'Hb_mid', 'Hb_end',
            'Hwf_start', 'Hwf_mid', 'Hwf_end',
            'Hwb_start', 'Hwb_mid', 'Hwb_end',
            'xeps_start', 'xeps_mid', 'xeps_end',
            'veps_start', 'veps_mid', 'veps_end'
        ]

    with_jupyter = in_notebook()
    # -- TODO: --------------------------------------
    #  * Get rid of unnecessary `params` argument,
    #    all of the entries exist in `configs`.
    # ----------------------------------------------
    if num_chains > 16:
        logger.warning(
            f'Reducing `num_chains` from {num_chains} to 16 for plotting.'
        )
        num_chains = 16

    plot_times = {}
    save_times = {}

    title = None
    if params is not None:
        try:
            title = get_title_str_from_params(params)
        except:
            title = None

    else:
        if configs is not None:
            params = {
                'beta_init': configs['beta_init'],
                'beta_final': configs['beta_final'],
                'x_shape': configs['dynamics_config']['x_shape'],
                'num_steps': configs['dynamics_config']['num_steps'],
                'net_weights': configs['dynamics_config']['net_weights'],
            }
        else:
            params = {}

    tstamp = io.get_timestamp('%Y-%m-%d-%H%M%S')
    plotdir = None
    if out_dir is not None:
        plotdir = os.path.join(out_dir, f'plots_{tstamp}')
        io.check_else_make_dir(plotdir)

    tint_data = {}
    output = {}
    if 'charges' in data_container.data:
        if configs is not None:
            lf = configs['dynamics_config']['num_steps']  # type: int
        else:
            lf = 0
        qarr = np.array(data_container.data['charges'])
        t0 = time.time()
        tint_dict, _ = plot_autocorrs_vs_draws(qarr, num_pts=20,
                                               nstart=1000, therm_frac=0.2,
                                               out_dir=plotdir, lf=lf)
        plot_times['plot_autocorrs_vs_draws'] = time.time() - t0

        tint_data = deepcopy(params)
        tint_data.update({
            'narr': tint_dict['narr'],
            'tint': tint_dict['tint'],
            'run_params': params,
        })

        run_dir = params.get('run_dir', None)
        if run_dir is not None:
            if os.path.isdir(str(Path(run_dir))):
                tint_file = os.path.join(run_dir, 'tint_data.z')
                t0 = time.time()
                io.savez(tint_data, tint_file, 'tint_data')
                save_times['tint_data'] = time.time() - t0

        t0 = time.time()
        qsteps = logging_steps * np.arange(qarr.shape[0])
        _ = plot_charges(qsteps, qarr, out_dir=plotdir, title=title)
        plot_times['plot_charges'] = time.time() - t0

        output.update({
            'tint_dict': tint_dict,
            'charges_steps': qsteps,
            'charges_arr': qarr,
        })

    hmc = params.get('hmc', False) if hmc is None else hmc

    data_dict = {}
    data_vars = {}
    #  charges_steps = []
    #  charges_arr = []
    for key, val in data_container.data.items():
        if key in SKEYS and key not in keep_strs:
            continue
        #  if key == 'x':
        #      continue
        #
        # ====
        # Conditional to skip logdet-related data
        # from being plotted if data generated from HMC run
        if hmc:
            for skip_str in ['Hw', 'ld', 'sld', 'sumlogdet']:
                if skip_str in key:
                    continue

        arr = np.array(val)
        steps = logging_steps * np.arange(len(arr))

        if therm_frac > 0:
            if logging_steps == 1:
                arr, steps = therm_arr(arr, therm_frac=therm_frac)
            else:
                drop = int(therm_frac * arr.shape[0])
                arr = arr[drop:]
                steps = steps[drop:]

        if logging_steps == 1 and therm_frac > 0:
            arr, steps = therm_arr(arr, therm_frac=therm_frac)

        #  if arr.flatten().std() < 1e-2:
        #      logger.warning(f'Skipping plot for: {key}')
        #      logger.warning(f'std({key}) = {arr.flatten().std()} < 1e-2')

        labels = ('MC Step', key)
        data = (steps, arr)

        # -- NOTE: arr.shape: (draws,) = (D,) -------------------------------
        if len(arr.shape) == 1:
            data_dict[key] = xr.DataArray(arr, dims=['draw'], coords=[steps])
            #  if verbose:
            #      plotdir_ = os.path.join(plotdir, f'mcmc_lineplots')
            #      io.check_else_make_dir(plotdir_)
            #      lplot_fname = os.path.join(plotdir_, f'{key}.pdf')
            #      _, _ = mcmc_lineplot(data, labels, title,
            #                           lplot_fname, show_avg=True)
            #      plt.close('all')

        # -- NOTE: arr.shape: (draws, chains) = (D, C) ----------------------
        elif len(arr.shape) == 2:
            data_dict[key] = data
            chains = np.arange(arr.shape[1])
            data_arr = xr.DataArray(arr.T,
                                    dims=['chain', 'draw'],
                                    coords=[chains, steps])
            data_dict[key] = data_arr
            data_vars[key] = data_arr
            #  if verbose:
            #      plotdir_ = os.path.join(plotdir, 'traceplots')
            #      tplot_fname = os.path.join(plotdir_, f'{key}_traceplot.pdf')
            #      _ = mcmc_traceplot(key, data_arr, title, tplot_fname)
            #      plt.close('all')

        # -- NOTE: arr.shape: (draws, leapfrogs, chains) = (D, L, C) ---------
        elif len(arr.shape) == 3:
            _, leapfrogs_, chains_ = arr.shape
            chains = np.arange(chains_)
            leapfrogs = np.arange(leapfrogs_)
            data_dict[key] = xr.DataArray(arr.T,  # NOTE: [arr.T] = (C, L, D)
                                          dims=['chain', 'leapfrog', 'draw'],
                                          coords=[chains, leapfrogs, steps])

    #  plotdir_xr = None
    #  if plotdir is not None:
    #      plotdir_xr = os.path.join(plotdir, 'xarr_plots')

    plotdir_xr = None
    if plotdir is not None:
        plotdir_xr = os.path.join(plotdir, 'xarr_plots')

    t0 = time.time()
    dataset, dtplot = data_container.plot_dataset(plotdir_xr,
                                                  num_chains=num_chains,
                                                  therm_frac=therm_frac,
                                                  ridgeplots=True,
                                                  cmap=cmap,
                                                  profile=profile)

    if not with_jupyter:
        plt.close('all')

    plot_times['data_container.plot_dataset'] = {'total': time.time() - t0}
    for key, val in dtplot.items():
        plot_times['data_container.plot_dataset'][key] = val

    if not hmc and 'Hwf' in data_dict.keys():
        t0 = time.time()
        _ = plot_energy_distributions(data_dict, out_dir=plotdir, title=title)
        plot_times['plot_energy_distributions'] = time.time() - t0

    t0 = time.time()
    _ = mcmc_avg_lineplots(data_dict, title, plotdir)
    plot_times['mcmc_avg_lineplots'] = time.time() - t0

    if not with_jupyter:
        plt.close('all')

    output.update({
        'data_container': data_container,
        'data_dict': data_dict,
        'data_vars': data_vars,
        'out_dir': plotdir,
        'save_times': save_times,
        'plot_times': plot_times,
    })

    return output
Ejemplo n.º 10
0
def setup(
    configs: dict,
    x: tf.Tensor = None,
    betas: list[tf.Tensor] = None,
    strict: bool = False,
    try_restore: bool = False,
):
    """Setup training."""
    train_steps = configs.get('train_steps', None)  # type: int
    save_steps = configs.get('save_steps', None)  # type: int
    print_steps = configs.get('print_steps', None)  # type: int

    beta_init = configs.get('beta_init', None)  # type: float
    beta_final = configs.get('beta_final', None)  # type: float

    dirs = configs.get('dirs', None)  # type: dict[str, Any]
    logdir = dirs.get('logdir', dirs.get('log_dir', None))

    assert dirs is not None
    assert logdir is not None
    assert beta_init is not None and beta_final is not None

    train_data = DataContainer(train_steps, dirs=dirs, print_steps=print_steps)

    # Check if we want to restore from existing directory
    ensure_new = configs.get('ensure_new', False)
    restore_dir = configs.get('restore_from', None)
    datadir = os.path.join(logdir, 'training', 'train_data')
    if ensure_new:
        dynamics = build_dynamics(configs)

    else:
        if restore_dir is not None:
            dynamics = restore_from(configs, restore_dir, strict=strict)
            datadir = os.path.join(restore_dir, 'training', 'train_data')
        else:
            prev_logdir = look_for_previous_logdir(logdir)
            datadir = os.path.join(logdir, 'training', 'train_data')
            try:
                dynamics = restore_from(configs, prev_logdir, strict=strict)
            except OSError:
                logger.error(f'Unable to restore dynamics!')
                try:
                    dynamics = restore_from(configs, logdir, strict=strict)
                except OSError:
                    logger.error(f'Creating new GaugeDynamics...')
                    dynamics = build_dynamics(configs)

    current_step = dynamics.optimizer.iterations.numpy()
    if train_steps <= current_step:
        train_steps = current_step + min(save_steps, print_steps)

    train_data.steps = train_steps

    if os.path.isdir(datadir) and try_restore and not ensure_new:
        try:
            x = train_data.restore(datadir,
                                   step=current_step,
                                   x_shape=dynamics.x_shape,
                                   rank=RANK,
                                   local_rank=LOCAL_RANK)
        except ValueError:
            logger.warning('Unable to restore `x`, re-sampling from [-pi,pi)')
            x = tf.random.uniform(dynamics.x_shape, minval=-PI, maxval=PI)
    else:
        x = tf.random.uniform(dynamics.x_shape, minval=-PI, maxval=PI)

    # Reshape x from [batch_size, Nt, Nx, Nd] --> [batch_size, Nt * Nx * Nd]
    x = tf.cast(tf.reshape(x, (x.shape[0], -1)), tf.keras.backend.floatx())

    # Create checkpoint and checkpoint manager for saving during training
    ckptdir = os.path.join(logdir, 'training', 'checkpoints')
    ckpt = tf.train.Checkpoint(dynamics=dynamics, optimizer=dynamics.optimizer)
    manager = tf.train.CheckpointManager(ckpt, ckptdir, max_to_keep=5)
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        logger.info(f'Restored ckpt from: {manager.latest_checkpoint}')

    # Setup summary writer for logging metrics through tensorboard
    summdir = dirs['summary_dir']
    make_summaries = configs.get('make_summaries', True)
    steps = tf.range(current_step, train_steps, dtype=tf.int64)
    betas = setup_betas(beta_init, beta_final, train_steps, current_step)

    dynamics.compile(loss=dynamics.calc_losses,
                     optimizer=dynamics.optimizer,
                     experimental_run_tf_function=False)

    _ = dynamics.apply_transition((x, tf.constant(betas[0])), training=True)

    writer = None
    if IS_CHIEF:
        plot_models(dynamics, dirs['log_dir'])
        io.savez(configs, os.path.join(dirs['log_dir'], 'train_configs.z'))
        if make_summaries and TF_VERSION == 2:
            try:
                writer = tf.summary.create_file_writer(summdir)
            except AttributeError:
                writer = None
        else:
            writer = None

    return {
        'x': x,
        'betas': betas,
        'dynamics': dynamics,
        'dirs': dirs,
        'steps': steps,
        'writer': writer,
        'manager': manager,
        'configs': configs,
        'checkpoint': ckpt,
        'train_data': train_data,
    }