Exemple #1
0
    def test_multi_steps(self):
        batch_size = 32
        x_size = 7
        # Parameters should be updated only every `k_steps` optimisation steps.
        k_steps = 4
        data = jnp.ones([batch_size, x_size])

        def get_loss(x):
            loss = jnp.sum(hk.Linear(10)(x)**2)
            return loss

        loss_init, loss_apply = hk.without_apply_rng(hk.transform(get_loss))
        params = loss_init(jax.random.PRNGKey(1915), data)

        ms_opt = wrappers.MultiSteps(alias.adam(1e-4), k_steps)
        opt_init, opt_update = ms_opt.gradient_transformation()

        # Put the training in one function, to check that the update is indeed
        # jittable.
        def train_step(data, opt_state, params):
            grad = jax.grad(loss_apply)(params, data)
            updates, opt_state = opt_update(grad, opt_state, params)
            return updates, opt_state

        opt_state = opt_init(params)

        prev_loss = loss_apply(params, data)
        for idx in range(5 * k_steps):
            updates, opt_state = self.variant(train_step)(data, opt_state,
                                                          params)
            new_params = update.apply_updates(params, updates)
            new_loss = loss_apply(new_params, data)
            if idx % k_steps < k_steps - 1:
                # The parameters should not have changed and the loss should be
                # constant.
                jax.tree_multimap(np.testing.assert_array_equal, new_params,
                                  params)
                np.testing.assert_equal(new_loss, prev_loss)
                self.assertFalse(ms_opt.has_updated(opt_state))
            else:
                # This is a step where parameters should actually have been updated, and
                # the loss should accordingly go down.
                np.testing.assert_array_less(new_loss, prev_loss)
                prev_loss = new_loss
                self.assertTrue(ms_opt.has_updated(opt_state))
            params = new_params
Exemple #2
0
 def test_multi_steps_every_k_schedule(self):
     # Test a non-trivial schedule which varies over time.
     ms_opt = wrappers.MultiSteps(
         alias.sgd(1e-4), lambda grad_step: jnp.where(grad_step < 2, 1, 3))
     opt_init, opt_update = ms_opt.gradient_transformation()
     params = dict(a=jnp.zeros([]))
     opt_state = opt_init(params)
     grad = dict(a=jnp.zeros([]))
     self.assertFalse(ms_opt.has_updated(opt_state))
     # First two steps have 1 mini-step per update.
     for _ in range(2):
         _, opt_state = opt_update(grad, opt_state, params)
         self.assertTrue(ms_opt.has_updated(opt_state))
     # Subsequently, mini-steps should have 3 mini-steps per update.
     for _ in range(5):
         for _ in range(2):
             _, opt_state = opt_update(grad, opt_state, params)
             self.assertFalse(ms_opt.has_updated(opt_state))
         _, opt_state = opt_update(grad, opt_state, params)
         self.assertTrue(ms_opt.has_updated(opt_state))