Example #1
0
    def optimize_params(self, p0, num_iters, step_size, tolerance, verbose):

        """
        
        Perform gradient descent using JAX optimizers. 

        """

        opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)
        opt_state = opt_init(p0)

        @jit
        def step(_i, _opt_state):
            p = get_params(_opt_state)
            g = grad(self.negative_log_evidence)(p)
            return opt_update(_i, g, _opt_state)

        cost_list = []
        params_list = []

        if verbose:
            self.print_progress_header(p0)

        for i in range(num_iters):

            opt_state = step(i, opt_state)
            params_list.append(get_params(opt_state))
            cost_list.append(self.negative_log_evidence(params_list[-1]))

            if verbose:
                if i % verbose == 0:
                    self.print_progress(i, params_list[-1], cost_list[-1])

            if len(params_list) > tolerance:

                if jnp.all((jnp.array(cost_list[1:])) - jnp.array(cost_list[:-1]) > 0):
                    params = params_list[0]
                    if verbose:
                        print('Stop: cost has been monotonically increasing for {} steps.'.format(tolerance))
                    break
                elif jnp.all(jnp.array(cost_list[:-1]) - jnp.array(cost_list[1:]) < 1e-5):
                    params = params_list[-1]
                    if verbose:
                        print('Stop: cost has been stop changing for {} steps.'.format(tolerance))
                    break
                else:
                    params_list.pop(0)
                    cost_list.pop(0)

        else:

            params = params_list[-1]
            if verbose:
                print('Stop: reached {0} steps, final cost={1:.5f}.'.format(num_iters, cost_list[-1]))

        return params
Example #2
0
    def __init__(self,
                 vocab_size,
                 output_size,
                 loss_fn,
                 rng,
                 temperature=1,
                 learning_rate=0.001,
                 conv_depth=300,
                 n_conv_layers=2,
                 n_dense_units=300,
                 n_dense_layers=0,
                 kernel_size=5,
                 across_batch=False,
                 add_pos_encoding=False,
                 mean_over_pos=False,
                 model_fn=build_model_stax):
        self.output_size = output_size
        self.temperature = temperature

        # Setup randomness.
        self.rng = rng

        model_settings = {
            "output_size": output_size,
            "n_dense_units": n_dense_units,
            "n_dense_layers": n_dense_layers,
            "conv_depth": conv_depth,
            "n_conv_layers": n_conv_layers,
            "across_batch": across_batch,
            "kernel_size": kernel_size,
            "add_pos_encoding": add_pos_encoding,
            "mean_over_pos": mean_over_pos,
            "mode": "train"
        }

        self._model_init, model_train = model_fn(**model_settings)
        self._model_train = jax.jit(model_train)

        model_settings["mode"] = "eval"
        _, model_predict = model_fn(**model_settings)
        self._model_predict = jax.jit(model_predict)

        self.rng, subrng = jrand.split(self.rng)
        _, init_params = self._model_init(subrng, (-1, -1, vocab_size))
        self.params = init_params

        # Setup parameters for model and optimizer
        self.make_state, self._opt_update_state, self._get_params = adam(
            learning_rate)

        self.loss_fn = functools.partial(loss_fn, run_model_fn=self.run_model)
        self.loss_grad_fn = jax.grad(self.loss_fn)

        # Track steps of optimization so far.
        self._step_idx = 0
Example #3
0
    def __init__(self, optimizer_name, model, lr):
        if optimizer_name not in jax_optimizers:
            raise ValueError(
                "Optimizer {} is not implemented yet".format(name))

        if optimizer_name == "adam":
            opt_init, opt_update, get_params = optimizers.adam(step_size=lr)

        self.opt_state = opt_init(model.net_params)
        self.get_params = get_params

        @jit
        def step(i, _opt_state, data):
            _params = get_params(_opt_state)
            loss_func = model.loss_func
            g = jit(jax.grad(loss_func))(_params, data)
            return opt_update(i, g, _opt_state)

        self._step = step
        self.iter_cnt = 0
