예제 #1
0
def get_data_model_fns(args):
    dtype = jnp.float64 if args.use_float64 else jnp.float32
    train_set, test_set, task, data_info = data_utils.make_ds_pmap_fullbatch(
        args.dataset_name, dtype, truncate_to=args.subset_train_to)

    net_apply, net_init = models.get_model(args.model_name, data_info)
    net_apply = precision_utils.rewrite_high_precision(net_apply)

    (likelihood_factory, predict_fn, ensemble_upd_fn, metrics_fns,
     tabulate_metrics) = train_utils.get_task_specific_fns(task, data_info)
    log_likelihood_fn = likelihood_factory(args.temperature)
    log_prior_fn, log_prior_diff_fn = losses.make_gaussian_log_prior(
        args.weight_decay, args.temperature)

    key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2)
    init_data = jax.tree_map(lambda elem: elem[0][:1], train_set)
    params, net_state = net_init(net_init_key, init_data, True)

    param_types = tree_utils.tree_get_types(params)
    assert all([
        p_type == dtype for p_type in param_types
    ]), ("Params data types {} do not match specified data type {}".format(
        param_types, dtype))

    return (train_set, test_set, net_apply, params, net_state, key,
            log_likelihood_fn, log_prior_fn, log_prior_diff_fn, predict_fn,
            ensemble_upd_fn, metrics_fns, tabulate_metrics)
def run_visualization():

  subdirname = "posterior_visualization"
  dirname = os.path.join(args.dir, subdirname)
  os.makedirs(dirname, exist_ok=True)
  cmd_args_utils.save_cmd(dirname, None)

  dtype = jnp.float64 if args.use_float64 else jnp.float32
  train_set, test_set, task, data_info = data_utils.make_ds_pmap_fullbatch(
      args.dataset_name, dtype)

  net_apply, net_init = models.get_model(args.model_name, data_info)
  net_apply = precision_utils.rewrite_high_precision(net_apply)
  init_data = jax.tree_map(lambda elem: elem[0][:1], train_set)
  net_init_key = jax.random.PRNGKey(0)
  params, net_state = net_init(net_init_key, init_data, True)

  (likelihood_factory, predict_fn, ensemble_upd_fn, metrics_fns,
   tabulate_metrics) = train_utils.get_task_specific_fns(task, data_info)
  log_likelihood_fn = likelihood_factory(args.temperature)
  log_prior_fn, log_prior_diff_fn = losses.make_gaussian_log_prior(
      args.weight_decay, args.temperature)

  def eval(params, net_state, dataset):
    likelihood, _ = log_likelihood_fn(net_apply, params, net_state, dataset,
                                      True)
    prior = log_prior_fn(params)
    likelihood = jax.lax.psum(likelihood, axis_name="i")
    log_prob = likelihood + prior
    return log_prob, likelihood, prior

  params1 = load_params(args.checkpoint1)
  params2 = load_params(args.checkpoint2)
  params3 = load_params(args.checkpoint3)

  # for params in [params1, params2, params3]:
  #   print(jax.pmap(eval, axis_name='i', in_axes=(None, None, 0))
  #         (params, net_state, train_set))

  u_vec, u_norm, v_vec, v_norm, origin = get_u_v_o(params1, params2, params3)

  u_ts = onp.linspace(args.limit_left, args.limit_right, args.grid_size)
  v_ts = onp.linspace(args.limit_bottom, args.limit_top, args.grid_size)
  n_u, n_v = len(u_ts), len(v_ts)
  log_probs = onp.zeros((n_u, n_v))
  log_likelihoods = onp.zeros((n_u, n_v))
  log_priors = onp.zeros((n_u, n_v))
  grid = onp.zeros((n_u, n_v, 2))

  @functools.partial(jax.pmap, axis_name="i", in_axes=(None, 0))
  def eval_row_of_plot(u_t_, dataset):

    def loop_body(_, v_t_):
      params = jax.tree_multimap(
          lambda u, v, o: o + u * u_t_ * u_norm + v * v_t_ * v_norm, u_vec,
          v_vec, origin)
      logprob, likelihood, prior = eval(params, net_state, dataset)
      arr = jnp.array([logprob, likelihood, prior])
      return None, arr

    _, vals = jax.lax.scan(loop_body, None, v_ts)
    row_logprobs, row_likelihoods, row_priors = jnp.split(vals, [1, 2], axis=1)
    return row_logprobs, row_likelihoods, row_priors

  for u_i, u_t in enumerate(tqdm.tqdm(u_ts)):
    log_probs_i, likelihoods_i, priors_i = eval_row_of_plot(u_t, train_set)
    log_probs_i, likelihoods_i, priors_i = map(
        lambda arr: arr[0], [log_probs_i, likelihoods_i, priors_i])
    log_probs[u_i] = log_probs_i[:, 0]
    log_likelihoods[u_i] = likelihoods_i[:, 0]
    log_priors[u_i] = priors_i[:, 0]
    grid[u_i, :, 0] = onp.array([u_t] * n_v)
    grid[u_i, :, 1] = v_ts

  onp.savez(
      os.path.join(dirname, "surface_plot.npz"),
      log_probs=log_probs,
      log_priors=log_priors,
      log_likelihoods=log_likelihoods,
      grid=grid,
      u_norm=u_norm,
      v_norm=v_norm)

  plt.contour(grid[:, :, 0], grid[:, :, 1], log_probs, zorder=1)
  plt.contourf(grid[:, :, 0], grid[:, :, 1], log_probs, zorder=0, alpha=0.55)
  plt.plot([0., 1., 0.5], [0., 0., 1.], "ro", ms=20, markeredgecolor="k")
  plt.colorbar()
  plt.savefig(os.path.join(dirname, "log_prob.pdf"), bbox_inches="tight")
