Exemplo n.º 1
0
def update_plot(
    y: Metric,
    fig: plt.Figure,
    ax: plt.Axes,
    line: list[plt.Line2D],
    display_id: DisplayHandle,
    window: int = 15,
    logging_steps: int = 1,
):
    if not in_notebook():
        return

    y = np.array(y)
    if len(y.shape) == 2:
        y = y.mean(-1)

    yavg = moving_average(y.squeeze(), window=window)
    line[0].set_ydata(yavg)
    line[0].set_xdata(logging_steps * np.arange(yavg.shape[0]))
    #  line[0].set_ydata(y)
    #  line[0].set_xdata(plot_freq * np.arange(y.shape[0]))
    #  line[0].set_xdata(np.arange(len(yavg)))
    ax.relim()
    ax.autoscale_view()
    fig.canvas.draw()
    display_id.update(fig)
Exemplo n.º 2
0
def init_plots(configs: dict = None, ylabels: list[str] = None, **kwargs):
    loss = None
    dq_int = None
    dq_sin = None
    if in_notebook():
        dq_int = init_live_plot(configs=configs,
                                ylabel='dq_int',
                                xlabel='Training Step',
                                **kwargs)
        dq_sin = init_live_plot(configs=configs,
                                ylabel='dq_sin',
                                xlabel='Training Step',
                                **kwargs)

        c0 = ['C0', 'C1']
        if ylabels is None:
            ylabels = ['loss', 'beta']
        loss = init_live_joint_plots(ylabels=ylabels,
                                     xlabel='Train Step',
                                     colors=c0,
                                     use_title=False,
                                     **kwargs)
    return {
        'dq_int': dq_int,
        'dq_sin': dq_sin,
        'loss': loss,
    }
Exemplo n.º 3
0
def test(
    dynamics: GaugeDynamics,
    steps: Steps,
    beta: Union[list[float], float],
    x: torch.Tensor = None,
    skip: Union[str, list[str]] = None,
    keep: Union[str, list[str]] = None,
    history: History = None,
    nchains_test: int = None,
) -> History:
    """Run training and evaluate the trained model."""
    logger.log(80 * '-')
    logger.log(f'Running inference...')
    should_print = not in_notebook()

    dynamics.eval()

    if history is None:
        history = History()

    if x is None:
        x = random_angle(dynamics.config.x_shape, requires_grad=True)

    if isinstance(beta, float):
        beta = np.array(steps.train * [beta], dtype=np.float32).tolist()

    assert isinstance(beta, list)
    assert len(beta) == steps.train
    assert isinstance(beta[0], float)

    test_logs = []
    test_beta = beta[-1]
    x = x.reshape(x.shape[0], -1).to(DEVICE)
    if nchains_test is not None:
        x = x[:nchains_test, :]

    for step in range(steps.test):
        x, metrics = test_step((x, test_beta), dynamics, timer=history.timer)
        history.update(metrics, step)
        pre = [f'{step}/{steps.test}']
        mstr = history.metrics_summary(window=0, pre=pre, keep=keep, skip=skip)
        if not should_print:
            logger.log(mstr)

        test_logs.append(mstr)

    rate = history.timer.get_eval_rate(dynamics.config.num_steps)
    logger.log(f'Done training! took: {rate["total_time"]}')
    logger.log(f'Timing info:')
    for key, val in rate.items():
        logger.log(f' - {key}={val}')

    return history
