def run_svgd(key, lr, full_data=False, progress_bar=False): key, subkey = random.split(key) init_particles = ravel(*sample_from_prior(subkey, n_particles)) svgd_opt = optax.sgd(lr) svgd_grad = models.KernelGradient( get_target_logp=lambda batch: get_minibatch_logp(*batch), scaled=True) particles = models.Particles(key, svgd_grad.gradient, init_particles, custom_optimizer=svgd_opt) test_batches = get_batches(x_test, y_test, 2 * NUM_VALS) if full_data else get_batches( x_val, y_val, 2 * NUM_VALS) train_batches = get_batches(xx, yy, NUM_STEPS + 1) if full_data else get_batches( x_train, y_train, NUM_STEPS + 1) for i, batch in tqdm(enumerate(train_batches), total=NUM_STEPS, disable=not progress_bar): particles.step(batch) if i % (NUM_STEPS // NUM_VALS) == 0: test_logp = get_minibatch_logp(*next(test_batches)) stepdata = { "accuracy": compute_test_accuracy(unravel(particles.particles)[0]), "test_logp": test_logp(particles.particles), } metrics.append_to_log(particles.rundata, stepdata) particles.done() return particles
def run_sgld(key, lr, full_data=False, progress_bar=False): key, subkey = random.split(key) init_particles = ravel(*sample_from_prior(subkey, n_particles)) key, subkey = random.split(key) # sgld_opt = utils.scaled_sgld(subkey, lr, schedule) sgld_opt = utils.sgld(lr, 0) def energy_gradient(data, particles, aux=True): """data = [batch_x, batch_y]""" xbatch, ybatch = data logp = get_minibatch_logp(xbatch, ybatch) logprob, grads = value_and_grad(logp)(particles) if aux: return -grads, {"logp": logprob} else: return -grads particles = models.Particles(key, energy_gradient, init_particles, custom_optimizer=sgld_opt) test_batches = get_batches(x_test, y_test, 2 * NUM_VALS) if full_data else get_batches( x_val, y_val, 2 * NUM_VALS) train_batches = get_batches(xx, yy, NUM_STEPS + 1) if full_data else get_batches( x_train, y_train, NUM_STEPS + 1) for i, batch_xy in tqdm(enumerate(train_batches), total=NUM_STEPS, disable=not progress_bar): particles.step(batch_xy) if i % (NUM_STEPS // NUM_VALS) == 0: test_logp = get_minibatch_logp(*next(test_batches)) stepdata = { "accuracy": compute_test_accuracy(unravel(particles.particles)[0]), "train_accuracy": compute_train_accuracy(unravel(particles.particles)[0]), "test_logp": np.mean(test_logp(particles.particles)) } metrics.append_to_log(particles.rundata, stepdata) particles.done() return particles
def run_neural_svgd(key, plr, full_data=False, progress_bar=False): key, subkey = random.split(key) init_particles = ravel(*sample_from_prior(subkey, n_particles)) nsvgd_opt = optax.sgd(plr) key1, key2 = random.split(key) neural_grad = models.SteinNetwork( target_dim=init_particles.shape[1], #get_target_logp=lambda batch: get_minibatch_logp(*batch), learning_rate=neural_lr, key=key1, aux=False, lambda_reg=lambda_reg) particles = models.Particles(key2, neural_grad.gradient, init_particles, custom_optimizer=nsvgd_opt) test_batches = get_batches(x_test, y_test, 2 * NUM_VALS) if full_data else get_batches( x_val, y_val, 2 * NUM_VALS) train_batches = get_batches(xx, yy, NUM_STEPS + 1) if full_data else get_batches( x_train, y_train, NUM_STEPS + 1) @jit def v_dlogp(particles, batch): logp = get_minibatch_logp(*batch) return vmap(grad(logp))(particles) # Warmup on first batch # key, subkey = random.split(key) # neural_grad.warmup(key=subkey, # sample_split_particles=sample_tv, # next_data=lambda: next(get_batches(x_train, y_train, n_steps=100+1)), # note: lambda always returns first batch # n_iter=3) first_batch = next(get_batches(x_train, y_train, n_steps=100 + 1)) for _ in range(3): key, subkey = random.split(key) split_particles = sample_tv(subkey) split_dlogp = [v_dlogp(x, first_batch) for x in split_particles] neural_grad.train(split_particles, split_dlogp, n_steps=30, early_stopping=True) for i, data_batch in tqdm(enumerate(train_batches), total=NUM_STEPS, disable=not progress_bar): key, subkey = random.split(key) split_particles = particles.next_batch(subkey) split_dlogp = [v_dlogp(x, data_batch) for x in split_particles] neural_grad.train(split_particles, split_dlogp, n_steps=10) particles.step(neural_grad.get_params()) if i % (NUM_STEPS // NUM_VALS) == 0: test_logp = get_minibatch_logp(*next(test_batches)) train_logp = get_minibatch_logp(*data_batch) stepdata = { "accuracy": compute_test_accuracy(unravel(particles.particles)[0]), "test_logp": test_logp(particles.particles), "training_logp": train_logp(particles.particles), } metrics.append_to_log(particles.rundata, stepdata) neural_grad.done() particles.done() return particles, neural_grad
def train(key, particle_stepsize: float = 1e-3, evaluate_every: int = 10, n_iter: int = 400, n_samples: int = 100, results_file: str = cfg.results_path + 'svgd-bnn.csv', overwrite_file: bool = False, optimizer="sgd"): """ Initialize model; warmup; training; evaluation. Returns a dictionary of metrics. Args: particle_stepsize: learning rate of BNN evaluate_every: compute metrics every `evaluate_every` steps n_iter: number of train-update iterations write_results_to_file: whether to save accuracy in csv file """ csv_string = f"{particle_stepsize}" # initialize particles and the dynamics model key, subkey = random.split(key) init_particles = vmap(bnn.init_flat_params)(random.split( subkey, n_samples)) if optimizer == "sgd": opt = optax.sgd(particle_stepsize) elif optimizer == "adam": opt = optax.adam(particle_stepsize) else: raise ValueError("must be adam or sgd") key, subkey1, subkey2 = random.split(key, 3) svgd_grad = models.KernelGradient(get_target_logp=bnn.get_minibatch_logp, scaled=False, lambda_reg=LAMBDA_REG) particles = models.Particles(key=subkey2, gradient=svgd_grad.gradient, init_samples=init_particles, custom_optimizer=opt) def evaluate(step_counter, ps): stepdata = { "accuracy": bnn.compute_acc_from_flat(ps), "step_counter": step_counter, } with open(results_file, "a") as file: file.write(csv_string + f"{step_counter},{stepdata['accuracy']}\n") return stepdata if not os.path.isfile(results_file) or overwrite_file: with open(results_file, "w") as file: file.write("meta_lr,particle_stepsize,patience," "max_train_steps,step,accuracy\n") print("Training...") for step_counter in tqdm(range(n_iter), disable=on_cluster): train_batch = next(data.train_batches) particles.step(train_batch) if (step_counter + 1) % evaluate_every == 0: metrics.append_to_log(particles.rundata, evaluate(step_counter, particles.particles)) if step_counter % data.steps_per_epoch == 0: print(f"Starting epoch {step_counter // data.steps_per_epoch + 1}") # final eval final_eval = evaluate(-1, particles.particles) particles.done() return final_eval['accuracy']
def train( key, meta_lr: float = DEFAULT_META_LR, particle_stepsize: float = 1e-3, evaluate_every: int = 10, n_iter: int = 200, n_samples: int = cfg.n_samples + 1, # add 1 to account for dummy val set particle_steps_per_iter: int = 1, max_train_steps_per_iter: int = DEFAULT_MAX_TRAIN_STEPS, patience: int = DEFAULT_PATIENCE, dropout: bool = True, results_file: str = cfg.results_path + 'nvgd-bnn.csv', overwrite_file: bool = False, early_stopping: bool = True, optimizer: str = "sgd", hidden_sizes=[DEFAULT_LAYER_SIZE] * 3, use_hypernetwork: bool = False): """ Initialize model; warmup; training; evaluation. Returns a dictionary of metrics. Args: meta_lr: learning rate of Stein network particle_stepsize: learning rate of BNN evaluate_every: compute metrics every `evaluate_every` steps n_iter: number of train-update iterations particle_steps_per_iter: num particle updates after training Stein network max_train_steps_per_iter: cutoff for Stein network training iteration patience: early stopping criterion dropout: use dropout during training of the Stein network write_results_to_file: whether to save accuracy in csv file use_hypernetwork: whether to use net-to-net hypernetwork to model the witness function """ # csv_string = f"{meta_lr},{particle_stepsize}," \ # f"{patience},{max_train_steps_per_iter}," # initialize particles and the dynamics model key, subkey = random.split(key) init_particles = vmap(bnn.init_flat_params)(random.split( subkey, n_samples)) if optimizer == "sgd": opt = optax.sgd(particle_stepsize) elif optimizer == "adam": opt = optax.adam(particle_stepsize) else: raise ValueError("optimizer must be sgd or adam") key, subkey1, subkey2 = random.split(key, 3) neural_grad = models.SteinNetwork(target_dim=init_particles.shape[1], learning_rate=meta_lr, key=subkey1, sizes=hidden_sizes + [init_particles.shape[1]], aux=False, use_hutchinson=True, lambda_reg=LAMBDA_REG, patience=patience, dropout=dropout, particle_unravel=nets.cnn_unravel, hypernet=use_hypernetwork) particles = models.Particles(key=subkey2, gradient=neural_grad.gradient, init_samples=init_particles, custom_optimizer=opt) minibatch_vdlogp = vmap(value_and_grad(bnn.minibatch_logp), (0, None)) @jit def split_vdlogp(split_particles, train_batch): """returns tuple (split_logp, split_dlogp)""" train_out, val_out = [ minibatch_vdlogp(x, train_batch) for x in split_particles ] return tuple(zip(train_out, val_out)) def step(split_particles, split_dlogp): """one iteration of the particle trajectory simulation""" neural_grad.train(split_particles=split_particles, split_dlogp=split_dlogp, n_steps=max_train_steps_per_iter, early_stopping=early_stopping) for _ in range(particle_steps_per_iter): particles.step(neural_grad.get_params()) return def evaluate(step_counter, ps, logp): ll = logp.mean() stepdata = { "accuracy": bnn.compute_acc_from_flat(ps), "step_counter": step_counter, "loglikelihood": ll, } with open(results_file, "a") as file: file.write(f"{step_counter},{stepdata['accuracy']},{ll}\n") return stepdata if not os.path.isfile(results_file) or overwrite_file: with open(results_file, "w") as file: file.write("step_counter,accuracy,loglikelihood\n") print("Training...") for step_counter in tqdm(range(n_iter), disable=on_cluster): key, subkey = random.split(key) train_batch = next(data.train_batches) n_train_particles = 3 * n_samples // 4 if early_stopping else n_samples - 1 split_particles = particles.next_batch( key, n_train_particles=n_train_particles) split_logp, split_dlogp = split_vdlogp(split_particles, train_batch) step(split_particles, split_dlogp) if (step_counter + 1) % evaluate_every == 0: eval_ps = particles.particles if early_stopping else split_particles[ 0] metrics.append_to_log( particles.rundata, evaluate(step_counter, eval_ps, split_logp[0])) if step_counter % data.steps_per_epoch == 0: print(f"Starting epoch {step_counter // data.steps_per_epoch + 1}") neural_grad.done() particles.done() final_eval = evaluate(-1, particles.particles, split_particles[0]) return round(float(final_eval['accuracy']), 4), particles.rundata