Esempio n. 1
0
 def save_config(self, config_dir: str):
     """Helper method for saving configuration objects."""
     io.save_dict(self.config, config_dir, name='dynamics_config')
     io.save_dict(self.net_config, config_dir, name='network_config')
     io.save_dict(self.lr_config, config_dir, name='lr_config')
     io.save_dict(self.params, config_dir, name='dynamics_params')
     if self.conv_config is not None and self.config.use_conv_net:
         io.save_dict(self.conv_config, config_dir, name='conv_config')
Esempio n. 2
0
def setup_directories(configs: dict) -> dict:
    """Setup directories for training."""
    logfile = os.path.join(os.getcwd(), 'log_dirs.txt')
    ensure_new = configs.get('ensure_new', False)
    logdir = configs.get('logdir', configs.get('log_dir', None))
    if logdir is not None:
        logdir_exists = os.path.isdir(logdir)
        contents = os.listdir(logdir)
        logdir_nonempty = False
        if contents is not None and isinstance(contents, list):
            if len(contents) > 0:
                logdir_nonempty = True

        if logdir_exists and logdir_nonempty and ensure_new:
            raise ValueError(
                f'Nonempty `logdir`, but `ensure_new={ensure_new}')

    # Create `logdir`, `logdir/training/...`' etc
    dirs = io.setup_directories(configs, timestamps=TSTAMPS)
    configs['dirs'] = dirs
    logdir = dirs.get('logdir', dirs.get('log_dir', None))  # type: str
    configs['log_dir'] = logdir
    configs['logdir'] = logdir

    restore_dir = configs.get('restore_from', None)
    if restore_dir is None and not ensure_new:
        candidate = look_for_previous_logdir(logdir)
        if candidate != logdir:
            if candidate.is_dir():
                nckpts = len(list(candidate.rglob('checkpoint')))
                if nckpts > 0:
                    restore_dir = candidate
                    configs['restore_from'] = restore_dir

    if restore_dir is not None and not ensure_new:
        try:
            restored = load_configs_from_logdir(restore_dir)
            if restored is not None:
                io.save_dict(restored, logdir, name='restored_train_configs')
        except FileNotFoundError:
            logger.warning(f'Unable to load configs from {restore_dir}')

    if RANK == 0:
        io.check_else_make_dir(logdir)
        restore_dir = configs.get('restore_dir', None)
        if restore_dir is not None and not ensure_new:
            try:
                restored = load_configs_from_logdir(restore_dir)
                if restored is not None:
                    io.save_dict(restored,
                                 logdir,
                                 name='restored_train_configs')
            except FileNotFoundError:
                logger.warning(f'Unable to load configs from {restore_dir}')
                pass

    return configs
Esempio n. 3
0
def train(flags: AttrDict, x: tf.Tensor = None, restore_x: bool = False):
    """Train model.

    Returns:
        x (tf.Tensor): Batch of configurations
        dynamics (GaugeDynamics): Dynamics object.
        train_data (DataContainer): Object containing train data.
        flags (AttrDict): AttrDict containing flags used.
    """
    dirs = io.setup_directories(flags)
    flags.update({'dirs': dirs})

    if restore_x:
        x = None
        try:
            xfile = os.path.join(dirs.train_dir, 'train_data',
                                 f'x_rank{RANK}-{LOCAL_RANK}.z')
            x = io.loadz(xfile)
        except FileNotFoundError:
            io.log(f'Unable to restore x from {xfile}. Using random init.')

    if x is None:
        x = tf.random.normal(flags.dynamics_config['lattice_shape'])
        x = tf.reshape(x, (x.shape[0], -1))

    dynamics = build_dynamics(flags)
    dynamics.save_config(dirs.config_dir)

    io.log('\n'.join([120 * '*', 'Training L2HMC sampler...']))
    x, train_data = train_dynamics(dynamics, flags, dirs, x=x)

    if IS_CHIEF:
        output_dir = os.path.join(dirs.train_dir, 'outputs')
        train_data.save_data(output_dir)

        params = {
            'beta_init': train_data.data.beta[0],
            'beta_final': train_data.data.beta[-1],
            'eps': dynamics.eps.numpy(),
            'lattice_shape': dynamics.config.lattice_shape,
            'num_steps': dynamics.config.num_steps,
            'net_weights': dynamics.net_weights,
        }
        plot_data(train_data,
                  dirs.train_dir,
                  flags,
                  thermalize=True,
                  params=params)

    io.log('\n'.join(['Done training model', 120 * '*']))
    io.save_dict(dict(flags), dirs.log_dir, 'configs')

    return x, dynamics, train_data, flags
