Example #1
0
def serial_sample_generator_old(model: FlowModel, action_fn: ActionFn,
                                batch_size: int, num_samples: int):
    prior = model.prior
    layers = model.layers
    layers.eval()
    #  model.layers.eval()
    x, q, logq, logp = None, None, None, None
    for i in range(num_samples):
        batch_i = i % batch_size
        if batch_i == 0:
            # we're out of samples to propose, generate a new batch
            _, x, logq = apply_flow_to_prior(prior,
                                             layers,
                                             batch_size=batch_size)
            logp = -action_fn(x)
            q = qed.batch_charges(x)

        yield x[batch_i], q[batch_i], logq[batch_i], logp[batch_i]
Example #2
0
    def __init__(
            self,
            flow: nn.ModuleList,
            config: TrainConfig,
            lfconfig: lfConfig,
    ):
        super().__init__()
        self.flow = flow                    # layers of a `FlowModel`
        self.config = config                # Training config
        self.lfconfig = lfconfig            # lfConfig object
        self.dt = self.lfconfig.dt          # step size
        self.tau = self.lfconfig.tau        # trajectory length
        self.nstep = self.lfconfig.nstep    # number of leapfrog steps
        self._denom = (self.config.beta * self.config.volume)


        action_fn = qed.BatchAction(config.beta)
        self._action_fn = lambda x: action_fn(x)
        self._charge_fn = lambda x: qed.batch_charges(x=x).detach()
        self._action_sum = lambda x: self.action(x).sum(-1)
        self._action_sum_hmc = lambda x: self._action_fn(x).sum(-1)
        self._plaq_fn = lambda x: (
            ((-1.) * self._action_fn(x) / self._denom).detach()
        )
Example #3
0
File: hmc.py Project: nftqcd/fthmc
def run_hmc(
    param: Param,
    x: torch.Tensor = None,
    #  keep_fields: bool = True,
    plot_metrics: bool = True,
    figsize: tuple[int, int] = None,
    use_title: bool = True,
    save_data: bool = True,
    #  colors: list = None,
    #  nprint: int = 1,
    nplot: int = 10,
):
    """Run generic HMC.

    Explicitly, we perform `param.nrun` independent experiments, where each
    experiment consists of generating `param.ntraj` trajectories.
    """
    logdir = param.logdir
    if os.path.isdir(logdir):
        logdir = io.tstamp_dir(logdir)

    data_dir = os.path.join(logdir, 'data')
    plots_dir = os.path.join(logdir, 'plots')
    check_else_make_dir([plots_dir, data_dir])

    action = qed.BatchAction(param.beta)
    logger.log(repr(param))

    dt_run = 0.
    histories = {}
    run_times = []
    fields_arr = []
    ylabels = ['acc', 'dq', 'plaq']
    xlabels = len(ylabels) * ['trajectory']
    plots = {}
    if in_notebook():
        plots = init_live_plots(
            param=param,
            figsize=figsize,
            use_title=use_title,  # config=config,
            xlabels=xlabels,
            ylabels=ylabels)

    for n in range(param.nrun):
        t0 = time.time()

        hstr = f'RUN: {n}, last took: {int(dt_run//60)} m {dt_run%60:.4g} s'
        logger.rule(hstr)

        x = param.initializer()
        p = (-1.) * action(x) / (param.beta * param.volume)
        q = qed.batch_charges(x)

        logger.print_metrics({'plaq': p, 'q': q})
        xarr = []
        history = {}
        for i in range(param.ntraj):
            t1 = time.time()
            dH, exp_mdH, acc, x = qed.hmc(param, x, verbose=False)

            try:
                qold = history['q'][-1]
            except KeyError:
                qold = q
            qnew = qed.batch_charges(x)
            dq = torch.sqrt((qnew - qold)**2)

            plaq = (-1.) * action(x) / (param.beta * param.volume)

            xarr.append(x)

            metrics = {
                'traj': n * param.ntraj + i + 1,
                'dt': time.time() - t1,
                'acc': acc.to(DTYPE),  # 'True' if acc else 'False',
                'dH': dH,
                'plaq': plaq,
                'q': int(qnew),
                'dq': dq,
            }
            for k, v in metrics.items():
                try:
                    history[k].append(v)
                except KeyError:
                    history[k] = [v]

            if (i - 1) % param.nprint == 0:
                _ = logger.print_metrics(metrics)

            if in_notebook() and i % nplot == 0:
                data = {
                    'dq': history['dq'],
                    'acc': history['acc'],
                    'plaq': history['plaq'],
                }
                update_plots(plots, data)

        dt = time.time() - t0
        run_times.append(dt)
        histories[n] = history
        fields_arr.append(xarr)

        if plot_metrics:
            outdir = os.path.join(plots_dir, f'run{n}')
            check_else_make_dir(outdir)
            plot_history(history,
                         param,
                         therm_frac=THERM_FRAC,
                         xlabel='traj',
                         outdir=outdir,
                         num_chains=CHAINS_TO_PLOT)

    run_times_strs = [f'{dt:.4f}' for dt in run_times]
    dt_strs = [f'{dt/param.ntraj:.4f}' for dt in run_times]

    logger.log(f'Run times: {run_times_strs}')
    logger.log(f'Per trajectory: {dt_strs}')

    hfile = os.path.join(logdir, 'hmc_histories.z')
    io.save_history(histories, hfile, name='hmc_histories')

    xfile = os.path.join(logdir, 'hmc_fields_arr.z')
    savez(fields_arr, xfile, name='hmc_fields_arr')

    return fields_arr, histories
