def test_init_state(self): params = np.zeros((3, 2)) optimizer_def = optim.Adafactor(learning_rate=0.1, decay_rate=0.8, beta1=None, min_dim_size_to_factor=0) state = optimizer_def.init_state(params) expected_hyper_params = _AdafactorHyperParams(0.1, True, True, None, 0.8, 0, 1.0, None, 0, 1e-30, 1e-3) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _AdafactorParamState(np.zeros((2, )), np.zeros((3, )), np.zeros((1, )), np.zeros((1, )))) check_eq(state, expected_state) # unfactorized optimizer_def = optim.Adafactor(learning_rate=0.1, decay_rate=0.8, beta1=0.0, min_dim_size_to_factor=32) state = optimizer_def.init_state(params) expected_hyper_params = _AdafactorHyperParams(0.1, True, True, 0.0, 0.8, 0, 1.0, None, 32, 1e-30, 1e-3) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _AdafactorParamState(np.zeros((1, )), np.zeros((1, )), np.zeros((3, 2)), np.zeros((3, 2)))) check_eq(state, expected_state)
def test_apply_gradient(self): optimizer_def = optim.Adafactor(learning_rate=0.1, decay_rate=0.8, min_dim_size_to_factor=0) params = onp.ones((3, 2), onp.float32) state = optim.OptimizerState( 1, _AdafactorParamState(onp.array([0.9, 0.9]), onp.array([0.1, 0.1, 0.1]), onp.zeros((1,)), onp.zeros((1,)))) grads = onp.ones((3, 2), onp.float32) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 2, _AdafactorParamState( onp.array([0.9574349, 0.9574349]), onp.array([0.6169143, 0.6169143, 0.6169143]), onp.zeros((1,)), onp.zeros((1,)))) expected_new_params = 0.9 * onp.ones((3, 2)) onp.testing.assert_allclose(new_params, expected_new_params) check_eq(new_state, expected_new_state, rtol=1e-6) # unfactored w momentum optimizer_def = optim.Adafactor(learning_rate=0.1, beta1=0.0, decay_rate=0.8, min_dim_size_to_factor=32) params = onp.ones((3, 2), onp.float32) state = optim.OptimizerState( 1, _AdafactorParamState(onp.zeros(1,), onp.zeros(1,), 0.5*onp.ones((3, 2)), onp.zeros((3, 2)))) grads = onp.ones((3, 2), onp.float32) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_params = 0.9 * onp.ones((3, 2)) onp.testing.assert_allclose(new_params, expected_new_params) expected_new_state = optim.OptimizerState( 2, _AdafactorParamState( onp.array([0.0]), onp.array([0.0]), 0.787174 * onp.ones((3, 2)), 0.1 * onp.ones((3,2)))) check_eq(new_state, expected_new_state, rtol=1e-6)