Ejemplo n.º 1
0
class ExperimentalOptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

    @chex.all_variants()
    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(
            LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1,
                                           momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
        ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR_SCHED, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1),
         optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum',
         alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR_SCHED,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
    )
    def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer,
                                      rtol):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = jax_optimizer
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)

        @self.variant
        def step(updates, state):
            return optax_optimizer.update(updates, state)

        for _ in range(STEPS):
            updates, state = step(self.per_step_updates, state)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
Ejemplo n.º 2
0
    def __init__(
        self,
        obs_shape: Shape,
        n_actions: int,
        hparams: HParams,
        seed: int = 0,
        preprocess: Callable[[Observation], Observation] = lambda x: x,
        logging: bool = False,
    ):
        # public:
        self.n_actions = n_actions
        self.obs_shape = obs_shape
        self.memory = ReplayBuffer(hparams.replay_memory_size)
        self.rng = PRNGSequence(seed)
        self.preprocess = preprocess
        network = Cnn(n_actions, hidden_size=hparams.hidden_size)
        optimiser = Optimiser(*rmsprop_momentum(
            step_size=hparams.learning_rate,
            gamma=hparams.squared_gradient_momentum,
            momentum=hparams.gradient_momentum,
            eps=hparams.min_squared_gradient,
        ))
        super().__init__(network, optimiser, hparams, logging)

        # private:
        k = next(self.rng)
        _, params_target = self.network.init(k, (-1, *obs_shape))
        _, params_online = self.network.init(k, (-1, *obs_shape))
        self._opt_state = self.optimiser.init(params_online)
        self._params_target = params_target
Ejemplo n.º 3
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        hparams: HParams,
        preprocess: Callable = lambda x: x,
        logging: bool = False,
    ):
        # public:
        self.obs_spec = obs_spec
        self.action_spec = action_spec
        self.rng = jax.random.PRNGKey(hparams.seed)
        self.memory = OnlineBuffer(1, hparams.seed)
        self.preprocess = preprocess
        network = Cnn(action_spec.num_values)
        optimiser = Optimiser(
            *rmsprop_momentum(
                step_size=hparams.learning_rate,
                gamma=hparams.squared_gradient_momentum,
                momentum=hparams.gradient_momentum,
                eps=hparams.min_squared_gradient,
            )
        )
        super().__init__(network, optimiser, hparams, logging)

        # private:
        _, params = self.network.init(self.rng, (-1, *obs_spec.shape))
        self._opt_state = self.optimiser.init(params)
Ejemplo n.º 4
0
def RmsPropMomentum(step_size, gamma=0.9, eps=1e-8, momentum=0.9):
    return OptimizerFromExperimental(
        experimental.rmsprop_momentum(step_size, gamma, eps, momentum))