Esempio n. 4
0
def run(
        dynamics: GaugeDynamics,
        configs: dict[str, Any],
        x: tf.Tensor = None,
        beta: float = None,
        runs_dir: str = None,
        make_plots: bool = True,
        therm_frac: float = 0.33,
        num_chains: int = 16,
        save_x: bool = False,
        md_steps: int = 50,
        skip_existing: bool = False,
        save_dataset: bool = True,
        use_hdf5: bool = False,
        skip_keys: list[str] = None,
        run_steps: int = None,
) -> InferenceResults:
    """Run inference. (Note: Higher-level than `run_dynamics`)."""
    if not IS_CHIEF:
        return InferenceResults(None, None, None, None, None)

    if num_chains > 16:
        logger.warning(f'Reducing `num_chains` from: {num_chains} to {16}.')
        num_chains = 16

    if run_steps is None:
        run_steps = configs.get('run_steps', 50000)

    if beta is None:
        beta = configs.get('beta_final', configs.get('beta', None))

    assert beta is not None

    logdir = configs.get('log_dir', configs.get('logdir', None))
    if runs_dir is None:
        rs = 'inference_hmc' if dynamics.config.hmc else 'inference'
        runs_dir = os.path.join(logdir, rs)

    io.check_else_make_dir(runs_dir)
    run_dir = io.make_run_dir(configs=configs, base_dir=runs_dir,
                              beta=beta, skip_existing=skip_existing)
    logger.info(f'run_dir: {run_dir}')
    data_dir = os.path.join(run_dir, 'run_data')
    summary_dir = os.path.join(run_dir, 'summaries')
    log_file = os.path.join(run_dir, 'run_log.txt')
    io.check_else_make_dir([run_dir, data_dir, summary_dir])
    writer = tf.summary.create_file_writer(summary_dir)
    writer.set_as_default()

    configs['logging_steps'] = 1
    if x is None:
        x = convert_to_angle(tf.random.uniform(shape=dynamics.x_shape,
                                               minval=-PI, maxval=PI))

    # == RUN DYNAMICS =======================================================
    nw = dynamics.net_weights
    inf_type = 'HMC' if dynamics.config.hmc else 'inference'
    logger.info(', '.join([f'Running {inf_type}', f'beta={beta}', f'nw={nw}']))
    t0 = time.time()
    results = run_dynamics(dynamics, flags=configs, beta=beta, x=x,
                           writer=writer, save_x=save_x, md_steps=md_steps)
    logger.info(f'Done running {inf_type}. took: {time.time() - t0:.4f} s')
    # =======================================================================

    run_data = results.run_data
    run_data.update_dirs({'log_dir': logdir, 'run_dir': run_dir})
    run_params = {
        'hmc': dynamics.config.hmc,
        'run_dir': run_dir,
        'beta': beta,
        'run_steps': run_steps,
        'plaq_weight': dynamics.plaq_weight,
        'charge_weight': dynamics.charge_weight,
        'x_shape': dynamics.x_shape,
        'num_steps': dynamics.config.num_steps,
        'net_weights': dynamics.net_weights,
        'input_shape': dynamics.x_shape,
    }

    inf_log_fpath = os.path.join(run_dir, 'inference_log.txt')
    with open(inf_log_fpath, 'a') as f:
        f.writelines('\n'.join(results.data_strs))

    traj_len = dynamics.config.num_steps * tf.reduce_mean(dynamics.xeps)
    if hasattr(dynamics, 'xeps') and hasattr(dynamics, 'veps'):
        xeps_avg = tf.reduce_mean(dynamics.xeps)
        veps_avg = tf.reduce_mean(dynamics.veps)
        traj_len = tf.reduce_sum(dynamics.xeps)
        run_params.update({
            'xeps': dynamics.xeps,
            'traj_len': traj_len,
            'veps': dynamics.veps,
            'xeps_avg': xeps_avg,
            'veps_avg': veps_avg,
            'eps_avg': (xeps_avg + veps_avg) / 2.,
        })

    elif hasattr(dynamics, 'eps'):
        run_params.update({
            'eps': dynamics.eps,
        })

    io.save_params(run_params, run_dir, name='run_params')

    save_times = {}
    plot_times = {}
    if make_plots:
        output = plot_data(data_container=run_data,
                           configs=configs,
                           params=run_params,
                           out_dir=run_dir,
                           hmc=dynamics.config.hmc,
                           therm_frac=therm_frac,
                           num_chains=num_chains,
                           profile=True,
                           cmap='crest',
                           logging_steps=1)

        save_times = io.SortedDict(**output['save_times'])
        plot_times = io.SortedDict(**output['plot_times'])
        dt1 = io.SortedDict(**plot_times['data_container.plot_dataset'])
        plot_times['data_container.plot_dataset'] = dt1

        tint_data = {
            'beta': beta,
            'run_dir': run_dir,
            'traj_len': traj_len,
            'run_params': run_params,
            'eps': run_params['eps_avg'],
            'lf': run_params['num_steps'],
            'narr': output['tint_dict']['narr'],
            'tint': output['tint_dict']['tint'],
        }

        t0 = time.time()
        tint_file = os.path.join(run_dir, 'tint_data.z')
        io.savez(tint_data, tint_file, 'tint_data')
        save_times['savez_tint_data'] = time.time() - t0

    t0 = time.time()

    logfile = os.path.join(run_dir, 'inference.log')
    run_data.save_and_flush(data_dir=data_dir,
                            log_file=logfile,
                            use_hdf5=use_hdf5,
                            skip_keys=skip_keys,
                            save_dataset=save_dataset)

    save_times['run_data.flush_data_strs'] = time.time() - t0

    t0 = time.time()
    try:
        run_data.write_to_csv(logdir, run_dir, hmc=dynamics.config.hmc)
    except TypeError:
        logger.warning(f'Unable to write to csv. Continuing...')

    save_times['run_data.write_to_csv'] = time.time() - t0

    # t0 = time.time()
    # io.save_inference(run_dir, run_data)
    # save_times['io.save_inference'] = time.time() - t0
    profdir = os.path.join(run_dir, 'profile_info')
    io.check_else_make_dir(profdir)
    io.save_dict(plot_times, profdir, name='plot_times')
    io.save_dict(save_times, profdir, name='save_times')

    return InferenceResults(dynamics=results.dynamics,
                            x=results.x, x_arr=results.x_arr,
                            run_data=results.run_data,
                            data_strs=results.data_strs)
