def test_regularized_training(self): """Test that adding regularization penalty to the training loss works.""" np.random.seed(0) # Set up the problem of recovering w given x and # y = x . w + noise # with the a priori assumption that w is sparse. There are fewer examples # than dimensions (x is a wide matrix), so the problem is underdetermined # without the sparsity assumption. num_examples, num_dim = 8, 10 x = np.random.randn(num_examples, num_dim).astype(np.float32) true_w = np.zeros((num_dim, 2), np.float32) true_w[[2, 4, 6], 0] = [1.0, 2.0, 3.0] true_w[[3, 5], 1] = [4.0, 5.0] y = np.dot(x, true_w) + 1e-3 * np.random.randn(num_examples, 2) # Get the least squares estimate for w. It isn't very accurate. least_squares_w = np.linalg.lstsq(x, y, rcond=None)[0] least_squares_w_error = hk_util.l2_loss(least_squares_w - true_w) # Get a better estimate by solving the L1 regularized problem # argmin_w ||x . w - y||_2^2 + c ||w||_1. w_regularizer = lambda w: 4.0 * hk_util.l1_loss(w) def model_fun(batch): x = batch['x'] return hk_util.Linear(2, use_bias=False, w_regularizer=w_regularizer)(x) model = hk.transform(model_fun) def loss_fun(params, batch): """Training loss with L1 regularization penalty term.""" y_predicted, penalties = model.apply(params, batch) return hk_util.l2_loss(y_predicted - batch['y']) + penalties batch = {'x': x, 'y': y} params = model.init(jax.random.PRNGKey(0), batch) optimizer = optix.chain( # Gradient descent with decreasing learning rate. optix.trace(decay=0.0, nesterov=False), optix.scale_by_schedule(lambda i: -0.05 / jnp.sqrt(1 + i))) opt_state = optimizer.init(params) @jax.jit def train_step(params, opt_state, batch): grads = jax.grad(loss_fun)(params, batch) updates, opt_state = optimizer.update(grads, opt_state) new_params = optix.apply_updates(params, updates) return new_params, opt_state for _ in range(1000): params, opt_state = train_step(params, opt_state, batch) l1_w = params['linear']['w'] l1_w_error = hk_util.l2_loss(l1_w - true_w).item() # The L1-regularized estimate is much more accurate. self.assertGreater(least_squares_w_error, 4.0) self.assertLess(l1_w_error, 1.0)
def make_optimizer(): """SGD with nesterov momentum and a custom lr schedule.""" return optix.chain( optix.trace(decay=FLAGS.optimizer_momentum, nesterov=FLAGS.optimizer_use_nesterov), optix.scale_by_schedule(lr_schedule), optix.scale(-1))
def make_optimizer(lr_schedule, momentum_decay): return optix.chain(optix.trace(decay=momentum_decay, nesterov=False), optix.scale_by_schedule(lr_schedule), optix.scale(-1))
def make_optimizer(lr_schedule, momentum_decay): """Make SGD optimizer with momentum.""" # Maximize log-prob instead of minimizing loss return optix.chain(optix.trace(decay=momentum_decay, nesterov=False), optix.scale_by_schedule(lr_schedule))