示例#1
0
def train_hmc(flags):
    """Main method for training HMC model."""
    hflags = AttrDict(dict(flags).copy())
    lr_config = AttrDict(hflags.pop('lr_config', None))
    config = AttrDict(hflags.pop('dynamics_config', None))
    net_config = AttrDict(hflags.pop('network_config', None))
    hflags.train_steps = hflags.pop('hmc_steps', None)
    hflags.beta_init = hflags.beta_final

    config.update({
        'hmc': True,
        'use_ncp': False,
        'aux_weight': 0.,
        'zero_init': False,
        'separate_networks': False,
        'use_conv_net': False,
        'directional_updates': False,
        'use_scattered_xnet_update': False,
        'use_tempered_traj': False,
        'gauge_eq_masks': False,
    })

    hflags.profiler = False
    hflags.make_summaries = True

    lr_config = LearningRateConfig(
        warmup_steps=0,
        decay_rate=0.9,
        decay_steps=hflags.train_steps // 10,
        lr_init=lr_config.get('lr_init', None),
    )

    train_dirs = io.setup_directories(hflags, 'training_hmc')
    dynamics = GaugeDynamics(hflags, config, net_config, lr_config)
    dynamics.save_config(train_dirs.config_dir)
    x, train_data = train_dynamics(dynamics, hflags, dirs=train_dirs)
    if IS_CHIEF:
        output_dir = os.path.join(train_dirs.train_dir, 'outputs')
        io.check_else_make_dir(output_dir)
        train_data.save_data(output_dir)

        params = {
            'eps': dynamics.eps,
            'num_steps': dynamics.config.num_steps,
            'beta_init': hflags.beta_init,
            'beta_final': hflags.beta_final,
            'lattice_shape': dynamics.config.lattice_shape,
            'net_weights': NET_WEIGHTS_HMC,
        }
        plot_data(train_data,
                  train_dirs.train_dir,
                  hflags,
                  thermalize=True,
                  params=params)

    return x, train_data, dynamics.eps.numpy()
示例#2
0
def short_training(
        train_steps: int,
        beta: float,
        log_dir: str,
        dynamics: GaugeDynamics,
        x: tf.Tensor = None,
):
    """Perform a brief training run prior to running inference."""
    ckpt_dir = os.path.join(log_dir, 'training', 'checkpoints')
    ckpt = tf.train.Checkpoint(dynamics=dynamics, optimizer=dynamics.optimizer)
    manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=5)
    current_step = 0
    if manager.latest_checkpoint:
        io.log(f'Restored model from: {manager.latest_checkpoint}')
        ckpt.restore(manager.latest_checkpoint)
        current_step = dynamics.optimizer.iterations.numpy()

    if x is None:
        x = convert_to_angle(tf.random.normal(dynamics.x_shape))

    train_data = DataContainer(current_step+train_steps, print_steps=1)

    dynamics.compile(loss=dynamics.calc_losses,
                     optimizer=dynamics.optimizer,
                     experimental_run_tf_function=False)

    x, metrics = dynamics.train_step((x, tf.constant(beta)))

    header = train_data.get_header(metrics, skip=SKEYS,
                                   prepend=['{:^12s}'.format('step')])
    io.log(header.split('\n'))
    for step in range(current_step, current_step + train_steps):
        start = time.time()
        x, metrics = dynamics.train_step((x, tf.constant(beta)))
        metrics.dt = time.time() - start
        train_data.update(step, metrics)
        data_str = train_data.print_metrics(metrics)
        logger.info(data_str)

        # logger.print_metrics(metrics)
        #  data_str = train_data.get_fstr(step, metrics, skip=SKEYS)
        #  io.log(data_str)

    return dynamics, train_data, x
示例#3
0
def run_profiler(dynamics: GaugeDynamics,
                 inputs: tuple[tf.Tensor, Union[float, tf.Tensor]],
                 logdir: str,
                 steps: int = 10):
    logger.debug(f'Running {steps} profiling steps!')
    x, beta = inputs
    beta = tf.constant(beta)
    metrics = None
    for _ in range(steps):
        x, metrics = dynamics.train_step((x, beta))

    tf.profiler.experimental.start(logdir=logdir, options=OPTIONS)
    for step in range(steps):
        with tf.profiler.experimental.Trace('train', step_num=step, _r=1):
            x, metrics = dynamics.train_step((x, beta))

    tf.profiler.experimental.stop(save=True)
    logger.debug(f'Done!')

    return x, metrics