Example #4
0
        return Trace_ELBO().loss(random.PRNGKey(0), {}, model, guide, x)

    def renyi_loss_fn(x):
        return RenyiELBO(alpha=alpha,
                         num_particles=10).loss(random.PRNGKey(0), {}, model,
                                                guide, x)

    elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.0)
    renyi_loss, renyi_grad = value_and_grad(renyi_loss_fn)(2.0)
    assert_allclose(elbo_loss, renyi_loss, rtol=1e-6)
    assert_allclose(elbo_grad, renyi_grad, rtol=1e-6)


@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)])
@pytest.mark.parametrize(
    "optimizer", [optim.Adam(0.05), optimizers.adam(0.05)])
def test_beta_bernoulli(elbo, optimizer):
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
        with numpyro.plate("N", len(data)):
            numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q",
                                1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
Example #5
0
    def optimize_params(self,
                        p0,
                        extra,
                        num_epochs,
                        num_iters,
                        metric,
                        step_size,
                        tolerance,
                        verbose,
                        return_model=None) -> dict:
        """
        Gradient descent using JAX optimizer, and verbose logging.
        """
        if return_model is None:
            if extra is not None:
                return_model = 'best_dev_cost'
            else:
                return_model = 'best_train_cost'

        assert (extra is not None) or (
            'dev'
            not in return_model), 'Cannot use dev set if dev set is not given.'

        if num_epochs != 1:
            raise NotImplementedError()

        @jit
        def step(_i, _opt_state):
            p = get_params(_opt_state)
            g = grad(self.cost)(p)
            return opt_update(_i, g, _opt_state)

        # preallocation
        cost_train = np.zeros(num_iters)
        cost_dev = np.zeros(num_iters)
        metric_train = np.zeros(num_iters)
        metric_dev = np.zeros(num_iters)
        params_list = []

        if verbose:
            self.print_progress_header(c_train=True,
                                       c_dev=extra,
                                       m_train=metric is not None,
                                       m_dev=metric is not None and extra)

        time_start = time.time()

        opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)
        opt_state = opt_init(p0)

        i = 0
        c_dev = None
        m_train = None
        m_dev = None
        y_pred_dev = None

        for i in range(num_iters):

            opt_state = step(i, opt_state)
            params = get_params(opt_state)
            params_list.append(params)

            y_pred_train = self.forwardpass(p=params, extra=None)
            c_train = self.cost(p=params, precomputed=y_pred_train)
            cost_train[i] = c_train

            if extra is not None:
                y_pred_dev = self.forwardpass(p=params, extra=extra)
                c_dev = self.cost(p=params,
                                  extra=extra,
                                  precomputed=y_pred_dev)
                cost_dev[i] = c_dev

            if metric is not None:

                m_train = self.compute_score(self.y, y_pred_train, metric)
                metric_train[i] = m_train

                if extra is not None:
                    m_dev = self.compute_score(extra['y'], y_pred_dev, metric)
                    metric_dev[i] = m_dev

            time_elapsed = time.time() - time_start
            if verbose and (i % int(verbose) == 0):
                self.print_progress(i,
                                    time_elapsed,
                                    c_train=c_train,
                                    c_dev=c_dev,
                                    m_train=m_train,
                                    m_dev=m_dev)

            if tolerance and i > 300:  # tolerance = 0: no early stop.

                total_time_elapsed = time.time() - time_start
                cost_train_slice = cost_train[i - tolerance:i]
                cost_dev_slice = cost_dev[i - tolerance:i]

                if jnp.all(cost_dev_slice[1:] - cost_dev_slice[:-1] > 0):
                    stop = 'dev_stop'
                    if verbose:
                        print(
                            f'Stop at {i} steps: ' +
                            f'cost (dev) has been monotonically increasing for {tolerance} steps.\n'
                        )
                    break

                if jnp.all(
                        cost_train_slice[:-1] - cost_train_slice[1:] < 1e-5):
                    stop = 'train_stop'
                    if verbose:
                        print(
                            f'Stop at {i} steps: ' +
                            f'cost (train) has been changing less than 1e-5 for {tolerance} steps.\n'
                        )
                    break

        else:
            total_time_elapsed = time.time() - time_start
            stop = 'maxiter_stop'

            if verbose:
                print('Stop: reached {0} steps.\n'.format(num_iters))

        if return_model == 'best_dev_cost':
            best = np.argmin(cost_dev[:i + 1])

        elif return_model == 'best_train_cost':
            best = np.argmin(cost_train[:i + 1])

        elif return_model == 'best_dev_metric':
            if metric in ['mse', 'gcv']:
                best = np.argmin(metric_dev[:i + 1])
            else:
                best = np.argmax(metric_dev[:i + 1])

        elif return_model == 'best_train_metric':
            if metric in ['mse', 'gcv']:
                best = np.argmin(metric_train[:i + 1])
            else:
                best = np.argmax(metric_train[:i + 1])

        else:
            if return_model != 'last':
                print(
                    'Provided `return_model` is not supported. Fallback to `last`'
                )
            if stop == 'dev_stop':
                best = i - tolerance
            else:
                best = i

        params = params_list[best]
        metric_dev_opt = metric_dev[best]

        self.cost_train = cost_train[:i + 1]
        self.cost_dev = cost_dev[:i + 1]
        self.metric_train = metric_train[:i + 1]
        self.metric_dev = metric_dev[:i + 1]
        self.metric_dev_opt = metric_dev_opt
        self.total_time_elapsed = total_time_elapsed

        return params
