コード例 #1
0
 def dump_configs(x, data_dir, rank=0, local_rank=0):
     """Save configs `x` separately for each rank."""
     xfile = os.path.join(data_dir, f'x_rank{rank}-{local_rank}.z')
     io.log('Saving configs from rank ' f'{rank}-{local_rank} to: {xfile}.')
     head, _ = os.path.split(xfile)
     io.check_else_make_dir(head)
     joblib.dump(x, xfile)
コード例 #2
0
 def save_data(
         self,
         data_dir: Union[str, Path],
         skip_keys: list[str] = None,
         use_hdf5: bool = True,
         save_dataset: bool = True,
         # compression: str = 'gzip',
 ):
     """Save `self.data` entries to individual files in `output_dir`."""
     success = 0
     data_dir = Path(data_dir)
     io.check_else_make_dir(str(data_dir))
     if use_hdf5:
         # ------------ Save using h5py -------------------
         hfile = data_dir.joinpath(f'data_rank{RANK}.hdf5')
         self.to_h5pyfile(hfile, skip_keys=skip_keys)
         success = 1
     if save_dataset:
         # ----- Save to `netCDF4` file using hierarchical grouping ------
         self.save_dataset(data_dir)
         success = 1
     if success == 0:
         # -------------------------------------------------
         # Save each `{key: val}` entry in `self.data` as a 
         # compressed file, named `f'{key}.z'` .
         # -------------------------------------------------
         skip_keys = [] if skip_keys is None else skip_keys
         logger.info(f'Saving individual entries to {data_dir}.')
         for key, val in self.data.items():
             if key in skip_keys:
                 continue
             out_file = os.path.join(data_dir, f'{key}.z')
             io.savez(np.array(val), out_file)
コード例 #3
0
ファイル: gauge_dynamics.py プロジェクト: cphysics/l2hmc-qcd
 def save_networks(self, log_dir):
     """Save networks to disk."""
     models_dir = os.path.join(log_dir, 'training', 'models')
     io.check_else_make_dir(models_dir)
     eps_file = os.path.join(models_dir, 'eps.z')
     io.savez(self.eps.numpy(), eps_file, name='eps')
     if self.config.separate_networks:
         xnet_paths = [
             os.path.join(models_dir, f'dynamics_xnet{i}')
             for i in range(self.config.num_steps)
         ]
         vnet_paths = [
             os.path.join(models_dir, f'dynamics_vnet{i}')
             for i in range(self.config.num_steps)
         ]
         for idx, (xf, vf) in enumerate(zip(xnet_paths, vnet_paths)):
             xnet = self.xnet[idx]  # type: tf.keras.models.Model
             vnet = self.vnet[idx]  # type: tf.keras.models.Model
             io.log(f'Saving `xnet{idx}` to {xf}.')
             io.log(f'Saving `vnet{idx}` to {vf}.')
             xnet.save(xf)
             vnet.save(vf)
     else:
         xnet_paths = os.path.join(models_dir, 'dynamics_xnet')
         vnet_paths = os.path.join(models_dir, 'dynamics_vnet')
         io.log(f'Saving `xnet` to {xnet_paths}.')
         io.log(f'Saving `vnet` to {vnet_paths}.')
         self.xnet.save(xnet_paths)
         self.vnet.save(vnet_paths)
コード例 #4
0
 def save_dataset(self, out_dir, therm_frac=0.):
     """Save `self.data` as `xr.Dataset` to `out_dir/dataset.nc`."""
     dataset = self.get_dataset(therm_frac)
     io.check_else_make_dir(out_dir)
     outfile = Path(out_dir).joinpath('dataset.nc')
     mode = 'a' if outfile.is_file() else 'w'
     logger.debug(f'Saving dataset to: {outfile}.')
     dataset.to_netcdf(str(outfile), mode=mode)
コード例 #5
0
def setup_directories(configs: dict) -> dict:
    """Setup directories for training."""
    logfile = os.path.join(os.getcwd(), 'log_dirs.txt')
    ensure_new = configs.get('ensure_new', False)
    logdir = configs.get('logdir', configs.get('log_dir', None))
    if logdir is not None:
        logdir_exists = os.path.isdir(logdir)
        contents = os.listdir(logdir)
        logdir_nonempty = False
        if contents is not None and isinstance(contents, list):
            if len(contents) > 0:
                logdir_nonempty = True

        if logdir_exists and logdir_nonempty and ensure_new:
            raise ValueError(
                f'Nonempty `logdir`, but `ensure_new={ensure_new}')

    # Create `logdir`, `logdir/training/...`' etc
    dirs = io.setup_directories(configs, timestamps=TSTAMPS)
    configs['dirs'] = dirs
    logdir = dirs.get('logdir', dirs.get('log_dir', None))  # type: str
    configs['log_dir'] = logdir
    configs['logdir'] = logdir

    restore_dir = configs.get('restore_from', None)
    if restore_dir is None and not ensure_new:
        candidate = look_for_previous_logdir(logdir)
        if candidate != logdir:
            if candidate.is_dir():
                nckpts = len(list(candidate.rglob('checkpoint')))
                if nckpts > 0:
                    restore_dir = candidate
                    configs['restore_from'] = restore_dir

    if restore_dir is not None and not ensure_new:
        try:
            restored = load_configs_from_logdir(restore_dir)
            if restored is not None:
                io.save_dict(restored, logdir, name='restored_train_configs')
        except FileNotFoundError:
            logger.warning(f'Unable to load configs from {restore_dir}')

    if RANK == 0:
        io.check_else_make_dir(logdir)
        restore_dir = configs.get('restore_dir', None)
        if restore_dir is not None and not ensure_new:
            try:
                restored = load_configs_from_logdir(restore_dir)
                if restored is not None:
                    io.save_dict(restored,
                                 logdir,
                                 name='restored_train_configs')
            except FileNotFoundError:
                logger.warning(f'Unable to load configs from {restore_dir}')
                pass

    return configs
コード例 #6
0
    def save_data(self, data_dir, rank=0):
        """Save `self.data` entries to individual files in `output_dir`."""
        if rank != 0:
            return

        io.check_else_make_dir(data_dir)
        for key, val in self.data.items():
            out_file = os.path.join(data_dir, f'{key}.z')
            io.savez(np.array(val), out_file)
