class FlaxOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @parameterized.named_parameters( ('sgd', alias.sgd(LR), optim.GradientDescent(LR)), ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum(LR, beta=0.9)), # Different names. ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True), optim.Momentum(LR, beta=0.9, nesterov=True)), ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)), ('centered_rmsprop', alias.rmsprop(LR, centered=True), optim.RMSProp(LR, centered=True)), ('adam', alias.adam(LR), optim.Adam(LR)), ('adam_w', alias.adamw(LR, weight_decay=1e-4), optim.Adam(LR, weight_decay=1e-4)), # Different name. ('adagrad', alias.adagrad(LR, initial_accumulator_value=0.), # Different default! optim.Adagrad(LR)), ('lamb', alias.lamb(LR), optim.LAMB(LR)), ) def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): # flax/optim flax_params = self.init_params flax_optimizer = flax_optimizer.create(flax_params) for _ in range(STEPS): flax_optimizer = flax_optimizer.apply_gradient( self.per_step_updates) flax_params = flax_optimizer.target # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) for _ in range(STEPS): updates, state = optax_optimizer.update( self.per_step_updates, state, optax_params) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
def create_optimizer(name='adam', learning_rate=6.25e-5, beta1=0.9, beta2=0.999, eps=1.5e-4): """Create an optimizer for training. Currently, only the Adam optimizer is supported. Args: name: str, name of the optimizer to create. learning_rate: float, learning rate to use in the optimizer. beta1: float, beta1 parameter for the optimizer. beta2: float, beta2 parameter for the optimizer. eps: float, epsilon parameter for the optimizer. Returns: A flax optimizer. """ if name == 'adam': logging.info( 'Creating Adam optimizer with settings lr=%f, beta1=%f, ' 'beta2=%f, eps=%f', learning_rate, beta1, beta2, eps) return optim.Adam(learning_rate=learning_rate, beta1=beta1, beta2=beta2, eps=eps) elif name == 'rmsprop': logging.info( 'Creating RMSProp optimizer with settings lr=%f, beta2=%f, ' 'eps=%f', learning_rate, beta2, eps) return optim.RMSProp(learning_rate=learning_rate, beta2=beta2, eps=eps) else: raise ValueError('Unsupported optimizer {}'.format(name))
def test_init_state(self): params = onp.zeros((1, )) optimizer_def = optim.RMSProp(learning_rate=0.1, beta2=0.9, eps=0.01) state = optimizer_def.init_state(params) expected_hyper_params = _RMSPropHyperParams(0.1, 0.9, 0.01) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _RMSPropParamState(onp.zeros((1, )))) self.assertEqual(state, expected_state)
def test_apply_gradient(self): optimizer_def = optim.RMSProp(learning_rate=0.1, beta2=0.9, eps=0.01) params = onp.array([1.]) state = optim.OptimizerState(1, _RMSPropParamState(onp.array([0.1]))) grads = onp.array([4.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 2, _RMSPropParamState(onp.array([1.69]))) expected_new_params = onp.array([0.6946565]) onp.testing.assert_allclose(new_params, expected_new_params) self.assertEqual(new_state, expected_new_state)
def test_apply_gradient_centered(self): optimizer_def = optim.RMSProp(learning_rate=0.1, beta2=0.9, eps=0.01, centered=True) params = np.array([1.]) state = optim.OptimizerState( 1, _RMSPropParamState(np.array([0.1]), np.array([0.1]))) grads = np.array([4.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 2, _RMSPropParamState(np.array([1.69]), np.array([0.49]))) expected_new_params = np.array([0.670543], dtype=np.float32) np.testing.assert_allclose(new_params, expected_new_params, rtol=1e-6) np.testing.assert_allclose(new_state.param_states.v, expected_new_state.param_states.v) np.testing.assert_allclose(new_state.param_states.mg, expected_new_state.param_states.mg)
def get_optimizer(hparams): """Constructs the optimizer from the given HParams. Args: hparams: Hyper parameters. Returns: A flax optimizer. """ if hparams.optimizer == 'sgd': return optimizers.GradientDescent( learning_rate=hparams.lr_hparams['initial_learning_rate']) if hparams.optimizer == 'nesterov': return optimizers.Momentum( learning_rate=hparams.lr_hparams['initial_learning_rate'], beta=hparams.opt_hparams.get('momentum', 0.9), weight_decay=hparams.opt_hparams.get('weight_decay', 0.0), nesterov=True) if hparams.optimizer == 'momentum': return optimizers.Momentum( learning_rate=hparams.lr_hparams['initial_learning_rate'], beta=hparams.opt_hparams.get('momentum', 0.9), weight_decay=hparams.opt_hparams.get('weight_decay', 0.0), nesterov=False) if hparams.optimizer == 'adam': return optimizers.Adam( learning_rate=hparams.lr_hparams['initial_learning_rate'], beta1=hparams.opt_hparams.get('beta1', 0.9), beta2=hparams.opt_hparams.get('beta2', 0.999), eps=hparams.opt_hparams.get('epsilon', 1e-8), weight_decay=hparams.opt_hparams.get('weight_decay', 0.0), ) if hparams.optimizer == 'rmsprop': return optimizers.RMSProp( learning_rate=hparams.lr_hparams.get('initial_learning_rate'), beta2=hparams.opt_hparams.get('beta2', 0.9), eps=hparams.opt_hparams.get('epsilon', 1e-8)) else: raise NotImplementedError('Optimizer {} not implemented'.format( hparams.optimizer))
class FlaxOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([0., 0.3, 500., 5.]), jnp.array([300., 3.])) @parameterized.named_parameters( ('sgd', alias.sgd(LR), optim.GradientDescent(LR)), ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum( LR, beta=0.9)), # Different names. ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True), optim.Momentum(LR, beta=0.9, nesterov=True)), ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)), ('centered_rmsprop', alias.rmsprop( LR, centered=True), optim.RMSProp(LR, centered=True)), ('adam', alias.adam(LR), optim.Adam(LR)), ('adam_w', alias.adamw(LR, weight_decay=1e-4), optim.Adam(LR, weight_decay=1e-4)), # Different name. ( 'adagrad', alias.adagrad(LR, initial_accumulator_value=0.), # Different default! optim.Adagrad(LR)), ('lamb', alias.lamb(LR), optim.LAMB(LR)), ('lars', alias.lars(LR, weight_decay=.5, trust_coefficient=0.003, momentum=0.9, eps=1e-3), optim.LARS( LR, weight_decay=.5, trust_coefficient=0.003, beta=0.9, eps=1e-3)), ('adafactor', alias.adafactor(learning_rate=LR / 10., factored=True, multiply_by_parameter_scale=True, clipping_threshold=1.0, decay_rate=0.8, min_dim_size_to_factor=2), optim.Adafactor(learning_rate=LR / 10., factored=True, multiply_by_parameter_scale=True, clipping_threshold=1.0, decay_rate=0.8, min_dim_size_to_factor=2)), ) def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): # flax/optim flax_params = self.init_params flax_optimizer = flax_optimizer.create(flax_params) for _ in range(STEPS): flax_optimizer = flax_optimizer.apply_gradient( self.per_step_updates) flax_params = flax_optimizer.target # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) for _ in range(STEPS): updates, state = optax_optimizer.update(self.per_step_updates, state, optax_params) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(flax_params, optax_params, rtol=2e-4)