Exemple #1
0
 def test_multi_optimizer(self):
   params = {'a': 0., 'b': 0.}
   opt_a = optim.GradientDescent(learning_rate=1.)
   opt_b = optim.GradientDescent(learning_rate=10.)
   t_a = traverse_util.t_identity['a']
   t_b = traverse_util.t_identity['b']
   optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b))
   state = optimizer_def.init_state(params)
   expected_hyper_params = [
       _GradientDescentHyperParams(1.),
       _GradientDescentHyperParams(10.)
   ]
   self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
   expected_state = [optim.OptimizerState(0, [()])] * 2
   self.assertEqual(state, expected_state)
   grads = {'a': -1., 'b': -2.}
   new_params, new_state = optimizer_def.apply_gradient(
       optimizer_def.hyper_params, params, state, grads)
   expected_params = {'a': 1., 'b': 20.}
   expected_state = [optim.OptimizerState(1, [()])] * 2
   self.assertEqual(new_state, expected_state)
   self.assertEqual(new_params, expected_params)
   # override learning_rate
   hp = optimizer_def.update_hyper_params(learning_rate=2.)
   new_params, new_state = optimizer_def.apply_gradient(
       hp, params, state, grads)
   expected_params = {'a': 2., 'b': 4.}
   self.assertEqual(new_params, expected_params)
  def test_grad_var(self):
    model_size = 10
    example_grads = [{
        'layer1': np.ones(model_size),
        'layer2': 3 * np.ones(model_size)
    }, {
        'layer1': 2 * np.ones(model_size),
        'layer2': np.ones(model_size)
    }]
    eval_config = {'ema_beta': 0.5}
    training_metrics_grabber = utils.TrainingMetricsGrabber.create(
        example_grads[0], eval_config)

    # For the purposes of this test, we create fake optimizers to satisfy
    # metrics grabber API.
    fake_model = nn.Model(None, example_grads[0])
    new_optimizer = optimizers.GradientDescent(
        learning_rate=None).create(fake_model)
    old_optimizer = optimizers.GradientDescent(
        learning_rate=None).create(fake_model)

    for grad in example_grads:
      training_metrics_grabber = training_metrics_grabber.update(
          grad, old_optimizer, new_optimizer)

    for layer in ['layer1', 'layer2']:
      expected_grad_ema = 1 / 4 * np.zeros(model_size) + 1 / 4 * example_grads[
          0][layer] + 1 / 2 * example_grads[1][layer]

      self.assertArraysAllClose(expected_grad_ema,
                                training_metrics_grabber.state[layer].grad_ema)
Exemple #3
0
 def test_multiple_hparams(self):
   opt_dict = {
       'hp-beta1':
           optim.GradientDescent(learning_rate=1
                                ).create(np.array([0.0, 0.1, 0.2, 0.9])),
       'hp-learning_rate':
           optim.GradientDescent(learning_rate=1
                                ).create(np.array([0.0, 0.1, 0.2, 0.9])),
   }
   gv_dict = {
       'beta1': {
           'learning_rate_scalar': 1,
           'activation_fn': 'linear',
           'activation_ceiling': None,
           'activation_floor': None,
           'clip_min': None,
           'clip_max': None,
       },
       'learning_rate': {
           'learning_rate_scalar': 1,
           'activation_fn': 'linear',
           'activation_ceiling': None,
           'activation_floor': None,
           'clip_min': None,
           'clip_max': None,
       }
   }
   out = guided_parameters.get_activated_hparams(opt_dict, gv_dict)
   self.assertListEqual(list(out['hp-beta1']), [0.0, 0.1, 0.2, 0.9])
   self.assertListEqual(list(out['hp-learning_rate']), [0.0, 0.1, 0.2, 0.9])
Exemple #4
0
 def test_multi_guided_parameter(self):
   opt_dict = {
       'hp-beta1':
           optim.GradientDescent(learning_rate=1
                                ).create(np.array([0.0, 0.1, 0.2, 0.9])),
       'dp-ex_index':
           optim.GradientDescent(learning_rate=1
                                ).create(np.array([1, 2, 3, 4, 5])),
   }
   gv_dict = {
       'beta1': {
           'learning_rate_scalar': 1,
           'activation_fn': 'linear',
           'activation_ceiling': None,
           'activation_floor': None,
       },
       'ex_index': {
           'learning_rate_scalar': 1,
           'activation_fn': 'linear',
           'activation_ceiling': None,
           'activation_floor': None,
       }
   }
   raw_vars_dict, act_fn_dict = guided_parameters.get_raw_vars_and_act_fns(
       opt_dict, gv_dict)
   self.assertListEqual(list(raw_vars_dict['hp-beta1']), [0.0, 0.1, 0.2, 0.9])
   self.assertListEqual(list(raw_vars_dict['dp-ex_index']), [1, 2, 3, 4, 5])
   self.assertCountEqual(list(act_fn_dict), ['hp-beta1', 'dp-ex_index'])
