def _prepare(): net_fn = models.lenet_fn net = hk.transform(net_fn) train_set, test_set, _ = data.make_ds_pmap_fullbatch(name="cifar10") init_key = jax.random.PRNGKey(0) init_data = jax.tree_map(lambda elem: elem[0], train_set) params = net.init(init_key, init_data) return net, params, train_set, test_set, init_data, init_key
def train_model(): subdirname = ( "sgld_wd_{}_stepsizes_{}_{}_batchsize_{}_epochs{}_{}_temp_{}_seed_{}". format(args.weight_decay, args.init_step_size, args.final_step_size, args.batch_size, args.num_epochs, args.num_burnin_epochs, args.temperature, args.seed)) dirname = os.path.join(args.dir, subdirname) os.makedirs(dirname, exist_ok=True) tf_writer = tf.summary.create_file_writer(dirname) cmd_args_utils.save_cmd(dirname, tf_writer) dtype = jnp.float64 if args.use_float64 else jnp.float32 train_set, test_set, num_classes = data.make_ds_pmap_fullbatch( args.dataset_name, dtype) net_apply, net_init = models.get_model(args.model_name, num_classes) net_apply = precision_utils.rewrite_high_precision(net_apply) log_likelihood_fn = nn_loss.make_xent_log_likelihood( num_classes, args.temperature) log_prior_fn, _ = nn_loss.make_gaussian_log_prior(args.weight_decay, args.temperature) num_data = jnp.size(train_set[1]) num_batches = num_data // args.batch_size num_devices = len(jax.devices()) burnin_steps = num_batches * args.num_burnin_epochs lr_schedule = train_utils.make_cosine_lr_schedule_with_burnin( args.init_step_size, args.final_step_size, burnin_steps) optimizer = sgmcmc.sgld_gradient_update(lr_schedule, args.seed) checkpoint_dict, status = checkpoint_utils.initialize( dirname, args.init_checkpoint) if status == checkpoint_utils.InitStatus.LOADED_PREEMPTED: print("Continuing the run from the last saved checkpoint") (start_iteration, params, net_state, opt_state, key, num_ensembled, ensemble_predicted_probs) = ( checkpoint_utils.parse_sgmcmc_checkpoint_dict(checkpoint_dict)) else: key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2) start_iteration = 0 num_ensembled = 0 ensemble_predicted_probs = None key = jax.random.split(key, num_devices) if status == checkpoint_utils.InitStatus.INIT_CKPT: print("Resuming the run from the provided init_checkpoint") _, params, net_state, _, _, _, _ = ( checkpoint_utils.parse_sgmcmc_checkpoint_dict(checkpoint_dict)) opt_state = optimizer.init(params) elif status == checkpoint_utils.InitStatus.INIT_RANDOM: print("Starting from random initialization with provided seed") init_data = jax.tree_map(lambda elem: elem[0][:1], train_set) params, net_state = net_init(net_init_key, init_data, True) opt_state = optimizer.init(params) net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices)) else: raise ValueError( "Unknown initialization status: {}".format(status)) sgmcmc_train_epoch, evaluate = train_utils.make_sgd_train_epoch( net_apply, log_likelihood_fn, log_prior_fn, optimizer, num_batches) ensemble_acc = None param_types = tree_utils._get_types(params) assert all([ p_type == dtype for p_type in param_types ]), ("Params data types {} do not match specified data type {}".format( param_types, dtype)) for iteration in range(start_iteration, args.num_epochs): start_time = time.time() params, net_state, opt_state, logprob_avg, key = sgmcmc_train_epoch( params, net_state, opt_state, train_set, key) iteration_time = time.time() - start_time tabulate_dict = OrderedDict() tabulate_dict["iteration"] = iteration tabulate_dict["step_size"] = lr_schedule(opt_state.count) tabulate_dict["train_logprob"] = logprob_avg tabulate_dict["train_acc"] = None tabulate_dict["test_logprob"] = None tabulate_dict["test_acc"] = None tabulate_dict["ensemble_acc"] = ensemble_acc tabulate_dict["n_ens"] = num_ensembled tabulate_dict["time"] = iteration_time with tf_writer.as_default(): tf.summary.scalar("train/log_prob_running", logprob_avg, step=iteration) tf.summary.scalar("hypers/step_size", lr_schedule(opt_state.count), step=iteration) tf.summary.scalar("debug/iteration_time", iteration_time, step=iteration) if iteration % args.save_freq == 0 or iteration == args.num_epochs - 1: checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration) checkpoint_path = os.path.join(dirname, checkpoint_name) checkpoint_dict = checkpoint_utils.make_sgmcmc_checkpoint_dict( iteration, params, net_state, opt_state, key, num_ensembled, ensemble_predicted_probs) checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict) if (iteration % args.eval_freq == 0) or (iteration == args.num_epochs - 1): test_log_prob, test_acc, test_ce, _ = evaluate( params, net_state, test_set) train_log_prob, train_acc, train_ce, prior = (evaluate( params, net_state, train_set)) tabulate_dict["train_logprob"] = train_log_prob tabulate_dict["test_logprob"] = test_log_prob tabulate_dict["train_acc"] = train_acc tabulate_dict["test_acc"] = test_acc with tf_writer.as_default(): tf.summary.scalar("train/log_prob", train_log_prob, step=iteration) tf.summary.scalar("test/log_prob", test_log_prob, step=iteration) tf.summary.scalar("train/log_likelihood", train_ce, step=iteration) tf.summary.scalar("test/log_likelihood", test_ce, step=iteration) tf.summary.scalar("train/accuracy", train_acc, step=iteration) tf.summary.scalar("test/accuracy", test_acc, step=iteration) if ((iteration > args.num_burnin_epochs) and ((iteration - args.num_burnin_epochs) % args.ensemble_freq == 0)): ensemble_predicted_probs, ensemble_acc, num_ensembled = ( train_utils.update_ensemble(net_apply, params, net_state, test_set, num_ensembled, ensemble_predicted_probs)) tabulate_dict["ensemble_acc"] = ensemble_acc tabulate_dict["n_ens"] = num_ensembled test_labels = onp.asarray(test_set[1]) ensemble_nll = metrics.nll(ensemble_predicted_probs, test_labels) ensemble_calibration = metrics.calibration_curve( ensemble_predicted_probs, test_labels) with tf_writer.as_default(): tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration) tf.summary.scalar("test/ens_ece", ensemble_calibration["ece"], step=iteration) tf.summary.scalar("test/ens_nll", ensemble_nll, step=iteration) tf.summary.scalar("debug/n_ens", num_ensembled, step=iteration) table = tabulate_utils.make_table(tabulate_dict, iteration - start_iteration, args.tabulate_freq) print(table)
def train_model(): subdirname = "sgd_wd_{}_stepsize_{}_batchsize_{}_momentum_{}_seed_{}".format( args.weight_decay, args.init_step_size, args.batch_size, args.momentum_decay, args.seed) dirname = os.path.join(args.dir, subdirname) os.makedirs(dirname, exist_ok=True) tf_writer = tf.summary.create_file_writer(dirname) cmd_args_utils.save_cmd(dirname, tf_writer) dtype = jnp.float64 if args.use_float64 else jnp.float32 train_set, test_set, num_classes = data.make_ds_pmap_fullbatch( args.dataset_name, dtype) net_apply, net_init = models.get_model(args.model_name, num_classes) log_likelihood_fn = nn_loss.make_xent_log_likelihood(num_classes, 1.) log_prior_fn, _ = (nn_loss.make_gaussian_log_prior(args.weight_decay, 1.)) num_data = jnp.size(train_set[1]) num_batches = num_data // args.batch_size num_devices = len(jax.devices()) total_steps = num_batches * args.num_epochs lr_schedule = train_utils.make_cosine_lr_schedule(args.init_step_size, total_steps) optimizer = train_utils.make_optimizer(lr_schedule, momentum_decay=args.momentum_decay) checkpoint_dict, status = checkpoint_utils.initialize( dirname, args.init_checkpoint) if status == checkpoint_utils.InitStatus.INIT_RANDOM: key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2) print("Starting from random initialization with provided seed") init_data = jax.tree_map(lambda elem: elem[0][:1], train_set) params, net_state = net_init(net_init_key, init_data, True) opt_state = optimizer.init(params) net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices)) key = jax.random.split(key, num_devices) start_iteration = 0 else: start_iteration, params, net_state, opt_state, key = ( checkpoint_utils.parse_sgd_checkpoint_dict(checkpoint_dict)) if status == checkpoint_utils.InitStatus.INIT_CKPT: print("Resuming the run from the provided init_checkpoint") # TODO: fix -- we should only load the parameters in this case elif status == checkpoint_utils.InitStatus.LOADED_PREEMPTED: print("Continuing the run from the last saved checkpoint") sgd_train_epoch, evaluate = train_utils.make_sgd_train_epoch( net_apply, log_likelihood_fn, log_prior_fn, optimizer, num_batches) for iteration in range(start_iteration, args.num_epochs): start_time = time.time() params, net_state, opt_state, logprob_avg, key = sgd_train_epoch( params, net_state, opt_state, train_set, key) iteration_time = time.time() - start_time tabulate_dict = OrderedDict() tabulate_dict["iteration"] = iteration tabulate_dict["step_size"] = lr_schedule(opt_state[-1].count) tabulate_dict["train_logprob"] = logprob_avg tabulate_dict["train_acc"] = None tabulate_dict["test_logprob"] = None tabulate_dict["test_acc"] = None tabulate_dict["time"] = iteration_time with tf_writer.as_default(): tf.summary.scalar("train/log_prob_running", logprob_avg, step=iteration) tf.summary.scalar("hypers/step_size", lr_schedule(opt_state[-1].count), step=iteration) tf.summary.scalar("debug/iteration_time", iteration_time, step=iteration) if iteration % args.save_freq == 0 or iteration == args.num_epochs - 1: checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration) checkpoint_path = os.path.join(dirname, checkpoint_name) checkpoint_dict = checkpoint_utils.make_sgd_checkpoint_dict( iteration, params, net_state, opt_state, key) checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict) if (iteration % args.eval_freq == 0) or (iteration == args.num_epochs - 1): test_log_prob, test_acc, test_ce, _ = evaluate( params, net_state, test_set) train_log_prob, train_acc, train_ce, prior = (evaluate( params, net_state, train_set)) tabulate_dict["train_logprob"] = train_log_prob tabulate_dict["test_logprob"] = test_log_prob tabulate_dict["train_acc"] = train_acc tabulate_dict["test_acc"] = test_acc with tf_writer.as_default(): tf.summary.scalar("train/log_prob", train_log_prob, step=iteration) tf.summary.scalar("test/log_prob", test_log_prob, step=iteration) tf.summary.scalar("train/log_likelihood", train_ce, step=iteration) tf.summary.scalar("test/log_likelihood", test_ce, step=iteration) tf.summary.scalar("train/accuracy", train_acc, step=iteration) tf.summary.scalar("test/accuracy", test_acc, step=iteration) table = tabulate_utils.make_table(tabulate_dict, iteration - start_iteration, args.tabulate_freq) print(table)
def train_model(): subdirname = ( "model_{}_wd_{}_stepsize_{}_trajlen_{}_burnin_{}_{}_mh_{}_temp_{}_" "seed_{}".format( args.model_name, args.weight_decay, args.step_size, args.trajectory_len, args.num_burn_in_iterations, args.burn_in_step_size_factor, not args.no_mh, args.temperature, args.seed )) dirname = os.path.join(args.dir, subdirname) os.makedirs(dirname, exist_ok=True) tf_writer = tf.summary.create_file_writer(dirname) cmd_args_utils.save_cmd(dirname, tf_writer) num_devices = len(jax.devices()) dtype = jnp.float64 if args.use_float64 else jnp.float32 train_set, test_set, num_classes = data.make_ds_pmap_fullbatch( args.dataset_name, dtype) net_apply, net_init = models.get_model(args.model_name, num_classes) checkpoint_dict, status = checkpoint_utils.initialize( dirname, args.init_checkpoint) if status == checkpoint_utils.InitStatus.LOADED_PREEMPTED: print("Continuing the run from the last saved checkpoint") (start_iteration, params, net_state, key, step_size, _, num_ensembled, ensemble_predicted_probs) = ( checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict)) else: key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2) start_iteration = 0 num_ensembled = 0 ensemble_predicted_probs = None step_size = args.step_size if status == checkpoint_utils.InitStatus.INIT_CKPT: print("Resuming the run from the provided init_checkpoint") _, params, net_state, _, _, _, _, _ = ( checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict)) elif status == checkpoint_utils.InitStatus.INIT_RANDOM: print("Starting from random initialization with provided seed") key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2) init_data = jax.tree_map(lambda elem: elem[0][:1], train_set) params, net_state = net_init(net_init_key, init_data, True) net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices)) else: raise ValueError("Unknown initialization status: {}".format(status)) # manually convert all params to dtype params = jax.tree_map(lambda p: p.astype(dtype), params) param_types = tree_utils._get_types(params) assert all([p_type == dtype for p_type in param_types]), ( "Params data types {} do not match specified data type {}".format( param_types, dtype)) trajectory_len = args.trajectory_len log_likelihood_fn = nn_loss.make_xent_log_likelihood( num_classes, args.temperature) log_prior_fn, log_prior_diff_fn = nn_loss.make_gaussian_log_prior( args.weight_decay, args.temperature) update, get_log_prob_and_grad, evaluate = train_utils.make_hmc_update( net_apply, log_likelihood_fn, log_prior_fn, log_prior_diff_fn, args.max_num_leapfrog_steps, args.target_accept_rate, args.step_size_adaptation_speed) log_prob, state_grad, log_likelihood, net_state = ( get_log_prob_and_grad(train_set, params, net_state)) assert log_prob.dtype == dtype, ( "log_prob data type {} does not match specified data type {}".format( log_prob.dtype, dtype)) grad_types = tree_utils._get_types(state_grad) assert all([g_type == dtype for g_type in grad_types]), ( "Gradient data types {} do not match specified data type {}".format( grad_types, dtype)) ensemble_acc = 0 for iteration in range(start_iteration, args.num_iterations): # do a linear ramp-down of the step-size in the burn-in phase if iteration < args.num_burn_in_iterations: alpha = iteration / (args.num_burn_in_iterations - 1) initial_step_size = args.step_size final_step_size = args.burn_in_step_size_factor * args.step_size step_size = final_step_size * alpha + initial_step_size * (1 - alpha) in_burnin = (iteration < args.num_burn_in_iterations) do_mh_correction = (not args.no_mh) and (not in_burnin) start_time = time.time() (params, net_state, log_likelihood, state_grad, step_size, key, accept_prob, accepted) = ( update(train_set, params, net_state, log_likelihood, state_grad, key, step_size, trajectory_len, do_mh_correction)) iteration_time = time.time() - start_time checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration) checkpoint_path = os.path.join(dirname, checkpoint_name) checkpoint_dict = checkpoint_utils.make_hmc_checkpoint_dict( iteration, params, net_state, key, step_size, accepted, num_ensembled, ensemble_predicted_probs) checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict) if ((not in_burnin) and accepted) or args.no_mh: ensemble_predicted_probs, ensemble_acc, num_ensembled = ( train_utils.update_ensemble( net_apply, params, net_state, test_set, num_ensembled, ensemble_predicted_probs)) test_log_prob, test_acc, test_ce, _ = evaluate(params, net_state, test_set) train_log_prob, train_acc, train_ce, prior = ( evaluate(params, net_state, train_set)) tabulate_dict = OrderedDict() tabulate_dict["iteration"] = iteration tabulate_dict["step_size"] = step_size tabulate_dict["train_logprob"] = log_prob tabulate_dict["train_acc"] = train_acc tabulate_dict["test_acc"] = test_acc tabulate_dict["test_ce"] = test_ce tabulate_dict["accept_prob"] = accept_prob tabulate_dict["accepted"] = accepted tabulate_dict["ensemble_acc"] = ensemble_acc tabulate_dict["n_ens"] = num_ensembled tabulate_dict["time"] = iteration_time with tf_writer.as_default(): tf.summary.scalar("train/log_prob", train_log_prob, step=iteration) tf.summary.scalar("test/log_prob", test_log_prob, step=iteration) tf.summary.scalar("train/log_likelihood", train_ce, step=iteration) tf.summary.scalar("test/log_likelihood", test_ce, step=iteration) tf.summary.scalar("train/accuracy", train_acc, step=iteration) tf.summary.scalar("test/accuracy", test_acc, step=iteration) tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration) tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration) tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration) if num_ensembled > 0: test_labels = onp.asarray(test_set[1]) ensemble_nll = metrics.nll(ensemble_predicted_probs, test_labels) ensemble_calibration = metrics.calibration_curve( ensemble_predicted_probs, test_labels) tf.summary.scalar( "test/ens_ece", ensemble_calibration["ece"], step=iteration) tf.summary.scalar("test/ens_nll", ensemble_nll, step=iteration) tf.summary.scalar("telemetry/log_prior", prior, step=iteration) tf.summary.scalar("telemetry/accept_prob", accept_prob, step=iteration) tf.summary.scalar("telemetry/accepted", accepted, step=iteration) tf.summary.scalar("telemetry/n_ens", num_ensembled, step=iteration) tf.summary.scalar("telemetry/iteration_time", iteration_time, step=iteration) tf.summary.scalar("hypers/step_size", step_size, step=iteration) tf.summary.scalar("hypers/trajectory_len", trajectory_len, step=iteration) tf.summary.scalar("hypers/weight_decay", args.weight_decay, step=iteration) tf.summary.scalar("hypers/temperature", args.temperature, step=iteration) tf.summary.scalar("debug/do_mh_correction", float(do_mh_correction), step=iteration) tf.summary.scalar("debug/in_burnin", float(in_burnin), step=iteration) table = tabulate_utils.make_table( tabulate_dict, iteration - start_iteration, args.tabulate_freq) print(table)
def run_visualization(): subdirname = "posterior_visualization" dirname = os.path.join(args.dir, subdirname) os.makedirs(dirname, exist_ok=True) cmd_args_utils.save_cmd(dirname, None) train_set, test_set, num_classes = data.make_ds_pmap_fullbatch( name=args.dataset_name) net_apply, _ = models.get_model(args.model_name, num_classes) net_state = FlatMapping({}) log_likelihood_fn = nn_loss.make_xent_log_likelihood(num_classes) log_prior_fn, _ = (nn_loss.make_gaussian_log_prior( weight_decay=args.weight_decay)) _, likelihood_prior_and_acc_fn = ( train_utils.make_perdevice_log_prob_acc_grad_fns( net_apply, log_likelihood_fn, log_prior_fn)) def eval(params, net_state, dataset): likelihood, prior, _, _ = likelihood_prior_and_acc_fn(params, net_state, dataset, is_training=True) likelihood = jax.lax.psum(likelihood, axis_name='i') log_prob = likelihood + prior return log_prob, likelihood, prior params1 = load_params(args.checkpoint1) params2 = load_params(args.checkpoint2) params3 = load_params(args.checkpoint3) # for params in [params1, params2, params3]: # print(jax.pmap(eval, axis_name='i', in_axes=(None, None, 0)) # (params, net_state, train_set)) u_vec, u_norm, v_vec, v_norm, origin = get_u_v_o(params1, params2, params3) u_ts = onp.linspace(args.limit_left, args.limit_right, args.grid_size) v_ts = onp.linspace(args.limit_bottom, args.limit_top, args.grid_size) n_u, n_v = len(u_ts), len(v_ts) log_probs = onp.zeros((n_u, n_v)) log_likelihoods = onp.zeros((n_u, n_v)) log_priors = onp.zeros((n_u, n_v)) grid = onp.zeros((n_u, n_v, 2)) @functools.partial(jax.pmap, axis_name='i', in_axes=(None, 0)) def eval_row_of_plot(u_t_, dataset): def loop_body(_, v_t_): params = jax.tree_multimap( lambda u, v, o: o + u * u_t_ * u_norm + v * v_t_ * v_norm, u_vec, v_vec, origin) logprob, likelihood, prior = eval(params, net_state, dataset) arr = jnp.array([logprob, likelihood, prior]) return None, arr _, vals = jax.lax.scan(loop_body, None, v_ts) row_logprobs, row_likelihoods, row_priors = jnp.split(vals, [1, 2], axis=1) return row_logprobs, row_likelihoods, row_priors for u_i, u_t in enumerate(tqdm.tqdm(u_ts)): log_probs_i, likelihoods_i, priors_i = eval_row_of_plot(u_t, train_set) log_probs_i, likelihoods_i, priors_i = map( lambda arr: arr[0], [log_probs_i, likelihoods_i, priors_i]) log_probs[u_i] = log_probs_i[:, 0] log_likelihoods[u_i] = likelihoods_i[:, 0] log_priors[u_i] = priors_i[:, 0] grid[u_i, :, 0] = onp.array([u_t] * n_v) grid[u_i, :, 1] = v_ts onp.savez(os.path.join(dirname, "surface_plot.npz"), log_probs=log_probs, log_priors=log_priors, log_likelihoods=log_likelihoods, grid=grid, u_norm=u_norm, v_norm=v_norm) plt.contour(grid[:, :, 0], grid[:, :, 1], log_probs, zorder=1) plt.contourf(grid[:, :, 0], grid[:, :, 1], log_probs, zorder=0, alpha=0.55) plt.plot([0., 1., 0.5], [0., 0., 1.], "ro", ms=20, markeredgecolor="k") plt.colorbar() plt.savefig(os.path.join(dirname, "log_prob.pdf"), bbox_inches="tight")