class ExperimentalOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @chex.all_variants() @parameterized.named_parameters( ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop( LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR_SCHED, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR_SCHED, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ) def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
def registration_individuals(x,y,aar_names,max_iter=10000,aars=None): if aars == None: aars = range(0,len(aar_names)) aar_indices = [y == aar for aar in aars] uti_indices = [np.triu_indices(sum(y == aar),k=1) for aar in aars] def cost_function(x,y): def foo(x,uti): dr = (x[:,uti[0]]-x[:,uti[1]]) return np.sqrt(np.sum(dr*dr,axis=0)).sum() return sum([foo(x[:,aar_indices[aar]],uti_indices[aar]) for aar in range(0,len(aars))]) def transform(param,x): thetas = param[0:len(x)] delta_ps = np.reshape(param[len(x):],(2,len(x))) return np.hstack([np.dot(np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]]),x_s)+np.expand_dims(delta_p,1) for theta,delta_p,x_s in zip(thetas,delta_ps.T,x)]) def func(param,x,y): value = cost_function(transform(param,x),y) return value loss = lambda param: func(param,x,y) opt_init, opt_update, get_params = optimizers.adagrad(step_size=1,momentum=0.9) @jit def step(i, opt_state): params = get_params(opt_state) g = grad(loss)(params) return opt_update(i, g, opt_state) net_params = numpy.hstack((numpy.random.uniform(-numpy.pi,numpy.pi,len(x)),numpy.zeros(2*len(x)))) previous_value = loss(net_params) logging.info('Iteration 0: loss = %f'%(previous_value)) opt_state = opt_init(net_params) for i in range(max_iter): opt_state = step(i, opt_state) if i > 0 and i % 10 == 0: net_params = get_params(opt_state) current_value = loss(net_params) logging.info('Iteration %d: loss = %f'%(i+1,current_value)) if numpy.isclose(previous_value/current_value,1): logging.info('Converged after %d iterations'%(i+1)) net_params = get_params(opt_state) return transform(net_params,x) previous_value = current_value logging.warning('Not converged after %d iterations'%(i+1)) net_params = get_params(opt_state) return transform(net_params,x)
def get_optimizer(optimizer, sched, b1=0.9, b2=0.999): if optimizer.lower() == 'adagrad': return optimizers.adagrad(sched) elif optimizer.lower() == 'adam': return optimizers.adam(sched, b1, b2) elif optimizer.lower() == 'rmsprop': return optimizers.rmsprop(sched) elif optimizer.lower() == 'momentum': return optimizers.momentum(sched, 0.9) elif optimizer.lower() == 'sgd': return optimizers.sgd(sched) else: raise Exception('Invalid optimizer: {}'.format(optimizer))
def adagrad(df, x0, lr=1, steps=jnp.inf, tol=0, gtol=0, momentum=0.9, verbose=True): opt_init, opt_update, get_params = optimizers.adagrad(step_size=lr, momentum=momentum) x = jnp.array(x0) opt_state = opt_init(x) i = 0 y = dy = jnp.inf obj = [] xs = [] dys = [] if verbose: logger.info("Starting Adagrad with:") logger.info((" x0 =" + len(x) * " {:+.2e},").format(*x)[:-1]) logger.info(" lr = {:.2e}".format(lr)) logger.info(" momentum = {:.2e}".format(momentum)) logger.info(" steps = {:3d}".format(steps)) logger.info(" tol = {:.2e}".format(tol)) logger.info(" gtol = {:.2e}".format(gtol)) while (i < steps) and (jnp.abs(y) > tol) and jnp.any(jnp.abs(dy) > gtol): x = get_params(opt_state) y, dy = df(x) opt_state = opt_update(i, dy, opt_state) xs.append(x) obj.append(y) dys.append(dy) i = i + 1 if verbose: logger.info("iteration {:3d}".format(i)) logger.info(" f(x) = {:.2e}".format(y)) logger.info((" x =" + len(x) * " {:+.2e},").format(*x)[:-1]) logger.info( (" df/dx =" + len(dy) * " {:+.2e},").format(*dy)[:-1]) obj = jnp.array(obj) xs = jnp.array(xs) dys = jnp.array(dys) result = {"f": obj, "dfdx": dys, "x": xs} return result
def get_optimizer(self, optim=None, stage='learn', step_size=None): if optim is None: if stage == 'learn': optim = self.optim_learn else: optim = self.optim_proj if step_size is None: step_size = self.step_size if optim == 1: if self.verb > 2: print("With momentum optimizer") opt_init, opt_update, get_params = momentum(step_size=step_size, mass=0.95) elif optim == 2: if self.verb > 2: print("With rmsprop optimizer") opt_init, opt_update, get_params = rmsprop(step_size, gamma=0.9, eps=1e-8) elif optim == 3: if self.verb > 2: print("With adagrad optimizer") opt_init, opt_update, get_params = adagrad(step_size, momentum=0.9) elif optim == 4: if self.verb > 2: print("With Nesterov optimizer") opt_init, opt_update, get_params = nesterov(step_size, 0.9) elif optim == 5: if self.verb > 2: print("With SGD optimizer") opt_init, opt_update, get_params = sgd(step_size) else: if self.verb > 2: print("With adam optimizer") opt_init, opt_update, get_params = adam(step_size) return opt_init, opt_update, get_params
def get_optimizer( learning_rate: float = 1e-4, optimizer="sdg", optimizer_kwargs: dict = None ) -> JaxOptimizer: """Return a `JaxOptimizer` dataclass for a JAX optimizer Args: learning_rate (float, optional): Step size. Defaults to 1e-4. optimizer (str, optional): Optimizer type (Allowed types: "adam", "adamax", "adagrad", "rmsprop", "sdg"). Defaults to "sdg". optimizer_kwargs (dict, optional): Additional keyword arguments that are passed to the optimizer. Defaults to None. Returns: JaxOptimizer """ from jax.config import config # pylint:disable=import-outside-toplevel config.update("jax_enable_x64", True) from jax import jit # pylint:disable=import-outside-toplevel from jax.experimental import optimizers # pylint:disable=import-outside-toplevel if optimizer_kwargs is None: optimizer_kwargs = {} optimizer = optimizer.lower() if optimizer == "adam": opt_init, opt_update, get_params = optimizers.adam(learning_rate, **optimizer_kwargs) elif optimizer == "adagrad": opt_init, opt_update, get_params = optimizers.adagrad(learning_rate, **optimizer_kwargs) elif optimizer == "adamax": opt_init, opt_update, get_params = optimizers.adamax(learning_rate, **optimizer_kwargs) elif optimizer == "rmsprop": opt_init, opt_update, get_params = optimizers.rmsprop(learning_rate, **optimizer_kwargs) else: opt_init, opt_update, get_params = optimizers.sgd(learning_rate, **optimizer_kwargs) opt_update = jit(opt_update) return JaxOptimizer(opt_init, opt_update, get_params)
elif args.linear_lr_decay: def lr_schedule(epoch): return args.lr * (args.epochs - epoch) / args.epochs lr = lr_schedule else: lr = args.lr if args.optimizer == 'sgd': opt_init, opt_apply, get_params = myopt.sgd(lr) elif args.optimizer == 'momentum': opt_init, opt_apply, get_params = myopt.momentum( lr, args.momentum, weight_decay=args.weight_decay) elif args.optimizer == 'adagrad': opt_init, opt_apply, get_params = optimizers.adagrad(lr, args.momentum) elif args.optimizer == 'adam': opt_init, opt_apply, get_params = optimizers.adam(lr) state = opt_init(params) if args.loss == 'logistic': loss = lambda fx, y: np.mean(-np.sum(logsoftmax(fx) * y, axis=1)) elif args.loss == 'squared': loss = lambda fx, y: np.mean(np.sum((fx - y)**2, axis=1)) value_and_grad_loss = jit( value_and_grad(lambda params, x, y: loss(f(params, x), y))) loss_fn = jit(lambda params, x, y: loss(f(params, x), y)) accuracy_sum = jit( lambda fx, y: np.sum(np.argmax(fx, axis=1) == np.argmax(y, axis=1)))
def _JaxAdaGrad(machine, learning_rate=0.001, epscut=1.0e-7): return Wrap(machine, jaxopt.adagrad(learning_rate))
class AliasTest(chex.TestCase): def setUp(self): super(AliasTest, self).setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @chex.all_variants() @parameterized.named_parameters( ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop(LR, .9, 0.1), optimizers.rmsprop( LR, .9, 0.1), 1e-5), ('adagrad', alias.adagrad( LR, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ) def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol) @parameterized.named_parameters( ('sgd', alias.sgd(1e-2, 0.0)), ('adam', alias.adam(1e-1)), ('adamw', alias.adamw(1e-1)), ('lamb', alias.adamw(1e-1)), ('rmsprop', alias.rmsprop(1e-1)), ('fromage', transform.scale_by_fromage(-1e-2)), ('adabelief', alias.adabelief(1e-1)), ) def test_parabel(self, opt): initial_params = jnp.array([-1.0, 10.0, 1.0]) final_params = jnp.array([1.0, -1.0, 1.0]) @jax.grad def get_updates(params): return jnp.sum((params - final_params)**2) @jax.jit def step(params, state): updates, state = opt.update(get_updates(params), state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) for _ in range(1000): params, state = step(params, state) chex.assert_tree_all_close(params, final_params, rtol=1e-2, atol=1e-2) @parameterized.named_parameters( ('sgd', alias.sgd(2e-3, 0.2)), ('adam', alias.adam(1e-1)), ('adamw', alias.adamw(1e-1)), ('lamb', alias.adamw(1e-1)), ('rmsprop', alias.rmsprop(5e-3)), ('fromage', transform.scale_by_fromage(-5e-3)), ('adabelief', alias.adabelief(1e-1)), ) def test_rosenbrock(self, opt): a = 1.0 b = 100.0 initial_params = jnp.array([0.0, 0.0]) final_params = jnp.array([a, a**2]) @jax.grad def get_updates(params): return (a - params[0])**2 + b * (params[1] - params[0]**2)**2 @jax.jit def step(params, state): updates, state = opt.update(get_updates(params), state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) for _ in range(10000): params, state = step(params, state) chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)
def __init__(self, learning_rate): super().__init__(learning_rate) self.opt_init, self.opt_update, self.get_params = adagrad( step_size=self.lr)
def Adagrad(step_size=0.001, momentum=0.9): return OptimizerFromExperimental(experimental.adagrad(step_size, momentum))