def test_momentum_with_weight_norm(self): params = onp.ones((2, 2)) * 2. optimizer_def = optim.WeightNorm(optim.Momentum(0.1)) state = optimizer_def.init_state(params) self.assertEqual(jax.tree_map(onp.shape, state), optim.OptimizerState( step=(), param_states=_WeightNormParamState( direction_state=_MomentumParamState(momentum=(2, 2)), scale_state=_MomentumParamState(momentum=(1, 2)), mult=(1, 2) ) )) grads = onp.ones((2, 2)) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) onp.testing.assert_allclose(new_params, onp.full_like(params, 1.9)) onp.testing.assert_allclose(new_state.param_states.mult, 1.9 * 2 ** 0.5)
def create_optimizer(config, model, learning_rate, train_size, sampler_rng): """Create optimizer definition based on config flags.""" if config.optimizer == 'adam': optimizer_def = optim.Adam( learning_rate=learning_rate, beta1=config.momentum) elif config.optimizer == 'momentum': optimizer_def = optim.Momentum( learning_rate=learning_rate, beta=config.momentum) elif config.optimizer == 'sym_euler': optimizer_def = sym_euler_sgmcmc.SymEulerSGMCMC( train_size, sampler_rng, learning_rate=learning_rate, beta=config.momentum, temperature=config.base_temp, step_size_factor=1.) else: raise ValueError('Invalid value %s for config.optimizer.' % config.optimizer) if config.weight_norm == 'none': pass elif config.weight_norm == 'learned': optimizer_def = optim.WeightNorm(optimizer_def) elif config.weight_norm in ['fixed', 'ws_sqrt', 'learned_b', 'ws']: # Applied in layers directly. pass else: raise ValueError('Invalid value %s for config.weight_norm.' % config.weight_norm) optimizer = optimizer_def.create(model) if not config.debug_run: optimizer = optimizer.replicate() return optimizer