def test_scale_by_fromage(self): schedule = lambda c: 1.0 / (c + 1.0) fromage = transform.scale_by_fromage(step_size_factor_fn=schedule) params = self.init_params state = fromage.init(params) previous_norm = jnp.inf transform_fn = self.variant(fromage.update) for _ in range(STEPS): # Apply a step of fromage updates, state = transform_fn(self.per_step_updates, state, params) # Updates should get smaller due to the the learning schedule. norm = transform.global_norm(updates) self.assertLess(norm, previous_norm) previous_norm = norm
class AliasTest(chex.TestCase): def setUp(self): super(AliasTest, self).setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @chex.all_variants() @parameterized.named_parameters( ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop(LR, .9, 0.1), optimizers.rmsprop( LR, .9, 0.1), 1e-5), ('adagrad', alias.adagrad( LR, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ) def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol) @parameterized.named_parameters( ('sgd', alias.sgd(1e-2, 0.0)), ('adam', alias.adam(1e-1)), ('adamw', alias.adamw(1e-1)), ('lamb', alias.adamw(1e-1)), ('rmsprop', alias.rmsprop(1e-1)), ('fromage', transform.scale_by_fromage(-1e-2)), ('adabelief', alias.adabelief(1e-1)), ) def test_parabel(self, opt): initial_params = jnp.array([-1.0, 10.0, 1.0]) final_params = jnp.array([1.0, -1.0, 1.0]) @jax.grad def get_updates(params): return jnp.sum((params - final_params)**2) @jax.jit def step(params, state): updates, state = opt.update(get_updates(params), state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) for _ in range(1000): params, state = step(params, state) chex.assert_tree_all_close(params, final_params, rtol=1e-2, atol=1e-2) @parameterized.named_parameters( ('sgd', alias.sgd(2e-3, 0.2)), ('adam', alias.adam(1e-1)), ('adamw', alias.adamw(1e-1)), ('lamb', alias.adamw(1e-1)), ('rmsprop', alias.rmsprop(5e-3)), ('fromage', transform.scale_by_fromage(-5e-3)), ('adabelief', alias.adabelief(1e-1)), ) def test_rosenbrock(self, opt): a = 1.0 b = 100.0 initial_params = jnp.array([0.0, 0.0]) final_params = jnp.array([a, a**2]) @jax.grad def get_updates(params): return (a - params[0])**2 + b * (params[1] - params[0]**2)**2 @jax.jit def step(params, state): updates, state = opt.update(get_updates(params), state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) for _ in range(10000): params, state = step(params, state) chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)