Esempio n. 5
0
def run_inference_from_log_dir(
        log_dir: str,
        run_steps: int = 50000,
        beta: float = None,
        eps: float = None,
        make_plots: bool = True,
        train_steps: int = 10,
        therm_frac: float = 0.33,
        batch_size: int = 16,
        num_chains: int = 16,
        x: tf.Tensor = None,
) -> InferenceResults:      # (type: InferenceResults)
    """Run inference by loading networks in from `log_dir`."""
    configs = _find_configs(log_dir)
    if configs is None:
        raise FileNotFoundError(
            f'Unable to load configs from `log_dir`: {log_dir}. Exiting'
        )

    if eps is not None:
        configs['dynamics_config']['eps'] = eps

    else:
        try:
            eps_file = os.path.join(log_dir, 'training', 'models', 'eps.z')
            eps = io.loadz(eps_file)
        except FileNotFoundError:
            eps = configs.get('dynamics_config', None).get('eps', None)

    if beta is not None:
        configs.update({'beta': beta, 'beta_final': beta})

    if batch_size is not None:
        batch_size = int(batch_size)
        prev_shape = configs['dynamics_config']['x_shape']
        new_shape = (batch_size, *prev_shape[1:])
        configs['dynamics_config']['x_shape'] = new_shape

    configs = AttrDict(configs)
    dynamics = build_dynamics(configs)
    xnet, vnet = dynamics._load_networks(log_dir)
    dynamics.xnet = xnet
    dynamics.vnet = vnet

    if train_steps > 0:
        dynamics, train_data, x = short_training(train_steps,
                                                 configs.beta_final,
                                                 log_dir, dynamics, x=x)
    else:
        dynamics.compile(loss=dynamics.calc_losses,
                         optimizer=dynamics.optimizer,
                         experimental_run_tf_function=False)

    _, log_str = os.path.split(log_dir)

    if x is None:
        x = convert_to_angle(tf.random.normal(dynamics.x_shape))

    configs['run_steps'] = run_steps
    configs['print_steps'] = max((run_steps // 100, 1))
    configs['md_steps'] = 100
    runs_dir = os.path.join(log_dir, 'LOADED', 'inference')
    io.check_else_make_dir(runs_dir)
    io.save_dict(configs, runs_dir, name='inference_configs')
    inference_results = run(dynamics=dynamics,
                            configs=configs, x=x, beta=beta,
                            runs_dir=runs_dir, make_plots=make_plots,
                            therm_frac=therm_frac, num_chains=num_chains)

    return inference_results
Esempio n. 6
0
def train(
    configs: dict[str, Any],
    x: tf.Tensor = None,
    num_chains: int = 32,
    make_plots: bool = True,
    steps_dict: dict[str, int] = None,
    skip_keys: list[str] = None,
    save_metrics: bool = True,
    save_dataset: bool = False,
    use_hdf5: bool = False,
    custom_betas: Union[list, np.ndarray] = None,
    **kwargs,
) -> TrainOutputs:
    """Train model.

    Returns:
        train_outputs: Dataclass with attributes:
          - x: tf.Tensor
          - logdir: str
          - configs: dict[str, Any]
          - data: DataContainer
          - dynamics: GaugeDynamics
    """
    start = time.time()
    configs = setup_directories(configs)
    try_restore = kwargs.pop('try_restore', True)
    config = setup(configs, x=x, try_restore=try_restore)
    dynamics = config['dynamics']
    dirs = config['dirs']
    configs = config['configs']
    train_data = config['train_data']

    dynamics.save_config(dirs['config_dir'])
    if RANK == 0:
        logfile = os.path.join(os.getcwd(), 'log_dirs.txt')
        logdir = configs.get('logdir', configs.get('log_dir', None))  # str
        io.save_dict(configs, logdir, name='train_configs')
        io.write(f'{logdir}', logfile, 'a')

        restore_from = configs.get('restore_from', None)
        if restore_from is not None:
            restored = load_configs_from_logdir(restore_from)
            if restored is not None:
                io.save_dict(restored, logdir, name='restored_train_configs')

    # -- Train dynamics -----------------------------------------
    #  logger.rule('TRAINING')
    logger.info(f'Starting training at: {io.get_timestamp("%x %X")}')
    t0 = time.time()
    x, train_data = train_dynamics(dynamics,
                                   config,
                                   dirs,
                                   x=x,
                                   steps_dict=steps_dict,
                                   custom_betas=custom_betas,
                                   save_metrics=save_metrics)
    logger.info(f'DONE TRAINING. TOOK: {time.time() - t0:.4f}')
    logger.info(f'Training took: {time.time() - t0:.4f}')
    #  shutdown()

    # ------------------------------------

    if IS_CHIEF:
        logdir = dirs['logdir']
        train_dir = dirs['train_dir']
        io.save_dict(dict(configs), dirs['log_dir'], 'configs')
        # train_data.save_and_flush(dirs['data_dir'])
        train_data.save_and_flush(
            dirs['data_dir'],
            log_file=dirs['log_file'],
            use_hdf5=use_hdf5,
            save_dataset=save_dataset,
            skip_keys=skip_keys)  #['forward', 'backward'])
        logger.info(f'Done training model! took: {time.time() - start:.4f}s')

        b0 = configs.get('beta_init', None)
        b1 = configs.get('beta_final', None)
        if make_plots:
            params = {
                'beta_init': train_data.data.get('beta', [b0])[0],
                'beta_final': train_data.data.get('beta', [b1])[0],
                'x_shape': dynamics.config.x_shape,
                'num_steps': dynamics.config.num_steps,
                'net_weights': dynamics.net_weights,
            }
            t0 = time.time()
            logging_steps = configs.get('logging_steps', None)  # type: int
            _ = plot_data(data_container=train_data,
                          configs=configs,
                          params=params,
                          out_dir=train_dir,
                          therm_frac=0,
                          cmap='flare',
                          num_chains=num_chains,
                          logging_steps=logging_steps,
                          **kwargs)
            dt = time.time() - t0
            logger.debug(
                f'Time spent plotting: {dt}s = {dt // 60}m {(dt % 60):.4f}s')

        if not dynamics.config.hmc:
            dynamics.save_networks(logdir)

    return TrainOutputs(x, dirs['log_dir'], configs, train_data, dynamics)