def serial_sample_generator_old(model: FlowModel, action_fn: ActionFn, batch_size: int, num_samples: int): prior = model.prior layers = model.layers layers.eval() # model.layers.eval() x, q, logq, logp = None, None, None, None for i in range(num_samples): batch_i = i % batch_size if batch_i == 0: # we're out of samples to propose, generate a new batch _, x, logq = apply_flow_to_prior(prior, layers, batch_size=batch_size) logp = -action_fn(x) q = qed.batch_charges(x) yield x[batch_i], q[batch_i], logq[batch_i], logp[batch_i]
def __init__( self, flow: nn.ModuleList, config: TrainConfig, lfconfig: lfConfig, ): super().__init__() self.flow = flow # layers of a `FlowModel` self.config = config # Training config self.lfconfig = lfconfig # lfConfig object self.dt = self.lfconfig.dt # step size self.tau = self.lfconfig.tau # trajectory length self.nstep = self.lfconfig.nstep # number of leapfrog steps self._denom = (self.config.beta * self.config.volume) action_fn = qed.BatchAction(config.beta) self._action_fn = lambda x: action_fn(x) self._charge_fn = lambda x: qed.batch_charges(x=x).detach() self._action_sum = lambda x: self.action(x).sum(-1) self._action_sum_hmc = lambda x: self._action_fn(x).sum(-1) self._plaq_fn = lambda x: ( ((-1.) * self._action_fn(x) / self._denom).detach() )
def run_hmc( param: Param, x: torch.Tensor = None, # keep_fields: bool = True, plot_metrics: bool = True, figsize: tuple[int, int] = None, use_title: bool = True, save_data: bool = True, # colors: list = None, # nprint: int = 1, nplot: int = 10, ): """Run generic HMC. Explicitly, we perform `param.nrun` independent experiments, where each experiment consists of generating `param.ntraj` trajectories. """ logdir = param.logdir if os.path.isdir(logdir): logdir = io.tstamp_dir(logdir) data_dir = os.path.join(logdir, 'data') plots_dir = os.path.join(logdir, 'plots') check_else_make_dir([plots_dir, data_dir]) action = qed.BatchAction(param.beta) logger.log(repr(param)) dt_run = 0. histories = {} run_times = [] fields_arr = [] ylabels = ['acc', 'dq', 'plaq'] xlabels = len(ylabels) * ['trajectory'] plots = {} if in_notebook(): plots = init_live_plots( param=param, figsize=figsize, use_title=use_title, # config=config, xlabels=xlabels, ylabels=ylabels) for n in range(param.nrun): t0 = time.time() hstr = f'RUN: {n}, last took: {int(dt_run//60)} m {dt_run%60:.4g} s' logger.rule(hstr) x = param.initializer() p = (-1.) * action(x) / (param.beta * param.volume) q = qed.batch_charges(x) logger.print_metrics({'plaq': p, 'q': q}) xarr = [] history = {} for i in range(param.ntraj): t1 = time.time() dH, exp_mdH, acc, x = qed.hmc(param, x, verbose=False) try: qold = history['q'][-1] except KeyError: qold = q qnew = qed.batch_charges(x) dq = torch.sqrt((qnew - qold)**2) plaq = (-1.) * action(x) / (param.beta * param.volume) xarr.append(x) metrics = { 'traj': n * param.ntraj + i + 1, 'dt': time.time() - t1, 'acc': acc.to(DTYPE), # 'True' if acc else 'False', 'dH': dH, 'plaq': plaq, 'q': int(qnew), 'dq': dq, } for k, v in metrics.items(): try: history[k].append(v) except KeyError: history[k] = [v] if (i - 1) % param.nprint == 0: _ = logger.print_metrics(metrics) if in_notebook() and i % nplot == 0: data = { 'dq': history['dq'], 'acc': history['acc'], 'plaq': history['plaq'], } update_plots(plots, data) dt = time.time() - t0 run_times.append(dt) histories[n] = history fields_arr.append(xarr) if plot_metrics: outdir = os.path.join(plots_dir, f'run{n}') check_else_make_dir(outdir) plot_history(history, param, therm_frac=THERM_FRAC, xlabel='traj', outdir=outdir, num_chains=CHAINS_TO_PLOT) run_times_strs = [f'{dt:.4f}' for dt in run_times] dt_strs = [f'{dt/param.ntraj:.4f}' for dt in run_times] logger.log(f'Run times: {run_times_strs}') logger.log(f'Per trajectory: {dt_strs}') hfile = os.path.join(logdir, 'hmc_histories.z') io.save_history(histories, hfile, name='hmc_histories') xfile = os.path.join(logdir, 'hmc_fields_arr.z') savez(fields_arr, xfile, name='hmc_fields_arr') return fields_arr, histories
def run( self, x: torch.Tensor = None, nprint: int = 25, nplot: int = 25, window: int = 10, num_trajs: int = 1024, writer: Optional[SummaryWriter] = None, plotdir: str = None, **kwargs, ): if x is not None: assert isinstance(x, torch.Tensor) else: x = self.initializer() logger.log(f'Running ftHMC with tau={self.tau}, nsteps={self.nstep}') history = {} # type: dict[str, list[torch.Tensor]] q = qed.batch_charges(x) p = (-1.) * self.action(x) / self._denom logger.print_metrics({'plaq': p, 'q': q}) plots = {} if in_notebook(): plots = init_live_plots(config=self.config, ylabels=['acc', 'dq', 'plaq'], xlabels=3 * ['trajectory'], **kwargs) # ------------------------------------------------------------ # TODO: Create directories for `FieldTransformation` and # `save_live_plots` along with metrics, other plots to dirs # ------------------------------------------------------------ for i in range(num_trajs): x, metrics_ = self.hmc(x, step=i) try: qold = history['q'][i-1] except KeyError: qold = q x_phys, _ = self.flow_forward(x) lmetrics = self.lattice_metrics(x_phys, qold) metrics = {**metrics_, **lmetrics} for key, val in metrics.items(): if writer is not None: write_summaries(metrics, writer=writer, step=i, pre='ftHMC') try: history[key].append(val) except KeyError: history[key] = [val] if (i - 1) % nplot == 0 and in_notebook() and plots != {}: data = { k: history[k] for k in ['dq', 'acc', 'plaq'] } plotter.update_plots(plots, data, window=window) if i % nprint == 0: logger.print_metrics(metrics) if plotdir is not None and in_notebook(): plotter.save_live_plots(plots, outdir=plotdir) # histories[n] = history # hfile = os.path.join(train_dir, 'train_history.z') # io.save_history(history, hfile, name='ftHMC_history') plotter.plot_history(history, therm_frac=0.0, outdir=plotdir, config=self.config, lfconfig=self.lfconfig, xlabel='Trajectory') return history
def train_step( model: FlowModel, config: TrainConfig, action: ActionFn, optimizer: optim.Optimizer, batch_size: int, scheduler: Any = None, scaler: GradScaler = None, pre_model: FlowModel = None, dkl_factor: float = 1., xi: torch.Tensor = None, ): """Perform a single training step. TODO: Add `torch.device` to arguments for DDP. """ t0 = time.time() # layers, prior = model['layers'], model['prior'] optimizer.zero_grad() loss_dkl = torch.tensor(0.0) if torch.cuda.is_available(): loss_dkl = loss_dkl.cuda() if pre_model is not None: pre_xi = pre_model.prior.sample_n(batch_size) x = qed.ft_flow(pre_model.layers, pre_xi) xi = qed.ft_flow_inv(pre_model.layers, x) # with torch.cuda.amp.autocast(): x, xi, logq = apply_flow_to_prior(model.prior, model.layers, xi=xi, batch_size=batch_size) logp = (-1.) * action(x) dkl = calc_dkl(logp, logq) ess = calc_ess(logp, logq) qi = qed.batch_charges(xi) q = qed.batch_charges(x) plaq = logp / (config.beta * config.volume) dq = torch.sqrt((q - qi) ** 2) loss_dkl = dkl_factor * dkl if scaler is not None: scaler.scale(loss_dkl).backward() scaler.step(optimizer) scaler.update() else: loss_dkl.backward() optimizer.step() if scheduler is not None: scheduler.step(loss_dkl) metrics = { 'dt': time.time() - t0, 'ess': grab(ess), 'logp': grab(logp), 'logq': grab(logq), 'loss_dkl': grab(loss_dkl), 'q': grab(q), 'dq': grab(dq), 'plaq': grab(plaq), } return metrics