Beispiel #1
0
class FlaxOptimizersEquivalenceTest(chex.TestCase):

  def setUp(self):
    super().setUp()
    self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
    self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

  @parameterized.named_parameters(
      ('sgd',
       alias.sgd(LR),
       optim.GradientDescent(LR)),
      ('momentum',
       alias.sgd(LR, momentum=0.9),
       optim.Momentum(LR, beta=0.9)),  # Different names.
      ('nesterov_momentum',
       alias.sgd(LR, momentum=0.9, nesterov=True),
       optim.Momentum(LR, beta=0.9, nesterov=True)),
      ('rmsprop',
       alias.rmsprop(LR),
       optim.RMSProp(LR)),
      ('centered_rmsprop',
       alias.rmsprop(LR, centered=True),
       optim.RMSProp(LR, centered=True)),
      ('adam',
       alias.adam(LR),
       optim.Adam(LR)),
      ('adam_w',
       alias.adamw(LR, weight_decay=1e-4),
       optim.Adam(LR, weight_decay=1e-4)),  # Different name.
      ('adagrad',
       alias.adagrad(LR, initial_accumulator_value=0.),  # Different default!
       optim.Adagrad(LR)),
      ('lamb',
       alias.lamb(LR),
       optim.LAMB(LR)),
  )
  def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

    # flax/optim
    flax_params = self.init_params
    flax_optimizer = flax_optimizer.create(flax_params)
    for _ in range(STEPS):
      flax_optimizer = flax_optimizer.apply_gradient(
          self.per_step_updates)
      flax_params = flax_optimizer.target

    # optax
    optax_params = self.init_params
    state = optax_optimizer.init(optax_params)
    for _ in range(STEPS):
      updates, state = optax_optimizer.update(
          self.per_step_updates, state, optax_params)
      optax_params = update.apply_updates(optax_params, updates)

    # Check equivalence.
    chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
Beispiel #2
0
def create_optimizer(name='adam',
                     learning_rate=6.25e-5,
                     beta1=0.9,
                     beta2=0.999,
                     eps=1.5e-4):
    """Create an optimizer for training.

  Currently, only the Adam optimizer is supported.

  Args:
    name: str, name of the optimizer to create.
    learning_rate: float, learning rate to use in the optimizer.
    beta1: float, beta1 parameter for the optimizer.
    beta2: float, beta2 parameter for the optimizer.
    eps: float, epsilon parameter for the optimizer.

  Returns:
    A flax optimizer.
  """
    if name == 'adam':
        logging.info(
            'Creating Adam optimizer with settings lr=%f, beta1=%f, '
            'beta2=%f, eps=%f', learning_rate, beta1, beta2, eps)
        return optim.Adam(learning_rate=learning_rate,
                          beta1=beta1,
                          beta2=beta2,
                          eps=eps)
    elif name == 'rmsprop':
        logging.info(
            'Creating RMSProp optimizer with settings lr=%f, beta2=%f, '
            'eps=%f', learning_rate, beta2, eps)
        return optim.RMSProp(learning_rate=learning_rate, beta2=beta2, eps=eps)
    else:
        raise ValueError('Unsupported optimizer {}'.format(name))
Beispiel #3
0
    def test_init_state(self):
        params = onp.zeros((1, ))
        optimizer_def = optim.RMSProp(learning_rate=0.1, beta2=0.9, eps=0.01)
        state = optimizer_def.init_state(params)

        expected_hyper_params = _RMSPropHyperParams(0.1, 0.9, 0.01)
        self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
        expected_state = optim.OptimizerState(
            0, _RMSPropParamState(onp.zeros((1, ))))
        self.assertEqual(state, expected_state)
Beispiel #4
0
 def test_apply_gradient(self):
     optimizer_def = optim.RMSProp(learning_rate=0.1, beta2=0.9, eps=0.01)
     params = onp.array([1.])
     state = optim.OptimizerState(1, _RMSPropParamState(onp.array([0.1])))
     grads = onp.array([4.])
     new_params, new_state = optimizer_def.apply_gradient(
         optimizer_def.hyper_params, params, state, grads)
     expected_new_state = optim.OptimizerState(
         2, _RMSPropParamState(onp.array([1.69])))
     expected_new_params = onp.array([0.6946565])
     onp.testing.assert_allclose(new_params, expected_new_params)
     self.assertEqual(new_state, expected_new_state)
