def train_hmc(flags): """Main method for training HMC model.""" hflags = AttrDict(dict(flags).copy()) lr_config = AttrDict(hflags.pop('lr_config', None)) config = AttrDict(hflags.pop('dynamics_config', None)) net_config = AttrDict(hflags.pop('network_config', None)) hflags.train_steps = hflags.pop('hmc_steps', None) hflags.beta_init = hflags.beta_final config.update({ 'hmc': True, 'use_ncp': False, 'aux_weight': 0., 'zero_init': False, 'separate_networks': False, 'use_conv_net': False, 'directional_updates': False, 'use_scattered_xnet_update': False, 'use_tempered_traj': False, 'gauge_eq_masks': False, }) hflags.profiler = False hflags.make_summaries = True lr_config = LearningRateConfig( warmup_steps=0, decay_rate=0.9, decay_steps=hflags.train_steps // 10, lr_init=lr_config.get('lr_init', None), ) train_dirs = io.setup_directories(hflags, 'training_hmc') dynamics = GaugeDynamics(hflags, config, net_config, lr_config) dynamics.save_config(train_dirs.config_dir) x, train_data = train_dynamics(dynamics, hflags, dirs=train_dirs) if IS_CHIEF: output_dir = os.path.join(train_dirs.train_dir, 'outputs') io.check_else_make_dir(output_dir) train_data.save_data(output_dir) params = { 'eps': dynamics.eps, 'num_steps': dynamics.config.num_steps, 'beta_init': hflags.beta_init, 'beta_final': hflags.beta_final, 'lattice_shape': dynamics.config.lattice_shape, 'net_weights': NET_WEIGHTS_HMC, } plot_data(train_data, train_dirs.train_dir, hflags, thermalize=True, params=params) return x, train_data, dynamics.eps.numpy()
def short_training( train_steps: int, beta: float, log_dir: str, dynamics: GaugeDynamics, x: tf.Tensor = None, ): """Perform a brief training run prior to running inference.""" ckpt_dir = os.path.join(log_dir, 'training', 'checkpoints') ckpt = tf.train.Checkpoint(dynamics=dynamics, optimizer=dynamics.optimizer) manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=5) current_step = 0 if manager.latest_checkpoint: io.log(f'Restored model from: {manager.latest_checkpoint}') ckpt.restore(manager.latest_checkpoint) current_step = dynamics.optimizer.iterations.numpy() if x is None: x = convert_to_angle(tf.random.normal(dynamics.x_shape)) train_data = DataContainer(current_step+train_steps, print_steps=1) dynamics.compile(loss=dynamics.calc_losses, optimizer=dynamics.optimizer, experimental_run_tf_function=False) x, metrics = dynamics.train_step((x, tf.constant(beta))) header = train_data.get_header(metrics, skip=SKEYS, prepend=['{:^12s}'.format('step')]) io.log(header.split('\n')) for step in range(current_step, current_step + train_steps): start = time.time() x, metrics = dynamics.train_step((x, tf.constant(beta))) metrics.dt = time.time() - start train_data.update(step, metrics) data_str = train_data.print_metrics(metrics) logger.info(data_str) # logger.print_metrics(metrics) # data_str = train_data.get_fstr(step, metrics, skip=SKEYS) # io.log(data_str) return dynamics, train_data, x
def run_profiler(dynamics: GaugeDynamics, inputs: tuple[tf.Tensor, Union[float, tf.Tensor]], logdir: str, steps: int = 10): logger.debug(f'Running {steps} profiling steps!') x, beta = inputs beta = tf.constant(beta) metrics = None for _ in range(steps): x, metrics = dynamics.train_step((x, beta)) tf.profiler.experimental.start(logdir=logdir, options=OPTIONS) for step in range(steps): with tf.profiler.experimental.Trace('train', step_num=step, _r=1): x, metrics = dynamics.train_step((x, beta)) tf.profiler.experimental.stop(save=True) logger.debug(f'Done!') return x, metrics
def run_md( dynamics: GaugeDynamics, inputs: tuple, md_steps: int, ): x, beta = inputs beta = tf.constant(beta) logger.debug(f'Running {md_steps} MD updates!') for _ in range(md_steps): mc_states, _ = dynamics.md_update((x, beta), training=True) x = mc_states.out.x # type: tf.Tensor logger.debug(f'Done!') return x
def trace_train_step( dynamics: GaugeDynamics, writer: SummaryWriter, outdir: str, x: tf.Tensor = None, beta: float = None, graph: bool = True, profiler: bool = True, ): if x is None: x = tf.random.uniform(dynamics.x_shape, *(-PI, PI)) if beta is None: beta = 1. # Bracket the function call with # tf.summary.trace_on() and tf.summary.trace_export() tf.summary.trace_on(graph=graph, profiler=profiler) # Call only one tf.function when tracing x, metrics = dynamics.train_step((x, beta)) with writer.as_default(): tf.summary.trace_export(name='dynamics_train_step', step=0, profiler_outdir=outdir)
def run_dynamics( dynamics: GaugeDynamics, flags: AttrDict, x: tf.Tensor = None, save_x: bool = False, md_steps: int = 0, ) -> (DataContainer, tf.Tensor, list): """Run inference on trained dynamics.""" if not IS_CHIEF: return None, None, None # Setup print_steps = flags.get('print_steps', 5) beta = flags.get('beta', flags.get('beta_final', None)) test_step = dynamics.test_step if flags.get('compile', True): test_step = tf.function(dynamics.test_step) io.log('Compiled `dynamics.test_step` using tf.function!') if x is None: x = tf.random.uniform(shape=dynamics.x_shape, minval=-PI, maxval=PI, dtype=TF_FLOAT) run_data = DataContainer(flags.run_steps) template = '\n'.join([f'beta: {beta}', f'eps: {dynamics.eps.numpy():.4g}', f'net_weights: {dynamics.net_weights}']) io.log(f'Running inference with:\n {template}') # Run 50 MD updates (w/o accept/reject) to ensure chains don't get stuck if md_steps > 0: for _ in range(md_steps): mc_states, _ = dynamics.md_update(x, beta, training=False) x = mc_states.out.x try: x, metrics = test_step((x, tf.constant(beta))) except Exception as exception: # pylint:disable=broad-except io.log(f'Exception: {exception}') test_step = dynamics.test_step x, metrics = test_step((x, tf.constant(beta))) header = run_data.get_header(metrics, skip=['charges'], prepend=['{:^12s}'.format('step')]) # io.log(header) io.log(header.split('\n'), should_print=True) # ------------------------------------------------------------- x_arr = [] def timed_step(x: tf.Tensor, beta: tf.Tensor): start = time.time() x, metrics = test_step((x, tf.constant(beta))) metrics.dt = time.time() - start if save_x: x_arr.append(x.numpy()) return x, metrics steps = tf.range(flags.run_steps, dtype=tf.int64) if NUM_NODES == 1: ctup = (CBARS['red'], CBARS['green'], CBARS['red'], CBARS['reset']) steps = tqdm(steps, desc='running', unit='step', bar_format=("%s{l_bar}%s{bar}%s{r_bar}%s" % ctup)) for step in steps: x, metrics = timed_step(x, beta) run_data.update(step, metrics) if step % print_steps == 0: summarize_dict(metrics, step, prefix='testing') data_str = run_data.get_fstr(step, metrics, skip=['charges']) io.log(data_str, should_print=True) if (step + 1) % 1000 == 0: io.log(header, should_print=True) return run_data, x, x_arr
def run_dynamics( dynamics: GaugeDynamics, flags: dict[str, Any], writer: tf.summary.SummaryWriter = None, x: tf.Tensor = None, beta: float = None, save_x: bool = False, md_steps: int = 0, # window: int = 0, # should_track: bool = False, ) -> (InferenceResults): """Run inference on trained dynamics.""" if not IS_CHIEF: return InferenceResults(None, None, None, None, None) # -- Setup ----------------------------- print_steps = flags.get('print_steps', 5) if beta is None: beta = flags.get('beta', flags.get('beta_final', None)) # type: float if beta is None: logger.warning(f'beta unspecified! setting to 1') beta = 1. assert beta is not None and isinstance(beta, float) test_step = dynamics.test_step if flags.get('compile', True): test_step = tf.function(dynamics.test_step) io.log('Compiled `dynamics.test_step` using tf.function!') if x is None: x = tf.random.uniform(shape=dynamics.x_shape, *(-PI, PI)) # minval, maxval=PI, # dtype=TF_FLOAT) assert tf.is_tensor(x) run_steps = flags.get('run_steps', 20000) run_data = DataContainer(run_steps) template = '\n'.join([f'beta={beta}', f'net_weights={dynamics.net_weights}']) logger.info(f'Running inference with {template}') # Run `md_steps MD updates (w/o accept/reject) # to ensure chains don't get stuck if md_steps > 0: for _ in range(md_steps): mc_states, _ = dynamics.md_update((x, beta), training=False) x = mc_states.out.x try: x, metrics = test_step((x, tf.constant(beta))) except Exception as err: # pylint:disable=broad-except logger.warning(err) # io.log(f'Exception: {exception}') test_step = dynamics.test_step x, metrics = test_step((x, tf.constant(beta))) x_arr = [] def timed_step(x: tf.Tensor, beta: tf.Tensor): start = time.time() x, metrics = test_step((x, tf.constant(beta))) metrics.dt = time.time() - start if 'sin_charges' not in metrics: charges = dynamics.lattice.calc_both_charges(x=x) metrics['charges'] = charges.intQ metrics['sin_charges'] = charges.sinQ if save_x: x_arr.append(x.numpy()) return x, metrics summary_steps = max(run_steps // 100, 50) if writer is not None: writer.set_as_default() steps = tf.range(run_steps, dtype=tf.int64) keep_ = ['step', 'dt', 'loss', 'accept_prob', 'beta', 'dq_int', 'dq_sin', 'dQint', 'dQsin', 'plaqs', 'p4x4'] beta = tf.constant(beta, dtype=TF_FLOAT) # type: tf.Tensor data_strs = [] for idx, step in enumerate(steps): x, metrics = timed_step(x, beta) run_data.update(step, metrics) # update data after every accept/reject if step % summary_steps == 0: update_summaries(step, metrics, dynamics) # summarize_dict(metrics, step, prefix='testing') if step % print_steps == 0: pre = [f'{step}/{steps[-1]}'] ms = run_data.print_metrics(metrics, pre=pre, keep=keep_) data_strs.append(ms) return InferenceResults(dynamics=dynamics, x=x, x_arr=x_arr, run_data=run_data, data_strs=data_strs)
def train_dynamics( dynamics: GaugeDynamics, inputs: dict[str, Any], dirs: dict[str, str] = None, x: tf.Tensor = None, steps_dict: dict[str, int] = None, save_metrics: bool = True, custom_betas: Union[list, np.ndarray] = None, window: int = 0, ) -> tuple[tf.Tensor, DataContainer]: """Train model.""" configs = inputs['configs'] steps = configs.get('steps', []) min_lr = configs.get('min_lr', 1e-5) patience = configs.get('patience', 10) factor = configs.get('reduce_lr_factor', 0.5) save_steps = configs.get('save_steps', 10000) # type: int print_steps = configs.get('print_steps', 1000) # type: int logging_steps = configs.get('logging_steps', 500) # type: int steps_per_epoch = configs.get('steps_per_epoch', 1000) # type: int if steps_dict is not None: save_steps = steps_dict.get('save', 10000) # type: int print_steps = steps_dict.get('print', 1000) # type: int logging_steps = steps_dict.get('logging_steps', 500) # type: int steps_per_epoch = steps_dict.get('steps_per_epoch', 1000) # type: int # -- Helper functions for training, logging, saving, etc. -------------- # step_times = [] timer = StepTimer(evals_per_step=dynamics.config.num_steps) def train_step(x: tf.Tensor, beta: tf.Tensor): # start = time.time() timer.start() x, metrics = dynamics.train_step((x, tf.constant(beta))) dt = timer.stop() metrics.dt = dt return x, metrics def should_print(step: int) -> bool: return IS_CHIEF and step % print_steps == 0 def should_log(step: int) -> bool: return IS_CHIEF and step % logging_steps == 0 def should_save(step: int) -> bool: return step % save_steps == 0 and ckpt is not None xshape = dynamics._xshape xr = tf.random.uniform(xshape, -PI, PI) x = inputs.get('x', xr) if x is None else x assert x is not None if custom_betas is None: betas = np.array(inputs.get('betas', None)) assert betas is not None and betas.shape[0] > 0 steps = np.array(inputs.get('steps')) assert steps is not None and steps.shape[0] > 0 else: betas = np.array(custom_betas) start = dynamics.optimizer.iterations nsteps = len(betas) steps = np.arange(start, start + nsteps) dirs = inputs.get('dirs', None) if dirs is None else dirs # type: dict assert dirs is not None manager = inputs['manager'] # type: tf.train.CheckpointManager ckpt = inputs['checkpoint'] # type: tf.train.Checkpoint train_data = inputs['train_data'] # type: DataContainer # tf.compat.v1.autograph.experimental.do_not_convert(dynamics.train_step) # -- Setup dynamic learning rate schedule ----------------- assert dynamics.lr_config is not None warmup_steps = dynamics.lr_config.warmup_steps reduce_lr = ReduceLROnPlateau(monitor='loss', mode='min', warmup_steps=warmup_steps, factor=factor, min_lr=min_lr, verbose=1, patience=patience) reduce_lr.set_model(dynamics) # -- Setup summary writer ----------- writer = inputs.get('writer', None) # type: tf.summary.SummaryWriter if IS_CHIEF and writer is not None: writer.set_as_default() # -- Run profiler? ---------------------------------------- if configs.get('profiler', False): if RANK == 0: sdir = dirs['summary_dir'] # trace_train_step(dynamics, # graph=True, # profiler=True, # outdir=sdir, # writer=writer) x, metrics = run_profiler(dynamics, (x, betas[0]), logdir=sdir, steps=5) else: x, metrics = dynamics.train_step((x, betas[0])) # -- Run MD update to not get stuck ---------------------- md_steps = configs.get('md_steps', 0) if md_steps > 0: x = run_md(dynamics, (x, betas[0]), md_steps) warmup_steps = dynamics.lr_config.warmup_steps total_steps = steps[-1] if len(steps) != len(betas): betas = betas[steps[0]:] keep = [ 'dt', 'loss', 'accept_prob', 'beta', 'Hwb_start', 'Hwf_start', 'Hwb_mid', 'Hwf_mid', 'Hwb_end', 'Hwf_end', 'xeps', 'veps', 'dq', 'dq_sin', 'plaqs', 'p4x4', 'charges', 'sin_charges' ] plots = {} if in_notebook(): plots = plotter.init_plots(configs, figsize=(9, 3), dpi=125) # -- Training loop --------------------------------------------------- data_strs = [] logdir = dirs['log_dir'] data_dir = dirs['data_dir'] logfile = dirs['log_file'] logfile = os.path.join(logdir, 'training', 'train_log.txt') assert x is not None assert manager is not None assert len(steps) == len(betas) for step, beta in zip(steps, betas): x, metrics = train_step(x, beta) # ---------------------------------------------------------------- # TODO: Run inference when beta hits an integer # >>> beta_inf = {i: False, for i in np.arange(beta_final)} # >>> if any(np.isclose(beta, np.array(list(beta_inf.keys())))): # >>> run_inference(...) # ---------------------------------------------------------------- if (step + 1) > warmup_steps and (step + 1) % steps_per_epoch == 0: reduce_lr.on_epoch_end(step + 1, {'loss': metrics.loss}) # -- Save checkpoints and dump configs `x` from each rank -------- if should_save(step + 1): train_data.update(step, metrics) train_data.dump_configs(x, data_dir, rank=RANK, local_rank=LOCAL_RANK) if IS_CHIEF: _ = timer.save_and_write(logdir, mode='w') # -- Save CheckpointManager ------------------------------ manager.save() mstr = f'Checkpoint saved to: {manager.latest_checkpoint}' logger.info(mstr) with open(logfile, 'w') as f: f.writelines('\n'.join(data_strs)) # -- Save train_data and free consumed memory ------------ train_data.save_and_flush(data_dir, logfile, rank=RANK, mode='a') if not dynamics.config.hmc: # -- Save network weights ---------------------------- dynamics.save_networks(logdir) logger.info(f'Networks saved to: {logdir}') # -- Print current training state and metrics ------------------- if should_print(step): train_data.update(step, metrics) keep_ = [ 'step', 'dt', 'loss', 'accept_prob', 'beta', 'dq_int', 'dq_sin', 'dQint', 'dQsin', 'plaqs', 'p4x4' ] pre = [f'{step:>4g}/{total_steps:<4g}'] # data_str = logger.print_metrics(metrics, window=50, # pre=pre, keep=keep_) data_str = train_data.print_metrics(metrics, window=window, pre=pre, keep=keep_) data_strs.append(data_str) if in_notebook() and step % PLOT_STEPS == 0 and IS_CHIEF: train_data.update(step, metrics) if len(train_data.data.keys()) == 0: update_plots(metrics, plots, logging_steps=configs['logging_steps']) else: update_plots(train_data.data, plots, logging_steps=configs['logging_steps']) # -- Update summary objects --------------------- if should_log(step): train_data.update(step, metrics) if writer is not None: update_summaries(step, metrics, dynamics) writer.flush() # -- Dump config objects ------------------------------------------------- train_data.dump_configs(x, data_dir, rank=RANK, local_rank=LOCAL_RANK) if IS_CHIEF: manager.save() logger.info(f'Checkpoint saved to: {manager.latest_checkpoint}') with open(logfile, 'w') as f: f.writelines('\n'.join(data_strs)) if save_metrics: train_data.save_and_flush(data_dir, logfile, rank=RANK, mode='a') if not dynamics.config.hmc: dynamics.save_networks(logdir) if writer is not None: writer.flush() writer.close() # ngrad_evals = SIZE * dynamics.config.num_steps * len(step_times) # eval_rate = ngrad_evals / np.sum(step_times) # outstr = '\n'.join([f'ngrad_evals: {ngrad_evals}', # f'sum(step_times): {np.sum(step_times)}', # f'eval rate: {eval_rate}']) # with open(Path(logdir).joinpath('eval_rate.txt'), 'a') as f: # f.write(outstr) # # csvfile = Path(logdir).joinpath('dt_train.csv') # pd.DataFrame(step_times).to_csv(csvfile, mode='a') return x, train_data