class ExperimentalOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @chex.all_variants() @parameterized.named_parameters( ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop( LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR_SCHED, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR_SCHED, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ) def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
def get_optimizer(optimizer, sched, b1=0.9, b2=0.999): if optimizer.lower() == 'adagrad': return optimizers.adagrad(sched) elif optimizer.lower() == 'adam': return optimizers.adam(sched, b1, b2) elif optimizer.lower() == 'rmsprop': return optimizers.rmsprop(sched) elif optimizer.lower() == 'momentum': return optimizers.momentum(sched, 0.9) elif optimizer.lower() == 'sgd': return optimizers.sgd(sched) else: raise Exception('Invalid optimizer: {}'.format(optimizer))
def optimizer(name="adam", momentum_mass=0.9, rmsprop_gamma=0.9, rmsprop_eps=1e-8, adam_b1=0.9, adam_b2=0.997, adam_eps=1e-8): """Return the optimizer, by name.""" if name == "sgd": return optimizers.sgd(learning_rate) if name == "momentum": return optimizers.momentum(learning_rate, mass=momentum_mass) if name == "rmsprop": return optimizers.rmsprop( learning_rate, gamma=rmsprop_gamma, eps=rmsprop_eps) if name == "adam": return optimizers.adam(learning_rate, b1=adam_b1, b2=adam_b2, eps=adam_eps) raise ValueError("Unknown optimizer %s" % str(name))
def test_rmsprop(self): decay, eps = .9, 0.1 # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = optimizers.rmsprop(LR, decay, eps) state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # experimental/optix.py optix_params = self.init_params rmsprop = optix.rmsprop(LR, decay, eps) state = rmsprop.init(optix_params) for _ in range(STEPS): updates, state = rmsprop.update(self.per_step_updates, state) optix_params = optix.apply_updates(optix_params, updates) # Check equivalence. for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)): np.testing.assert_allclose(x, y, rtol=1e-5)
def get_optimizer(self, optim=None, stage='learn', step_size=None): if optim is None: if stage == 'learn': optim = self.optim_learn else: optim = self.optim_proj if step_size is None: step_size = self.step_size if optim == 1: if self.verb > 2: print("With momentum optimizer") opt_init, opt_update, get_params = momentum(step_size=step_size, mass=0.95) elif optim == 2: if self.verb > 2: print("With rmsprop optimizer") opt_init, opt_update, get_params = rmsprop(step_size, gamma=0.9, eps=1e-8) elif optim == 3: if self.verb > 2: print("With adagrad optimizer") opt_init, opt_update, get_params = adagrad(step_size, momentum=0.9) elif optim == 4: if self.verb > 2: print("With Nesterov optimizer") opt_init, opt_update, get_params = nesterov(step_size, 0.9) elif optim == 5: if self.verb > 2: print("With SGD optimizer") opt_init, opt_update, get_params = sgd(step_size) else: if self.verb > 2: print("With adam optimizer") opt_init, opt_update, get_params = adam(step_size) return opt_init, opt_update, get_params
def get_optimizer( learning_rate: float = 1e-4, optimizer="sdg", optimizer_kwargs: dict = None ) -> JaxOptimizer: """Return a `JaxOptimizer` dataclass for a JAX optimizer Args: learning_rate (float, optional): Step size. Defaults to 1e-4. optimizer (str, optional): Optimizer type (Allowed types: "adam", "adamax", "adagrad", "rmsprop", "sdg"). Defaults to "sdg". optimizer_kwargs (dict, optional): Additional keyword arguments that are passed to the optimizer. Defaults to None. Returns: JaxOptimizer """ from jax.config import config # pylint:disable=import-outside-toplevel config.update("jax_enable_x64", True) from jax import jit # pylint:disable=import-outside-toplevel from jax.experimental import optimizers # pylint:disable=import-outside-toplevel if optimizer_kwargs is None: optimizer_kwargs = {} optimizer = optimizer.lower() if optimizer == "adam": opt_init, opt_update, get_params = optimizers.adam(learning_rate, **optimizer_kwargs) elif optimizer == "adagrad": opt_init, opt_update, get_params = optimizers.adagrad(learning_rate, **optimizer_kwargs) elif optimizer == "adamax": opt_init, opt_update, get_params = optimizers.adamax(learning_rate, **optimizer_kwargs) elif optimizer == "rmsprop": opt_init, opt_update, get_params = optimizers.rmsprop(learning_rate, **optimizer_kwargs) else: opt_init, opt_update, get_params = optimizers.sgd(learning_rate, **optimizer_kwargs) opt_update = jit(opt_update) return JaxOptimizer(opt_init, opt_update, get_params)
def _JaxRmsProp(machine, learning_rate=0.001, beta=0.9, epscut=1.0e-7): return Wrap(machine, jaxopt.rmsprop(learning_rate, beta, epscut))
class AliasTest(chex.TestCase): def setUp(self): super(AliasTest, self).setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @chex.all_variants() @parameterized.named_parameters( ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop(LR, .9, 0.1), optimizers.rmsprop( LR, .9, 0.1), 1e-5), ('adagrad', alias.adagrad( LR, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ) def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol) @parameterized.named_parameters( ('sgd', alias.sgd(1e-2, 0.0)), ('adam', alias.adam(1e-1)), ('adamw', alias.adamw(1e-1)), ('lamb', alias.adamw(1e-1)), ('rmsprop', alias.rmsprop(1e-1)), ('fromage', transform.scale_by_fromage(-1e-2)), ('adabelief', alias.adabelief(1e-1)), ) def test_parabel(self, opt): initial_params = jnp.array([-1.0, 10.0, 1.0]) final_params = jnp.array([1.0, -1.0, 1.0]) @jax.grad def get_updates(params): return jnp.sum((params - final_params)**2) @jax.jit def step(params, state): updates, state = opt.update(get_updates(params), state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) for _ in range(1000): params, state = step(params, state) chex.assert_tree_all_close(params, final_params, rtol=1e-2, atol=1e-2) @parameterized.named_parameters( ('sgd', alias.sgd(2e-3, 0.2)), ('adam', alias.adam(1e-1)), ('adamw', alias.adamw(1e-1)), ('lamb', alias.adamw(1e-1)), ('rmsprop', alias.rmsprop(5e-3)), ('fromage', transform.scale_by_fromage(-5e-3)), ('adabelief', alias.adabelief(1e-1)), ) def test_rosenbrock(self, opt): a = 1.0 b = 100.0 initial_params = jnp.array([0.0, 0.0]) final_params = jnp.array([a, a**2]) @jax.grad def get_updates(params): return (a - params[0])**2 + b * (params[1] - params[0]**2)**2 @jax.jit def step(params, state): updates, state = opt.update(get_updates(params), state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) for _ in range(10000): params, state = step(params, state) chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)
def __init__(self, learning_rate): super().__init__(learning_rate) self.opt_init, self.opt_update, self.get_params = rmsprop( step_size=self.lr)
def main(): X, y, Xtest, ytest = get_data(50) # PRIOR FUNCTIONS (mean, covariance) mu_f = zero_mean cov_f = functools.partial(gram, rbf_kernel) gp_priors = (mu_f, cov_f) # Kernel, Likelihood parameters params = { "gamma": 2.0, # 'length_scale': 1.0, # 'var_f': 1.0, "likelihood_noise": 1.0, } # saturate parameters with likelihoods params = saturate(params) # LOSS FUNCTION mll_loss = jax.jit(functools.partial(marginal_likelihood, gp_priors)) # GRADIENT LOSS FUNCTION dloss = jax.jit(jax.grad(mll_loss)) # STEP FUNCTION @jax.jit def step(params, X, y, opt_state): # calculate loss loss = mll_loss(params, X, y) # calculate gradient of loss grads = dloss(params, X, y) # update optimizer state opt_state = opt_update(0, grads, opt_state) # update params params = get_params(opt_state) return params, opt_state, loss # TRAINING PARARMETERS n_epochs = 500 learning_rate = 0.01 losses = list() # initialize optimizer opt_init, opt_update, get_params = optimizers.rmsprop( step_size=learning_rate) # initialize parameters opt_state = opt_init(params) # get initial parameters params = get_params(opt_state) postfix = {} with tqdm.trange(n_epochs) as bar: for i in bar: # 1 step - optimize function params, opt_state, value = step(params, X, y, opt_state) # update params postfix = {} for ikey in params.keys(): postfix[ikey] = f"{jax.nn.softplus(params[ikey]):.2f}" # save loss values losses.append(value.mean()) # update progress bar postfix["Loss"] = f"{onp.array(losses[-1]):.2f}" bar.set_postfix(postfix) # saturate params params = saturate(params) # Posterior Predictions mu_y, var_y = posterior(params, gp_priors, X, y, Xtest, True, False) # Uncertainty uncertainty = 1.96 * jnp.sqrt(var_y.squeeze()) fig, ax = plt.subplots(ncols=2, figsize=(10, 5)) ax[0].scatter(X, y, c="red", label="Training Data") ax[0].plot( Xtest.squeeze(), mu_y.squeeze(), label=r"Predictive Mean", color="black", linewidth=3, ) ax[0].fill_between( Xtest.squeeze(), mu_y.squeeze() + uncertainty, mu_y.squeeze() - uncertainty, alpha=0.3, color="darkorange", label=f"Predictive Std (95% Confidence)", ) ax[0].legend(fontsize=12) ax[1].plot(losses, label="losses") plt.tight_layout() fig.savefig("figures/jaxgp/examples/1d_example.png") plt.show()
def main(): env = gym.make('SpaceInvaders-v0') memory = deque(maxlen=MEM_SIZE) # fill memory with random interactions with the environment while len(memory) < MEM_SIZE: observation = env.reset() frames = deque([np.zeros((185, 95)) for _ in range(STACK_SIZE)], maxlen=STACK_SIZE) frames.append(preprocess(observation)) state = stack_frames(frames) done = False while not done: # 0 no action, 1 fire, 2 move right, 3 move left, 4 move right fire, 5 move left fire action = env.action_space.sample() observation_, reward, done, info = env.step(action) frames.append(preprocess(observation_)) state_ = stack_frames(frames) memory = store_transition(memory, state, action, reward, state_) state = state_ print('done initializing memory') init_Q, pred_Q = DeepQNetwork() # two separate Q-Table approximations (eval and next) # initialize parameters, not committing to a batch size (NHWC) # we choose 3 channels as we want to pass stacks of 4 consecutive frames in_shape = (-1, 185, 95, STACK_SIZE) if LOAD: path = os.path.join(WEIGHTS_PATH, "params_Q_eval.npy") params_Q_eval = load_params(path) else: _, params_Q_eval = init_Q(in_shape) params_Q_next = params_Q_eval.copy() # Initialize RMSProp optimizer opt_init, opt_update = optimizers.rmsprop(ALPHA) opt_state = opt_init(params_Q_eval) opt_step = 0 # Define a simple mean-squared-error loss def loss(params, batch): inputs, targets = batch predictions = pred_Q(params, inputs) return np.mean((predictions - targets) ** 2) # Define a compiled update step @jit def step(j, opt_state, batch): params = optimizers.get_params(opt_state) g = grad(loss)(params, batch) return opt_update(j, g, opt_state) def learn(opt_step, opt_state, params_Q_eval, params_Q_next): mini_batch = sample(memory, BATCH_SIZE) if opt_step % TAU == 0: params_Q_next = params_Q_eval.copy() input_states = np.stack([transition[0] for transition in mini_batch]) next_states = np.stack([transition[3] for transition in mini_batch]) predicted_Q = pred_Q(params_Q_eval, input_states) predicted_Q_next = pred_Q(params_Q_next, next_states) max_action = np.argmax(predicted_Q_next, axis=1) rewards = np.array([transition[2] for transition in mini_batch]) Q_target = onp.array(predicted_Q) Q_target[:, max_action] = rewards + GAMMA * np.max(predicted_Q_next, axis=1) opt_state = step(opt_step, opt_state, (input_states, Q_target)) params_Q_eval = optimizers.get_params(opt_state) return opt_state, params_Q_eval, params_Q_next scores = [] eps_history = [] eps = EPS_START if LEARN else 0 for i in range(NUM_GAMES): print('starting game ', i + 1, 'epsilon: %.4f' % eps) eps_history.append(eps) done = False observation = env.reset() frames = deque([np.zeros((185, 95)) for _ in range(STACK_SIZE)], maxlen=STACK_SIZE) frames.append(preprocess(observation)) state = stack_frames(frames) score = 0 while not done: action = choose_action(env, state.reshape((1, 185, 95, STACK_SIZE)), pred_Q, params_Q_eval, eps) observation_, reward, done, info = env.step(action) score += reward if RENDER: env.render() if LEARN: frames.append(preprocess(observation)) state_ = stack_frames(frames) memory = store_transition(memory, state, action, reward, state_) state = state_ opt_state, params_Q_eval, params_Q_next = learn(opt_step, opt_state, params_Q_eval, params_Q_next) opt_step += 1 if opt_step > 500: if eps - 1e-4 > EPS_END: eps -= 1e-4 else: eps = EPS_END if LEARN: out_path = os.path.join(WEIGHTS_PATH, 'params_Q_eval_' + str(i)) onp.save(out_path, params_Q_eval) scores.append(score) print('score: ', score)
def RmsProp(step_size, gamma=0.9, eps=1e-8): return OptimizerFromExperimental( experimental.rmsprop(step_size, gamma, eps))