def setup_optimizers(): global x, y, z model = init_model(ode_kwargs) return model grad_fn = jax.grad(lambda *args: loss_fn(forward, *args)) lr_seq = exponential_decay(step_size=1e-2, decay_steps=1, decay_rate=0.999, lowest=1e-2 / 10) x, y, z = optimizers.adamax(step_size=lr_seq)
def get_optimizer( learning_rate: float = 1e-4, optimizer="sdg", optimizer_kwargs: dict = None ) -> JaxOptimizer: """Return a `JaxOptimizer` dataclass for a JAX optimizer Args: learning_rate (float, optional): Step size. Defaults to 1e-4. optimizer (str, optional): Optimizer type (Allowed types: "adam", "adamax", "adagrad", "rmsprop", "sdg"). Defaults to "sdg". optimizer_kwargs (dict, optional): Additional keyword arguments that are passed to the optimizer. Defaults to None. Returns: JaxOptimizer """ from jax.config import config # pylint:disable=import-outside-toplevel config.update("jax_enable_x64", True) from jax import jit # pylint:disable=import-outside-toplevel from jax.experimental import optimizers # pylint:disable=import-outside-toplevel if optimizer_kwargs is None: optimizer_kwargs = {} optimizer = optimizer.lower() if optimizer == "adam": opt_init, opt_update, get_params = optimizers.adam(learning_rate, **optimizer_kwargs) elif optimizer == "adagrad": opt_init, opt_update, get_params = optimizers.adagrad(learning_rate, **optimizer_kwargs) elif optimizer == "adamax": opt_init, opt_update, get_params = optimizers.adamax(learning_rate, **optimizer_kwargs) elif optimizer == "rmsprop": opt_init, opt_update, get_params = optimizers.rmsprop(learning_rate, **optimizer_kwargs) else: opt_init, opt_update, get_params = optimizers.sgd(learning_rate, **optimizer_kwargs) opt_update = jit(opt_update) return JaxOptimizer(opt_init, opt_update, get_params)
def _JaxAdaMax(machine, alpha=0.001, beta1=0.9, beta2=0.999, epscut=1.0e-7): return Wrap(machine, jaxopt.adamax(alpha, beta1, beta2, epscut))
def run(): """ Run the experiment. """ ds_train, ds_test, meta = init_physionet_data(rng, parse_args) num_batches = meta["num_batches"] num_test_batches = meta["num_test_batches"] model = init_model(ode_kwargs) forward = lambda *args: model["forward"](*args)[1:] grad_fn = jax.grad(lambda *args: loss_fn(forward, *args)) lr_schedule = exponential_decay(step_size=parse_args.lr, decay_steps=1, decay_rate=0.999, lowest=parse_args.lr / 10) opt_init, opt_update, get_params = optimizers.adamax(step_size=lr_schedule) opt_state = opt_init(model["params"]) def get_kl_coef(epoch_): """ Tuning schedule for KL coefficient. (annealing) """ return max(0., 1 - 0.99**(epoch_ - 10)) @jax.jit def update(_itr, _opt_state, _batch, kl_coef): """ Update the params based on grad for current batch. """ return opt_update(_itr, grad_fn(get_params(_opt_state), _batch, kl_coef), _opt_state) @jax.jit def sep_losses(_opt_state, _batch, kl_coef): """ Convenience function for calculating losses separately. """ params = get_params(_opt_state) preds, rec_r, gen_r, z0_params, nfe = forward(params, batch["observed_data"], batch["observed_tp"], batch["tp_to_predict"], batch["observed_mask"]) likelihood_ = _likelihood(preds, batch["observed_data"], batch["observed_mask"]) mse_ = _mse(preds, batch["observed_data"], batch["observed_mask"]) kl_ = _kl_div(z0_params) return -logsumexp(likelihood_ - kl_coef * kl_, axis=0), likelihood_, kl_, mse_, rec_r, gen_r def evaluate_loss(opt_state, ds_test, kl_coef): """ Convenience function for evaluating loss over train set in smaller batches. """ loss, likelihood, kl, mse, rec_r, gen_r, rec_nfe, gen_nfe = [], [], [], [], [], [], [], [] for test_batch_num in range(num_test_batches): test_batch = next(ds_test) batch_loss, batch_likelihood, batch_kl, batch_mse, batch_rec_r, batch_gen_r = \ sep_losses(opt_state, test_batch, kl_coef) if count_nfe: nfes = model["nfe"](get_params(opt_state), test_batch["observed_data"], test_batch["observed_tp"], test_batch["tp_to_predict"], test_batch["observed_mask"]) rec_nfe.append(nfes["rec"]) gen_nfe.append(nfes["gen"]) else: rec_nfe.append(0) gen_nfe.append(0) loss.append(batch_loss) likelihood.append(batch_likelihood) kl.append(batch_kl) mse.append(batch_mse) rec_r.append(batch_rec_r) gen_r.append(batch_gen_r) loss = jnp.array(loss) likelihood = jnp.array(likelihood) kl = jnp.array(kl) mse = jnp.array(mse) rec_r = jnp.array(rec_r) gen_r = jnp.array(gen_r) rec_nfe = jnp.array(rec_nfe) gen_nfe = jnp.array(gen_nfe) return jnp.mean(loss), jnp.mean(likelihood), jnp.mean(kl), jnp.mean(mse), jnp.mean(rec_r), jnp.mean(gen_r), \ jnp.mean(rec_nfe), jnp.mean(gen_nfe) itr = 0 info = collections.defaultdict(dict) for epoch in range(parse_args.nepochs): for i in range(num_batches): batch = next(ds_train) itr += 1 opt_state = update(itr, opt_state, batch, get_kl_coef(epoch)) if itr % parse_args.test_freq == 0: loss_, likelihood_, kl_, mse_, rec_r_, gen_r_, rec_nfe_, gen_nfe_ = \ evaluate_loss(opt_state, ds_test, get_kl_coef(epoch)) print_str = 'Iter {:04d} | Loss {:.6f} | ' \ 'Likelihood {:.6f} | KL {:.6f} | MSE {:.6f} | Enc. r {:.6f} | Dec. r {:.6f} | ' \ 'Enc. NFE {:.6f} | Dec. NFE {:.6f}'.\ format(itr, loss_, likelihood_, kl_, mse_, rec_r_, gen_r_, rec_nfe_, gen_nfe_) print(print_str) outfile = open( "%s/reg_%s_lam_%.12e_info.txt" % (dirname, reg, lam), "a") outfile.write(print_str + "\n") outfile.close() info[itr]["loss"] = loss_ info[itr]["likelihood"] = likelihood_ info[itr]["kl"] = kl_ info[itr]["mse"] = mse_ info[itr]["rec_r"] = rec_r_ info[itr]["gen_r"] = gen_r_ info[itr]["rec_nfe"] = rec_nfe_ info[itr]["gen_nfe"] = gen_nfe_ if itr % parse_args.save_freq == 0: param_filename = "%s/reg_%s_lam_%.12e_%d_fargs.pickle" % ( dirname, reg, lam, itr) fargs = get_params(opt_state) outfile = open(param_filename, "wb") pickle.dump(fargs, outfile) outfile.close() outfile = open( "%s/reg_%s_lam_%.12e_iter.txt" % (dirname, reg, lam), "a") outfile.write("Iter: {:04d}\n".format(itr)) outfile.close() meta = {"info": info, "args": parse_args} outfile = open("%s/reg_%s_lam_%.12e_meta.pickle" % (dirname, reg, lam), "wb") pickle.dump(meta, outfile) outfile.close()