def test_multi_optimizer(self): params = {'a': 0., 'b': 0.} opt_a = optim.GradientDescent(learning_rate=1.) opt_b = optim.GradientDescent(learning_rate=10.) t_a = traverse_util.t_identity['a'] t_b = traverse_util.t_identity['b'] optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b)) state = optimizer_def.init_state(params) expected_hyper_params = [ _GradientDescentHyperParams(1.), _GradientDescentHyperParams(10.) ] self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = [optim.OptimizerState(0, [()])] * 2 self.assertEqual(state, expected_state) grads = {'a': -1., 'b': -2.} new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_params = {'a': 1., 'b': 20.} expected_state = [optim.OptimizerState(1, [()])] * 2 self.assertEqual(new_state, expected_state) self.assertEqual(new_params, expected_params) # override learning_rate hp = optimizer_def.update_hyper_params(learning_rate=2.) new_params, new_state = optimizer_def.apply_gradient( hp, params, state, grads) expected_params = {'a': 2., 'b': 4.} self.assertEqual(new_params, expected_params)
def test_grad_var(self): model_size = 10 example_grads = [{ 'layer1': np.ones(model_size), 'layer2': 3 * np.ones(model_size) }, { 'layer1': 2 * np.ones(model_size), 'layer2': np.ones(model_size) }] eval_config = {'ema_beta': 0.5} training_metrics_grabber = utils.TrainingMetricsGrabber.create( example_grads[0], eval_config) # For the purposes of this test, we create fake optimizers to satisfy # metrics grabber API. fake_model = nn.Model(None, example_grads[0]) new_optimizer = optimizers.GradientDescent( learning_rate=None).create(fake_model) old_optimizer = optimizers.GradientDescent( learning_rate=None).create(fake_model) for grad in example_grads: training_metrics_grabber = training_metrics_grabber.update( grad, old_optimizer, new_optimizer) for layer in ['layer1', 'layer2']: expected_grad_ema = 1 / 4 * np.zeros(model_size) + 1 / 4 * example_grads[ 0][layer] + 1 / 2 * example_grads[1][layer] self.assertArraysAllClose(expected_grad_ema, training_metrics_grabber.state[layer].grad_ema)
def test_multiple_hparams(self): opt_dict = { 'hp-beta1': optim.GradientDescent(learning_rate=1 ).create(np.array([0.0, 0.1, 0.2, 0.9])), 'hp-learning_rate': optim.GradientDescent(learning_rate=1 ).create(np.array([0.0, 0.1, 0.2, 0.9])), } gv_dict = { 'beta1': { 'learning_rate_scalar': 1, 'activation_fn': 'linear', 'activation_ceiling': None, 'activation_floor': None, 'clip_min': None, 'clip_max': None, }, 'learning_rate': { 'learning_rate_scalar': 1, 'activation_fn': 'linear', 'activation_ceiling': None, 'activation_floor': None, 'clip_min': None, 'clip_max': None, } } out = guided_parameters.get_activated_hparams(opt_dict, gv_dict) self.assertListEqual(list(out['hp-beta1']), [0.0, 0.1, 0.2, 0.9]) self.assertListEqual(list(out['hp-learning_rate']), [0.0, 0.1, 0.2, 0.9])
def test_multi_guided_parameter(self): opt_dict = { 'hp-beta1': optim.GradientDescent(learning_rate=1 ).create(np.array([0.0, 0.1, 0.2, 0.9])), 'dp-ex_index': optim.GradientDescent(learning_rate=1 ).create(np.array([1, 2, 3, 4, 5])), } gv_dict = { 'beta1': { 'learning_rate_scalar': 1, 'activation_fn': 'linear', 'activation_ceiling': None, 'activation_floor': None, }, 'ex_index': { 'learning_rate_scalar': 1, 'activation_fn': 'linear', 'activation_ceiling': None, 'activation_floor': None, } } raw_vars_dict, act_fn_dict = guided_parameters.get_raw_vars_and_act_fns( opt_dict, gv_dict) self.assertListEqual(list(raw_vars_dict['hp-beta1']), [0.0, 0.1, 0.2, 0.9]) self.assertListEqual(list(raw_vars_dict['dp-ex_index']), [1, 2, 3, 4, 5]) self.assertCountEqual(list(act_fn_dict), ['hp-beta1', 'dp-ex_index'])
def test_multi_optimizer_multiple_matches(self): params = {'a': {'x': 0., 'y': 0.}, 'b': {'y': 0, 'z': 0.}} opt_a = optim.GradientDescent(learning_rate=1.) opt_b = optim.GradientDescent(learning_rate=10.) t_a = optim.ModelParamTraversal( lambda path, _: path.endswith('/x') or path.endswith('/y')) t_b = optim.ModelParamTraversal(lambda path, value: value.dtype == jnp. int32 or path.endswith('/z')) optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b)) with self.assertRaisesRegex( ValueError, r"Multiple optimizers match.*'y': \[0, 1\]"): jax.jit(optimizer_def.init_state)(params)
def test_init_state(self): params = onp.zeros((1,)) optimizer_def = optim.GradientDescent(learning_rate=0.1) state = optimizer_def.init_state(params) expected_hyper_params = _GradientDescentHyperParams(0.1) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState(0, ()) self.assertEqual(state, expected_state)
class FlaxOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @parameterized.named_parameters( ('sgd', alias.sgd(LR), optim.GradientDescent(LR)), ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum(LR, beta=0.9)), # Different names. ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True), optim.Momentum(LR, beta=0.9, nesterov=True)), ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)), ('centered_rmsprop', alias.rmsprop(LR, centered=True), optim.RMSProp(LR, centered=True)), ('adam', alias.adam(LR), optim.Adam(LR)), ('adam_w', alias.adamw(LR, weight_decay=1e-4), optim.Adam(LR, weight_decay=1e-4)), # Different name. ('adagrad', alias.adagrad(LR, initial_accumulator_value=0.), # Different default! optim.Adagrad(LR)), ('lamb', alias.lamb(LR), optim.LAMB(LR)), ) def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): # flax/optim flax_params = self.init_params flax_optimizer = flax_optimizer.create(flax_params) for _ in range(STEPS): flax_optimizer = flax_optimizer.apply_gradient( self.per_step_updates) flax_params = flax_optimizer.target # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) for _ in range(STEPS): updates, state = optax_optimizer.update( self.per_step_updates, state, optax_params) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
def test_apply_gradient(self): optimizer_def = optim.GradientDescent(learning_rate=0.1) params = onp.ones((1,)) state = optim.OptimizerState(0, ()) grads = onp.array([3.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState(1, ()) expected_new_params = onp.array([0.7]) self.assertEqual(new_params, expected_new_params) self.assertEqual(new_state, expected_new_state)
def test_ignores_model(self): opt_dict = { 'model': optim.GradientDescent(learning_rate=1 ).create(np.array([0.0, 0.1, 0.2, 0.9])), } gv_dict = {} raw_vars_dict, act_fn_dict = guided_parameters.get_raw_vars_and_act_fns( opt_dict, gv_dict) self.assertEqual(raw_vars_dict, {}) self.assertEqual(act_fn_dict, {})
def test_optimizer_with_focus(self): params = {'a': 0., 'b': 0.} opt_def = optim.GradientDescent(learning_rate=1.) t_a = traverse_util.t_identity['a'] optimizer = opt_def.create(params, focus=t_a) expected_state = [optim.OptimizerState(0, [()])] self.assertEqual(optimizer.state, expected_state) grads = {'a': -1., 'b': -2.} new_optimizer = optimizer.apply_gradient(grads) expected_params = {'a': 1., 'b': 0.} expected_state = [optim.OptimizerState(1, [()])] self.assertEqual(new_optimizer.state, expected_state) self.assertEqual(new_optimizer.target, expected_params)
def test_ignores_model(self): opt_dict = { 'model': optim.GradientDescent(learning_rate=1 ).create(np.array([0.0, 0.1, 0.2, 0.9])), } gv_dict = {} out = guided_parameters.get_activated_hparams(opt_dict, gv_dict) self.assertEqual( out, { 'hp-beta1': None, 'hp-decay_rate': None, 'hp-eps': None, 'hp-learning_rate': None, 'hp-weight_decay': None, 'hp-label_smoothing': None, })
def get_optimizer(hparams): """Constructs the optimizer from the given HParams. Args: hparams: Hyper parameters. Returns: A flax optimizer. """ if hparams.optimizer == 'sgd': return optimizers.GradientDescent( learning_rate=hparams.lr_hparams['initial_learning_rate']) if hparams.optimizer == 'nesterov': return optimizers.Momentum( learning_rate=hparams.lr_hparams['initial_learning_rate'], beta=hparams.opt_hparams.get('momentum', 0.9), weight_decay=hparams.opt_hparams.get('weight_decay', 0.0), nesterov=True) if hparams.optimizer == 'momentum': return optimizers.Momentum( learning_rate=hparams.lr_hparams['initial_learning_rate'], beta=hparams.opt_hparams.get('momentum', 0.9), weight_decay=hparams.opt_hparams.get('weight_decay', 0.0), nesterov=False) if hparams.optimizer == 'adam': return optimizers.Adam( learning_rate=hparams.lr_hparams['initial_learning_rate'], beta1=hparams.opt_hparams.get('beta1', 0.9), beta2=hparams.opt_hparams.get('beta2', 0.999), eps=hparams.opt_hparams.get('epsilon', 1e-8), weight_decay=hparams.opt_hparams.get('weight_decay', 0.0), ) if hparams.optimizer == 'rmsprop': return optimizers.RMSProp( learning_rate=hparams.lr_hparams.get('initial_learning_rate'), beta2=hparams.opt_hparams.get('beta2', 0.9), eps=hparams.opt_hparams.get('epsilon', 1e-8)) else: raise NotImplementedError('Optimizer {} not implemented'.format( hparams.optimizer))
def create_optimizer(model_params, learning_rate: float): optimizer_def = optim.GradientDescent(learning_rate) model_optimizer = optimizer_def.create(model_params) return model_optimizer
class FlaxOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([0., 0.3, 500., 5.]), jnp.array([300., 3.])) @parameterized.named_parameters( ('sgd', alias.sgd(LR), optim.GradientDescent(LR)), ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum( LR, beta=0.9)), # Different names. ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True), optim.Momentum(LR, beta=0.9, nesterov=True)), ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)), ('centered_rmsprop', alias.rmsprop( LR, centered=True), optim.RMSProp(LR, centered=True)), ('adam', alias.adam(LR), optim.Adam(LR)), ('adam_w', alias.adamw(LR, weight_decay=1e-4), optim.Adam(LR, weight_decay=1e-4)), # Different name. ( 'adagrad', alias.adagrad(LR, initial_accumulator_value=0.), # Different default! optim.Adagrad(LR)), ('lamb', alias.lamb(LR), optim.LAMB(LR)), ('lars', alias.lars(LR, weight_decay=.5, trust_coefficient=0.003, momentum=0.9, eps=1e-3), optim.LARS( LR, weight_decay=.5, trust_coefficient=0.003, beta=0.9, eps=1e-3)), ('adafactor', alias.adafactor(learning_rate=LR / 10., factored=True, multiply_by_parameter_scale=True, clipping_threshold=1.0, decay_rate=0.8, min_dim_size_to_factor=2), optim.Adafactor(learning_rate=LR / 10., factored=True, multiply_by_parameter_scale=True, clipping_threshold=1.0, decay_rate=0.8, min_dim_size_to_factor=2)), ) def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): # flax/optim flax_params = self.init_params flax_optimizer = flax_optimizer.create(flax_params) for _ in range(STEPS): flax_optimizer = flax_optimizer.apply_gradient( self.per_step_updates) flax_params = flax_optimizer.target # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) for _ in range(STEPS): updates, state = optax_optimizer.update(self.per_step_updates, state, optax_params) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(flax_params, optax_params, rtol=2e-4)
def get_optimizer(hps): """Constructs the optimizer from the given HParams.""" if 'weight_decay' in hps.opt_hparams: weight_decay = hps.opt_hparams['weight_decay'] else: weight_decay = 0 if hps.optimizer == 'sgd': return optimizers.GradientDescent(learning_rate=None) elif hps.optimizer == 'nesterov': return optimizers.Momentum(learning_rate=None, beta=hps.opt_hparams['momentum'], nesterov=True, weight_decay=weight_decay) elif hps.optimizer == 'momentum': return optimizers.Momentum(learning_rate=None, beta=hps.opt_hparams['momentum'], nesterov=False, weight_decay=weight_decay) elif hps.optimizer == 'lamb': assert hps.l2_decay_factor is None or weight_decay == 0.0 return optimizers.LAMB(learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=weight_decay) elif hps.optimizer == 'adam': assert hps.l2_decay_factor is None or weight_decay == 0.0 return optimizers.Adam(learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=weight_decay) elif hps.optimizer == 'lars': assert hps.l2_decay_factor is None or weight_decay == 0.0 return optimizers.LARS(learning_rate=None, beta=hps.opt_hparams['beta'], weight_decay=weight_decay) elif hps.optimizer == 'mlperf_lars_resnet': assert hps.l2_decay_factor is None or weight_decay == 0.0 weight_opt_def = optimizers.LARS(learning_rate=None, beta=hps.opt_hparams['beta'], weight_decay=weight_decay) other_opt_def = optimizers.Momentum(learning_rate=None, beta=hps.opt_hparams['beta'], weight_decay=0, nesterov=False) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optimizers.ModelParamTraversal(filter_weights) other_traversal = optimizers.ModelParamTraversal(filter_other) return optimizers.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) elif hps.optimizer == 'mlperf_lamb': assert hps.l2_decay_factor is None or weight_decay == 0.0 weight_opt_def = optimizers.LAMB( learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=hps.opt_hparams['lamb_weight_decay']) other_opt_def = optimizers.Adam( learning_rate=None, beta1=hps.opt_hparams['beta1'], beta2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], weight_decay=hps.opt_hparams['adam_weight_decay']) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optimizers.ModelParamTraversal(filter_weights) other_traversal = optimizers.ModelParamTraversal(filter_other) return optimizers.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) else: raise NotImplementedError('Optimizer {} not implemented'.format( hps.optimizer))