コード例 #7
0
ファイル: training_utils.py プロジェクト: cphysics/l2hmc-qcd
def train_hmc(flags):
    """Main method for training HMC model."""
    hflags = AttrDict(dict(flags).copy())
    lr_config = AttrDict(hflags.pop('lr_config', None))
    config = AttrDict(hflags.pop('dynamics_config', None))
    net_config = AttrDict(hflags.pop('network_config', None))
    hflags.train_steps = hflags.pop('hmc_steps', None)
    hflags.beta_init = hflags.beta_final

    config.update({
        'hmc': True,
        'use_ncp': False,
        'aux_weight': 0.,
        'zero_init': False,
        'separate_networks': False,
        'use_conv_net': False,
        'directional_updates': False,
        'use_scattered_xnet_update': False,
        'use_tempered_traj': False,
        'gauge_eq_masks': False,
    })

    hflags.profiler = False
    hflags.make_summaries = True

    lr_config = LearningRateConfig(
        warmup_steps=0,
        decay_rate=0.9,
        decay_steps=hflags.train_steps // 10,
        lr_init=lr_config.get('lr_init', None),
    )

    train_dirs = io.setup_directories(hflags, 'training_hmc')
    dynamics = GaugeDynamics(hflags, config, net_config, lr_config)
    dynamics.save_config(train_dirs.config_dir)
    x, train_data = train_dynamics(dynamics, hflags, dirs=train_dirs)
    if IS_CHIEF:
        output_dir = os.path.join(train_dirs.train_dir, 'outputs')
        io.check_else_make_dir(output_dir)
        train_data.save_data(output_dir)

        params = {
            'eps': dynamics.eps,
            'num_steps': dynamics.config.num_steps,
            'beta_init': hflags.beta_init,
            'beta_final': hflags.beta_final,
            'lattice_shape': dynamics.config.lattice_shape,
            'net_weights': NET_WEIGHTS_HMC,
        }
        plot_data(train_data,
                  train_dirs.train_dir,
                  hflags,
                  thermalize=True,
                  params=params)

    return x, train_data, dynamics.eps.numpy()
コード例 #8
0
 def __init__(self, steps, header=None, dirs=None, print_steps=100):
     self.steps = steps
     self.print_steps = print_steps
     self.dirs = dirs
     self.data_strs = [header]
     self.steps_arr = []
     self.data = AttrDict(defaultdict(list))
     if dirs is not None:
         io.check_else_make_dir(
             [v for k, v in dirs.items() if 'file' not in k])
コード例 #9
0
def plot_data(train_data, out_dir, flags, thermalize=False, params=None):
    out_dir = os.path.join(out_dir, 'plots')
    io.check_else_make_dir(out_dir)

    title = None if params is None else get_title_str_from_params(params)

    logging_steps = flags.get('logging_steps', 1)
    flags_file = os.path.join(out_dir, 'FLAGS.z')
    if os.path.isfile(flags_file):
        train_flags = io.loadz(flags_file)
        logging_steps = train_flags.get('logging_steps', 1)

    #  logging_steps = flags.logging_steps if 'training' in out_dir else 1

    data_dict = {}
    for key, val in train_data.data.items():
        if key == 'x':
            continue

        arr = np.array(val)
        steps = logging_steps * np.arange(len(np.array(val)))

        if thermalize or key == 'dt':
            arr, steps = therm_arr(arr, therm_frac=0.33)
            #  steps = steps[::logging_setps]
            #  steps *= logging_steps

        labels = ('MC Step', key)
        data = (steps, arr)

        if len(arr.shape) == 1:
            lplot_fname = os.path.join(out_dir, f'{key}.png')
            _, _ = mcmc_lineplot(data,
                                 labels,
                                 title,
                                 lplot_fname,
                                 show_avg=True)

        elif len(arr.shape) > 1:
            data_dict[key] = data
            chains = np.arange(arr.shape[1])
            data_arr = xr.DataArray(arr.T,
                                    dims=['chain', 'draw'],
                                    coords=[chains, steps])

            tplot_fname = os.path.join(out_dir, f'{key}_traceplot.png')
            _ = mcmc_traceplot(key, data_arr, title, tplot_fname)

        plt.close('all')

    _ = mcmc_avg_lineplots(data_dict, title, out_dir)
    _ = plot_charges(*data_dict['charges'], out_dir=out_dir, title=title)
    _ = plot_energy_distributions(data_dict, out_dir=out_dir, title=title)

    plt.close('all')
コード例 #10
0
    def plot_dataset(
            self,
            out_dir: str = None,
            therm_frac: float = 0.,
            num_chains: int = None,
            ridgeplots: bool = True,
            profile: bool = False,
            skip: Union[list[str], str] = None,
            cmap: str = 'viridis_r',
    ):
        """Create trace plot + histogram for each entry in self.data."""
        tdict = {}
        dataset = self.get_dataset(therm_frac)
        for key, val in dataset.data_vars.items():
            if skip is not None:
                if key in skip:
                    continue

            t0 = time.time()
            try:
                std = np.std(val.values.flatten())
            except TypeError:
                continue
            # if np.std(val.values.flatten()) < 1e-2:
            if std < 1e-2:
                continue

            if len(val.shape) == 2:  # shape: (chain, draw)
                val = val[:num_chains, :]

            if len(val.shape) == 3:  # shape: (chain, leapfrogs, draw)
                val = val[:num_chains, :, :]

            fig, ax = plt.subplots(constrained_layout=True, figsize=set_size())
            _ = val.plot(ax=ax)
            if out_dir is not None:
                io.check_else_make_dir(out_dir)
                out_file = os.path.join(out_dir, f'{key}_xrPlot.svg')
                fig.savefig(out_file, dpi=400, bbox_inches='tight')
                plt.close('all')
                plt.clf()

            if profile:
                tdict[key] = time.time() - t0

        if out_dir is not None:
            out_dir = os.path.join(out_dir, 'ridgeplots')
            io.check_else_make_dir(out_dir)

        if ridgeplots:
            make_ridgeplots(dataset, num_chains=num_chains,
                            out_dir=out_dir, cmap=cmap)

        return dataset, tdict
