Ejemplo n.º 1
0
  def test_regularized_training(self):
    """Test that adding regularization penalty to the training loss works."""
    np.random.seed(0)
    # Set up the problem of recovering w given x and
    #   y = x . w + noise
    # with the a priori assumption that w is sparse. There are fewer examples
    # than dimensions (x is a wide matrix), so the problem is underdetermined
    # without the sparsity assumption.
    num_examples, num_dim = 8, 10
    x = np.random.randn(num_examples, num_dim).astype(np.float32)
    true_w = np.zeros((num_dim, 2), np.float32)
    true_w[[2, 4, 6], 0] = [1.0, 2.0, 3.0]
    true_w[[3, 5], 1] = [4.0, 5.0]
    y = np.dot(x, true_w) + 1e-3 * np.random.randn(num_examples, 2)

    # Get the least squares estimate for w. It isn't very accurate.
    least_squares_w = np.linalg.lstsq(x, y, rcond=None)[0]
    least_squares_w_error = hk_util.l2_loss(least_squares_w - true_w)

    # Get a better estimate by solving the L1 regularized problem
    #  argmin_w ||x . w - y||_2^2 + c ||w||_1.
    w_regularizer = lambda w: 4.0 * hk_util.l1_loss(w)
    def model_fun(batch):
      x = batch['x']
      return hk_util.Linear(2, use_bias=False, w_regularizer=w_regularizer)(x)

    model = hk.transform(model_fun)

    def loss_fun(params, batch):
      """Training loss with L1 regularization penalty term."""
      y_predicted, penalties = model.apply(params, batch)
      return hk_util.l2_loss(y_predicted - batch['y']) + penalties

    batch = {'x': x, 'y': y}
    params = model.init(jax.random.PRNGKey(0), batch)
    optimizer = optix.chain(  # Gradient descent with decreasing learning rate.
        optix.trace(decay=0.0, nesterov=False),
        optix.scale_by_schedule(lambda i: -0.05 / jnp.sqrt(1 + i)))
    opt_state = optimizer.init(params)

    @jax.jit
    def train_step(params, opt_state, batch):
      grads = jax.grad(loss_fun)(params, batch)
      updates, opt_state = optimizer.update(grads, opt_state)
      new_params = optix.apply_updates(params, updates)
      return new_params, opt_state

    for _ in range(1000):
      params, opt_state = train_step(params, opt_state, batch)

    l1_w = params['linear']['w']
    l1_w_error = hk_util.l2_loss(l1_w - true_w).item()

    # The L1-regularized estimate is much more accurate.
    self.assertGreater(least_squares_w_error, 4.0)
    self.assertLess(l1_w_error, 1.0)
Ejemplo n.º 2
0
    def test_apply_every(self):
        # The frequency of the application of sgd
        k = 4
        zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.]))

        # experimental/optix.py sgd
        optix_sgd_params = self.init_params
        sgd = optix.sgd(LR, 0.0)
        state_sgd = sgd.init(optix_sgd_params)

        # experimental/optix.py sgd apply every
        optix_sgd_apply_every_params = self.init_params
        sgd_apply_every = optix.chain(optix.apply_every(k=k),
                                      optix.trace(decay=0, nesterov=False),
                                      optix.scale(-LR))
        state_sgd_apply_every = sgd_apply_every.init(
            optix_sgd_apply_every_params)
        for i in range(STEPS):
            # Apply a step of sgd
            updates_sgd, state_sgd = sgd.update(self.per_step_updates,
                                                state_sgd)
            optix_sgd_params = optix.apply_updates(optix_sgd_params,
                                                   updates_sgd)

            # Apply a step of sgd_apply_every
            updates_sgd_apply_every, state_sgd_apply_every = sgd_apply_every.update(
                self.per_step_updates, state_sgd_apply_every)
            optix_sgd_apply_every_params = optix.apply_updates(
                optix_sgd_apply_every_params, updates_sgd_apply_every)
            if i % k == k - 1:
                # Check equivalence.
                for x, y in zip(tree_leaves(optix_sgd_apply_every_params),
                                tree_leaves(optix_sgd_params)):
                    np.testing.assert_allclose(x, y, atol=1e-6, rtol=100)
            else:
                # Check updaue is zero.
                for x, y in zip(tree_leaves(updates_sgd_apply_every),
                                tree_leaves(zero_update)):
                    np.testing.assert_allclose(x, y, atol=1e-10, rtol=1e-5)
Ejemplo n.º 3
0
def make_optimizer():
    """SGD with nesterov momentum and a custom lr schedule."""
    return optix.chain(
        optix.trace(decay=FLAGS.optimizer_momentum,
                    nesterov=FLAGS.optimizer_use_nesterov),
        optix.scale_by_schedule(lr_schedule), optix.scale(-1))
Ejemplo n.º 4
0
def make_optimizer(lr_schedule, momentum_decay):
  return optix.chain(optix.trace(decay=momentum_decay, nesterov=False),
                     optix.scale_by_schedule(lr_schedule),
                     optix.scale(-1))
Ejemplo n.º 5
0
def make_optimizer(lr_schedule, momentum_decay):
    """Make SGD optimizer with momentum."""
    # Maximize log-prob instead of minimizing loss
    return optix.chain(optix.trace(decay=momentum_decay, nesterov=False),
                       optix.scale_by_schedule(lr_schedule))