Exemplo n.º 1
0
    def test_numeric_static_args(self, static_args):
        optim = schedule.inject_hyperparams(transform.scale_by_adam,
                                            static_args=static_args)(b1=0.9,
                                                                     b2=0.95)

        params = [jnp.ones((1, 2)), jnp.ones(2), jnp.ones((1, 1, 1))]
        grads = params
        state = self.variant(optim.init)(params)
        _, state = self.variant(optim.update)(grads, state)

        assert not set(state.hyperparams.keys()).intersection(set(static_args))
Exemplo n.º 2
0
    def test_overriding_hyperparam(self):
        optim = schedule.inject_hyperparams(transform.clip_by_global_norm)(0.1)
        params = jnp.zeros((3, 5, 7))
        state = self.variant(optim.init)(params)
        update_fn = self.variant(optim.update)

        grads = jnp.ones_like(params)
        for i in range(5):
            state.hyperparams['max_norm'] = i
            updates, state = update_fn(grads, state)
            assert np.isclose(jnp.linalg.norm(updates.ravel()), i)
Exemplo n.º 3
0
    def test_constant_hyperparams(self):
        optim = schedule.inject_hyperparams(transform.scale_by_adam)(b1=0.,
                                                                     b2=0.)

        params = [jnp.zeros([2, 3]) for _ in range(3)]
        state = self.variant(optim.init)(params)
        update_fn = self.variant(optim.update)

        grads = jax.tree_map(jnp.ones_like, params)
        for _ in range(5):
            updates, state = update_fn(grads, state, params)
            np.testing.assert_almost_equal(state.hyperparams['b1'], 0.0)
            np.testing.assert_almost_equal(state.hyperparams['b2'], 0.0)
            np.testing.assert_almost_equal(state.hyperparams['eps'], 1e-8)
            np.testing.assert_almost_equal(state.hyperparams['eps_root'], 0.0)
            assert 'eps' in state.hyperparams
            chex.assert_tree_all_close(updates, grads)
Exemplo n.º 4
0
    def test_updates(self):
        optim = schedule.inject_hyperparams(transform.scale)(  # stateless
            step_size=schedule.piecewise_constant_schedule(
                3.0, {
                    2: 5,
                    8: 2,
                    13: 1.5
                }))

        params = [jnp.zeros([], dtype=jnp.float32)]
        state = self.variant(optim.init)(params)
        update_fn = self.variant(optim.update)
        expected_step_size = [3.0] * 2 + [15.0] * 6 + [30.0] * 5 + [45.0] * 3

        grads = [jnp.ones([], dtype=jnp.float32)]
        for i in range(15):
            updates, state = update_fn(grads, state, params=params)
            np.testing.assert_almost_equal(updates[0],
                                           expected_step_size[i + 1])
Exemplo n.º 5
0
    def test_hyperparams_state(self):
        optim = schedule.inject_hyperparams(transform.trace)(  # stateful
            decay=schedule.piecewise_constant_schedule(0.8, {
                4: 0.5,
                10: 1.25
            }),
            nesterov=True)

        params = [jnp.zeros([2, 3]) for _ in range(3)]
        state = self.variant(optim.init)(params)
        update_fn = self.variant(optim.update)

        expected_mom = [0.8] * 4 + [0.4] * 6 + [0.5] * 2
        grads = jax.tree_map(jnp.ones_like, params)
        for i in range(12):
            np.testing.assert_almost_equal(state.hyperparams['decay'],
                                           expected_mom[i])
            _, state = update_fn(grads, state)

        np.testing.assert_almost_equal(state.hyperparams['decay'],
                                       expected_mom[-1])
Exemplo n.º 6
0
 def test_static_args_error(self, static_args):
     with self.assertRaises(ValueError):
         schedule.inject_hyperparams(transform.scale,
                                     static_args=static_args)