コード例 #11
0
def mcmc_avg_lineplots(data, title=None, out_dir=None):
    """Plot trace of avg."""
    fig, axes = None, None
    for idx, (key, val) in enumerate(data.items()):
        fig, axes = plt.subplots(ncols=2, figsize=set_size(subplots=(1, 2)),
                                 constrained_layout=True)
        axes = axes.flatten()
        if len(val) == 2:
            if len(val[0].shape) > len(val[1].shape):
                arr, steps = val
            else:
                steps, arr = val
        else:
            arr = val
            steps = np.arange(arr.shape[0])

        if isinstance(arr, xr.DataArray):
            arr = arr.values

        if len(arr.shape) == 3:
            # ====
            # TODO: Create separate plots for each leapfrog?
            arr = np.mean(arr, axis=1)

        xlabel = 'MC Step'
        if len(val.shape) == 1:
            avg = arr
            ylabel = ' '.join(key.split('_'))

        else:
            avg = np.mean(arr, axis=1)
            ylabel = ' '.join(key.split('_')) + r" avg"

        _ = axes[0].plot(steps, avg, color=COLORS[idx])
        _ = axes[0].set_xlabel(xlabel)
        _ = axes[0].set_ylabel(ylabel)
        _ = sns.kdeplot(arr.flatten(), ax=axes[1],
                        color=COLORS[idx], fill=True)
        _ = axes[1].set_xlabel(ylabel)
        _ = axes[1].set_ylabel('')
        if title is not None:
            _ = fig.suptitle(title)

        if out_dir is not None:
            dir_ = os.path.join(out_dir, 'avg_lineplots')
            io.check_else_make_dir(dir_)
            fpath = os.path.join(dir_, f'{key}_avg.pdf')
            savefig(fig, fpath)

    return fig, axes
コード例 #12
0
def energy_traceplot(key, arr, out_dir=None, title=None):
    if out_dir is not None:
        out_dir = os.path.join(out_dir, 'energy_traceplots')
        io.check_else_make_dir(out_dir)

    for idx in range(arr.shape[1]):
        arr_ = arr[:, idx, :]
        steps = np.arange(arr_.shape[0])
        chains = np.arange(arr_.shape[1])
        data_arr = xr.DataArray(arr_.T,
                                dims=['chain', 'draw'],
                                coords=[chains, steps])
        new_key = f'{key}_lf{idx}'
        if out_dir is not None:
            tplot_fname = os.path.join(out_dir,
                                       f'{new_key}_traceplot.pdf')

        _ = mcmc_traceplot(new_key, data_arr, title, tplot_fname)
コード例 #13
0
def run_hmc(
        args: AttrDict,
        hmc_dir: str = None,
        skip_existing: bool = False,
) -> (GaugeDynamics, DataContainer, tf.Tensor):
    """Run HMC using `inference_args` on a model specified by `params`.

    NOTE:
    -----
    args should be a dict with the following keys:
        - 'hmc'
        - 'eps'
        - 'beta'
        - 'num_steps'
        - 'run_steps'
        - 'lattice_shape'
    """
    if not IS_CHIEF:
        return None, None, None

    if hmc_dir is None:
        root_dir = os.path.join(GAUGE_LOGS_DIR, 'hmc_logs')
        month_str = io.get_timestamp('%Y_%m')
        hmc_dir = os.path.join(root_dir, month_str)

    io.check_else_make_dir(hmc_dir)

    def get_run_fstr(run_dir):
        _, tail = os.path.split(run_dir)
        fstr = tail.split('-')[0]
        return fstr

    if skip_existing:
        run_dirs = [os.path.join(hmc_dir, i) for i in os.listdir(hmc_dir)]
        run_fstrs = [get_run_fstr(i) for i in run_dirs]
        run_fstr = io.get_run_dir_fstr(args)
        if run_fstr in run_fstrs:
            io.log('ERROR:Existing run found! Skipping.')
            return None, None, None

    dynamics = build_dynamics(args)
    dynamics, run_data, x = run(dynamics, args, runs_dir=hmc_dir)

    return dynamics, run_data, x
コード例 #14
0
    def __init__(
            self,
            steps: int = None,
            # header: str = None,
            dirs: dict[str, Union[str, Path]] = None,
            print_steps: int = 100,
    ):
        self.steps = steps
        self.data_strs = []
        self.steps_arr = []
        self.print_steps = print_steps
        self.data = {}
        # self.data = AttrDict(defaultdict(list))

        if dirs is not None:
            names = [v for k, v in dirs.items() if 'file' not in k]
            io.check_else_make_dir(names)
        else:
            dirs = {}

        self.dirs = {k: Path(v).resolve() for k, v in dirs.items()}
コード例 #15
0
    def write_to_csv(self, log_dir, run_dir, hmc=False):
        """Write data averages to bulk csv file for comparing runs."""
        _, run_str = os.path.split(run_dir)
        avg_data = {
            'log_dir': log_dir,
            'run_dir': run_str,
            'hmc': hmc,
        }

        for key, val in dict(sorted(self.data.items())).items():
            # try:
            #     arr = val.numpy()
            # except AttributeError:
            #     try:
            #         arr = np.array(val)
            #     except ValueError:
            #         arr = val
            try:
                avg_data[key] = np.mean(val)
                # avg_data[key] = val.mean()
            except AttributeError:
                raise AttributeError(f'Unable to call `val.mean()` on {val}')

            # if 'steps' not in avg_data:
            #     avg_data['steps'] = len(steps)

            # avg_data[key] = np.mean(arr)

        avg_df = pd.DataFrame(avg_data, index=[0])
        outdir = os.path.join(BASE_DIR, 'logs', 'GaugeModel_logs')
        csv_file = os.path.join(outdir, 'inference.csv')
        head, tail = os.path.split(csv_file)
        io.check_else_make_dir(head)
        logger.info(f'Appending inference results to {csv_file}.')
        if not os.path.isfile(csv_file):
            avg_df.to_csv(csv_file, header=True, index=False, mode='w')
        else:
            avg_df.to_csv(csv_file, header=False, index=False, mode='a')
コード例 #16
0
def run_from_log_dir(log_dir: str, net_weights: NetWeights, run_steps=5000):
    configs = load_configs_from_log_dir(log_dir)
    if 'x_shape' not in configs['dynamics_config'].keys():
        x_shape = configs['dynamics_config'].pop('lattice_shape')
        configs['dynamics_config']['x_shape'] = x_shape

    beta = configs['beta_final']
    nwstr = 'nw' + ''.join([f'{int(i)}' for i in net_weights])
    run_dir = os.path.join(PROJECT_DIR, 'l2hmc_function_tests',
                           'inference', f'beta{beta}', f'{nwstr}')
    if os.path.isdir(run_dir):
        io.log(f'EXISTING RUN FOUND AT: {run_dir}, SKIPPING!', style='bold red')

    io.check_else_make_dir(run_dir)
    log_dir = configs.get('log_dir', None)
    configs['log_dir_orig'] = log_dir
    configs['log_dir'] = run_dir
    configs['run_steps'] = run_steps
    configs = AttrDict(configs)

    dynamics = build_dynamics(configs)
    xnet, vnet = dynamics._load_networks(log_dir)
    dynamics.xnet = xnet
    dynamics.vnet = vnet
    io.log(f'Original dynamics.net_weights: {dynamics.net_weights}')
    io.log(f'Setting `dynamics.net_weights` to: {net_weights}')
    dynamics._set_net_weights(net_weights)
    dynamics.net_weights = net_weights
    io.log(f'Now, dynamics.net_weights: {dynamics.net_weights}')
    dynamics, train_data, x = short_training(1000, beta, log_dir=log_dir,
                                             dynamics=dynamics, x=None)
    inference_results = run(dynamics, configs, beta=beta, runs_dir=run_dir,
                            md_steps=500, make_plots=True, therm_frac=0.2,
                            num_chains=16)

    return inference_results