Exemple #5
0
 def test_multi_optimizer_multiple_matches(self):
     params = {'a': {'x': 0., 'y': 0.}, 'b': {'y': 0, 'z': 0.}}
     opt_a = optim.GradientDescent(learning_rate=1.)
     opt_b = optim.GradientDescent(learning_rate=10.)
     t_a = optim.ModelParamTraversal(
         lambda path, _: path.endswith('/x') or path.endswith('/y'))
     t_b = optim.ModelParamTraversal(lambda path, value: value.dtype == jnp.
                                     int32 or path.endswith('/z'))
     optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b))
     with self.assertRaisesRegex(
             ValueError, r"Multiple optimizers match.*'y': \[0, 1\]"):
         jax.jit(optimizer_def.init_state)(params)
Exemple #6
0
 def test_init_state(self):
   params = onp.zeros((1,))
   optimizer_def = optim.GradientDescent(learning_rate=0.1)
   state = optimizer_def.init_state(params)
   expected_hyper_params = _GradientDescentHyperParams(0.1)
   self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
   expected_state = optim.OptimizerState(0, ())
   self.assertEqual(state, expected_state)
class FlaxOptimizersEquivalenceTest(chex.TestCase):

  def setUp(self):
    super().setUp()
    self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
    self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

  @parameterized.named_parameters(
      ('sgd',
       alias.sgd(LR),
       optim.GradientDescent(LR)),
      ('momentum',
       alias.sgd(LR, momentum=0.9),
       optim.Momentum(LR, beta=0.9)),  # Different names.
      ('nesterov_momentum',
       alias.sgd(LR, momentum=0.9, nesterov=True),
       optim.Momentum(LR, beta=0.9, nesterov=True)),
      ('rmsprop',
       alias.rmsprop(LR),
       optim.RMSProp(LR)),
      ('centered_rmsprop',
       alias.rmsprop(LR, centered=True),
       optim.RMSProp(LR, centered=True)),
      ('adam',
       alias.adam(LR),
       optim.Adam(LR)),
      ('adam_w',
       alias.adamw(LR, weight_decay=1e-4),
       optim.Adam(LR, weight_decay=1e-4)),  # Different name.
      ('adagrad',
       alias.adagrad(LR, initial_accumulator_value=0.),  # Different default!
       optim.Adagrad(LR)),
      ('lamb',
       alias.lamb(LR),
       optim.LAMB(LR)),
  )
  def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

    # flax/optim
    flax_params = self.init_params
    flax_optimizer = flax_optimizer.create(flax_params)
    for _ in range(STEPS):
      flax_optimizer = flax_optimizer.apply_gradient(
          self.per_step_updates)
      flax_params = flax_optimizer.target

    # optax
    optax_params = self.init_params
    state = optax_optimizer.init(optax_params)
    for _ in range(STEPS):
      updates, state = optax_optimizer.update(
          self.per_step_updates, state, optax_params)
      optax_params = update.apply_updates(optax_params, updates)

    # Check equivalence.
    chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
Exemple #8
0
 def test_apply_gradient(self):
   optimizer_def = optim.GradientDescent(learning_rate=0.1)
   params = onp.ones((1,))
   state = optim.OptimizerState(0, ())
   grads = onp.array([3.])
   new_params, new_state = optimizer_def.apply_gradient(
       optimizer_def.hyper_params, params, state, grads)
   expected_new_state = optim.OptimizerState(1, ())
   expected_new_params = onp.array([0.7])
   self.assertEqual(new_params, expected_new_params)
   self.assertEqual(new_state, expected_new_state)
Exemple #9
0
 def test_ignores_model(self):
   opt_dict = {
       'model':
           optim.GradientDescent(learning_rate=1
                                ).create(np.array([0.0, 0.1, 0.2, 0.9])),
   }
   gv_dict = {}
   raw_vars_dict, act_fn_dict = guided_parameters.get_raw_vars_and_act_fns(
       opt_dict, gv_dict)
   self.assertEqual(raw_vars_dict, {})
   self.assertEqual(act_fn_dict, {})
