def test_numeric_static_args(self, static_args): optim = schedule.inject_hyperparams(transform.scale_by_adam, static_args=static_args)(b1=0.9, b2=0.95) params = [jnp.ones((1, 2)), jnp.ones(2), jnp.ones((1, 1, 1))] grads = params state = self.variant(optim.init)(params) _, state = self.variant(optim.update)(grads, state) assert not set(state.hyperparams.keys()).intersection(set(static_args))
def test_overriding_hyperparam(self): optim = schedule.inject_hyperparams(transform.clip_by_global_norm)(0.1) params = jnp.zeros((3, 5, 7)) state = self.variant(optim.init)(params) update_fn = self.variant(optim.update) grads = jnp.ones_like(params) for i in range(5): state.hyperparams['max_norm'] = i updates, state = update_fn(grads, state) assert np.isclose(jnp.linalg.norm(updates.ravel()), i)
def test_constant_hyperparams(self): optim = schedule.inject_hyperparams(transform.scale_by_adam)(b1=0., b2=0.) params = [jnp.zeros([2, 3]) for _ in range(3)] state = self.variant(optim.init)(params) update_fn = self.variant(optim.update) grads = jax.tree_map(jnp.ones_like, params) for _ in range(5): updates, state = update_fn(grads, state, params) np.testing.assert_almost_equal(state.hyperparams['b1'], 0.0) np.testing.assert_almost_equal(state.hyperparams['b2'], 0.0) np.testing.assert_almost_equal(state.hyperparams['eps'], 1e-8) np.testing.assert_almost_equal(state.hyperparams['eps_root'], 0.0) assert 'eps' in state.hyperparams chex.assert_tree_all_close(updates, grads)
def test_updates(self): optim = schedule.inject_hyperparams(transform.scale)( # stateless step_size=schedule.piecewise_constant_schedule( 3.0, { 2: 5, 8: 2, 13: 1.5 })) params = [jnp.zeros([], dtype=jnp.float32)] state = self.variant(optim.init)(params) update_fn = self.variant(optim.update) expected_step_size = [3.0] * 2 + [15.0] * 6 + [30.0] * 5 + [45.0] * 3 grads = [jnp.ones([], dtype=jnp.float32)] for i in range(15): updates, state = update_fn(grads, state, params=params) np.testing.assert_almost_equal(updates[0], expected_step_size[i + 1])
def test_hyperparams_state(self): optim = schedule.inject_hyperparams(transform.trace)( # stateful decay=schedule.piecewise_constant_schedule(0.8, { 4: 0.5, 10: 1.25 }), nesterov=True) params = [jnp.zeros([2, 3]) for _ in range(3)] state = self.variant(optim.init)(params) update_fn = self.variant(optim.update) expected_mom = [0.8] * 4 + [0.4] * 6 + [0.5] * 2 grads = jax.tree_map(jnp.ones_like, params) for i in range(12): np.testing.assert_almost_equal(state.hyperparams['decay'], expected_mom[i]) _, state = update_fn(grads, state) np.testing.assert_almost_equal(state.hyperparams['decay'], expected_mom[-1])
def test_static_args_error(self, static_args): with self.assertRaises(ValueError): schedule.inject_hyperparams(transform.scale, static_args=static_args)