コード例 #17
0
def run(
        dynamics: GaugeDynamics,
        args: AttrDict,
        x: tf.Tensor = None,
        runs_dir: str = None
) -> (GaugeDynamics, DataContainer, tf.Tensor):
    """Run inference."""
    if not IS_CHIEF:
        return None, None, None

    if runs_dir is None:
        if dynamics.config.hmc:
            runs_dir = os.path.join(args.log_dir, 'inference_hmc')
        else:
            runs_dir = os.path.join(args.log_dir, 'inference')

    eps = dynamics.eps
    if hasattr(eps, 'numpy'):
        eps = eps.numpy()

    args.eps = eps

    io.check_else_make_dir(runs_dir)
    run_dir = io.make_run_dir(args, runs_dir)
    data_dir = os.path.join(run_dir, 'run_data')
    summary_dir = os.path.join(run_dir, 'summaries')
    log_file = os.path.join(run_dir, 'run_log.txt')
    io.check_else_make_dir([run_dir, data_dir, summary_dir])
    writer = tf.summary.create_file_writer(summary_dir)
    writer.set_as_default()

    run_steps = args.get('run_steps', 2000)
    beta = args.get('beta', None)
    if beta is None:
        beta = args.get('beta_final', None)

    if x is None:
        x = convert_to_angle(tf.random.normal(shape=dynamics.x_shape))

    run_data, x, _ = run_dynamics(dynamics, args, x, save_x=False)

    run_data.flush_data_strs(log_file, mode='a')
    run_data.write_to_csv(args.log_dir, run_dir, hmc=dynamics.config.hmc)
    io.save_inference(run_dir, run_data)
    if args.get('save_run_data', True):
        run_data.save_data(data_dir)

    run_params = {
        'eps': eps,
        'beta': beta,
        'run_steps': run_steps,
        'plaq_weight': dynamics.plaq_weight,
        'charge_weight': dynamics.charge_weight,
        'lattice_shape': dynamics.lattice_shape,
        'num_steps': dynamics.config.num_steps,
        'net_weights': dynamics.net_weights,
        'input_shape': dynamics.x_shape,
    }
    run_params.update(dynamics.params)
    io.save_params(run_params, run_dir, name='run_params')

    args.logging_steps = 1
    plot_data(run_data, run_dir, args, thermalize=True, params=run_params)

    return dynamics, run_data, x
コード例 #18
0
def calc_tau_int_from_dir(
        input_path: str,
        hmc: bool = False,
        px_cutoff: float = None,
        therm_frac: float = 0.2,
        num_pts: int = 50,
        nstart: int = 100,
        make_plot: bool = True,
        save_data: bool = True,
        keep_charges: bool = False,
):
    """
    NOTE: `load_charges_from_dir` returns `output_arr`:
      `output_arr` is a list of dicts, each of which look like:
            output = {
                'beta': beta,
                'lf': lf,
                'eps': eps,
                'traj_len': lf * eps,
                'qarr': charges,
                'run_params': params,
                'run_dir': run_dir,
            }
    """
    output_arr = load_charges_from_dir(input_path, hmc=hmc)
    if output_arr is None:
        logger.info(', '.join([
            'WARNING: Skipping entry!',
            f'\t unable to load charge data from {input_path}',
        ]))

        return None

    tint_data = {}
    for output in output_arr:
        run_dir = output['run_dir']
        beta = output['beta']
        lf = output['lf']
        #  eps = output['eps']
        #  traj_len = output['traj_len']
        run_params = output['run_params']

        if hmc:
            data_dir = os.path.join(input_path, 'run_data')
            plot_dir = os.path.join(input_path, 'plots', 'tau_int_plots')
        else:
            run_dir = output['run_params']['run_dir']
            data_dir = os.path.join(run_dir, 'run_data')
            plot_dir = os.path.join(run_dir, 'plots')

        outfile = os.path.join(data_dir, 'tau_int_data.z')
        outfile1 = os.path.join(data_dir, 'tint_data.z')
        fdraws = os.path.join(plot_dir, 'tau_int_vs_draws.pdf')
        ftlen = os.path.join(plot_dir, 'tau_int_vs_traj_len.pdf')
        c1 = os.path.isfile(outfile)
        c11 = os.path.isfile(outfile1)
        c2 = os.path.isfile(fdraws)
        c3 = os.path.isfile(ftlen)
        if c1 or c11 or c2 or c3:
            loaded = io.loadz(outfile)
            output.update(loaded)
            logger.info(', '.join([
                'WARNING: Loading existing data'
                f'\t Found existing data at: {outfile}.',
            ]))
            loaded = io.loadz(outfile)
            n = loaded.get('draws', loaded.get('narr', None))
            tint = loaded.get('tau_int', loaded.get('tint', None))
            output.update(loaded)

        xeps_check = 'xeps' in output['run_params'].keys()
        veps_check = 'veps' in output['run_params'].keys()
        if xeps_check and veps_check:
            xeps = tf.reduce_mean(output['run_params']['xeps'])
            veps = tf.reduce_mean(output['run_params']['veps'])
            eps = tf.reduce_mean([xeps, veps]).numpy()
        else:
            eps = output['eps']
            if isinstance(eps, list):
                eps = tf.reduce_mean(eps)
            elif tf.is_tensor(eps):
                try:
                    eps = eps.numpy()
                except AttributeError:
                    eps = tf.reduce_mean(eps)

        traj_len = lf * eps

        qarr, _ = therm_arr(output['qarr'], therm_frac=therm_frac)

        n, tint = calc_autocorr(qarr.T, num_pts=num_pts, nstart=nstart)
        #  output.update({
        #      'draws': n,
        #      'tau_int': tint,
        #      'qarr.shape': qarr.shape,
        #  })
        tint_data[run_dir] = {
            'run_dir': run_dir,
            'run_params': run_params,
            'lf': lf,
            'eps': eps,
            'traj_len': traj_len,
            'narr': n,
            'tint': tint,
            'qarr.shape': qarr.shape,
        }
        if save_data:
            io.savez(tint_data, outfile, 'tint_data')
            #  io.savez(output, outfile, name='tint_data')

        if make_plot:
            #  fbeta = os.path.join(plot_dir, 'tau_int_vs_beta.pdf')
            io.check_else_make_dir(plot_dir)
            prefix = 'HMC' if hmc else 'L2HMC'
            xlabel = 'draws'
            ylabel = r'$\tau_{\mathrm{int}}$ (estimate)'
            title = (f'{prefix}, '
                     + r'$\beta=$' + f'{beta}, '
                     + r'$N_{\mathrm{lf}}=$' + f'{lf}, '
                     + r'$\varepsilon=$' + f'{eps:.2g}, '
                     + r'$\lambda=$' + f'{traj_len:.2g}')

            _, ax = plt.subplots(constrained_layout=True)
            best = []
            for t in tint.T:
                _ = ax.plot(n, t, marker='.', color='k')
                best.append(t[-1])

            _ = ax.set_ylabel(ylabel)
            _ = ax.set_xlabel(xlabel)
            _ = ax.set_title(title)

            _ = ax.set_xscale('log')
            _ = ax.set_yscale('log')
            _ = ax.grid(alpha=0.4)
            logger.info(f'Saving figure to: {fdraws}')
            _ = plt.savefig(fdraws, dpi=400, bbox_inches='tight')
            plt.close('all')

            _, ax = plt.subplots()
            for b in best:
                _ = ax.plot(traj_len, b, marker='.', color='k')
            _ = ax.set_ylabel(ylabel)
            _ = ax.set_xlabel(r'trajectory length, $\lambda$')
            _ = ax.set_title(title)
            _ = ax.set_yscale('log')
            _ = ax.grid(True, alpha=0.4)
            logger.info(f'Saving figure to: {ftlen}')
            _ = plt.savefig(ftlen, dpi=400, bbox_inches='tight')
            plt.close('all')

    return tint_data
