def test_apply_gradient(self): optimizer_def = optim.Adam(learning_rate=0.1, beta1=0.2, beta2=0.9, eps=0.01, weight_decay=0.0) params = onp.array([1.]) state = optim.OptimizerState( 1, _AdamParamState(onp.array([0.1]), onp.array([0.9]))) grads = onp.array([4.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 2, _AdamParamState(onp.array([3.22]), onp.array([2.41]))) expected_new_params = onp.array([0.906085]) onp.testing.assert_allclose(new_params, expected_new_params) self.assertEqual(new_state, expected_new_state)
def test_init_state(self): params = onp.zeros((1,)) optimizer_def = optim.Adam(learning_rate=0.1, beta1=0.2, beta2=0.9, eps=0.01, weight_decay=0.0) state = optimizer_def.init_state(params) expected_hyper_params = _AdamHyperParams(0.1, 0.2, 0.9, 0.01, 0.0) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _AdamParamState(onp.zeros((1,)), onp.zeros((1,)))) self.assertEqual(state, expected_state)
def apply_param_gradient(self, step, hyper_params, param, state, grad, hessian): """takes an additional hessian parameter""" beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 weight_decay = hyper_params.weight_decay hessian = average_magnitude(hessian) hessian_sq = jax.lax.square(hessian) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * hessian_sq # bias correction t = step + 1. grad_ema_corr = grad_ema / (1 - beta1**t) grad_sq_ema_corr = grad_sq_ema / (1 - beta2**t) denom = jnp.sqrt( grad_sq_ema_corr)**hyper_params.hessian_power + hyper_params.eps new_param = param - hyper_params.learning_rate * grad_ema_corr / denom new_param -= hyper_params.learning_rate * weight_decay * param new_state = _AdamParamState(grad_ema, grad_sq_ema) return new_param, new_state
def init_param_state(self, param): return _AdamParamState(jnp.zeros_like(param), jnp.zeros_like(param))