class ExperimentalOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @chex.all_variants() @parameterized.named_parameters( ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop( LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR_SCHED, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR_SCHED, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ) def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
def args_to_op(optimizer_string, lr, mom=0.9, var=0.999, eps=1e-7): return { "gd": lambda lr, *unused: op.sgd(lr), "sgd": lambda lr, *unused: op.sgd(lr), "momentum": lambda lr, mom, *unused: op.momentum(lr, mom), "adam": lambda lr, mom, var, eps: op.adam(lr, mom, var, eps), }[optimizer_string.lower()](lr, mom, var, eps)
def _define_optimizer_JAX(dataset, learn_rate): if 'purchases' == dataset: return optimizers.adam(learn_rate) # return optimizers.sgd(learn_rate) # temporarily error.... elif 'fashion_mnist' == dataset: return optimizers.sgd(learn_rate) elif 'cifar10' == dataset: return optimizers.sgd(learn_rate) else: assert False, ('Error: unknown dataset - {}'.format(dataset))
def train_devise( weights, X, y, label_embeddings, margin, n_epochs=2, learning_rate=0.001, batch_size=16, ): opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_state = opt_init(weights) loader = torch_util.get_dataloader(np.array(X), np.array(y), batch_size=batch_size) loss_fn = get_similarity_based_hinge_loss(label_embeddings, X, y, margin) gradient_fn = jax.jit(jax.grad(loss_fn)) for i in tqdm.tqdm(range(n_epochs)): for (j, (embeddings_np, labels_np)) in enumerate(loader): embeddings, labels = jnp.array(embeddings_np), jnp.array(labels_np) grads = gradient_fn(weights) opt_state = opt_update(j, grads, opt_state) weights = get_params(opt_state) return weights
def fit_mixture(data, num_components=3, verbose=False, num_samples=5000) -> LogisticMixtureParams: # the data might be something weird, like a pandas dataframe column; # turn it into a regular old numpy array data_as_np_array = np.array(data) step_size = 0.01 components = initialize_components(num_components) (init_fun, update_fun, get_params) = sgd(step_size) opt_state = init_fun(components) for i in tqdm(range(num_samples)): components = get_params(opt_state) grads = -grad_mixture_logpdf(data_as_np_array, components) if np.any(np.isnan(grads)): print("Encoutered nan gradient, stopping early") print(grads) print(components) break grads = clip_grads(grads, 1.0) opt_state = update_fun(i, grads, opt_state) if i % 500 == 0 and verbose: pprint(components) score = mixture_logpdf(data_as_np_array, components) print(f"Log score: {score:.3f}") return structure_mixture_params(components)
def test_optimize_rotation(self): opt_init, opt_update, get_params = optimizers.sgd(5e-2) x0 = onp.array([ [1.0, 0.2, 3.3], # H [-0.6,-1.1,-0.9],# C [3.4, 5.5, 0.2], # H [3.6, 5.6, 0.6], # H ], dtype=onp.float64) x1 = onp.array([ [1.0, 0.2, 3.3], # H [-0.6,-1.1,-0.9],# C [3.4, 5.5, 0.2], # H [3.6, 5.6, 0.6], # H ], dtype=onp.float64) + onp.random.rand(x0.shape[0],3)*10 grad_fn = jax.jit(jax.grad(rmsd.opt_rot_rmsd, argnums=(0,))) opt_state = opt_init(x0) for i in range(1500): g = grad_fn(get_params(opt_state), x1)[0] opt_state = opt_update(i, g, opt_state) x_final = get_params(opt_state) assert rmsd.opt_rot_rmsd(x_final, x1) < 0.1
def testMaxLearningRate(self, train_shape, network, out_logits, fn_and_kernel): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key) key = keys[0] split = keys[1] if len(train_shape) == 2: train_shape = (train_shape[0] * 5, train_shape[1] * 10) else: train_shape = (16, 8, 8, 3) x_train = np.asarray(normal(train_shape, seed=split)) keys = tf_random_split(key) key = keys[0] split = keys[1] y_train = np.asarray( stateless_uniform(shape=(train_shape[0], out_logits), seed=split, minval=0, maxval=1) < 0.5, np.float32) # Regress to an MSE loss. loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train)**2) grad_loss = jit(grad(loss)) def get_loss(opt_state): return loss(get_params(opt_state), x_train) steps = 20 for lr_factor in [0.5, 3.]: params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') step_size = predict.max_learning_rate( g_dd, y_train_size=y_train.size) * lr_factor opt_init, opt_update, get_params = optimizers.sgd(step_size) opt_state = opt_init(params) init_loss = get_loss(opt_state) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) trained_loss = get_loss(opt_state) loss_ratio = trained_loss / (init_loss + 1e-12) if lr_factor == 3.: if not math.isnan(loss_ratio): self.assertGreater(loss_ratio, 10.) else: self.assertLess(loss_ratio, 0.1)
def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits, fn_and_kernel, momentum, learning_rate, t, loss): key, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape, train_shape) params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') g_td = ntk(x_test, x_train, 'ntk') # Regress to an MSE loss. loss_fn = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2) grad_loss = jit(grad(lambda params, x: loss_fn(f(params, x), y_train))) trace_axes = () if g_dd.ndim == 4 else (-1,) if loss == 'mse_analytic': if momentum is not None: raise absltest.SkipTest(momentum) predictor = predict.gradient_descent_mse(g_dd, y_train, learning_rate=learning_rate, trace_axes=trace_axes) elif loss == 'mse': predictor = predict.gradient_descent(loss_fn, g_dd, y_train, learning_rate=learning_rate, momentum=momentum, trace_axes=trace_axes) else: raise NotImplementedError(loss) predictor = jit(predictor) fx_train_0 = f(params, x_train) fx_test_0 = f(params, x_test) self._test_zero_time(predictor, fx_train_0, fx_test_0, g_td, momentum) self._test_multi_step(predictor, fx_train_0, fx_test_0, g_td, momentum) if loss == 'mse_analytic': self._test_inf_time(predictor, fx_train_0, fx_test_0, g_td, y_train) if momentum is None: opt_init, opt_update, get_params = optimizers.sgd(learning_rate) else: opt_init, opt_update, get_params = optimizers.momentum(learning_rate, momentum) opt_state = opt_init(params) for i in range(t): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) params = get_params(opt_state) fx_train_nn, fx_test_nn = f(params, x_train), f(params, x_test) fx_train_t, fx_test_t = predictor(t, fx_train_0, fx_test_0, g_td) self.assertAllClose(fx_train_nn, fx_train_t, rtol=RTOL, atol=ATOL) self.assertAllClose(fx_test_nn, fx_test_t, rtol=RTOL, atol=ATOL)
def testIssue758(self): # code from https://github.com/google/jax/issues/758 # this is more of a scan + jacfwd/jacrev test, but it lives here to use the # optimizers.py code def harmonic_bond(conf, params): return np.sum(conf * params) opt_init, opt_update, get_params = optimizers.sgd(5e-2) x0 = onp.array([0.5], dtype=onp.float64) params = onp.array([0.3], dtype=onp.float64) def minimize_structure(test_params): energy_fn = functools.partial(harmonic_bond, params=test_params) grad_fn = grad(energy_fn, argnums=(0, )) opt_state = opt_init(x0) def apply_carry(carry, _): i, x = carry g = grad_fn(get_params(x))[0] new_state = opt_update(i, g, x) new_carry = (i + 1, new_state) return new_carry, _ carry_final, _ = lax.scan(apply_carry, (0, opt_state), np.zeros((75, 0))) trip, opt_final = carry_final assert trip == 75 return opt_final initial_params = np.float64(0.5) minimize_structure(initial_params) def loss(test_params): opt_final = minimize_structure(test_params) return 1.0 - get_params(opt_final)[0] loss_opt_init, loss_opt_update, loss_get_params = optimizers.sgd(5e-2) J1 = jacrev(loss, argnums=(0, ))(initial_params) J2 = jacfwd(loss, argnums=(0, ))(initial_params) self.assertAllClose(J1, J2, check_dtypes=True, rtol=1e-6)
def testMaxLearningRate(self, train_shape, network, out_logits, fn_and_kernel, name): key = random.PRNGKey(0) key, split = random.split(key) if len(train_shape) == 2: train_shape = (train_shape[0] * 5, train_shape[1] * 10) else: train_shape = (16, 8, 8, 3) x_train = random.normal(split, train_shape) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) for lr_factor in [0.5, 3.]: params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) # Regress to an MSE loss. loss = lambda params, x: \ 0.5 * np.mean((f(params, x) - y_train) ** 2) grad_loss = jit(grad(loss)) g_dd = ntk(x_train, None, 'ntk') steps = 20 if name == 'theoretical': step_size = predict.max_learning_rate( g_dd, num_outputs=out_logits) * lr_factor else: step_size = predict.max_learning_rate( g_dd, num_outputs=-1) * lr_factor opt_init, opt_update, get_params = optimizers.sgd(step_size) opt_state = opt_init(params) def get_loss(opt_state): return loss(get_params(opt_state), x_train) init_loss = get_loss(opt_state) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) trained_loss = get_loss(opt_state) loss_ratio = trained_loss / (init_loss + 1e-12) if lr_factor == 3.: if not math.isnan(loss_ratio): self.assertGreater(loss_ratio, 10.) else: self.assertLess(loss_ratio, 0.1)
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): training_steps = 1000 learning_rate = 0.1 ensemble_size = 1024 init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) key, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape, train_shape) predict_fn_mse_ens = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, learning_rate=learning_rate, diag_reg=0.) train = (x_train, y_train) ensemble_key = random.split(key, ensemble_size) loss = jit(lambda params, x, y: 0.5 * np.mean((apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss)(get_params(state), x, y)) def train_network(key): _, params = init_fn(key, (-1,) + train_shape[1:]) opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) return get_params(opt_state) params = vmap(train_network)(ensemble_key) rtol = 0.08 for x in [None, 'x_test']: with self.subTest(x=x): x = x if x is None else x_test x_fin = x_train if x is None else x_test ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin) mean_emp = np.mean(ensemble_fx, axis=0, keepdims=True) mean_subtracted = ensemble_fx - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / ( mean_subtracted.shape[0] * mean_subtracted.shape[-1]) ntk = predict_fn_mse_ens(training_steps, x, 'ntk', compute_cov=True) self._assertAllClose(mean_emp, ntk.mean, rtol) self._assertAllClose(cov_emp, ntk.covariance, rtol)
def get_optimizer(optimizer, sched, b1=0.9, b2=0.999): if optimizer.lower() == 'adagrad': return optimizers.adagrad(sched) elif optimizer.lower() == 'adam': return optimizers.adam(sched, b1, b2) elif optimizer.lower() == 'rmsprop': return optimizers.rmsprop(sched) elif optimizer.lower() == 'momentum': return optimizers.momentum(sched, 0.9) elif optimizer.lower() == 'sgd': return optimizers.sgd(sched) else: raise Exception('Invalid optimizer: {}'.format(optimizer))
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, apply_fn, _ = stax.serial( stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn, vmap_axes=0), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train) # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
def optimizer(name="adam", momentum_mass=0.9, rmsprop_gamma=0.9, rmsprop_eps=1e-8, adam_b1=0.9, adam_b2=0.997, adam_eps=1e-8): """Return the optimizer, by name.""" if name == "sgd": return optimizers.sgd(learning_rate) if name == "momentum": return optimizers.momentum(learning_rate, mass=momentum_mass) if name == "rmsprop": return optimizers.rmsprop( learning_rate, gamma=rmsprop_gamma, eps=rmsprop_eps) if name == "adam": return optimizers.adam(learning_rate, b1=adam_b1, b2=adam_b2, eps=adam_eps) raise ValueError("Unknown optimizer %s" % str(name))
def testTracedStepSize(self): def loss(x): return jnp.dot(x, x) x0 = jnp.ones(2) step_size = 0.1 init_fun, _, _ = optimizers.sgd(step_size) opt_state = init_fun(x0) @jit def update(opt_state, step_size): _, update_fun, get_params = optimizers.sgd(step_size) x = get_params(opt_state) g = grad(loss)(x) return update_fun(0, g, opt_state) update(opt_state, 0.9) # doesn't crash
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, f = stax.serial(layers.Dense(4096), stax.Tanh, layers.Dense(10)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. theta = tangents.ntk(f, batch_size=32) g_dd = theta(params, x_train) import ipdb ipdb.set_trace() g_td = theta(params, x_test, x_train) predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td) # Get initial values of the network in function space. fx_train = f(params, x_train) fx_test = f(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(fx_train, fx_test, FLAGS.train_time) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, f(params, x_train), fx_train, loss) util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
def testTracedStepSize(self): def loss(x, _): return np.dot(x, x) x0 = np.ones(2) num_iters = 100 step_size = 0.1 init_fun, _ = optimizers.sgd(step_size) opt_state = init_fun(x0) @jit def update(opt_state, step_size): _, update_fun = optimizers.sgd(step_size) x = optimizers.get_params(opt_state) g = grad(loss)(x, None) return update_fun(0, g, opt_state) update(opt_state, 0.9) # doesn't crash
def relaxed_gd(X, phi_y, W_0, lr, beta_, lambda_, tol, N_max, verbose=True): """Run gradient descent on relaxed_loss.""" print(f"Lambda {lambda_}") dim_H_ = phi_y.shape[1] n_features_ = X.shape[1] predictor_shape = ( dim_H_, n_features_, ) assert W_0.shape == predictor_shape opt_init, opt_update, get_params = optimizers.sgd(lr) opt_state = opt_init(W_0) def step(opt_state, X, phi_y, beta, lambda_, i): value, grads = jax.value_and_grad(relaxed_predict_loss)( get_params(opt_state), X, phi_y, beta, lambda_) if (i < 10 or i % 100 == 0) and verbose: print(f"Step {i} value: \t\t{value.item()}") opt_state = opt_update(i, grads, opt_state) return value, opt_state old_value = jnp.inf diff = jnp.inf step_index = 0 while step_index < N_max: value, opt_state = step(opt_state, X, phi_y, beta_, lambda_, step_index) diff = abs(old_value - value) old_value = value step_index += 1 status = { "max_steps": step_index >= N_max, "tol_max": tol >= diff.item(), "diff": diff.item(), "step_index": step_index, "final_value": value.item(), } if verbose: print(status) return get_params(opt_state), status
def privately_train(rng, params, predict, X, y): """Generic train function called for each slice. Responsible for, given an rng key, a set of parameters to be trained, some inputs X and some outputs y, finetuning the params on X and y according to some internally defined training configuration. """ locals().update(private_training_parameters) def clipped_grad(params, l2_norm_clip, single_example_batch): grads = grad(loss)(params, predict, single_example_batch) nonempty_grads, tree_def = tree_flatten(grads) total_grad_norm = np.linalg.norm([np.linalg.norm(neg.ravel()) for neg in nonempty_grads]) divisor = np.max((total_grad_norm / l2_norm_clip, 1.)) normalized_nonempty_grads = [g / divisor for g in nonempty_grads] return tree_unflatten(tree_def, normalized_nonempty_grads) def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier, batch_size): # Add batch dimension for when each example is separated batch = (np.expand_dims(batch[0], 1), np.expand_dims(batch[1], 1)) clipped_grads = vmap(clipped_grad, (None, None, 0))(params, l2_norm_clip, batch) clipped_grads_flat, grads_treedef = tree_flatten(clipped_grads) aggregated_clipped_grads = [np.sum(g, 0) for g in clipped_grads_flat] rngs = random.split(rng, len(aggregated_clipped_grads)) noised_aggregated_clipped_grads = [g + l2_norm_clip * noise_multiplier * random.normal(r, g.shape) for r, g in zip(rngs, aggregated_clipped_grads)] normalized_noised_aggregated_clipped_grads = [g / batch_size for g in noised_aggregated_clipped_grads] return tree_unflatten(grads_treedef, normalized_noised_aggregated_clipped_grads) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) grads = private_grad(params, batch, rng, l2_norm_clip, noise_multiplier, batch_size) return opt_update(i, grads, opt_state) opt_init, opt_update, get_params = optimizers.sgd(step_size) opt_state = opt_init(params) batches = data_stream(rng, batch_size, X, y) itercount = itertools.count() for _ in range(iterations): temp, rng = random.split(rng) opt_state = private_update(temp, next(itercount), opt_state, next(batches)) return get_params(opt_state)
def run_optimiser(Niters, l_rate, x_data, y_data, params_IC): if optimiser_type == "sgd": opt_init, opt_update, get_params = optimizers.sgd(l_rate) elif optimiser_type == "adam": opt_init, opt_update, get_params = optimizers.adam(l_rate) else: raise ValueError("Optimiser not added.") @progress_bar_scan(Niters) def body(state, step): loss_val, loss_grad = val_and_grad_loss(get_params(state), x_data, y_data) state = opt_update(step, loss_grad, state) return state, loss_val state, loss_array = lax.scan(body, opt_init(params_IC), jnp.arange(Niters)) return get_params(state), loss_array
def test_sgd(self): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = optimizers.sgd(LR) state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # experimental/optix.py optix_params = self.init_params sgd = optix.sgd(LR, 0.0) state = sgd.init(optix_params) for _ in range(STEPS): updates, state = sgd.update(self.per_step_updates, state) optix_params = optix.apply_updates(optix_params, updates) # Check equivalence. for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)): np.testing.assert_allclose(x, y, rtol=1e-5)
def get_optimizer(self, optim=None, stage='learn', step_size=None): if optim is None: if stage == 'learn': optim = self.optim_learn else: optim = self.optim_proj if step_size is None: step_size = self.step_size if optim == 1: if self.verb > 2: print("With momentum optimizer") opt_init, opt_update, get_params = momentum(step_size=step_size, mass=0.95) elif optim == 2: if self.verb > 2: print("With rmsprop optimizer") opt_init, opt_update, get_params = rmsprop(step_size, gamma=0.9, eps=1e-8) elif optim == 3: if self.verb > 2: print("With adagrad optimizer") opt_init, opt_update, get_params = adagrad(step_size, momentum=0.9) elif optim == 4: if self.verb > 2: print("With Nesterov optimizer") opt_init, opt_update, get_params = nesterov(step_size, 0.9) elif optim == 5: if self.verb > 2: print("With SGD optimizer") opt_init, opt_update, get_params = sgd(step_size) else: if self.verb > 2: print("With adam optimizer") opt_init, opt_update, get_params = adam(step_size) return opt_init, opt_update, get_params
def gamma2kappa(g1, g2, kappa_shape, obj, step_size=0.01, n_iter=1000): # Set up optimizer init_kE = np.zeros(kappa_shape) init_kB = np.zeros(kappa_shape) init_params = (init_kE, init_kB) opt_init, opt_update, get_params = optimizers.sgd(step_size=0.001) opt_state = opt_init(init_params) @jit def update(i, opt_state): params = get_params(opt_state) gradient = grad(obj)(params, g1, g2) return opt_update(i, gradient, opt_state) # Loop for t in range(n_iter): opt_state = update(t, opt_state) params = get_params(opt_state) kEhat, kBhat = params return kEhat, kBhat
def get_optimizer( learning_rate: float = 1e-4, optimizer="sdg", optimizer_kwargs: dict = None ) -> JaxOptimizer: """Return a `JaxOptimizer` dataclass for a JAX optimizer Args: learning_rate (float, optional): Step size. Defaults to 1e-4. optimizer (str, optional): Optimizer type (Allowed types: "adam", "adamax", "adagrad", "rmsprop", "sdg"). Defaults to "sdg". optimizer_kwargs (dict, optional): Additional keyword arguments that are passed to the optimizer. Defaults to None. Returns: JaxOptimizer """ from jax.config import config # pylint:disable=import-outside-toplevel config.update("jax_enable_x64", True) from jax import jit # pylint:disable=import-outside-toplevel from jax.experimental import optimizers # pylint:disable=import-outside-toplevel if optimizer_kwargs is None: optimizer_kwargs = {} optimizer = optimizer.lower() if optimizer == "adam": opt_init, opt_update, get_params = optimizers.adam(learning_rate, **optimizer_kwargs) elif optimizer == "adagrad": opt_init, opt_update, get_params = optimizers.adagrad(learning_rate, **optimizer_kwargs) elif optimizer == "adamax": opt_init, opt_update, get_params = optimizers.adamax(learning_rate, **optimizer_kwargs) elif optimizer == "rmsprop": opt_init, opt_update, get_params = optimizers.rmsprop(learning_rate, **optimizer_kwargs) else: opt_init, opt_update, get_params = optimizers.sgd(learning_rate, **optimizer_kwargs) opt_update = jit(opt_update) return JaxOptimizer(opt_init, opt_update, get_params)
def run_vmc(steps, step_size, diag_shift, n_samples): opt = jaxopt.sgd(step_size) # opt = nk.optimizer.Sgd(step_size) sr = nk.optimizer.SR(lsq_solver="BDCSVD", diag_shift=diag_shift) sr.store_rank_enabled = False # not supported by BDCSVD sr.store_covariance_matrix_enabled = True vmc = nk.Vmc( hamiltonian=ha, sampler=sa, optimizer=opt, n_samples=n_samples, n_discard=min(n_samples // 10, 200), sr=sr, ) if mpi_rank == 0: print(vmc.info()) print(HEADER_STRING) for step in vmc.iter(steps, 1): output(vmc, step)
def test_validate_optimizers(): """Make sure we correctly check/standardize the optimizers.""" from jax import jit # pylint:disable=import-outside-toplevel from jax.experimental import optimizers # pylint:disable=import-outside-toplevel from pyepal.models.nt import JaxOptimizer # pylint:disable=import-outside-toplevel opt_init, opt_update, get_params = optimizers.sgd(1e-3) opt_update = jit(opt_update) optimizer = JaxOptimizer(opt_init, opt_update, get_params) optimizers = [ JaxOptimizer(opt_init, opt_update, get_params), JaxOptimizer(opt_init, opt_update, get_params), ] with pytest.raises(ValueError): validate_optimizers(opt_init, 2) with pytest.raises(ValueError): validate_optimizers([optimizer], 2) assert validate_optimizers(optimizers, 2) == optimizers with pytest.raises(ValueError): validate_optimizers(optimizer, 2)
def update(opt_state, step_size): _, update_fun, get_params = optimizers.sgd(step_size) x = get_params(opt_state) g = grad(loss)(x) return update_fun(0, g, opt_state)
st.pyplot() plt.close() st.header("Comparison against finite SGD-NNs") """ Finally, let's bring back the practioner loved Neural Networks for a comparison. """ learning_rate = st.slider("Learning rate", 1e-4, 1.0, 0.1, step=1e-4, format="%.4f") opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) loss = jit(lambda params, x, y: 0.5 * np.mean((apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss)(get_params(state), x, y)) train_losses = [] test_losses = [] opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) train_losses += [loss(get_params(opt_state), *train)] test_losses += [loss(get_params(opt_state), *test)]
def main(_): sns.set() sns.set_palette(sns.color_palette('hls', 10)) npr.seed(FLAGS.seed) logging.info('Starting experiment.') # Create model folder for outputs try: gfile.MakeDirs(FLAGS.work_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+') # use mean/std of svhn train train_images, _, _ = datasets.get_dataset_split( name=FLAGS.train_split.split('-')[0], split=FLAGS.train_split.split('-')[1], shuffle=False) train_mu, train_std = onp.mean(train_images), onp.std(train_images) del train_images # BEGIN: fetch test data and candidate pool test_images, test_labels, _ = datasets.get_dataset_split( name=FLAGS.test_split.split('-')[0], split=FLAGS.test_split.split('-')[1], shuffle=False) pool_images, pool_labels, _ = datasets.get_dataset_split( name=FLAGS.pool_split.split('-')[0], split=FLAGS.pool_split.split('-')[1], shuffle=False) n_pool = len(pool_images) test_images = (test_images - train_mu) / train_std # normalize w train mu/std pool_images = (pool_images - train_mu) / train_std # normalize w train mu/std # augmentation for train/pool data if FLAGS.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None # END: fetch test data and candidate pool # BEGIN: load ckpt opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) if FLAGS.pretrained_dir is not None: with gfile.Open(FLAGS.pretrained_dir, 'rb') as fpre: pretrained_opt_state = optimizers.pack_optimizer_state( pickle.load(fpre)) fixed_params = get_params(pretrained_opt_state)[:7] ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx) with gfile.Open(ckpt_dir, 'wr') as fckpt: opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt)) params = get_params(opt_state) # combine fixed pretrained params and dpsgd trained last layers params = fixed_params + params opt_state = opt_init(params) else: ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx) with gfile.Open(ckpt_dir, 'wr') as fckpt: opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt)) params = get_params(opt_state) stdout_log.write('finetune from: {}\n'.format(ckpt_dir)) logging.info('finetune from: %s', ckpt_dir) test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) logging.info('test accuracy: %.2f', test_acc) stdout_log.write('test accuracy: {}\n'.format(test_acc)) stdout_log.flush() # END: load ckpt # BEGIN: setup for dp model @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad_loss(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) # END: setup for dp model ### BEGIN: prepare extra points picked from pool data # BEGIN: on pool data pool_embeddings = [apply_fn_0(params[:-1], pool_images[b_i:b_i + FLAGS.batch_size]) \ for b_i in range(0, n_pool, FLAGS.batch_size)] pool_embeddings = np.concatenate(pool_embeddings, axis=0) pool_logits = apply_fn_1(params[-1:], pool_embeddings) pool_true_labels = np.argmax(pool_labels, axis=1) pool_predicted_labels = np.argmax(pool_logits, axis=1) pool_correct_indices = \ onp.where(pool_true_labels == pool_predicted_labels)[0] pool_incorrect_indices = \ onp.where(pool_true_labels != pool_predicted_labels)[0] assert len(pool_correct_indices) + \ len(pool_incorrect_indices) == len(pool_labels) pool_probs = stax.softmax(pool_logits) if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy': pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1) stdout_log.write('all {} entropy: min {}, max {}\n'.format( len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy))) pool_entropy_sorted_indices = onp.argsort(pool_entropy) # take the n_extra most uncertain points pool_uncertain_indices = \ pool_entropy_sorted_indices[::-1][:FLAGS.n_extra] stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format( len(pool_entropy[pool_uncertain_indices]), onp.min(pool_entropy[pool_uncertain_indices]), onp.max(pool_entropy[pool_uncertain_indices]))) elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference': # 1st_prob - 2nd_prob assert len(pool_probs.shape) == 2 sorted_pool_probs = onp.sort(pool_probs, axis=1) pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2] assert min(pool_probs_diff) > 0. stdout_log.write('all {} difference: min {}, max {}\n'.format( len(pool_probs_diff), onp.min(pool_probs_diff), onp.max(pool_probs_diff))) pool_uncertain_indices = onp.argsort(pool_probs_diff)[:FLAGS.n_extra] stdout_log.write('uncertain {} difference: min {}, max {}\n'.format( len(pool_probs_diff[pool_uncertain_indices]), onp.min(pool_probs_diff[pool_uncertain_indices]), onp.max(pool_probs_diff[pool_uncertain_indices]))) elif FLAGS.uncertain == 2 or FLAGS.uncertain == 'random': pool_uncertain_indices = npr.permutation(n_pool)[:FLAGS.n_extra] # END: on pool data ### END: prepare extra points picked from pool data finetune_images = copy.deepcopy(pool_images[pool_uncertain_indices]) finetune_labels = copy.deepcopy(pool_labels[pool_uncertain_indices]) stdout_log.write('Starting fine-tuning...\n') logging.info('Starting fine-tuning...') stdout_log.flush() stdout_log.write('{} points picked via {}\n'.format( len(finetune_images), FLAGS.uncertain)) logging.info('%d points picked via %s', len(finetune_images), FLAGS.uncertain) assert FLAGS.n_extra == len(finetune_images) for epoch in range(1, FLAGS.epochs + 1): # BEGIN: finetune model with extra data, evaluate and save num_extra = len(finetune_images) num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) finetune = data.DataChunk(X=finetune_images, Y=finetune_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') batches = data.minibatcher(finetune, FLAGS.batch_size, transform=augmentation) itercount = itertools.count() key = random.PRNGKey(FLAGS.seed) start_time = time.time() for _ in range(num_batches): # tmp_time = time.time() b = next(batches) if FLAGS.dpsgd: opt_state = private_update( key, next(itercount), opt_state, shape_as_image(b.X, b.Y, dummy_dim=True)) else: opt_state = update(key, next(itercount), opt_state, shape_as_image(b.X, b.Y)) # stdout_log.write('single update in {:.2f} sec\n'.format( # time.time() - tmp_time)) epoch_time = time.time() - start_time stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time)) logging.info('Epoch %d in %.2f sec', epoch, epoch_time) # accuracy on test data params = get_params(opt_state) test_pred_0 = test_pred test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) test_loss = loss(params, shape_as_image(test_images, test_labels)) stdout_log.write( 'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format( test_loss, 100 * test_acc)) logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss, 100 * test_acc) stdout_log.flush() # visualize prediction difference between 2 checkpoints. if FLAGS.visualize: utils.visualize_ckpt_difference(test_images, np.argmax(test_labels, axis=1), test_pred_0, test_pred, epoch - 1, epoch, FLAGS.work_dir, mu=train_mu, sigma=train_std) # END: finetune model with extra data, evaluate and save stdout_log.close()
def main(_): if FLAGS.microbatches: raise NotImplementedError( 'Microbatches < batch size not currently supported' ) train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) key = random.PRNGKey(FLAGS.seed) def data_stream(): rng = npr.RandomState(FLAGS.seed) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) _, init_params = init_random_params(key, (-1, 28, 28, 1)) opt_state = opt_init(init_params) itercount = itertools.count() steps_per_epoch = 60000 // FLAGS.batch_size print('\nStarting training...') for epoch in range(1, FLAGS.epochs + 1): start_time = time.time() # pylint: disable=no-value-for-parameter for _ in range(num_batches): if FLAGS.dpsgd: opt_state = \ private_update( key, next(itercount), opt_state, shape_as_image(*next(batches), dummy_dim=True)) else: opt_state = update( key, next(itercount), opt_state, shape_as_image(*next(batches))) # pylint: enable=no-value-for-parameter epoch_time = time.time() - start_time print('Epoch {} in {:0.2f} sec'.format(epoch, epoch_time)) # evaluate test accuracy params = get_params(opt_state) test_acc = accuracy(params, shape_as_image(test_images, test_labels)) test_loss = loss(params, shape_as_image(test_images, test_labels)) print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format( test_loss, 100 * test_acc)) # determine privacy loss so far if FLAGS.dpsgd: delta = 1e-5 num_examples = 60000 eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta) print( 'For delta={:.0e}, the current epsilon is: {:.2f}'.format(delta, eps)) else: print('Trained with vanilla non-private SGD optimizer')