示例#1
0
    def test_init_state(self):
        params = np.zeros((3, 2))
        optimizer_def = optim.Adafactor(learning_rate=0.1,
                                        decay_rate=0.8,
                                        beta1=None,
                                        min_dim_size_to_factor=0)
        state = optimizer_def.init_state(params)

        expected_hyper_params = _AdafactorHyperParams(0.1, True, True, None,
                                                      0.8, 0, 1.0, None, 0,
                                                      1e-30, 1e-3)
        self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
        expected_state = optim.OptimizerState(
            0,
            _AdafactorParamState(np.zeros((2, )), np.zeros((3, )),
                                 np.zeros((1, )), np.zeros((1, ))))
        check_eq(state, expected_state)

        # unfactorized
        optimizer_def = optim.Adafactor(learning_rate=0.1,
                                        decay_rate=0.8,
                                        beta1=0.0,
                                        min_dim_size_to_factor=32)
        state = optimizer_def.init_state(params)

        expected_hyper_params = _AdafactorHyperParams(0.1, True, True, 0.0,
                                                      0.8, 0, 1.0, None, 32,
                                                      1e-30, 1e-3)
        self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
        expected_state = optim.OptimizerState(
            0,
            _AdafactorParamState(np.zeros((1, )), np.zeros((1, )),
                                 np.zeros((3, 2)), np.zeros((3, 2))))
        check_eq(state, expected_state)
示例#2
0
 def test_multi_optimizer(self):
   params = {'a': 0., 'b': 0.}
   opt_a = optim.GradientDescent(learning_rate=1.)
   opt_b = optim.GradientDescent(learning_rate=10.)
   t_a = traverse_util.t_identity['a']
   t_b = traverse_util.t_identity['b']
   optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b))
   state = optimizer_def.init_state(params)
   expected_hyper_params = [
       _GradientDescentHyperParams(1.),
       _GradientDescentHyperParams(10.)
   ]
   self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
   expected_state = [optim.OptimizerState(0, [()])] * 2
   self.assertEqual(state, expected_state)
   grads = {'a': -1., 'b': -2.}
   new_params, new_state = optimizer_def.apply_gradient(
       optimizer_def.hyper_params, params, state, grads)
   expected_params = {'a': 1., 'b': 20.}
   expected_state = [optim.OptimizerState(1, [()])] * 2
   self.assertEqual(new_state, expected_state)
   self.assertEqual(new_params, expected_params)
   # override learning_rate
   hp = optimizer_def.update_hyper_params(learning_rate=2.)
   new_params, new_state = optimizer_def.apply_gradient(
       hp, params, state, grads)
   expected_params = {'a': 2., 'b': 4.}
   self.assertEqual(new_params, expected_params)
示例#3
0
 def test_apply_gradient(self):
   optimizer_def = optim.GradientDescent(learning_rate=0.1)
   params = onp.ones((1,))
   state = optim.OptimizerState(0, ())
   grads = onp.array([3.])
   new_params, new_state = optimizer_def.apply_gradient(
       optimizer_def.hyper_params, params, state, grads)
   expected_new_state = optim.OptimizerState(1, ())
   expected_new_params = onp.array([0.7])
   self.assertEqual(new_params, expected_new_params)
   self.assertEqual(new_state, expected_new_state)
示例#4
0
 def test_apply_gradient(self):
     optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2)
     params = np.ones((1, ))
     state = optim.OptimizerState(0, _MomentumParamState(np.array([1.])))
     grads = np.array([3.])
     new_params, new_state = optimizer_def.apply_gradient(
         optimizer_def.hyper_params, params, state, grads)
     expected_new_state = optim.OptimizerState(
         1, _MomentumParamState(np.array([3.2])))
     expected_new_params = np.array([1. - 0.32])
     self.assertEqual(new_params, expected_new_params)
     self.assertEqual(new_state, expected_new_state)
