def test_multi_steps(self): batch_size = 32 x_size = 7 # Parameters should be updated only every `k_steps` optimisation steps. k_steps = 4 data = jnp.ones([batch_size, x_size]) def get_loss(x): loss = jnp.sum(hk.Linear(10)(x)**2) return loss loss_init, loss_apply = hk.without_apply_rng(hk.transform(get_loss)) params = loss_init(jax.random.PRNGKey(1915), data) ms_opt = wrappers.MultiSteps(alias.adam(1e-4), k_steps) opt_init, opt_update = ms_opt.gradient_transformation() # Put the training in one function, to check that the update is indeed # jittable. def train_step(data, opt_state, params): grad = jax.grad(loss_apply)(params, data) updates, opt_state = opt_update(grad, opt_state, params) return updates, opt_state opt_state = opt_init(params) prev_loss = loss_apply(params, data) for idx in range(5 * k_steps): updates, opt_state = self.variant(train_step)(data, opt_state, params) new_params = update.apply_updates(params, updates) new_loss = loss_apply(new_params, data) if idx % k_steps < k_steps - 1: # The parameters should not have changed and the loss should be # constant. jax.tree_multimap(np.testing.assert_array_equal, new_params, params) np.testing.assert_equal(new_loss, prev_loss) self.assertFalse(ms_opt.has_updated(opt_state)) else: # This is a step where parameters should actually have been updated, and # the loss should accordingly go down. np.testing.assert_array_less(new_loss, prev_loss) prev_loss = new_loss self.assertTrue(ms_opt.has_updated(opt_state)) params = new_params
def test_multi_steps_every_k_schedule(self): # Test a non-trivial schedule which varies over time. ms_opt = wrappers.MultiSteps( alias.sgd(1e-4), lambda grad_step: jnp.where(grad_step < 2, 1, 3)) opt_init, opt_update = ms_opt.gradient_transformation() params = dict(a=jnp.zeros([])) opt_state = opt_init(params) grad = dict(a=jnp.zeros([])) self.assertFalse(ms_opt.has_updated(opt_state)) # First two steps have 1 mini-step per update. for _ in range(2): _, opt_state = opt_update(grad, opt_state, params) self.assertTrue(ms_opt.has_updated(opt_state)) # Subsequently, mini-steps should have 3 mini-steps per update. for _ in range(5): for _ in range(2): _, opt_state = opt_update(grad, opt_state, params) self.assertFalse(ms_opt.has_updated(opt_state)) _, opt_state = opt_update(grad, opt_state, params) self.assertTrue(ms_opt.has_updated(opt_state))