Beispiel #1
0
    def test_log_prob_fn(self):
        """Test log-prob fn constructed by train_utils."""
        weight_decay = 30.
        net, params, train_set, test_set, _, _ = self._prepare()
        log_likelihood_fn = nn_loss.xent_log_likelihood
        log_prior_fn, log_prior_diff = (nn_loss.make_gaussian_log_prior(
            weight_decay=weight_decay))

        def get_log_prob(dataset):
            _, _, fn = (train_utils.make_hmc_update_eval_fns(
                net, dataset, test_set, log_likelihood_fn, log_prior_fn,
                log_prior_diff))
            return fn

        log_prob_and_grad_fn = get_log_prob(train_set)
        log_prob, grad, likelihood, prior = log_prob_and_grad_fn(params)
        self.assertEqual(log_prob, likelihood + prior)

        n_split = train_set[0].shape[1]
        first_half = jax.tree_map(lambda x: x[:, :n_split], train_set)
        second_half = jax.tree_map(lambda x: x[:, n_split:], train_set)
        log_prob_and_grad_first_half_fn = get_log_prob(first_half)
        log_prob_and_grad_second_half_fn = get_log_prob(second_half)
        log_prob_fh, grad_fh, likelihood_fh, prior_fh = (
            log_prob_and_grad_first_half_fn(params))

        log_prob_sh, grad_sh, likelihood_sh, prior_sh = (
            log_prob_and_grad_second_half_fn(params))

        self.assertEqual(likelihood, likelihood_fh + likelihood_sh)
        self.assertEqual(prior, prior_fh)
        self.assertEqual(prior, prior_sh)
        self.assertEqual(log_prob, log_prob_fh + log_prob_sh - prior_sh)
Beispiel #2
0
    def test_gaussian_prior_difference_precision(self):
        """Test that we compute gaussian prior difference with high precision"""
        weight_decay = 30.
        net, params, _, _, init_data, init_key = self._prepare()
        init_key, = jax.random.split(init_key, 1)
        params2 = net.init(init_key, init_data)
        _, prior_diff_fn = nn_loss.make_gaussian_log_prior(
            weight_decay=weight_decay)

        def prior_diff_onp_fn(params1, params2):
            diff = sum([
                onp.sum(onp.array(p1)**2 - onp.array(p2)**2) for p1, p2 in zip(
                    jax.tree_leaves(params1), jax.tree_leaves(params2))
            ])
            return -0.5 * weight_decay * diff

        prior_diff = prior_diff_fn(params, params2)
        prior_diff_np = prior_diff_onp_fn(params, params2)
        self.assertLess(jnp.abs(prior_diff_np - prior_diff), 1e-2)