Exemple #10
0
 def test_optimizer_with_focus(self):
   params = {'a': 0., 'b': 0.}
   opt_def = optim.GradientDescent(learning_rate=1.)
   t_a = traverse_util.t_identity['a']
   optimizer = opt_def.create(params, focus=t_a)
   expected_state = [optim.OptimizerState(0, [()])]
   self.assertEqual(optimizer.state, expected_state)
   grads = {'a': -1., 'b': -2.}
   new_optimizer = optimizer.apply_gradient(grads)
   expected_params = {'a': 1., 'b': 0.}
   expected_state = [optim.OptimizerState(1, [()])]
   self.assertEqual(new_optimizer.state, expected_state)
   self.assertEqual(new_optimizer.target, expected_params)
Exemple #11
0
 def test_ignores_model(self):
   opt_dict = {
       'model':
           optim.GradientDescent(learning_rate=1
                                ).create(np.array([0.0, 0.1, 0.2, 0.9])),
   }
   gv_dict = {}
   out = guided_parameters.get_activated_hparams(opt_dict, gv_dict)
   self.assertEqual(
       out, {
           'hp-beta1': None,
           'hp-decay_rate': None,
           'hp-eps': None,
           'hp-learning_rate': None,
           'hp-weight_decay': None,
           'hp-label_smoothing': None,
       })
Exemple #12
0
def get_optimizer(hparams):
    """Constructs  the optimizer from the given HParams.

  Args:
    hparams: Hyper parameters.

  Returns:
    A flax optimizer.
  """
    if hparams.optimizer == 'sgd':
        return optimizers.GradientDescent(
            learning_rate=hparams.lr_hparams['initial_learning_rate'])
    if hparams.optimizer == 'nesterov':
        return optimizers.Momentum(
            learning_rate=hparams.lr_hparams['initial_learning_rate'],
            beta=hparams.opt_hparams.get('momentum', 0.9),
            weight_decay=hparams.opt_hparams.get('weight_decay', 0.0),
            nesterov=True)
    if hparams.optimizer == 'momentum':
        return optimizers.Momentum(
            learning_rate=hparams.lr_hparams['initial_learning_rate'],
            beta=hparams.opt_hparams.get('momentum', 0.9),
            weight_decay=hparams.opt_hparams.get('weight_decay', 0.0),
            nesterov=False)
    if hparams.optimizer == 'adam':
        return optimizers.Adam(
            learning_rate=hparams.lr_hparams['initial_learning_rate'],
            beta1=hparams.opt_hparams.get('beta1', 0.9),
            beta2=hparams.opt_hparams.get('beta2', 0.999),
            eps=hparams.opt_hparams.get('epsilon', 1e-8),
            weight_decay=hparams.opt_hparams.get('weight_decay', 0.0),
        )
    if hparams.optimizer == 'rmsprop':
        return optimizers.RMSProp(
            learning_rate=hparams.lr_hparams.get('initial_learning_rate'),
            beta2=hparams.opt_hparams.get('beta2', 0.9),
            eps=hparams.opt_hparams.get('epsilon', 1e-8))
    else:
        raise NotImplementedError('Optimizer {} not implemented'.format(
            hparams.optimizer))
def create_optimizer(model_params, learning_rate: float):
    optimizer_def = optim.GradientDescent(learning_rate)
    model_optimizer = optimizer_def.create(model_params)
    return model_optimizer
Exemple #14
0
class FlaxOptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([0., 0.3, 500.,
                                            5.]), jnp.array([300., 3.]))

    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR), optim.GradientDescent(LR)),
        ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum(
            LR, beta=0.9)),  # Different names.
        ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True),
         optim.Momentum(LR, beta=0.9, nesterov=True)),
        ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)),
        ('centered_rmsprop', alias.rmsprop(
            LR, centered=True), optim.RMSProp(LR, centered=True)),
        ('adam', alias.adam(LR), optim.Adam(LR)),
        ('adam_w', alias.adamw(LR, weight_decay=1e-4),
         optim.Adam(LR, weight_decay=1e-4)),  # Different name.
        (
            'adagrad',
            alias.adagrad(LR,
                          initial_accumulator_value=0.),  # Different default!
            optim.Adagrad(LR)),
        ('lamb', alias.lamb(LR), optim.LAMB(LR)),
        ('lars',
         alias.lars(LR,
                    weight_decay=.5,
                    trust_coefficient=0.003,
                    momentum=0.9,
                    eps=1e-3),
         optim.LARS(
             LR, weight_decay=.5, trust_coefficient=0.003, beta=0.9,
             eps=1e-3)),
        ('adafactor',
         alias.adafactor(learning_rate=LR / 10.,
                         factored=True,
                         multiply_by_parameter_scale=True,
                         clipping_threshold=1.0,
                         decay_rate=0.8,
                         min_dim_size_to_factor=2),
         optim.Adafactor(learning_rate=LR / 10.,
                         factored=True,
                         multiply_by_parameter_scale=True,
                         clipping_threshold=1.0,
                         decay_rate=0.8,
                         min_dim_size_to_factor=2)),
    )
    def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

        # flax/optim
        flax_params = self.init_params
        flax_optimizer = flax_optimizer.create(flax_params)
        for _ in range(STEPS):
            flax_optimizer = flax_optimizer.apply_gradient(
                self.per_step_updates)
            flax_params = flax_optimizer.target

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)
        for _ in range(STEPS):
            updates, state = optax_optimizer.update(self.per_step_updates,
                                                    state, optax_params)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(flax_params, optax_params, rtol=2e-4)