示例#5
0
 def test_apply_gradient(self):
     optimizer_def = optim.RMSProp(learning_rate=0.1, beta2=0.9, eps=0.01)
     params = onp.array([1.])
     state = optim.OptimizerState(1, _RMSPropParamState(onp.array([0.1])))
     grads = onp.array([4.])
     new_params, new_state = optimizer_def.apply_gradient(
         optimizer_def.hyper_params, params, state, grads)
     expected_new_state = optim.OptimizerState(
         2, _RMSPropParamState(onp.array([1.69])))
     expected_new_params = onp.array([0.6946565])
     onp.testing.assert_allclose(new_params, expected_new_params)
     self.assertEqual(new_state, expected_new_state)
示例#6
0
 def test_apply_gradient(self):
     optimizer_def = optim.Adagrad(learning_rate=0.1, eps=0.01)
     params = np.array([1.])
     state = optim.OptimizerState(1, _AdagradParamState(np.array([0.1])))
     grads = np.array([4.])
     new_params, new_state = optimizer_def.apply_gradient(
         optimizer_def.hyper_params, params, state, grads)
     expected_new_state = optim.OptimizerState(
         2, _AdagradParamState(np.array([16.1])))
     expected_new_params = np.array([0.9005588])
     np.testing.assert_allclose(new_params, expected_new_params)
     self.assertEqual(new_state, expected_new_state)
示例#7
0
 def test_optimizer_with_focus(self):
   params = {'a': 0., 'b': 0.}
   opt_def = optim.GradientDescent(learning_rate=1.)
   t_a = traverse_util.t_identity['a']
   optimizer = opt_def.create(params, focus=t_a)
   expected_state = [optim.OptimizerState(0, [()])]
   self.assertEqual(optimizer.state, expected_state)
   grads = {'a': -1., 'b': -2.}
   new_optimizer = optimizer.apply_gradient(grads)
   expected_params = {'a': 1., 'b': 0.}
   expected_state = [optim.OptimizerState(1, [()])]
   self.assertEqual(new_optimizer.state, expected_state)
   self.assertEqual(new_optimizer.target, expected_params)
示例#8
0
 def test_apply_gradient(self):
     optimizer_def = optim.Adadelta(learning_rate=0.1,
                                    rho=0.9,
                                    eps=1e-6,
                                    weight_decay=0.1)
     params = np.array([1.])
     state = optim.OptimizerState(
         1, _AdadeltaParamState(np.zeros((1, )), np.zeros((1, ))))
     grads = np.array([1.])
     new_param, new_state = optimizer_def.apply_gradient(
         optimizer_def.hyper_params, params, state, grads)
     expected_new_state = optim.OptimizerState(
         2, _AdadeltaParamState(np.array([0.1]), np.array([9.999902e-7])))
     expected_new_params = np.array([0.9896838])
     np.testing.assert_allclose(new_param, expected_new_params)
     self.assertEqual(new_state, expected_new_state)
示例#9
0
 def test_empty_optimizer(self):
   params = {}
   optimizer_def = optim.Momentum(learning_rate=0.1)
   optimizer = optimizer_def.create(params)
   new_optimizer = optimizer.apply_gradient({})
   expected_state = optim.OptimizerState(1, {})
   self.assertEqual(new_optimizer.state, expected_state)
示例#10
0
 def test_apply_gradient(self):
   optimizer_def = optim.Adam(learning_rate=0.1,
                              beta1=0.2,
                              beta2=0.9,
                              eps=0.01,
                              weight_decay=0.0)
   params = onp.array([1.])
   state = optim.OptimizerState(
       1, _AdamParamState(onp.array([0.1]), onp.array([0.9])))
   grads = onp.array([4.])
   new_params, new_state = optimizer_def.apply_gradient(
       optimizer_def.hyper_params, params, state, grads)
   expected_new_state = optim.OptimizerState(
       2, _AdamParamState(onp.array([3.22]), onp.array([2.41])))
   expected_new_params = onp.array([0.906085])
   onp.testing.assert_allclose(new_params, expected_new_params)
   self.assertEqual(new_state, expected_new_state)
示例#11
0
 def test_init_state(self):
   params = onp.zeros((1,))
   optimizer_def = optim.GradientDescent(learning_rate=0.1)
   state = optimizer_def.init_state(params)
   expected_hyper_params = _GradientDescentHyperParams(0.1)
   self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
   expected_state = optim.OptimizerState(0, ())
   self.assertEqual(state, expected_state)
