コード例 #1
0
def train_model():
  # Initialize training directory
  dirname, tf_writer = get_dirname_tfwriter(args)

  # Initialize data, model, losses and metrics
  (train_set, test_set, net_apply, params, net_state, key, log_likelihood_fn, _,
   _, predict_fn, ensemble_upd_fn, metrics_fns,
   tabulate_metrics) = script_utils.get_data_model_fns(args)

  # Convert the model to MFVI parameterization
  net_apply, mean_apply, _, params, net_state = vi.get_mfvi_model_fn(
      net_apply, params, net_state, seed=0, sigma_init=args.vi_sigma_init)
  prior_kl = vi.make_kl_with_gaussian_prior(args.weight_decay, args.temperature)
  vi_ensemble_predict_fn = make_vi_ensemble_predict_fn(predict_fn,
                                                       ensemble_upd_fn, args)

  # Initialize step-size schedule and optimizer
  num_batches, total_steps = script_utils.get_num_batches_total_steps(
      args, train_set)
  num_devices = len(jax.devices())
  lr_schedule = optim_utils.make_cosine_lr_schedule(args.init_step_size,
                                                    total_steps)
  optimizer = get_optimizer(lr_schedule, args)

  # Initialize variables
  opt_state = optimizer.init(params)
  net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices))
  key = jax.random.split(key, num_devices)
  init_dict = checkpoint_utils.make_sgd_checkpoint_dict(-1, params, net_state,
                                                        opt_state, key)
  init_dict = script_utils.get_initialization_dict(dirname, args, init_dict)
  start_iteration, params, net_state, opt_state, key = (
      checkpoint_utils.parse_sgd_checkpoint_dict(init_dict))
  start_iteration += 1

  # Loading mean checkpoint
  if args.mean_init_checkpoint:
    print("Initializing VI mean from the provided checkpoint")
    ckpt_dict = checkpoint_utils.load_checkpoint(args.mean_init_checkpoint)
    mean_params = checkpoint_utils.parse_sgd_checkpoint_dict(ckpt_dict)[1]
    params["mean"] = mean_params

  # Define train epoch
  sgd_train_epoch = script_utils.time_fn(
      train_utils.make_sgd_train_epoch(net_apply, log_likelihood_fn, prior_kl,
                                       optimizer, num_batches))

  # Train
  for iteration in range(start_iteration, args.num_epochs):

    (params, net_state, opt_state, elbo_avg, key), iteration_time = (
        sgd_train_epoch(params, net_state, opt_state, train_set, key))

    # Evaluate the model
    train_stats = {"ELBO": elbo_avg, "KL": prior_kl(params)}
    test_stats, ensemble_stats = {}, {}
    if (iteration % args.eval_freq == 0) or (iteration == args.num_epochs - 1):
      # Evaluate the mean
      _, test_predictions, train_predictions, test_stats, train_stats_ = (
          script_utils.evaluate(mean_apply, params, net_state, train_set,
                                test_set, predict_fn, metrics_fns, prior_kl))
      train_stats.update(train_stats_)
      del train_stats["prior"]

      # Evaluate the ensemble
      net_state, ensemble_predictions = onp.asarray(
          vi_ensemble_predict_fn(net_apply, params, net_state, test_set))
      ensemble_stats = train_utils.evaluate_metrics(ensemble_predictions,
                                                    test_set[1], metrics_fns)

    # Save checkpoint
    if iteration % args.save_freq == 0 or iteration == args.num_epochs - 1:
      checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration)
      checkpoint_path = os.path.join(dirname, checkpoint_name)
      checkpoint_dict = checkpoint_utils.make_sgd_checkpoint_dict(
          iteration, params, net_state, opt_state, key)
      checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict)

    # Log results
    other_logs = script_utils.get_common_logs(iteration, iteration_time, args)
    other_logs["hypers/step_size"] = lr_schedule(opt_state[-1].count)
    other_logs["hypers/momentum"] = args.momentum_decay
    logging_dict = logging_utils.make_logging_dict(train_stats, test_stats,
                                                   ensemble_stats)
    logging_dict.update(other_logs)
    script_utils.write_to_tensorboard(tf_writer, logging_dict, iteration)
    # Add a histogram of MFVI stds
    with tf_writer.as_default():
      stds = jax.tree_map(jax.nn.softplus, params["inv_softplus_std"])
      stds = jnp.concatenate([std.reshape(-1) for std in jax.tree_leaves(stds)])
      tf.summary.histogram("MFVI/param_stds", stds, step=iteration)

    tabulate_dict = script_utils.get_tabulate_dict(tabulate_metrics,
                                                   logging_dict)
    tabulate_dict["lr"] = lr_schedule(opt_state[-1].count)
    table = logging_utils.make_table(tabulate_dict, iteration - start_iteration,
                                     args.tabulate_freq)
    print(table)
