def test_apply_gradient(self): optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) params = np.ones((1, )) state = optim.OptimizerState(0, _MomentumParamState(np.array([1.]))) grads = np.array([3.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 1, _MomentumParamState(np.array([3.2]))) expected_new_params = np.array([1. - 0.32]) self.assertEqual(new_params, expected_new_params) self.assertEqual(new_state, expected_new_state)
def test_momentum_with_weight_norm(self): params = onp.ones((2, 2)) * 2. optimizer_def = optim.WeightNorm(optim.Momentum(0.1)) state = optimizer_def.init_state(params) self.assertEqual(jax.tree_map(onp.shape, state), optim.OptimizerState( step=(), param_states=_WeightNormParamState( direction_state=_MomentumParamState(momentum=(2, 2)), scale_state=_MomentumParamState(momentum=(1, 2)), mult=(1, 2) ) )) grads = onp.ones((2, 2)) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) onp.testing.assert_allclose(new_params, onp.full_like(params, 1.9)) onp.testing.assert_allclose(new_state.param_states.mult, 1.9 * 2 ** 0.5)
def test_create(self): params = onp.ones((1,)) optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) optimizer = optimizer_def.create(params) expected_state = optim.OptimizerState( 0, _MomentumParamState(onp.zeros((1,)))) self.assertEqual(optimizer.optimizer_def, optimizer_def) self.assertEqual(optimizer.state, expected_state) self.assertEqual(optimizer.target, params)
def test_init_state(self): params = onp.zeros((1,)) optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) state = optimizer_def.init_state(params) expected_hyper_params = _MomentumHyperParams(0.1, 0.2, 0, False) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _MomentumParamState(onp.zeros((1,)))) self.assertEqual(state, expected_state)