def test_regularized_training(self): """Test that adding regularization penalty to the training loss works.""" np.random.seed(0) # Set up the problem of recovering w given x and # y = x . w + noise # with the a priori assumption that w is sparse. There are fewer examples # than dimensions (x is a wide matrix), so the problem is underdetermined # without the sparsity assumption. num_examples, num_dim = 8, 10 x = np.random.randn(num_examples, num_dim).astype(np.float32) true_w = np.zeros((num_dim, 2), np.float32) true_w[[2, 4, 6], 0] = [1.0, 2.0, 3.0] true_w[[3, 5], 1] = [4.0, 5.0] y = np.dot(x, true_w) + 1e-3 * np.random.randn(num_examples, 2) # Get the least squares estimate for w. It isn't very accurate. least_squares_w = np.linalg.lstsq(x, y, rcond=None)[0] least_squares_w_error = hk_util.l2_loss(least_squares_w - true_w) # Get a better estimate by solving the L1 regularized problem # argmin_w ||x . w - y||_2^2 + c ||w||_1. w_regularizer = lambda w: 4.0 * hk_util.l1_loss(w) def model_fun(batch): x = batch['x'] return hk_util.Linear(2, use_bias=False, w_regularizer=w_regularizer)(x) model = hk.transform(model_fun) def loss_fun(params, batch): """Training loss with L1 regularization penalty term.""" y_predicted, penalties = model.apply(params, batch) return hk_util.l2_loss(y_predicted - batch['y']) + penalties batch = {'x': x, 'y': y} params = model.init(jax.random.PRNGKey(0), batch) optimizer = optix.chain( # Gradient descent with decreasing learning rate. optix.trace(decay=0.0, nesterov=False), optix.scale_by_schedule(lambda i: -0.05 / jnp.sqrt(1 + i))) opt_state = optimizer.init(params) @jax.jit def train_step(params, opt_state, batch): grads = jax.grad(loss_fun)(params, batch) updates, opt_state = optimizer.update(grads, opt_state) new_params = optix.apply_updates(params, updates) return new_params, opt_state for _ in range(1000): params, opt_state = train_step(params, opt_state, batch) l1_w = params['linear']['w'] l1_w_error = hk_util.l2_loss(l1_w - true_w).item() # The L1-regularized estimate is much more accurate. self.assertGreater(least_squares_w_error, 4.0) self.assertLess(l1_w_error, 1.0)
def test_apply_every(self): # The frequency of the application of sgd k = 4 zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.])) # experimental/optix.py sgd optix_sgd_params = self.init_params sgd = optix.sgd(LR, 0.0) state_sgd = sgd.init(optix_sgd_params) # experimental/optix.py sgd apply every optix_sgd_apply_every_params = self.init_params sgd_apply_every = optix.chain(optix.apply_every(k=k), optix.trace(decay=0, nesterov=False), optix.scale(-LR)) state_sgd_apply_every = sgd_apply_every.init( optix_sgd_apply_every_params) for i in range(STEPS): # Apply a step of sgd updates_sgd, state_sgd = sgd.update(self.per_step_updates, state_sgd) optix_sgd_params = optix.apply_updates(optix_sgd_params, updates_sgd) # Apply a step of sgd_apply_every updates_sgd_apply_every, state_sgd_apply_every = sgd_apply_every.update( self.per_step_updates, state_sgd_apply_every) optix_sgd_apply_every_params = optix.apply_updates( optix_sgd_apply_every_params, updates_sgd_apply_every) if i % k == k - 1: # Check equivalence. for x, y in zip(tree_leaves(optix_sgd_apply_every_params), tree_leaves(optix_sgd_params)): np.testing.assert_allclose(x, y, atol=1e-6, rtol=100) else: # Check updaue is zero. for x, y in zip(tree_leaves(updates_sgd_apply_every), tree_leaves(zero_update)): np.testing.assert_allclose(x, y, atol=1e-10, rtol=1e-5)
def make_optimizer(): """SGD with nesterov momentum and a custom lr schedule.""" return optix.chain( optix.trace(decay=FLAGS.optimizer_momentum, nesterov=FLAGS.optimizer_use_nesterov), optix.scale_by_schedule(lr_schedule), optix.scale(-1))
def make_optimizer(lr_schedule, momentum_decay): return optix.chain(optix.trace(decay=momentum_decay, nesterov=False), optix.scale_by_schedule(lr_schedule), optix.scale(-1))
def make_optimizer(lr_schedule, momentum_decay): """Make SGD optimizer with momentum.""" # Maximize log-prob instead of minimizing loss return optix.chain(optix.trace(decay=momentum_decay, nesterov=False), optix.scale_by_schedule(lr_schedule))