Exemplo n.º 1
0
 def test_multi_optimizer(self):
   params = {'a': 0., 'b': 0.}
   opt_a = optim.GradientDescent(learning_rate=1.)
   opt_b = optim.GradientDescent(learning_rate=10.)
   t_a = traverse_util.t_identity['a']
   t_b = traverse_util.t_identity['b']
   optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b))
   state = optimizer_def.init_state(params)
   expected_hyper_params = [
       _GradientDescentHyperParams(1.),
       _GradientDescentHyperParams(10.)
   ]
   self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
   expected_state = [optim.OptimizerState(0, [()])] * 2
   self.assertEqual(state, expected_state)
   grads = {'a': -1., 'b': -2.}
   new_params, new_state = optimizer_def.apply_gradient(
       optimizer_def.hyper_params, params, state, grads)
   expected_params = {'a': 1., 'b': 20.}
   expected_state = [optim.OptimizerState(1, [()])] * 2
   self.assertEqual(new_state, expected_state)
   self.assertEqual(new_params, expected_params)
   # override learning_rate
   hp = optimizer_def.update_hyper_params(learning_rate=2.)
   new_params, new_state = optimizer_def.apply_gradient(
       hp, params, state, grads)
   expected_params = {'a': 2., 'b': 4.}
   self.assertEqual(new_params, expected_params)
Exemplo n.º 2
0
 def test_init_state(self):
   params = onp.zeros((1,))
   optimizer_def = optim.GradientDescent(learning_rate=0.1)
   state = optimizer_def.init_state(params)
   expected_hyper_params = _GradientDescentHyperParams(0.1)
   self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
   expected_state = optim.OptimizerState(0, ())
   self.assertEqual(state, expected_state)