コード例 #2
0
def train_model():
    # Initialize training directory
    dirname, tf_writer = get_dirname_tfwriter(args)

    # Initialize data, model, losses and metrics
    (train_set, test_set, net_apply, params, net_state, key, log_likelihood_fn,
     log_prior_fn, _, predict_fn, ensemble_upd_fn, metrics_fns,
     tabulate_metrics) = script_utils.get_data_model_fns(args)

    # Initialize step-size schedule and optimizer
    num_batches, total_steps = script_utils.get_num_batches_total_steps(
        args, train_set)
    num_devices = len(jax.devices())
    lr_schedule = optim_utils.make_cosine_lr_schedule(args.init_step_size,
                                                      total_steps)
    optimizer = optim_utils.make_sgd_optimizer(
        lr_schedule, momentum_decay=args.momentum_decay)

    # Initialize variables
    opt_state = optimizer.init(params)
    net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices))
    key = jax.random.split(key, num_devices)
    init_dict = checkpoint_utils.make_sgd_checkpoint_dict(
        -1, params, net_state, opt_state, key)
    init_dict = script_utils.get_initialization_dict(dirname, args, init_dict)
    start_iteration, params, net_state, opt_state, key = (
        checkpoint_utils.parse_sgd_checkpoint_dict(init_dict))
    start_iteration += 1

    # Define train epoch
    sgd_train_epoch = script_utils.time_fn(
        train_utils.make_sgd_train_epoch(net_apply, log_likelihood_fn,
                                         log_prior_fn, optimizer, num_batches))

    # Train
    for iteration in range(start_iteration, args.num_epochs):

        (params, net_state, opt_state, logprob_avg,
         key), iteration_time = (sgd_train_epoch(params, net_state, opt_state,
                                                 train_set, key))

        # Evaluate the model
        train_stats, test_stats = {"log_prob": logprob_avg}, {}
        if (iteration % args.eval_freq == 0) or (iteration
                                                 == args.num_epochs - 1):
            _, test_predictions, train_predictions, test_stats, train_stats_ = (
                script_utils.evaluate(net_apply, params, net_state, train_set,
                                      test_set, predict_fn, metrics_fns,
                                      log_prior_fn))
            train_stats.update(train_stats_)

        # Save checkpoint
        if iteration % args.save_freq == 0 or iteration == args.num_epochs - 1:
            checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration)
            checkpoint_path = os.path.join(dirname, checkpoint_name)
            checkpoint_dict = checkpoint_utils.make_sgd_checkpoint_dict(
                iteration, params, net_state, opt_state, key)
            checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict)

        # Log results
        other_logs = script_utils.get_common_logs(iteration, iteration_time,
                                                  args)
        other_logs["hypers/step_size"] = lr_schedule(opt_state[-1].count)
        other_logs["hypers/momentum"] = args.momentum_decay
        logging_dict = logging_utils.make_logging_dict(train_stats, test_stats,
                                                       {})
        logging_dict.update(other_logs)
        script_utils.write_to_tensorboard(tf_writer, logging_dict, iteration)

        tabulate_dict = script_utils.get_tabulate_dict(tabulate_metrics,
                                                       logging_dict)
        tabulate_dict["lr"] = lr_schedule(opt_state[-1].count)
        table = logging_utils.make_table(tabulate_dict,
                                         iteration - start_iteration,
                                         args.tabulate_freq)
        print(table)
