def test_multi_optimizer(self): params = {'a': 0., 'b': 0.} opt_a = optim.GradientDescent(learning_rate=1.) opt_b = optim.GradientDescent(learning_rate=10.) t_a = traverse_util.t_identity['a'] t_b = traverse_util.t_identity['b'] optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b)) state = optimizer_def.init_state(params) expected_hyper_params = [ _GradientDescentHyperParams(1.), _GradientDescentHyperParams(10.) ] self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = [optim.OptimizerState(0, [()])] * 2 self.assertEqual(state, expected_state) grads = {'a': -1., 'b': -2.} new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_params = {'a': 1., 'b': 20.} expected_state = [optim.OptimizerState(1, [()])] * 2 self.assertEqual(new_state, expected_state) self.assertEqual(new_params, expected_params) # override learning_rate hp = optimizer_def.update_hyper_params(learning_rate=2.) new_params, new_state = optimizer_def.apply_gradient( hp, params, state, grads) expected_params = {'a': 2., 'b': 4.} self.assertEqual(new_params, expected_params)
def test_init_state(self): params = onp.zeros((1,)) optimizer_def = optim.GradientDescent(learning_rate=0.1) state = optimizer_def.init_state(params) expected_hyper_params = _GradientDescentHyperParams(0.1) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState(0, ()) self.assertEqual(state, expected_state)