Exemplo n.º 4
0
def update_joint_plots(
    plot_data1: LivePlotData,
    plot_data2: LivePlotData,
    display_id: DisplayHandle,
    window: int = 15,
    logging_steps: int = 1,
    fig: plt.Figure = None,
):
    if not in_notebook():
        return

    if fig is None:
        fig = plt.gcf()

    plot_obj1 = plot_data1.plot_obj
    plot_obj2 = plot_data2.plot_obj

    x1 = np.array(plot_data1.data).squeeze()  # type: np.ndarray
    x2 = np.array(plot_data2.data).squeeze()  # type: np.ndarray

    #  x1avg = x1.mean(-1) if len(x1.shape) == 2 else x1  # type: np.ndarray
    #  x2avg = x2.mean(-1) if len(x2.shape) == 2 else x2  # type: np.ndarray
    if len(x1.shape) == 2:
        x1avg = np.mean(x1, -1)
    else:
        x1avg = x1
    if len(x2.shape) == 2:
        x2avg = np.mean(x2, -1)
    else:
        x2avg = x2

    y1 = moving_average(x1avg, window=window)
    y2 = moving_average(x2avg, window=window)

    plot_obj1.line[0].set_ydata(np.array(y1))
    plot_obj1.line[0].set_xdata(logging_steps * np.arange(y1.shape[0]))

    plot_obj2.line[0].set_ydata(y2)
    plot_obj2.line[0].set_xdata(logging_steps * np.arange(y2.shape[0]))

    plot_obj1.ax.relim()
    plot_obj2.ax.relim()

    plot_obj1.ax.autoscale_view()
    plot_obj2.ax.autoscale_view()

    fig.canvas.draw()
    display_id.update(fig)  # need to force colab to update plot
Exemplo n.º 5
0
def train(
    dynamics: GaugeDynamics,
    optimizer: optim.Optimizer,
    steps: Steps,
    beta: Union[list[float], float],
    window: int = 10,
    x: torch.Tensor = None,
    skip: Union[str, list[str]] = None,
    keep: Union[str, list[str]] = None,
    history: History = None,
) -> History:
    """Train dynamics."""
    dynamics.train()

    if x is None:
        x = random_angle(dynamics.config.x_shape, requires_grad=True)
        x = x.reshape(x.shape[0], -1)

    train_logs = []
    if history is None:
        history = History()

    should_print = (not in_notebook())

    if isinstance(beta, list):
        assert len(beta) == steps.train
    elif isinstance(beta, float):
        beta = np.array(steps.train * [beta], dtype=np.float32).tolist()
    else:
        raise TypeError(
            f'beta expected to be `float | list[float]`,\n got: {type(beta)}')

    assert (isinstance(beta, list) and isinstance(beta[0], float))

    for step, b in zip(range(steps.train), beta):
        x, metrics = train_step((to_u1(x), b),
                                dynamics=dynamics,
                                optimizer=optimizer,
                                timer=history.timer)
        if (step + 1) % steps.log == 0:
            history.update(metrics, step)
            # pre = [f'{step}/{steps.train}']
            pre = [f'{int((step / steps.train) * 100)}%']
            mstr = history.metrics_summary(window=window,
                                           pre=pre,
                                           keep=keep,
                                           skip=skip)
            # should_print=should_print)
            if not should_print:
                logger.log(mstr)

            train_logs.append(mstr)

    logger.log(80 * '-')
    rate = history.timer.get_eval_rate(dynamics.config.num_steps)
    logger.log(f'Done training! took: {rate["total_time"]}')
    logger.log(f'Timing info:')
    for key, val in rate.items():
        logger.log(f' - {key}={val}')

    return history
