def test_equivalence(self): hb = transform_chain( ['precondition_by_rms', 'polyak_hb', 'add_decayed_weights'], [{ 'decay': 0.3 }, { 'decay': 0.5 }, { 'weight_decay': 0.1 }], learning_rate=1.0) ema = transform_chain( ['precondition_by_rms', 'first_moment_ema', 'add_decayed_weights'], [{ 'decay': 0.3 }, { 'decay': 0.5 }, { 'weight_decay': 0.05 }], learning_rate=2.0) targets = _optimizer_loop(hb) results = _optimizer_loop(ema) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def test_debias_true(self): adam = transform_chain(['scale_by_adam'], [{'b1': 0.0}]) precondition_by_rms = transform_chain(['precondition_by_rms'], [{ 'debias': True }]) targets = _optimizer_loop(adam) results = _optimizer_loop(precondition_by_rms) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def test_mask_dim_1(self): """Test mask dimension.""" optimizer = transform_chain( ['nesterov'], masks=[lambda p: jax.tree_map(lambda x: x.ndim != 1, p)]) params = {'w': jnp.array([1, 2, 3]), 'b': jnp.ones((2, 2))} state = optimizer.init(params) update, state = optimizer.update(params, state, params) optimizer2 = transform_chain(['nesterov']) params2 = {'b': jnp.ones((2, 2))} state2 = optimizer2.init(params2) update2, state2 = optimizer2.update(params2, state2, params2) chex.assert_trees_all_close(update['b'], update2['b']) chex.assert_trees_all_close(update['w'], params['w'])
def test_correctness(self): """Testing correctness via an independent flax.optim run.""" target_solution = [ { 'w': jnp.array([0.65, 0.58000004]) }, { 'w': jnp.array([0.26849997, 0.12220004]) }, { 'w': jnp.array([0.09766498, -0.08280197]) }, { 'w': jnp.array([0.17850482, 0.01420582]) }, { 'w': jnp.array([0.38620475, 0.2634457]) }, ] optimizer = transform_chain(['polyak_hb'], [{ 'decay': 0.7 }], learning_rate=0.01) results = _optimizer_loop(optimizer) for target, result in zip(target_solution, results): chex.assert_trees_all_close(target, result)
def test_correctness(self): """Testing correctness via an independent flax.optim run.""" target_solution = [ { 'w': jnp.array([0.40500003, 0.286]) }, { 'w': jnp.array([0.255515, 0.106618]) }, { 'w': jnp.array([0.31884143, 0.18260972]) }, { 'w': jnp.array([0.40163627, 0.28196353]) }, { 'w': jnp.array([0.43924114, 0.32708937]) }, ] optimizer = transform_chain(['nesterov'], [{ 'decay': 0.7 }], learning_rate=0.01) results = _optimizer_loop(optimizer) for target, result in zip(target_solution, results): chex.assert_trees_all_close(target, result)
def test_output_modality_1(self): decays = [0.19, 0.75, 1.0] scales = [0.9, 0.5, 1.0] decay_distribution = [0.34, 0.34, 0.32] ks_opt = transform_chain(['precondition_by_layered_adaptive_rms'], [{ 'decays': decays, 'scales': scales, 'decay_distribution': decay_distribution, 'eps_root': 0.0 }], learning_rate=1.0) scales = jnp.array([0.9, 0.5, 1.0, 1.0]) betas = jnp.array([0.19, 0.75, 1.0, 1.0]) one_minus_betas = jnp.array([0.81, 0.25, 1.0, 1.0]) params = {'w': jnp.ones((4, ))} opt_state = ks_opt.init(params) # step 1 grads = {'w': 2 * jnp.ones((4, ))} true_nu = one_minus_betas * (grads['w']**2) true_updates = { 'w': -1.0 * jnp.array(scales) * grads['w'] / jnp.sqrt(true_nu) } opt_updates, opt_state = ks_opt.update(grads, opt_state) chex.assert_trees_all_close(true_updates, opt_updates) params = optax.apply_updates(params, opt_updates) # step2 grads = {'w': jnp.ones((4, ))} true_nu = one_minus_betas * (grads['w']**2) + betas * true_nu true_updates = { 'w': -1.0 * jnp.array(scales) * grads['w'] / jnp.sqrt(true_nu) } opt_updates, opt_state = ks_opt.update(grads, opt_state) chex.assert_trees_all_close(true_updates, opt_updates)
def test_correctness(self): """Testing correctness via optax.adam.""" def amsgrad(): adam = optax.scale_by_adam() def init_fn(params): return adam.init(params) def update_fn(updates, state, params=None): prev_nu = state.nu _, state = adam.update(updates, state, params) curr_nu = state.nu nu_hat = jax.tree_multimap(jnp.maximum, curr_nu, prev_nu) updates = jax.tree_multimap( lambda m, v: m / (jnp.sqrt(v + 0.0) + 1e-8), state.mu, nu_hat) return updates, optax.ScaleByAdamState(count=state.count, mu=state.mu, nu=nu_hat) return optax.GradientTransformation(init_fn, update_fn) true_amsgrad = amsgrad() ks_amsgrad = transform_chain(['scale_by_amsgrad']) targets = _optimizer_loop(true_amsgrad) results = _optimizer_loop(ks_amsgrad) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def test_correctness(self): """Testing correctness via independent implementation.""" def ema(decay, debias=True): def init_fn(params): del params return {'w': jnp.zeros((2, )), 'count': 0} def update_fn(updates, state, params=None): del params state['count'] += 1 state['w'] = ((1 - decay) * updates['w'] + decay * state['w']) if debias: update = {'w': state['w'] / (1 - decay**state['count'])} else: update = {'w': state['w']} return update, state return optax.GradientTransformation(init_fn, update_fn) decay = 0.7 learning_rate = 0.01 true_ema = optax.chain(ema(decay), optax.scale(-1. * learning_rate)) ks_ema = transform_chain(['first_moment_ema'], [{ 'decay': decay, 'debias': True, }], learning_rate=learning_rate) targets = _optimizer_loop(true_ema) results = _optimizer_loop(ks_ema) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def test_no_op(self): """Test no-op.""" optimizer = transform_chain( ['nesterov'], masks=[lambda p: jax.tree_map(lambda x: x.ndim != 1, p)]) params = {'w': jnp.array([1, 2, 3])} state = optimizer.init(params) update, state = optimizer.update(params, state, params) chex.assert_trees_all_close(params, update)
def test_debias_false(self): rms_prop = optax.scale_by_rms() precondition_by_rms = transform_chain(['precondition_by_rms'], [{ 'eps': 0, 'eps_root': 1e-8, 'decay': 0.9, 'debias': False }]) targets = _optimizer_loop(rms_prop) results = _optimizer_loop(precondition_by_rms) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def test_with_0_momentum_yogi(self): optax_yogi = optax.yogi(learning_rate=1.0, b1=0.0, b2=0.9, eps=1e-8) precondition_by_yogi = transform_chain(['precondition_by_yogi'], [{ 'eps': 1e-8, 'eps_root': 1e-6, 'b2': 0.9, 'debias': True }], learning_rate=1.0) targets = _optimizer_loop(optax_yogi) results = _optimizer_loop(precondition_by_yogi) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def test_adagrad(self): true_adagrad = optax.adagrad(0.7, initial_accumulator_value=0.3) ks_adagrad = transform_chain( ['precondition_by_rss', 'first_moment_ema'], [{ 'initial_accumulator_value': 0.3 }, { 'decay': 0.0 }], learning_rate=0.7) targets = _optimizer_loop(true_adagrad) results = _optimizer_loop(ks_adagrad) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def test_dummy_step(self): """Test dummy step.""" num_weights = 100 xs = jnp.ones((num_weights, )) ys = 1 optimizer = transform_chain(['nesterov', 'polyak_hb'], [{}, {}]) params = {'w': jnp.ones((num_weights, ))} opt_state = optimizer.init(flax.core.FrozenDict(params)) compute_loss = lambda params, x, y: optax.l2_loss( params['w'].dot(x), y) grads = jax.grad(compute_loss)(params, xs, ys) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) self.assertTrue(params)
def test_construction_no_hps(self): self.assertTrue(transform_chain(['nesterov', 'polyak_hb']))
def test_construction(self): self.assertTrue(transform_chain(['nesterov', 'polyak_hb'], [{}, {}]))
def test_correctness(self): """Testing correctness via independent implementation.""" rms_decay = 0.9 eps_root = 0.0 eps = 1e-8 moment_decay = 0.1 class State(NamedTuple): nu: optax.Updates trace: optax.Params count: chex.Array def twisted_adam(): def init_fn(params): return State(nu=jax.tree_map(jnp.zeros_like, params), trace=jax.tree_map(jnp.zeros_like, params), count=jnp.zeros([], jnp.int32)) def update_fn(updates, state, params=None): del params count = state.count + jnp.array(1, jnp.int32) nu = { 'w': (1 - rms_decay) * (updates['w']**2) + rms_decay * state.nu['w'] } updates = { 'w': updates['w'] / (jax.lax.sqrt(nu['w'] + eps_root) + eps) } updates = { 'w': updates['w'] * jnp.sqrt((1 - rms_decay**count)) } trace = { 'w': (1 - moment_decay) * updates['w'] + moment_decay * state.trace['w'] } updates = {'w': trace['w']} updates = {'w': updates['w'] / (1 - moment_decay**count)} return updates, State(nu=nu, count=count, trace=trace) return optax.GradientTransformation(init_fn, update_fn) true_twisted_adam = twisted_adam() ks_twisted_adam = transform_chain( ['precondition_by_rms', 'first_moment_ema'], [ { 'decay': rms_decay, 'eps': eps, 'eps_root': eps_root, 'debias': True }, { 'decay': moment_decay, 'debias': True }, ]) targets = _optimizer_loop(true_twisted_adam) results = _optimizer_loop(ks_twisted_adam) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)