示例#12
0
 def test_create(self):
   params = onp.ones((1,))
   optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2)
   optimizer = optimizer_def.create(params)
   expected_state = optim.OptimizerState(
       0, _MomentumParamState(onp.zeros((1,))))
   self.assertEqual(optimizer.optimizer_def, optimizer_def)
   self.assertEqual(optimizer.state, expected_state)
   self.assertEqual(optimizer.target, params)
示例#13
0
 def test_init_state(self):
   params = onp.zeros((1,))
   optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2)
   state = optimizer_def.init_state(params)
   expected_hyper_params = _MomentumHyperParams(0.1, 0.2, 0, False)
   self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
   expected_state = optim.OptimizerState(
       0, _MomentumParamState(onp.zeros((1,))))
   self.assertEqual(state, expected_state)
示例#14
0
  def test_apply_gradient(self):
    optimizer_def = optim.Adafactor(learning_rate=0.1, decay_rate=0.8,
                                    min_dim_size_to_factor=0)
    params = onp.ones((3, 2), onp.float32)
    state = optim.OptimizerState(
        1, _AdafactorParamState(onp.array([0.9, 0.9]),
                                onp.array([0.1, 0.1, 0.1]),
                                onp.zeros((1,)),
                                onp.zeros((1,))))
    grads = onp.ones((3, 2), onp.float32)
    new_params, new_state = optimizer_def.apply_gradient(
        optimizer_def.hyper_params, params, state, grads)
    expected_new_state = optim.OptimizerState(
        2, _AdafactorParamState(
            onp.array([0.9574349, 0.9574349]),
            onp.array([0.6169143, 0.6169143, 0.6169143]),
            onp.zeros((1,)),
            onp.zeros((1,))))
    expected_new_params = 0.9 * onp.ones((3, 2))
    onp.testing.assert_allclose(new_params, expected_new_params)
    check_eq(new_state, expected_new_state, rtol=1e-6)

    # unfactored w momentum
    optimizer_def = optim.Adafactor(learning_rate=0.1,
                                    beta1=0.0, decay_rate=0.8,
                                    min_dim_size_to_factor=32)
    params = onp.ones((3, 2), onp.float32)
    state = optim.OptimizerState(
        1, _AdafactorParamState(onp.zeros(1,),
                                onp.zeros(1,),
                                0.5*onp.ones((3, 2)),
                                onp.zeros((3, 2))))
    grads = onp.ones((3, 2), onp.float32)
    new_params, new_state = optimizer_def.apply_gradient(
        optimizer_def.hyper_params, params, state, grads)
    expected_new_params = 0.9 * onp.ones((3, 2))
    onp.testing.assert_allclose(new_params, expected_new_params)
    expected_new_state = optim.OptimizerState(
        2, _AdafactorParamState(
            onp.array([0.0]),
            onp.array([0.0]),
            0.787174 * onp.ones((3, 2)),
            0.1 * onp.ones((3,2))))
    check_eq(new_state, expected_new_state, rtol=1e-6)
示例#15
0
    def test_init_state(self):
        params = onp.zeros((1, ))
        optimizer_def = optim.RMSProp(learning_rate=0.1, beta2=0.9, eps=0.01)
        state = optimizer_def.init_state(params)

        expected_hyper_params = _RMSPropHyperParams(0.1, 0.9, 0.01)
        self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
        expected_state = optim.OptimizerState(
            0, _RMSPropParamState(onp.zeros((1, ))))
        self.assertEqual(state, expected_state)
示例#16
0
  def test_init_state(self):
    params = onp.zeros((1,))
    optimizer_def = optim.Adagrad(learning_rate=0.1, eps=0.01)
    state = optimizer_def.init_state(params)

    expected_hyper_params = _AdagradHyperParams(0.1, 0.01)
    self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
    expected_state = optim.OptimizerState(
        0, _AdagradParamState(onp.zeros((1,))))
    self.assertEqual(state, expected_state)