Exemple #15
0
def get_optimizer(hps):
    """Constructs the optimizer from the given HParams."""
    if 'weight_decay' in hps.opt_hparams:
        weight_decay = hps.opt_hparams['weight_decay']
    else:
        weight_decay = 0

    if hps.optimizer == 'sgd':
        return optimizers.GradientDescent(learning_rate=None)
    elif hps.optimizer == 'nesterov':
        return optimizers.Momentum(learning_rate=None,
                                   beta=hps.opt_hparams['momentum'],
                                   nesterov=True,
                                   weight_decay=weight_decay)
    elif hps.optimizer == 'momentum':
        return optimizers.Momentum(learning_rate=None,
                                   beta=hps.opt_hparams['momentum'],
                                   nesterov=False,
                                   weight_decay=weight_decay)
    elif hps.optimizer == 'lamb':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        return optimizers.LAMB(learning_rate=None,
                               beta1=hps.opt_hparams['beta1'],
                               beta2=hps.opt_hparams['beta2'],
                               eps=hps.opt_hparams['epsilon'],
                               weight_decay=weight_decay)
    elif hps.optimizer == 'adam':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        return optimizers.Adam(learning_rate=None,
                               beta1=hps.opt_hparams['beta1'],
                               beta2=hps.opt_hparams['beta2'],
                               eps=hps.opt_hparams['epsilon'],
                               weight_decay=weight_decay)
    elif hps.optimizer == 'lars':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        return optimizers.LARS(learning_rate=None,
                               beta=hps.opt_hparams['beta'],
                               weight_decay=weight_decay)
    elif hps.optimizer == 'mlperf_lars_resnet':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        weight_opt_def = optimizers.LARS(learning_rate=None,
                                         beta=hps.opt_hparams['beta'],
                                         weight_decay=weight_decay)
        other_opt_def = optimizers.Momentum(learning_rate=None,
                                            beta=hps.opt_hparams['beta'],
                                            weight_decay=0,
                                            nesterov=False)

        def filter_weights(key, _):
            return 'bias' not in key and 'scale' not in key

        def filter_other(key, _):
            return 'bias' in key or 'scale' in key

        weight_traversal = optimizers.ModelParamTraversal(filter_weights)
        other_traversal = optimizers.ModelParamTraversal(filter_other)
        return optimizers.MultiOptimizer((weight_traversal, weight_opt_def),
                                         (other_traversal, other_opt_def))
    elif hps.optimizer == 'mlperf_lamb':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        weight_opt_def = optimizers.LAMB(
            learning_rate=None,
            beta1=hps.opt_hparams['beta1'],
            beta2=hps.opt_hparams['beta2'],
            eps=hps.opt_hparams['epsilon'],
            weight_decay=hps.opt_hparams['lamb_weight_decay'])
        other_opt_def = optimizers.Adam(
            learning_rate=None,
            beta1=hps.opt_hparams['beta1'],
            beta2=hps.opt_hparams['beta2'],
            eps=hps.opt_hparams['epsilon'],
            weight_decay=hps.opt_hparams['adam_weight_decay'])

        def filter_weights(key, _):
            return 'bias' not in key and 'scale' not in key

        def filter_other(key, _):
            return 'bias' in key or 'scale' in key

        weight_traversal = optimizers.ModelParamTraversal(filter_weights)
        other_traversal = optimizers.ModelParamTraversal(filter_other)
        return optimizers.MultiOptimizer((weight_traversal, weight_opt_def),
                                         (other_traversal, other_opt_def))
    else:
        raise NotImplementedError('Optimizer {} not implemented'.format(
            hps.optimizer))