Example #6
0
    def initialize_parametric_nonlinearity(self,
                                           init_to='exponential',
                                           method=None,
                                           params_dict=None):

        if method is None:  # if no methods specified, use defaults.
            method = self.output_nonlinearity or self.filter_nonlinearity
        else:  # otherwise, overwrite the default nonlinearity.
            self.output_nonlinearity = method
            if self.filter_nonlinearity is not None:
                self.filter_nonlinearity = method

        assert method is not None

        # prepare data
        if params_dict is None:
            params_dict = {}
        xrange = params_dict['xrange'] if 'xrange' in params_dict else 5
        nx = params_dict['nx'] if 'nx' in params_dict else 1000
        x0 = jnp.linspace(-xrange, xrange, nx)

        if init_to == 'exponential':
            y0 = jnp.exp(x0)
        elif init_to == 'softplus':
            y0 = softplus(x0)
        elif init_to == 'relu':
            y0 = relu(x0)
        elif init_to == 'nonparametric':
            y0 = self.fnl_nonparametric(x0)
        elif init_to == 'gaussian':
            import scipy.signal
            # noinspection PyUnresolvedReferences
            y0 = scipy.signal.gaussian(nx, nx / 10)
        else:
            raise NotImplementedError(init_to)

        # fit nonlin
        if method == 'spline':
            smooth = params_dict['smooth'] if 'smooth' in params_dict else 'cr'
            df = params_dict['df'] if 'df' in params_dict else 7
            if smooth == 'cr':
                X = cr(x0, df)
            elif smooth == 'cc':
                X = cc(x0, df)
            elif smooth == 'bs':
                deg = params_dict['degree'] if 'degree' in params_dict else 3
                X = bs(x0, df, deg)
            else:
                raise NotImplementedError(smooth)

            opt_params = jnp.linalg.pinv(X.T @ X) @ X.T @ y0

            self.nl_basis = X

            def _nl(_opt_params, x_new):
                return jnp.maximum(interp1d(x0, X @ _opt_params)(x_new), 0)

        elif method == 'nn':

            def loss(_params, _data):
                x = _data['x']
                y = _data['y']
                yhat = _predict(_params, x)
                return jnp.mean((y - yhat)**2)

            @jit
            def step(_i, _opt_state, _data):
                p = get_params(_opt_state)
                g = grad(loss)(p, _data)
                return opt_update(_i, g, _opt_state)

            random_seed = params_dict[
                'random_seed'] if 'random_seed' in params_dict else 2046
            key = random.PRNGKey(random_seed)

            step_size = params_dict[
                'step_size'] if 'step_size' in params_dict else 0.01
            layer_sizes = params_dict[
                'layer_sizes'] if 'layer_sizes' in params_dict else [
                    10, 10, 1
                ]
            layers = []
            for layer_size in layer_sizes:
                layers.append(Dense(layer_size))
                layers.append(BatchNorm(axis=(0, 1)))
                layers.append(Relu)
            else:
                layers.pop(-1)

            init_random_params, _predict = stax.serial(*layers)

            num_subunits = params_dict[
                'num_subunits'] if 'num_subunits' in params_dict else 1
            _, init_params = init_random_params(key, (-1, num_subunits))

            opt_init, opt_update, get_params = optimizers.adam(step_size)
            opt_state = opt_init(init_params)

            num_iters = params_dict[
                'num_iters'] if 'num_iters' in params_dict else 1000
            if num_subunits == 1:
                data = {'x': x0.reshape(-1, 1), 'y': y0.reshape(-1, 1)}
            else:
                data = {
                    'x': jnp.vstack([x0 for _ in range(num_subunits)]).T,
                    'y': y0.reshape(-1, 1)
                }

            for i in range(num_iters):
                opt_state = step(i, opt_state, data)
            opt_params = get_params(opt_state)

            def _nl(_opt_params, x_new):
                if len(x_new.shape) == 1:
                    x_new = x_new.reshape(-1, 1)
                return jnp.maximum(_predict(_opt_params, x_new), 0)
        else:
            raise NotImplementedError(method)

        self.nl_xrange = x0
        self.nl_params = opt_params
        self.fnl_fitted = _nl