Beispiel #3
0
def train_model():
    subdirname = (
        "sgld_wd_{}_stepsizes_{}_{}_batchsize_{}_epochs{}_{}_temp_{}_seed_{}".
        format(args.weight_decay, args.init_step_size, args.final_step_size,
               args.batch_size, args.num_epochs, args.num_burnin_epochs,
               args.temperature, args.seed))
    dirname = os.path.join(args.dir, subdirname)
    os.makedirs(dirname, exist_ok=True)
    tf_writer = tf.summary.create_file_writer(dirname)
    cmd_args_utils.save_cmd(dirname, tf_writer)

    dtype = jnp.float64 if args.use_float64 else jnp.float32
    train_set, test_set, num_classes = data.make_ds_pmap_fullbatch(
        args.dataset_name, dtype)

    net_apply, net_init = models.get_model(args.model_name, num_classes)
    net_apply = precision_utils.rewrite_high_precision(net_apply)

    log_likelihood_fn = nn_loss.make_xent_log_likelihood(
        num_classes, args.temperature)
    log_prior_fn, _ = nn_loss.make_gaussian_log_prior(args.weight_decay,
                                                      args.temperature)

    num_data = jnp.size(train_set[1])
    num_batches = num_data // args.batch_size
    num_devices = len(jax.devices())

    burnin_steps = num_batches * args.num_burnin_epochs
    lr_schedule = train_utils.make_cosine_lr_schedule_with_burnin(
        args.init_step_size, args.final_step_size, burnin_steps)
    optimizer = sgmcmc.sgld_gradient_update(lr_schedule, args.seed)

    checkpoint_dict, status = checkpoint_utils.initialize(
        dirname, args.init_checkpoint)

    if status == checkpoint_utils.InitStatus.LOADED_PREEMPTED:
        print("Continuing the run from the last saved checkpoint")
        (start_iteration, params, net_state, opt_state, key, num_ensembled,
         ensemble_predicted_probs) = (
             checkpoint_utils.parse_sgmcmc_checkpoint_dict(checkpoint_dict))

    else:
        key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2)
        start_iteration = 0
        num_ensembled = 0
        ensemble_predicted_probs = None
        key = jax.random.split(key, num_devices)

        if status == checkpoint_utils.InitStatus.INIT_CKPT:
            print("Resuming the run from the provided init_checkpoint")
            _, params, net_state, _, _, _, _ = (
                checkpoint_utils.parse_sgmcmc_checkpoint_dict(checkpoint_dict))
            opt_state = optimizer.init(params)
        elif status == checkpoint_utils.InitStatus.INIT_RANDOM:
            print("Starting from random initialization with provided seed")
            init_data = jax.tree_map(lambda elem: elem[0][:1], train_set)
            params, net_state = net_init(net_init_key, init_data, True)
            opt_state = optimizer.init(params)
            net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices))
        else:
            raise ValueError(
                "Unknown initialization status: {}".format(status))

    sgmcmc_train_epoch, evaluate = train_utils.make_sgd_train_epoch(
        net_apply, log_likelihood_fn, log_prior_fn, optimizer, num_batches)
    ensemble_acc = None

    param_types = tree_utils._get_types(params)
    assert all([
        p_type == dtype for p_type in param_types
    ]), ("Params data types {} do not match specified data type {}".format(
        param_types, dtype))

    for iteration in range(start_iteration, args.num_epochs):

        start_time = time.time()
        params, net_state, opt_state, logprob_avg, key = sgmcmc_train_epoch(
            params, net_state, opt_state, train_set, key)
        iteration_time = time.time() - start_time

        tabulate_dict = OrderedDict()
        tabulate_dict["iteration"] = iteration
        tabulate_dict["step_size"] = lr_schedule(opt_state.count)
        tabulate_dict["train_logprob"] = logprob_avg
        tabulate_dict["train_acc"] = None
        tabulate_dict["test_logprob"] = None
        tabulate_dict["test_acc"] = None
        tabulate_dict["ensemble_acc"] = ensemble_acc
        tabulate_dict["n_ens"] = num_ensembled
        tabulate_dict["time"] = iteration_time

        with tf_writer.as_default():
            tf.summary.scalar("train/log_prob_running",
                              logprob_avg,
                              step=iteration)
            tf.summary.scalar("hypers/step_size",
                              lr_schedule(opt_state.count),
                              step=iteration)
            tf.summary.scalar("debug/iteration_time",
                              iteration_time,
                              step=iteration)

        if iteration % args.save_freq == 0 or iteration == args.num_epochs - 1:
            checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration)
            checkpoint_path = os.path.join(dirname, checkpoint_name)
            checkpoint_dict = checkpoint_utils.make_sgmcmc_checkpoint_dict(
                iteration, params, net_state, opt_state, key, num_ensembled,
                ensemble_predicted_probs)
            checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict)

        if (iteration % args.eval_freq == 0) or (iteration
                                                 == args.num_epochs - 1):
            test_log_prob, test_acc, test_ce, _ = evaluate(
                params, net_state, test_set)
            train_log_prob, train_acc, train_ce, prior = (evaluate(
                params, net_state, train_set))

            tabulate_dict["train_logprob"] = train_log_prob
            tabulate_dict["test_logprob"] = test_log_prob
            tabulate_dict["train_acc"] = train_acc
            tabulate_dict["test_acc"] = test_acc
            with tf_writer.as_default():
                tf.summary.scalar("train/log_prob",
                                  train_log_prob,
                                  step=iteration)
                tf.summary.scalar("test/log_prob",
                                  test_log_prob,
                                  step=iteration)
                tf.summary.scalar("train/log_likelihood",
                                  train_ce,
                                  step=iteration)
                tf.summary.scalar("test/log_likelihood",
                                  test_ce,
                                  step=iteration)
                tf.summary.scalar("train/accuracy", train_acc, step=iteration)
                tf.summary.scalar("test/accuracy", test_acc, step=iteration)

        if ((iteration > args.num_burnin_epochs) and
            ((iteration - args.num_burnin_epochs) % args.ensemble_freq == 0)):
            ensemble_predicted_probs, ensemble_acc, num_ensembled = (
                train_utils.update_ensemble(net_apply, params, net_state,
                                            test_set, num_ensembled,
                                            ensemble_predicted_probs))
            tabulate_dict["ensemble_acc"] = ensemble_acc
            tabulate_dict["n_ens"] = num_ensembled
            test_labels = onp.asarray(test_set[1])

            ensemble_nll = metrics.nll(ensemble_predicted_probs, test_labels)
            ensemble_calibration = metrics.calibration_curve(
                ensemble_predicted_probs, test_labels)

            with tf_writer.as_default():
                tf.summary.scalar("test/ens_accuracy",
                                  ensemble_acc,
                                  step=iteration)
                tf.summary.scalar("test/ens_ece",
                                  ensemble_calibration["ece"],
                                  step=iteration)
                tf.summary.scalar("test/ens_nll", ensemble_nll, step=iteration)
                tf.summary.scalar("debug/n_ens", num_ensembled, step=iteration)

        table = tabulate_utils.make_table(tabulate_dict,
                                          iteration - start_iteration,
                                          args.tabulate_freq)
        print(table)
