コード例 #1
0
def setup_optimizers():
    global x, y, z
    model = init_model(ode_kwargs)
    return model
    grad_fn = jax.grad(lambda *args: loss_fn(forward, *args))
    lr_seq = exponential_decay(step_size=1e-2,
                               decay_steps=1,
                               decay_rate=0.999,
                               lowest=1e-2 / 10)
    x, y, z = optimizers.adamax(step_size=lr_seq)
コード例 #2
0
ファイル: nt.py プロジェクト: kjappelbaum/pyepal
def get_optimizer(
    learning_rate: float = 1e-4, optimizer="sdg", optimizer_kwargs: dict = None
) -> JaxOptimizer:
    """Return a `JaxOptimizer` dataclass for a JAX optimizer

    Args:
        learning_rate (float, optional): Step size. Defaults to 1e-4.
        optimizer (str, optional): Optimizer type (Allowed types: "adam",
            "adamax", "adagrad", "rmsprop", "sdg"). Defaults to "sdg".
        optimizer_kwargs (dict, optional): Additional keyword arguments
            that are passed to the optimizer. Defaults to None.

    Returns:
        JaxOptimizer
    """
    from jax.config import config  # pylint:disable=import-outside-toplevel

    config.update("jax_enable_x64", True)
    from jax import jit  # pylint:disable=import-outside-toplevel
    from jax.experimental import optimizers  # pylint:disable=import-outside-toplevel

    if optimizer_kwargs is None:
        optimizer_kwargs = {}
    optimizer = optimizer.lower()
    if optimizer == "adam":
        opt_init, opt_update, get_params = optimizers.adam(learning_rate, **optimizer_kwargs)
    elif optimizer == "adagrad":
        opt_init, opt_update, get_params = optimizers.adagrad(learning_rate, **optimizer_kwargs)
    elif optimizer == "adamax":
        opt_init, opt_update, get_params = optimizers.adamax(learning_rate, **optimizer_kwargs)
    elif optimizer == "rmsprop":
        opt_init, opt_update, get_params = optimizers.rmsprop(learning_rate, **optimizer_kwargs)
    else:
        opt_init, opt_update, get_params = optimizers.sgd(learning_rate, **optimizer_kwargs)

    opt_update = jit(opt_update)

    return JaxOptimizer(opt_init, opt_update, get_params)
コード例 #3
0
ファイル: __init__.py プロジェクト: vlpap/netket
def _JaxAdaMax(machine, alpha=0.001, beta1=0.9, beta2=0.999, epscut=1.0e-7):
    return Wrap(machine, jaxopt.adamax(alpha, beta1, beta2, epscut))