Example #7
0
class OptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4., 5.]))
        self.per_step_updates = (jnp.array([500.,
                                            5.]), jnp.array([300., 3., 1.]))

    @chex.all_variants()
    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(
            LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1,
                                           momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
        ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR_SCHED, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1),
         optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum',
         alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR_SCHED,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
        ('sm3', alias.sm3(LR), optimizers.sm3(LR), 1e-2),
    )
    def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer,
                                      rtol):

        # example_libraries/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = jax_optimizer
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)

        @self.variant
        def step(updates, state):
            return optax_optimizer.update(updates, state)

        for _ in range(STEPS):
            updates, state = step(self.per_step_updates, state)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
Example #8
0
def main():
    start_time = time()
    init_out_dir()
    last_step = get_last_ckpt_step()
    if last_step >= 0:
        my_log(f'\nCheckpoint found: {last_step}\n')
    else:
        clear_log()
    print_args()

    net_init, net_apply, net_init_cache, net_apply_fast = get_net()

    rng, rng_net = jrand.split(jrand.PRNGKey(args.seed))
    in_shape = (args.batch_size, args.L, args.L, 1)
    out_shape, params_init = net_init(rng_net, in_shape)

    _, cache_init = net_init_cache(params_init, jnp.zeros(in_shape), (-1, -1))

    # sample_fun = get_sample_fun(net_apply, None)
    sample_fun = get_sample_fun(net_apply_fast, cache_init)
    log_q_fun = get_log_q_fun(net_apply)

    need_beta_anneal = args.beta_anneal_step > 0

    opt_init, opt_update, get_params = optimizers.adam(args.lr)

    @jit
    def update(step, opt_state, rng):
        params = get_params(opt_state)
        rng, rng_sample = jrand.split(rng)
        spins = sample_fun(args.batch_size, params, rng_sample)
        log_q = log_q_fun(params, spins) / args.L**2
        energy = energy_fun(spins) / args.L**2

        def neg_log_Z_fun(params, spins):
            log_q = log_q_fun(params, spins) / args.L**2
            energy = energy_fun(spins) / args.L**2
            beta = args.beta
            if need_beta_anneal:
                beta *= jnp.minimum(step / args.beta_anneal_step, 1)
            neg_log_Z = log_q + beta * energy
            return neg_log_Z

        loss_fun = partial(expect,
                           log_q_fun,
                           neg_log_Z_fun,
                           mean_grad_expected_is_zero=True)
        grads = grad(loss_fun)(params, spins, spins)
        opt_state = opt_update(step, grads, opt_state)

        return spins, log_q, energy, opt_state, rng

    if last_step >= 0:
        params_init = load_ckpt(last_step)

    opt_state = opt_init(params_init)

    my_log('Training...')
    for step in range(last_step + 1, args.max_step + 1):
        spins, log_q, energy, opt_state, rng = update(step, opt_state, rng)

        if args.print_step and step % args.print_step == 0:
            # Use the final beta, not the annealed beta
            free_energy = log_q / args.beta + energy
            my_log(', '.join([
                f'step = {step}',
                f'F = {free_energy.mean():.8g}',
                f'F_std = {free_energy.std():.8g}',
                f'S = {-log_q.mean():.8g}',
                f'E = {energy.mean():.8g}',
                f'time = {time() - start_time:.3f}',
            ]))

        if args.save_step and step % args.save_step == 0:
            params = get_params(opt_state)
            save_ckpt(params, step)
    return q_state_out