コード例 #3
0
def train_model():
    # Initialize training directory
    dirname, tf_writer = get_dirname_tfwriter(args)

    # Initialize data, model, losses and metrics
    (train_set, test_set, net_apply, params, net_state, key, log_likelihood_fn,
     log_prior_fn, _, predict_fn, ensemble_upd_fn, metrics_fns,
     tabulate_metrics) = script_utils.get_data_model_fns(args)

    # Initialize step-size schedule and optimizer
    num_batches, total_steps = script_utils.get_num_batches_total_steps(
        args, train_set)
    num_devices = len(jax.devices())
    lr_schedule = get_lr_schedule(num_batches, args)
    preconditioner = get_preconditioner(args)
    optimizer = sgmcmc.sgld_gradient_update(lr_schedule,
                                            momentum_decay=args.momentum_decay,
                                            seed=args.seed,
                                            preconditioner=preconditioner)

    # Initialize variables
    opt_state = optimizer.init(params)
    net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices))
    key = jax.random.split(key, num_devices)
    init_dict = checkpoint_utils.make_sgmcmc_checkpoint_dict(
        -1, params, net_state, opt_state, key, 0, None, None)
    init_dict = script_utils.get_initialization_dict(dirname, args, init_dict)
    (start_iteration, params, net_state, opt_state, key, num_ensembled, _,
     ensemble_predictions) = (
         checkpoint_utils.parse_sgmcmc_checkpoint_dict(init_dict))
    start_iteration += 1

    # Define train epoch
    sgmcmc_train_epoch = script_utils.time_fn(
        train_utils.make_sgd_train_epoch(net_apply, log_likelihood_fn,
                                         log_prior_fn, optimizer, num_batches))

    # Train
    for iteration in range(start_iteration, args.num_epochs):

        (params, net_state, opt_state, logprob_avg,
         key), iteration_time = (sgmcmc_train_epoch(params, net_state,
                                                    opt_state, train_set, key))

        is_evaluation_epoch, is_ensembling_epoch, is_save_epoch = (
            is_eval_ens_save_epoch(iteration, args))

        # Evaluate the model
        train_stats, test_stats = {"log_prob": logprob_avg}, {}
        if is_evaluation_epoch or is_ensembling_epoch:
            _, test_predictions, train_predictions, test_stats, train_stats_ = (
                script_utils.evaluate(net_apply, params, net_state, train_set,
                                      test_set, predict_fn, metrics_fns,
                                      log_prior_fn))
            train_stats.update(train_stats_)

        # Ensemble predictions
        if is_ensembling_epoch:
            ensemble_predictions = ensemble_upd_fn(ensemble_predictions,
                                                   num_ensembled,
                                                   test_predictions)
            ensemble_stats = train_utils.evaluate_metrics(
                ensemble_predictions, test_set[1], metrics_fns)
            num_ensembled += 1
        else:
            ensemble_stats = {}
            test_predictions = None

        # Save checkpoint
        if is_save_epoch:
            checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration)
            checkpoint_path = os.path.join(dirname, checkpoint_name)
            checkpoint_dict = checkpoint_utils.make_sgmcmc_checkpoint_dict(
                iteration, params, net_state, opt_state, key, num_ensembled,
                test_predictions, ensemble_predictions)
            checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict)

        # Log results
        other_logs = script_utils.get_common_logs(iteration, iteration_time,
                                                  args)
        other_logs["hypers/step_size"] = lr_schedule(opt_state.count)
        other_logs["hypers/momentum"] = args.momentum_decay
        other_logs["telemetry/num_ensembled"] = num_ensembled
        logging_dict = logging_utils.make_logging_dict(train_stats, test_stats,
                                                       ensemble_stats)
        logging_dict.update(other_logs)
        script_utils.write_to_tensorboard(tf_writer, logging_dict, iteration)

        tabulate_dict = script_utils.get_tabulate_dict(tabulate_metrics,
                                                       logging_dict)
        tabulate_dict["lr"] = lr_schedule(opt_state.count)
        table = logging_utils.make_table(tabulate_dict,
                                         iteration - start_iteration,
                                         args.tabulate_freq)
        print(table)
