def setup_optimizers(): global x, y, z model = init_model(ode_kwargs) return model grad_fn = jax.grad(lambda *args: loss_fn(forward, *args)) lr_seq = exponential_decay(step_size=1e-2, decay_steps=1, decay_rate=0.999, lowest=1e-2 / 10) x, y, z = optimizers.adamax(step_size=lr_seq)
def run(): """ Run the experiment. """ ds_train, ds_test, meta = init_physionet_data(rng, parse_args) num_batches = meta["num_batches"] num_test_batches = meta["num_test_batches"] model = init_model(ode_kwargs) forward = lambda *args: model["forward"](*args)[1:] grad_fn = jax.grad(lambda *args: loss_fn(forward, *args)) lr_schedule = exponential_decay(step_size=parse_args.lr, decay_steps=1, decay_rate=0.999, lowest=parse_args.lr / 10) opt_init, opt_update, get_params = optimizers.adamax(step_size=lr_schedule) opt_state = opt_init(model["params"]) def get_kl_coef(epoch_): """ Tuning schedule for KL coefficient. (annealing) """ return max(0., 1 - 0.99**(epoch_ - 10)) @jax.jit def update(_itr, _opt_state, _batch, kl_coef): """ Update the params based on grad for current batch. """ return opt_update(_itr, grad_fn(get_params(_opt_state), _batch, kl_coef), _opt_state) @jax.jit def sep_losses(_opt_state, _batch, kl_coef): """ Convenience function for calculating losses separately. """ params = get_params(_opt_state) preds, rec_r, gen_r, z0_params, nfe = forward(params, batch["observed_data"], batch["observed_tp"], batch["tp_to_predict"], batch["observed_mask"]) likelihood_ = _likelihood(preds, batch["observed_data"], batch["observed_mask"]) mse_ = _mse(preds, batch["observed_data"], batch["observed_mask"]) kl_ = _kl_div(z0_params) return -logsumexp(likelihood_ - kl_coef * kl_, axis=0), likelihood_, kl_, mse_, rec_r, gen_r def evaluate_loss(opt_state, ds_test, kl_coef): """ Convenience function for evaluating loss over train set in smaller batches. """ loss, likelihood, kl, mse, rec_r, gen_r, rec_nfe, gen_nfe = [], [], [], [], [], [], [], [] for test_batch_num in range(num_test_batches): test_batch = next(ds_test) batch_loss, batch_likelihood, batch_kl, batch_mse, batch_rec_r, batch_gen_r = \ sep_losses(opt_state, test_batch, kl_coef) if count_nfe: nfes = model["nfe"](get_params(opt_state), test_batch["observed_data"], test_batch["observed_tp"], test_batch["tp_to_predict"], test_batch["observed_mask"]) rec_nfe.append(nfes["rec"]) gen_nfe.append(nfes["gen"]) else: rec_nfe.append(0) gen_nfe.append(0) loss.append(batch_loss) likelihood.append(batch_likelihood) kl.append(batch_kl) mse.append(batch_mse) rec_r.append(batch_rec_r) gen_r.append(batch_gen_r) loss = jnp.array(loss) likelihood = jnp.array(likelihood) kl = jnp.array(kl) mse = jnp.array(mse) rec_r = jnp.array(rec_r) gen_r = jnp.array(gen_r) rec_nfe = jnp.array(rec_nfe) gen_nfe = jnp.array(gen_nfe) return jnp.mean(loss), jnp.mean(likelihood), jnp.mean(kl), jnp.mean(mse), jnp.mean(rec_r), jnp.mean(gen_r), \ jnp.mean(rec_nfe), jnp.mean(gen_nfe) itr = 0 info = collections.defaultdict(dict) for epoch in range(parse_args.nepochs): for i in range(num_batches): batch = next(ds_train) itr += 1 opt_state = update(itr, opt_state, batch, get_kl_coef(epoch)) if itr % parse_args.test_freq == 0: loss_, likelihood_, kl_, mse_, rec_r_, gen_r_, rec_nfe_, gen_nfe_ = \ evaluate_loss(opt_state, ds_test, get_kl_coef(epoch)) print_str = 'Iter {:04d} | Loss {:.6f} | ' \ 'Likelihood {:.6f} | KL {:.6f} | MSE {:.6f} | Enc. r {:.6f} | Dec. r {:.6f} | ' \ 'Enc. NFE {:.6f} | Dec. NFE {:.6f}'.\ format(itr, loss_, likelihood_, kl_, mse_, rec_r_, gen_r_, rec_nfe_, gen_nfe_) print(print_str) outfile = open( "%s/reg_%s_lam_%.12e_info.txt" % (dirname, reg, lam), "a") outfile.write(print_str + "\n") outfile.close() info[itr]["loss"] = loss_ info[itr]["likelihood"] = likelihood_ info[itr]["kl"] = kl_ info[itr]["mse"] = mse_ info[itr]["rec_r"] = rec_r_ info[itr]["gen_r"] = gen_r_ info[itr]["rec_nfe"] = rec_nfe_ info[itr]["gen_nfe"] = gen_nfe_ if itr % parse_args.save_freq == 0: param_filename = "%s/reg_%s_lam_%.12e_%d_fargs.pickle" % ( dirname, reg, lam, itr) fargs = get_params(opt_state) outfile = open(param_filename, "wb") pickle.dump(fargs, outfile) outfile.close() outfile = open( "%s/reg_%s_lam_%.12e_iter.txt" % (dirname, reg, lam), "a") outfile.write("Iter: {:04d}\n".format(itr)) outfile.close() meta = {"info": info, "args": parse_args} outfile = open("%s/reg_%s_lam_%.12e_meta.pickle" % (dirname, reg, lam), "wb") pickle.dump(meta, outfile) outfile.close()