Exemplo n.º 6
0
def plot_data(
        data_container: "DataContainer", #  type: "DataContainer",  # noqa:F821
        configs: dict = None,
        out_dir: str = None,
        therm_frac: float = 0,
        params: Union[dict, AttrDict] = None,
        hmc: bool = None,
        num_chains: int = 32,
        profile: bool = False,
        cmap: str = 'crest',
        verbose: bool = False,
        logging_steps: int = 1,
) -> dict:
    """Plot data from `data_container.data`."""
    if verbose:
        keep_strs = list(data_container.data.keys())

    else:
        keep_strs = [
            'charges', 'plaqs', 'accept_prob',
            'Hf_start', 'Hf_mid', 'Hf_end',
            'Hb_start', 'Hb_mid', 'Hb_end',
            'Hwf_start', 'Hwf_mid', 'Hwf_end',
            'Hwb_start', 'Hwb_mid', 'Hwb_end',
            'xeps_start', 'xeps_mid', 'xeps_end',
            'veps_start', 'veps_mid', 'veps_end'
        ]

    with_jupyter = in_notebook()
    # -- TODO: --------------------------------------
    #  * Get rid of unnecessary `params` argument,
    #    all of the entries exist in `configs`.
    # ----------------------------------------------
    if num_chains > 16:
        logger.warning(
            f'Reducing `num_chains` from {num_chains} to 16 for plotting.'
        )
        num_chains = 16

    plot_times = {}
    save_times = {}

    title = None
    if params is not None:
        try:
            title = get_title_str_from_params(params)
        except:
            title = None

    else:
        if configs is not None:
            params = {
                'beta_init': configs['beta_init'],
                'beta_final': configs['beta_final'],
                'x_shape': configs['dynamics_config']['x_shape'],
                'num_steps': configs['dynamics_config']['num_steps'],
                'net_weights': configs['dynamics_config']['net_weights'],
            }
        else:
            params = {}

    tstamp = io.get_timestamp('%Y-%m-%d-%H%M%S')
    plotdir = None
    if out_dir is not None:
        plotdir = os.path.join(out_dir, f'plots_{tstamp}')
        io.check_else_make_dir(plotdir)

    tint_data = {}
    output = {}
    if 'charges' in data_container.data:
        if configs is not None:
            lf = configs['dynamics_config']['num_steps']  # type: int
        else:
            lf = 0
        qarr = np.array(data_container.data['charges'])
        t0 = time.time()
        tint_dict, _ = plot_autocorrs_vs_draws(qarr, num_pts=20,
                                               nstart=1000, therm_frac=0.2,
                                               out_dir=plotdir, lf=lf)
        plot_times['plot_autocorrs_vs_draws'] = time.time() - t0

        tint_data = deepcopy(params)
        tint_data.update({
            'narr': tint_dict['narr'],
            'tint': tint_dict['tint'],
            'run_params': params,
        })

        run_dir = params.get('run_dir', None)
        if run_dir is not None:
            if os.path.isdir(str(Path(run_dir))):
                tint_file = os.path.join(run_dir, 'tint_data.z')
                t0 = time.time()
                io.savez(tint_data, tint_file, 'tint_data')
                save_times['tint_data'] = time.time() - t0

        t0 = time.time()
        qsteps = logging_steps * np.arange(qarr.shape[0])
        _ = plot_charges(qsteps, qarr, out_dir=plotdir, title=title)
        plot_times['plot_charges'] = time.time() - t0

        output.update({
            'tint_dict': tint_dict,
            'charges_steps': qsteps,
            'charges_arr': qarr,
        })

    hmc = params.get('hmc', False) if hmc is None else hmc

    data_dict = {}
    data_vars = {}
    #  charges_steps = []
    #  charges_arr = []
    for key, val in data_container.data.items():
        if key in SKEYS and key not in keep_strs:
            continue
        #  if key == 'x':
        #      continue
        #
        # ====
        # Conditional to skip logdet-related data
        # from being plotted if data generated from HMC run
        if hmc:
            for skip_str in ['Hw', 'ld', 'sld', 'sumlogdet']:
                if skip_str in key:
                    continue

        arr = np.array(val)
        steps = logging_steps * np.arange(len(arr))

        if therm_frac > 0:
            if logging_steps == 1:
                arr, steps = therm_arr(arr, therm_frac=therm_frac)
            else:
                drop = int(therm_frac * arr.shape[0])
                arr = arr[drop:]
                steps = steps[drop:]

        if logging_steps == 1 and therm_frac > 0:
            arr, steps = therm_arr(arr, therm_frac=therm_frac)

        #  if arr.flatten().std() < 1e-2:
        #      logger.warning(f'Skipping plot for: {key}')
        #      logger.warning(f'std({key}) = {arr.flatten().std()} < 1e-2')

        labels = ('MC Step', key)
        data = (steps, arr)

        # -- NOTE: arr.shape: (draws,) = (D,) -------------------------------
        if len(arr.shape) == 1:
            data_dict[key] = xr.DataArray(arr, dims=['draw'], coords=[steps])
            #  if verbose:
            #      plotdir_ = os.path.join(plotdir, f'mcmc_lineplots')
            #      io.check_else_make_dir(plotdir_)
            #      lplot_fname = os.path.join(plotdir_, f'{key}.pdf')
            #      _, _ = mcmc_lineplot(data, labels, title,
            #                           lplot_fname, show_avg=True)
            #      plt.close('all')

        # -- NOTE: arr.shape: (draws, chains) = (D, C) ----------------------
        elif len(arr.shape) == 2:
            data_dict[key] = data
            chains = np.arange(arr.shape[1])
            data_arr = xr.DataArray(arr.T,
                                    dims=['chain', 'draw'],
                                    coords=[chains, steps])
            data_dict[key] = data_arr
            data_vars[key] = data_arr
            #  if verbose:
            #      plotdir_ = os.path.join(plotdir, 'traceplots')
            #      tplot_fname = os.path.join(plotdir_, f'{key}_traceplot.pdf')
            #      _ = mcmc_traceplot(key, data_arr, title, tplot_fname)
            #      plt.close('all')

        # -- NOTE: arr.shape: (draws, leapfrogs, chains) = (D, L, C) ---------
        elif len(arr.shape) == 3:
            _, leapfrogs_, chains_ = arr.shape
            chains = np.arange(chains_)
            leapfrogs = np.arange(leapfrogs_)
            data_dict[key] = xr.DataArray(arr.T,  # NOTE: [arr.T] = (C, L, D)
                                          dims=['chain', 'leapfrog', 'draw'],
                                          coords=[chains, leapfrogs, steps])

    #  plotdir_xr = None
    #  if plotdir is not None:
    #      plotdir_xr = os.path.join(plotdir, 'xarr_plots')

    plotdir_xr = None
    if plotdir is not None:
        plotdir_xr = os.path.join(plotdir, 'xarr_plots')

    t0 = time.time()
    dataset, dtplot = data_container.plot_dataset(plotdir_xr,
                                                  num_chains=num_chains,
                                                  therm_frac=therm_frac,
                                                  ridgeplots=True,
                                                  cmap=cmap,
                                                  profile=profile)

    if not with_jupyter:
        plt.close('all')

    plot_times['data_container.plot_dataset'] = {'total': time.time() - t0}
    for key, val in dtplot.items():
        plot_times['data_container.plot_dataset'][key] = val

    if not hmc and 'Hwf' in data_dict.keys():
        t0 = time.time()
        _ = plot_energy_distributions(data_dict, out_dir=plotdir, title=title)
        plot_times['plot_energy_distributions'] = time.time() - t0

    t0 = time.time()
    _ = mcmc_avg_lineplots(data_dict, title, plotdir)
    plot_times['mcmc_avg_lineplots'] = time.time() - t0

    if not with_jupyter:
        plt.close('all')

    output.update({
        'data_container': data_container,
        'data_dict': data_dict,
        'data_vars': data_vars,
        'out_dir': plotdir,
        'save_times': save_times,
        'plot_times': plot_times,
    })

    return output