Example #4
0
    def run(
            self,
            x: torch.Tensor = None,
            nprint: int = 25,
            nplot: int = 25,
            window: int = 10,
            num_trajs: int = 1024,
            writer: Optional[SummaryWriter] = None,
            plotdir: str = None,
            **kwargs,
    ):
        if x is not None:
            assert isinstance(x, torch.Tensor)

        else:
            x = self.initializer()

        logger.log(f'Running ftHMC with tau={self.tau}, nsteps={self.nstep}')
        history = {}  # type: dict[str, list[torch.Tensor]]
        q = qed.batch_charges(x)
        p = (-1.) * self.action(x) / self._denom
        logger.print_metrics({'plaq': p, 'q': q})

        plots = {}
        if in_notebook():
            plots = init_live_plots(config=self.config,
                                    ylabels=['acc', 'dq', 'plaq'],
                                    xlabels=3 * ['trajectory'], **kwargs)

        # ------------------------------------------------------------
        # TODO: Create directories for `FieldTransformation` and
        # `save_live_plots` along with metrics, other plots to dirs
        # ------------------------------------------------------------
        for i in range(num_trajs):
            x, metrics_ = self.hmc(x, step=i)
            try:
                qold = history['q'][i-1]
            except KeyError:
                qold = q
            x_phys, _ = self.flow_forward(x)
            lmetrics = self.lattice_metrics(x_phys, qold)
            metrics = {**metrics_, **lmetrics}

            for key, val in metrics.items():
                if writer is not None:
                    write_summaries(metrics, writer=writer,
                                    step=i, pre='ftHMC')
                try:
                    history[key].append(val)
                except KeyError:
                    history[key] = [val]

            if (i - 1) % nplot == 0 and in_notebook() and plots != {}:
                data = {
                    k: history[k] for k in ['dq', 'acc', 'plaq']
                }
                plotter.update_plots(plots, data, window=window)

            if i % nprint == 0:
                logger.print_metrics(metrics)

        if plotdir is not None and in_notebook():
            plotter.save_live_plots(plots, outdir=plotdir)

        #  histories[n] = history
        # hfile = os.path.join(train_dir, 'train_history.z')
        # io.save_history(history, hfile, name='ftHMC_history')
        plotter.plot_history(history,
                             therm_frac=0.0,
                             outdir=plotdir,
                             config=self.config,
                             lfconfig=self.lfconfig,
                             xlabel='Trajectory')

        return history
Example #5
0
def train_step(
        model: FlowModel,
        config: TrainConfig,
        action: ActionFn,
        optimizer: optim.Optimizer,
        batch_size: int,
        scheduler: Any = None,
        scaler: GradScaler = None,
        pre_model: FlowModel = None,
        dkl_factor: float = 1.,
        xi: torch.Tensor = None,
):
    """Perform a single training step.

    TODO: Add `torch.device` to arguments for DDP.
    """
    t0 = time.time()
    #  layers, prior = model['layers'], model['prior']
    optimizer.zero_grad()

    loss_dkl = torch.tensor(0.0)
    if torch.cuda.is_available():
        loss_dkl = loss_dkl.cuda()

    if pre_model is not None:
        pre_xi = pre_model.prior.sample_n(batch_size)
        x = qed.ft_flow(pre_model.layers, pre_xi)
        xi = qed.ft_flow_inv(pre_model.layers, x)

    #  with torch.cuda.amp.autocast():
    x, xi, logq = apply_flow_to_prior(model.prior,
                                      model.layers,
                                      xi=xi, batch_size=batch_size)
    logp = (-1.) * action(x)
    dkl = calc_dkl(logp, logq)

    ess = calc_ess(logp, logq)
    qi = qed.batch_charges(xi)
    q = qed.batch_charges(x)
    plaq = logp / (config.beta * config.volume)
    dq = torch.sqrt((q - qi) ** 2)

    loss_dkl = dkl_factor * dkl

    if scaler is not None:
        scaler.scale(loss_dkl).backward()
        scaler.step(optimizer)
        scaler.update()
    else:
        loss_dkl.backward()
        optimizer.step()

    if scheduler is not None:
        scheduler.step(loss_dkl)

    metrics = {
        'dt': time.time() - t0,
        'ess': grab(ess),
        'logp': grab(logp),
        'logq': grab(logq),
        'loss_dkl': grab(loss_dkl),
        'q': grab(q),
        'dq': grab(dq),
        'plaq': grab(plaq),
    }

    return metrics