def test_chain(self): transformations = [ transform.scale_by_adam(), transform.trace(decay=0, nesterov=False), transform.scale(-LR)] # Apply updates with chain. chain_params = self.init_params chained_transforms = combine.chain(*transformations) state = chained_transforms.init(chain_params) @self.variant def update_fn(updates, state): return chained_transforms.update(updates, state) for _ in range(STEPS): updates, state = update_fn(self.per_step_updates, state) chain_params = update.apply_updates(chain_params, updates) # Manually apply sequence of transformations. manual_params = self.init_params states = [t.init(manual_params) for t in transformations] for _ in range(STEPS): updates = self.per_step_updates new_states = [] for t, s in zip(transformations, states): updates, state = t.update(updates, s) new_states.append(state) manual_params = update.apply_updates(manual_params, updates) states = new_states # Check equivalence. chex.assert_tree_all_close(manual_params, chain_params, rtol=1e-4)
def test_flatten(self): def init_params(): return (jnp.array([1., 2.]), jnp.array([3., 4.])) per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) # First calculate new params without flattening optax_sgd_params = init_params() sgd = alias.sgd(1e-2, 0.0) state_sgd = sgd.init(optax_sgd_params) updates_sgd, state_sgd = sgd.update(per_step_updates, state_sgd) sgd_params_no_flatten = update.apply_updates(optax_sgd_params, updates_sgd) # And now calculate new params with flattening optax_sgd_params = init_params() sgd = wrappers.flatten(sgd) state_sgd = sgd.init(optax_sgd_params) updates_sgd, state_sgd = sgd.update(per_step_updates, state_sgd) sgd_params_flatten = update.apply_updates(optax_sgd_params, updates_sgd) # Test that both give the same result chex.assert_tree_all_close(sgd_params_no_flatten, sgd_params_flatten, atol=1e-7, rtol=1e-7)
def test_keep_params_nonnegative(self): grads = (jnp.array([500., -500., 0.]), jnp.array([500., -500., 0.]), jnp.array([500., -500., 0.])) params = (jnp.array([-1., -1., -1.]), jnp.array([1., 1., 1.]), jnp.array([0., 0., 0.])) # vanilla sgd opt = combine.chain(transform.trace(decay=0, nesterov=False), transform.scale(-LR)) opt_state = opt.init(params) updates, _ = opt.update(grads, opt_state, params) new_params = update.apply_updates(params, updates) chex.assert_tree_all_close( new_params, (jnp.array([-6., 4., -1.]), jnp.array( [-4., 6., 1.]), jnp.array([-5., 5., 0.]))) # sgd with keeping parameters non-negative opt = combine.chain(transform.trace(decay=0, nesterov=False), transform.scale(-LR), constrain.keep_params_nonnegative()) opt_state = opt.init(params) updates, _ = opt.update(grads, opt_state, params) new_params = update.apply_updates(params, updates) chex.assert_tree_all_close(new_params, (jnp.array( [0., 4., 0.]), jnp.array([0., 6., 1.]), jnp.array([0., 5., 0.])))
def test_apply_if_finite(self, opt_builder): one = jnp.ones([]) nan = jnp.array(jnp.nan) def fn(x): return x * hk.get_parameter('p', [], init=hk.initializers.Constant(0.)) fn = hk.without_apply_rng(hk.transform(fn)) params = fn.init(jax.random.PRNGKey(1905), one) opt = wrappers.apply_if_finite(opt_builder(), 2) state = opt.init(params) grads_fn = jax.grad(self.variant(fn.apply)) # Do one successful param update grads = grads_fn(params, one) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) # We know exactly what should be the value of params since we are # effectively using sgd in all cases. self.assertEqual(-1., float(jax.tree_flatten(params)[0][0])) self.assertTrue(bool(state.last_finite)) # Check 2 rejected param updates for step in range(2): grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-1., float(jax.tree_flatten(params)[0][0])) self.assertFalse(bool(state.last_finite)) self.assertEqual(step + 1, int(state.notfinite_count)) # Next successful param update grads = grads_fn(params, one) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-2., float(jax.tree_flatten(params)[0][0])) self.assertTrue(bool(state.last_finite)) # Again 2 rejected param updates for step in range(2): grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-2., float(jax.tree_flatten(params)[0][0])) self.assertFalse(bool(state.last_finite)) self.assertEqual(step + 1, int(state.notfinite_count)) # Next param update with NaN is accepted since we reached maximum grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertTrue(bool(jnp.isnan(jax.tree_flatten(params)[0][0]))) self.assertEqual(5, int(state.total_notfinite))
def test_update_requires_params(self): weight_decay = 0.1 mask = {'a': True, 'b': [False, True], 'c': {'d': True, 'e': (False, True)}} params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}} input_updates = jax.tree_util.tree_map(lambda x: x/10., params) correct_updates = jax.tree_util.tree_multimap( lambda m, u, p: u + weight_decay * p if m else u, mask, input_updates, params) init_fn, update_fn = wrappers.masked( transform.additive_weight_decay(weight_decay), mask) update_fn = self.variant(update_fn) state = init_fn(params) updates, state = update_fn(input_updates, state, params) chex.assert_tree_all_close(updates, correct_updates) params = update.apply_updates(params, updates) # Test repeated application new_correct_updates = jax.tree_util.tree_multimap( lambda m, u, p: u + weight_decay * p if m else u, mask, correct_updates, params) updates, state = update_fn(correct_updates, state, params) chex.assert_tree_all_close(updates, new_correct_updates)
def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
def update_fn(updates, state, params): # The test optimizer does not use the parameters, but we check that they # have been passed correctly. chex.assert_trees_all_equal_shapes(updates, params) aggregate_grads = update.apply_updates(state.aggregate_grads, updates) updates = jax.tree_map(lambda u: step_size * u, updates) return updates, TestOptimizerState(aggregate_grads, is_reset=False)
def test_apply_every(self): # The frequency of the application of sgd k = 4 zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.])) # optax sgd optax_sgd_params = self.init_params sgd = alias.sgd(LR, 0.0) state_sgd = sgd.init(optax_sgd_params) # optax sgd plus apply every optax_sgd_apply_every_params = self.init_params sgd_apply_every = combine.chain( transform.apply_every(k=k), transform.trace(decay=0, nesterov=False), transform.scale(-LR)) state_sgd_apply_every = sgd_apply_every.init( optax_sgd_apply_every_params) transform_fn = self.variant(sgd_apply_every.update) for i in range(STEPS): # Apply a step of sgd updates_sgd, state_sgd = sgd.update(self.per_step_updates, state_sgd) optax_sgd_params = update.apply_updates(optax_sgd_params, updates_sgd) # Apply a step of sgd_apply_every updates_sgd_apply_every, state_sgd_apply_every = transform_fn( self.per_step_updates, state_sgd_apply_every) optax_sgd_apply_every_params = update.apply_updates( optax_sgd_apply_every_params, updates_sgd_apply_every) # Every k steps, check equivalence. if i % k == k - 1: chex.assert_tree_all_close(optax_sgd_apply_every_params, optax_sgd_params, atol=1e-6, rtol=1e-5) # Otherwise, check update is zero. else: chex.assert_tree_all_close(updates_sgd_apply_every, zero_update, atol=0.0, rtol=0.0)
def do_update(loss_fun, optimizer, params, opt_state): loss, grads = jax.value_and_grad(loss_fun)(params) # Complex gradients need to be conjugated before being added to parameters grads = jax.tree_map(lambda x: x.conj(), grads) updates, opt_state = self.variant(optimizer.update)(grads, opt_state, params) params = update.apply_updates(params, updates) return loss, grads, params, opt_state
def step(params, state): updates = get_updates(params) if opt_name == 'dpsgd': updates = updates[None] # Complex gradients need to be conjugated before being added to parameters # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 updates = jax.tree_map(lambda x: x.conj(), updates) updates, state = opt.update(updates, state, params) params = update.apply_updates(params, updates) return params, state
def loop(self, optimizer, num_steps, params): """Performs a given number of optimizer steps.""" init_fn, update_fn = optimizer # Use the chex variant to check various function versions (jit, pmap, etc). step = self.variant(update_fn) opt_state = self.variant(init_fn)(params) for _ in range(num_steps): updates, opt_state = step(self.grads, opt_state, params) params = update.apply_updates(params, updates) return params, opt_state
def test_float32_input_outputs(self, transform_constr, transform_kwargs): initial_params = (jnp.array([1., 2.], dtype=jnp.float32), jnp.array([3., 4.], dtype=jnp.float32)) updates = (jnp.array([10., 21.], dtype=jnp.float32), jnp.array([33., 42.], dtype=jnp.float32)) scaler = transform_constr(**transform_kwargs) init_fn = self.variant(scaler.init) update_fn = self.variant(scaler.update) initial_state = init_fn(initial_params) updates, new_state = update_fn(updates, initial_state, params=initial_params) new_params = update.apply_updates(initial_params, updates) self._assert_dtype_equals(initial_state, new_state) self._assert_dtype_equals(initial_params, new_params)
def test_multi_steps(self): batch_size = 32 x_size = 7 # Parameters should be updated only every `k_steps` optimisation steps. k_steps = 4 data = jnp.ones([batch_size, x_size]) def get_loss(x): loss = jnp.sum(hk.Linear(10)(x)**2) return loss loss_init, loss_apply = hk.without_apply_rng(hk.transform(get_loss)) params = loss_init(jax.random.PRNGKey(1915), data) ms_opt = wrappers.MultiSteps(alias.adam(1e-4), k_steps) opt_init, opt_update = ms_opt.gradient_transformation() # Put the training in one function, to check that the update is indeed # jittable. def train_step(data, opt_state, params): grad = jax.grad(loss_apply)(params, data) updates, opt_state = opt_update(grad, opt_state, params) return updates, opt_state opt_state = opt_init(params) prev_loss = loss_apply(params, data) for idx in range(5 * k_steps): updates, opt_state = self.variant(train_step)(data, opt_state, params) new_params = update.apply_updates(params, updates) new_loss = loss_apply(new_params, data) if idx % k_steps < k_steps - 1: # The parameters should not have changed and the loss should be # constant. jax.tree_multimap(np.testing.assert_array_equal, new_params, params) np.testing.assert_equal(new_loss, prev_loss) self.assertFalse(ms_opt.has_updated(opt_state)) else: # This is a step where parameters should actually have been updated, and # the loss should accordingly go down. np.testing.assert_array_less(new_loss, prev_loss) prev_loss = new_loss self.assertTrue(ms_opt.has_updated(opt_state)) params = new_params
def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): # flax/optim flax_params = self.init_params flax_optimizer = flax_optimizer.create(flax_params) for _ in range(STEPS): flax_optimizer = flax_optimizer.apply_gradient( self.per_step_updates) flax_params = flax_optimizer.target # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) for _ in range(STEPS): updates, state = optax_optimizer.update(self.per_step_updates, state, optax_params) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
def fn_update(params, opt_state, x): grads = jax.grad(fn.apply)(params, x) grads = jax.lax.psum(grads, axis_name='i') updates, new_opt_state = opt.update(grads, opt_state, params) new_params = update.apply_updates(params, updates) return new_params, new_opt_state
def step(params, state): updates = get_updates(params) if opt_name == 'dpsgd': updates = updates[None] updates, state = opt.update(updates, state, params) params = update.apply_updates(params, updates) return params, state
def step(params, state): updates, state = opt.update(get_updates(params), state, params) params = update.apply_updates(params, updates) return params, state
def update_fn(updates, state, params=None): del params # unused by the test optimizer aggregate_grads = update.apply_updates(state.aggregate_grads, updates) updates = jax.tree_map(lambda u: step_size * u, updates) return updates, TestOptimizerState(aggregate_grads, is_reset=False)