コード例 #4
0
ファイル: run_hmc.py プロジェクト: yan0626/google-research
def train_model():

    subdirname = (
        "model_{}_wd_{}_stepsize_{}_trajlen_{}_burnin_{}_{}_mh_{}_temp_{}_"
        "seed_{}".format(args.model_name, args.weight_decay, args.step_size,
                         args.trajectory_len, args.num_burn_in_iterations,
                         args.burn_in_step_size_factor, not args.no_mh,
                         args.temperature, args.seed))
    dirname = os.path.join(args.dir, subdirname)
    os.makedirs(dirname, exist_ok=True)
    tf_writer = tf.summary.create_file_writer(dirname)
    cmd_args_utils.save_cmd(dirname, tf_writer)
    num_devices = len(jax.devices())

    dtype = jnp.float64 if args.use_float64 else jnp.float32
    train_set, test_set, task, data_info = data_utils.make_ds_pmap_fullbatch(
        args.dataset_name, dtype, truncate_to=args.subset_train_to)

    net_apply, net_init = models.get_model(args.model_name, data_info)
    net_apply = precision_utils.rewrite_high_precision(net_apply)

    checkpoint_dict, status = checkpoint_utils.initialize(
        dirname, args.init_checkpoint)

    if status == checkpoint_utils.InitStatus.LOADED_PREEMPTED:
        print("Continuing the run from the last saved checkpoint")
        (start_iteration, params, net_state, key, step_size, _, num_ensembled,
         ensemble_predictions) = (
             checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict))

    else:
        key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2)
        start_iteration = 0
        num_ensembled = 0
        ensemble_predictions = None
        step_size = args.step_size

        if status == checkpoint_utils.InitStatus.INIT_CKPT:
            print("Resuming the run from the provided init_checkpoint")
            _, params, net_state, _, _, _, _, _ = (
                checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict))
        elif status == checkpoint_utils.InitStatus.INIT_RANDOM:
            print("Starting from random initialization with provided seed")
            key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed),
                                                 2)
            init_data = jax.tree_map(lambda elem: elem[0][:1], train_set)
            params, net_state = net_init(net_init_key, init_data, True)
            net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices))
        else:
            raise ValueError(
                "Unknown initialization status: {}".format(status))

    # manually convert all params to dtype
    params = jax.tree_map(lambda p: p.astype(dtype), params)

    param_types = tree_utils.tree_get_types(params)
    assert all([
        p_type == dtype for p_type in param_types
    ]), ("Params data types {} do not match specified data type {}".format(
        param_types, dtype))

    trajectory_len = args.trajectory_len

    (likelihood_factory, predict_fn, ensemble_upd_fn, metrics_fns,
     tabulate_metrics) = train_utils.get_task_specific_fns(task, data_info)
    log_likelihood_fn = likelihood_factory(args.temperature)
    log_prior_fn, log_prior_diff_fn = losses.make_gaussian_log_prior(
        args.weight_decay, args.temperature)

    update, get_log_prob_and_grad = train_utils.make_hmc_update(
        net_apply, log_likelihood_fn, log_prior_fn, log_prior_diff_fn,
        args.max_num_leapfrog_steps, args.target_accept_rate,
        args.step_size_adaptation_speed)

    log_prob, state_grad, log_likelihood, net_state = (get_log_prob_and_grad(
        train_set, params, net_state))

    assert log_prob.dtype == dtype, (
        "log_prob data type {} does not match specified data type {}".format(
            log_prob.dtype, dtype))

    grad_types = tree_utils.tree_get_types(state_grad)
    assert all([
        g_type == dtype for g_type in grad_types
    ]), ("Gradient data types {} do not match specified data type {}".format(
        grad_types, dtype))

    for iteration in range(start_iteration, args.num_iterations):

        # do a linear ramp-down of the step-size in the burn-in phase
        if iteration < args.num_burn_in_iterations:
            alpha = iteration / (args.num_burn_in_iterations - 1)
            initial_step_size = args.step_size
            final_step_size = args.burn_in_step_size_factor * args.step_size
            step_size = final_step_size * alpha + initial_step_size * (1 -
                                                                       alpha)
        in_burnin = (iteration < args.num_burn_in_iterations)
        do_mh_correction = (not args.no_mh) and (not in_burnin)

        start_time = time.time()
        (params, net_state, log_likelihood, state_grad, step_size, key,
         accept_prob, accepted) = (update(train_set, params, net_state,
                                          log_likelihood, state_grad, key,
                                          step_size, trajectory_len,
                                          do_mh_correction))
        iteration_time = time.time() - start_time

        # Evaluation
        net_state, test_predictions = onp.asarray(
            predict_fn(net_apply, params, net_state, test_set))
        net_state, train_predictions = onp.asarray(
            predict_fn(net_apply, params, net_state, train_set))
        test_stats = train_utils.evaluate_metrics(test_predictions,
                                                  test_set[1], metrics_fns)
        train_stats = train_utils.evaluate_metrics(train_predictions,
                                                   train_set[1], metrics_fns)
        train_stats["prior"] = log_prior_fn(params)

        # Ensembling
        if ((not in_burnin) and accepted) or args.no_mh:
            ensemble_predictions = ensemble_upd_fn(ensemble_predictions,
                                                   num_ensembled,
                                                   test_predictions)
            ensemble_stats = train_utils.evaluate_metrics(
                ensemble_predictions, test_set[1], metrics_fns)
            num_ensembled += 1
        else:
            ensemble_stats = {}

        # Save the checkpoint
        checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration)
        checkpoint_path = os.path.join(dirname, checkpoint_name)
        checkpoint_dict = checkpoint_utils.make_hmc_checkpoint_dict(
            iteration, params, net_state, key, step_size, accepted,
            num_ensembled, ensemble_predictions)
        checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict)

        # Logging
        other_logs = {
            "telemetry/iteration": iteration,
            "telemetry/iteration_time": iteration_time,
            "telemetry/accept_prob": accept_prob,
            "telemetry/accepted": accepted,
            "telemetry/num_ensembled": num_ensembled,
            "hypers/step_size": step_size,
            "hypers/trajectory_len": trajectory_len,
            "hypers/weight_decay": args.weight_decay,
            "hypers/temperature": args.temperature,
            "debug/do_mh_correction": float(do_mh_correction),
            "debug/in_burnin": float(in_burnin)
        }
        logging_dict = logging_utils.make_logging_dict(train_stats, test_stats,
                                                       ensemble_stats)
        logging_dict.update(other_logs)

        with tf_writer.as_default():
            for stat_name, stat_val in logging_dict.items():
                tf.summary.scalar(stat_name, stat_val, step=iteration)
        tabulate_dict = OrderedDict()
        tabulate_dict["i"] = iteration
        tabulate_dict["t"] = iteration_time
        tabulate_dict["accept_p"] = accept_prob
        tabulate_dict["accepted"] = accepted
        for metric_name in tabulate_metrics:
            if metric_name in logging_dict:
                tabulate_dict[metric_name] = logging_dict[metric_name]
            else:
                tabulate_dict[metric_name] = None

        table = logging_utils.make_table(tabulate_dict,
                                         iteration - start_iteration,
                                         args.tabulate_freq)
        print(table)