def cost(params):
    state = circuit(params)
    op_qs = jnp.array([[0], [0.707], [0.707], [0]])
    fid = jnp.abs(jnp.dot(jnp.transpose(jnp.conjugate(op_qs)), state))**2
    return -jnp.real(fid)[0][0]


# fixed random parameter initialization
init_params = [
    2 * jnp.pi * np.random.rand(), 2 * jnp.pi * np.random.rand(),
    2 * jnp.pi * np.random.rand()
]
opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)
opt_state = opt_init(init_params)


def step(i, opt_state, opt_update):
    params = get_params(opt_state)
    g = grad(cost)(params)
    return opt_update(i, g, opt_state)


epoch = 0
epoch_max = 250

loss_hist = []
loss = cost(init_params)
Example #10
0
    def __init__(self,
                 n=2**7 - 1,
                 rng=None,
                 channels=8,
                 loss=loss_gmres,
                 iter_gmres=lambda i: 10,
                 training_iter=500,
                 name='net',
                 model_dir=None,
                 lr=3e-4,
                 k=0.0,
                 n_test=10,
                 beta1=0.9,
                 beta2=0.999,
                 lr_og=3e-3,
                 flaxd=False):
        self.n = n
        self.n_test = n_test
        self.mesh = meshes.Mesh(n)
        self.in_shape = (-1, n, n, 1)
        self.inner_channels = channels

        def itera(i):
            return onp.random.choice([5, 10, 10, 10, 10, 15, 15, 15, 20, 25])

        self.iter_gmres = itera
        self.training_iter = training_iter
        self.name = name
        self.k = k
        self.model_dir = model_dir
        if flaxd:
            self.test_loss = loss_gmresR_flax
        else:
            self.test_loss = loss_gmresR
        self.beta1 = beta1
        self.beta2 = beta2
        if rng is None:
            rng = random.PRNGKey(1)
        if not flaxd:
            self.net_init, self.net_apply = stax.serial(
                UNetBlock(1, (3, 3),
                          stax.serial(
                              UnbiasedConv(self.inner_channels, (3, 3),
                                           padding='SAME'),
                              UnbiasedConv(self.inner_channels, (3, 3),
                                           padding='SAME'),
                              UNetBlock(self.inner_channels, (3, 3),
                                        stax.serial(
                                            UnbiasedConv(self.inner_channels,
                                                         (3, 3),
                                                         padding='SAME'),
                                            UnbiasedConv(self.inner_channels,
                                                         (3, 3),
                                                         padding='SAME'),
                                            UnbiasedConv(self.inner_channels,
                                                         (3, 3),
                                                         padding='SAME'),
                                        ),
                                        strides=(2, 2),
                                        padding='VALID'),
                              UnbiasedConv(self.inner_channels, (3, 3),
                                           padding='SAME'),
                              UnbiasedConv(self.inner_channels, (3, 3),
                                           padding='SAME'),
                          ),
                          strides=(2, 2),
                          padding='VALID'), )
            out_shape, net_params = self.net_init(rng, self.in_shape)
        else:
            #import pdb;pdb.set_trace()
            model_def = flax_cnn.new_CNN.partial(
                inner_channels=self.inner_channels)
            out_shape, net_params = model_def.init_by_shape(
                rng, [(self.in_shape, np.float32)])
            self.model_def = model_def
            self.model = nn.Model(model_def, net_params)
            self.net_apply = lambda param, x: nn.Model(model_def, param)(
                x)  #.reshape(self.in_shape))
        self.out_shape = out_shape
        self.net_params = net_params
        self.loss = loss
        self.lr_og = lr_og
        self.lr = lr
        if not flaxd:
            self.opt_init, self.opt_update, self.get_params = optimizers.adam(
                step_size=lambda i: np.where(i < 100, lr_og, lr),
                b1=beta1,
                b2=beta2)
            self.opt_state = self.opt_init(self.net_params)
            self.step = self.step_notflax

        if flaxd:
            self.step = self.step_flax
            self.optimizer = flax.optim.Adam(learning_rate=lr,
                                             beta1=beta1,
                                             beta2=beta2).create(self.model)
            #self.optimizer = flax.optim.Momentum(
            #    learning_rate= lr, beta=beta1,
            #    weight_decay=0, nesterov=False).create(self.model)
        self.alpha = lambda i: 0.0
        self.flaxd = flaxd
        if flaxd:
            self.preconditioner = self.preconditioner_flaxed
        else:
            self.preconditioner = self.preconditioner_unflaxed