def train_model():
    subdirname = "sgd_wd_{}_stepsize_{}_batchsize_{}_momentum_{}_seed_{}".format(
        args.weight_decay, args.init_step_size, args.batch_size,
        args.momentum_decay, args.seed)
    dirname = os.path.join(args.dir, subdirname)
    os.makedirs(dirname, exist_ok=True)
    tf_writer = tf.summary.create_file_writer(dirname)
    cmd_args_utils.save_cmd(dirname, tf_writer)

    dtype = jnp.float64 if args.use_float64 else jnp.float32
    train_set, test_set, num_classes = data.make_ds_pmap_fullbatch(
        args.dataset_name, dtype)

    net_apply, net_init = models.get_model(args.model_name, num_classes)

    log_likelihood_fn = nn_loss.make_xent_log_likelihood(num_classes, 1.)
    log_prior_fn, _ = (nn_loss.make_gaussian_log_prior(args.weight_decay, 1.))

    num_data = jnp.size(train_set[1])
    num_batches = num_data // args.batch_size
    num_devices = len(jax.devices())

    total_steps = num_batches * args.num_epochs
    lr_schedule = train_utils.make_cosine_lr_schedule(args.init_step_size,
                                                      total_steps)
    optimizer = train_utils.make_optimizer(lr_schedule,
                                           momentum_decay=args.momentum_decay)

    checkpoint_dict, status = checkpoint_utils.initialize(
        dirname, args.init_checkpoint)
    if status == checkpoint_utils.InitStatus.INIT_RANDOM:
        key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2)
        print("Starting from random initialization with provided seed")
        init_data = jax.tree_map(lambda elem: elem[0][:1], train_set)
        params, net_state = net_init(net_init_key, init_data, True)
        opt_state = optimizer.init(params)
        net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices))
        key = jax.random.split(key, num_devices)
        start_iteration = 0
    else:
        start_iteration, params, net_state, opt_state, key = (
            checkpoint_utils.parse_sgd_checkpoint_dict(checkpoint_dict))
        if status == checkpoint_utils.InitStatus.INIT_CKPT:
            print("Resuming the run from the provided init_checkpoint")
            # TODO: fix -- we should only load the parameters in this case
        elif status == checkpoint_utils.InitStatus.LOADED_PREEMPTED:
            print("Continuing the run from the last saved checkpoint")

    sgd_train_epoch, evaluate = train_utils.make_sgd_train_epoch(
        net_apply, log_likelihood_fn, log_prior_fn, optimizer, num_batches)

    for iteration in range(start_iteration, args.num_epochs):

        start_time = time.time()
        params, net_state, opt_state, logprob_avg, key = sgd_train_epoch(
            params, net_state, opt_state, train_set, key)
        iteration_time = time.time() - start_time

        tabulate_dict = OrderedDict()
        tabulate_dict["iteration"] = iteration
        tabulate_dict["step_size"] = lr_schedule(opt_state[-1].count)
        tabulate_dict["train_logprob"] = logprob_avg
        tabulate_dict["train_acc"] = None
        tabulate_dict["test_logprob"] = None
        tabulate_dict["test_acc"] = None
        tabulate_dict["time"] = iteration_time

        with tf_writer.as_default():
            tf.summary.scalar("train/log_prob_running",
                              logprob_avg,
                              step=iteration)
            tf.summary.scalar("hypers/step_size",
                              lr_schedule(opt_state[-1].count),
                              step=iteration)
            tf.summary.scalar("debug/iteration_time",
                              iteration_time,
                              step=iteration)

        if iteration % args.save_freq == 0 or iteration == args.num_epochs - 1:
            checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration)
            checkpoint_path = os.path.join(dirname, checkpoint_name)
            checkpoint_dict = checkpoint_utils.make_sgd_checkpoint_dict(
                iteration, params, net_state, opt_state, key)
            checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict)

        if (iteration % args.eval_freq == 0) or (iteration
                                                 == args.num_epochs - 1):
            test_log_prob, test_acc, test_ce, _ = evaluate(
                params, net_state, test_set)
            train_log_prob, train_acc, train_ce, prior = (evaluate(
                params, net_state, train_set))

            tabulate_dict["train_logprob"] = train_log_prob
            tabulate_dict["test_logprob"] = test_log_prob
            tabulate_dict["train_acc"] = train_acc
            tabulate_dict["test_acc"] = test_acc

            with tf_writer.as_default():
                tf.summary.scalar("train/log_prob",
                                  train_log_prob,
                                  step=iteration)
                tf.summary.scalar("test/log_prob",
                                  test_log_prob,
                                  step=iteration)
                tf.summary.scalar("train/log_likelihood",
                                  train_ce,
                                  step=iteration)
                tf.summary.scalar("test/log_likelihood",
                                  test_ce,
                                  step=iteration)
                tf.summary.scalar("train/accuracy", train_acc, step=iteration)
                tf.summary.scalar("test/accuracy", test_acc, step=iteration)

        table = tabulate_utils.make_table(tabulate_dict,
                                          iteration - start_iteration,
                                          args.tabulate_freq)
        print(table)
