Exemplo n.º 1
0
class ExperimentalOptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

    @chex.all_variants()
    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(
            LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1,
                                           momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
        ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR_SCHED, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1),
         optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum',
         alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR_SCHED,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
    )
    def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer,
                                      rtol):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = jax_optimizer
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)

        @self.variant
        def step(updates, state):
            return optax_optimizer.update(updates, state)

        for _ in range(STEPS):
            updates, state = step(self.per_step_updates, state)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
Exemplo n.º 2
0
def registration_individuals(x,y,aar_names,max_iter=10000,aars=None):

  if aars == None:
      aars = range(0,len(aar_names))
    
  aar_indices = [y == aar for aar in aars]
  uti_indices = [np.triu_indices(sum(y == aar),k=1) for aar in aars]
    
  def cost_function(x,y):
    def foo(x,uti): 
      dr = (x[:,uti[0]]-x[:,uti[1]])
      return np.sqrt(np.sum(dr*dr,axis=0)).sum()
    return sum([foo(x[:,aar_indices[aar]],uti_indices[aar]) for aar in range(0,len(aars))])

  def transform(param,x):
    thetas = param[0:len(x)]
    delta_ps = np.reshape(param[len(x):],(2,len(x)))
    return np.hstack([np.dot(np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]]),x_s)+np.expand_dims(delta_p,1) for theta,delta_p,x_s in zip(thetas,delta_ps.T,x)])

  def func(param,x,y):
    value = cost_function(transform(param,x),y)
    return value

  loss = lambda param: func(param,x,y)
    
  opt_init, opt_update, get_params = optimizers.adagrad(step_size=1,momentum=0.9)

  @jit
  def step(i, opt_state):
    params = get_params(opt_state)
    g = grad(loss)(params)
    return opt_update(i, g, opt_state)

  net_params = numpy.hstack((numpy.random.uniform(-numpy.pi,numpy.pi,len(x)),numpy.zeros(2*len(x))))
  previous_value = loss(net_params)
  logging.info('Iteration 0: loss = %f'%(previous_value))
  opt_state = opt_init(net_params)
  for i in range(max_iter):
    opt_state = step(i, opt_state)
    if i > 0 and i % 10 == 0:
      net_params = get_params(opt_state)
      current_value = loss(net_params)
      logging.info('Iteration %d: loss = %f'%(i+1,current_value))

      if numpy.isclose(previous_value/current_value,1):
          logging.info('Converged after %d iterations'%(i+1))
          net_params = get_params(opt_state)
          return transform(net_params,x)

      previous_value = current_value

  logging.warning('Not converged after %d iterations'%(i+1))
  net_params = get_params(opt_state)
  return transform(net_params,x)
Exemplo n.º 3
0
def get_optimizer(optimizer, sched, b1=0.9, b2=0.999):
  if optimizer.lower() == 'adagrad':
    return optimizers.adagrad(sched)
  elif optimizer.lower() == 'adam':
    return optimizers.adam(sched, b1, b2)
  elif optimizer.lower() == 'rmsprop':
    return optimizers.rmsprop(sched)
  elif optimizer.lower() == 'momentum':
    return optimizers.momentum(sched, 0.9)
  elif optimizer.lower() == 'sgd':
    return optimizers.sgd(sched)
  else:
    raise Exception('Invalid optimizer: {}'.format(optimizer))