Example #11
0
    def optimize(self, p0, num_iters, metric, step_size, tolerance, verbose,
                 return_model):
        """Workhorse of optimization.

        p0: dict
            A dictionary of the initial model parameters to be optimized.

        num_iters: int
            Maximum number of iteration.

        metric: str
            Method of model evaluation. Can be
            `mse`, `corrcoeff`, `r2`


        step_size: float or jax scheduler
            Learning rate.

        tolerance: int
            Tolerance for early stop. If the training cost doesn't change more than 1e-5
            in the last (tolerance) steps, or the dev cost monotonically increase, stop.

        verbose: int
            Print progress. If verbose=0, no progress will be print.

        return_model: str
            Return the 'best' model on dev set metrics or the 'last' model.
        """
        if return_model is None:
            if 'dev' in self.y:
                return_model = 'best_dev_cost'
            else:
                return_model = 'best_train_cost'

        assert ('dev' in self.y) or (
            'dev'
            not in return_model), 'Cannot use dev set if dev set is not given.'

        @jit
        def step(_i, _opt_state):
            p = get_params(_opt_state)
            l, g = value_and_grad(self.cost)(p)
            return l, opt_update(_i, g, _opt_state)

        opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)
        opt_state = opt_init(p0)

        cost_train = np.full(num_iters, np.nan)
        cost_dev = np.full(num_iters, np.nan)
        metric_train = np.full(num_iters, np.nan)
        metric_dev = np.full(num_iters, np.nan)
        params_list = []

        extra = 'dev' in self.y

        if verbose:
            self.print_progress_header(c_train=True,
                                       c_dev=extra,
                                       m_train=metric is not None,
                                       m_dev=metric is not None and extra)

        time_start = time.time()
        i = 0

        for i in range(num_iters):
            cost_train[i], opt_state = step(i, opt_state)

            params = get_params(opt_state)
            params_list.append(params)

            y_pred_train = self.forwardpass(p=params, kind='train')
            metric_train[i] = self.compute_score(self.y['train'], y_pred_train,
                                                 metric)

            if 'dev' in self.y:
                y_pred_dev = self.forwardpass(p=params, kind='dev')
                cost_dev[i] = self.cost(p=params,
                                        kind='dev',
                                        precomputed=y_pred_dev,
                                        penalize=False)
                metric_dev[i] = self.compute_score(self.y['dev'], y_pred_dev,
                                                   metric)

            time_elapsed = time.time() - time_start
            if verbose:
                if i % int(verbose) == 0:
                    self.print_progress(i,
                                        time_elapsed,
                                        c_train=cost_train[i],
                                        c_dev=cost_dev[i],
                                        m_train=metric_train[i],
                                        m_dev=metric_dev[i])

            if tolerance and i > 300:  # tolerance = 0: no early stop.

                total_time_elapsed = time.time() - time_start

                if 'dev' in self.y and np.all(
                        np.diff(cost_dev[i - tolerance:i]) > 0):
                    stop = 'dev_stop'
                    if verbose:
                        print(
                            'Stop at {0} steps: cost (dev) has been monotonically increasing for {1} steps.'
                            .format(i, tolerance))
                        print('Total time elapsed: {0:.3f}s.\n'.format(
                            total_time_elapsed))
                    break

                if np.all(np.diff(cost_train[i - tolerance:i]) < 1e-5):
                    stop = 'train_stop'
                    if verbose:
                        print(
                            'Stop at {0} steps: cost (train) has been changing less than 1e-5 for {1} steps.'
                            .format(i, tolerance))
                        print('Total time elapsed: {0:.3f}s.\n'.format(
                            total_time_elapsed))
                    break

        else:
            total_time_elapsed = time.time() - time_start
            stop = 'maxiter_stop'
            if verbose:
                print('Stop: reached {0} steps.'.format(num_iters))
                print('Total time elapsed: {0:.3f}s.\n'.format(
                    total_time_elapsed))

        if return_model == 'best_dev_cost':
            best = np.argmin(cost_dev[:i + 1])

        elif return_model == 'best_train_cost':
            best = np.argmin(cost_train[:i + 1])

        elif return_model == 'best_dev_metric':
            if metric in ['mse', 'gcv']:
                best = np.argmin(metric_dev[:i + 1])
            else:
                best = np.argmax(metric_dev[:i + 1])

        elif return_model == 'best_train_metric':
            if metric in ['mse', 'gcv']:
                best = np.argmin(metric_train[:i + 1])
            else:
                best = np.argmax(metric_train[:i + 1])

        else:
            if return_model != 'last':
                print(
                    'Provided `return_model` is not supported. Fallback to `last`'
                )
            if stop == 'dev_stop':
                best = i - tolerance
            else:
                best = i

        params = params_list[best]
        metric_dev_opt = metric_dev[best]

        self.cost_train = cost_train[:i + 1]
        self.cost_dev = cost_dev[:i + 1]
        self.metric_train = metric_train[:i + 1]
        self.metric_dev = metric_dev[:i + 1]
        self.metric_dev_opt = metric_dev_opt
        self.total_time_elapsed = total_time_elapsed

        self.all_params = params_list[:i +
                                      1]  # not sure if this will occupy a lot of RAM.

        self.y_pred['opt'].update({'train': y_pred_train})
        if 'dev' in self.y:
            self.y_pred['opt'].update({'dev': y_pred_dev})

        return params