Exemplo n.º 1
0
class ExperimentalOptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

    @chex.all_variants()
    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(
            LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1,
                                           momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
        ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR_SCHED, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1),
         optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum',
         alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR_SCHED,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
    )
    def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer,
                                      rtol):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = jax_optimizer
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)

        @self.variant
        def step(updates, state):
            return optax_optimizer.update(updates, state)

        for _ in range(STEPS):
            updates, state = step(self.per_step_updates, state)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
Exemplo n.º 2
0
def get_optimizer(optimizer, sched, b1=0.9, b2=0.999):
  if optimizer.lower() == 'adagrad':
    return optimizers.adagrad(sched)
  elif optimizer.lower() == 'adam':
    return optimizers.adam(sched, b1, b2)
  elif optimizer.lower() == 'rmsprop':
    return optimizers.rmsprop(sched)
  elif optimizer.lower() == 'momentum':
    return optimizers.momentum(sched, 0.9)
  elif optimizer.lower() == 'sgd':
    return optimizers.sgd(sched)
  else:
    raise Exception('Invalid optimizer: {}'.format(optimizer))
Exemplo n.º 3
0
def optimizer(name="adam",
              momentum_mass=0.9, rmsprop_gamma=0.9, rmsprop_eps=1e-8,
              adam_b1=0.9, adam_b2=0.997, adam_eps=1e-8):
  """Return the optimizer, by name."""
  if name == "sgd":
    return optimizers.sgd(learning_rate)
  if name == "momentum":
    return optimizers.momentum(learning_rate, mass=momentum_mass)
  if name == "rmsprop":
    return optimizers.rmsprop(
        learning_rate, gamma=rmsprop_gamma, eps=rmsprop_eps)
  if name == "adam":
    return optimizers.adam(learning_rate, b1=adam_b1, b2=adam_b2, eps=adam_eps)
  raise ValueError("Unknown optimizer %s" % str(name))
Exemplo n.º 4
0
    def test_rmsprop(self):
        decay, eps = .9, 0.1

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = optimizers.rmsprop(LR, decay, eps)
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # experimental/optix.py
        optix_params = self.init_params
        rmsprop = optix.rmsprop(LR, decay, eps)
        state = rmsprop.init(optix_params)
        for _ in range(STEPS):
            updates, state = rmsprop.update(self.per_step_updates, state)
            optix_params = optix.apply_updates(optix_params, updates)

        # Check equivalence.
        for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)):
            np.testing.assert_allclose(x, y, rtol=1e-5)
Exemplo n.º 5
0
    def get_optimizer(self, optim=None, stage='learn', step_size=None):

        if optim is None:
            if stage == 'learn':
                optim = self.optim_learn
            else:
                optim = self.optim_proj
        if step_size is None:
            step_size = self.step_size

        if optim == 1:
            if self.verb > 2:
                print("With momentum optimizer")
            opt_init, opt_update, get_params = momentum(step_size=step_size,
                                                        mass=0.95)
        elif optim == 2:
            if self.verb > 2:
                print("With rmsprop optimizer")
            opt_init, opt_update, get_params = rmsprop(step_size,
                                                       gamma=0.9,
                                                       eps=1e-8)
        elif optim == 3:
            if self.verb > 2:
                print("With adagrad optimizer")
            opt_init, opt_update, get_params = adagrad(step_size, momentum=0.9)
        elif optim == 4:
            if self.verb > 2:
                print("With Nesterov optimizer")
            opt_init, opt_update, get_params = nesterov(step_size, 0.9)
        elif optim == 5:
            if self.verb > 2:
                print("With SGD optimizer")
            opt_init, opt_update, get_params = sgd(step_size)
        else:
            if self.verb > 2:
                print("With adam optimizer")
            opt_init, opt_update, get_params = adam(step_size)

        return opt_init, opt_update, get_params