コード例 #4
0
def run():
    """
    Run the experiment.
    """

    ds_train, ds_test, meta = init_physionet_data(rng, parse_args)
    num_batches = meta["num_batches"]
    num_test_batches = meta["num_test_batches"]

    model = init_model(ode_kwargs)
    forward = lambda *args: model["forward"](*args)[1:]
    grad_fn = jax.grad(lambda *args: loss_fn(forward, *args))

    lr_schedule = exponential_decay(step_size=parse_args.lr,
                                    decay_steps=1,
                                    decay_rate=0.999,
                                    lowest=parse_args.lr / 10)
    opt_init, opt_update, get_params = optimizers.adamax(step_size=lr_schedule)
    opt_state = opt_init(model["params"])

    def get_kl_coef(epoch_):
        """
        Tuning schedule for KL coefficient. (annealing)
        """
        return max(0., 1 - 0.99**(epoch_ - 10))

    @jax.jit
    def update(_itr, _opt_state, _batch, kl_coef):
        """
        Update the params based on grad for current batch.
        """
        return opt_update(_itr, grad_fn(get_params(_opt_state), _batch,
                                        kl_coef), _opt_state)

    @jax.jit
    def sep_losses(_opt_state, _batch, kl_coef):
        """
        Convenience function for calculating losses separately.
        """
        params = get_params(_opt_state)
        preds, rec_r, gen_r, z0_params, nfe = forward(params,
                                                      batch["observed_data"],
                                                      batch["observed_tp"],
                                                      batch["tp_to_predict"],
                                                      batch["observed_mask"])
        likelihood_ = _likelihood(preds, batch["observed_data"],
                                  batch["observed_mask"])
        mse_ = _mse(preds, batch["observed_data"], batch["observed_mask"])
        kl_ = _kl_div(z0_params)
        return -logsumexp(likelihood_ - kl_coef * kl_,
                          axis=0), likelihood_, kl_, mse_, rec_r, gen_r

    def evaluate_loss(opt_state, ds_test, kl_coef):
        """
        Convenience function for evaluating loss over train set in smaller batches.
        """
        loss, likelihood, kl, mse, rec_r, gen_r, rec_nfe, gen_nfe = [], [], [], [], [], [], [], []

        for test_batch_num in range(num_test_batches):
            test_batch = next(ds_test)

            batch_loss, batch_likelihood, batch_kl, batch_mse, batch_rec_r, batch_gen_r = \
                sep_losses(opt_state, test_batch, kl_coef)

            if count_nfe:
                nfes = model["nfe"](get_params(opt_state),
                                    test_batch["observed_data"],
                                    test_batch["observed_tp"],
                                    test_batch["tp_to_predict"],
                                    test_batch["observed_mask"])
                rec_nfe.append(nfes["rec"])
                gen_nfe.append(nfes["gen"])
            else:
                rec_nfe.append(0)
                gen_nfe.append(0)

            loss.append(batch_loss)
            likelihood.append(batch_likelihood)
            kl.append(batch_kl)
            mse.append(batch_mse)
            rec_r.append(batch_rec_r)
            gen_r.append(batch_gen_r)

        loss = jnp.array(loss)
        likelihood = jnp.array(likelihood)
        kl = jnp.array(kl)
        mse = jnp.array(mse)
        rec_r = jnp.array(rec_r)
        gen_r = jnp.array(gen_r)
        rec_nfe = jnp.array(rec_nfe)
        gen_nfe = jnp.array(gen_nfe)

        return jnp.mean(loss), jnp.mean(likelihood), jnp.mean(kl), jnp.mean(mse), jnp.mean(rec_r), jnp.mean(gen_r), \
               jnp.mean(rec_nfe), jnp.mean(gen_nfe)

    itr = 0
    info = collections.defaultdict(dict)
    for epoch in range(parse_args.nepochs):
        for i in range(num_batches):
            batch = next(ds_train)

            itr += 1

            opt_state = update(itr, opt_state, batch, get_kl_coef(epoch))

            if itr % parse_args.test_freq == 0:
                loss_, likelihood_, kl_, mse_, rec_r_, gen_r_, rec_nfe_, gen_nfe_ = \
                    evaluate_loss(opt_state, ds_test, get_kl_coef(epoch))

                print_str = 'Iter {:04d} | Loss {:.6f} | ' \
                            'Likelihood {:.6f} | KL {:.6f} | MSE {:.6f} | Enc. r {:.6f} | Dec. r {:.6f} | ' \
                            'Enc. NFE {:.6f} | Dec. NFE {:.6f}'.\
                    format(itr, loss_, likelihood_, kl_, mse_, rec_r_, gen_r_, rec_nfe_, gen_nfe_)

                print(print_str)

                outfile = open(
                    "%s/reg_%s_lam_%.12e_info.txt" % (dirname, reg, lam), "a")
                outfile.write(print_str + "\n")
                outfile.close()

                info[itr]["loss"] = loss_
                info[itr]["likelihood"] = likelihood_
                info[itr]["kl"] = kl_
                info[itr]["mse"] = mse_
                info[itr]["rec_r"] = rec_r_
                info[itr]["gen_r"] = gen_r_
                info[itr]["rec_nfe"] = rec_nfe_
                info[itr]["gen_nfe"] = gen_nfe_

            if itr % parse_args.save_freq == 0:
                param_filename = "%s/reg_%s_lam_%.12e_%d_fargs.pickle" % (
                    dirname, reg, lam, itr)
                fargs = get_params(opt_state)
                outfile = open(param_filename, "wb")
                pickle.dump(fargs, outfile)
                outfile.close()

            outfile = open(
                "%s/reg_%s_lam_%.12e_iter.txt" % (dirname, reg, lam), "a")
            outfile.write("Iter: {:04d}\n".format(itr))
            outfile.close()
    meta = {"info": info, "args": parse_args}
    outfile = open("%s/reg_%s_lam_%.12e_meta.pickle" % (dirname, reg, lam),
                   "wb")
    pickle.dump(meta, outfile)
    outfile.close()