class ExperimentalOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @chex.all_variants() @parameterized.named_parameters( ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop( LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR_SCHED, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR_SCHED, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ) def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
def __init__( self, obs_shape: Shape, n_actions: int, hparams: HParams, seed: int = 0, preprocess: Callable[[Observation], Observation] = lambda x: x, logging: bool = False, ): # public: self.n_actions = n_actions self.obs_shape = obs_shape self.memory = ReplayBuffer(hparams.replay_memory_size) self.rng = PRNGSequence(seed) self.preprocess = preprocess network = Cnn(n_actions, hidden_size=hparams.hidden_size) optimiser = Optimiser(*rmsprop_momentum( step_size=hparams.learning_rate, gamma=hparams.squared_gradient_momentum, momentum=hparams.gradient_momentum, eps=hparams.min_squared_gradient, )) super().__init__(network, optimiser, hparams, logging) # private: k = next(self.rng) _, params_target = self.network.init(k, (-1, *obs_shape)) _, params_online = self.network.init(k, (-1, *obs_shape)) self._opt_state = self.optimiser.init(params_online) self._params_target = params_target
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, hparams: HParams, preprocess: Callable = lambda x: x, logging: bool = False, ): # public: self.obs_spec = obs_spec self.action_spec = action_spec self.rng = jax.random.PRNGKey(hparams.seed) self.memory = OnlineBuffer(1, hparams.seed) self.preprocess = preprocess network = Cnn(action_spec.num_values) optimiser = Optimiser( *rmsprop_momentum( step_size=hparams.learning_rate, gamma=hparams.squared_gradient_momentum, momentum=hparams.gradient_momentum, eps=hparams.min_squared_gradient, ) ) super().__init__(network, optimiser, hparams, logging) # private: _, params = self.network.init(self.rng, (-1, *obs_spec.shape)) self._opt_state = self.optimiser.init(params)
def RmsPropMomentum(step_size, gamma=0.9, eps=1e-8, momentum=0.9): return OptimizerFromExperimental( experimental.rmsprop_momentum(step_size, gamma, eps, momentum))