コード例 #19
0
ファイル: inference_utils.py プロジェクト: saforem2/l2hmc-qcd
def run(
        dynamics: GaugeDynamics,
        configs: dict[str, Any],
        x: tf.Tensor = None,
        beta: float = None,
        runs_dir: str = None,
        make_plots: bool = True,
        therm_frac: float = 0.33,
        num_chains: int = 16,
        save_x: bool = False,
        md_steps: int = 50,
        skip_existing: bool = False,
        save_dataset: bool = True,
        use_hdf5: bool = False,
        skip_keys: list[str] = None,
        run_steps: int = None,
) -> InferenceResults:
    """Run inference. (Note: Higher-level than `run_dynamics`)."""
    if not IS_CHIEF:
        return InferenceResults(None, None, None, None, None)

    if num_chains > 16:
        logger.warning(f'Reducing `num_chains` from: {num_chains} to {16}.')
        num_chains = 16

    if run_steps is None:
        run_steps = configs.get('run_steps', 50000)

    if beta is None:
        beta = configs.get('beta_final', configs.get('beta', None))

    assert beta is not None

    logdir = configs.get('log_dir', configs.get('logdir', None))
    if runs_dir is None:
        rs = 'inference_hmc' if dynamics.config.hmc else 'inference'
        runs_dir = os.path.join(logdir, rs)

    io.check_else_make_dir(runs_dir)
    run_dir = io.make_run_dir(configs=configs, base_dir=runs_dir,
                              beta=beta, skip_existing=skip_existing)
    logger.info(f'run_dir: {run_dir}')
    data_dir = os.path.join(run_dir, 'run_data')
    summary_dir = os.path.join(run_dir, 'summaries')
    log_file = os.path.join(run_dir, 'run_log.txt')
    io.check_else_make_dir([run_dir, data_dir, summary_dir])
    writer = tf.summary.create_file_writer(summary_dir)
    writer.set_as_default()

    configs['logging_steps'] = 1
    if x is None:
        x = convert_to_angle(tf.random.uniform(shape=dynamics.x_shape,
                                               minval=-PI, maxval=PI))

    # == RUN DYNAMICS =======================================================
    nw = dynamics.net_weights
    inf_type = 'HMC' if dynamics.config.hmc else 'inference'
    logger.info(', '.join([f'Running {inf_type}', f'beta={beta}', f'nw={nw}']))
    t0 = time.time()
    results = run_dynamics(dynamics, flags=configs, beta=beta, x=x,
                           writer=writer, save_x=save_x, md_steps=md_steps)
    logger.info(f'Done running {inf_type}. took: {time.time() - t0:.4f} s')
    # =======================================================================

    run_data = results.run_data
    run_data.update_dirs({'log_dir': logdir, 'run_dir': run_dir})
    run_params = {
        'hmc': dynamics.config.hmc,
        'run_dir': run_dir,
        'beta': beta,
        'run_steps': run_steps,
        'plaq_weight': dynamics.plaq_weight,
        'charge_weight': dynamics.charge_weight,
        'x_shape': dynamics.x_shape,
        'num_steps': dynamics.config.num_steps,
        'net_weights': dynamics.net_weights,
        'input_shape': dynamics.x_shape,
    }

    inf_log_fpath = os.path.join(run_dir, 'inference_log.txt')
    with open(inf_log_fpath, 'a') as f:
        f.writelines('\n'.join(results.data_strs))

    traj_len = dynamics.config.num_steps * tf.reduce_mean(dynamics.xeps)
    if hasattr(dynamics, 'xeps') and hasattr(dynamics, 'veps'):
        xeps_avg = tf.reduce_mean(dynamics.xeps)
        veps_avg = tf.reduce_mean(dynamics.veps)
        traj_len = tf.reduce_sum(dynamics.xeps)
        run_params.update({
            'xeps': dynamics.xeps,
            'traj_len': traj_len,
            'veps': dynamics.veps,
            'xeps_avg': xeps_avg,
            'veps_avg': veps_avg,
            'eps_avg': (xeps_avg + veps_avg) / 2.,
        })

    elif hasattr(dynamics, 'eps'):
        run_params.update({
            'eps': dynamics.eps,
        })

    io.save_params(run_params, run_dir, name='run_params')

    save_times = {}
    plot_times = {}
    if make_plots:
        output = plot_data(data_container=run_data,
                           configs=configs,
                           params=run_params,
                           out_dir=run_dir,
                           hmc=dynamics.config.hmc,
                           therm_frac=therm_frac,
                           num_chains=num_chains,
                           profile=True,
                           cmap='crest',
                           logging_steps=1)

        save_times = io.SortedDict(**output['save_times'])
        plot_times = io.SortedDict(**output['plot_times'])
        dt1 = io.SortedDict(**plot_times['data_container.plot_dataset'])
        plot_times['data_container.plot_dataset'] = dt1

        tint_data = {
            'beta': beta,
            'run_dir': run_dir,
            'traj_len': traj_len,
            'run_params': run_params,
            'eps': run_params['eps_avg'],
            'lf': run_params['num_steps'],
            'narr': output['tint_dict']['narr'],
            'tint': output['tint_dict']['tint'],
        }

        t0 = time.time()
        tint_file = os.path.join(run_dir, 'tint_data.z')
        io.savez(tint_data, tint_file, 'tint_data')
        save_times['savez_tint_data'] = time.time() - t0

    t0 = time.time()

    logfile = os.path.join(run_dir, 'inference.log')
    run_data.save_and_flush(data_dir=data_dir,
                            log_file=logfile,
                            use_hdf5=use_hdf5,
                            skip_keys=skip_keys,
                            save_dataset=save_dataset)

    save_times['run_data.flush_data_strs'] = time.time() - t0

    t0 = time.time()
    try:
        run_data.write_to_csv(logdir, run_dir, hmc=dynamics.config.hmc)
    except TypeError:
        logger.warning(f'Unable to write to csv. Continuing...')

    save_times['run_data.write_to_csv'] = time.time() - t0

    # t0 = time.time()
    # io.save_inference(run_dir, run_data)
    # save_times['io.save_inference'] = time.time() - t0
    profdir = os.path.join(run_dir, 'profile_info')
    io.check_else_make_dir(profdir)
    io.save_dict(plot_times, profdir, name='plot_times')
    io.save_dict(save_times, profdir, name='save_times')

    return InferenceResults(dynamics=results.dynamics,
                            x=results.x, x_arr=results.x_arr,
                            run_data=results.run_data,
                            data_strs=results.data_strs)
