def train_model(): # Initialize training directory dirname, tf_writer = get_dirname_tfwriter(args) # Initialize data, model, losses and metrics (train_set, test_set, net_apply, params, net_state, key, log_likelihood_fn, _, _, predict_fn, ensemble_upd_fn, metrics_fns, tabulate_metrics) = script_utils.get_data_model_fns(args) # Convert the model to MFVI parameterization net_apply, mean_apply, _, params, net_state = vi.get_mfvi_model_fn( net_apply, params, net_state, seed=0, sigma_init=args.vi_sigma_init) prior_kl = vi.make_kl_with_gaussian_prior(args.weight_decay, args.temperature) vi_ensemble_predict_fn = make_vi_ensemble_predict_fn(predict_fn, ensemble_upd_fn, args) # Initialize step-size schedule and optimizer num_batches, total_steps = script_utils.get_num_batches_total_steps( args, train_set) num_devices = len(jax.devices()) lr_schedule = optim_utils.make_cosine_lr_schedule(args.init_step_size, total_steps) optimizer = get_optimizer(lr_schedule, args) # Initialize variables opt_state = optimizer.init(params) net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices)) key = jax.random.split(key, num_devices) init_dict = checkpoint_utils.make_sgd_checkpoint_dict(-1, params, net_state, opt_state, key) init_dict = script_utils.get_initialization_dict(dirname, args, init_dict) start_iteration, params, net_state, opt_state, key = ( checkpoint_utils.parse_sgd_checkpoint_dict(init_dict)) start_iteration += 1 # Loading mean checkpoint if args.mean_init_checkpoint: print("Initializing VI mean from the provided checkpoint") ckpt_dict = checkpoint_utils.load_checkpoint(args.mean_init_checkpoint) mean_params = checkpoint_utils.parse_sgd_checkpoint_dict(ckpt_dict)[1] params["mean"] = mean_params # Define train epoch sgd_train_epoch = script_utils.time_fn( train_utils.make_sgd_train_epoch(net_apply, log_likelihood_fn, prior_kl, optimizer, num_batches)) # Train for iteration in range(start_iteration, args.num_epochs): (params, net_state, opt_state, elbo_avg, key), iteration_time = ( sgd_train_epoch(params, net_state, opt_state, train_set, key)) # Evaluate the model train_stats = {"ELBO": elbo_avg, "KL": prior_kl(params)} test_stats, ensemble_stats = {}, {} if (iteration % args.eval_freq == 0) or (iteration == args.num_epochs - 1): # Evaluate the mean _, test_predictions, train_predictions, test_stats, train_stats_ = ( script_utils.evaluate(mean_apply, params, net_state, train_set, test_set, predict_fn, metrics_fns, prior_kl)) train_stats.update(train_stats_) del train_stats["prior"] # Evaluate the ensemble net_state, ensemble_predictions = onp.asarray( vi_ensemble_predict_fn(net_apply, params, net_state, test_set)) ensemble_stats = train_utils.evaluate_metrics(ensemble_predictions, test_set[1], metrics_fns) # Save checkpoint if iteration % args.save_freq == 0 or iteration == args.num_epochs - 1: checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration) checkpoint_path = os.path.join(dirname, checkpoint_name) checkpoint_dict = checkpoint_utils.make_sgd_checkpoint_dict( iteration, params, net_state, opt_state, key) checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict) # Log results other_logs = script_utils.get_common_logs(iteration, iteration_time, args) other_logs["hypers/step_size"] = lr_schedule(opt_state[-1].count) other_logs["hypers/momentum"] = args.momentum_decay logging_dict = logging_utils.make_logging_dict(train_stats, test_stats, ensemble_stats) logging_dict.update(other_logs) script_utils.write_to_tensorboard(tf_writer, logging_dict, iteration) # Add a histogram of MFVI stds with tf_writer.as_default(): stds = jax.tree_map(jax.nn.softplus, params["inv_softplus_std"]) stds = jnp.concatenate([std.reshape(-1) for std in jax.tree_leaves(stds)]) tf.summary.histogram("MFVI/param_stds", stds, step=iteration) tabulate_dict = script_utils.get_tabulate_dict(tabulate_metrics, logging_dict) tabulate_dict["lr"] = lr_schedule(opt_state[-1].count) table = logging_utils.make_table(tabulate_dict, iteration - start_iteration, args.tabulate_freq) print(table)
def train_model(): # Initialize training directory dirname, tf_writer = get_dirname_tfwriter(args) # Initialize data, model, losses and metrics (train_set, test_set, net_apply, params, net_state, key, log_likelihood_fn, log_prior_fn, _, predict_fn, ensemble_upd_fn, metrics_fns, tabulate_metrics) = script_utils.get_data_model_fns(args) # Initialize step-size schedule and optimizer num_batches, total_steps = script_utils.get_num_batches_total_steps( args, train_set) num_devices = len(jax.devices()) lr_schedule = optim_utils.make_cosine_lr_schedule(args.init_step_size, total_steps) optimizer = optim_utils.make_sgd_optimizer( lr_schedule, momentum_decay=args.momentum_decay) # Initialize variables opt_state = optimizer.init(params) net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices)) key = jax.random.split(key, num_devices) init_dict = checkpoint_utils.make_sgd_checkpoint_dict( -1, params, net_state, opt_state, key) init_dict = script_utils.get_initialization_dict(dirname, args, init_dict) start_iteration, params, net_state, opt_state, key = ( checkpoint_utils.parse_sgd_checkpoint_dict(init_dict)) start_iteration += 1 # Define train epoch sgd_train_epoch = script_utils.time_fn( train_utils.make_sgd_train_epoch(net_apply, log_likelihood_fn, log_prior_fn, optimizer, num_batches)) # Train for iteration in range(start_iteration, args.num_epochs): (params, net_state, opt_state, logprob_avg, key), iteration_time = (sgd_train_epoch(params, net_state, opt_state, train_set, key)) # Evaluate the model train_stats, test_stats = {"log_prob": logprob_avg}, {} if (iteration % args.eval_freq == 0) or (iteration == args.num_epochs - 1): _, test_predictions, train_predictions, test_stats, train_stats_ = ( script_utils.evaluate(net_apply, params, net_state, train_set, test_set, predict_fn, metrics_fns, log_prior_fn)) train_stats.update(train_stats_) # Save checkpoint if iteration % args.save_freq == 0 or iteration == args.num_epochs - 1: checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration) checkpoint_path = os.path.join(dirname, checkpoint_name) checkpoint_dict = checkpoint_utils.make_sgd_checkpoint_dict( iteration, params, net_state, opt_state, key) checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict) # Log results other_logs = script_utils.get_common_logs(iteration, iteration_time, args) other_logs["hypers/step_size"] = lr_schedule(opt_state[-1].count) other_logs["hypers/momentum"] = args.momentum_decay logging_dict = logging_utils.make_logging_dict(train_stats, test_stats, {}) logging_dict.update(other_logs) script_utils.write_to_tensorboard(tf_writer, logging_dict, iteration) tabulate_dict = script_utils.get_tabulate_dict(tabulate_metrics, logging_dict) tabulate_dict["lr"] = lr_schedule(opt_state[-1].count) table = logging_utils.make_table(tabulate_dict, iteration - start_iteration, args.tabulate_freq) print(table)
def train_model(): # Initialize training directory dirname, tf_writer = get_dirname_tfwriter(args) # Initialize data, model, losses and metrics (train_set, test_set, net_apply, params, net_state, key, log_likelihood_fn, log_prior_fn, _, predict_fn, ensemble_upd_fn, metrics_fns, tabulate_metrics) = script_utils.get_data_model_fns(args) # Initialize step-size schedule and optimizer num_batches, total_steps = script_utils.get_num_batches_total_steps( args, train_set) num_devices = len(jax.devices()) lr_schedule = get_lr_schedule(num_batches, args) preconditioner = get_preconditioner(args) optimizer = sgmcmc.sgld_gradient_update(lr_schedule, momentum_decay=args.momentum_decay, seed=args.seed, preconditioner=preconditioner) # Initialize variables opt_state = optimizer.init(params) net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices)) key = jax.random.split(key, num_devices) init_dict = checkpoint_utils.make_sgmcmc_checkpoint_dict( -1, params, net_state, opt_state, key, 0, None, None) init_dict = script_utils.get_initialization_dict(dirname, args, init_dict) (start_iteration, params, net_state, opt_state, key, num_ensembled, _, ensemble_predictions) = ( checkpoint_utils.parse_sgmcmc_checkpoint_dict(init_dict)) start_iteration += 1 # Define train epoch sgmcmc_train_epoch = script_utils.time_fn( train_utils.make_sgd_train_epoch(net_apply, log_likelihood_fn, log_prior_fn, optimizer, num_batches)) # Train for iteration in range(start_iteration, args.num_epochs): (params, net_state, opt_state, logprob_avg, key), iteration_time = (sgmcmc_train_epoch(params, net_state, opt_state, train_set, key)) is_evaluation_epoch, is_ensembling_epoch, is_save_epoch = ( is_eval_ens_save_epoch(iteration, args)) # Evaluate the model train_stats, test_stats = {"log_prob": logprob_avg}, {} if is_evaluation_epoch or is_ensembling_epoch: _, test_predictions, train_predictions, test_stats, train_stats_ = ( script_utils.evaluate(net_apply, params, net_state, train_set, test_set, predict_fn, metrics_fns, log_prior_fn)) train_stats.update(train_stats_) # Ensemble predictions if is_ensembling_epoch: ensemble_predictions = ensemble_upd_fn(ensemble_predictions, num_ensembled, test_predictions) ensemble_stats = train_utils.evaluate_metrics( ensemble_predictions, test_set[1], metrics_fns) num_ensembled += 1 else: ensemble_stats = {} test_predictions = None # Save checkpoint if is_save_epoch: checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration) checkpoint_path = os.path.join(dirname, checkpoint_name) checkpoint_dict = checkpoint_utils.make_sgmcmc_checkpoint_dict( iteration, params, net_state, opt_state, key, num_ensembled, test_predictions, ensemble_predictions) checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict) # Log results other_logs = script_utils.get_common_logs(iteration, iteration_time, args) other_logs["hypers/step_size"] = lr_schedule(opt_state.count) other_logs["hypers/momentum"] = args.momentum_decay other_logs["telemetry/num_ensembled"] = num_ensembled logging_dict = logging_utils.make_logging_dict(train_stats, test_stats, ensemble_stats) logging_dict.update(other_logs) script_utils.write_to_tensorboard(tf_writer, logging_dict, iteration) tabulate_dict = script_utils.get_tabulate_dict(tabulate_metrics, logging_dict) tabulate_dict["lr"] = lr_schedule(opt_state.count) table = logging_utils.make_table(tabulate_dict, iteration - start_iteration, args.tabulate_freq) print(table)
def train_model(): subdirname = ( "model_{}_wd_{}_stepsize_{}_trajlen_{}_burnin_{}_{}_mh_{}_temp_{}_" "seed_{}".format(args.model_name, args.weight_decay, args.step_size, args.trajectory_len, args.num_burn_in_iterations, args.burn_in_step_size_factor, not args.no_mh, args.temperature, args.seed)) dirname = os.path.join(args.dir, subdirname) os.makedirs(dirname, exist_ok=True) tf_writer = tf.summary.create_file_writer(dirname) cmd_args_utils.save_cmd(dirname, tf_writer) num_devices = len(jax.devices()) dtype = jnp.float64 if args.use_float64 else jnp.float32 train_set, test_set, task, data_info = data_utils.make_ds_pmap_fullbatch( args.dataset_name, dtype, truncate_to=args.subset_train_to) net_apply, net_init = models.get_model(args.model_name, data_info) net_apply = precision_utils.rewrite_high_precision(net_apply) checkpoint_dict, status = checkpoint_utils.initialize( dirname, args.init_checkpoint) if status == checkpoint_utils.InitStatus.LOADED_PREEMPTED: print("Continuing the run from the last saved checkpoint") (start_iteration, params, net_state, key, step_size, _, num_ensembled, ensemble_predictions) = ( checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict)) else: key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2) start_iteration = 0 num_ensembled = 0 ensemble_predictions = None step_size = args.step_size if status == checkpoint_utils.InitStatus.INIT_CKPT: print("Resuming the run from the provided init_checkpoint") _, params, net_state, _, _, _, _, _ = ( checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict)) elif status == checkpoint_utils.InitStatus.INIT_RANDOM: print("Starting from random initialization with provided seed") key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2) init_data = jax.tree_map(lambda elem: elem[0][:1], train_set) params, net_state = net_init(net_init_key, init_data, True) net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices)) else: raise ValueError( "Unknown initialization status: {}".format(status)) # manually convert all params to dtype params = jax.tree_map(lambda p: p.astype(dtype), params) param_types = tree_utils.tree_get_types(params) assert all([ p_type == dtype for p_type in param_types ]), ("Params data types {} do not match specified data type {}".format( param_types, dtype)) trajectory_len = args.trajectory_len (likelihood_factory, predict_fn, ensemble_upd_fn, metrics_fns, tabulate_metrics) = train_utils.get_task_specific_fns(task, data_info) log_likelihood_fn = likelihood_factory(args.temperature) log_prior_fn, log_prior_diff_fn = losses.make_gaussian_log_prior( args.weight_decay, args.temperature) update, get_log_prob_and_grad = train_utils.make_hmc_update( net_apply, log_likelihood_fn, log_prior_fn, log_prior_diff_fn, args.max_num_leapfrog_steps, args.target_accept_rate, args.step_size_adaptation_speed) log_prob, state_grad, log_likelihood, net_state = (get_log_prob_and_grad( train_set, params, net_state)) assert log_prob.dtype == dtype, ( "log_prob data type {} does not match specified data type {}".format( log_prob.dtype, dtype)) grad_types = tree_utils.tree_get_types(state_grad) assert all([ g_type == dtype for g_type in grad_types ]), ("Gradient data types {} do not match specified data type {}".format( grad_types, dtype)) for iteration in range(start_iteration, args.num_iterations): # do a linear ramp-down of the step-size in the burn-in phase if iteration < args.num_burn_in_iterations: alpha = iteration / (args.num_burn_in_iterations - 1) initial_step_size = args.step_size final_step_size = args.burn_in_step_size_factor * args.step_size step_size = final_step_size * alpha + initial_step_size * (1 - alpha) in_burnin = (iteration < args.num_burn_in_iterations) do_mh_correction = (not args.no_mh) and (not in_burnin) start_time = time.time() (params, net_state, log_likelihood, state_grad, step_size, key, accept_prob, accepted) = (update(train_set, params, net_state, log_likelihood, state_grad, key, step_size, trajectory_len, do_mh_correction)) iteration_time = time.time() - start_time # Evaluation net_state, test_predictions = onp.asarray( predict_fn(net_apply, params, net_state, test_set)) net_state, train_predictions = onp.asarray( predict_fn(net_apply, params, net_state, train_set)) test_stats = train_utils.evaluate_metrics(test_predictions, test_set[1], metrics_fns) train_stats = train_utils.evaluate_metrics(train_predictions, train_set[1], metrics_fns) train_stats["prior"] = log_prior_fn(params) # Ensembling if ((not in_burnin) and accepted) or args.no_mh: ensemble_predictions = ensemble_upd_fn(ensemble_predictions, num_ensembled, test_predictions) ensemble_stats = train_utils.evaluate_metrics( ensemble_predictions, test_set[1], metrics_fns) num_ensembled += 1 else: ensemble_stats = {} # Save the checkpoint checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration) checkpoint_path = os.path.join(dirname, checkpoint_name) checkpoint_dict = checkpoint_utils.make_hmc_checkpoint_dict( iteration, params, net_state, key, step_size, accepted, num_ensembled, ensemble_predictions) checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict) # Logging other_logs = { "telemetry/iteration": iteration, "telemetry/iteration_time": iteration_time, "telemetry/accept_prob": accept_prob, "telemetry/accepted": accepted, "telemetry/num_ensembled": num_ensembled, "hypers/step_size": step_size, "hypers/trajectory_len": trajectory_len, "hypers/weight_decay": args.weight_decay, "hypers/temperature": args.temperature, "debug/do_mh_correction": float(do_mh_correction), "debug/in_burnin": float(in_burnin) } logging_dict = logging_utils.make_logging_dict(train_stats, test_stats, ensemble_stats) logging_dict.update(other_logs) with tf_writer.as_default(): for stat_name, stat_val in logging_dict.items(): tf.summary.scalar(stat_name, stat_val, step=iteration) tabulate_dict = OrderedDict() tabulate_dict["i"] = iteration tabulate_dict["t"] = iteration_time tabulate_dict["accept_p"] = accept_prob tabulate_dict["accepted"] = accepted for metric_name in tabulate_metrics: if metric_name in logging_dict: tabulate_dict[metric_name] = logging_dict[metric_name] else: tabulate_dict[metric_name] = None table = logging_utils.make_table(tabulate_dict, iteration - start_iteration, args.tabulate_freq) print(table)