Beispiel #5
0
def train_model():
  
  subdirname = (
    "model_{}_wd_{}_stepsize_{}_trajlen_{}_burnin_{}_{}_mh_{}_temp_{}_"
    "seed_{}".format(
      args.model_name, args.weight_decay, args.step_size, args.trajectory_len,
      args.num_burn_in_iterations, args.burn_in_step_size_factor,
      not args.no_mh, args.temperature, args.seed
    ))
  dirname = os.path.join(args.dir, subdirname)
  os.makedirs(dirname, exist_ok=True)
  tf_writer = tf.summary.create_file_writer(dirname)
  cmd_args_utils.save_cmd(dirname, tf_writer)
  num_devices = len(jax.devices())

  dtype = jnp.float64 if args.use_float64 else jnp.float32
  train_set, test_set, num_classes = data.make_ds_pmap_fullbatch(
    args.dataset_name, dtype)

  net_apply, net_init = models.get_model(args.model_name, num_classes)
  
  checkpoint_dict, status = checkpoint_utils.initialize(
      dirname, args.init_checkpoint)
  
  if status == checkpoint_utils.InitStatus.LOADED_PREEMPTED:
    print("Continuing the run from the last saved checkpoint")
    (start_iteration, params, net_state, key, step_size, _, num_ensembled,
     ensemble_predicted_probs) = (
        checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict))
    
  else:
    key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2)
    start_iteration = 0
    num_ensembled = 0
    ensemble_predicted_probs = None
    step_size = args.step_size
    
    if status == checkpoint_utils.InitStatus.INIT_CKPT:
      print("Resuming the run from the provided init_checkpoint")
      _, params, net_state, _, _, _, _, _ = (
        checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict))
    elif status == checkpoint_utils.InitStatus.INIT_RANDOM:
      print("Starting from random initialization with provided seed")
      key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2)
      init_data = jax.tree_map(lambda elem: elem[0][:1], train_set)
      params, net_state = net_init(net_init_key, init_data, True)
      net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices))
    else:
      raise ValueError("Unknown initialization status: {}".format(status))

  # manually convert all params to dtype
  params = jax.tree_map(lambda p: p.astype(dtype), params)
  
  param_types = tree_utils._get_types(params)
  assert all([p_type == dtype for p_type in param_types]), (
    "Params data types {} do not match specified data type {}".format(
      param_types, dtype))
    
  trajectory_len = args.trajectory_len

  log_likelihood_fn = nn_loss.make_xent_log_likelihood(
    num_classes, args.temperature)
  log_prior_fn, log_prior_diff_fn = nn_loss.make_gaussian_log_prior(
    args.weight_decay, args.temperature)

  update, get_log_prob_and_grad, evaluate = train_utils.make_hmc_update(
    net_apply, log_likelihood_fn, log_prior_fn, log_prior_diff_fn,
    args.max_num_leapfrog_steps, args.target_accept_rate,
    args.step_size_adaptation_speed)

  log_prob, state_grad, log_likelihood, net_state = (
      get_log_prob_and_grad(train_set, params, net_state))

  assert log_prob.dtype == dtype, (
    "log_prob data type {} does not match specified data type {}".format(
        log_prob.dtype, dtype))

  grad_types = tree_utils._get_types(state_grad)
  assert all([g_type == dtype for g_type in grad_types]), (
    "Gradient data types {} do not match specified data type {}".format(
      grad_types, dtype))

  ensemble_acc = 0
  
  for iteration in range(start_iteration, args.num_iterations):
    
    # do a linear ramp-down of the step-size in the burn-in phase
    if iteration < args.num_burn_in_iterations:
      alpha = iteration / (args.num_burn_in_iterations - 1)
      initial_step_size = args.step_size
      final_step_size = args.burn_in_step_size_factor * args.step_size
      step_size = final_step_size * alpha + initial_step_size * (1 - alpha)
    in_burnin = (iteration < args.num_burn_in_iterations)
    do_mh_correction = (not args.no_mh) and (not in_burnin)

    start_time = time.time()
    (params, net_state, log_likelihood, state_grad, step_size, key,
     accept_prob, accepted) = (
        update(train_set, params, net_state, log_likelihood, state_grad,
               key, step_size, trajectory_len, do_mh_correction))
    iteration_time = time.time() - start_time

    checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration)
    checkpoint_path = os.path.join(dirname, checkpoint_name)
    checkpoint_dict = checkpoint_utils.make_hmc_checkpoint_dict(
        iteration, params, net_state, key, step_size, accepted, num_ensembled,
        ensemble_predicted_probs)
    checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict)

    if ((not in_burnin) and accepted) or args.no_mh:
      ensemble_predicted_probs, ensemble_acc, num_ensembled = (
          train_utils.update_ensemble(
              net_apply, params, net_state, test_set, num_ensembled,
              ensemble_predicted_probs))
      
    test_log_prob, test_acc, test_ce, _ = evaluate(params, net_state, test_set)
    train_log_prob, train_acc, train_ce, prior = (
        evaluate(params, net_state, train_set))
      
    tabulate_dict = OrderedDict()
    tabulate_dict["iteration"] = iteration
    tabulate_dict["step_size"] = step_size
    tabulate_dict["train_logprob"] = log_prob
    tabulate_dict["train_acc"] = train_acc
    tabulate_dict["test_acc"] = test_acc
    tabulate_dict["test_ce"] = test_ce
    tabulate_dict["accept_prob"] = accept_prob
    tabulate_dict["accepted"] = accepted
    tabulate_dict["ensemble_acc"] = ensemble_acc
    tabulate_dict["n_ens"] = num_ensembled
    tabulate_dict["time"] = iteration_time

    with tf_writer.as_default():
      tf.summary.scalar("train/log_prob", train_log_prob, step=iteration)
      tf.summary.scalar("test/log_prob", test_log_prob, step=iteration)
      tf.summary.scalar("train/log_likelihood", train_ce, step=iteration)
      tf.summary.scalar("test/log_likelihood", test_ce, step=iteration)
      tf.summary.scalar("train/accuracy", train_acc, step=iteration)
      tf.summary.scalar("test/accuracy", test_acc, step=iteration)
      tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration)
      tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration)
      tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration)
      
      if num_ensembled > 0:
        test_labels = onp.asarray(test_set[1])
        ensemble_nll = metrics.nll(ensemble_predicted_probs, test_labels)
        ensemble_calibration = metrics.calibration_curve(
            ensemble_predicted_probs, test_labels)
        tf.summary.scalar(
            "test/ens_ece", ensemble_calibration["ece"], step=iteration)
        tf.summary.scalar("test/ens_nll", ensemble_nll, step=iteration)

      tf.summary.scalar("telemetry/log_prior", prior, step=iteration)
      tf.summary.scalar("telemetry/accept_prob", accept_prob, step=iteration)
      tf.summary.scalar("telemetry/accepted", accepted, step=iteration)
      tf.summary.scalar("telemetry/n_ens", num_ensembled, step=iteration)
      tf.summary.scalar("telemetry/iteration_time", iteration_time,
                        step=iteration)
      
      tf.summary.scalar("hypers/step_size", step_size, step=iteration)
      tf.summary.scalar("hypers/trajectory_len", trajectory_len,
                        step=iteration)
      tf.summary.scalar("hypers/weight_decay", args.weight_decay,
                        step=iteration)
      tf.summary.scalar("hypers/temperature", args.temperature,
                        step=iteration)

      tf.summary.scalar("debug/do_mh_correction", float(do_mh_correction),
                        step=iteration)
      tf.summary.scalar("debug/in_burnin", float(in_burnin),
                        step=iteration)
      
      
      

    table = tabulate_utils.make_table(
      tabulate_dict, iteration - start_iteration, args.tabulate_freq)
    print(table)
