class ExperimentalOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @chex.all_variants() @parameterized.named_parameters( ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop( LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR_SCHED, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR_SCHED, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ) def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
class FlaxOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @parameterized.named_parameters( ('sgd', alias.sgd(LR), optim.GradientDescent(LR)), ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum(LR, beta=0.9)), # Different names. ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True), optim.Momentum(LR, beta=0.9, nesterov=True)), ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)), ('centered_rmsprop', alias.rmsprop(LR, centered=True), optim.RMSProp(LR, centered=True)), ('adam', alias.adam(LR), optim.Adam(LR)), ('adam_w', alias.adamw(LR, weight_decay=1e-4), optim.Adam(LR, weight_decay=1e-4)), # Different name. ('adagrad', alias.adagrad(LR, initial_accumulator_value=0.), # Different default! optim.Adagrad(LR)), ('lamb', alias.lamb(LR), optim.LAMB(LR)), ) def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): # flax/optim flax_params = self.init_params flax_optimizer = flax_optimizer.create(flax_params) for _ in range(STEPS): flax_optimizer = flax_optimizer.apply_gradient( self.per_step_updates) flax_params = flax_optimizer.target # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) for _ in range(STEPS): updates, state = optax_optimizer.update( self.per_step_updates, state, optax_params) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
def test_flatten(self): def init_params(): return (jnp.array([1., 2.]), jnp.array([3., 4.])) per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) # First calculate new params without flattening optax_sgd_params = init_params() sgd = alias.sgd(1e-2, 0.0) state_sgd = sgd.init(optax_sgd_params) updates_sgd, state_sgd = sgd.update(per_step_updates, state_sgd) sgd_params_no_flatten = update.apply_updates(optax_sgd_params, updates_sgd) # And now calculate new params with flattening optax_sgd_params = init_params() sgd = wrappers.flatten(sgd) state_sgd = sgd.init(optax_sgd_params) updates_sgd, state_sgd = sgd.update(per_step_updates, state_sgd) sgd_params_flatten = update.apply_updates(optax_sgd_params, updates_sgd) # Test that both give the same result chex.assert_tree_all_close(sgd_params_no_flatten, sgd_params_flatten, atol=1e-7, rtol=1e-7)
def test_explicit_dtype(self, dtype): expected_dtype = jax.dtypes.canonicalize_dtype( dtype) # None -> float32 tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype) trace_state, _ = tx.init(jnp.array([0.0, 0.0])) self.assertEqual(expected_dtype, trace_state.trace.dtype) tx = alias.adam(0.1, mu_dtype=dtype) adam_state, _ = tx.init(jnp.array([0.0, 0.0])) self.assertEqual(expected_dtype, adam_state.mu.dtype) tx = alias.adamw(0.1, mu_dtype=dtype) adam_state, _, _ = tx.init(jnp.array([0.0, 0.0])) self.assertEqual(expected_dtype, adam_state.mu.dtype)
def test_apply_if_finite_pmap(self): # Unlike in `test_apply_if_finite`: # * pmap is applied to the gradient computation and the optimisation; # * the NaNs are caused inside the function and do not come from the inputs. half = jnp.ones([1]) / 2. two = jnp.ones([1]) * 2. # Causes a NaN in arctanh def fn(x): return jnp.arctanh(x) * hk.get_parameter( 'p', [], init=hk.initializers.Constant(0.)) fn = hk.without_apply_rng(hk.transform(fn)) opt = wrappers.apply_if_finite(alias.sgd(1.), 2) def fn_update(params, opt_state, x): grads = jax.grad(fn.apply)(params, x) grads = jax.lax.psum(grads, axis_name='i') updates, new_opt_state = opt.update(grads, opt_state, params) new_params = update.apply_updates(params, updates) return new_params, new_opt_state fn_update = jax.pmap(fn_update, axis_name='i') params = fn.init(jax.random.PRNGKey(1905), half) opt_state = opt.init(params) params = jax.tree_map(lambda x: x[None], params) opt_state = jax.tree_map(lambda x: x[None], opt_state) # Do one successful param update params, opt_state = fn_update(params, opt_state, half) self.assertTrue(bool(opt_state.last_finite)) # Check 2 rejected param updates for step in range(2): params, opt_state = fn_update(params, opt_state, two) self.assertFalse(bool(opt_state.last_finite)) self.assertEqual(step + 1, int(opt_state.notfinite_count)) # Next successful param update params, opt_state = fn_update(params, opt_state, half) self.assertTrue(bool(opt_state.last_finite)) # Again 2 rejected param updates for step in range(2): params, opt_state = fn_update(params, opt_state, two) self.assertFalse(bool(opt_state.last_finite)) self.assertEqual(step + 1, int(opt_state.notfinite_count)) # Next param update with NaN is accepted since we reached maximum params, opt_state = fn_update(params, opt_state, two) self.assertEqual(5, int(opt_state.total_notfinite))
def test_apply_every(self): # The frequency of the application of sgd k = 4 zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.])) # optax sgd optax_sgd_params = self.init_params sgd = alias.sgd(LR, 0.0) state_sgd = sgd.init(optax_sgd_params) # optax sgd plus apply every optax_sgd_apply_every_params = self.init_params sgd_apply_every = combine.chain( transform.apply_every(k=k), transform.trace(decay=0, nesterov=False), transform.scale(-LR)) state_sgd_apply_every = sgd_apply_every.init( optax_sgd_apply_every_params) transform_fn = self.variant(sgd_apply_every.update) for i in range(STEPS): # Apply a step of sgd updates_sgd, state_sgd = sgd.update(self.per_step_updates, state_sgd) optax_sgd_params = update.apply_updates(optax_sgd_params, updates_sgd) # Apply a step of sgd_apply_every updates_sgd_apply_every, state_sgd_apply_every = transform_fn( self.per_step_updates, state_sgd_apply_every) optax_sgd_apply_every_params = update.apply_updates( optax_sgd_apply_every_params, updates_sgd_apply_every) # Every k steps, check equivalence. if i % k == k - 1: chex.assert_tree_all_close(optax_sgd_apply_every_params, optax_sgd_params, atol=1e-6, rtol=1e-5) # Otherwise, check update is zero. else: chex.assert_tree_all_close(updates_sgd_apply_every, zero_update, atol=0.0, rtol=0.0)
def test_multi_steps_every_k_schedule(self): # Test a non-trivial schedule which varies over time. ms_opt = wrappers.MultiSteps( alias.sgd(1e-4), lambda grad_step: jnp.where(grad_step < 2, 1, 3)) opt_init, opt_update = ms_opt.gradient_transformation() params = dict(a=jnp.zeros([])) opt_state = opt_init(params) grad = dict(a=jnp.zeros([])) self.assertFalse(ms_opt.has_updated(opt_state)) # First two steps have 1 mini-step per update. for _ in range(2): _, opt_state = opt_update(grad, opt_state, params) self.assertTrue(ms_opt.has_updated(opt_state)) # Subsequently, mini-steps should have 3 mini-steps per update. for _ in range(5): for _ in range(2): _, opt_state = opt_update(grad, opt_state, params) self.assertFalse(ms_opt.has_updated(opt_state)) _, opt_state = opt_update(grad, opt_state, params) self.assertTrue(ms_opt.has_updated(opt_state))
def test_labels_mismatch(self, use_extra_label, use_fn): # The labels from label_fn must be a subet of the keys for the tx. params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} params = jax.tree_map(jnp.asarray, params) label_tree = {'a': 0, 'b': [1, 0], 'c': 1} # prefix of params if use_extra_label: label_tree['a'] = 3 transforms = { 0: alias.sgd(1.), 1: alias.adam(1., b1=0., b2=0.), 2: transform.trace(1.0) } init_fn, update_fn = combine.multi_transform( transforms, (lambda _: label_tree) if use_fn else label_tree) if use_extra_label: with self.assertRaises(ValueError): self.variant(init_fn)(params) else: state = self.variant(init_fn)(params) updates = jax.tree_map(lambda x: x / 10.0, params) self.variant(update_fn)(updates, state)
def _build_sgd(): return alias.sgd(1.)
class AliasTest(chex.TestCase): @parameterized.parameters( ('sgd', lambda: alias.sgd(1e-2, 0.0)), ('adam', lambda: alias.adam(1e-1)), ('adamw', lambda: alias.adamw(1e-1)), ('lamb', lambda: alias.adamw(1e-1)), ('rmsprop', lambda: alias.rmsprop(1e-1)), ('rmsprop_momentum', lambda: alias.rmsprop(5e-2, momentum=0.9)), ('fromage', lambda: alias.fromage(1e-2)), ('adabelief', lambda: alias.adabelief(1e-1)), ('radam', lambda: alias.radam(1e-1)), ('sm3', lambda: alias.sm3(1.0)), ('yogi', lambda: alias.yogi(1.0)), ('dpsgd', lambda: alias.dpsgd(1e-2, 10.0, 0.001, 0))) def test_parabel(self, opt_name, opt): opt = opt() initial_params = jnp.array([-1.0, 10.0, 1.0]) final_params = jnp.array([1.0, -1.0, 1.0]) @jax.grad def get_updates(params): return jnp.sum((params - final_params)**2) @jax.jit def step(params, state): updates = get_updates(params) if opt_name == 'dpsgd': updates = updates[None] updates, state = opt.update(updates, state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) for _ in range(1000): params, state = step(params, state) chex.assert_tree_all_close(params, final_params, rtol=1e-2, atol=1e-2) @parameterized.parameters( ('sgd', lambda: alias.sgd(2e-3, 0.2)), ('adam', lambda: alias.adam(1e-1)), ('adamw', lambda: alias.adamw(1e-1)), ('lamb', lambda: alias.adamw(1e-1)), ('rmsprop', lambda: alias.rmsprop(5e-3)), ('rmsprop_momentum', lambda: alias.rmsprop(5e-3, momentum=0.9)), ('fromage', lambda: alias.fromage(5e-3)), ('adabelief', lambda: alias.adabelief(1e-1)), ('radam', lambda: alias.radam(1e-3)), ('sm3', lambda: alias.sm3(1.0)), ('yogi', lambda: alias.yogi(1.0)), ('dpsgd', lambda: alias.dpsgd(2e-3, 10., 0.001, 0, 0.2))) def test_rosenbrock(self, opt_name, opt): opt = opt() a = 1.0 b = 100.0 initial_params = jnp.array([0.0, 0.0]) final_params = jnp.array([a, a**2]) @jax.grad def get_updates(params): return (a - params[0])**2 + b * (params[1] - params[0]**2)**2 @jax.jit def step(params, state): updates = get_updates(params) if opt_name == 'dpsgd': updates = updates[None] updates, state = opt.update(updates, state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) for _ in range(10000): params, state = step(params, state) chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)
class AliasTest(chex.TestCase): def setUp(self): super(AliasTest, self).setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @chex.all_variants() @parameterized.named_parameters( ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop(LR, .9, 0.1), optimizers.rmsprop( LR, .9, 0.1), 1e-5), ('adagrad', alias.adagrad( LR, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ) def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol) @parameterized.named_parameters( ('sgd', alias.sgd(1e-2, 0.0)), ('adam', alias.adam(1e-1)), ('adamw', alias.adamw(1e-1)), ('lamb', alias.adamw(1e-1)), ('rmsprop', alias.rmsprop(1e-1)), ('fromage', transform.scale_by_fromage(-1e-2)), ('adabelief', alias.adabelief(1e-1)), ) def test_parabel(self, opt): initial_params = jnp.array([-1.0, 10.0, 1.0]) final_params = jnp.array([1.0, -1.0, 1.0]) @jax.grad def get_updates(params): return jnp.sum((params - final_params)**2) @jax.jit def step(params, state): updates, state = opt.update(get_updates(params), state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) for _ in range(1000): params, state = step(params, state) chex.assert_tree_all_close(params, final_params, rtol=1e-2, atol=1e-2) @parameterized.named_parameters( ('sgd', alias.sgd(2e-3, 0.2)), ('adam', alias.adam(1e-1)), ('adamw', alias.adamw(1e-1)), ('lamb', alias.adamw(1e-1)), ('rmsprop', alias.rmsprop(5e-3)), ('fromage', transform.scale_by_fromage(-5e-3)), ('adabelief', alias.adabelief(1e-1)), ) def test_rosenbrock(self, opt): a = 1.0 b = 100.0 initial_params = jnp.array([0.0, 0.0]) final_params = jnp.array([a, a**2]) @jax.grad def get_updates(params): return (a - params[0])**2 + b * (params[1] - params[0]**2)**2 @jax.jit def step(params, state): updates, state = opt.update(get_updates(params), state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) for _ in range(10000): params, state = step(params, state) chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)
class FlaxOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([0., 0.3, 500., 5.]), jnp.array([300., 3.])) @parameterized.named_parameters( ('sgd', alias.sgd(LR), optim.GradientDescent(LR)), ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum( LR, beta=0.9)), # Different names. ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True), optim.Momentum(LR, beta=0.9, nesterov=True)), ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)), ('centered_rmsprop', alias.rmsprop( LR, centered=True), optim.RMSProp(LR, centered=True)), ('adam', alias.adam(LR), optim.Adam(LR)), ('adam_w', alias.adamw(LR, weight_decay=1e-4), optim.Adam(LR, weight_decay=1e-4)), # Different name. ( 'adagrad', alias.adagrad(LR, initial_accumulator_value=0.), # Different default! optim.Adagrad(LR)), ('lamb', alias.lamb(LR), optim.LAMB(LR)), ('lars', alias.lars(LR, weight_decay=.5, trust_coefficient=0.003, momentum=0.9, eps=1e-3), optim.LARS( LR, weight_decay=.5, trust_coefficient=0.003, beta=0.9, eps=1e-3)), ('adafactor', alias.adafactor(learning_rate=LR / 10., factored=True, multiply_by_parameter_scale=True, clipping_threshold=1.0, decay_rate=0.8, min_dim_size_to_factor=2), optim.Adafactor(learning_rate=LR / 10., factored=True, multiply_by_parameter_scale=True, clipping_threshold=1.0, decay_rate=0.8, min_dim_size_to_factor=2)), ) def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): # flax/optim flax_params = self.init_params flax_optimizer = flax_optimizer.create(flax_params) for _ in range(STEPS): flax_optimizer = flax_optimizer.apply_gradient( self.per_step_updates) flax_params = flax_optimizer.target # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) for _ in range(STEPS): updates, state = optax_optimizer.update(self.per_step_updates, state, optax_params) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(flax_params, optax_params, rtol=2e-4)
class AliasTest(chex.TestCase): @parameterized.product( ( dict(opt_name='sgd', opt=lambda: alias.sgd(1e-3, 0.9)), dict(opt_name='adafactor', opt=lambda: alias.adafactor(5e-3)), dict(opt_name='adagrad', opt=lambda: alias.adagrad(1.0)), dict(opt_name='adam', opt=lambda: alias.adam(1e-1)), dict(opt_name='adamw', opt=lambda: alias.adamw(1e-1)), dict(opt_name='lars', opt=lambda: alias.lars(1.0)), dict(opt_name='lamb', opt=lambda: alias.lamb(1e-3)), dict(opt_name='noisy_sgd', opt=lambda: alias.noisy_sgd(1e-3, eta=1e-4)), dict(opt_name='rmsprop', opt=lambda: alias.rmsprop(5e-3)), dict(opt_name='rmsprop_momentum', opt=lambda: alias.rmsprop(5e-3, momentum=0.9)), dict(opt_name='fromage', opt=lambda: alias.fromage(5e-3)), dict(opt_name='adabelief', opt=lambda: alias.adabelief(1e-2)), dict(opt_name='radam', opt=lambda: alias.radam(5e-3)), dict(opt_name='sm3', opt=lambda: alias.sm3(1.0)), dict(opt_name='yogi', opt=lambda: alias.yogi(1e-1)), dict(opt_name='dpsgd', opt=lambda: alias.dpsgd(1e-3, 10.0, 0.001, 0, 0.2)), ), target=(_setup_parabola, _setup_rosenbrock), dtype=(jnp.float32, jnp.complex64), ) def test_optimization(self, opt_name, opt, target, dtype): if (opt_name in ('fromage', 'noisy_sgd', 'sm3') and jnp.iscomplexobj(dtype)): raise absltest.SkipTest( f'{opt_name} does not support complex parameters.') opt = opt() initial_params, final_params, get_updates = target(dtype) @jax.jit def step(params, state): updates = get_updates(params) if opt_name == 'dpsgd': updates = updates[None] # Complex gradients need to be conjugated before being added to parameters # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 updates = jax.tree_map(lambda x: x.conj(), updates) updates, state = opt.update(updates, state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) for _ in range(10000): params, state = step(params, state) chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2) @parameterized.named_parameters([ ('float32', 'float32'), ('bfloat16', 'bfloat16'), ('complex64', 'complex64'), ('None', None), ]) def test_explicit_dtype(self, dtype): expected_dtype = jax.dtypes.canonicalize_dtype( dtype) # None -> float32 tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype) trace_state, _ = tx.init(jnp.array([0.0, 0.0])) self.assertEqual(expected_dtype, trace_state.trace.dtype) tx = alias.adam(0.1, mu_dtype=dtype) adam_state, _ = tx.init(jnp.array([0.0, 0.0])) self.assertEqual(expected_dtype, adam_state.mu.dtype) tx = alias.adamw(0.1, mu_dtype=dtype) adam_state, _, _ = tx.init(jnp.array([0.0, 0.0])) self.assertEqual(expected_dtype, adam_state.mu.dtype)
def test_empty(self, container): init_fn, update_fn = combine.multi_transform({0: alias.sgd(1.)}, lambda _: 0) updates, _ = update_fn(container(), init_fn(container())) self.assertEqual(updates, container())