def save_config(self, config_dir: str): """Helper method for saving configuration objects.""" io.save_dict(self.config, config_dir, name='dynamics_config') io.save_dict(self.net_config, config_dir, name='network_config') io.save_dict(self.lr_config, config_dir, name='lr_config') io.save_dict(self.params, config_dir, name='dynamics_params') if self.conv_config is not None and self.config.use_conv_net: io.save_dict(self.conv_config, config_dir, name='conv_config')
def setup_directories(configs: dict) -> dict: """Setup directories for training.""" logfile = os.path.join(os.getcwd(), 'log_dirs.txt') ensure_new = configs.get('ensure_new', False) logdir = configs.get('logdir', configs.get('log_dir', None)) if logdir is not None: logdir_exists = os.path.isdir(logdir) contents = os.listdir(logdir) logdir_nonempty = False if contents is not None and isinstance(contents, list): if len(contents) > 0: logdir_nonempty = True if logdir_exists and logdir_nonempty and ensure_new: raise ValueError( f'Nonempty `logdir`, but `ensure_new={ensure_new}') # Create `logdir`, `logdir/training/...`' etc dirs = io.setup_directories(configs, timestamps=TSTAMPS) configs['dirs'] = dirs logdir = dirs.get('logdir', dirs.get('log_dir', None)) # type: str configs['log_dir'] = logdir configs['logdir'] = logdir restore_dir = configs.get('restore_from', None) if restore_dir is None and not ensure_new: candidate = look_for_previous_logdir(logdir) if candidate != logdir: if candidate.is_dir(): nckpts = len(list(candidate.rglob('checkpoint'))) if nckpts > 0: restore_dir = candidate configs['restore_from'] = restore_dir if restore_dir is not None and not ensure_new: try: restored = load_configs_from_logdir(restore_dir) if restored is not None: io.save_dict(restored, logdir, name='restored_train_configs') except FileNotFoundError: logger.warning(f'Unable to load configs from {restore_dir}') if RANK == 0: io.check_else_make_dir(logdir) restore_dir = configs.get('restore_dir', None) if restore_dir is not None and not ensure_new: try: restored = load_configs_from_logdir(restore_dir) if restored is not None: io.save_dict(restored, logdir, name='restored_train_configs') except FileNotFoundError: logger.warning(f'Unable to load configs from {restore_dir}') pass return configs
def train(flags: AttrDict, x: tf.Tensor = None, restore_x: bool = False): """Train model. Returns: x (tf.Tensor): Batch of configurations dynamics (GaugeDynamics): Dynamics object. train_data (DataContainer): Object containing train data. flags (AttrDict): AttrDict containing flags used. """ dirs = io.setup_directories(flags) flags.update({'dirs': dirs}) if restore_x: x = None try: xfile = os.path.join(dirs.train_dir, 'train_data', f'x_rank{RANK}-{LOCAL_RANK}.z') x = io.loadz(xfile) except FileNotFoundError: io.log(f'Unable to restore x from {xfile}. Using random init.') if x is None: x = tf.random.normal(flags.dynamics_config['lattice_shape']) x = tf.reshape(x, (x.shape[0], -1)) dynamics = build_dynamics(flags) dynamics.save_config(dirs.config_dir) io.log('\n'.join([120 * '*', 'Training L2HMC sampler...'])) x, train_data = train_dynamics(dynamics, flags, dirs, x=x) if IS_CHIEF: output_dir = os.path.join(dirs.train_dir, 'outputs') train_data.save_data(output_dir) params = { 'beta_init': train_data.data.beta[0], 'beta_final': train_data.data.beta[-1], 'eps': dynamics.eps.numpy(), 'lattice_shape': dynamics.config.lattice_shape, 'num_steps': dynamics.config.num_steps, 'net_weights': dynamics.net_weights, } plot_data(train_data, dirs.train_dir, flags, thermalize=True, params=params) io.log('\n'.join(['Done training model', 120 * '*'])) io.save_dict(dict(flags), dirs.log_dir, 'configs') return x, dynamics, train_data, flags
def run( dynamics: GaugeDynamics, configs: dict[str, Any], x: tf.Tensor = None, beta: float = None, runs_dir: str = None, make_plots: bool = True, therm_frac: float = 0.33, num_chains: int = 16, save_x: bool = False, md_steps: int = 50, skip_existing: bool = False, save_dataset: bool = True, use_hdf5: bool = False, skip_keys: list[str] = None, run_steps: int = None, ) -> InferenceResults: """Run inference. (Note: Higher-level than `run_dynamics`).""" if not IS_CHIEF: return InferenceResults(None, None, None, None, None) if num_chains > 16: logger.warning(f'Reducing `num_chains` from: {num_chains} to {16}.') num_chains = 16 if run_steps is None: run_steps = configs.get('run_steps', 50000) if beta is None: beta = configs.get('beta_final', configs.get('beta', None)) assert beta is not None logdir = configs.get('log_dir', configs.get('logdir', None)) if runs_dir is None: rs = 'inference_hmc' if dynamics.config.hmc else 'inference' runs_dir = os.path.join(logdir, rs) io.check_else_make_dir(runs_dir) run_dir = io.make_run_dir(configs=configs, base_dir=runs_dir, beta=beta, skip_existing=skip_existing) logger.info(f'run_dir: {run_dir}') data_dir = os.path.join(run_dir, 'run_data') summary_dir = os.path.join(run_dir, 'summaries') log_file = os.path.join(run_dir, 'run_log.txt') io.check_else_make_dir([run_dir, data_dir, summary_dir]) writer = tf.summary.create_file_writer(summary_dir) writer.set_as_default() configs['logging_steps'] = 1 if x is None: x = convert_to_angle(tf.random.uniform(shape=dynamics.x_shape, minval=-PI, maxval=PI)) # == RUN DYNAMICS ======================================================= nw = dynamics.net_weights inf_type = 'HMC' if dynamics.config.hmc else 'inference' logger.info(', '.join([f'Running {inf_type}', f'beta={beta}', f'nw={nw}'])) t0 = time.time() results = run_dynamics(dynamics, flags=configs, beta=beta, x=x, writer=writer, save_x=save_x, md_steps=md_steps) logger.info(f'Done running {inf_type}. took: {time.time() - t0:.4f} s') # ======================================================================= run_data = results.run_data run_data.update_dirs({'log_dir': logdir, 'run_dir': run_dir}) run_params = { 'hmc': dynamics.config.hmc, 'run_dir': run_dir, 'beta': beta, 'run_steps': run_steps, 'plaq_weight': dynamics.plaq_weight, 'charge_weight': dynamics.charge_weight, 'x_shape': dynamics.x_shape, 'num_steps': dynamics.config.num_steps, 'net_weights': dynamics.net_weights, 'input_shape': dynamics.x_shape, } inf_log_fpath = os.path.join(run_dir, 'inference_log.txt') with open(inf_log_fpath, 'a') as f: f.writelines('\n'.join(results.data_strs)) traj_len = dynamics.config.num_steps * tf.reduce_mean(dynamics.xeps) if hasattr(dynamics, 'xeps') and hasattr(dynamics, 'veps'): xeps_avg = tf.reduce_mean(dynamics.xeps) veps_avg = tf.reduce_mean(dynamics.veps) traj_len = tf.reduce_sum(dynamics.xeps) run_params.update({ 'xeps': dynamics.xeps, 'traj_len': traj_len, 'veps': dynamics.veps, 'xeps_avg': xeps_avg, 'veps_avg': veps_avg, 'eps_avg': (xeps_avg + veps_avg) / 2., }) elif hasattr(dynamics, 'eps'): run_params.update({ 'eps': dynamics.eps, }) io.save_params(run_params, run_dir, name='run_params') save_times = {} plot_times = {} if make_plots: output = plot_data(data_container=run_data, configs=configs, params=run_params, out_dir=run_dir, hmc=dynamics.config.hmc, therm_frac=therm_frac, num_chains=num_chains, profile=True, cmap='crest', logging_steps=1) save_times = io.SortedDict(**output['save_times']) plot_times = io.SortedDict(**output['plot_times']) dt1 = io.SortedDict(**plot_times['data_container.plot_dataset']) plot_times['data_container.plot_dataset'] = dt1 tint_data = { 'beta': beta, 'run_dir': run_dir, 'traj_len': traj_len, 'run_params': run_params, 'eps': run_params['eps_avg'], 'lf': run_params['num_steps'], 'narr': output['tint_dict']['narr'], 'tint': output['tint_dict']['tint'], } t0 = time.time() tint_file = os.path.join(run_dir, 'tint_data.z') io.savez(tint_data, tint_file, 'tint_data') save_times['savez_tint_data'] = time.time() - t0 t0 = time.time() logfile = os.path.join(run_dir, 'inference.log') run_data.save_and_flush(data_dir=data_dir, log_file=logfile, use_hdf5=use_hdf5, skip_keys=skip_keys, save_dataset=save_dataset) save_times['run_data.flush_data_strs'] = time.time() - t0 t0 = time.time() try: run_data.write_to_csv(logdir, run_dir, hmc=dynamics.config.hmc) except TypeError: logger.warning(f'Unable to write to csv. Continuing...') save_times['run_data.write_to_csv'] = time.time() - t0 # t0 = time.time() # io.save_inference(run_dir, run_data) # save_times['io.save_inference'] = time.time() - t0 profdir = os.path.join(run_dir, 'profile_info') io.check_else_make_dir(profdir) io.save_dict(plot_times, profdir, name='plot_times') io.save_dict(save_times, profdir, name='save_times') return InferenceResults(dynamics=results.dynamics, x=results.x, x_arr=results.x_arr, run_data=results.run_data, data_strs=results.data_strs)
def run_inference_from_log_dir( log_dir: str, run_steps: int = 50000, beta: float = None, eps: float = None, make_plots: bool = True, train_steps: int = 10, therm_frac: float = 0.33, batch_size: int = 16, num_chains: int = 16, x: tf.Tensor = None, ) -> InferenceResults: # (type: InferenceResults) """Run inference by loading networks in from `log_dir`.""" configs = _find_configs(log_dir) if configs is None: raise FileNotFoundError( f'Unable to load configs from `log_dir`: {log_dir}. Exiting' ) if eps is not None: configs['dynamics_config']['eps'] = eps else: try: eps_file = os.path.join(log_dir, 'training', 'models', 'eps.z') eps = io.loadz(eps_file) except FileNotFoundError: eps = configs.get('dynamics_config', None).get('eps', None) if beta is not None: configs.update({'beta': beta, 'beta_final': beta}) if batch_size is not None: batch_size = int(batch_size) prev_shape = configs['dynamics_config']['x_shape'] new_shape = (batch_size, *prev_shape[1:]) configs['dynamics_config']['x_shape'] = new_shape configs = AttrDict(configs) dynamics = build_dynamics(configs) xnet, vnet = dynamics._load_networks(log_dir) dynamics.xnet = xnet dynamics.vnet = vnet if train_steps > 0: dynamics, train_data, x = short_training(train_steps, configs.beta_final, log_dir, dynamics, x=x) else: dynamics.compile(loss=dynamics.calc_losses, optimizer=dynamics.optimizer, experimental_run_tf_function=False) _, log_str = os.path.split(log_dir) if x is None: x = convert_to_angle(tf.random.normal(dynamics.x_shape)) configs['run_steps'] = run_steps configs['print_steps'] = max((run_steps // 100, 1)) configs['md_steps'] = 100 runs_dir = os.path.join(log_dir, 'LOADED', 'inference') io.check_else_make_dir(runs_dir) io.save_dict(configs, runs_dir, name='inference_configs') inference_results = run(dynamics=dynamics, configs=configs, x=x, beta=beta, runs_dir=runs_dir, make_plots=make_plots, therm_frac=therm_frac, num_chains=num_chains) return inference_results
def train( configs: dict[str, Any], x: tf.Tensor = None, num_chains: int = 32, make_plots: bool = True, steps_dict: dict[str, int] = None, skip_keys: list[str] = None, save_metrics: bool = True, save_dataset: bool = False, use_hdf5: bool = False, custom_betas: Union[list, np.ndarray] = None, **kwargs, ) -> TrainOutputs: """Train model. Returns: train_outputs: Dataclass with attributes: - x: tf.Tensor - logdir: str - configs: dict[str, Any] - data: DataContainer - dynamics: GaugeDynamics """ start = time.time() configs = setup_directories(configs) try_restore = kwargs.pop('try_restore', True) config = setup(configs, x=x, try_restore=try_restore) dynamics = config['dynamics'] dirs = config['dirs'] configs = config['configs'] train_data = config['train_data'] dynamics.save_config(dirs['config_dir']) if RANK == 0: logfile = os.path.join(os.getcwd(), 'log_dirs.txt') logdir = configs.get('logdir', configs.get('log_dir', None)) # str io.save_dict(configs, logdir, name='train_configs') io.write(f'{logdir}', logfile, 'a') restore_from = configs.get('restore_from', None) if restore_from is not None: restored = load_configs_from_logdir(restore_from) if restored is not None: io.save_dict(restored, logdir, name='restored_train_configs') # -- Train dynamics ----------------------------------------- # logger.rule('TRAINING') logger.info(f'Starting training at: {io.get_timestamp("%x %X")}') t0 = time.time() x, train_data = train_dynamics(dynamics, config, dirs, x=x, steps_dict=steps_dict, custom_betas=custom_betas, save_metrics=save_metrics) logger.info(f'DONE TRAINING. TOOK: {time.time() - t0:.4f}') logger.info(f'Training took: {time.time() - t0:.4f}') # shutdown() # ------------------------------------ if IS_CHIEF: logdir = dirs['logdir'] train_dir = dirs['train_dir'] io.save_dict(dict(configs), dirs['log_dir'], 'configs') # train_data.save_and_flush(dirs['data_dir']) train_data.save_and_flush( dirs['data_dir'], log_file=dirs['log_file'], use_hdf5=use_hdf5, save_dataset=save_dataset, skip_keys=skip_keys) #['forward', 'backward']) logger.info(f'Done training model! took: {time.time() - start:.4f}s') b0 = configs.get('beta_init', None) b1 = configs.get('beta_final', None) if make_plots: params = { 'beta_init': train_data.data.get('beta', [b0])[0], 'beta_final': train_data.data.get('beta', [b1])[0], 'x_shape': dynamics.config.x_shape, 'num_steps': dynamics.config.num_steps, 'net_weights': dynamics.net_weights, } t0 = time.time() logging_steps = configs.get('logging_steps', None) # type: int _ = plot_data(data_container=train_data, configs=configs, params=params, out_dir=train_dir, therm_frac=0, cmap='flare', num_chains=num_chains, logging_steps=logging_steps, **kwargs) dt = time.time() - t0 logger.debug( f'Time spent plotting: {dt}s = {dt // 60}m {(dt % 60):.4f}s') if not dynamics.config.hmc: dynamics.save_networks(logdir) return TrainOutputs(x, dirs['log_dir'], configs, train_data, dynamics)