示例#4
0
def run_md(
    dynamics: GaugeDynamics,
    inputs: tuple,
    md_steps: int,
):
    x, beta = inputs
    beta = tf.constant(beta)
    logger.debug(f'Running {md_steps} MD updates!')
    for _ in range(md_steps):
        mc_states, _ = dynamics.md_update((x, beta), training=True)
        x = mc_states.out.x  # type: tf.Tensor
    logger.debug(f'Done!')

    return x
示例#5
0
def trace_train_step(
    dynamics: GaugeDynamics,
    writer: SummaryWriter,
    outdir: str,
    x: tf.Tensor = None,
    beta: float = None,
    graph: bool = True,
    profiler: bool = True,
):
    if x is None:
        x = tf.random.uniform(dynamics.x_shape, *(-PI, PI))
    if beta is None:
        beta = 1.

    # Bracket the function call with
    # tf.summary.trace_on() and tf.summary.trace_export()
    tf.summary.trace_on(graph=graph, profiler=profiler)
    # Call only one tf.function when tracing
    x, metrics = dynamics.train_step((x, beta))
    with writer.as_default():
        tf.summary.trace_export(name='dynamics_train_step',
                                step=0,
                                profiler_outdir=outdir)
示例#6
0
def run_dynamics(
        dynamics: GaugeDynamics,
        flags: AttrDict,
        x: tf.Tensor = None,
        save_x: bool = False,
        md_steps: int = 0,
) -> (DataContainer, tf.Tensor, list):
    """Run inference on trained dynamics."""
    if not IS_CHIEF:
        return None, None, None

    # Setup
    print_steps = flags.get('print_steps', 5)
    beta = flags.get('beta', flags.get('beta_final', None))

    test_step = dynamics.test_step
    if flags.get('compile', True):
        test_step = tf.function(dynamics.test_step)
        io.log('Compiled `dynamics.test_step` using tf.function!')

    if x is None:
        x = tf.random.uniform(shape=dynamics.x_shape,
                              minval=-PI, maxval=PI,
                              dtype=TF_FLOAT)

    run_data = DataContainer(flags.run_steps)

    template = '\n'.join([f'beta: {beta}',
                          f'eps: {dynamics.eps.numpy():.4g}',
                          f'net_weights: {dynamics.net_weights}'])
    io.log(f'Running inference with:\n {template}')

    # Run 50 MD updates (w/o accept/reject) to ensure chains don't get stuck
    if md_steps > 0:
        for _ in range(md_steps):
            mc_states, _ = dynamics.md_update(x, beta, training=False)
            x = mc_states.out.x

    try:
        x, metrics = test_step((x, tf.constant(beta)))
    except Exception as exception:  # pylint:disable=broad-except
        io.log(f'Exception: {exception}')
        test_step = dynamics.test_step
        x, metrics = test_step((x, tf.constant(beta)))

    header = run_data.get_header(metrics,
                                 skip=['charges'],
                                 prepend=['{:^12s}'.format('step')])
    #  io.log(header)
    io.log(header.split('\n'), should_print=True)
    # -------------------------------------------------------------

    x_arr = []

    def timed_step(x: tf.Tensor, beta: tf.Tensor):
        start = time.time()
        x, metrics = test_step((x, tf.constant(beta)))
        metrics.dt = time.time() - start
        if save_x:
            x_arr.append(x.numpy())

        return x, metrics

    steps = tf.range(flags.run_steps, dtype=tf.int64)
    if NUM_NODES == 1:
        ctup = (CBARS['red'], CBARS['green'], CBARS['red'], CBARS['reset'])
        steps = tqdm(steps, desc='running', unit='step',
                     bar_format=("%s{l_bar}%s{bar}%s{r_bar}%s" % ctup))

    for step in steps:
        x, metrics = timed_step(x, beta)
        run_data.update(step, metrics)

        if step % print_steps == 0:
            summarize_dict(metrics, step, prefix='testing')
            data_str = run_data.get_fstr(step, metrics, skip=['charges'])
            io.log(data_str, should_print=True)

        if (step + 1) % 1000 == 0:
            io.log(header, should_print=True)

    return run_data, x, x_arr
