Esempio n. 1
0
 def test_apply_gradient(self):
   optimizer_def = optim.Adam(learning_rate=0.1,
                              beta1=0.2,
                              beta2=0.9,
                              eps=0.01,
                              weight_decay=0.0)
   params = onp.array([1.])
   state = optim.OptimizerState(
       1, _AdamParamState(onp.array([0.1]), onp.array([0.9])))
   grads = onp.array([4.])
   new_params, new_state = optimizer_def.apply_gradient(
       optimizer_def.hyper_params, params, state, grads)
   expected_new_state = optim.OptimizerState(
       2, _AdamParamState(onp.array([3.22]), onp.array([2.41])))
   expected_new_params = onp.array([0.906085])
   onp.testing.assert_allclose(new_params, expected_new_params)
   self.assertEqual(new_state, expected_new_state)
Esempio n. 2
0
  def test_init_state(self):
    params = onp.zeros((1,))
    optimizer_def = optim.Adam(learning_rate=0.1,
                               beta1=0.2,
                               beta2=0.9,
                               eps=0.01,
                               weight_decay=0.0)
    state = optimizer_def.init_state(params)

    expected_hyper_params = _AdamHyperParams(0.1, 0.2, 0.9, 0.01, 0.0)
    self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
    expected_state = optim.OptimizerState(
        0, _AdamParamState(onp.zeros((1,)), onp.zeros((1,))))
    self.assertEqual(state, expected_state)
Esempio n. 3
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad,
                             hessian):
        """takes an additional hessian parameter"""
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        weight_decay = hyper_params.weight_decay

        hessian = average_magnitude(hessian)
        hessian_sq = jax.lax.square(hessian)
        grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * hessian_sq

        # bias correction
        t = step + 1.
        grad_ema_corr = grad_ema / (1 - beta1**t)
        grad_sq_ema_corr = grad_sq_ema / (1 - beta2**t)

        denom = jnp.sqrt(
            grad_sq_ema_corr)**hyper_params.hessian_power + hyper_params.eps
        new_param = param - hyper_params.learning_rate * grad_ema_corr / denom
        new_param -= hyper_params.learning_rate * weight_decay * param
        new_state = _AdamParamState(grad_ema, grad_sq_ema)
        return new_param, new_state
Esempio n. 4
0
 def init_param_state(self, param):
     return _AdamParamState(jnp.zeros_like(param), jnp.zeros_like(param))