示例#17
0
 def test_apply_gradient_centered(self):
     optimizer_def = optim.RMSProp(learning_rate=0.1,
                                   beta2=0.9,
                                   eps=0.01,
                                   centered=True)
     params = np.array([1.])
     state = optim.OptimizerState(
         1, _RMSPropParamState(np.array([0.1]), np.array([0.1])))
     grads = np.array([4.])
     new_params, new_state = optimizer_def.apply_gradient(
         optimizer_def.hyper_params, params, state, grads)
     expected_new_state = optim.OptimizerState(
         2, _RMSPropParamState(np.array([1.69]), np.array([0.49])))
     expected_new_params = np.array([0.670543], dtype=np.float32)
     np.testing.assert_allclose(new_params, expected_new_params, rtol=1e-6)
     np.testing.assert_allclose(new_state.param_states.v,
                                expected_new_state.param_states.v)
     np.testing.assert_allclose(new_state.param_states.mg,
                                expected_new_state.param_states.mg)
示例#18
0
def maybe_copy_model_from_pretraining(optimizer, pretrain_optimizer, step,
                                      adam_opt_def):
  """Copy model parameters from pretraining."""
  if step < FLAGS.num_pretrain_steps:
    optimizer = jax_utils.unreplicate(optimizer)
    state_dict = adam_opt_def.state_dict(
        target=jax_utils.unreplicate(pretrain_optimizer).target,
        state=optim.OptimizerState(jnp.asarray(step, dtype=jnp.int32),
                                   optimizer.state.param_states))
    optimizer = optimizer.restore_state(state_dict)
    optimizer = jax_utils.replicate(optimizer)
  return optimizer
示例#19
0
    def test_init_state(self):
        params = np.zeros((1, ))
        optimizer_def = optim.Adadelta(learning_rate=0.1,
                                       rho=0.9,
                                       eps=1e-6,
                                       weight_decay=0.1)
        state = optimizer_def.init_state(params)

        expected_hyper_params = _AdadeltaHyperParams(0.1, 0.9, 1e-6, 0.1)
        self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
        expected_state = optim.OptimizerState(
            0, _AdadeltaParamState(np.zeros((1, )), np.zeros((1, ))))
        self.assertEqual(state, expected_state)
示例#20
0
  def test_init_state(self):
    params = onp.zeros((1,))
    optimizer_def = optim.Adam(learning_rate=0.1,
                               beta1=0.2,
                               beta2=0.9,
                               eps=0.01,
                               weight_decay=0.0)
    state = optimizer_def.init_state(params)

    expected_hyper_params = _AdamHyperParams(0.1, 0.2, 0.9, 0.01, 0.0)
    self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
    expected_state = optim.OptimizerState(
        0, _AdamParamState(onp.zeros((1,)), onp.zeros((1,))))
    self.assertEqual(state, expected_state)