Exemplo n.º 6
0
def get_optimizer(
    learning_rate: float = 1e-4, optimizer="sdg", optimizer_kwargs: dict = None
) -> JaxOptimizer:
    """Return a `JaxOptimizer` dataclass for a JAX optimizer

    Args:
        learning_rate (float, optional): Step size. Defaults to 1e-4.
        optimizer (str, optional): Optimizer type (Allowed types: "adam",
            "adamax", "adagrad", "rmsprop", "sdg"). Defaults to "sdg".
        optimizer_kwargs (dict, optional): Additional keyword arguments
            that are passed to the optimizer. Defaults to None.

    Returns:
        JaxOptimizer
    """
    from jax.config import config  # pylint:disable=import-outside-toplevel

    config.update("jax_enable_x64", True)
    from jax import jit  # pylint:disable=import-outside-toplevel
    from jax.experimental import optimizers  # pylint:disable=import-outside-toplevel

    if optimizer_kwargs is None:
        optimizer_kwargs = {}
    optimizer = optimizer.lower()
    if optimizer == "adam":
        opt_init, opt_update, get_params = optimizers.adam(learning_rate, **optimizer_kwargs)
    elif optimizer == "adagrad":
        opt_init, opt_update, get_params = optimizers.adagrad(learning_rate, **optimizer_kwargs)
    elif optimizer == "adamax":
        opt_init, opt_update, get_params = optimizers.adamax(learning_rate, **optimizer_kwargs)
    elif optimizer == "rmsprop":
        opt_init, opt_update, get_params = optimizers.rmsprop(learning_rate, **optimizer_kwargs)
    else:
        opt_init, opt_update, get_params = optimizers.sgd(learning_rate, **optimizer_kwargs)

    opt_update = jit(opt_update)

    return JaxOptimizer(opt_init, opt_update, get_params)
Exemplo n.º 7
0
def _JaxRmsProp(machine, learning_rate=0.001, beta=0.9, epscut=1.0e-7):
    return Wrap(machine, jaxopt.rmsprop(learning_rate, beta, epscut))