Exemplo n.º 7
0
def train_dynamics(
    dynamics: GaugeDynamics,
    inputs: dict[str, Any],
    dirs: dict[str, str] = None,
    x: tf.Tensor = None,
    steps_dict: dict[str, int] = None,
    save_metrics: bool = True,
    custom_betas: Union[list, np.ndarray] = None,
    window: int = 0,
) -> tuple[tf.Tensor, DataContainer]:
    """Train model."""
    configs = inputs['configs']
    steps = configs.get('steps', [])
    min_lr = configs.get('min_lr', 1e-5)
    patience = configs.get('patience', 10)
    factor = configs.get('reduce_lr_factor', 0.5)

    save_steps = configs.get('save_steps', 10000)  # type: int
    print_steps = configs.get('print_steps', 1000)  # type: int
    logging_steps = configs.get('logging_steps', 500)  # type: int
    steps_per_epoch = configs.get('steps_per_epoch', 1000)  # type: int
    if steps_dict is not None:
        save_steps = steps_dict.get('save', 10000)  # type: int
        print_steps = steps_dict.get('print', 1000)  # type: int
        logging_steps = steps_dict.get('logging_steps', 500)  # type: int
        steps_per_epoch = steps_dict.get('steps_per_epoch', 1000)  # type: int

    # -- Helper functions for training, logging, saving, etc. --------------
    #  step_times = []
    timer = StepTimer(evals_per_step=dynamics.config.num_steps)

    def train_step(x: tf.Tensor, beta: tf.Tensor):
        #  start = time.time()
        timer.start()
        x, metrics = dynamics.train_step((x, tf.constant(beta)))
        dt = timer.stop()
        metrics.dt = dt
        return x, metrics

    def should_print(step: int) -> bool:
        return IS_CHIEF and step % print_steps == 0

    def should_log(step: int) -> bool:
        return IS_CHIEF and step % logging_steps == 0

    def should_save(step: int) -> bool:
        return step % save_steps == 0 and ckpt is not None

    xshape = dynamics._xshape
    xr = tf.random.uniform(xshape, -PI, PI)
    x = inputs.get('x', xr) if x is None else x
    assert x is not None

    if custom_betas is None:
        betas = np.array(inputs.get('betas', None))
        assert betas is not None and betas.shape[0] > 0

        steps = np.array(inputs.get('steps'))
        assert steps is not None and steps.shape[0] > 0
    else:
        betas = np.array(custom_betas)
        start = dynamics.optimizer.iterations
        nsteps = len(betas)
        steps = np.arange(start, start + nsteps)

    dirs = inputs.get('dirs', None) if dirs is None else dirs  # type: dict
    assert dirs is not None

    manager = inputs['manager']  # type: tf.train.CheckpointManager
    ckpt = inputs['checkpoint']  # type: tf.train.Checkpoint
    train_data = inputs['train_data']  # type: DataContainer

    #  tf.compat.v1.autograph.experimental.do_not_convert(dynamics.train_step)

    # -- Setup dynamic learning rate schedule -----------------
    assert dynamics.lr_config is not None
    warmup_steps = dynamics.lr_config.warmup_steps
    reduce_lr = ReduceLROnPlateau(monitor='loss',
                                  mode='min',
                                  warmup_steps=warmup_steps,
                                  factor=factor,
                                  min_lr=min_lr,
                                  verbose=1,
                                  patience=patience)
    reduce_lr.set_model(dynamics)

    # -- Setup summary writer -----------
    writer = inputs.get('writer', None)  # type: tf.summary.SummaryWriter
    if IS_CHIEF and writer is not None:
        writer.set_as_default()

    # -- Run profiler? ----------------------------------------
    if configs.get('profiler', False):
        if RANK == 0:
            sdir = dirs['summary_dir']
            #  trace_train_step(dynamics,
            #                   graph=True,
            #                   profiler=True,
            #                   outdir=sdir,
            #                   writer=writer)
            x, metrics = run_profiler(dynamics, (x, betas[0]),
                                      logdir=sdir,
                                      steps=5)
    else:
        x, metrics = dynamics.train_step((x, betas[0]))

    # -- Run MD update to not get stuck ----------------------
    md_steps = configs.get('md_steps', 0)
    if md_steps > 0:
        x = run_md(dynamics, (x, betas[0]), md_steps)

    warmup_steps = dynamics.lr_config.warmup_steps
    total_steps = steps[-1]
    if len(steps) != len(betas):
        betas = betas[steps[0]:]

    keep = [
        'dt', 'loss', 'accept_prob', 'beta', 'Hwb_start', 'Hwf_start',
        'Hwb_mid', 'Hwf_mid', 'Hwb_end', 'Hwf_end', 'xeps', 'veps', 'dq',
        'dq_sin', 'plaqs', 'p4x4', 'charges', 'sin_charges'
    ]
    plots = {}
    if in_notebook():
        plots = plotter.init_plots(configs, figsize=(9, 3), dpi=125)

    # -- Training loop ---------------------------------------------------
    data_strs = []
    logdir = dirs['log_dir']
    data_dir = dirs['data_dir']
    logfile = dirs['log_file']
    logfile = os.path.join(logdir, 'training', 'train_log.txt')

    assert x is not None
    assert manager is not None
    assert len(steps) == len(betas)
    for step, beta in zip(steps, betas):
        x, metrics = train_step(x, beta)

        # ----------------------------------------------------------------
        # TODO: Run inference when beta hits an integer
        # >>> beta_inf = {i: False, for i in np.arange(beta_final)}
        # >>> if any(np.isclose(beta, np.array(list(beta_inf.keys())))):
        # >>>     run_inference(...)
        # ----------------------------------------------------------------

        if (step + 1) > warmup_steps and (step + 1) % steps_per_epoch == 0:
            reduce_lr.on_epoch_end(step + 1, {'loss': metrics.loss})

        # -- Save checkpoints and dump configs `x` from each rank --------
        if should_save(step + 1):
            train_data.update(step, metrics)
            train_data.dump_configs(x,
                                    data_dir,
                                    rank=RANK,
                                    local_rank=LOCAL_RANK)
            if IS_CHIEF:
                _ = timer.save_and_write(logdir, mode='w')
                # -- Save CheckpointManager ------------------------------
                manager.save()
                mstr = f'Checkpoint saved to: {manager.latest_checkpoint}'
                logger.info(mstr)
                with open(logfile, 'w') as f:
                    f.writelines('\n'.join(data_strs))

                # -- Save train_data and free consumed memory ------------
                train_data.save_and_flush(data_dir,
                                          logfile,
                                          rank=RANK,
                                          mode='a')
                if not dynamics.config.hmc:
                    # -- Save network weights ----------------------------
                    dynamics.save_networks(logdir)
                    logger.info(f'Networks saved to: {logdir}')

        # -- Print current training state and metrics -------------------
        if should_print(step):
            train_data.update(step, metrics)
            keep_ = [
                'step', 'dt', 'loss', 'accept_prob', 'beta', 'dq_int',
                'dq_sin', 'dQint', 'dQsin', 'plaqs', 'p4x4'
            ]
            pre = [f'{step:>4g}/{total_steps:<4g}']
            #  data_str = logger.print_metrics(metrics, window=50,
            #                                  pre=pre, keep=keep_)
            data_str = train_data.print_metrics(metrics,
                                                window=window,
                                                pre=pre,
                                                keep=keep_)
            data_strs.append(data_str)

        if in_notebook() and step % PLOT_STEPS == 0 and IS_CHIEF:
            train_data.update(step, metrics)
            if len(train_data.data.keys()) == 0:
                update_plots(metrics,
                             plots,
                             logging_steps=configs['logging_steps'])
            else:
                update_plots(train_data.data,
                             plots,
                             logging_steps=configs['logging_steps'])

        # -- Update summary objects ---------------------
        if should_log(step):
            train_data.update(step, metrics)
            if writer is not None:
                update_summaries(step, metrics, dynamics)
                writer.flush()

    # -- Dump config objects -------------------------------------------------
    train_data.dump_configs(x, data_dir, rank=RANK, local_rank=LOCAL_RANK)
    if IS_CHIEF:
        manager.save()
        logger.info(f'Checkpoint saved to: {manager.latest_checkpoint}')

        with open(logfile, 'w') as f:
            f.writelines('\n'.join(data_strs))

        if save_metrics:
            train_data.save_and_flush(data_dir, logfile, rank=RANK, mode='a')

        if not dynamics.config.hmc:
            dynamics.save_networks(logdir)

        if writer is not None:
            writer.flush()
            writer.close()

        #  ngrad_evals =  SIZE * dynamics.config.num_steps * len(step_times)
        #  eval_rate = ngrad_evals / np.sum(step_times)
        #  outstr = '\n'.join([f'ngrad_evals: {ngrad_evals}',
        #                      f'sum(step_times): {np.sum(step_times)}',
        #                      f'eval rate: {eval_rate}'])
        #  with open(Path(logdir).joinpath('eval_rate.txt'), 'a') as f:
        #      f.write(outstr)
        #
        #  csvfile = Path(logdir).joinpath('dt_train.csv')
        #  pd.DataFrame(step_times).to_csv(csvfile, mode='a')

    return x, train_data
Exemplo n.º 8
0
    k: io.get_timestamp(v)
    for k, v in dict(zip(names, formats)).items()
}

PlotData = plotter.LivePlotData

# logger = Logger()

OPTIONS = tf.profiler.experimental.ProfilerOptions(
    host_tracer_level=2,
    python_tracer_level=1,
    device_tracer_level=1,
    delay_ms=None,
)

if in_notebook():
    logger = logging.getLogger('jupyter')
else:
    logger = logging.getLogger('l2hmc')


def update_plots(history: dict,
                 plots: dict,
                 window: int = 1,
                 logging_steps: int = 1):
    lpdata = PlotData(history['loss'], plots['loss']['plot_obj1'])
    bpdata = PlotData(history['beta'], plots['loss']['plot_obj2'])
    fig_loss = plots['loss']['fig']
    id_loss = plots['loss']['display_id']
    plotter.update_joint_plots(lpdata,
                               bpdata,