def dump_configs(x, data_dir, rank=0, local_rank=0): """Save configs `x` separately for each rank.""" xfile = os.path.join(data_dir, f'x_rank{rank}-{local_rank}.z') io.log('Saving configs from rank ' f'{rank}-{local_rank} to: {xfile}.') head, _ = os.path.split(xfile) io.check_else_make_dir(head) joblib.dump(x, xfile)
def save_data( self, data_dir: Union[str, Path], skip_keys: list[str] = None, use_hdf5: bool = True, save_dataset: bool = True, # compression: str = 'gzip', ): """Save `self.data` entries to individual files in `output_dir`.""" success = 0 data_dir = Path(data_dir) io.check_else_make_dir(str(data_dir)) if use_hdf5: # ------------ Save using h5py ------------------- hfile = data_dir.joinpath(f'data_rank{RANK}.hdf5') self.to_h5pyfile(hfile, skip_keys=skip_keys) success = 1 if save_dataset: # ----- Save to `netCDF4` file using hierarchical grouping ------ self.save_dataset(data_dir) success = 1 if success == 0: # ------------------------------------------------- # Save each `{key: val}` entry in `self.data` as a # compressed file, named `f'{key}.z'` . # ------------------------------------------------- skip_keys = [] if skip_keys is None else skip_keys logger.info(f'Saving individual entries to {data_dir}.') for key, val in self.data.items(): if key in skip_keys: continue out_file = os.path.join(data_dir, f'{key}.z') io.savez(np.array(val), out_file)
def save_networks(self, log_dir): """Save networks to disk.""" models_dir = os.path.join(log_dir, 'training', 'models') io.check_else_make_dir(models_dir) eps_file = os.path.join(models_dir, 'eps.z') io.savez(self.eps.numpy(), eps_file, name='eps') if self.config.separate_networks: xnet_paths = [ os.path.join(models_dir, f'dynamics_xnet{i}') for i in range(self.config.num_steps) ] vnet_paths = [ os.path.join(models_dir, f'dynamics_vnet{i}') for i in range(self.config.num_steps) ] for idx, (xf, vf) in enumerate(zip(xnet_paths, vnet_paths)): xnet = self.xnet[idx] # type: tf.keras.models.Model vnet = self.vnet[idx] # type: tf.keras.models.Model io.log(f'Saving `xnet{idx}` to {xf}.') io.log(f'Saving `vnet{idx}` to {vf}.') xnet.save(xf) vnet.save(vf) else: xnet_paths = os.path.join(models_dir, 'dynamics_xnet') vnet_paths = os.path.join(models_dir, 'dynamics_vnet') io.log(f'Saving `xnet` to {xnet_paths}.') io.log(f'Saving `vnet` to {vnet_paths}.') self.xnet.save(xnet_paths) self.vnet.save(vnet_paths)
def save_dataset(self, out_dir, therm_frac=0.): """Save `self.data` as `xr.Dataset` to `out_dir/dataset.nc`.""" dataset = self.get_dataset(therm_frac) io.check_else_make_dir(out_dir) outfile = Path(out_dir).joinpath('dataset.nc') mode = 'a' if outfile.is_file() else 'w' logger.debug(f'Saving dataset to: {outfile}.') dataset.to_netcdf(str(outfile), mode=mode)
def setup_directories(configs: dict) -> dict: """Setup directories for training.""" logfile = os.path.join(os.getcwd(), 'log_dirs.txt') ensure_new = configs.get('ensure_new', False) logdir = configs.get('logdir', configs.get('log_dir', None)) if logdir is not None: logdir_exists = os.path.isdir(logdir) contents = os.listdir(logdir) logdir_nonempty = False if contents is not None and isinstance(contents, list): if len(contents) > 0: logdir_nonempty = True if logdir_exists and logdir_nonempty and ensure_new: raise ValueError( f'Nonempty `logdir`, but `ensure_new={ensure_new}') # Create `logdir`, `logdir/training/...`' etc dirs = io.setup_directories(configs, timestamps=TSTAMPS) configs['dirs'] = dirs logdir = dirs.get('logdir', dirs.get('log_dir', None)) # type: str configs['log_dir'] = logdir configs['logdir'] = logdir restore_dir = configs.get('restore_from', None) if restore_dir is None and not ensure_new: candidate = look_for_previous_logdir(logdir) if candidate != logdir: if candidate.is_dir(): nckpts = len(list(candidate.rglob('checkpoint'))) if nckpts > 0: restore_dir = candidate configs['restore_from'] = restore_dir if restore_dir is not None and not ensure_new: try: restored = load_configs_from_logdir(restore_dir) if restored is not None: io.save_dict(restored, logdir, name='restored_train_configs') except FileNotFoundError: logger.warning(f'Unable to load configs from {restore_dir}') if RANK == 0: io.check_else_make_dir(logdir) restore_dir = configs.get('restore_dir', None) if restore_dir is not None and not ensure_new: try: restored = load_configs_from_logdir(restore_dir) if restored is not None: io.save_dict(restored, logdir, name='restored_train_configs') except FileNotFoundError: logger.warning(f'Unable to load configs from {restore_dir}') pass return configs
def save_data(self, data_dir, rank=0): """Save `self.data` entries to individual files in `output_dir`.""" if rank != 0: return io.check_else_make_dir(data_dir) for key, val in self.data.items(): out_file = os.path.join(data_dir, f'{key}.z') io.savez(np.array(val), out_file)
def train_hmc(flags): """Main method for training HMC model.""" hflags = AttrDict(dict(flags).copy()) lr_config = AttrDict(hflags.pop('lr_config', None)) config = AttrDict(hflags.pop('dynamics_config', None)) net_config = AttrDict(hflags.pop('network_config', None)) hflags.train_steps = hflags.pop('hmc_steps', None) hflags.beta_init = hflags.beta_final config.update({ 'hmc': True, 'use_ncp': False, 'aux_weight': 0., 'zero_init': False, 'separate_networks': False, 'use_conv_net': False, 'directional_updates': False, 'use_scattered_xnet_update': False, 'use_tempered_traj': False, 'gauge_eq_masks': False, }) hflags.profiler = False hflags.make_summaries = True lr_config = LearningRateConfig( warmup_steps=0, decay_rate=0.9, decay_steps=hflags.train_steps // 10, lr_init=lr_config.get('lr_init', None), ) train_dirs = io.setup_directories(hflags, 'training_hmc') dynamics = GaugeDynamics(hflags, config, net_config, lr_config) dynamics.save_config(train_dirs.config_dir) x, train_data = train_dynamics(dynamics, hflags, dirs=train_dirs) if IS_CHIEF: output_dir = os.path.join(train_dirs.train_dir, 'outputs') io.check_else_make_dir(output_dir) train_data.save_data(output_dir) params = { 'eps': dynamics.eps, 'num_steps': dynamics.config.num_steps, 'beta_init': hflags.beta_init, 'beta_final': hflags.beta_final, 'lattice_shape': dynamics.config.lattice_shape, 'net_weights': NET_WEIGHTS_HMC, } plot_data(train_data, train_dirs.train_dir, hflags, thermalize=True, params=params) return x, train_data, dynamics.eps.numpy()
def __init__(self, steps, header=None, dirs=None, print_steps=100): self.steps = steps self.print_steps = print_steps self.dirs = dirs self.data_strs = [header] self.steps_arr = [] self.data = AttrDict(defaultdict(list)) if dirs is not None: io.check_else_make_dir( [v for k, v in dirs.items() if 'file' not in k])
def plot_data(train_data, out_dir, flags, thermalize=False, params=None): out_dir = os.path.join(out_dir, 'plots') io.check_else_make_dir(out_dir) title = None if params is None else get_title_str_from_params(params) logging_steps = flags.get('logging_steps', 1) flags_file = os.path.join(out_dir, 'FLAGS.z') if os.path.isfile(flags_file): train_flags = io.loadz(flags_file) logging_steps = train_flags.get('logging_steps', 1) # logging_steps = flags.logging_steps if 'training' in out_dir else 1 data_dict = {} for key, val in train_data.data.items(): if key == 'x': continue arr = np.array(val) steps = logging_steps * np.arange(len(np.array(val))) if thermalize or key == 'dt': arr, steps = therm_arr(arr, therm_frac=0.33) # steps = steps[::logging_setps] # steps *= logging_steps labels = ('MC Step', key) data = (steps, arr) if len(arr.shape) == 1: lplot_fname = os.path.join(out_dir, f'{key}.png') _, _ = mcmc_lineplot(data, labels, title, lplot_fname, show_avg=True) elif len(arr.shape) > 1: data_dict[key] = data chains = np.arange(arr.shape[1]) data_arr = xr.DataArray(arr.T, dims=['chain', 'draw'], coords=[chains, steps]) tplot_fname = os.path.join(out_dir, f'{key}_traceplot.png') _ = mcmc_traceplot(key, data_arr, title, tplot_fname) plt.close('all') _ = mcmc_avg_lineplots(data_dict, title, out_dir) _ = plot_charges(*data_dict['charges'], out_dir=out_dir, title=title) _ = plot_energy_distributions(data_dict, out_dir=out_dir, title=title) plt.close('all')
def plot_dataset( self, out_dir: str = None, therm_frac: float = 0., num_chains: int = None, ridgeplots: bool = True, profile: bool = False, skip: Union[list[str], str] = None, cmap: str = 'viridis_r', ): """Create trace plot + histogram for each entry in self.data.""" tdict = {} dataset = self.get_dataset(therm_frac) for key, val in dataset.data_vars.items(): if skip is not None: if key in skip: continue t0 = time.time() try: std = np.std(val.values.flatten()) except TypeError: continue # if np.std(val.values.flatten()) < 1e-2: if std < 1e-2: continue if len(val.shape) == 2: # shape: (chain, draw) val = val[:num_chains, :] if len(val.shape) == 3: # shape: (chain, leapfrogs, draw) val = val[:num_chains, :, :] fig, ax = plt.subplots(constrained_layout=True, figsize=set_size()) _ = val.plot(ax=ax) if out_dir is not None: io.check_else_make_dir(out_dir) out_file = os.path.join(out_dir, f'{key}_xrPlot.svg') fig.savefig(out_file, dpi=400, bbox_inches='tight') plt.close('all') plt.clf() if profile: tdict[key] = time.time() - t0 if out_dir is not None: out_dir = os.path.join(out_dir, 'ridgeplots') io.check_else_make_dir(out_dir) if ridgeplots: make_ridgeplots(dataset, num_chains=num_chains, out_dir=out_dir, cmap=cmap) return dataset, tdict
def mcmc_avg_lineplots(data, title=None, out_dir=None): """Plot trace of avg.""" fig, axes = None, None for idx, (key, val) in enumerate(data.items()): fig, axes = plt.subplots(ncols=2, figsize=set_size(subplots=(1, 2)), constrained_layout=True) axes = axes.flatten() if len(val) == 2: if len(val[0].shape) > len(val[1].shape): arr, steps = val else: steps, arr = val else: arr = val steps = np.arange(arr.shape[0]) if isinstance(arr, xr.DataArray): arr = arr.values if len(arr.shape) == 3: # ==== # TODO: Create separate plots for each leapfrog? arr = np.mean(arr, axis=1) xlabel = 'MC Step' if len(val.shape) == 1: avg = arr ylabel = ' '.join(key.split('_')) else: avg = np.mean(arr, axis=1) ylabel = ' '.join(key.split('_')) + r" avg" _ = axes[0].plot(steps, avg, color=COLORS[idx]) _ = axes[0].set_xlabel(xlabel) _ = axes[0].set_ylabel(ylabel) _ = sns.kdeplot(arr.flatten(), ax=axes[1], color=COLORS[idx], fill=True) _ = axes[1].set_xlabel(ylabel) _ = axes[1].set_ylabel('') if title is not None: _ = fig.suptitle(title) if out_dir is not None: dir_ = os.path.join(out_dir, 'avg_lineplots') io.check_else_make_dir(dir_) fpath = os.path.join(dir_, f'{key}_avg.pdf') savefig(fig, fpath) return fig, axes
def energy_traceplot(key, arr, out_dir=None, title=None): if out_dir is not None: out_dir = os.path.join(out_dir, 'energy_traceplots') io.check_else_make_dir(out_dir) for idx in range(arr.shape[1]): arr_ = arr[:, idx, :] steps = np.arange(arr_.shape[0]) chains = np.arange(arr_.shape[1]) data_arr = xr.DataArray(arr_.T, dims=['chain', 'draw'], coords=[chains, steps]) new_key = f'{key}_lf{idx}' if out_dir is not None: tplot_fname = os.path.join(out_dir, f'{new_key}_traceplot.pdf') _ = mcmc_traceplot(new_key, data_arr, title, tplot_fname)
def run_hmc( args: AttrDict, hmc_dir: str = None, skip_existing: bool = False, ) -> (GaugeDynamics, DataContainer, tf.Tensor): """Run HMC using `inference_args` on a model specified by `params`. NOTE: ----- args should be a dict with the following keys: - 'hmc' - 'eps' - 'beta' - 'num_steps' - 'run_steps' - 'lattice_shape' """ if not IS_CHIEF: return None, None, None if hmc_dir is None: root_dir = os.path.join(GAUGE_LOGS_DIR, 'hmc_logs') month_str = io.get_timestamp('%Y_%m') hmc_dir = os.path.join(root_dir, month_str) io.check_else_make_dir(hmc_dir) def get_run_fstr(run_dir): _, tail = os.path.split(run_dir) fstr = tail.split('-')[0] return fstr if skip_existing: run_dirs = [os.path.join(hmc_dir, i) for i in os.listdir(hmc_dir)] run_fstrs = [get_run_fstr(i) for i in run_dirs] run_fstr = io.get_run_dir_fstr(args) if run_fstr in run_fstrs: io.log('ERROR:Existing run found! Skipping.') return None, None, None dynamics = build_dynamics(args) dynamics, run_data, x = run(dynamics, args, runs_dir=hmc_dir) return dynamics, run_data, x
def __init__( self, steps: int = None, # header: str = None, dirs: dict[str, Union[str, Path]] = None, print_steps: int = 100, ): self.steps = steps self.data_strs = [] self.steps_arr = [] self.print_steps = print_steps self.data = {} # self.data = AttrDict(defaultdict(list)) if dirs is not None: names = [v for k, v in dirs.items() if 'file' not in k] io.check_else_make_dir(names) else: dirs = {} self.dirs = {k: Path(v).resolve() for k, v in dirs.items()}
def write_to_csv(self, log_dir, run_dir, hmc=False): """Write data averages to bulk csv file for comparing runs.""" _, run_str = os.path.split(run_dir) avg_data = { 'log_dir': log_dir, 'run_dir': run_str, 'hmc': hmc, } for key, val in dict(sorted(self.data.items())).items(): # try: # arr = val.numpy() # except AttributeError: # try: # arr = np.array(val) # except ValueError: # arr = val try: avg_data[key] = np.mean(val) # avg_data[key] = val.mean() except AttributeError: raise AttributeError(f'Unable to call `val.mean()` on {val}') # if 'steps' not in avg_data: # avg_data['steps'] = len(steps) # avg_data[key] = np.mean(arr) avg_df = pd.DataFrame(avg_data, index=[0]) outdir = os.path.join(BASE_DIR, 'logs', 'GaugeModel_logs') csv_file = os.path.join(outdir, 'inference.csv') head, tail = os.path.split(csv_file) io.check_else_make_dir(head) logger.info(f'Appending inference results to {csv_file}.') if not os.path.isfile(csv_file): avg_df.to_csv(csv_file, header=True, index=False, mode='w') else: avg_df.to_csv(csv_file, header=False, index=False, mode='a')
def run_from_log_dir(log_dir: str, net_weights: NetWeights, run_steps=5000): configs = load_configs_from_log_dir(log_dir) if 'x_shape' not in configs['dynamics_config'].keys(): x_shape = configs['dynamics_config'].pop('lattice_shape') configs['dynamics_config']['x_shape'] = x_shape beta = configs['beta_final'] nwstr = 'nw' + ''.join([f'{int(i)}' for i in net_weights]) run_dir = os.path.join(PROJECT_DIR, 'l2hmc_function_tests', 'inference', f'beta{beta}', f'{nwstr}') if os.path.isdir(run_dir): io.log(f'EXISTING RUN FOUND AT: {run_dir}, SKIPPING!', style='bold red') io.check_else_make_dir(run_dir) log_dir = configs.get('log_dir', None) configs['log_dir_orig'] = log_dir configs['log_dir'] = run_dir configs['run_steps'] = run_steps configs = AttrDict(configs) dynamics = build_dynamics(configs) xnet, vnet = dynamics._load_networks(log_dir) dynamics.xnet = xnet dynamics.vnet = vnet io.log(f'Original dynamics.net_weights: {dynamics.net_weights}') io.log(f'Setting `dynamics.net_weights` to: {net_weights}') dynamics._set_net_weights(net_weights) dynamics.net_weights = net_weights io.log(f'Now, dynamics.net_weights: {dynamics.net_weights}') dynamics, train_data, x = short_training(1000, beta, log_dir=log_dir, dynamics=dynamics, x=None) inference_results = run(dynamics, configs, beta=beta, runs_dir=run_dir, md_steps=500, make_plots=True, therm_frac=0.2, num_chains=16) return inference_results
def run( dynamics: GaugeDynamics, args: AttrDict, x: tf.Tensor = None, runs_dir: str = None ) -> (GaugeDynamics, DataContainer, tf.Tensor): """Run inference.""" if not IS_CHIEF: return None, None, None if runs_dir is None: if dynamics.config.hmc: runs_dir = os.path.join(args.log_dir, 'inference_hmc') else: runs_dir = os.path.join(args.log_dir, 'inference') eps = dynamics.eps if hasattr(eps, 'numpy'): eps = eps.numpy() args.eps = eps io.check_else_make_dir(runs_dir) run_dir = io.make_run_dir(args, runs_dir) data_dir = os.path.join(run_dir, 'run_data') summary_dir = os.path.join(run_dir, 'summaries') log_file = os.path.join(run_dir, 'run_log.txt') io.check_else_make_dir([run_dir, data_dir, summary_dir]) writer = tf.summary.create_file_writer(summary_dir) writer.set_as_default() run_steps = args.get('run_steps', 2000) beta = args.get('beta', None) if beta is None: beta = args.get('beta_final', None) if x is None: x = convert_to_angle(tf.random.normal(shape=dynamics.x_shape)) run_data, x, _ = run_dynamics(dynamics, args, x, save_x=False) run_data.flush_data_strs(log_file, mode='a') run_data.write_to_csv(args.log_dir, run_dir, hmc=dynamics.config.hmc) io.save_inference(run_dir, run_data) if args.get('save_run_data', True): run_data.save_data(data_dir) run_params = { 'eps': eps, 'beta': beta, 'run_steps': run_steps, 'plaq_weight': dynamics.plaq_weight, 'charge_weight': dynamics.charge_weight, 'lattice_shape': dynamics.lattice_shape, 'num_steps': dynamics.config.num_steps, 'net_weights': dynamics.net_weights, 'input_shape': dynamics.x_shape, } run_params.update(dynamics.params) io.save_params(run_params, run_dir, name='run_params') args.logging_steps = 1 plot_data(run_data, run_dir, args, thermalize=True, params=run_params) return dynamics, run_data, x
def calc_tau_int_from_dir( input_path: str, hmc: bool = False, px_cutoff: float = None, therm_frac: float = 0.2, num_pts: int = 50, nstart: int = 100, make_plot: bool = True, save_data: bool = True, keep_charges: bool = False, ): """ NOTE: `load_charges_from_dir` returns `output_arr`: `output_arr` is a list of dicts, each of which look like: output = { 'beta': beta, 'lf': lf, 'eps': eps, 'traj_len': lf * eps, 'qarr': charges, 'run_params': params, 'run_dir': run_dir, } """ output_arr = load_charges_from_dir(input_path, hmc=hmc) if output_arr is None: logger.info(', '.join([ 'WARNING: Skipping entry!', f'\t unable to load charge data from {input_path}', ])) return None tint_data = {} for output in output_arr: run_dir = output['run_dir'] beta = output['beta'] lf = output['lf'] # eps = output['eps'] # traj_len = output['traj_len'] run_params = output['run_params'] if hmc: data_dir = os.path.join(input_path, 'run_data') plot_dir = os.path.join(input_path, 'plots', 'tau_int_plots') else: run_dir = output['run_params']['run_dir'] data_dir = os.path.join(run_dir, 'run_data') plot_dir = os.path.join(run_dir, 'plots') outfile = os.path.join(data_dir, 'tau_int_data.z') outfile1 = os.path.join(data_dir, 'tint_data.z') fdraws = os.path.join(plot_dir, 'tau_int_vs_draws.pdf') ftlen = os.path.join(plot_dir, 'tau_int_vs_traj_len.pdf') c1 = os.path.isfile(outfile) c11 = os.path.isfile(outfile1) c2 = os.path.isfile(fdraws) c3 = os.path.isfile(ftlen) if c1 or c11 or c2 or c3: loaded = io.loadz(outfile) output.update(loaded) logger.info(', '.join([ 'WARNING: Loading existing data' f'\t Found existing data at: {outfile}.', ])) loaded = io.loadz(outfile) n = loaded.get('draws', loaded.get('narr', None)) tint = loaded.get('tau_int', loaded.get('tint', None)) output.update(loaded) xeps_check = 'xeps' in output['run_params'].keys() veps_check = 'veps' in output['run_params'].keys() if xeps_check and veps_check: xeps = tf.reduce_mean(output['run_params']['xeps']) veps = tf.reduce_mean(output['run_params']['veps']) eps = tf.reduce_mean([xeps, veps]).numpy() else: eps = output['eps'] if isinstance(eps, list): eps = tf.reduce_mean(eps) elif tf.is_tensor(eps): try: eps = eps.numpy() except AttributeError: eps = tf.reduce_mean(eps) traj_len = lf * eps qarr, _ = therm_arr(output['qarr'], therm_frac=therm_frac) n, tint = calc_autocorr(qarr.T, num_pts=num_pts, nstart=nstart) # output.update({ # 'draws': n, # 'tau_int': tint, # 'qarr.shape': qarr.shape, # }) tint_data[run_dir] = { 'run_dir': run_dir, 'run_params': run_params, 'lf': lf, 'eps': eps, 'traj_len': traj_len, 'narr': n, 'tint': tint, 'qarr.shape': qarr.shape, } if save_data: io.savez(tint_data, outfile, 'tint_data') # io.savez(output, outfile, name='tint_data') if make_plot: # fbeta = os.path.join(plot_dir, 'tau_int_vs_beta.pdf') io.check_else_make_dir(plot_dir) prefix = 'HMC' if hmc else 'L2HMC' xlabel = 'draws' ylabel = r'$\tau_{\mathrm{int}}$ (estimate)' title = (f'{prefix}, ' + r'$\beta=$' + f'{beta}, ' + r'$N_{\mathrm{lf}}=$' + f'{lf}, ' + r'$\varepsilon=$' + f'{eps:.2g}, ' + r'$\lambda=$' + f'{traj_len:.2g}') _, ax = plt.subplots(constrained_layout=True) best = [] for t in tint.T: _ = ax.plot(n, t, marker='.', color='k') best.append(t[-1]) _ = ax.set_ylabel(ylabel) _ = ax.set_xlabel(xlabel) _ = ax.set_title(title) _ = ax.set_xscale('log') _ = ax.set_yscale('log') _ = ax.grid(alpha=0.4) logger.info(f'Saving figure to: {fdraws}') _ = plt.savefig(fdraws, dpi=400, bbox_inches='tight') plt.close('all') _, ax = plt.subplots() for b in best: _ = ax.plot(traj_len, b, marker='.', color='k') _ = ax.set_ylabel(ylabel) _ = ax.set_xlabel(r'trajectory length, $\lambda$') _ = ax.set_title(title) _ = ax.set_yscale('log') _ = ax.grid(True, alpha=0.4) logger.info(f'Saving figure to: {ftlen}') _ = plt.savefig(ftlen, dpi=400, bbox_inches='tight') plt.close('all') return tint_data
def run( dynamics: GaugeDynamics, configs: dict[str, Any], x: tf.Tensor = None, beta: float = None, runs_dir: str = None, make_plots: bool = True, therm_frac: float = 0.33, num_chains: int = 16, save_x: bool = False, md_steps: int = 50, skip_existing: bool = False, save_dataset: bool = True, use_hdf5: bool = False, skip_keys: list[str] = None, run_steps: int = None, ) -> InferenceResults: """Run inference. (Note: Higher-level than `run_dynamics`).""" if not IS_CHIEF: return InferenceResults(None, None, None, None, None) if num_chains > 16: logger.warning(f'Reducing `num_chains` from: {num_chains} to {16}.') num_chains = 16 if run_steps is None: run_steps = configs.get('run_steps', 50000) if beta is None: beta = configs.get('beta_final', configs.get('beta', None)) assert beta is not None logdir = configs.get('log_dir', configs.get('logdir', None)) if runs_dir is None: rs = 'inference_hmc' if dynamics.config.hmc else 'inference' runs_dir = os.path.join(logdir, rs) io.check_else_make_dir(runs_dir) run_dir = io.make_run_dir(configs=configs, base_dir=runs_dir, beta=beta, skip_existing=skip_existing) logger.info(f'run_dir: {run_dir}') data_dir = os.path.join(run_dir, 'run_data') summary_dir = os.path.join(run_dir, 'summaries') log_file = os.path.join(run_dir, 'run_log.txt') io.check_else_make_dir([run_dir, data_dir, summary_dir]) writer = tf.summary.create_file_writer(summary_dir) writer.set_as_default() configs['logging_steps'] = 1 if x is None: x = convert_to_angle(tf.random.uniform(shape=dynamics.x_shape, minval=-PI, maxval=PI)) # == RUN DYNAMICS ======================================================= nw = dynamics.net_weights inf_type = 'HMC' if dynamics.config.hmc else 'inference' logger.info(', '.join([f'Running {inf_type}', f'beta={beta}', f'nw={nw}'])) t0 = time.time() results = run_dynamics(dynamics, flags=configs, beta=beta, x=x, writer=writer, save_x=save_x, md_steps=md_steps) logger.info(f'Done running {inf_type}. took: {time.time() - t0:.4f} s') # ======================================================================= run_data = results.run_data run_data.update_dirs({'log_dir': logdir, 'run_dir': run_dir}) run_params = { 'hmc': dynamics.config.hmc, 'run_dir': run_dir, 'beta': beta, 'run_steps': run_steps, 'plaq_weight': dynamics.plaq_weight, 'charge_weight': dynamics.charge_weight, 'x_shape': dynamics.x_shape, 'num_steps': dynamics.config.num_steps, 'net_weights': dynamics.net_weights, 'input_shape': dynamics.x_shape, } inf_log_fpath = os.path.join(run_dir, 'inference_log.txt') with open(inf_log_fpath, 'a') as f: f.writelines('\n'.join(results.data_strs)) traj_len = dynamics.config.num_steps * tf.reduce_mean(dynamics.xeps) if hasattr(dynamics, 'xeps') and hasattr(dynamics, 'veps'): xeps_avg = tf.reduce_mean(dynamics.xeps) veps_avg = tf.reduce_mean(dynamics.veps) traj_len = tf.reduce_sum(dynamics.xeps) run_params.update({ 'xeps': dynamics.xeps, 'traj_len': traj_len, 'veps': dynamics.veps, 'xeps_avg': xeps_avg, 'veps_avg': veps_avg, 'eps_avg': (xeps_avg + veps_avg) / 2., }) elif hasattr(dynamics, 'eps'): run_params.update({ 'eps': dynamics.eps, }) io.save_params(run_params, run_dir, name='run_params') save_times = {} plot_times = {} if make_plots: output = plot_data(data_container=run_data, configs=configs, params=run_params, out_dir=run_dir, hmc=dynamics.config.hmc, therm_frac=therm_frac, num_chains=num_chains, profile=True, cmap='crest', logging_steps=1) save_times = io.SortedDict(**output['save_times']) plot_times = io.SortedDict(**output['plot_times']) dt1 = io.SortedDict(**plot_times['data_container.plot_dataset']) plot_times['data_container.plot_dataset'] = dt1 tint_data = { 'beta': beta, 'run_dir': run_dir, 'traj_len': traj_len, 'run_params': run_params, 'eps': run_params['eps_avg'], 'lf': run_params['num_steps'], 'narr': output['tint_dict']['narr'], 'tint': output['tint_dict']['tint'], } t0 = time.time() tint_file = os.path.join(run_dir, 'tint_data.z') io.savez(tint_data, tint_file, 'tint_data') save_times['savez_tint_data'] = time.time() - t0 t0 = time.time() logfile = os.path.join(run_dir, 'inference.log') run_data.save_and_flush(data_dir=data_dir, log_file=logfile, use_hdf5=use_hdf5, skip_keys=skip_keys, save_dataset=save_dataset) save_times['run_data.flush_data_strs'] = time.time() - t0 t0 = time.time() try: run_data.write_to_csv(logdir, run_dir, hmc=dynamics.config.hmc) except TypeError: logger.warning(f'Unable to write to csv. Continuing...') save_times['run_data.write_to_csv'] = time.time() - t0 # t0 = time.time() # io.save_inference(run_dir, run_data) # save_times['io.save_inference'] = time.time() - t0 profdir = os.path.join(run_dir, 'profile_info') io.check_else_make_dir(profdir) io.save_dict(plot_times, profdir, name='plot_times') io.save_dict(save_times, profdir, name='save_times') return InferenceResults(dynamics=results.dynamics, x=results.x, x_arr=results.x_arr, run_data=results.run_data, data_strs=results.data_strs)
def run_inference_from_log_dir( log_dir: str, run_steps: int = 50000, beta: float = None, eps: float = None, make_plots: bool = True, train_steps: int = 10, therm_frac: float = 0.33, batch_size: int = 16, num_chains: int = 16, x: tf.Tensor = None, ) -> InferenceResults: # (type: InferenceResults) """Run inference by loading networks in from `log_dir`.""" configs = _find_configs(log_dir) if configs is None: raise FileNotFoundError( f'Unable to load configs from `log_dir`: {log_dir}. Exiting' ) if eps is not None: configs['dynamics_config']['eps'] = eps else: try: eps_file = os.path.join(log_dir, 'training', 'models', 'eps.z') eps = io.loadz(eps_file) except FileNotFoundError: eps = configs.get('dynamics_config', None).get('eps', None) if beta is not None: configs.update({'beta': beta, 'beta_final': beta}) if batch_size is not None: batch_size = int(batch_size) prev_shape = configs['dynamics_config']['x_shape'] new_shape = (batch_size, *prev_shape[1:]) configs['dynamics_config']['x_shape'] = new_shape configs = AttrDict(configs) dynamics = build_dynamics(configs) xnet, vnet = dynamics._load_networks(log_dir) dynamics.xnet = xnet dynamics.vnet = vnet if train_steps > 0: dynamics, train_data, x = short_training(train_steps, configs.beta_final, log_dir, dynamics, x=x) else: dynamics.compile(loss=dynamics.calc_losses, optimizer=dynamics.optimizer, experimental_run_tf_function=False) _, log_str = os.path.split(log_dir) if x is None: x = convert_to_angle(tf.random.normal(dynamics.x_shape)) configs['run_steps'] = run_steps configs['print_steps'] = max((run_steps // 100, 1)) configs['md_steps'] = 100 runs_dir = os.path.join(log_dir, 'LOADED', 'inference') io.check_else_make_dir(runs_dir) io.save_dict(configs, runs_dir, name='inference_configs') inference_results = run(dynamics=dynamics, configs=configs, x=x, beta=beta, runs_dir=runs_dir, make_plots=make_plots, therm_frac=therm_frac, num_chains=num_chains) return inference_results
def make_ridgeplots( dataset: xr.Dataset, num_chains: int = None, out_dir: str = None, drop_zeros: bool = False, cmap: str = 'viridis_r', # default_style: dict = None, ): """Make ridgeplots.""" data = {} with sns.axes_style('white', rc={'axes.facecolor': (0, 0, 0, 0)}): for key, val in dataset.data_vars.items(): if 'leapfrog' in val.coords.dims: lf_data = { key: [], 'lf': [], } for lf in val.leapfrog.values: # val.shape = (chain, leapfrog, draw) # x.shape = (chain, draw); selects data for a single lf x = val[{'leapfrog': lf}].values # if num_chains is not None, keep `num_chains` for plotting if num_chains is not None: x = x[:num_chains, :] x = x.flatten() if drop_zeros: x = x[x != 0] # x = val[{'leapfrog': lf}].values.flatten() lf_arr = np.array(len(x) * [f'{lf}']) lf_data[key].extend(x) lf_data['lf'].extend(lf_arr) lfdf = pd.DataFrame(lf_data) data[key] = lfdf # Initialize the FacetGrid object ncolors = len(val.leapfrog.values) pal = sns.color_palette(cmap, n_colors=ncolors) g = sns.FacetGrid(lfdf, row='lf', hue='lf', aspect=15, height=0.25, palette=pal) # Draw the densities in a few steps _ = g.map(sns.kdeplot, key, cut=1, shade=True, alpha=0.7, linewidth=1.25) _ = g.map(plt.axhline, y=0, lw=1.5, alpha=0.7, clip_on=False) # Define and use a simple function to # label the plot in axes coords: def label(x, color, label): ax = plt.gca() ax.text(0, 0.10, label, fontweight='bold', color=color, ha='left', va='center', transform=ax.transAxes, fontsize='small') _ = g.map(label, key) # Set the subplots to overlap _ = g.fig.subplots_adjust(hspace=-0.75) # Remove the axes details that don't play well with overlap _ = g.set_titles('') _ = g.set(yticks=[]) _ = g.set(yticklabels=[]) _ = g.despine(bottom=True, left=True) if out_dir is not None: io.check_else_make_dir(out_dir) out_file = os.path.join(out_dir, f'{key}_ridgeplot.svg') # logger.log(f'Saving figure to: {out_file}.') plt.savefig(out_file, dpi=400, bbox_inches='tight') #plt.close('all') # sns.set(style='whitegrid', palette='bright', context='paper') fig = plt.gcf() ax = plt.gca() return fig, ax, data
def savefig(fig, fpath): io.check_else_make_dir(os.path.dirname(fpath)) io.log(f'Saving figure to: {fpath}.') fig.savefig(fpath, dpi=400, bbox_inches='tight')
def savefig(fig, fpath): io.check_else_make_dir(os.path.dirname(fpath)) # logger.log(f'Saving figure to: {fpath}.') fig.savefig(fpath, dpi=400, bbox_inches='tight') fig.clf() plt.close('all')
def make_plots(self, run_params, run_data=None, energy_data=None, runs_np=True, out_dir=None): """Create trace + KDE plots of lattice observables and energy data.""" type_str = 'figures_np' if runs_np else 'figures_tf' figs_dir = os.path.join(self._log_dir, type_str) fig_dir = os.path.join(figs_dir, run_params['run_str']) io.check_else_make_dir(fig_dir) dataset = None energy_dataset = None try: fname, title_str = self._plot_setup(run_params) except FileNotFoundError: return dataset, energy_dataset tp_fname = f'{fname}_traceplot' pp_fname = f'{fname}_posterior' rp_fname = f'{fname}_ridgeplot' dataset = self.build_dataset(run_data, run_params) tp_out_file = os.path.join(fig_dir, f'{tp_fname}.pdf') pp_out_file = os.path.join(fig_dir, f'{pp_fname}.pdf') var_names = ['tunneling_rate', 'plaqs_diffs'] if hasattr(dataset, 'dx'): var_names.append('dx') var_names.extend(['accept_prob', 'charges_squared', 'charges']) tp_out_file_ = None pp_out_file_ = None if out_dir is not None: io.check_else_make_dir(out_dir) tp_out_file1 = os.path.join(out_dir, f'{tp_fname}.pdf') pp_out_file1 = os.path.join(out_dir, f'{pp_fname}.pdf') ################################################### # Create traceplot + posterior plot of observables ################################################### self._plot_trace(dataset, tp_out_file, var_names=var_names, out_file1=tp_out_file1) self._plot_posterior(dataset, pp_out_file, var_names=var_names, out_file1=pp_out_file1) # * * * * * * * * * * * * * * * * * # Create ridgeplot of plaq diffs * # * * * * * * * * * * * * * * * * * rp_out_file = os.path.join(fig_dir, f'{rp_fname}.pdf') _ = az.plot_forest(dataset, kind='ridgeplot', var_names=['plaqs_diffs'], ridgeplot_alpha=0.4, ridgeplot_overlap=0.1, combined=False) fig = plt.gcf() fig.suptitle(title_str, fontsize='x-large', y=1.025) self._savefig(fig, rp_out_file) if out_dir is not None: rp_out_file1 = os.path.join(out_dir, f'{rp_fname}.pdf') self._savefig(fig, rp_out_file1) # * * * * * * * * * * * * * * * * * * * * * * * * * * # Create traceplot + posterior plot of energy data * # * * * * * * * * * * * * * * * * * * * * * * * * * * if energy_data is not None: energy_dataset = self.energy_plots(energy_data, run_params, fname, out_dir=out_dir) return dataset, energy_dataset
def plot_data( data_container: "DataContainer", # type: "DataContainer", # noqa:F821 configs: dict = None, out_dir: str = None, therm_frac: float = 0, params: Union[dict, AttrDict] = None, hmc: bool = None, num_chains: int = 32, profile: bool = False, cmap: str = 'crest', verbose: bool = False, logging_steps: int = 1, ) -> dict: """Plot data from `data_container.data`.""" if verbose: keep_strs = list(data_container.data.keys()) else: keep_strs = [ 'charges', 'plaqs', 'accept_prob', 'Hf_start', 'Hf_mid', 'Hf_end', 'Hb_start', 'Hb_mid', 'Hb_end', 'Hwf_start', 'Hwf_mid', 'Hwf_end', 'Hwb_start', 'Hwb_mid', 'Hwb_end', 'xeps_start', 'xeps_mid', 'xeps_end', 'veps_start', 'veps_mid', 'veps_end' ] with_jupyter = in_notebook() # -- TODO: -------------------------------------- # * Get rid of unnecessary `params` argument, # all of the entries exist in `configs`. # ---------------------------------------------- if num_chains > 16: logger.warning( f'Reducing `num_chains` from {num_chains} to 16 for plotting.' ) num_chains = 16 plot_times = {} save_times = {} title = None if params is not None: try: title = get_title_str_from_params(params) except: title = None else: if configs is not None: params = { 'beta_init': configs['beta_init'], 'beta_final': configs['beta_final'], 'x_shape': configs['dynamics_config']['x_shape'], 'num_steps': configs['dynamics_config']['num_steps'], 'net_weights': configs['dynamics_config']['net_weights'], } else: params = {} tstamp = io.get_timestamp('%Y-%m-%d-%H%M%S') plotdir = None if out_dir is not None: plotdir = os.path.join(out_dir, f'plots_{tstamp}') io.check_else_make_dir(plotdir) tint_data = {} output = {} if 'charges' in data_container.data: if configs is not None: lf = configs['dynamics_config']['num_steps'] # type: int else: lf = 0 qarr = np.array(data_container.data['charges']) t0 = time.time() tint_dict, _ = plot_autocorrs_vs_draws(qarr, num_pts=20, nstart=1000, therm_frac=0.2, out_dir=plotdir, lf=lf) plot_times['plot_autocorrs_vs_draws'] = time.time() - t0 tint_data = deepcopy(params) tint_data.update({ 'narr': tint_dict['narr'], 'tint': tint_dict['tint'], 'run_params': params, }) run_dir = params.get('run_dir', None) if run_dir is not None: if os.path.isdir(str(Path(run_dir))): tint_file = os.path.join(run_dir, 'tint_data.z') t0 = time.time() io.savez(tint_data, tint_file, 'tint_data') save_times['tint_data'] = time.time() - t0 t0 = time.time() qsteps = logging_steps * np.arange(qarr.shape[0]) _ = plot_charges(qsteps, qarr, out_dir=plotdir, title=title) plot_times['plot_charges'] = time.time() - t0 output.update({ 'tint_dict': tint_dict, 'charges_steps': qsteps, 'charges_arr': qarr, }) hmc = params.get('hmc', False) if hmc is None else hmc data_dict = {} data_vars = {} # charges_steps = [] # charges_arr = [] for key, val in data_container.data.items(): if key in SKEYS and key not in keep_strs: continue # if key == 'x': # continue # # ==== # Conditional to skip logdet-related data # from being plotted if data generated from HMC run if hmc: for skip_str in ['Hw', 'ld', 'sld', 'sumlogdet']: if skip_str in key: continue arr = np.array(val) steps = logging_steps * np.arange(len(arr)) if therm_frac > 0: if logging_steps == 1: arr, steps = therm_arr(arr, therm_frac=therm_frac) else: drop = int(therm_frac * arr.shape[0]) arr = arr[drop:] steps = steps[drop:] if logging_steps == 1 and therm_frac > 0: arr, steps = therm_arr(arr, therm_frac=therm_frac) # if arr.flatten().std() < 1e-2: # logger.warning(f'Skipping plot for: {key}') # logger.warning(f'std({key}) = {arr.flatten().std()} < 1e-2') labels = ('MC Step', key) data = (steps, arr) # -- NOTE: arr.shape: (draws,) = (D,) ------------------------------- if len(arr.shape) == 1: data_dict[key] = xr.DataArray(arr, dims=['draw'], coords=[steps]) # if verbose: # plotdir_ = os.path.join(plotdir, f'mcmc_lineplots') # io.check_else_make_dir(plotdir_) # lplot_fname = os.path.join(plotdir_, f'{key}.pdf') # _, _ = mcmc_lineplot(data, labels, title, # lplot_fname, show_avg=True) # plt.close('all') # -- NOTE: arr.shape: (draws, chains) = (D, C) ---------------------- elif len(arr.shape) == 2: data_dict[key] = data chains = np.arange(arr.shape[1]) data_arr = xr.DataArray(arr.T, dims=['chain', 'draw'], coords=[chains, steps]) data_dict[key] = data_arr data_vars[key] = data_arr # if verbose: # plotdir_ = os.path.join(plotdir, 'traceplots') # tplot_fname = os.path.join(plotdir_, f'{key}_traceplot.pdf') # _ = mcmc_traceplot(key, data_arr, title, tplot_fname) # plt.close('all') # -- NOTE: arr.shape: (draws, leapfrogs, chains) = (D, L, C) --------- elif len(arr.shape) == 3: _, leapfrogs_, chains_ = arr.shape chains = np.arange(chains_) leapfrogs = np.arange(leapfrogs_) data_dict[key] = xr.DataArray(arr.T, # NOTE: [arr.T] = (C, L, D) dims=['chain', 'leapfrog', 'draw'], coords=[chains, leapfrogs, steps]) # plotdir_xr = None # if plotdir is not None: # plotdir_xr = os.path.join(plotdir, 'xarr_plots') plotdir_xr = None if plotdir is not None: plotdir_xr = os.path.join(plotdir, 'xarr_plots') t0 = time.time() dataset, dtplot = data_container.plot_dataset(plotdir_xr, num_chains=num_chains, therm_frac=therm_frac, ridgeplots=True, cmap=cmap, profile=profile) if not with_jupyter: plt.close('all') plot_times['data_container.plot_dataset'] = {'total': time.time() - t0} for key, val in dtplot.items(): plot_times['data_container.plot_dataset'][key] = val if not hmc and 'Hwf' in data_dict.keys(): t0 = time.time() _ = plot_energy_distributions(data_dict, out_dir=plotdir, title=title) plot_times['plot_energy_distributions'] = time.time() - t0 t0 = time.time() _ = mcmc_avg_lineplots(data_dict, title, plotdir) plot_times['mcmc_avg_lineplots'] = time.time() - t0 if not with_jupyter: plt.close('all') output.update({ 'data_container': data_container, 'data_dict': data_dict, 'data_vars': data_vars, 'out_dir': plotdir, 'save_times': save_times, 'plot_times': plot_times, }) return output