Exemplo n.º 4
0
def adagrad(df,
            x0,
            lr=1,
            steps=jnp.inf,
            tol=0,
            gtol=0,
            momentum=0.9,
            verbose=True):
    opt_init, opt_update, get_params = optimizers.adagrad(step_size=lr,
                                                          momentum=momentum)
    x = jnp.array(x0)
    opt_state = opt_init(x)
    i = 0
    y = dy = jnp.inf
    obj = []
    xs = []
    dys = []
    if verbose:
        logger.info("Starting Adagrad with:")
        logger.info(("    x0       =" + len(x) * " {:+.2e},").format(*x)[:-1])
        logger.info("    lr       = {:.2e}".format(lr))
        logger.info("    momentum = {:.2e}".format(momentum))
        logger.info("    steps    = {:3d}".format(steps))
        logger.info("    tol      = {:.2e}".format(tol))
        logger.info("    gtol     = {:.2e}".format(gtol))
    while (i < steps) and (jnp.abs(y) > tol) and jnp.any(jnp.abs(dy) > gtol):
        x = get_params(opt_state)
        y, dy = df(x)
        opt_state = opt_update(i, dy, opt_state)
        xs.append(x)
        obj.append(y)
        dys.append(dy)
        i = i + 1
        if verbose:
            logger.info("iteration {:3d}".format(i))
            logger.info("    f(x)  = {:.2e}".format(y))
            logger.info(("    x     =" + len(x) * " {:+.2e},").format(*x)[:-1])
            logger.info(
                ("    df/dx =" + len(dy) * " {:+.2e},").format(*dy)[:-1])
    obj = jnp.array(obj)
    xs = jnp.array(xs)
    dys = jnp.array(dys)
    result = {"f": obj, "dfdx": dys, "x": xs}
    return result
Exemplo n.º 5
0
    def get_optimizer(self, optim=None, stage='learn', step_size=None):

        if optim is None:
            if stage == 'learn':
                optim = self.optim_learn
            else:
                optim = self.optim_proj
        if step_size is None:
            step_size = self.step_size

        if optim == 1:
            if self.verb > 2:
                print("With momentum optimizer")
            opt_init, opt_update, get_params = momentum(step_size=step_size,
                                                        mass=0.95)
        elif optim == 2:
            if self.verb > 2:
                print("With rmsprop optimizer")
            opt_init, opt_update, get_params = rmsprop(step_size,
                                                       gamma=0.9,
                                                       eps=1e-8)
        elif optim == 3:
            if self.verb > 2:
                print("With adagrad optimizer")
            opt_init, opt_update, get_params = adagrad(step_size, momentum=0.9)
        elif optim == 4:
            if self.verb > 2:
                print("With Nesterov optimizer")
            opt_init, opt_update, get_params = nesterov(step_size, 0.9)
        elif optim == 5:
            if self.verb > 2:
                print("With SGD optimizer")
            opt_init, opt_update, get_params = sgd(step_size)
        else:
            if self.verb > 2:
                print("With adam optimizer")
            opt_init, opt_update, get_params = adam(step_size)

        return opt_init, opt_update, get_params
Exemplo n.º 6
0
def get_optimizer(
    learning_rate: float = 1e-4, optimizer="sdg", optimizer_kwargs: dict = None
) -> JaxOptimizer:
    """Return a `JaxOptimizer` dataclass for a JAX optimizer

    Args:
        learning_rate (float, optional): Step size. Defaults to 1e-4.
        optimizer (str, optional): Optimizer type (Allowed types: "adam",
            "adamax", "adagrad", "rmsprop", "sdg"). Defaults to "sdg".
        optimizer_kwargs (dict, optional): Additional keyword arguments
            that are passed to the optimizer. Defaults to None.

    Returns:
        JaxOptimizer
    """
    from jax.config import config  # pylint:disable=import-outside-toplevel

    config.update("jax_enable_x64", True)
    from jax import jit  # pylint:disable=import-outside-toplevel
    from jax.experimental import optimizers  # pylint:disable=import-outside-toplevel

    if optimizer_kwargs is None:
        optimizer_kwargs = {}
    optimizer = optimizer.lower()
    if optimizer == "adam":
        opt_init, opt_update, get_params = optimizers.adam(learning_rate, **optimizer_kwargs)
    elif optimizer == "adagrad":
        opt_init, opt_update, get_params = optimizers.adagrad(learning_rate, **optimizer_kwargs)
    elif optimizer == "adamax":
        opt_init, opt_update, get_params = optimizers.adamax(learning_rate, **optimizer_kwargs)
    elif optimizer == "rmsprop":
        opt_init, opt_update, get_params = optimizers.rmsprop(learning_rate, **optimizer_kwargs)
    else:
        opt_init, opt_update, get_params = optimizers.sgd(learning_rate, **optimizer_kwargs)

    opt_update = jit(opt_update)

    return JaxOptimizer(opt_init, opt_update, get_params)