コード例 #20
0
ファイル: inference_utils.py プロジェクト: saforem2/l2hmc-qcd
def run_inference_from_log_dir(
        log_dir: str,
        run_steps: int = 50000,
        beta: float = None,
        eps: float = None,
        make_plots: bool = True,
        train_steps: int = 10,
        therm_frac: float = 0.33,
        batch_size: int = 16,
        num_chains: int = 16,
        x: tf.Tensor = None,
) -> InferenceResults:      # (type: InferenceResults)
    """Run inference by loading networks in from `log_dir`."""
    configs = _find_configs(log_dir)
    if configs is None:
        raise FileNotFoundError(
            f'Unable to load configs from `log_dir`: {log_dir}. Exiting'
        )

    if eps is not None:
        configs['dynamics_config']['eps'] = eps

    else:
        try:
            eps_file = os.path.join(log_dir, 'training', 'models', 'eps.z')
            eps = io.loadz(eps_file)
        except FileNotFoundError:
            eps = configs.get('dynamics_config', None).get('eps', None)

    if beta is not None:
        configs.update({'beta': beta, 'beta_final': beta})

    if batch_size is not None:
        batch_size = int(batch_size)
        prev_shape = configs['dynamics_config']['x_shape']
        new_shape = (batch_size, *prev_shape[1:])
        configs['dynamics_config']['x_shape'] = new_shape

    configs = AttrDict(configs)
    dynamics = build_dynamics(configs)
    xnet, vnet = dynamics._load_networks(log_dir)
    dynamics.xnet = xnet
    dynamics.vnet = vnet

    if train_steps > 0:
        dynamics, train_data, x = short_training(train_steps,
                                                 configs.beta_final,
                                                 log_dir, dynamics, x=x)
    else:
        dynamics.compile(loss=dynamics.calc_losses,
                         optimizer=dynamics.optimizer,
                         experimental_run_tf_function=False)

    _, log_str = os.path.split(log_dir)

    if x is None:
        x = convert_to_angle(tf.random.normal(dynamics.x_shape))

    configs['run_steps'] = run_steps
    configs['print_steps'] = max((run_steps // 100, 1))
    configs['md_steps'] = 100
    runs_dir = os.path.join(log_dir, 'LOADED', 'inference')
    io.check_else_make_dir(runs_dir)
    io.save_dict(configs, runs_dir, name='inference_configs')
    inference_results = run(dynamics=dynamics,
                            configs=configs, x=x, beta=beta,
                            runs_dir=runs_dir, make_plots=make_plots,
                            therm_frac=therm_frac, num_chains=num_chains)

    return inference_results
コード例 #21
0
def make_ridgeplots(
        dataset: xr.Dataset,
        num_chains: int = None,
        out_dir: str = None,
        drop_zeros: bool = False,
        cmap: str = 'viridis_r',
        # default_style: dict = None,
):
    """Make ridgeplots."""
    data = {}
    with sns.axes_style('white', rc={'axes.facecolor': (0, 0, 0, 0)}):
        for key, val in dataset.data_vars.items():
            if 'leapfrog' in val.coords.dims:
                lf_data = {
                    key: [],
                    'lf': [],
                }
                for lf in val.leapfrog.values:
                    # val.shape = (chain, leapfrog, draw)
                    # x.shape = (chain, draw);  selects data for a single lf
                    x = val[{'leapfrog': lf}].values
                    # if num_chains is not None, keep `num_chains` for plotting
                    if num_chains is not None:
                        x = x[:num_chains, :]

                    x = x.flatten()
                    if drop_zeros:
                        x = x[x != 0]
                    #  x = val[{'leapfrog': lf}].values.flatten()
                    lf_arr = np.array(len(x) * [f'{lf}'])
                    lf_data[key].extend(x)
                    lf_data['lf'].extend(lf_arr)

                lfdf = pd.DataFrame(lf_data)
                data[key] = lfdf

                # Initialize the FacetGrid object
                ncolors = len(val.leapfrog.values)
                pal = sns.color_palette(cmap, n_colors=ncolors)
                g = sns.FacetGrid(lfdf, row='lf', hue='lf',
                                  aspect=15, height=0.25, palette=pal)

                # Draw the densities in a few steps
                _ = g.map(sns.kdeplot, key, cut=1,
                          shade=True, alpha=0.7, linewidth=1.25)
                _ = g.map(plt.axhline, y=0, lw=1.5, alpha=0.7, clip_on=False)

                # Define and use a simple function to
                # label the plot in axes coords:
                def label(x, color, label):
                    ax = plt.gca()
                    ax.text(0, 0.10, label, fontweight='bold', color=color,
                            ha='left', va='center', transform=ax.transAxes,
                            fontsize='small')

                _ = g.map(label, key)
                # Set the subplots to overlap
                _ = g.fig.subplots_adjust(hspace=-0.75)
                # Remove the axes details that don't play well with overlap
                _ = g.set_titles('')
                _ = g.set(yticks=[])
                _ = g.set(yticklabels=[])
                _ = g.despine(bottom=True, left=True)
                if out_dir is not None:
                    io.check_else_make_dir(out_dir)
                    out_file = os.path.join(out_dir, f'{key}_ridgeplot.svg')
                    #  logger.log(f'Saving figure to: {out_file}.')
                    plt.savefig(out_file, dpi=400, bbox_inches='tight')

            #plt.close('all')

    #  sns.set(style='whitegrid', palette='bright', context='paper')
    fig = plt.gcf()
    ax = plt.gca()

    return fig, ax, data
コード例 #22
0
def savefig(fig, fpath):
    io.check_else_make_dir(os.path.dirname(fpath))
    io.log(f'Saving figure to: {fpath}.')
    fig.savefig(fpath, dpi=400, bbox_inches='tight')
コード例 #23
0
def savefig(fig, fpath):
    io.check_else_make_dir(os.path.dirname(fpath))
    #  logger.log(f'Saving figure to: {fpath}.')
    fig.savefig(fpath, dpi=400, bbox_inches='tight')
    fig.clf()
    plt.close('all')
コード例 #24
0
    def make_plots(self,
                   run_params,
                   run_data=None,
                   energy_data=None,
                   runs_np=True,
                   out_dir=None):
        """Create trace + KDE plots of lattice observables and energy data."""
        type_str = 'figures_np' if runs_np else 'figures_tf'
        figs_dir = os.path.join(self._log_dir, type_str)
        fig_dir = os.path.join(figs_dir, run_params['run_str'])
        io.check_else_make_dir(fig_dir)

        dataset = None
        energy_dataset = None
        try:
            fname, title_str = self._plot_setup(run_params)
        except FileNotFoundError:
            return dataset, energy_dataset

        tp_fname = f'{fname}_traceplot'
        pp_fname = f'{fname}_posterior'
        rp_fname = f'{fname}_ridgeplot'

        dataset = self.build_dataset(run_data, run_params)

        tp_out_file = os.path.join(fig_dir, f'{tp_fname}.pdf')
        pp_out_file = os.path.join(fig_dir, f'{pp_fname}.pdf')

        var_names = ['tunneling_rate', 'plaqs_diffs']
        if hasattr(dataset, 'dx'):
            var_names.append('dx')
        var_names.extend(['accept_prob', 'charges_squared', 'charges'])

        tp_out_file_ = None
        pp_out_file_ = None
        if out_dir is not None:
            io.check_else_make_dir(out_dir)
            tp_out_file1 = os.path.join(out_dir, f'{tp_fname}.pdf')
            pp_out_file1 = os.path.join(out_dir, f'{pp_fname}.pdf')

        ###################################################
        # Create traceplot + posterior plot of observables
        ###################################################
        self._plot_trace(dataset,
                         tp_out_file,
                         var_names=var_names,
                         out_file1=tp_out_file1)

        self._plot_posterior(dataset,
                             pp_out_file,
                             var_names=var_names,
                             out_file1=pp_out_file1)

        # * * * * * * * * * * * * * * * * *
        # Create ridgeplot of plaq diffs  *
        # * * * * * * * * * * * * * * * * *
        rp_out_file = os.path.join(fig_dir, f'{rp_fname}.pdf')
        _ = az.plot_forest(dataset,
                           kind='ridgeplot',
                           var_names=['plaqs_diffs'],
                           ridgeplot_alpha=0.4,
                           ridgeplot_overlap=0.1,
                           combined=False)
        fig = plt.gcf()
        fig.suptitle(title_str, fontsize='x-large', y=1.025)
        self._savefig(fig, rp_out_file)
        if out_dir is not None:
            rp_out_file1 = os.path.join(out_dir, f'{rp_fname}.pdf')
            self._savefig(fig, rp_out_file1)

        # * * * * * * * * * * * * * * * * * * * * * * * * * *
        # Create traceplot + posterior plot of energy data  *
        # * * * * * * * * * * * * * * * * * * * * * * * * * *
        if energy_data is not None:
            energy_dataset = self.energy_plots(energy_data,
                                               run_params,
                                               fname,
                                               out_dir=out_dir)

        return dataset, energy_dataset
コード例 #25
0
def plot_data(
        data_container: "DataContainer", #  type: "DataContainer",  # noqa:F821
        configs: dict = None,
        out_dir: str = None,
        therm_frac: float = 0,
        params: Union[dict, AttrDict] = None,
        hmc: bool = None,
        num_chains: int = 32,
        profile: bool = False,
        cmap: str = 'crest',
        verbose: bool = False,
        logging_steps: int = 1,
) -> dict:
    """Plot data from `data_container.data`."""
    if verbose:
        keep_strs = list(data_container.data.keys())

    else:
        keep_strs = [
            'charges', 'plaqs', 'accept_prob',
            'Hf_start', 'Hf_mid', 'Hf_end',
            'Hb_start', 'Hb_mid', 'Hb_end',
            'Hwf_start', 'Hwf_mid', 'Hwf_end',
            'Hwb_start', 'Hwb_mid', 'Hwb_end',
            'xeps_start', 'xeps_mid', 'xeps_end',
            'veps_start', 'veps_mid', 'veps_end'
        ]

    with_jupyter = in_notebook()
    # -- TODO: --------------------------------------
    #  * Get rid of unnecessary `params` argument,
    #    all of the entries exist in `configs`.
    # ----------------------------------------------
    if num_chains > 16:
        logger.warning(
            f'Reducing `num_chains` from {num_chains} to 16 for plotting.'
        )
        num_chains = 16

    plot_times = {}
    save_times = {}

    title = None
    if params is not None:
        try:
            title = get_title_str_from_params(params)
        except:
            title = None

    else:
        if configs is not None:
            params = {
                'beta_init': configs['beta_init'],
                'beta_final': configs['beta_final'],
                'x_shape': configs['dynamics_config']['x_shape'],
                'num_steps': configs['dynamics_config']['num_steps'],
                'net_weights': configs['dynamics_config']['net_weights'],
            }
        else:
            params = {}

    tstamp = io.get_timestamp('%Y-%m-%d-%H%M%S')
    plotdir = None
    if out_dir is not None:
        plotdir = os.path.join(out_dir, f'plots_{tstamp}')
        io.check_else_make_dir(plotdir)

    tint_data = {}
    output = {}
    if 'charges' in data_container.data:
        if configs is not None:
            lf = configs['dynamics_config']['num_steps']  # type: int
        else:
            lf = 0
        qarr = np.array(data_container.data['charges'])
        t0 = time.time()
        tint_dict, _ = plot_autocorrs_vs_draws(qarr, num_pts=20,
                                               nstart=1000, therm_frac=0.2,
                                               out_dir=plotdir, lf=lf)
        plot_times['plot_autocorrs_vs_draws'] = time.time() - t0

        tint_data = deepcopy(params)
        tint_data.update({
            'narr': tint_dict['narr'],
            'tint': tint_dict['tint'],
            'run_params': params,
        })

        run_dir = params.get('run_dir', None)
        if run_dir is not None:
            if os.path.isdir(str(Path(run_dir))):
                tint_file = os.path.join(run_dir, 'tint_data.z')
                t0 = time.time()
                io.savez(tint_data, tint_file, 'tint_data')
                save_times['tint_data'] = time.time() - t0

        t0 = time.time()
        qsteps = logging_steps * np.arange(qarr.shape[0])
        _ = plot_charges(qsteps, qarr, out_dir=plotdir, title=title)
        plot_times['plot_charges'] = time.time() - t0

        output.update({
            'tint_dict': tint_dict,
            'charges_steps': qsteps,
            'charges_arr': qarr,
        })

    hmc = params.get('hmc', False) if hmc is None else hmc

    data_dict = {}
    data_vars = {}
    #  charges_steps = []
    #  charges_arr = []
    for key, val in data_container.data.items():
        if key in SKEYS and key not in keep_strs:
            continue
        #  if key == 'x':
        #      continue
        #
        # ====
        # Conditional to skip logdet-related data
        # from being plotted if data generated from HMC run
        if hmc:
            for skip_str in ['Hw', 'ld', 'sld', 'sumlogdet']:
                if skip_str in key:
                    continue

        arr = np.array(val)
        steps = logging_steps * np.arange(len(arr))

        if therm_frac > 0:
            if logging_steps == 1:
                arr, steps = therm_arr(arr, therm_frac=therm_frac)
            else:
                drop = int(therm_frac * arr.shape[0])
                arr = arr[drop:]
                steps = steps[drop:]

        if logging_steps == 1 and therm_frac > 0:
            arr, steps = therm_arr(arr, therm_frac=therm_frac)

        #  if arr.flatten().std() < 1e-2:
        #      logger.warning(f'Skipping plot for: {key}')
        #      logger.warning(f'std({key}) = {arr.flatten().std()} < 1e-2')

        labels = ('MC Step', key)
        data = (steps, arr)

        # -- NOTE: arr.shape: (draws,) = (D,) -------------------------------
        if len(arr.shape) == 1:
            data_dict[key] = xr.DataArray(arr, dims=['draw'], coords=[steps])
            #  if verbose:
            #      plotdir_ = os.path.join(plotdir, f'mcmc_lineplots')
            #      io.check_else_make_dir(plotdir_)
            #      lplot_fname = os.path.join(plotdir_, f'{key}.pdf')
            #      _, _ = mcmc_lineplot(data, labels, title,
            #                           lplot_fname, show_avg=True)
            #      plt.close('all')

        # -- NOTE: arr.shape: (draws, chains) = (D, C) ----------------------
        elif len(arr.shape) == 2:
            data_dict[key] = data
            chains = np.arange(arr.shape[1])
            data_arr = xr.DataArray(arr.T,
                                    dims=['chain', 'draw'],
                                    coords=[chains, steps])
            data_dict[key] = data_arr
            data_vars[key] = data_arr
            #  if verbose:
            #      plotdir_ = os.path.join(plotdir, 'traceplots')
            #      tplot_fname = os.path.join(plotdir_, f'{key}_traceplot.pdf')
            #      _ = mcmc_traceplot(key, data_arr, title, tplot_fname)
            #      plt.close('all')

        # -- NOTE: arr.shape: (draws, leapfrogs, chains) = (D, L, C) ---------
        elif len(arr.shape) == 3:
            _, leapfrogs_, chains_ = arr.shape
            chains = np.arange(chains_)
            leapfrogs = np.arange(leapfrogs_)
            data_dict[key] = xr.DataArray(arr.T,  # NOTE: [arr.T] = (C, L, D)
                                          dims=['chain', 'leapfrog', 'draw'],
                                          coords=[chains, leapfrogs, steps])

    #  plotdir_xr = None
    #  if plotdir is not None:
    #      plotdir_xr = os.path.join(plotdir, 'xarr_plots')

    plotdir_xr = None
    if plotdir is not None:
        plotdir_xr = os.path.join(plotdir, 'xarr_plots')

    t0 = time.time()
    dataset, dtplot = data_container.plot_dataset(plotdir_xr,
                                                  num_chains=num_chains,
                                                  therm_frac=therm_frac,
                                                  ridgeplots=True,
                                                  cmap=cmap,
                                                  profile=profile)

    if not with_jupyter:
        plt.close('all')

    plot_times['data_container.plot_dataset'] = {'total': time.time() - t0}
    for key, val in dtplot.items():
        plot_times['data_container.plot_dataset'][key] = val

    if not hmc and 'Hwf' in data_dict.keys():
        t0 = time.time()
        _ = plot_energy_distributions(data_dict, out_dir=plotdir, title=title)
        plot_times['plot_energy_distributions'] = time.time() - t0

    t0 = time.time()
    _ = mcmc_avg_lineplots(data_dict, title, plotdir)
    plot_times['mcmc_avg_lineplots'] = time.time() - t0

    if not with_jupyter:
        plt.close('all')

    output.update({
        'data_container': data_container,
        'data_dict': data_dict,
        'data_vars': data_vars,
        'out_dir': plotdir,
        'save_times': save_times,
        'plot_times': plot_times,
    })

    return output