def update_plot( y: Metric, fig: plt.Figure, ax: plt.Axes, line: list[plt.Line2D], display_id: DisplayHandle, window: int = 15, logging_steps: int = 1, ): if not in_notebook(): return y = np.array(y) if len(y.shape) == 2: y = y.mean(-1) yavg = moving_average(y.squeeze(), window=window) line[0].set_ydata(yavg) line[0].set_xdata(logging_steps * np.arange(yavg.shape[0])) # line[0].set_ydata(y) # line[0].set_xdata(plot_freq * np.arange(y.shape[0])) # line[0].set_xdata(np.arange(len(yavg))) ax.relim() ax.autoscale_view() fig.canvas.draw() display_id.update(fig)
def init_plots(configs: dict = None, ylabels: list[str] = None, **kwargs): loss = None dq_int = None dq_sin = None if in_notebook(): dq_int = init_live_plot(configs=configs, ylabel='dq_int', xlabel='Training Step', **kwargs) dq_sin = init_live_plot(configs=configs, ylabel='dq_sin', xlabel='Training Step', **kwargs) c0 = ['C0', 'C1'] if ylabels is None: ylabels = ['loss', 'beta'] loss = init_live_joint_plots(ylabels=ylabels, xlabel='Train Step', colors=c0, use_title=False, **kwargs) return { 'dq_int': dq_int, 'dq_sin': dq_sin, 'loss': loss, }
def test( dynamics: GaugeDynamics, steps: Steps, beta: Union[list[float], float], x: torch.Tensor = None, skip: Union[str, list[str]] = None, keep: Union[str, list[str]] = None, history: History = None, nchains_test: int = None, ) -> History: """Run training and evaluate the trained model.""" logger.log(80 * '-') logger.log(f'Running inference...') should_print = not in_notebook() dynamics.eval() if history is None: history = History() if x is None: x = random_angle(dynamics.config.x_shape, requires_grad=True) if isinstance(beta, float): beta = np.array(steps.train * [beta], dtype=np.float32).tolist() assert isinstance(beta, list) assert len(beta) == steps.train assert isinstance(beta[0], float) test_logs = [] test_beta = beta[-1] x = x.reshape(x.shape[0], -1).to(DEVICE) if nchains_test is not None: x = x[:nchains_test, :] for step in range(steps.test): x, metrics = test_step((x, test_beta), dynamics, timer=history.timer) history.update(metrics, step) pre = [f'{step}/{steps.test}'] mstr = history.metrics_summary(window=0, pre=pre, keep=keep, skip=skip) if not should_print: logger.log(mstr) test_logs.append(mstr) rate = history.timer.get_eval_rate(dynamics.config.num_steps) logger.log(f'Done training! took: {rate["total_time"]}') logger.log(f'Timing info:') for key, val in rate.items(): logger.log(f' - {key}={val}') return history
def update_joint_plots( plot_data1: LivePlotData, plot_data2: LivePlotData, display_id: DisplayHandle, window: int = 15, logging_steps: int = 1, fig: plt.Figure = None, ): if not in_notebook(): return if fig is None: fig = plt.gcf() plot_obj1 = plot_data1.plot_obj plot_obj2 = plot_data2.plot_obj x1 = np.array(plot_data1.data).squeeze() # type: np.ndarray x2 = np.array(plot_data2.data).squeeze() # type: np.ndarray # x1avg = x1.mean(-1) if len(x1.shape) == 2 else x1 # type: np.ndarray # x2avg = x2.mean(-1) if len(x2.shape) == 2 else x2 # type: np.ndarray if len(x1.shape) == 2: x1avg = np.mean(x1, -1) else: x1avg = x1 if len(x2.shape) == 2: x2avg = np.mean(x2, -1) else: x2avg = x2 y1 = moving_average(x1avg, window=window) y2 = moving_average(x2avg, window=window) plot_obj1.line[0].set_ydata(np.array(y1)) plot_obj1.line[0].set_xdata(logging_steps * np.arange(y1.shape[0])) plot_obj2.line[0].set_ydata(y2) plot_obj2.line[0].set_xdata(logging_steps * np.arange(y2.shape[0])) plot_obj1.ax.relim() plot_obj2.ax.relim() plot_obj1.ax.autoscale_view() plot_obj2.ax.autoscale_view() fig.canvas.draw() display_id.update(fig) # need to force colab to update plot
def train( dynamics: GaugeDynamics, optimizer: optim.Optimizer, steps: Steps, beta: Union[list[float], float], window: int = 10, x: torch.Tensor = None, skip: Union[str, list[str]] = None, keep: Union[str, list[str]] = None, history: History = None, ) -> History: """Train dynamics.""" dynamics.train() if x is None: x = random_angle(dynamics.config.x_shape, requires_grad=True) x = x.reshape(x.shape[0], -1) train_logs = [] if history is None: history = History() should_print = (not in_notebook()) if isinstance(beta, list): assert len(beta) == steps.train elif isinstance(beta, float): beta = np.array(steps.train * [beta], dtype=np.float32).tolist() else: raise TypeError( f'beta expected to be `float | list[float]`,\n got: {type(beta)}') assert (isinstance(beta, list) and isinstance(beta[0], float)) for step, b in zip(range(steps.train), beta): x, metrics = train_step((to_u1(x), b), dynamics=dynamics, optimizer=optimizer, timer=history.timer) if (step + 1) % steps.log == 0: history.update(metrics, step) # pre = [f'{step}/{steps.train}'] pre = [f'{int((step / steps.train) * 100)}%'] mstr = history.metrics_summary(window=window, pre=pre, keep=keep, skip=skip) # should_print=should_print) if not should_print: logger.log(mstr) train_logs.append(mstr) logger.log(80 * '-') rate = history.timer.get_eval_rate(dynamics.config.num_steps) logger.log(f'Done training! took: {rate["total_time"]}') logger.log(f'Timing info:') for key, val in rate.items(): logger.log(f' - {key}={val}') return history
def plot_data( data_container: "DataContainer", # type: "DataContainer", # noqa:F821 configs: dict = None, out_dir: str = None, therm_frac: float = 0, params: Union[dict, AttrDict] = None, hmc: bool = None, num_chains: int = 32, profile: bool = False, cmap: str = 'crest', verbose: bool = False, logging_steps: int = 1, ) -> dict: """Plot data from `data_container.data`.""" if verbose: keep_strs = list(data_container.data.keys()) else: keep_strs = [ 'charges', 'plaqs', 'accept_prob', 'Hf_start', 'Hf_mid', 'Hf_end', 'Hb_start', 'Hb_mid', 'Hb_end', 'Hwf_start', 'Hwf_mid', 'Hwf_end', 'Hwb_start', 'Hwb_mid', 'Hwb_end', 'xeps_start', 'xeps_mid', 'xeps_end', 'veps_start', 'veps_mid', 'veps_end' ] with_jupyter = in_notebook() # -- TODO: -------------------------------------- # * Get rid of unnecessary `params` argument, # all of the entries exist in `configs`. # ---------------------------------------------- if num_chains > 16: logger.warning( f'Reducing `num_chains` from {num_chains} to 16 for plotting.' ) num_chains = 16 plot_times = {} save_times = {} title = None if params is not None: try: title = get_title_str_from_params(params) except: title = None else: if configs is not None: params = { 'beta_init': configs['beta_init'], 'beta_final': configs['beta_final'], 'x_shape': configs['dynamics_config']['x_shape'], 'num_steps': configs['dynamics_config']['num_steps'], 'net_weights': configs['dynamics_config']['net_weights'], } else: params = {} tstamp = io.get_timestamp('%Y-%m-%d-%H%M%S') plotdir = None if out_dir is not None: plotdir = os.path.join(out_dir, f'plots_{tstamp}') io.check_else_make_dir(plotdir) tint_data = {} output = {} if 'charges' in data_container.data: if configs is not None: lf = configs['dynamics_config']['num_steps'] # type: int else: lf = 0 qarr = np.array(data_container.data['charges']) t0 = time.time() tint_dict, _ = plot_autocorrs_vs_draws(qarr, num_pts=20, nstart=1000, therm_frac=0.2, out_dir=plotdir, lf=lf) plot_times['plot_autocorrs_vs_draws'] = time.time() - t0 tint_data = deepcopy(params) tint_data.update({ 'narr': tint_dict['narr'], 'tint': tint_dict['tint'], 'run_params': params, }) run_dir = params.get('run_dir', None) if run_dir is not None: if os.path.isdir(str(Path(run_dir))): tint_file = os.path.join(run_dir, 'tint_data.z') t0 = time.time() io.savez(tint_data, tint_file, 'tint_data') save_times['tint_data'] = time.time() - t0 t0 = time.time() qsteps = logging_steps * np.arange(qarr.shape[0]) _ = plot_charges(qsteps, qarr, out_dir=plotdir, title=title) plot_times['plot_charges'] = time.time() - t0 output.update({ 'tint_dict': tint_dict, 'charges_steps': qsteps, 'charges_arr': qarr, }) hmc = params.get('hmc', False) if hmc is None else hmc data_dict = {} data_vars = {} # charges_steps = [] # charges_arr = [] for key, val in data_container.data.items(): if key in SKEYS and key not in keep_strs: continue # if key == 'x': # continue # # ==== # Conditional to skip logdet-related data # from being plotted if data generated from HMC run if hmc: for skip_str in ['Hw', 'ld', 'sld', 'sumlogdet']: if skip_str in key: continue arr = np.array(val) steps = logging_steps * np.arange(len(arr)) if therm_frac > 0: if logging_steps == 1: arr, steps = therm_arr(arr, therm_frac=therm_frac) else: drop = int(therm_frac * arr.shape[0]) arr = arr[drop:] steps = steps[drop:] if logging_steps == 1 and therm_frac > 0: arr, steps = therm_arr(arr, therm_frac=therm_frac) # if arr.flatten().std() < 1e-2: # logger.warning(f'Skipping plot for: {key}') # logger.warning(f'std({key}) = {arr.flatten().std()} < 1e-2') labels = ('MC Step', key) data = (steps, arr) # -- NOTE: arr.shape: (draws,) = (D,) ------------------------------- if len(arr.shape) == 1: data_dict[key] = xr.DataArray(arr, dims=['draw'], coords=[steps]) # if verbose: # plotdir_ = os.path.join(plotdir, f'mcmc_lineplots') # io.check_else_make_dir(plotdir_) # lplot_fname = os.path.join(plotdir_, f'{key}.pdf') # _, _ = mcmc_lineplot(data, labels, title, # lplot_fname, show_avg=True) # plt.close('all') # -- NOTE: arr.shape: (draws, chains) = (D, C) ---------------------- elif len(arr.shape) == 2: data_dict[key] = data chains = np.arange(arr.shape[1]) data_arr = xr.DataArray(arr.T, dims=['chain', 'draw'], coords=[chains, steps]) data_dict[key] = data_arr data_vars[key] = data_arr # if verbose: # plotdir_ = os.path.join(plotdir, 'traceplots') # tplot_fname = os.path.join(plotdir_, f'{key}_traceplot.pdf') # _ = mcmc_traceplot(key, data_arr, title, tplot_fname) # plt.close('all') # -- NOTE: arr.shape: (draws, leapfrogs, chains) = (D, L, C) --------- elif len(arr.shape) == 3: _, leapfrogs_, chains_ = arr.shape chains = np.arange(chains_) leapfrogs = np.arange(leapfrogs_) data_dict[key] = xr.DataArray(arr.T, # NOTE: [arr.T] = (C, L, D) dims=['chain', 'leapfrog', 'draw'], coords=[chains, leapfrogs, steps]) # plotdir_xr = None # if plotdir is not None: # plotdir_xr = os.path.join(plotdir, 'xarr_plots') plotdir_xr = None if plotdir is not None: plotdir_xr = os.path.join(plotdir, 'xarr_plots') t0 = time.time() dataset, dtplot = data_container.plot_dataset(plotdir_xr, num_chains=num_chains, therm_frac=therm_frac, ridgeplots=True, cmap=cmap, profile=profile) if not with_jupyter: plt.close('all') plot_times['data_container.plot_dataset'] = {'total': time.time() - t0} for key, val in dtplot.items(): plot_times['data_container.plot_dataset'][key] = val if not hmc and 'Hwf' in data_dict.keys(): t0 = time.time() _ = plot_energy_distributions(data_dict, out_dir=plotdir, title=title) plot_times['plot_energy_distributions'] = time.time() - t0 t0 = time.time() _ = mcmc_avg_lineplots(data_dict, title, plotdir) plot_times['mcmc_avg_lineplots'] = time.time() - t0 if not with_jupyter: plt.close('all') output.update({ 'data_container': data_container, 'data_dict': data_dict, 'data_vars': data_vars, 'out_dir': plotdir, 'save_times': save_times, 'plot_times': plot_times, }) return output
def train_dynamics( dynamics: GaugeDynamics, inputs: dict[str, Any], dirs: dict[str, str] = None, x: tf.Tensor = None, steps_dict: dict[str, int] = None, save_metrics: bool = True, custom_betas: Union[list, np.ndarray] = None, window: int = 0, ) -> tuple[tf.Tensor, DataContainer]: """Train model.""" configs = inputs['configs'] steps = configs.get('steps', []) min_lr = configs.get('min_lr', 1e-5) patience = configs.get('patience', 10) factor = configs.get('reduce_lr_factor', 0.5) save_steps = configs.get('save_steps', 10000) # type: int print_steps = configs.get('print_steps', 1000) # type: int logging_steps = configs.get('logging_steps', 500) # type: int steps_per_epoch = configs.get('steps_per_epoch', 1000) # type: int if steps_dict is not None: save_steps = steps_dict.get('save', 10000) # type: int print_steps = steps_dict.get('print', 1000) # type: int logging_steps = steps_dict.get('logging_steps', 500) # type: int steps_per_epoch = steps_dict.get('steps_per_epoch', 1000) # type: int # -- Helper functions for training, logging, saving, etc. -------------- # step_times = [] timer = StepTimer(evals_per_step=dynamics.config.num_steps) def train_step(x: tf.Tensor, beta: tf.Tensor): # start = time.time() timer.start() x, metrics = dynamics.train_step((x, tf.constant(beta))) dt = timer.stop() metrics.dt = dt return x, metrics def should_print(step: int) -> bool: return IS_CHIEF and step % print_steps == 0 def should_log(step: int) -> bool: return IS_CHIEF and step % logging_steps == 0 def should_save(step: int) -> bool: return step % save_steps == 0 and ckpt is not None xshape = dynamics._xshape xr = tf.random.uniform(xshape, -PI, PI) x = inputs.get('x', xr) if x is None else x assert x is not None if custom_betas is None: betas = np.array(inputs.get('betas', None)) assert betas is not None and betas.shape[0] > 0 steps = np.array(inputs.get('steps')) assert steps is not None and steps.shape[0] > 0 else: betas = np.array(custom_betas) start = dynamics.optimizer.iterations nsteps = len(betas) steps = np.arange(start, start + nsteps) dirs = inputs.get('dirs', None) if dirs is None else dirs # type: dict assert dirs is not None manager = inputs['manager'] # type: tf.train.CheckpointManager ckpt = inputs['checkpoint'] # type: tf.train.Checkpoint train_data = inputs['train_data'] # type: DataContainer # tf.compat.v1.autograph.experimental.do_not_convert(dynamics.train_step) # -- Setup dynamic learning rate schedule ----------------- assert dynamics.lr_config is not None warmup_steps = dynamics.lr_config.warmup_steps reduce_lr = ReduceLROnPlateau(monitor='loss', mode='min', warmup_steps=warmup_steps, factor=factor, min_lr=min_lr, verbose=1, patience=patience) reduce_lr.set_model(dynamics) # -- Setup summary writer ----------- writer = inputs.get('writer', None) # type: tf.summary.SummaryWriter if IS_CHIEF and writer is not None: writer.set_as_default() # -- Run profiler? ---------------------------------------- if configs.get('profiler', False): if RANK == 0: sdir = dirs['summary_dir'] # trace_train_step(dynamics, # graph=True, # profiler=True, # outdir=sdir, # writer=writer) x, metrics = run_profiler(dynamics, (x, betas[0]), logdir=sdir, steps=5) else: x, metrics = dynamics.train_step((x, betas[0])) # -- Run MD update to not get stuck ---------------------- md_steps = configs.get('md_steps', 0) if md_steps > 0: x = run_md(dynamics, (x, betas[0]), md_steps) warmup_steps = dynamics.lr_config.warmup_steps total_steps = steps[-1] if len(steps) != len(betas): betas = betas[steps[0]:] keep = [ 'dt', 'loss', 'accept_prob', 'beta', 'Hwb_start', 'Hwf_start', 'Hwb_mid', 'Hwf_mid', 'Hwb_end', 'Hwf_end', 'xeps', 'veps', 'dq', 'dq_sin', 'plaqs', 'p4x4', 'charges', 'sin_charges' ] plots = {} if in_notebook(): plots = plotter.init_plots(configs, figsize=(9, 3), dpi=125) # -- Training loop --------------------------------------------------- data_strs = [] logdir = dirs['log_dir'] data_dir = dirs['data_dir'] logfile = dirs['log_file'] logfile = os.path.join(logdir, 'training', 'train_log.txt') assert x is not None assert manager is not None assert len(steps) == len(betas) for step, beta in zip(steps, betas): x, metrics = train_step(x, beta) # ---------------------------------------------------------------- # TODO: Run inference when beta hits an integer # >>> beta_inf = {i: False, for i in np.arange(beta_final)} # >>> if any(np.isclose(beta, np.array(list(beta_inf.keys())))): # >>> run_inference(...) # ---------------------------------------------------------------- if (step + 1) > warmup_steps and (step + 1) % steps_per_epoch == 0: reduce_lr.on_epoch_end(step + 1, {'loss': metrics.loss}) # -- Save checkpoints and dump configs `x` from each rank -------- if should_save(step + 1): train_data.update(step, metrics) train_data.dump_configs(x, data_dir, rank=RANK, local_rank=LOCAL_RANK) if IS_CHIEF: _ = timer.save_and_write(logdir, mode='w') # -- Save CheckpointManager ------------------------------ manager.save() mstr = f'Checkpoint saved to: {manager.latest_checkpoint}' logger.info(mstr) with open(logfile, 'w') as f: f.writelines('\n'.join(data_strs)) # -- Save train_data and free consumed memory ------------ train_data.save_and_flush(data_dir, logfile, rank=RANK, mode='a') if not dynamics.config.hmc: # -- Save network weights ---------------------------- dynamics.save_networks(logdir) logger.info(f'Networks saved to: {logdir}') # -- Print current training state and metrics ------------------- if should_print(step): train_data.update(step, metrics) keep_ = [ 'step', 'dt', 'loss', 'accept_prob', 'beta', 'dq_int', 'dq_sin', 'dQint', 'dQsin', 'plaqs', 'p4x4' ] pre = [f'{step:>4g}/{total_steps:<4g}'] # data_str = logger.print_metrics(metrics, window=50, # pre=pre, keep=keep_) data_str = train_data.print_metrics(metrics, window=window, pre=pre, keep=keep_) data_strs.append(data_str) if in_notebook() and step % PLOT_STEPS == 0 and IS_CHIEF: train_data.update(step, metrics) if len(train_data.data.keys()) == 0: update_plots(metrics, plots, logging_steps=configs['logging_steps']) else: update_plots(train_data.data, plots, logging_steps=configs['logging_steps']) # -- Update summary objects --------------------- if should_log(step): train_data.update(step, metrics) if writer is not None: update_summaries(step, metrics, dynamics) writer.flush() # -- Dump config objects ------------------------------------------------- train_data.dump_configs(x, data_dir, rank=RANK, local_rank=LOCAL_RANK) if IS_CHIEF: manager.save() logger.info(f'Checkpoint saved to: {manager.latest_checkpoint}') with open(logfile, 'w') as f: f.writelines('\n'.join(data_strs)) if save_metrics: train_data.save_and_flush(data_dir, logfile, rank=RANK, mode='a') if not dynamics.config.hmc: dynamics.save_networks(logdir) if writer is not None: writer.flush() writer.close() # ngrad_evals = SIZE * dynamics.config.num_steps * len(step_times) # eval_rate = ngrad_evals / np.sum(step_times) # outstr = '\n'.join([f'ngrad_evals: {ngrad_evals}', # f'sum(step_times): {np.sum(step_times)}', # f'eval rate: {eval_rate}']) # with open(Path(logdir).joinpath('eval_rate.txt'), 'a') as f: # f.write(outstr) # # csvfile = Path(logdir).joinpath('dt_train.csv') # pd.DataFrame(step_times).to_csv(csvfile, mode='a') return x, train_data
k: io.get_timestamp(v) for k, v in dict(zip(names, formats)).items() } PlotData = plotter.LivePlotData # logger = Logger() OPTIONS = tf.profiler.experimental.ProfilerOptions( host_tracer_level=2, python_tracer_level=1, device_tracer_level=1, delay_ms=None, ) if in_notebook(): logger = logging.getLogger('jupyter') else: logger = logging.getLogger('l2hmc') def update_plots(history: dict, plots: dict, window: int = 1, logging_steps: int = 1): lpdata = PlotData(history['loss'], plots['loss']['plot_obj1']) bpdata = PlotData(history['beta'], plots['loss']['plot_obj2']) fig_loss = plots['loss']['fig'] id_loss = plots['loss']['display_id'] plotter.update_joint_plots(lpdata, bpdata,