Ejemplo n.º 1
0
class FlaxOptimizersEquivalenceTest(chex.TestCase):

  def setUp(self):
    super().setUp()
    self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
    self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

  @parameterized.named_parameters(
      ('sgd',
       alias.sgd(LR),
       optim.GradientDescent(LR)),
      ('momentum',
       alias.sgd(LR, momentum=0.9),
       optim.Momentum(LR, beta=0.9)),  # Different names.
      ('nesterov_momentum',
       alias.sgd(LR, momentum=0.9, nesterov=True),
       optim.Momentum(LR, beta=0.9, nesterov=True)),
      ('rmsprop',
       alias.rmsprop(LR),
       optim.RMSProp(LR)),
      ('centered_rmsprop',
       alias.rmsprop(LR, centered=True),
       optim.RMSProp(LR, centered=True)),
      ('adam',
       alias.adam(LR),
       optim.Adam(LR)),
      ('adam_w',
       alias.adamw(LR, weight_decay=1e-4),
       optim.Adam(LR, weight_decay=1e-4)),  # Different name.
      ('adagrad',
       alias.adagrad(LR, initial_accumulator_value=0.),  # Different default!
       optim.Adagrad(LR)),
      ('lamb',
       alias.lamb(LR),
       optim.LAMB(LR)),
  )
  def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

    # flax/optim
    flax_params = self.init_params
    flax_optimizer = flax_optimizer.create(flax_params)
    for _ in range(STEPS):
      flax_optimizer = flax_optimizer.apply_gradient(
          self.per_step_updates)
      flax_params = flax_optimizer.target

    # optax
    optax_params = self.init_params
    state = optax_optimizer.init(optax_params)
    for _ in range(STEPS):
      updates, state = optax_optimizer.update(
          self.per_step_updates, state, optax_params)
      optax_params = update.apply_updates(optax_params, updates)

    # Check equivalence.
    chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
Ejemplo n.º 2
0
class FlaxOptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([0., 0.3, 500.,
                                            5.]), jnp.array([300., 3.]))

    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR), optim.GradientDescent(LR)),
        ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum(
            LR, beta=0.9)),  # Different names.
        ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True),
         optim.Momentum(LR, beta=0.9, nesterov=True)),
        ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)),
        ('centered_rmsprop', alias.rmsprop(
            LR, centered=True), optim.RMSProp(LR, centered=True)),
        ('adam', alias.adam(LR), optim.Adam(LR)),
        ('adam_w', alias.adamw(LR, weight_decay=1e-4),
         optim.Adam(LR, weight_decay=1e-4)),  # Different name.
        (
            'adagrad',
            alias.adagrad(LR,
                          initial_accumulator_value=0.),  # Different default!
            optim.Adagrad(LR)),
        ('lamb', alias.lamb(LR), optim.LAMB(LR)),
        ('lars',
         alias.lars(LR,
                    weight_decay=.5,
                    trust_coefficient=0.003,
                    momentum=0.9,
                    eps=1e-3),
         optim.LARS(
             LR, weight_decay=.5, trust_coefficient=0.003, beta=0.9,
             eps=1e-3)),
        ('adafactor',
         alias.adafactor(learning_rate=LR / 10.,
                         factored=True,
                         multiply_by_parameter_scale=True,
                         clipping_threshold=1.0,
                         decay_rate=0.8,
                         min_dim_size_to_factor=2),
         optim.Adafactor(learning_rate=LR / 10.,
                         factored=True,
                         multiply_by_parameter_scale=True,
                         clipping_threshold=1.0,
                         decay_rate=0.8,
                         min_dim_size_to_factor=2)),
    )
    def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

        # flax/optim
        flax_params = self.init_params
        flax_optimizer = flax_optimizer.create(flax_params)
        for _ in range(STEPS):
            flax_optimizer = flax_optimizer.apply_gradient(
                self.per_step_updates)
            flax_params = flax_optimizer.target

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)
        for _ in range(STEPS):
            updates, state = optax_optimizer.update(self.per_step_updates,
                                                    state, optax_params)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(flax_params, optax_params, rtol=2e-4)
Ejemplo n.º 3
0
def get_optimizer(hps):
    """Constructs the optimizer from the given HParams."""
    if 'weight_decay' in hps.opt_hparams:
        weight_decay = hps.opt_hparams['weight_decay']
    else:
        weight_decay = 0

    if hps.optimizer == 'sgd':
        return optimizers.GradientDescent(learning_rate=None)
    elif hps.optimizer == 'nesterov':
        return optimizers.Momentum(learning_rate=None,
                                   beta=hps.opt_hparams['momentum'],
                                   nesterov=True,
                                   weight_decay=weight_decay)
    elif hps.optimizer == 'momentum':
        return optimizers.Momentum(learning_rate=None,
                                   beta=hps.opt_hparams['momentum'],
                                   nesterov=False,
                                   weight_decay=weight_decay)
    elif hps.optimizer == 'lamb':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        return optimizers.LAMB(learning_rate=None,
                               beta1=hps.opt_hparams['beta1'],
                               beta2=hps.opt_hparams['beta2'],
                               eps=hps.opt_hparams['epsilon'],
                               weight_decay=weight_decay)
    elif hps.optimizer == 'adam':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        return optimizers.Adam(learning_rate=None,
                               beta1=hps.opt_hparams['beta1'],
                               beta2=hps.opt_hparams['beta2'],
                               eps=hps.opt_hparams['epsilon'],
                               weight_decay=weight_decay)
    elif hps.optimizer == 'lars':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        return optimizers.LARS(learning_rate=None,
                               beta=hps.opt_hparams['beta'],
                               weight_decay=weight_decay)
    elif hps.optimizer == 'mlperf_lars_resnet':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        weight_opt_def = optimizers.LARS(learning_rate=None,
                                         beta=hps.opt_hparams['beta'],
                                         weight_decay=weight_decay)
        other_opt_def = optimizers.Momentum(learning_rate=None,
                                            beta=hps.opt_hparams['beta'],
                                            weight_decay=0,
                                            nesterov=False)

        def filter_weights(key, _):
            return 'bias' not in key and 'scale' not in key

        def filter_other(key, _):
            return 'bias' in key or 'scale' in key

        weight_traversal = optimizers.ModelParamTraversal(filter_weights)
        other_traversal = optimizers.ModelParamTraversal(filter_other)
        return optimizers.MultiOptimizer((weight_traversal, weight_opt_def),
                                         (other_traversal, other_opt_def))
    elif hps.optimizer == 'mlperf_lamb':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        weight_opt_def = optimizers.LAMB(
            learning_rate=None,
            beta1=hps.opt_hparams['beta1'],
            beta2=hps.opt_hparams['beta2'],
            eps=hps.opt_hparams['epsilon'],
            weight_decay=hps.opt_hparams['lamb_weight_decay'])
        other_opt_def = optimizers.Adam(
            learning_rate=None,
            beta1=hps.opt_hparams['beta1'],
            beta2=hps.opt_hparams['beta2'],
            eps=hps.opt_hparams['epsilon'],
            weight_decay=hps.opt_hparams['adam_weight_decay'])

        def filter_weights(key, _):
            return 'bias' not in key and 'scale' not in key

        def filter_other(key, _):
            return 'bias' in key or 'scale' in key

        weight_traversal = optimizers.ModelParamTraversal(filter_weights)
        other_traversal = optimizers.ModelParamTraversal(filter_other)
        return optimizers.MultiOptimizer((weight_traversal, weight_opt_def),
                                         (other_traversal, other_opt_def))
    else:
        raise NotImplementedError('Optimizer {} not implemented'.format(
            hps.optimizer))