Example #1
0
 def test_apply_gradient(self):
     optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2)
     params = np.ones((1, ))
     state = optim.OptimizerState(0, _MomentumParamState(np.array([1.])))
     grads = np.array([3.])
     new_params, new_state = optimizer_def.apply_gradient(
         optimizer_def.hyper_params, params, state, grads)
     expected_new_state = optim.OptimizerState(
         1, _MomentumParamState(np.array([3.2])))
     expected_new_params = np.array([1. - 0.32])
     self.assertEqual(new_params, expected_new_params)
     self.assertEqual(new_state, expected_new_state)
Example #2
0
 def test_momentum_with_weight_norm(self):
   params = onp.ones((2, 2)) * 2.
   optimizer_def = optim.WeightNorm(optim.Momentum(0.1))
   state = optimizer_def.init_state(params)
   self.assertEqual(jax.tree_map(onp.shape, state), optim.OptimizerState(
       step=(),
       param_states=_WeightNormParamState(
           direction_state=_MomentumParamState(momentum=(2, 2)),
           scale_state=_MomentumParamState(momentum=(1, 2)),
           mult=(1, 2)
       )
   ))
   grads = onp.ones((2, 2))
   new_params, new_state = optimizer_def.apply_gradient(
       optimizer_def.hyper_params, params, state, grads)
   onp.testing.assert_allclose(new_params, onp.full_like(params, 1.9))
   onp.testing.assert_allclose(new_state.param_states.mult, 1.9 * 2 ** 0.5)
Example #3
0
 def test_create(self):
   params = onp.ones((1,))
   optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2)
   optimizer = optimizer_def.create(params)
   expected_state = optim.OptimizerState(
       0, _MomentumParamState(onp.zeros((1,))))
   self.assertEqual(optimizer.optimizer_def, optimizer_def)
   self.assertEqual(optimizer.state, expected_state)
   self.assertEqual(optimizer.target, params)
Example #4
0
 def test_init_state(self):
   params = onp.zeros((1,))
   optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2)
   state = optimizer_def.init_state(params)
   expected_hyper_params = _MomentumHyperParams(0.1, 0.2, 0, False)
   self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
   expected_state = optim.OptimizerState(
       0, _MomentumParamState(onp.zeros((1,))))
   self.assertEqual(state, expected_state)