class FlaxOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @parameterized.named_parameters( ('sgd', alias.sgd(LR), optim.GradientDescent(LR)), ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum(LR, beta=0.9)), # Different names. ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True), optim.Momentum(LR, beta=0.9, nesterov=True)), ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)), ('centered_rmsprop', alias.rmsprop(LR, centered=True), optim.RMSProp(LR, centered=True)), ('adam', alias.adam(LR), optim.Adam(LR)), ('adam_w', alias.adamw(LR, weight_decay=1e-4), optim.Adam(LR, weight_decay=1e-4)), # Different name. ('adagrad', alias.adagrad(LR, initial_accumulator_value=0.), # Different default! optim.Adagrad(LR)), ('lamb', alias.lamb(LR), optim.LAMB(LR)), ) def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): # flax/optim flax_params = self.init_params flax_optimizer = flax_optimizer.create(flax_params) for _ in range(STEPS): flax_optimizer = flax_optimizer.apply_gradient( self.per_step_updates) flax_params = flax_optimizer.target # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) for _ in range(STEPS): updates, state = optax_optimizer.update( self.per_step_updates, state, optax_params) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
class FlaxOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([0., 0.3, 500., 5.]), jnp.array([300., 3.])) @parameterized.named_parameters( ('sgd', alias.sgd(LR), optim.GradientDescent(LR)), ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum( LR, beta=0.9)), # Different names. ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True), optim.Momentum(LR, beta=0.9, nesterov=True)), ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)), ('centered_rmsprop', alias.rmsprop( LR, centered=True), optim.RMSProp(LR, centered=True)), ('adam', alias.adam(LR), optim.Adam(LR)), ('adam_w', alias.adamw(LR, weight_decay=1e-4), optim.Adam(LR, weight_decay=1e-4)), # Different name. ( 'adagrad', alias.adagrad(LR, initial_accumulator_value=0.), # Different default! optim.Adagrad(LR)), ('lamb', alias.lamb(LR), optim.LAMB(LR)), ('lars', alias.lars(LR, weight_decay=.5, trust_coefficient=0.003, momentum=0.9, eps=1e-3), optim.LARS( LR, weight_decay=.5, trust_coefficient=0.003, beta=0.9, eps=1e-3)), ('adafactor', alias.adafactor(learning_rate=LR / 10., factored=True, multiply_by_parameter_scale=True, clipping_threshold=1.0, decay_rate=0.8, min_dim_size_to_factor=2), optim.Adafactor(learning_rate=LR / 10., factored=True, multiply_by_parameter_scale=True, clipping_threshold=1.0, decay_rate=0.8, min_dim_size_to_factor=2)), ) def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): # flax/optim flax_params = self.init_params flax_optimizer = flax_optimizer.create(flax_params) for _ in range(STEPS): flax_optimizer = flax_optimizer.apply_gradient( self.per_step_updates) flax_params = flax_optimizer.target # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) for _ in range(STEPS): updates, state = optax_optimizer.update(self.per_step_updates, state, optax_params) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(flax_params, optax_params, rtol=2e-4)
def get_optimizer(hps): """Constructs the optimizer from the given HParams.""" if 'weight_decay' in hps.opt_hparams: weight_decay = hps.opt_hparams['weight_decay'] else: weight_decay = 0 if hps.optimizer == 'sgd': return optimizers.GradientDescent(learning_rate=None) elif hps.optimizer == 'nesterov': return optimizers.Momentum(learning_rate=None, beta=hps.opt_hparams['momentum'], nesterov=True, weight_decay=weight_decay) elif hps.optimizer == 'momentum': return optimizers.Momentum(learning_rate=None, beta=hps.opt_hparams['momentum'], nesterov=False, weight_decay=weight_decay) elif hps.optimizer == 'lamb': assert hps.l2_decay_factor is None or weight_decay == 0.0 return optimizers.LAMB(learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=weight_decay) elif hps.optimizer == 'adam': assert hps.l2_decay_factor is None or weight_decay == 0.0 return optimizers.Adam(learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=weight_decay) elif hps.optimizer == 'lars': assert hps.l2_decay_factor is None or weight_decay == 0.0 return optimizers.LARS(learning_rate=None, beta=hps.opt_hparams['beta'], weight_decay=weight_decay) elif hps.optimizer == 'mlperf_lars_resnet': assert hps.l2_decay_factor is None or weight_decay == 0.0 weight_opt_def = optimizers.LARS(learning_rate=None, beta=hps.opt_hparams['beta'], weight_decay=weight_decay) other_opt_def = optimizers.Momentum(learning_rate=None, beta=hps.opt_hparams['beta'], weight_decay=0, nesterov=False) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optimizers.ModelParamTraversal(filter_weights) other_traversal = optimizers.ModelParamTraversal(filter_other) return optimizers.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) elif hps.optimizer == 'mlperf_lamb': assert hps.l2_decay_factor is None or weight_decay == 0.0 weight_opt_def = optimizers.LAMB( learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=hps.opt_hparams['lamb_weight_decay']) other_opt_def = optimizers.Adam( learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=hps.opt_hparams['adam_weight_decay']) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optimizers.ModelParamTraversal(filter_weights) other_traversal = optimizers.ModelParamTraversal(filter_other) return optimizers.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) else: raise NotImplementedError('Optimizer {} not implemented'.format( hps.optimizer))