示例#21
0
 def test_momentum_with_weight_norm(self):
   params = onp.ones((2, 2)) * 2.
   optimizer_def = optim.WeightNorm(optim.Momentum(0.1))
   state = optimizer_def.init_state(params)
   self.assertEqual(jax.tree_map(onp.shape, state), optim.OptimizerState(
       step=(),
       param_states=_WeightNormParamState(
           direction_state=_MomentumParamState(momentum=(2, 2)),
           scale_state=_MomentumParamState(momentum=(1, 2)),
           mult=(1, 2)
       )
   ))
   grads = onp.ones((2, 2))
   new_params, new_state = optimizer_def.apply_gradient(
       optimizer_def.hyper_params, params, state, grads)
   onp.testing.assert_allclose(new_params, onp.full_like(params, 1.9))
   onp.testing.assert_allclose(new_state.param_states.mult, 1.9 * 2 ** 0.5)
    def test_init_state(self):
        # Create an optimizer def and check the params are wired through.
        optimizer_def = shampoo.Shampoo(
            learning_rate=0.1,
            beta1=0.9,
            beta2=0.9,
            diagonal_epsilon=0.0,
            matrix_epsilon=1e-1,
            exponent_override=2,
            weight_decay=1e-4,
            start_preconditioning_step=1,
            preconditioning_compute_steps=1,
            statistics_compute_steps=1,
            no_preconditioning_for_layers_with_dim_gt=8192,
            best_effort_shape_interpretation=True,
            block_size=8,
            graft_type=shampoo.LayerwiseGrafting.SGD,
            nesterov=False,
            batch_axis_name=None)
        expected_hyper_params = shampoo._ShampooHyperParams(
            learning_rate=0.1,
            beta1=0.9,
            beta2=0.9,
            diagonal_eps=0.0,
            matrix_eps=1e-1,
            exponent_override=2,
            weight_decay=1e-4,
            start_preconditioning_step=1,
            preconditioning_compute_steps=1,
            statistics_compute_steps=1,
            no_preconditioning_for_layers_with_dim_gt=8192,
            best_effort_shape_interpretation=True,
            block_size=8,
            graft_type=shampoo.LayerwiseGrafting.SGD,
            nesterov=False,
            batch_axis_name=None)
        self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)

        params = np.zeros((1, ))
        state = optimizer_def.init_state(params)
        zeros_like_param = np.zeros((1, ))
        expected_state = optim.OptimizerState(
            0,
            shampoo._ShampooDefaultParamState(zeros_like_param, [], [],
                                              zeros_like_param,
                                              zeros_like_param))
        self.assertEqual(state, expected_state)

        params = np.zeros((8, ))
        state = optimizer_def.init_state(params)
        identity = np.eye(8)
        statistic = identity * 1e-1  #  I * matrix_epsilon
        preconditioner = identity
        self.assertLen(state.param_states.statistics, 1)
        self.assertLen(state.param_states.statistics, 1)
        np.testing.assert_allclose(state.param_states.statistics[0], statistic)
        np.testing.assert_allclose(state.param_states.preconditioners[0],
                                   preconditioner)

        params = np.zeros((8, 8))
        state = optimizer_def.init_state(params)
        identity = np.eye(8)
        statistic = identity * 1e-1  #  I * matrix_epsilon
        preconditioner = identity
        self.assertLen(state.param_states.statistics, 2)
        self.assertLen(state.param_states.statistics, 2)
        np.testing.assert_allclose(state.param_states.statistics[0], statistic)
        np.testing.assert_allclose(state.param_states.statistics[1], statistic)
        np.testing.assert_allclose(state.param_states.preconditioners[0],
                                   preconditioner)
        np.testing.assert_allclose(state.param_states.preconditioners[1],
                                   preconditioner)

        params = np.zeros((16, 16))
        state = optimizer_def.init_state(params)
        zeros_like_param = np.zeros((8, ))
        identity = np.eye(8)
        statistic = identity * 1e-1  #  I * matrix_epsilon
        preconditioner = identity
        self.assertLen(state.param_states.statistics, 8)
        self.assertLen(state.param_states.statistics, 8)
        for i in range(8):
            np.testing.assert_allclose(state.param_states.statistics[i],
                                       statistic)
            np.testing.assert_allclose(state.param_states.preconditioners[i],
                                       preconditioner)

        # Test best_effort_shape_interpretation
        # (3, 2, 16) wil be reshaped to (6, 16)
        # Last dim will be split into two (6, 8) and (6, 8)
        params = np.zeros((3, 2, 16))
        state = optimizer_def.init_state(params)
        zeros_like_param = np.zeros((8, ))
        identity_left = np.eye(6)
        statistic_left = identity_left * 1e-1  #  I * matrix_epsilon
        preconditioner_left = identity_left
        identity_right = np.eye(8)
        statistic_right = identity_right * 1e-1  #  I * matrix_epsilon
        preconditioner_right = identity_right
        self.assertLen(state.param_states.statistics, 4)
        self.assertLen(state.param_states.statistics, 4)
        for i in range(4):
            if i % 2 == 0:
                np.testing.assert_allclose(state.param_states.statistics[i],
                                           statistic_left)
                np.testing.assert_allclose(
                    state.param_states.preconditioners[i], preconditioner_left)
            else:
                np.testing.assert_allclose(state.param_states.statistics[i],
                                           statistic_right)
                np.testing.assert_allclose(
                    state.param_states.preconditioners[i],
                    preconditioner_right)