Exemplo n.º 8
0
class AliasTest(chex.TestCase):
    def setUp(self):
        super(AliasTest, self).setUp()
        self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

    @chex.all_variants()
    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(LR, .9, 0.1), optimizers.rmsprop(
            LR, .9, 0.1), 1e-5),
        ('adagrad', alias.adagrad(
            LR,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
    )
    def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer,
                                      rtol):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = jax_optimizer
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)

        @self.variant
        def step(updates, state):
            return optax_optimizer.update(updates, state)

        for _ in range(STEPS):
            updates, state = step(self.per_step_updates, state)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)

    @parameterized.named_parameters(
        ('sgd', alias.sgd(1e-2, 0.0)),
        ('adam', alias.adam(1e-1)),
        ('adamw', alias.adamw(1e-1)),
        ('lamb', alias.adamw(1e-1)),
        ('rmsprop', alias.rmsprop(1e-1)),
        ('fromage', transform.scale_by_fromage(-1e-2)),
        ('adabelief', alias.adabelief(1e-1)),
    )
    def test_parabel(self, opt):
        initial_params = jnp.array([-1.0, 10.0, 1.0])
        final_params = jnp.array([1.0, -1.0, 1.0])

        @jax.grad
        def get_updates(params):
            return jnp.sum((params - final_params)**2)

        @jax.jit
        def step(params, state):
            updates, state = opt.update(get_updates(params), state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(1000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=1e-2, atol=1e-2)

    @parameterized.named_parameters(
        ('sgd', alias.sgd(2e-3, 0.2)),
        ('adam', alias.adam(1e-1)),
        ('adamw', alias.adamw(1e-1)),
        ('lamb', alias.adamw(1e-1)),
        ('rmsprop', alias.rmsprop(5e-3)),
        ('fromage', transform.scale_by_fromage(-5e-3)),
        ('adabelief', alias.adabelief(1e-1)),
    )
    def test_rosenbrock(self, opt):
        a = 1.0
        b = 100.0
        initial_params = jnp.array([0.0, 0.0])
        final_params = jnp.array([a, a**2])

        @jax.grad
        def get_updates(params):
            return (a - params[0])**2 + b * (params[1] - params[0]**2)**2

        @jax.jit
        def step(params, state):
            updates, state = opt.update(get_updates(params), state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(10000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)
Exemplo n.º 9
0
 def __init__(self, learning_rate):
     super().__init__(learning_rate)
     self.opt_init, self.opt_update, self.get_params = rmsprop(
         step_size=self.lr)
Exemplo n.º 10
0
def main():
    X, y, Xtest, ytest = get_data(50)

    # PRIOR FUNCTIONS (mean, covariance)
    mu_f = zero_mean
    cov_f = functools.partial(gram, rbf_kernel)
    gp_priors = (mu_f, cov_f)

    # Kernel, Likelihood parameters
    params = {
        "gamma": 2.0,
        # 'length_scale': 1.0,
        # 'var_f': 1.0,
        "likelihood_noise": 1.0,
    }
    # saturate parameters with likelihoods
    params = saturate(params)

    # LOSS FUNCTION
    mll_loss = jax.jit(functools.partial(marginal_likelihood, gp_priors))

    # GRADIENT LOSS FUNCTION
    dloss = jax.jit(jax.grad(mll_loss))

    # STEP FUNCTION
    @jax.jit
    def step(params, X, y, opt_state):
        # calculate loss
        loss = mll_loss(params, X, y)

        # calculate gradient of loss
        grads = dloss(params, X, y)

        # update optimizer state
        opt_state = opt_update(0, grads, opt_state)

        # update params
        params = get_params(opt_state)

        return params, opt_state, loss

    # TRAINING PARARMETERS
    n_epochs = 500
    learning_rate = 0.01
    losses = list()

    # initialize optimizer
    opt_init, opt_update, get_params = optimizers.rmsprop(
        step_size=learning_rate)

    # initialize parameters
    opt_state = opt_init(params)

    # get initial parameters
    params = get_params(opt_state)

    postfix = {}

    with tqdm.trange(n_epochs) as bar:

        for i in bar:
            # 1 step - optimize function
            params, opt_state, value = step(params, X, y, opt_state)

            # update params
            postfix = {}
            for ikey in params.keys():
                postfix[ikey] = f"{jax.nn.softplus(params[ikey]):.2f}"

            # save loss values
            losses.append(value.mean())

            # update progress bar
            postfix["Loss"] = f"{onp.array(losses[-1]):.2f}"
            bar.set_postfix(postfix)
            # saturate params
            params = saturate(params)

    # Posterior Predictions
    mu_y, var_y = posterior(params, gp_priors, X, y, Xtest, True, False)

    # Uncertainty
    uncertainty = 1.96 * jnp.sqrt(var_y.squeeze())

    fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
    ax[0].scatter(X, y, c="red", label="Training Data")
    ax[0].plot(
        Xtest.squeeze(),
        mu_y.squeeze(),
        label=r"Predictive Mean",
        color="black",
        linewidth=3,
    )
    ax[0].fill_between(
        Xtest.squeeze(),
        mu_y.squeeze() + uncertainty,
        mu_y.squeeze() - uncertainty,
        alpha=0.3,
        color="darkorange",
        label=f"Predictive Std (95% Confidence)",
    )
    ax[0].legend(fontsize=12)
    ax[1].plot(losses, label="losses")
    plt.tight_layout()
    fig.savefig("figures/jaxgp/examples/1d_example.png")
    plt.show()
Exemplo n.º 11
0
def main():
    env = gym.make('SpaceInvaders-v0')

    memory = deque(maxlen=MEM_SIZE)

    # fill memory with random interactions with the environment
    while len(memory) < MEM_SIZE:
        observation = env.reset()
        frames = deque([np.zeros((185, 95)) for _ in range(STACK_SIZE)], maxlen=STACK_SIZE)
        frames.append(preprocess(observation))
        state = stack_frames(frames)
        done = False
        while not done:
            # 0 no action, 1 fire, 2 move right, 3 move left, 4 move right fire, 5 move left fire
            action = env.action_space.sample()
            observation_, reward, done, info = env.step(action)
            frames.append(preprocess(observation_))
            state_ = stack_frames(frames)
            memory = store_transition(memory, state, action, reward, state_)
            state = state_
    print('done initializing memory')

    init_Q, pred_Q = DeepQNetwork()

    # two separate Q-Table approximations (eval and next)
    # initialize parameters, not committing to a batch size (NHWC)
    # we choose 3 channels as we want to pass stacks of 4 consecutive frames
    in_shape = (-1, 185, 95, STACK_SIZE)
    if LOAD:
        path = os.path.join(WEIGHTS_PATH, "params_Q_eval.npy")
        params_Q_eval = load_params(path)
    else:
        _, params_Q_eval = init_Q(in_shape)
    params_Q_next = params_Q_eval.copy()

    # Initialize RMSProp optimizer
    opt_init, opt_update = optimizers.rmsprop(ALPHA)
    opt_state = opt_init(params_Q_eval)
    opt_step = 0

    # Define a simple mean-squared-error loss
    def loss(params, batch):
        inputs, targets = batch
        predictions = pred_Q(params, inputs)
        return np.mean((predictions - targets) ** 2)

    # Define a compiled update step
    @jit
    def step(j, opt_state, batch):
        params = optimizers.get_params(opt_state)
        g = grad(loss)(params, batch)
        return opt_update(j, g, opt_state)

    def learn(opt_step, opt_state, params_Q_eval, params_Q_next):
        mini_batch = sample(memory, BATCH_SIZE)

        if opt_step % TAU == 0:
            params_Q_next = params_Q_eval.copy()

        input_states = np.stack([transition[0] for transition in mini_batch])
        next_states = np.stack([transition[3] for transition in mini_batch])

        predicted_Q = pred_Q(params_Q_eval, input_states)
        predicted_Q_next = pred_Q(params_Q_next, next_states)

        max_action = np.argmax(predicted_Q_next, axis=1)
        rewards = np.array([transition[2] for transition in mini_batch])

        Q_target = onp.array(predicted_Q)
        Q_target[:, max_action] = rewards + GAMMA * np.max(predicted_Q_next, axis=1)

        opt_state = step(opt_step, opt_state, (input_states, Q_target))
        params_Q_eval = optimizers.get_params(opt_state)

        return opt_state, params_Q_eval, params_Q_next

    scores = []
    eps_history = []
    eps = EPS_START if LEARN else 0

    for i in range(NUM_GAMES):
        print('starting game ', i + 1, 'epsilon: %.4f' % eps)
        eps_history.append(eps)
        done = False
        observation = env.reset()
        frames = deque([np.zeros((185, 95)) for _ in range(STACK_SIZE)], maxlen=STACK_SIZE)
        frames.append(preprocess(observation))
        state = stack_frames(frames)
        score = 0
        while not done:
            action = choose_action(env, state.reshape((1, 185, 95, STACK_SIZE)),
                                   pred_Q, params_Q_eval, eps)
            observation_, reward, done, info = env.step(action)
            score += reward

            if RENDER:
                env.render()

            if LEARN:
                frames.append(preprocess(observation))
                state_ = stack_frames(frames)
                memory = store_transition(memory, state, action, reward, state_)
                state = state_
                opt_state, params_Q_eval, params_Q_next = learn(opt_step, opt_state,
                                                                params_Q_eval, params_Q_next)
                opt_step += 1

                if opt_step > 500:
                    if eps - 1e-4 > EPS_END:
                        eps -= 1e-4
                    else:
                        eps = EPS_END

        if LEARN:
            out_path = os.path.join(WEIGHTS_PATH, 'params_Q_eval_' + str(i))

            onp.save(out_path, params_Q_eval)
        scores.append(score)
        print('score: ', score)
Exemplo n.º 12
0
def RmsProp(step_size, gamma=0.9, eps=1e-8):
    return OptimizerFromExperimental(
        experimental.rmsprop(step_size, gamma, eps))