def test_init_state(self): params = np.zeros((3, 2)) optimizer_def = optim.Adafactor(learning_rate=0.1, decay_rate=0.8, beta1=None, min_dim_size_to_factor=0) state = optimizer_def.init_state(params) expected_hyper_params = _AdafactorHyperParams(0.1, True, True, None, 0.8, 0, 1.0, None, 0, 1e-30, 1e-3) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _AdafactorParamState(np.zeros((2, )), np.zeros((3, )), np.zeros((1, )), np.zeros((1, )))) check_eq(state, expected_state) # unfactorized optimizer_def = optim.Adafactor(learning_rate=0.1, decay_rate=0.8, beta1=0.0, min_dim_size_to_factor=32) state = optimizer_def.init_state(params) expected_hyper_params = _AdafactorHyperParams(0.1, True, True, 0.0, 0.8, 0, 1.0, None, 32, 1e-30, 1e-3) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _AdafactorParamState(np.zeros((1, )), np.zeros((1, )), np.zeros((3, 2)), np.zeros((3, 2)))) check_eq(state, expected_state)
def test_multi_optimizer(self): params = {'a': 0., 'b': 0.} opt_a = optim.GradientDescent(learning_rate=1.) opt_b = optim.GradientDescent(learning_rate=10.) t_a = traverse_util.t_identity['a'] t_b = traverse_util.t_identity['b'] optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b)) state = optimizer_def.init_state(params) expected_hyper_params = [ _GradientDescentHyperParams(1.), _GradientDescentHyperParams(10.) ] self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = [optim.OptimizerState(0, [()])] * 2 self.assertEqual(state, expected_state) grads = {'a': -1., 'b': -2.} new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_params = {'a': 1., 'b': 20.} expected_state = [optim.OptimizerState(1, [()])] * 2 self.assertEqual(new_state, expected_state) self.assertEqual(new_params, expected_params) # override learning_rate hp = optimizer_def.update_hyper_params(learning_rate=2.) new_params, new_state = optimizer_def.apply_gradient( hp, params, state, grads) expected_params = {'a': 2., 'b': 4.} self.assertEqual(new_params, expected_params)
def test_apply_gradient(self): optimizer_def = optim.GradientDescent(learning_rate=0.1) params = onp.ones((1,)) state = optim.OptimizerState(0, ()) grads = onp.array([3.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState(1, ()) expected_new_params = onp.array([0.7]) self.assertEqual(new_params, expected_new_params) self.assertEqual(new_state, expected_new_state)
def test_apply_gradient(self): optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) params = np.ones((1, )) state = optim.OptimizerState(0, _MomentumParamState(np.array([1.]))) grads = np.array([3.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 1, _MomentumParamState(np.array([3.2]))) expected_new_params = np.array([1. - 0.32]) self.assertEqual(new_params, expected_new_params) self.assertEqual(new_state, expected_new_state)
def test_apply_gradient(self): optimizer_def = optim.RMSProp(learning_rate=0.1, beta2=0.9, eps=0.01) params = onp.array([1.]) state = optim.OptimizerState(1, _RMSPropParamState(onp.array([0.1]))) grads = onp.array([4.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 2, _RMSPropParamState(onp.array([1.69]))) expected_new_params = onp.array([0.6946565]) onp.testing.assert_allclose(new_params, expected_new_params) self.assertEqual(new_state, expected_new_state)
def test_apply_gradient(self): optimizer_def = optim.Adagrad(learning_rate=0.1, eps=0.01) params = np.array([1.]) state = optim.OptimizerState(1, _AdagradParamState(np.array([0.1]))) grads = np.array([4.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 2, _AdagradParamState(np.array([16.1]))) expected_new_params = np.array([0.9005588]) np.testing.assert_allclose(new_params, expected_new_params) self.assertEqual(new_state, expected_new_state)
def test_optimizer_with_focus(self): params = {'a': 0., 'b': 0.} opt_def = optim.GradientDescent(learning_rate=1.) t_a = traverse_util.t_identity['a'] optimizer = opt_def.create(params, focus=t_a) expected_state = [optim.OptimizerState(0, [()])] self.assertEqual(optimizer.state, expected_state) grads = {'a': -1., 'b': -2.} new_optimizer = optimizer.apply_gradient(grads) expected_params = {'a': 1., 'b': 0.} expected_state = [optim.OptimizerState(1, [()])] self.assertEqual(new_optimizer.state, expected_state) self.assertEqual(new_optimizer.target, expected_params)
def test_apply_gradient(self): optimizer_def = optim.Adadelta(learning_rate=0.1, rho=0.9, eps=1e-6, weight_decay=0.1) params = np.array([1.]) state = optim.OptimizerState( 1, _AdadeltaParamState(np.zeros((1, )), np.zeros((1, )))) grads = np.array([1.]) new_param, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 2, _AdadeltaParamState(np.array([0.1]), np.array([9.999902e-7]))) expected_new_params = np.array([0.9896838]) np.testing.assert_allclose(new_param, expected_new_params) self.assertEqual(new_state, expected_new_state)
def test_empty_optimizer(self): params = {} optimizer_def = optim.Momentum(learning_rate=0.1) optimizer = optimizer_def.create(params) new_optimizer = optimizer.apply_gradient({}) expected_state = optim.OptimizerState(1, {}) self.assertEqual(new_optimizer.state, expected_state)
def test_apply_gradient(self): optimizer_def = optim.Adam(learning_rate=0.1, beta1=0.2, beta2=0.9, eps=0.01, weight_decay=0.0) params = onp.array([1.]) state = optim.OptimizerState( 1, _AdamParamState(onp.array([0.1]), onp.array([0.9]))) grads = onp.array([4.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 2, _AdamParamState(onp.array([3.22]), onp.array([2.41]))) expected_new_params = onp.array([0.906085]) onp.testing.assert_allclose(new_params, expected_new_params) self.assertEqual(new_state, expected_new_state)
def test_init_state(self): params = onp.zeros((1,)) optimizer_def = optim.GradientDescent(learning_rate=0.1) state = optimizer_def.init_state(params) expected_hyper_params = _GradientDescentHyperParams(0.1) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState(0, ()) self.assertEqual(state, expected_state)
def test_create(self): params = onp.ones((1,)) optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) optimizer = optimizer_def.create(params) expected_state = optim.OptimizerState( 0, _MomentumParamState(onp.zeros((1,)))) self.assertEqual(optimizer.optimizer_def, optimizer_def) self.assertEqual(optimizer.state, expected_state) self.assertEqual(optimizer.target, params)
def test_init_state(self): params = onp.zeros((1,)) optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) state = optimizer_def.init_state(params) expected_hyper_params = _MomentumHyperParams(0.1, 0.2, 0, False) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _MomentumParamState(onp.zeros((1,)))) self.assertEqual(state, expected_state)
def test_apply_gradient(self): optimizer_def = optim.Adafactor(learning_rate=0.1, decay_rate=0.8, min_dim_size_to_factor=0) params = onp.ones((3, 2), onp.float32) state = optim.OptimizerState( 1, _AdafactorParamState(onp.array([0.9, 0.9]), onp.array([0.1, 0.1, 0.1]), onp.zeros((1,)), onp.zeros((1,)))) grads = onp.ones((3, 2), onp.float32) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 2, _AdafactorParamState( onp.array([0.9574349, 0.9574349]), onp.array([0.6169143, 0.6169143, 0.6169143]), onp.zeros((1,)), onp.zeros((1,)))) expected_new_params = 0.9 * onp.ones((3, 2)) onp.testing.assert_allclose(new_params, expected_new_params) check_eq(new_state, expected_new_state, rtol=1e-6) # unfactored w momentum optimizer_def = optim.Adafactor(learning_rate=0.1, beta1=0.0, decay_rate=0.8, min_dim_size_to_factor=32) params = onp.ones((3, 2), onp.float32) state = optim.OptimizerState( 1, _AdafactorParamState(onp.zeros(1,), onp.zeros(1,), 0.5*onp.ones((3, 2)), onp.zeros((3, 2)))) grads = onp.ones((3, 2), onp.float32) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_params = 0.9 * onp.ones((3, 2)) onp.testing.assert_allclose(new_params, expected_new_params) expected_new_state = optim.OptimizerState( 2, _AdafactorParamState( onp.array([0.0]), onp.array([0.0]), 0.787174 * onp.ones((3, 2)), 0.1 * onp.ones((3,2)))) check_eq(new_state, expected_new_state, rtol=1e-6)
def test_init_state(self): params = onp.zeros((1, )) optimizer_def = optim.RMSProp(learning_rate=0.1, beta2=0.9, eps=0.01) state = optimizer_def.init_state(params) expected_hyper_params = _RMSPropHyperParams(0.1, 0.9, 0.01) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _RMSPropParamState(onp.zeros((1, )))) self.assertEqual(state, expected_state)
def test_init_state(self): params = onp.zeros((1,)) optimizer_def = optim.Adagrad(learning_rate=0.1, eps=0.01) state = optimizer_def.init_state(params) expected_hyper_params = _AdagradHyperParams(0.1, 0.01) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _AdagradParamState(onp.zeros((1,)))) self.assertEqual(state, expected_state)
def test_apply_gradient_centered(self): optimizer_def = optim.RMSProp(learning_rate=0.1, beta2=0.9, eps=0.01, centered=True) params = np.array([1.]) state = optim.OptimizerState( 1, _RMSPropParamState(np.array([0.1]), np.array([0.1]))) grads = np.array([4.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 2, _RMSPropParamState(np.array([1.69]), np.array([0.49]))) expected_new_params = np.array([0.670543], dtype=np.float32) np.testing.assert_allclose(new_params, expected_new_params, rtol=1e-6) np.testing.assert_allclose(new_state.param_states.v, expected_new_state.param_states.v) np.testing.assert_allclose(new_state.param_states.mg, expected_new_state.param_states.mg)
def maybe_copy_model_from_pretraining(optimizer, pretrain_optimizer, step, adam_opt_def): """Copy model parameters from pretraining.""" if step < FLAGS.num_pretrain_steps: optimizer = jax_utils.unreplicate(optimizer) state_dict = adam_opt_def.state_dict( target=jax_utils.unreplicate(pretrain_optimizer).target, state=optim.OptimizerState(jnp.asarray(step, dtype=jnp.int32), optimizer.state.param_states)) optimizer = optimizer.restore_state(state_dict) optimizer = jax_utils.replicate(optimizer) return optimizer
def test_init_state(self): params = np.zeros((1, )) optimizer_def = optim.Adadelta(learning_rate=0.1, rho=0.9, eps=1e-6, weight_decay=0.1) state = optimizer_def.init_state(params) expected_hyper_params = _AdadeltaHyperParams(0.1, 0.9, 1e-6, 0.1) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _AdadeltaParamState(np.zeros((1, )), np.zeros((1, )))) self.assertEqual(state, expected_state)
def test_init_state(self): params = onp.zeros((1,)) optimizer_def = optim.Adam(learning_rate=0.1, beta1=0.2, beta2=0.9, eps=0.01, weight_decay=0.0) state = optimizer_def.init_state(params) expected_hyper_params = _AdamHyperParams(0.1, 0.2, 0.9, 0.01, 0.0) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _AdamParamState(onp.zeros((1,)), onp.zeros((1,)))) self.assertEqual(state, expected_state)
def test_momentum_with_weight_norm(self): params = onp.ones((2, 2)) * 2. optimizer_def = optim.WeightNorm(optim.Momentum(0.1)) state = optimizer_def.init_state(params) self.assertEqual(jax.tree_map(onp.shape, state), optim.OptimizerState( step=(), param_states=_WeightNormParamState( direction_state=_MomentumParamState(momentum=(2, 2)), scale_state=_MomentumParamState(momentum=(1, 2)), mult=(1, 2) ) )) grads = onp.ones((2, 2)) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) onp.testing.assert_allclose(new_params, onp.full_like(params, 1.9)) onp.testing.assert_allclose(new_state.param_states.mult, 1.9 * 2 ** 0.5)
def test_init_state(self): # Create an optimizer def and check the params are wired through. optimizer_def = shampoo.Shampoo( learning_rate=0.1, beta1=0.9, beta2=0.9, diagonal_epsilon=0.0, matrix_epsilon=1e-1, exponent_override=2, weight_decay=1e-4, start_preconditioning_step=1, preconditioning_compute_steps=1, statistics_compute_steps=1, no_preconditioning_for_layers_with_dim_gt=8192, best_effort_shape_interpretation=True, block_size=8, graft_type=shampoo.LayerwiseGrafting.SGD, nesterov=False, batch_axis_name=None) expected_hyper_params = shampoo._ShampooHyperParams( learning_rate=0.1, beta1=0.9, beta2=0.9, diagonal_eps=0.0, matrix_eps=1e-1, exponent_override=2, weight_decay=1e-4, start_preconditioning_step=1, preconditioning_compute_steps=1, statistics_compute_steps=1, no_preconditioning_for_layers_with_dim_gt=8192, best_effort_shape_interpretation=True, block_size=8, graft_type=shampoo.LayerwiseGrafting.SGD, nesterov=False, batch_axis_name=None) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) params = np.zeros((1, )) state = optimizer_def.init_state(params) zeros_like_param = np.zeros((1, )) expected_state = optim.OptimizerState( 0, shampoo._ShampooDefaultParamState(zeros_like_param, [], [], zeros_like_param, zeros_like_param)) self.assertEqual(state, expected_state) params = np.zeros((8, )) state = optimizer_def.init_state(params) identity = np.eye(8) statistic = identity * 1e-1 # I * matrix_epsilon preconditioner = identity self.assertLen(state.param_states.statistics, 1) self.assertLen(state.param_states.statistics, 1) np.testing.assert_allclose(state.param_states.statistics[0], statistic) np.testing.assert_allclose(state.param_states.preconditioners[0], preconditioner) params = np.zeros((8, 8)) state = optimizer_def.init_state(params) identity = np.eye(8) statistic = identity * 1e-1 # I * matrix_epsilon preconditioner = identity self.assertLen(state.param_states.statistics, 2) self.assertLen(state.param_states.statistics, 2) np.testing.assert_allclose(state.param_states.statistics[0], statistic) np.testing.assert_allclose(state.param_states.statistics[1], statistic) np.testing.assert_allclose(state.param_states.preconditioners[0], preconditioner) np.testing.assert_allclose(state.param_states.preconditioners[1], preconditioner) params = np.zeros((16, 16)) state = optimizer_def.init_state(params) zeros_like_param = np.zeros((8, )) identity = np.eye(8) statistic = identity * 1e-1 # I * matrix_epsilon preconditioner = identity self.assertLen(state.param_states.statistics, 8) self.assertLen(state.param_states.statistics, 8) for i in range(8): np.testing.assert_allclose(state.param_states.statistics[i], statistic) np.testing.assert_allclose(state.param_states.preconditioners[i], preconditioner) # Test best_effort_shape_interpretation # (3, 2, 16) wil be reshaped to (6, 16) # Last dim will be split into two (6, 8) and (6, 8) params = np.zeros((3, 2, 16)) state = optimizer_def.init_state(params) zeros_like_param = np.zeros((8, )) identity_left = np.eye(6) statistic_left = identity_left * 1e-1 # I * matrix_epsilon preconditioner_left = identity_left identity_right = np.eye(8) statistic_right = identity_right * 1e-1 # I * matrix_epsilon preconditioner_right = identity_right self.assertLen(state.param_states.statistics, 4) self.assertLen(state.param_states.statistics, 4) for i in range(4): if i % 2 == 0: np.testing.assert_allclose(state.param_states.statistics[i], statistic_left) np.testing.assert_allclose( state.param_states.preconditioners[i], preconditioner_left) else: np.testing.assert_allclose(state.param_states.statistics[i], statistic_right) np.testing.assert_allclose( state.param_states.preconditioners[i], preconditioner_right)