def load_inference_data(dirs, search_strs, inference_str='inference'): data = {s: {} for s in search_strs} for d in dirs: print(f'Looking in dir: {d}...') run_dir = Path(os.path.join(d, inference_str)) if run_dir.is_dir(): run_dirs = [x for x in run_dir.iterdir() if x.is_dir()] for rd in run_dirs: print(f'...looking in run_dir: {rd}...') rp_file = os.path.join(str(rd), 'run_params.z') if os.path.isfile(rp_file): params = io.loadz(rp_file) beta = params['beta'] eps = params['eps'] num_steps = params['num_steps'] data_dir = os.path.join(str(rd), 'run_data') if os.path.isdir(data_dir): for search_str in search_strs: dfile = os.path.join(data_dir, f'{search_str}.z') if os.path.isfile(dfile): _data = io.loadz(dfile) try: data[search_str].update({ (beta, num_steps, eps): _data }) except KeyError: data[search_str] = { (beta, num_steps, eps): _data } return data
def load_and_run( args: AttrDict, x: tf.Tensor = None, runs_dir: str = None, ) -> (GaugeDynamics, DataContainer, tf.Tensor): """Load trained model from checkpoint and run inference.""" if not IS_CHIEF: return None, None, None io.print_dict(args) ckpt_dir = os.path.join(args.log_dir, 'training', 'checkpoints') flags = restore_from_train_flags(args) eps_file = os.path.join(args.log_dir, 'training', 'train_data', 'eps.z') flags.eps = io.loadz(eps_file)[-1] dynamics = build_dynamics(flags) ckpt = tf.train.Checkpoint(dynamics=dynamics, optimizer=dynamics.optimizer) manager = tf.train.CheckpointManager(ckpt, max_to_keep=5, directory=ckpt_dir) if manager.latest_checkpoint: io.log(f'Restored model from: {manager.latest_checkpoint}') status = ckpt.restore(manager.latest_checkpoint) status.assert_existing_objects_matched() xfile = os.path.join(args.log_dir, 'training', 'train_data', 'x_rank0.z') io.log(f'Restored x from: {xfile}.') x = io.loadz(xfile) dynamics, run_data, x = run(dynamics, args, x=x, runs_dir=runs_dir) return dynamics, run_data, x
def main(args, random_start=True): """Run inference on trained model from `log_dir/checkpoints/`.""" if not IS_CHIEF: return io.print_flags(args) skip = not args.get('overwrite', False) # If no `log_dir` specified, run generic HMC log_dir = args.get('log_dir', None) if log_dir is None: io.log('`log_dir` not specified, running generic HMC...') _ = run_hmc(args=args, hmc_dir=None, skip_existing=skip) return # Otherwise, load training flags train_flags_file = os.path.join(log_dir, 'training', 'FLAGS.z') train_flags = io.loadz(train_flags_file) beta = args.get('beta', None) eps = args.get('eps', None) if beta is None: io.log('Using `beta_final` from training flags') beta = train_flags['beta_final'] if eps is None: eps_file = os.path.join(log_dir, 'training', 'train_data', 'eps.z') io.log(f'Loading `eps` from {eps_file}') eps_arr = io.loadz(eps_file) eps = tf.cast(eps_arr[-1], TF_FLOAT) # Update `args` with values from training flags args.update({ 'eps': eps, 'beta': beta, 'num_steps': int(train_flags['num_steps']), 'lattice_shape': train_flags['lattice_shape'], }) # Run generic HMC using trained step-size (by loading it from _ = run_hmc(args=args, hmc_dir=None, skip_existing=skip) # `x` will be randomly initialized if passed as `None` x = None if not random_start: # Load the last configuration from the end of training run x_file = os.path.join(args.log_dir, 'training', 'train_data', 'x_rank0.z') x = io.loadz(x_file) if os.path.isfile(x_file) else None # Run inference on trained model from `args.log_dir` args['hmc'] = False # Ensure we're running L2HMC _ = load_and_run(args, x=x) return
def load_charge_data(dirs, hmc=False): data = {} for d in dirs: print(f'Looking in dir: {d}...') if 'inference_hmc' in str(d): print(f'Skipping {str(d)}...') continue dqfile = sorted(d.rglob('dq.z')) qfile = sorted(d.rglob('charges.z')) rpfile = sorted(d.rglob('run_params.z')) num_runs = len(dqfile) if num_runs > 0: for dqf, qf, rpf in zip(dqfile, qfile, rpfile): params = io.loadz(rpf) if 'xeps' and 'veps' in params.keys(): xeps = np.array([i.numpy() for i in params['xeps']]) veps = np.array([i.numpy() for i in params['veps']]) eps = (np.mean(xeps) + np.mean(veps)) / 2. elif 'eps' in params.keys(): eps = params['eps'] params['eps'] = eps params = RunParams(**params) qarr = io.loadz(qf) dqarr = io.loadz(dqf) print('...loading data for (beta, num_steps, eps): ' f'({params.beta}, {params.num_steps}, {params.eps:.3g})') charge_data = ChargeData(q=qarr, dq=dqarr, params=params) try: data[params.beta].update({params.traj_len: charge_data}) except KeyError: data[params.beta] = {params.traj_len: charge_data} # def _update_dict(beta, z, qdata): # try: # z[beta].update({params.traj_len: qdata}) # except KeyError: # z[beta] = {params.traj_len: qdata} # # return z # # data = _update_dict(params.beta, data, charge_data) return data
def _get_important_params(path): """Get important parameters by finding and loading from `run_params.z`.""" params_file = os.path.join(path, 'run_params.z') try: params = io.loadz(params_file) except FileNotFoundError as err: logger.info(f'Unable to locate {params_file}.') raise err lf = params.get('num_steps', None) beta = params.get('beta', params.get('beta_final', None)) xeps = params.get('xeps', None) traj_len = params.get('traj_len', None) if xeps is None: eps = params.get('eps', params.get('eps_avg', None)) if eps is None: raise ValueError(f'Unable to determine `eps`.') eps = np.mean(eps) traj_len = lf * eps else: eps = np.mean(xeps) traj_len = np.sum(xeps) important_params = { 'lf': lf, 'eps': eps, 'traj_len': traj_len, 'beta': beta, 'run_params': params, } return important_params
def _find_configs(log_dir): configs_file = os.path.join(log_dir, 'configs.z') if os.path.isfile(configs_file): return io.loadz(configs_file) configs = [ x for x in Path(log_dir).rglob('*configs.z*') if x.is_file() ] if configs != []: return io.loadz(configs[0]) configs = [ x for x in Path(log_dir).rglob('*FLAGS.z*') if x.is_file() ] if configs != []: return io.loadz(configs[0]) return None
def _load_inference_data(log_dir, fnames, inference_str='inference'): """Helper function for loading inference data from `log_dir`.""" run_dir = os.path.join(log_dir, inference_str) if os.path.isdir(run_dir): data_dir = os.path.join(run_dir, 'run_data') rp_file = os.path.join(run_dir, 'run_params.z') if os.path.isfile(rp_file) and os.path.isdir(data_dir): run_params = io.loadz(rp_file) key = (run_params['beta'], run_params['eps'], run_params['num_steps']) data = [ io.loadz(os.path.join(data_dir, f'{fname}.z')) for fname in fnames ] return key, data
def __init__(self, log_dir=None, n_boot=5000, therm_frac=0.25, nw_include=None, calc_stats=True, filter_str=None, runs_np=False): """Initialization method.""" self._log_dir = log_dir self._n_boot = n_boot self._therm_frac = therm_frac self._nw_include = nw_include self._calc_stats = calc_stats self.run_dirs = io.get_run_dirs(log_dir, filter_str, runs_np) self._params = io.loadz(os.path.join(self._log_dir, 'parameters.pkl')) self._train_weights = ( self._params['x_scale_weight'], self._params['x_translation_weight'], self._params['x_transformation_weight'], self._params['v_scale_weight'], self._params['v_translation_weight'], self._params['v_transformation_weight'], ) _tws_title = ', '.join((str(i) for i in self._train_weights)) self._tws_title = f'({_tws_title})' self._tws_fname = ''.join((io.strf(i) for i in self._train_weights))
def restore_flags(flags, train_dir): """Update `FLAGS` using restored flags from `log_dir`.""" rf_file = os.path.join(train_dir, 'FLAGS.z') restored = AttrDict(dict(io.loadz(rf_file))) io.log(f'Restoring FLAGS from: {rf_file}...') flags.update(restored) return flags
def load_configs_from_logdir(logdir: Union[str, Path]) -> dict[str, Any]: try: configs_file = os.path.join(logdir, 'train_configs.z') configs = io.loadz(configs_file) except (EOFError, KeyError, ValueError): configs_file = os.path.join(logdir, 'train_configs.json') with open(configs_file, 'r') as f: configs = json.load(f) return configs
def plot_data(train_data, out_dir, flags, thermalize=False, params=None): out_dir = os.path.join(out_dir, 'plots') io.check_else_make_dir(out_dir) title = None if params is None else get_title_str_from_params(params) logging_steps = flags.get('logging_steps', 1) flags_file = os.path.join(out_dir, 'FLAGS.z') if os.path.isfile(flags_file): train_flags = io.loadz(flags_file) logging_steps = train_flags.get('logging_steps', 1) # logging_steps = flags.logging_steps if 'training' in out_dir else 1 data_dict = {} for key, val in train_data.data.items(): if key == 'x': continue arr = np.array(val) steps = logging_steps * np.arange(len(np.array(val))) if thermalize or key == 'dt': arr, steps = therm_arr(arr, therm_frac=0.33) # steps = steps[::logging_setps] # steps *= logging_steps labels = ('MC Step', key) data = (steps, arr) if len(arr.shape) == 1: lplot_fname = os.path.join(out_dir, f'{key}.png') _, _ = mcmc_lineplot(data, labels, title, lplot_fname, show_avg=True) elif len(arr.shape) > 1: data_dict[key] = data chains = np.arange(arr.shape[1]) data_arr = xr.DataArray(arr.T, dims=['chain', 'draw'], coords=[chains, steps]) tplot_fname = os.path.join(out_dir, f'{key}_traceplot.png') _ = mcmc_traceplot(key, data_arr, title, tplot_fname) plt.close('all') _ = mcmc_avg_lineplots(data_dict, title, out_dir) _ = plot_charges(*data_dict['charges'], out_dir=out_dir, title=title) _ = plot_energy_distributions(data_dict, out_dir=out_dir, title=title) plt.close('all')
def train(flags: AttrDict, x: tf.Tensor = None, restore_x: bool = False): """Train model. Returns: x (tf.Tensor): Batch of configurations dynamics (GaugeDynamics): Dynamics object. train_data (DataContainer): Object containing train data. flags (AttrDict): AttrDict containing flags used. """ dirs = io.setup_directories(flags) flags.update({'dirs': dirs}) if restore_x: x = None try: xfile = os.path.join(dirs.train_dir, 'train_data', f'x_rank{RANK}-{LOCAL_RANK}.z') x = io.loadz(xfile) except FileNotFoundError: io.log(f'Unable to restore x from {xfile}. Using random init.') if x is None: x = tf.random.normal(flags.dynamics_config['lattice_shape']) x = tf.reshape(x, (x.shape[0], -1)) dynamics = build_dynamics(flags) dynamics.save_config(dirs.config_dir) io.log('\n'.join([120 * '*', 'Training L2HMC sampler...'])) x, train_data = train_dynamics(dynamics, flags, dirs, x=x) if IS_CHIEF: output_dir = os.path.join(dirs.train_dir, 'outputs') train_data.save_data(output_dir) params = { 'beta_init': train_data.data.beta[0], 'beta_final': train_data.data.beta[-1], 'eps': dynamics.eps.numpy(), 'lattice_shape': dynamics.config.lattice_shape, 'num_steps': dynamics.config.num_steps, 'net_weights': dynamics.net_weights, } plot_data(train_data, dirs.train_dir, flags, thermalize=True, params=params) io.log('\n'.join(['Done training model', 120 * '*'])) io.save_dict(dict(flags), dirs.log_dir, 'configs') return x, dynamics, train_data, flags
def restore_flags(flags, train_dir): """Update `FLAGS` using restored flags from `log_dir`.""" rf_file = os.path.join(train_dir, 'FLAGS.z') if os.path.isfile(rf_file): try: restored = io.loadz(rf_file) restored = AttrDict(restored) logger.info(f'Restoring FLAGS from: {rf_file}...') flags.update(restored) except (FileNotFoundError, EOFError): pass return flags
def load_from_dir(d, fnames=None): if fnames is None: fnames = { 'dq': 'dq.z', 'charges': 'charges.z', 'run_params': 'run_params.z' } darr = [x for x in Path(d).iterdir() if x.is_dir()] for rd in darr: files = {k: sorted(rd.glob(f'*{v}*')) for k, v in fnames.items()} data = {k: io.loadz(v) for k, v in files.items()} return data
def load_data(data_dir): """Load data from `data_dir` and populate `self.data`.""" contents = os.listdir(data_dir) fnames = [i for i in contents if i.endswith('.z')] keys = [i.rstrip('.z') for i in fnames] data_files = [os.path.join(data_dir, i) for i in fnames] data = {} for key, val in zip(keys, data_files): if 'x_rank' in key: continue io.log(f'Restored {key} from {val}.') data[key] = io.loadz(val) return AttrDict(data)
def get_observables(self, run_dir=None): """Get all observables from inference_data in `run_dir`.""" run_params = io.loadz(os.path.join(run_dir, 'run_params.pkl')) beta = run_params['beta'] net_weights = tuple([int(i) for i in run_params['net_weights']]) keep = True if self._nw_include is not None: keep = net_weights in self._nw_include # If none (< 10 %) of the proposed configs are rejected, # don't bother loading data and calculating statistics. px = self._load_sqz('px.pkl') avg_px = np.mean(px) if avg_px < 0.1 or not keep: io.log(f'Skipping! nw: {net_weights}, avg_px: {avg_px:.3g}') return None, run_params io.log(f'Loading data for net_weights: {net_weights}...') io.log(f' run_dir: {run_dir}') # load chages, plaqs data charges = self._load_sqz('charges.pkl') plaqs = self._load_sqz('plaqs.pkl') dplq = u1_plaq_exact(beta) - plaqs # thermalize configs px, _ = therm_arr(px, self._therm_frac) dplq, _ = therm_arr(dplq, self._therm_frac) charges, _ = np.insert(charges, 0, 0, axis=0) charges, _ = therm_arr(charges) dq, _ = calc_tunneling_rate(charges) dq = dq.T dx = self._get_dx('dx.pkl') dxf = self.get_dx('dxf.pkl') dxb = self._get_dx('dxb.pkl') observables = { 'plaqs_diffs': dplq, 'accept_prob': px, 'tunneling_rate': dq, } _names = ['dx', 'dxf', 'dxb'] _vals = [dx, dxf, dxb] for name, val in zip(_names, _vals): if val is not None: observables[name] = val return observables
def load_from_dir(d, fnames=None): if fnames is None: fnames = ['dq', 'charges'] darr = [x for x in Path(d).iterdir() if x.is_dir()] data = {} for p in darr: for fname in fnames: data[fname] = {} files = p.rglob(f'*{fname}*') if len(files) > 0: for f in files: x = io.loadz(f) data[fname] = x return data
def load_data(data_dir: Union[str, Path]): """Load data from `data_dir` and populate `self.data`.""" # TODO: Update to use h5py for `.hdf5` files contents = os.listdir(data_dir) fnames = [i for i in contents if i.endswith('.z')] keys = [i.rstrip('.z') for i in fnames] data_files = [os.path.join(data_dir, i) for i in fnames] data = {} for key, val in zip(keys, data_files): if 'x_rank' in key: continue data[key] = io.loadz(val) if VERBOSE: logger.info(f'Restored {key} from {val}.') return AttrDict(data)
def restore(self, data_dir, rank=0, local_rank=0, step=None, x_shape=None): """Restore `self.data` from `data_dir`.""" if step is not None: self.steps += step x_file = os.path.join(data_dir, f'x_rank{rank}-{local_rank}.z') try: x = io.loadz(x_file) io.log(f'Restored `x` from: {x_file}.', should_print=True) except FileNotFoundError: io.log(f'Unable to load `x` from {x_file}.', level='WARNING') io.log('Using random normal init.', level='WARNING') x = tf.random.normal(x_shape) data = self.load_data(data_dir) for key, val in data.items(): self.data[key] = np.array(val).tolist() return x
def test_resume_training(log_dir: str): """Test restoring a training session from a checkpoint.""" flags = AttrDict( dict(io.loadz(os.path.join(log_dir, 'training', 'FLAGS.z')))) flags.log_dir = log_dir flags.train_steps += flags.get('train_steps', 10) x, dynamics, train_data, flags = train(flags) beta = flags.get('beta', 1.) dynamics, run_data, x = run(dynamics, flags, x=x) return AttrDict({ 'x': x, 'flags': flags, 'log_dir': flags.log_dir, 'dynamics': dynamics, 'run_data': run_data, 'train_data': train_data, })
def __init__(self, params, run_params, run_data, energy_data): self._params = params self._run_params = run_params self._run_data = run_data self._energy_data = self._sort_energy_data(energy_data) self._log_dir = params.get('log_dir', None) self._params = io.loadz(os.path.join(self._log_dir, 'parameters.pkl')) self._train_weights = ( self._params['x_scale_weight'], self._params['x_translation_weight'], self._params['x_transformation_weight'], self._params['v_scale_weight'], self._params['v_translation_weight'], self._params['v_transformation_weight'], ) _tws_title = ', '.join((str(i) for i in self._train_weights)) self._tws_title = f'({_tws_title})' self._tws_fname = ''.join((io._strf(i) for i in self._train_weights))
def restore(self, data_dir, rank=0, local_rank=0, step=None, x_shape=None): """Restore `self.data` from `data_dir`.""" if step is not None: self.steps += step x_file = os.path.join(data_dir, f'x_rank{rank}-{local_rank}.z') try: x = io.loadz(x_file) logger.info(f'Restored `x` from: {x_file}.') except FileNotFoundError: logger.warning(f'Unable to load `x` from {x_file}.') x = np.random.uniform(-np.pi, np.pi, size=x_shape) # x = tf.random.uniform(x_shape, minval=np.pi, maxval=np.pi) data = self.load_data(data_dir) for key, val in data.items(): arr = np.array(val) shape = arr.shape self.data[key] = arr.tolist() logger.debug(f'Restored: train_data.data[{key}].shape={shape}') return x
def _deal_with_new_data(tint_file: str, save: bool = False): loaded = io.loadz(tint_file) if 'tau_int_data' in tint_file: head = os.path.dirname(os.path.split(tint_file)[0]) else: head, _ = os.path.split(str(tint_file)) tint_arr = loaded.get('tint', loaded.get('tau_int', None)) narr = loaded.get('narr', loaded.get('draws', None)) if tint_arr is None or narr is None: raise ValueError(f'Unable to load from {tint_file}.') try: params = _get_important_params(head) except (ValueError, FileNotFoundError) as err: logger.info(f'Error loading params from {head}, skipping!') raise err lf = params.get('lf', None) eps = params.get('eps', None) traj_len = params.get('traj_len', None) beta = params.get('beta', None) data = { 'tint': tint_arr, 'narr': narr, 'run_dir': head, 'lf': lf, 'eps': eps, 'traj_len': traj_len, 'beta': beta, 'run_params': params.get('run_params', None), } if save: outfile = os.path.join(head, 'tint_data.z') logger.info(f'Saving tint data to: {outfile}.') io.savez(data, outfile, 'tint_data') return data
def calc_tau_int_from_dir( input_path: str, hmc: bool = False, px_cutoff: float = None, therm_frac: float = 0.2, num_pts: int = 50, nstart: int = 100, make_plot: bool = True, save_data: bool = True, keep_charges: bool = False, ): """ NOTE: `load_charges_from_dir` returns `output_arr`: `output_arr` is a list of dicts, each of which look like: output = { 'beta': beta, 'lf': lf, 'eps': eps, 'traj_len': lf * eps, 'qarr': charges, 'run_params': params, 'run_dir': run_dir, } """ output_arr = load_charges_from_dir(input_path, hmc=hmc) if output_arr is None: logger.info(', '.join([ 'WARNING: Skipping entry!', f'\t unable to load charge data from {input_path}', ])) return None tint_data = {} for output in output_arr: run_dir = output['run_dir'] beta = output['beta'] lf = output['lf'] # eps = output['eps'] # traj_len = output['traj_len'] run_params = output['run_params'] if hmc: data_dir = os.path.join(input_path, 'run_data') plot_dir = os.path.join(input_path, 'plots', 'tau_int_plots') else: run_dir = output['run_params']['run_dir'] data_dir = os.path.join(run_dir, 'run_data') plot_dir = os.path.join(run_dir, 'plots') outfile = os.path.join(data_dir, 'tau_int_data.z') outfile1 = os.path.join(data_dir, 'tint_data.z') fdraws = os.path.join(plot_dir, 'tau_int_vs_draws.pdf') ftlen = os.path.join(plot_dir, 'tau_int_vs_traj_len.pdf') c1 = os.path.isfile(outfile) c11 = os.path.isfile(outfile1) c2 = os.path.isfile(fdraws) c3 = os.path.isfile(ftlen) if c1 or c11 or c2 or c3: loaded = io.loadz(outfile) output.update(loaded) logger.info(', '.join([ 'WARNING: Loading existing data' f'\t Found existing data at: {outfile}.', ])) loaded = io.loadz(outfile) n = loaded.get('draws', loaded.get('narr', None)) tint = loaded.get('tau_int', loaded.get('tint', None)) output.update(loaded) xeps_check = 'xeps' in output['run_params'].keys() veps_check = 'veps' in output['run_params'].keys() if xeps_check and veps_check: xeps = tf.reduce_mean(output['run_params']['xeps']) veps = tf.reduce_mean(output['run_params']['veps']) eps = tf.reduce_mean([xeps, veps]).numpy() else: eps = output['eps'] if isinstance(eps, list): eps = tf.reduce_mean(eps) elif tf.is_tensor(eps): try: eps = eps.numpy() except AttributeError: eps = tf.reduce_mean(eps) traj_len = lf * eps qarr, _ = therm_arr(output['qarr'], therm_frac=therm_frac) n, tint = calc_autocorr(qarr.T, num_pts=num_pts, nstart=nstart) # output.update({ # 'draws': n, # 'tau_int': tint, # 'qarr.shape': qarr.shape, # }) tint_data[run_dir] = { 'run_dir': run_dir, 'run_params': run_params, 'lf': lf, 'eps': eps, 'traj_len': traj_len, 'narr': n, 'tint': tint, 'qarr.shape': qarr.shape, } if save_data: io.savez(tint_data, outfile, 'tint_data') # io.savez(output, outfile, name='tint_data') if make_plot: # fbeta = os.path.join(plot_dir, 'tau_int_vs_beta.pdf') io.check_else_make_dir(plot_dir) prefix = 'HMC' if hmc else 'L2HMC' xlabel = 'draws' ylabel = r'$\tau_{\mathrm{int}}$ (estimate)' title = (f'{prefix}, ' + r'$\beta=$' + f'{beta}, ' + r'$N_{\mathrm{lf}}=$' + f'{lf}, ' + r'$\varepsilon=$' + f'{eps:.2g}, ' + r'$\lambda=$' + f'{traj_len:.2g}') _, ax = plt.subplots(constrained_layout=True) best = [] for t in tint.T: _ = ax.plot(n, t, marker='.', color='k') best.append(t[-1]) _ = ax.set_ylabel(ylabel) _ = ax.set_xlabel(xlabel) _ = ax.set_title(title) _ = ax.set_xscale('log') _ = ax.set_yscale('log') _ = ax.grid(alpha=0.4) logger.info(f'Saving figure to: {fdraws}') _ = plt.savefig(fdraws, dpi=400, bbox_inches='tight') plt.close('all') _, ax = plt.subplots() for b in best: _ = ax.plot(traj_len, b, marker='.', color='k') _ = ax.set_ylabel(ylabel) _ = ax.set_xlabel(r'trajectory length, $\lambda$') _ = ax.set_title(title) _ = ax.set_yscale('log') _ = ax.grid(True, alpha=0.4) logger.info(f'Saving figure to: {ftlen}') _ = plt.savefig(ftlen, dpi=400, bbox_inches='tight') plt.close('all') return tint_data
def _load_sqz(self, fname): data = io.loadz(os.path.join(self._obs_dir, fname)) return np.squeeze(np.array(data))
def run_inference_from_log_dir( log_dir: str, run_steps: int = 50000, beta: float = None, eps: float = None, make_plots: bool = True, train_steps: int = 10, therm_frac: float = 0.33, batch_size: int = 16, num_chains: int = 16, x: tf.Tensor = None, ) -> InferenceResults: # (type: InferenceResults) """Run inference by loading networks in from `log_dir`.""" configs = _find_configs(log_dir) if configs is None: raise FileNotFoundError( f'Unable to load configs from `log_dir`: {log_dir}. Exiting' ) if eps is not None: configs['dynamics_config']['eps'] = eps else: try: eps_file = os.path.join(log_dir, 'training', 'models', 'eps.z') eps = io.loadz(eps_file) except FileNotFoundError: eps = configs.get('dynamics_config', None).get('eps', None) if beta is not None: configs.update({'beta': beta, 'beta_final': beta}) if batch_size is not None: batch_size = int(batch_size) prev_shape = configs['dynamics_config']['x_shape'] new_shape = (batch_size, *prev_shape[1:]) configs['dynamics_config']['x_shape'] = new_shape configs = AttrDict(configs) dynamics = build_dynamics(configs) xnet, vnet = dynamics._load_networks(log_dir) dynamics.xnet = xnet dynamics.vnet = vnet if train_steps > 0: dynamics, train_data, x = short_training(train_steps, configs.beta_final, log_dir, dynamics, x=x) else: dynamics.compile(loss=dynamics.calc_losses, optimizer=dynamics.optimizer, experimental_run_tf_function=False) _, log_str = os.path.split(log_dir) if x is None: x = convert_to_angle(tf.random.normal(dynamics.x_shape)) configs['run_steps'] = run_steps configs['print_steps'] = max((run_steps // 100, 1)) configs['md_steps'] = 100 runs_dir = os.path.join(log_dir, 'LOADED', 'inference') io.check_else_make_dir(runs_dir) io.save_dict(configs, runs_dir, name='inference_configs') inference_results = run(dynamics=dynamics, configs=configs, x=x, beta=beta, runs_dir=runs_dir, make_plots=make_plots, therm_frac=therm_frac, num_chains=num_chains) return inference_results
def main(args): """Main method for training.""" hmc_steps = args.get('hmc_steps', 0) tf.keras.backend.set_floatx('float32') log_file = os.path.join(os.getcwd(), 'log_dirs.txt') x = None log_dir = args.get('log_dir', None) beta_init = args.get('beta_init', None) beta_final = args.get('beta_final', None) if log_dir is not None: # we want to restore from latest checkpoint train_steps = args.get('train_steps', None) args = restore_flags(args, os.path.join(args.log_dir, 'training')) args.train_steps = train_steps # use newly passed value args.restore = True if beta_init != args.get('beta_init', None): args.beta_init = beta_init if beta_final != args.get('beta_final', None): args.beta_final = beta_final args.train_steps = train_steps else: # New training session timestamps = AttrDict({ 'month': io.get_timestamp('%Y_%m'), 'time': io.get_timestamp('%Y-%M-%d-%H%M%S'), 'hour': io.get_timestamp('%Y-%m-%d-%H'), 'minute': io.get_timestamp('%Y-%m-%d-%H%M'), 'second': io.get_timestamp('%Y-%m-%d-%H%M%S'), }) args.log_dir = io.make_log_dir(args, 'GaugeModel', log_file, timestamps=timestamps) io.write(f'{args.log_dir}', log_file, 'a') args.restore = False if hmc_steps > 0: x, _, eps = train_hmc(args) args.dynamics_config['eps'] = eps dynamics_config = args.get('dynamics_config', None) if dynamics_config is not None: log_dir = dynamics_config.get('log_dir', None) if log_dir is not None: eps_file = os.path.join(log_dir, 'training', 'models', 'eps.z') if os.path.isfile(eps_file): io.log(f'Loading eps from: {eps_file}') eps = io.loadz(eps_file) args.dynamics_config['eps'] = eps _, dynamics, _, args = train(args, x=x) # ==== # Run inference on trained model if args.get('run_steps', 5000) > 0: # ==== # Run with random start dynamics, _, _ = run(dynamics, args) # ==== # Run HMC args.hmc = True args.dynamics_config['eps'] = 0.15 hmc_dir = os.path.join(args.log_dir, 'inference_hmc') _ = run_hmc(args=args, hmc_dir=hmc_dir)
def load_configs_from_log_dir(log_dir): configs = io.loadz(os.path.join(log_dir, 'configs.z')) return configs
def load_charges_from_dir( d: str, hmc: bool = False, px_cutoff: float = None ): """Load charge data from `d`.""" logger.info(f'Looking in {d}...') if not os.path.isdir(os.path.abspath(d)): logger.info(', '.join([ 'WARNING: Skipping entry!', f'{d} is not a directory.', ])) return None if 'old' in str(d): return None if 'inference_hmc' in str(d) and not hmc: return None qfs = [x for x in Path(d).rglob('charges.z') if x.is_file()] pxfs = [x for x in Path(d).rglob('accept_prob.z') if x.is_file()] rpfs = [x for x in Path(d).rglob('run_params.z') if x.is_file()] num_runs = len(qfs) if num_runs == 0: return None output_arr = [] for idx, (qf, pxf, rpf) in enumerate(zip(qfs, pxfs, rpfs)): params = io.loadz(rpf) beta = params['beta'] lf = params['num_steps'] run_dir, _ = os.path.split(rpf) if px_cutoff is not None: px = io.loadz(pxf) midpt = px.shape[0] // 2 px_avg = np.mean(px[midpt:]) if px_avg < px_cutoff: logger.info(', '.join([ f'{WSTR}: Bad acceptance prob.', f'px_avg: {px_avg:.3g} < 0.1', f'dir: {d}', ])) return None if 'xeps' and 'veps' in params.keys(): xeps = np.mean(params['xeps']) veps = np.mean(params['veps']) eps = np.mean([xeps, veps]) else: eps = params.get('eps', None) if eps is None: raise ValueError('Unable to determine eps.') # eps = tf.reduce_mean(eps).numpy() logger.info('Loading data for: ' + ', '.join([ f'beta: {str(beta)}', f'lf: {str(lf)}', f'eps: {str(eps)}', f'run_dir: {run_dir}', ])) charges = io.loadz(qf) charges = np.array(charges, dtype=int) output = { 'beta': beta, 'lf': lf, 'eps': eps, 'traj_len': lf * eps, 'qarr': charges, 'run_params': params, 'run_dir': run_dir, } output_arr.append(output) return output_arr
def restore_from_train_flags(args): """Populate entries in `args` using the training `FLAGS` from `log_dir`.""" train_dir = os.path.join(args.log_dir, 'training') flags = AttrDict(dict(io.loadz(os.path.join(train_dir, 'FLAGS.z')))) return flags