示例#7
0
def run_dynamics(
        dynamics: GaugeDynamics,
        flags: dict[str, Any],
        writer: tf.summary.SummaryWriter = None,
        x: tf.Tensor = None,
        beta: float = None,
        save_x: bool = False,
        md_steps: int = 0,
        # window: int = 0,
        #  should_track: bool = False,
) -> (InferenceResults):
    """Run inference on trained dynamics."""
    if not IS_CHIEF:
        return InferenceResults(None, None, None, None, None)

    # -- Setup -----------------------------
    print_steps = flags.get('print_steps', 5)
    if beta is None:
        beta = flags.get('beta', flags.get('beta_final', None))  # type: float
        if beta is None:
            logger.warning(f'beta unspecified! setting to 1')
            beta = 1.
        assert beta is not None and isinstance(beta, float)

    test_step = dynamics.test_step
    if flags.get('compile', True):
        test_step = tf.function(dynamics.test_step)
        io.log('Compiled `dynamics.test_step` using tf.function!')

    if x is None:
        x = tf.random.uniform(shape=dynamics.x_shape, *(-PI, PI))
                              # minval, maxval=PI,
                              # dtype=TF_FLOAT)
    assert tf.is_tensor(x)

    run_steps = flags.get('run_steps', 20000)
    run_data = DataContainer(run_steps)

    template = '\n'.join([f'beta={beta}',
                          f'net_weights={dynamics.net_weights}'])
    logger.info(f'Running inference with {template}')

    # Run `md_steps MD updates (w/o accept/reject)
    # to ensure chains don't get stuck
    if md_steps > 0:
        for _ in range(md_steps):
            mc_states, _ = dynamics.md_update((x, beta), training=False)
            x = mc_states.out.x

    try:
        x, metrics = test_step((x, tf.constant(beta)))
    except Exception as err:  # pylint:disable=broad-except
        logger.warning(err)
        #  io.log(f'Exception: {exception}')
        test_step = dynamics.test_step
        x, metrics = test_step((x, tf.constant(beta)))

    x_arr = []

    def timed_step(x: tf.Tensor, beta: tf.Tensor):
        start = time.time()
        x, metrics = test_step((x, tf.constant(beta)))
        metrics.dt = time.time() - start
        if 'sin_charges' not in metrics:
            charges = dynamics.lattice.calc_both_charges(x=x)
            metrics['charges'] = charges.intQ
            metrics['sin_charges'] = charges.sinQ
        if save_x:
            x_arr.append(x.numpy())

        return x, metrics

    summary_steps = max(run_steps // 100, 50)

    if writer is not None:
        writer.set_as_default()

    steps = tf.range(run_steps, dtype=tf.int64)
    keep_ = ['step', 'dt', 'loss', 'accept_prob', 'beta',
             'dq_int', 'dq_sin', 'dQint', 'dQsin', 'plaqs', 'p4x4']

    beta = tf.constant(beta, dtype=TF_FLOAT)  # type: tf.Tensor
    data_strs = []
    for idx, step in enumerate(steps):
        x, metrics = timed_step(x, beta)
        run_data.update(step, metrics)  # update data after every accept/reject

        if step % summary_steps == 0:
            update_summaries(step, metrics, dynamics)
            # summarize_dict(metrics, step, prefix='testing')

        if step % print_steps == 0:
            pre = [f'{step}/{steps[-1]}']
            ms = run_data.print_metrics(metrics,
                                        pre=pre, keep=keep_)
            data_strs.append(ms)

    return InferenceResults(dynamics=dynamics, x=x, x_arr=x_arr,
                            run_data=run_data, data_strs=data_strs)
示例#8
0
def train_dynamics(
    dynamics: GaugeDynamics,
    inputs: dict[str, Any],
    dirs: dict[str, str] = None,
    x: tf.Tensor = None,
    steps_dict: dict[str, int] = None,
    save_metrics: bool = True,
    custom_betas: Union[list, np.ndarray] = None,
    window: int = 0,
) -> tuple[tf.Tensor, DataContainer]:
    """Train model."""
    configs = inputs['configs']
    steps = configs.get('steps', [])
    min_lr = configs.get('min_lr', 1e-5)
    patience = configs.get('patience', 10)
    factor = configs.get('reduce_lr_factor', 0.5)

    save_steps = configs.get('save_steps', 10000)  # type: int
    print_steps = configs.get('print_steps', 1000)  # type: int
    logging_steps = configs.get('logging_steps', 500)  # type: int
    steps_per_epoch = configs.get('steps_per_epoch', 1000)  # type: int
    if steps_dict is not None:
        save_steps = steps_dict.get('save', 10000)  # type: int
        print_steps = steps_dict.get('print', 1000)  # type: int
        logging_steps = steps_dict.get('logging_steps', 500)  # type: int
        steps_per_epoch = steps_dict.get('steps_per_epoch', 1000)  # type: int

    # -- Helper functions for training, logging, saving, etc. --------------
    #  step_times = []
    timer = StepTimer(evals_per_step=dynamics.config.num_steps)

    def train_step(x: tf.Tensor, beta: tf.Tensor):
        #  start = time.time()
        timer.start()
        x, metrics = dynamics.train_step((x, tf.constant(beta)))
        dt = timer.stop()
        metrics.dt = dt
        return x, metrics

    def should_print(step: int) -> bool:
        return IS_CHIEF and step % print_steps == 0

    def should_log(step: int) -> bool:
        return IS_CHIEF and step % logging_steps == 0

    def should_save(step: int) -> bool:
        return step % save_steps == 0 and ckpt is not None

    xshape = dynamics._xshape
    xr = tf.random.uniform(xshape, -PI, PI)
    x = inputs.get('x', xr) if x is None else x
    assert x is not None

    if custom_betas is None:
        betas = np.array(inputs.get('betas', None))
        assert betas is not None and betas.shape[0] > 0

        steps = np.array(inputs.get('steps'))
        assert steps is not None and steps.shape[0] > 0
    else:
        betas = np.array(custom_betas)
        start = dynamics.optimizer.iterations
        nsteps = len(betas)
        steps = np.arange(start, start + nsteps)

    dirs = inputs.get('dirs', None) if dirs is None else dirs  # type: dict
    assert dirs is not None

    manager = inputs['manager']  # type: tf.train.CheckpointManager
    ckpt = inputs['checkpoint']  # type: tf.train.Checkpoint
    train_data = inputs['train_data']  # type: DataContainer

    #  tf.compat.v1.autograph.experimental.do_not_convert(dynamics.train_step)

    # -- Setup dynamic learning rate schedule -----------------
    assert dynamics.lr_config is not None
    warmup_steps = dynamics.lr_config.warmup_steps
    reduce_lr = ReduceLROnPlateau(monitor='loss',
                                  mode='min',
                                  warmup_steps=warmup_steps,
                                  factor=factor,
                                  min_lr=min_lr,
                                  verbose=1,
                                  patience=patience)
    reduce_lr.set_model(dynamics)

    # -- Setup summary writer -----------
    writer = inputs.get('writer', None)  # type: tf.summary.SummaryWriter
    if IS_CHIEF and writer is not None:
        writer.set_as_default()

    # -- Run profiler? ----------------------------------------
    if configs.get('profiler', False):
        if RANK == 0:
            sdir = dirs['summary_dir']
            #  trace_train_step(dynamics,
            #                   graph=True,
            #                   profiler=True,
            #                   outdir=sdir,
            #                   writer=writer)
            x, metrics = run_profiler(dynamics, (x, betas[0]),
                                      logdir=sdir,
                                      steps=5)
    else:
        x, metrics = dynamics.train_step((x, betas[0]))

    # -- Run MD update to not get stuck ----------------------
    md_steps = configs.get('md_steps', 0)
    if md_steps > 0:
        x = run_md(dynamics, (x, betas[0]), md_steps)

    warmup_steps = dynamics.lr_config.warmup_steps
    total_steps = steps[-1]
    if len(steps) != len(betas):
        betas = betas[steps[0]:]

    keep = [
        'dt', 'loss', 'accept_prob', 'beta', 'Hwb_start', 'Hwf_start',
        'Hwb_mid', 'Hwf_mid', 'Hwb_end', 'Hwf_end', 'xeps', 'veps', 'dq',
        'dq_sin', 'plaqs', 'p4x4', 'charges', 'sin_charges'
    ]
    plots = {}
    if in_notebook():
        plots = plotter.init_plots(configs, figsize=(9, 3), dpi=125)

    # -- Training loop ---------------------------------------------------
    data_strs = []
    logdir = dirs['log_dir']
    data_dir = dirs['data_dir']
    logfile = dirs['log_file']
    logfile = os.path.join(logdir, 'training', 'train_log.txt')

    assert x is not None
    assert manager is not None
    assert len(steps) == len(betas)
    for step, beta in zip(steps, betas):
        x, metrics = train_step(x, beta)

        # ----------------------------------------------------------------
        # TODO: Run inference when beta hits an integer
        # >>> beta_inf = {i: False, for i in np.arange(beta_final)}
        # >>> if any(np.isclose(beta, np.array(list(beta_inf.keys())))):
        # >>>     run_inference(...)
        # ----------------------------------------------------------------

        if (step + 1) > warmup_steps and (step + 1) % steps_per_epoch == 0:
            reduce_lr.on_epoch_end(step + 1, {'loss': metrics.loss})

        # -- Save checkpoints and dump configs `x` from each rank --------
        if should_save(step + 1):
            train_data.update(step, metrics)
            train_data.dump_configs(x,
                                    data_dir,
                                    rank=RANK,
                                    local_rank=LOCAL_RANK)
            if IS_CHIEF:
                _ = timer.save_and_write(logdir, mode='w')
                # -- Save CheckpointManager ------------------------------
                manager.save()
                mstr = f'Checkpoint saved to: {manager.latest_checkpoint}'
                logger.info(mstr)
                with open(logfile, 'w') as f:
                    f.writelines('\n'.join(data_strs))

                # -- Save train_data and free consumed memory ------------
                train_data.save_and_flush(data_dir,
                                          logfile,
                                          rank=RANK,
                                          mode='a')
                if not dynamics.config.hmc:
                    # -- Save network weights ----------------------------
                    dynamics.save_networks(logdir)
                    logger.info(f'Networks saved to: {logdir}')

        # -- Print current training state and metrics -------------------
        if should_print(step):
            train_data.update(step, metrics)
            keep_ = [
                'step', 'dt', 'loss', 'accept_prob', 'beta', 'dq_int',
                'dq_sin', 'dQint', 'dQsin', 'plaqs', 'p4x4'
            ]
            pre = [f'{step:>4g}/{total_steps:<4g}']
            #  data_str = logger.print_metrics(metrics, window=50,
            #                                  pre=pre, keep=keep_)
            data_str = train_data.print_metrics(metrics,
                                                window=window,
                                                pre=pre,
                                                keep=keep_)
            data_strs.append(data_str)

        if in_notebook() and step % PLOT_STEPS == 0 and IS_CHIEF:
            train_data.update(step, metrics)
            if len(train_data.data.keys()) == 0:
                update_plots(metrics,
                             plots,
                             logging_steps=configs['logging_steps'])
            else:
                update_plots(train_data.data,
                             plots,
                             logging_steps=configs['logging_steps'])

        # -- Update summary objects ---------------------
        if should_log(step):
            train_data.update(step, metrics)
            if writer is not None:
                update_summaries(step, metrics, dynamics)
                writer.flush()

    # -- Dump config objects -------------------------------------------------
    train_data.dump_configs(x, data_dir, rank=RANK, local_rank=LOCAL_RANK)
    if IS_CHIEF:
        manager.save()
        logger.info(f'Checkpoint saved to: {manager.latest_checkpoint}')

        with open(logfile, 'w') as f:
            f.writelines('\n'.join(data_strs))

        if save_metrics:
            train_data.save_and_flush(data_dir, logfile, rank=RANK, mode='a')

        if not dynamics.config.hmc:
            dynamics.save_networks(logdir)

        if writer is not None:
            writer.flush()
            writer.close()

        #  ngrad_evals =  SIZE * dynamics.config.num_steps * len(step_times)
        #  eval_rate = ngrad_evals / np.sum(step_times)
        #  outstr = '\n'.join([f'ngrad_evals: {ngrad_evals}',
        #                      f'sum(step_times): {np.sum(step_times)}',
        #                      f'eval rate: {eval_rate}'])
        #  with open(Path(logdir).joinpath('eval_rate.txt'), 'a') as f:
        #      f.write(outstr)
        #
        #  csvfile = Path(logdir).joinpath('dt_train.csv')
        #  pd.DataFrame(step_times).to_csv(csvfile, mode='a')

    return x, train_data