def run_visualization():

    subdirname = "posterior_visualization"
    dirname = os.path.join(args.dir, subdirname)
    os.makedirs(dirname, exist_ok=True)
    cmd_args_utils.save_cmd(dirname, None)

    train_set, test_set, num_classes = data.make_ds_pmap_fullbatch(
        name=args.dataset_name)

    net_apply, _ = models.get_model(args.model_name, num_classes)
    net_state = FlatMapping({})

    log_likelihood_fn = nn_loss.make_xent_log_likelihood(num_classes)
    log_prior_fn, _ = (nn_loss.make_gaussian_log_prior(
        weight_decay=args.weight_decay))

    _, likelihood_prior_and_acc_fn = (
        train_utils.make_perdevice_log_prob_acc_grad_fns(
            net_apply, log_likelihood_fn, log_prior_fn))

    def eval(params, net_state, dataset):
        likelihood, prior, _, _ = likelihood_prior_and_acc_fn(params,
                                                              net_state,
                                                              dataset,
                                                              is_training=True)
        likelihood = jax.lax.psum(likelihood, axis_name='i')
        log_prob = likelihood + prior
        return log_prob, likelihood, prior

    params1 = load_params(args.checkpoint1)
    params2 = load_params(args.checkpoint2)
    params3 = load_params(args.checkpoint3)

    # for params in [params1, params2, params3]:
    #   print(jax.pmap(eval, axis_name='i', in_axes=(None, None, 0))
    #         (params, net_state, train_set))

    u_vec, u_norm, v_vec, v_norm, origin = get_u_v_o(params1, params2, params3)

    u_ts = onp.linspace(args.limit_left, args.limit_right, args.grid_size)
    v_ts = onp.linspace(args.limit_bottom, args.limit_top, args.grid_size)
    n_u, n_v = len(u_ts), len(v_ts)
    log_probs = onp.zeros((n_u, n_v))
    log_likelihoods = onp.zeros((n_u, n_v))
    log_priors = onp.zeros((n_u, n_v))
    grid = onp.zeros((n_u, n_v, 2))

    @functools.partial(jax.pmap, axis_name='i', in_axes=(None, 0))
    def eval_row_of_plot(u_t_, dataset):
        def loop_body(_, v_t_):
            params = jax.tree_multimap(
                lambda u, v, o: o + u * u_t_ * u_norm + v * v_t_ * v_norm,
                u_vec, v_vec, origin)
            logprob, likelihood, prior = eval(params, net_state, dataset)
            arr = jnp.array([logprob, likelihood, prior])
            return None, arr

        _, vals = jax.lax.scan(loop_body, None, v_ts)
        row_logprobs, row_likelihoods, row_priors = jnp.split(vals, [1, 2],
                                                              axis=1)
        return row_logprobs, row_likelihoods, row_priors

    for u_i, u_t in enumerate(tqdm.tqdm(u_ts)):
        log_probs_i, likelihoods_i, priors_i = eval_row_of_plot(u_t, train_set)
        log_probs_i, likelihoods_i, priors_i = map(
            lambda arr: arr[0], [log_probs_i, likelihoods_i, priors_i])
        log_probs[u_i] = log_probs_i[:, 0]
        log_likelihoods[u_i] = likelihoods_i[:, 0]
        log_priors[u_i] = priors_i[:, 0]
        grid[u_i, :, 0] = onp.array([u_t] * n_v)
        grid[u_i, :, 1] = v_ts

    onp.savez(os.path.join(dirname, "surface_plot.npz"),
              log_probs=log_probs,
              log_priors=log_priors,
              log_likelihoods=log_likelihoods,
              grid=grid,
              u_norm=u_norm,
              v_norm=v_norm)

    plt.contour(grid[:, :, 0], grid[:, :, 1], log_probs, zorder=1)
    plt.contourf(grid[:, :, 0], grid[:, :, 1], log_probs, zorder=0, alpha=0.55)
    plt.plot([0., 1., 0.5], [0., 0., 1.], "ro", ms=20, markeredgecolor="k")
    plt.colorbar()
    plt.savefig(os.path.join(dirname, "log_prob.pdf"), bbox_inches="tight")
