def test_negative(self): """Check piecewise constant schedule of negative values.""" # Get schedule function. schedule_fn = self.variant( schedule.piecewise_constant_schedule(-0.1, { 3: 2., 6: 0.5 })) # Test that generated values equal the expected schedule values. generated_vals = [] for count in range(10): # Compute next value. generated_vals.append(schedule_fn(count)) # Test output. expected_vals = -1 * np.array( [0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1]) np.testing.assert_allclose(expected_vals, np.array(generated_vals), atol=1e-3)
def test_updates(self): optim = schedule.inject_hyperparams(transform.scale)( # stateless step_size=schedule.piecewise_constant_schedule( 3.0, { 2: 5, 8: 2, 13: 1.5 })) params = [jnp.zeros([], dtype=jnp.float32)] state = self.variant(optim.init)(params) update_fn = self.variant(optim.update) expected_step_size = [3.0] * 2 + [15.0] * 6 + [30.0] * 5 + [45.0] * 3 grads = [jnp.ones([], dtype=jnp.float32)] for i in range(15): updates, state = update_fn(grads, state, params=params) np.testing.assert_almost_equal(updates[0], expected_step_size[i + 1])
def test_hyperparams_state(self): optim = schedule.inject_hyperparams(transform.trace)( # stateful decay=schedule.piecewise_constant_schedule(0.8, { 4: 0.5, 10: 1.25 }), nesterov=True) params = [jnp.zeros([2, 3]) for _ in range(3)] state = self.variant(optim.init)(params) update_fn = self.variant(optim.update) expected_mom = [0.8] * 4 + [0.4] * 6 + [0.5] * 2 grads = jax.tree_map(jnp.ones_like, params) for i in range(12): np.testing.assert_almost_equal(state.hyperparams['decay'], expected_mom[i]) _, state = update_fn(grads, state) np.testing.assert_almost_equal(state.hyperparams['decay'], expected_mom[-1])