Exemplo n.º 7
0
elif args.linear_lr_decay:

    def lr_schedule(epoch):
        return args.lr * (args.epochs - epoch) / args.epochs

    lr = lr_schedule
else:
    lr = args.lr

if args.optimizer == 'sgd':
    opt_init, opt_apply, get_params = myopt.sgd(lr)
elif args.optimizer == 'momentum':
    opt_init, opt_apply, get_params = myopt.momentum(
        lr, args.momentum, weight_decay=args.weight_decay)
elif args.optimizer == 'adagrad':
    opt_init, opt_apply, get_params = optimizers.adagrad(lr, args.momentum)
elif args.optimizer == 'adam':
    opt_init, opt_apply, get_params = optimizers.adam(lr)

state = opt_init(params)

if args.loss == 'logistic':
    loss = lambda fx, y: np.mean(-np.sum(logsoftmax(fx) * y, axis=1))
elif args.loss == 'squared':
    loss = lambda fx, y: np.mean(np.sum((fx - y)**2, axis=1))
value_and_grad_loss = jit(
    value_and_grad(lambda params, x, y: loss(f(params, x), y)))
loss_fn = jit(lambda params, x, y: loss(f(params, x), y))
accuracy_sum = jit(
    lambda fx, y: np.sum(np.argmax(fx, axis=1) == np.argmax(y, axis=1)))
Exemplo n.º 8
0
def _JaxAdaGrad(machine, learning_rate=0.001, epscut=1.0e-7):
    return Wrap(machine, jaxopt.adagrad(learning_rate))
Exemplo n.º 9
0
class AliasTest(chex.TestCase):
    def setUp(self):
        super(AliasTest, self).setUp()
        self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

    @chex.all_variants()
    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(LR, .9, 0.1), optimizers.rmsprop(
            LR, .9, 0.1), 1e-5),
        ('adagrad', alias.adagrad(
            LR,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
    )
    def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer,
                                      rtol):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = jax_optimizer
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)

        @self.variant
        def step(updates, state):
            return optax_optimizer.update(updates, state)

        for _ in range(STEPS):
            updates, state = step(self.per_step_updates, state)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)

    @parameterized.named_parameters(
        ('sgd', alias.sgd(1e-2, 0.0)),
        ('adam', alias.adam(1e-1)),
        ('adamw', alias.adamw(1e-1)),
        ('lamb', alias.adamw(1e-1)),
        ('rmsprop', alias.rmsprop(1e-1)),
        ('fromage', transform.scale_by_fromage(-1e-2)),
        ('adabelief', alias.adabelief(1e-1)),
    )
    def test_parabel(self, opt):
        initial_params = jnp.array([-1.0, 10.0, 1.0])
        final_params = jnp.array([1.0, -1.0, 1.0])

        @jax.grad
        def get_updates(params):
            return jnp.sum((params - final_params)**2)

        @jax.jit
        def step(params, state):
            updates, state = opt.update(get_updates(params), state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(1000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=1e-2, atol=1e-2)

    @parameterized.named_parameters(
        ('sgd', alias.sgd(2e-3, 0.2)),
        ('adam', alias.adam(1e-1)),
        ('adamw', alias.adamw(1e-1)),
        ('lamb', alias.adamw(1e-1)),
        ('rmsprop', alias.rmsprop(5e-3)),
        ('fromage', transform.scale_by_fromage(-5e-3)),
        ('adabelief', alias.adabelief(1e-1)),
    )
    def test_rosenbrock(self, opt):
        a = 1.0
        b = 100.0
        initial_params = jnp.array([0.0, 0.0])
        final_params = jnp.array([a, a**2])

        @jax.grad
        def get_updates(params):
            return (a - params[0])**2 + b * (params[1] - params[0]**2)**2

        @jax.jit
        def step(params, state):
            updates, state = opt.update(get_updates(params), state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(10000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)
Exemplo n.º 10
0
 def __init__(self, learning_rate):
     super().__init__(learning_rate)
     self.opt_init, self.opt_update, self.get_params = adagrad(
         step_size=self.lr)
Exemplo n.º 11
0
def Adagrad(step_size=0.001, momentum=0.9):
    return OptimizerFromExperimental(experimental.adagrad(step_size, momentum))