Beispiel #7
0
    def test_accept_prob(self):
        """Test make_accept_prob implementation in hmc."""
        weight_decay = 30.
        net, params, train_set, test_set, init_data, init_key = self._prepare()
        init_key, = jax.random.split(init_key, 1)
        params2 = net.init(init_key, init_data)

        # Rescale parameters so accept_prob is not 0 or 1.
        params, params2 = jax.tree_map(lambda p: p * 1e-2, [params, params2])

        log_likelihood_fn = nn_loss.xent_log_likelihood
        log_prior_fn, log_prior_diff = (nn_loss.make_gaussian_log_prior(
            weight_decay=weight_decay))
        _, _, log_prob_and_grad_fn = (train_utils.make_hmc_update_eval_fns(
            net, train_set, test_set, log_likelihood_fn, log_prior_fn,
            log_prior_diff))

        log_prob, _, log_likelihood, log_prior = log_prob_and_grad_fn(params)
        log_prob2, _, log_likelihood2, log_prior2 = log_prob_and_grad_fn(
            params2)

        key, key2 = jax.random.split(init_key, 2)
        momentum = hmc.sample_momentum(params, key)
        momentum2 = hmc.sample_momentum(params2, key2)
        momentum, momentum2 = jax.tree_map(lambda p: p * 1e-2,
                                           [momentum, momentum2])

        get_accept_prob = hmc.make_accept_prob(log_prior_diff)
        accept_prob = get_accept_prob(log_likelihood, params, momentum,
                                      log_likelihood2, params2, momentum2)
        accept_prob_reverse = get_accept_prob(log_likelihood2, params2,
                                              momentum2, log_likelihood,
                                              params, momentum)

        def prior_onp_fn(params):
            norm_sq = sum([
                onp.sum(onp.array(p, dtype=onp.float128)**2)
                for p in jax.tree_leaves(params)
            ])
            return -0.5 * weight_decay * norm_sq

        def kinetic_energy_onp_fn(momentum):
            return sum([
                0.5 * onp.sum(onp.array(m, dtype=onp.float128)**2)
                for m in jax.tree_leaves(momentum)
            ])

        energy = (kinetic_energy_onp_fn(momentum) -
                  onp.array(log_likelihood, dtype=onp.float128) -
                  prior_onp_fn(params))
        energy2 = (kinetic_energy_onp_fn(momentum2) -
                   onp.array(log_likelihood2, dtype=onp.float128) -
                   prior_onp_fn(params2))
        energy_diff = energy - energy2
        onp_accept_prob = onp.minimum(1., onp.exp(energy_diff))
        onp_accept_prob_reverse = onp.minimum(1., onp.exp(-energy_diff))

        self.assertLess(
            onp.abs(onp_accept_prob -
                    onp.array(accept_prob, dtype=onp.float128)), 1e-4)
        self.assertLess(
            onp.abs(onp_accept_prob_reverse -
                    onp.array(accept_prob_reverse, dtype=onp.float128)), 1e-4)

        accept_prob_not_1_or_0 = (((accept_prob > 1e-4) and
                                   (accept_prob < 1. - 1e-4))
                                  or ((accept_prob_reverse > 1e-4) and
                                      (accept_prob_reverse < 1. - 1e-4)))
        assert accept_prob_not_1_or_0