Beispiel #5
0
 def test_apply_gradient_centered(self):
     optimizer_def = optim.RMSProp(learning_rate=0.1,
                                   beta2=0.9,
                                   eps=0.01,
                                   centered=True)
     params = np.array([1.])
     state = optim.OptimizerState(
         1, _RMSPropParamState(np.array([0.1]), np.array([0.1])))
     grads = np.array([4.])
     new_params, new_state = optimizer_def.apply_gradient(
         optimizer_def.hyper_params, params, state, grads)
     expected_new_state = optim.OptimizerState(
         2, _RMSPropParamState(np.array([1.69]), np.array([0.49])))
     expected_new_params = np.array([0.670543], dtype=np.float32)
     np.testing.assert_allclose(new_params, expected_new_params, rtol=1e-6)
     np.testing.assert_allclose(new_state.param_states.v,
                                expected_new_state.param_states.v)
     np.testing.assert_allclose(new_state.param_states.mg,
                                expected_new_state.param_states.mg)
Beispiel #6
0
def get_optimizer(hparams):
    """Constructs  the optimizer from the given HParams.

  Args:
    hparams: Hyper parameters.

  Returns:
    A flax optimizer.
  """
    if hparams.optimizer == 'sgd':
        return optimizers.GradientDescent(
            learning_rate=hparams.lr_hparams['initial_learning_rate'])
    if hparams.optimizer == 'nesterov':
        return optimizers.Momentum(
            learning_rate=hparams.lr_hparams['initial_learning_rate'],
            beta=hparams.opt_hparams.get('momentum', 0.9),
            weight_decay=hparams.opt_hparams.get('weight_decay', 0.0),
            nesterov=True)
    if hparams.optimizer == 'momentum':
        return optimizers.Momentum(
            learning_rate=hparams.lr_hparams['initial_learning_rate'],
            beta=hparams.opt_hparams.get('momentum', 0.9),
            weight_decay=hparams.opt_hparams.get('weight_decay', 0.0),
            nesterov=False)
    if hparams.optimizer == 'adam':
        return optimizers.Adam(
            learning_rate=hparams.lr_hparams['initial_learning_rate'],
            beta1=hparams.opt_hparams.get('beta1', 0.9),
            beta2=hparams.opt_hparams.get('beta2', 0.999),
            eps=hparams.opt_hparams.get('epsilon', 1e-8),
            weight_decay=hparams.opt_hparams.get('weight_decay', 0.0),
        )
    if hparams.optimizer == 'rmsprop':
        return optimizers.RMSProp(
            learning_rate=hparams.lr_hparams.get('initial_learning_rate'),
            beta2=hparams.opt_hparams.get('beta2', 0.9),
            eps=hparams.opt_hparams.get('epsilon', 1e-8))
    else:
        raise NotImplementedError('Optimizer {} not implemented'.format(
            hparams.optimizer))
Beispiel #7
0
class FlaxOptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([0., 0.3, 500.,
                                            5.]), jnp.array([300., 3.]))

    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR), optim.GradientDescent(LR)),
        ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum(
            LR, beta=0.9)),  # Different names.
        ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True),
         optim.Momentum(LR, beta=0.9, nesterov=True)),
        ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)),
        ('centered_rmsprop', alias.rmsprop(
            LR, centered=True), optim.RMSProp(LR, centered=True)),
        ('adam', alias.adam(LR), optim.Adam(LR)),
        ('adam_w', alias.adamw(LR, weight_decay=1e-4),
         optim.Adam(LR, weight_decay=1e-4)),  # Different name.
        (
            'adagrad',
            alias.adagrad(LR,
                          initial_accumulator_value=0.),  # Different default!
            optim.Adagrad(LR)),
        ('lamb', alias.lamb(LR), optim.LAMB(LR)),
        ('lars',
         alias.lars(LR,
                    weight_decay=.5,
                    trust_coefficient=0.003,
                    momentum=0.9,
                    eps=1e-3),
         optim.LARS(
             LR, weight_decay=.5, trust_coefficient=0.003, beta=0.9,
             eps=1e-3)),
        ('adafactor',
         alias.adafactor(learning_rate=LR / 10.,
                         factored=True,
                         multiply_by_parameter_scale=True,
                         clipping_threshold=1.0,
                         decay_rate=0.8,
                         min_dim_size_to_factor=2),
         optim.Adafactor(learning_rate=LR / 10.,
                         factored=True,
                         multiply_by_parameter_scale=True,
                         clipping_threshold=1.0,
                         decay_rate=0.8,
                         min_dim_size_to_factor=2)),
    )
    def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

        # flax/optim
        flax_params = self.init_params
        flax_optimizer = flax_optimizer.create(flax_params)
        for _ in range(STEPS):
            flax_optimizer = flax_optimizer.apply_gradient(
                self.per_step_updates)
            flax_params = flax_optimizer.target

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)
        for _ in range(STEPS):
            updates, state = optax_optimizer.update(self.per_step_updates,
                                                    state, optax_params)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(flax_params, optax_params, rtol=2e-4)