def test_set_to_zero_returns_tree_of_correct_zero_arrays(self): """Tests that zero transform returns a tree of zeros of correct shape.""" grads = ({'a': np.ones((3, 4)), 'b': 1.}, np.ones((1, 2, 3))) updates, _ = self.variant(base.set_to_zero().update)(grads, base.EmptyState()) correct_zeros = ({'a': np.zeros((3, 4)), 'b': 0.}, np.zeros((1, 2, 3))) chex.assert_trees_all_close(updates, correct_zeros, rtol=0)
def test_output_modality_1(self): decays = [0.19, 0.75, 1.0] scales = [0.9, 0.5, 1.0] decay_distribution = [0.34, 0.34, 0.32] ks_opt = transform_chain(['precondition_by_layered_adaptive_rms'], [{ 'decays': decays, 'scales': scales, 'decay_distribution': decay_distribution, 'eps_root': 0.0 }], learning_rate=1.0) scales = jnp.array([0.9, 0.5, 1.0, 1.0]) betas = jnp.array([0.19, 0.75, 1.0, 1.0]) one_minus_betas = jnp.array([0.81, 0.25, 1.0, 1.0]) params = {'w': jnp.ones((4, ))} opt_state = ks_opt.init(params) # step 1 grads = {'w': 2 * jnp.ones((4, ))} true_nu = one_minus_betas * (grads['w']**2) true_updates = { 'w': -1.0 * jnp.array(scales) * grads['w'] / jnp.sqrt(true_nu) } opt_updates, opt_state = ks_opt.update(grads, opt_state) chex.assert_trees_all_close(true_updates, opt_updates) params = optax.apply_updates(params, opt_updates) # step2 grads = {'w': jnp.ones((4, ))} true_nu = one_minus_betas * (grads['w']**2) + betas * true_nu true_updates = { 'w': -1.0 * jnp.array(scales) * grads['w'] / jnp.sqrt(true_nu) } opt_updates, opt_state = ks_opt.update(grads, opt_state) chex.assert_trees_all_close(true_updates, opt_updates)
def test_correctness(self): """Testing correctness via optax.adam.""" def amsgrad(): adam = optax.scale_by_adam() def init_fn(params): return adam.init(params) def update_fn(updates, state, params=None): prev_nu = state.nu _, state = adam.update(updates, state, params) curr_nu = state.nu nu_hat = jax.tree_multimap(jnp.maximum, curr_nu, prev_nu) updates = jax.tree_multimap( lambda m, v: m / (jnp.sqrt(v + 0.0) + 1e-8), state.mu, nu_hat) return updates, optax.ScaleByAdamState(count=state.count, mu=state.mu, nu=nu_hat) return optax.GradientTransformation(init_fn, update_fn) true_amsgrad = amsgrad() ks_amsgrad = transform_chain(['scale_by_amsgrad']) targets = _optimizer_loop(true_amsgrad) results = _optimizer_loop(ks_amsgrad) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def test_correctness(self): """Testing correctness via independent implementation.""" def ema(decay, debias=True): def init_fn(params): del params return {'w': jnp.zeros((2, )), 'count': 0} def update_fn(updates, state, params=None): del params state['count'] += 1 state['w'] = ((1 - decay) * updates['w'] + decay * state['w']) if debias: update = {'w': state['w'] / (1 - decay**state['count'])} else: update = {'w': state['w']} return update, state return optax.GradientTransformation(init_fn, update_fn) decay = 0.7 learning_rate = 0.01 true_ema = optax.chain(ema(decay), optax.scale(-1. * learning_rate)) ks_ema = transform_chain(['first_moment_ema'], [{ 'decay': decay, 'debias': True, }], learning_rate=learning_rate) targets = _optimizer_loop(true_ema) results = _optimizer_loop(ks_ema) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def test_common_clipping_norm(self, l2_norm_clip): l2_norms_threshold = jax.tree_map(lambda p: l2_norm_clip, self.params) dp_agg = optimizers.dp_aggregate(l2_norms_threshold=l2_norms_threshold, base_sensitivity=1., noise_multiplier=0., init_rng=jax.random.PRNGKey(42)) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) # Shape of the three arrays below is (self.batch_size, ) norms = [ jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1) for g in jax.tree_leaves(self.per_eg_grads) ] divisors = [jnp.maximum(norm / l2_norm_clip, 1.) for norm in norms] # Since the values of all the parameters are the same within each example, # we can easily compute what the values of the gradients should be: expected_val = [ jnp.sum(jnp.arange(self.batch_size) / div) for div in divisors ] expected_tree = jax.tree_unflatten(jax.tree_structure(self.params), expected_val) expected_tree = jax.tree_map( lambda val, p: jnp.broadcast_to(val, p.shape), expected_tree, self.params) for _ in range(3): updates, state = update_fn(self.per_eg_grads, state, self.params) chex.assert_trees_all_close(updates, expected_tree, rtol=2e-7)
def test_ema_accumulator(self, decay, debias): """Test ema_accumulator.""" grads = jax.random.uniform(jax.random.PRNGKey(0), (10, 10)) grads = {'updates': grads, 'variables': {}, 'moments': {}, 'output': None} nth_grads = preconditioner.nth_power(power=[1, 2]) accumulator = preconditioner.ema_accumulator(decay, debias) grads, _ = nth_grads.update(grads, None) state = accumulator.init(grads['variables']) for i in range(5): moments, count = state grads = jax.random.uniform(jax.random.PRNGKey(i + 1), (10, 10)) grads = {'updates': grads, 'variables': {}, 'moments': {}, 'output': None} grads, _ = nth_grads.update(grads, None) updates, state = accumulator.update(grads, state) actual = jax.tree_map(lambda g, t: (1 - decay) * g + decay * t, grads['variables'], moments) if debias: count += jnp.array(1, dtype=jnp.int32) beta = jnp.array(1, dtype=jnp.int32) - decay ** count actual = jax.tree_map(lambda t: t / beta.astype(t.dtype), actual) # pylint: disable=cell-var-from-loop chex.assert_trees_all_close(updates['moments'], actual)
def test_precondition_by_rms(self, decay, eps, eps_root, debias): """Test precondition_by_rms.""" actual_rms = transform.precondition_by_rms( decay=decay, eps=eps, eps_root=eps_root, debias=debias) decon_rms = preconditioner.preconditioner(preconditioner.nth_power, preconditioner.ema_accumulator, preconditioner.rexp_updater, {'power': 2}, { 'decay': decay, 'debias': debias }, { 'eps': eps, 'eps_root': eps_root, }) params = jax.random.uniform(jax.random.PRNGKey(0), (10, 10)) init = self.variant(decon_rms.init) update = self.variant(decon_rms.update) actual_state = actual_rms.init(params) decon_state = init(params) for i in range(5): grads = jax.random.uniform(jax.random.PRNGKey(i + 1), (10, 10)) actual_updates, actual_state = actual_rms.update(grads, actual_state) decon_updates, decon_state = update(grads, decon_state) actual_params = optax.apply_updates(params, actual_updates) decon_params = optax.apply_updates(params, decon_updates) chex.assert_trees_all_close(actual_params, decon_params, atol=1e-4)
def test_correctness(self): """Testing correctness via an independent flax.optim run.""" target_solution = [ { 'w': jnp.array([0.40500003, 0.286]) }, { 'w': jnp.array([0.255515, 0.106618]) }, { 'w': jnp.array([0.31884143, 0.18260972]) }, { 'w': jnp.array([0.40163627, 0.28196353]) }, { 'w': jnp.array([0.43924114, 0.32708937]) }, ] optimizer = transform_chain(['nesterov'], [{ 'decay': 0.7 }], learning_rate=0.01) results = _optimizer_loop(optimizer) for target, result in zip(target_solution, results): chex.assert_trees_all_close(target, result)
def testLossIsNearZeroAtOrigin(self): # Check that the loss is near-zero when x is near-zero. _, loss, x, _, _, _, _, _ = self._precompute_lossfun_inputs() loss_near_zero = loss[jnp.abs(x) < 1e-5] chex.assert_trees_all_close(loss_near_zero, jnp.zeros_like(loss_near_zero), atol=1e-5)
def test_equivalence(self): hb = transform_chain( ['precondition_by_rms', 'polyak_hb', 'add_decayed_weights'], [{ 'decay': 0.3 }, { 'decay': 0.5 }, { 'weight_decay': 0.1 }], learning_rate=1.0) ema = transform_chain( ['precondition_by_rms', 'first_moment_ema', 'add_decayed_weights'], [{ 'decay': 0.3 }, { 'decay': 0.5 }, { 'weight_decay': 0.05 }], learning_rate=2.0) targets = _optimizer_loop(hb) results = _optimizer_loop(ema) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def test_correctness(self): """Testing correctness via an independent flax.optim run.""" target_solution = [ { 'w': jnp.array([0.65, 0.58000004]) }, { 'w': jnp.array([0.26849997, 0.12220004]) }, { 'w': jnp.array([0.09766498, -0.08280197]) }, { 'w': jnp.array([0.17850482, 0.01420582]) }, { 'w': jnp.array([0.38620475, 0.2634457]) }, ] optimizer = transform_chain(['polyak_hb'], [{ 'decay': 0.7 }], learning_rate=0.01) results = _optimizer_loop(optimizer) for target, result in zip(target_solution, results): chex.assert_trees_all_close(target, result)
def test_batch(self): """Test that batch layer is indeed ignored. Code taken from: https://github.com/google/flax/issues/932 """ key = jax.random.PRNGKey(0) x = jnp.ones((5, 4, 4, 3)) y = jax.random.uniform(key, (5, 4, 4, 7)) foo_vars = flax.core.unfreeze(Foo(filters=7, train=True).init(key, x)) tx = optax.masked(optax.adam(1e-7), create_weight_decay_mask()) @self.variant def train_step(params, x, y): y1, new_batch_stats = Foo( filters=7, train=True).apply( params, x, mutable=['batch_stats']) return jnp.abs(y - y1).sum(), new_batch_stats state = self.variant(tx.init)(foo_vars['params']) grads, _ = jax.grad(train_step, has_aux=True)(foo_vars, x, y) updates, state = self.variant(tx.update)(dict(grads['params']), state) chex.assert_trees_all_close(updates['BatchNorm_0'], grads['params']['BatchNorm_0'])
def testDerivativeIsMonotonicWrtX(self): # Check that the loss increases monotonically with |x|. _, _, x, alpha, _, d_x, _, _ = self._precompute_lossfun_inputs() # This is just to suppress a warning below. d_x = jnp.where(jnp.isfinite(d_x), d_x, jnp.zeros_like(d_x)) mask = jnp.isfinite(alpha) & (jnp.abs(d_x) > (300. * jnp.finfo(jnp.float32).eps)) chex.assert_trees_all_close(jnp.sign(d_x[mask]), jnp.sign(x[mask]))
def test_stateless_with_tree_map_no_params(self): updates = {'linear': jnp.full((5, 3), 3.0)} opt = base.stateless_with_tree_map(lambda g, _: g * 2.0) state = opt.init(None) update_fn = self.variant(opt.update) new_updates, _ = update_fn(updates, state) expected_updates = {'linear': jnp.full((5, 3), 6.0)} chex.assert_trees_all_close(new_updates, expected_updates)
def testLossIsQuadraticNearOrigin(self): # Check that the loss is well-approximated by a quadratic bowl when # |x| < scale _, loss, x, _, scale, _, _, _ = self._precompute_lossfun_inputs() mask = jnp.abs(x) < (0.5 * scale) loss_quad = 0.5 * jnp.square(x / scale) chex.assert_trees_all_close(loss_quad[mask], loss[mask], rtol=1e-5, atol=1e-2)
def test_no_op(self): """Test no-op.""" optimizer = transform_chain( ['nesterov'], masks=[lambda p: jax.tree_map(lambda x: x.ndim != 1, p)]) params = {'w': jnp.array([1, 2, 3])} state = optimizer.init(params) update, state = optimizer.update(params, state, params) chex.assert_trees_all_close(params, update)
def test_stateless_with_tree_map(self): params = {'a': jnp.zeros((1, 2)), 'b': jnp.ones((1, ))} updates = {'a': jnp.ones((1, 2)), 'b': jnp.full((1, ), 2.0)} opt = base.stateless_with_tree_map(lambda g, p: g + 0.1 * p) state = opt.init(params) update_fn = self.variant(opt.update) new_updates, _ = update_fn(updates, state, params) expected_updates = {'a': jnp.ones((1, 2)), 'b': jnp.array([2.1])} chex.assert_trees_all_close(new_updates, expected_updates)
def test_debias_true(self): adam = transform_chain(['scale_by_adam'], [{'b1': 0.0}]) precondition_by_rms = transform_chain(['precondition_by_rms'], [{ 'debias': True }]) targets = _optimizer_loop(adam) results = _optimizer_loop(precondition_by_rms) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def test_stateless_no_params(self): updates = {'linear': jnp.full((5, 3), 3.0)} @base.stateless def opt(g, _): return jax.tree_map(lambda g_: g_ * 2, g) state = opt.init(None) update_fn = self.variant(opt.update) new_updates, _ = update_fn(updates, state) expected_updates = {'linear': jnp.full((5, 3), 6.0)} chex.assert_trees_all_close(new_updates, expected_updates)
def test_stateless(self): params = {'a': jnp.zeros((1, 2)), 'b': jnp.ones((1, ))} updates = {'a': jnp.ones((1, 2)), 'b': jnp.full((1, ), 2.0)} @base.stateless def opt(g, p): return jax.tree_map(lambda g_, p_: g_ + 0.1 * p_, g, p) state = opt.init(params) update_fn = self.variant(opt.update) new_updates, _ = update_fn(updates, state, params) expected_updates = {'a': jnp.ones((1, 2)), 'b': jnp.array([2.1])} chex.assert_trees_all_close(new_updates, expected_updates)
def testAlphaEqualsNegativeInfinity(self): # Check that alpha == -Infinity reproduces Welsch aka Leclerc loss. x = np.linspace(-15, 15, 1000, dtype=np.float64) alpha = -float('inf') scale = 1.7 # Our loss. loss = self.variant(general.lossfun)(x, alpha, scale) # Welsch/Leclerc loss. loss_true = (1. - np.exp(-0.5 * np.square(x / scale))) chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
def testLossIsScaleInvariant(self): # Check that loss(mult * x, alpha, mult * scale) == loss(x, alpha, scale) (num_samples, loss, x, alpha, scale, _, _, _) = (self._precompute_lossfun_inputs()) # Random log-normally distributed scalings in ~(0.2, 20) rng = random.PRNGKey(0) mult = jnp.maximum(0.2, jnp.exp(random.normal(rng, shape=[num_samples]))) # Compute the scaled loss. loss_scaled = general.lossfun(mult * x, alpha, mult * scale) chex.assert_trees_all_close(loss, loss_scaled, atol=1e-4, rtol=1e-4)
def testAlphaEqualsZero(self): # Check that alpha == 0 reproduces Cauchy aka Lorentzian loss. x = np.linspace(-15, 15, 1000, dtype=np.float64) alpha = 0. scale = 1.7 # Our loss. loss = self.variant(general.lossfun)(x, alpha, scale) # Cauchy/Lorentzian loss. loss_true = (np.log(0.5 * np.square(x / scale) + 1)) chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
def testAlphaEqualsNegativeTwo(self): # Check that alpha == -2 reproduces Geman-McClure loss. x = np.linspace(-15, 15, 1000, dtype=np.float64) alpha = -2. scale = 1.7 # Our loss. loss = self.variant(general.lossfun)(x, alpha, scale) # Geman-McClure loss. loss_true = (2. * np.square(x / scale) / (np.square(x / scale) + 4.)) chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
def testAlphaEqualsOne(self): # Check that alpha == 1 reproduces Charbonnier aka pseudo-Huber loss. x = np.linspace(-15, 15, 1000, dtype=np.float64) alpha = 1. scale = 1.7 # Our loss. loss = self.variant(general.lossfun)(x, alpha, scale) # Charbonnier loss. loss_true = (np.sqrt(np.square(x / scale) + 1) - 1) chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
def testAlphaEqualsTwo(self): # Check that alpha == 2 reproduces L2 loss. x = np.linspace(-15, 15, 1000, dtype=np.float64) alpha = 2. scale = 1.7 # Our loss. loss = self.variant(general.lossfun)(x, alpha, scale) # L2 Loss. loss_true = 0.5 * np.square(x / scale) chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
def testAlphaEqualsInfinity(self): # Check that alpha == Infinity takes the correct form. x = np.linspace(-15, 15, 1000, dtype=np.float64) alpha = float('inf') scale = 1.7 # Our loss. loss = self.variant(general.lossfun)(x, alpha, scale) # The true loss. loss_true = (jnp.exp(0.5 * jnp.square(x / scale)) - 1.) chex.assert_trees_all_close(loss, loss_true, atol=1e-4, rtol=1e-4)
def test_debias_false(self): rms_prop = optax.scale_by_rms() precondition_by_rms = transform_chain(['precondition_by_rms'], [{ 'eps': 0, 'eps_root': 1e-8, 'decay': 0.9, 'debias': False }]) targets = _optimizer_loop(rms_prop) results = _optimizer_loop(precondition_by_rms) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def testAlphaEqualsFour(self): # Check that alpha == 4 reproduces a quartic. x = np.linspace(-15, 15, 1000, dtype=np.float64) alpha = 4. scale = 1.7 # Our loss. loss = self.variant(general.lossfun)(x, alpha, scale) # The true loss. loss_true = np.square(np.square(x / scale)) / 8 + np.square( x / scale) / 2 chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
def test_add_decayed_weights_with_mask(self): """Test mask is not added for add_decayed_weights if specified in hps.""" class Foo(nn.Module): """Dummy model.""" train: bool filters: int @nn.compact def __call__(self, x): x = nn.Conv(self.filters, (1, 1), use_bias=False, dtype=jnp.float32)(x) x = nn.BatchNorm(use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=jnp.float32)(x) return x tx = from_hparams( ml_collections.ConfigDict({ '0': { 'element': 'add_decayed_weights', 'hps': { 'weight_decay': 1e-4, 'mask': 'bias_bn' } } })) key = jax.random.PRNGKey(0) x = jnp.ones((5, 4, 4, 3)) y = jax.random.uniform(key, (5, 4, 4, 7)) foo_vars = flax.core.unfreeze(Foo(filters=7, train=True).init(key, x)) @self.variant def train_step(params, x, y): y1, new_batch_stats = Foo(filters=7, train=True).apply( params, x, mutable=['batch_stats']) return jnp.abs(y - y1).sum(), new_batch_stats state = self.variant(tx.init)(foo_vars['params']) grads, _ = jax.grad(train_step, has_aux=True)(foo_vars, x, y) updates, state = self.variant(tx.update)(dict(grads['params']), state, foo_vars['params']) chex.assert_trees_all_close(updates['BatchNorm_0'], grads['params']['BatchNorm_0'])