예제 #3
0
def train_model():

    subdirname = (
        "model_{}_wd_{}_stepsize_{}_trajlen_{}_burnin_{}_{}_mh_{}_temp_{}_"
        "seed_{}".format(args.model_name, args.weight_decay, args.step_size,
                         args.trajectory_len, args.num_burn_in_iterations,
                         args.burn_in_step_size_factor, not args.no_mh,
                         args.temperature, args.seed))
    dirname = os.path.join(args.dir, subdirname)
    os.makedirs(dirname, exist_ok=True)
    tf_writer = tf.summary.create_file_writer(dirname)
    cmd_args_utils.save_cmd(dirname, tf_writer)
    num_devices = len(jax.devices())

    dtype = jnp.float64 if args.use_float64 else jnp.float32
    train_set, test_set, task, data_info = data_utils.make_ds_pmap_fullbatch(
        args.dataset_name, dtype, truncate_to=args.subset_train_to)

    net_apply, net_init = models.get_model(args.model_name, data_info)
    net_apply = precision_utils.rewrite_high_precision(net_apply)

    checkpoint_dict, status = checkpoint_utils.initialize(
        dirname, args.init_checkpoint)

    if status == checkpoint_utils.InitStatus.LOADED_PREEMPTED:
        print("Continuing the run from the last saved checkpoint")
        (start_iteration, params, net_state, key, step_size, _, num_ensembled,
         ensemble_predictions) = (
             checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict))

    else:
        key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2)
        start_iteration = 0
        num_ensembled = 0
        ensemble_predictions = None
        step_size = args.step_size

        if status == checkpoint_utils.InitStatus.INIT_CKPT:
            print("Resuming the run from the provided init_checkpoint")
            _, params, net_state, _, _, _, _, _ = (
                checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict))
        elif status == checkpoint_utils.InitStatus.INIT_RANDOM:
            print("Starting from random initialization with provided seed")
            key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed),
                                                 2)
            init_data = jax.tree_map(lambda elem: elem[0][:1], train_set)
            params, net_state = net_init(net_init_key, init_data, True)
            net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices))
        else:
            raise ValueError(
                "Unknown initialization status: {}".format(status))

    # manually convert all params to dtype
    params = jax.tree_map(lambda p: p.astype(dtype), params)

    param_types = tree_utils.tree_get_types(params)
    assert all([
        p_type == dtype for p_type in param_types
    ]), ("Params data types {} do not match specified data type {}".format(
        param_types, dtype))

    trajectory_len = args.trajectory_len

    (likelihood_factory, predict_fn, ensemble_upd_fn, metrics_fns,
     tabulate_metrics) = train_utils.get_task_specific_fns(task, data_info)
    log_likelihood_fn = likelihood_factory(args.temperature)
    log_prior_fn, log_prior_diff_fn = losses.make_gaussian_log_prior(
        args.weight_decay, args.temperature)

    update, get_log_prob_and_grad = train_utils.make_hmc_update(
        net_apply, log_likelihood_fn, log_prior_fn, log_prior_diff_fn,
        args.max_num_leapfrog_steps, args.target_accept_rate,
        args.step_size_adaptation_speed)

    log_prob, state_grad, log_likelihood, net_state = (get_log_prob_and_grad(
        train_set, params, net_state))

    assert log_prob.dtype == dtype, (
        "log_prob data type {} does not match specified data type {}".format(
            log_prob.dtype, dtype))

    grad_types = tree_utils.tree_get_types(state_grad)
    assert all([
        g_type == dtype for g_type in grad_types
    ]), ("Gradient data types {} do not match specified data type {}".format(
        grad_types, dtype))

    for iteration in range(start_iteration, args.num_iterations):

        # do a linear ramp-down of the step-size in the burn-in phase
        if iteration < args.num_burn_in_iterations:
            alpha = iteration / (args.num_burn_in_iterations - 1)
            initial_step_size = args.step_size
            final_step_size = args.burn_in_step_size_factor * args.step_size
            step_size = final_step_size * alpha + initial_step_size * (1 -
                                                                       alpha)
        in_burnin = (iteration < args.num_burn_in_iterations)
        do_mh_correction = (not args.no_mh) and (not in_burnin)

        start_time = time.time()
        (params, net_state, log_likelihood, state_grad, step_size, key,
         accept_prob, accepted) = (update(train_set, params, net_state,
                                          log_likelihood, state_grad, key,
                                          step_size, trajectory_len,
                                          do_mh_correction))
        iteration_time = time.time() - start_time

        # Evaluation
        net_state, test_predictions = onp.asarray(
            predict_fn(net_apply, params, net_state, test_set))
        net_state, train_predictions = onp.asarray(
            predict_fn(net_apply, params, net_state, train_set))
        test_stats = train_utils.evaluate_metrics(test_predictions,
                                                  test_set[1], metrics_fns)
        train_stats = train_utils.evaluate_metrics(train_predictions,
                                                   train_set[1], metrics_fns)
        train_stats["prior"] = log_prior_fn(params)

        # Ensembling
        if ((not in_burnin) and accepted) or args.no_mh:
            ensemble_predictions = ensemble_upd_fn(ensemble_predictions,
                                                   num_ensembled,
                                                   test_predictions)
            ensemble_stats = train_utils.evaluate_metrics(
                ensemble_predictions, test_set[1], metrics_fns)
            num_ensembled += 1
        else:
            ensemble_stats = {}

        # Save the checkpoint
        checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration)
        checkpoint_path = os.path.join(dirname, checkpoint_name)
        checkpoint_dict = checkpoint_utils.make_hmc_checkpoint_dict(
            iteration, params, net_state, key, step_size, accepted,
            num_ensembled, ensemble_predictions)
        checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict)

        # Logging
        other_logs = {
            "telemetry/iteration": iteration,
            "telemetry/iteration_time": iteration_time,
            "telemetry/accept_prob": accept_prob,
            "telemetry/accepted": accepted,
            "telemetry/num_ensembled": num_ensembled,
            "hypers/step_size": step_size,
            "hypers/trajectory_len": trajectory_len,
            "hypers/weight_decay": args.weight_decay,
            "hypers/temperature": args.temperature,
            "debug/do_mh_correction": float(do_mh_correction),
            "debug/in_burnin": float(in_burnin)
        }
        logging_dict = logging_utils.make_logging_dict(train_stats, test_stats,
                                                       ensemble_stats)
        logging_dict.update(other_logs)

        with tf_writer.as_default():
            for stat_name, stat_val in logging_dict.items():
                tf.summary.scalar(stat_name, stat_val, step=iteration)
        tabulate_dict = OrderedDict()
        tabulate_dict["i"] = iteration
        tabulate_dict["t"] = iteration_time
        tabulate_dict["accept_p"] = accept_prob
        tabulate_dict["accepted"] = accepted
        for metric_name in tabulate_metrics:
            if metric_name in logging_dict:
                tabulate_dict[metric_name] = logging_dict[metric_name]
            else:
                tabulate_dict[metric_name] = None

        table = logging_utils.make_table(tabulate_dict,
                                         iteration - start_iteration,
                                         args.tabulate_freq)
        print(table)