Пример #1
0
 def test_negative(self):
     """Check piecewise constant schedule of negative values."""
     # Get schedule function.
     schedule_fn = self.variant(
         schedule.piecewise_constant_schedule(-0.1, {
             3: 2.,
             6: 0.5
         }))
     # Test that generated values equal the expected schedule values.
     generated_vals = []
     for count in range(10):
         # Compute next value.
         generated_vals.append(schedule_fn(count))
     # Test output.
     expected_vals = -1 * np.array(
         [0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1])
     np.testing.assert_allclose(expected_vals,
                                np.array(generated_vals),
                                atol=1e-3)
Пример #2
0
    def test_updates(self):
        optim = schedule.inject_hyperparams(transform.scale)(  # stateless
            step_size=schedule.piecewise_constant_schedule(
                3.0, {
                    2: 5,
                    8: 2,
                    13: 1.5
                }))

        params = [jnp.zeros([], dtype=jnp.float32)]
        state = self.variant(optim.init)(params)
        update_fn = self.variant(optim.update)
        expected_step_size = [3.0] * 2 + [15.0] * 6 + [30.0] * 5 + [45.0] * 3

        grads = [jnp.ones([], dtype=jnp.float32)]
        for i in range(15):
            updates, state = update_fn(grads, state, params=params)
            np.testing.assert_almost_equal(updates[0],
                                           expected_step_size[i + 1])
Пример #3
0
    def test_hyperparams_state(self):
        optim = schedule.inject_hyperparams(transform.trace)(  # stateful
            decay=schedule.piecewise_constant_schedule(0.8, {
                4: 0.5,
                10: 1.25
            }),
            nesterov=True)

        params = [jnp.zeros([2, 3]) for _ in range(3)]
        state = self.variant(optim.init)(params)
        update_fn = self.variant(optim.update)

        expected_mom = [0.8] * 4 + [0.4] * 6 + [0.5] * 2
        grads = jax.tree_map(jnp.ones_like, params)
        for i in range(12):
            np.testing.assert_almost_equal(state.hyperparams['decay'],
                                           expected_mom[i])
            _, state = update_fn(grads, state)

        np.testing.assert_almost_equal(state.hyperparams['decay'],
                                       expected_mom[-1])