def statistics(net_params: List[jnp.ndarray], deq_params: List[jnp.ndarray], rng: random.PRNGKey): # Split pseudo-random number key. rng, rng_sample, rng_xobs, rng_kl = random.split(rng, 4) # Compute comparison statistics. _, xsph, _ = ode_forward(rng_sample, net_params, 10000, 4) xobs = rejection_sampling(rng_xobs, len(xsph), 4, embedded_sphere_density) mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0))) cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T))) approx = importance_density(rng_kl, net_params, deq_params, 10000, xsph) log_approx = jnp.log(approx) target = embedded_sphere_density(xsph) w = target / approx Z = jnp.nanmean(w) log_approx = jnp.log(approx) log_target = jnp.log(target) klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z) ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w)) ress = 100 * ess / len(w) del w, Z, log_approx, approx, log_target, target, xsph approx = importance_density(rng_kl, net_params, deq_params, 10000, xobs) log_approx = jnp.log(approx) target = embedded_sphere_density(xobs) w = approx / target Z = jnp.nanmean(w) log_target = jnp.log(target) klpq = jnp.nanmean(log_target - log_approx) + jnp.log(Z) del w, Z, log_approx, approx, log_target, target method = 'deqode ({})'.format('ELBO' if args.elbo_loss else 'KL') print( '{} - Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - KL$(p\Vert q)$ = {:.5f} - Rel. ESS: {:.2f}%' .format(method, mean_mse, cov_mse, klqp, klpq, ress))
def main(): # Set pseudo-random number generator keys. rng = random.PRNGKey(args.seed) rng, rng_net = random.split(rng, 2) rng, rng_sample, rng_xobs, rng_basis = random.split(rng, 4) rng, rng_fwd, rng_rev = random.split(rng, 3) rng, rng_kl = random.split(rng, 2) # Initialize the parameters of the ambient vector field network. _, params = net_init(rng_net, (-1, 4)) opt_state = opt_init(params) for it in range(args.num_steps): opt_state, kl = step(opt_state, it, args.num_samples) print('iter.: {} - kl: {:.4f}'.format(it, kl)) params = get_params(opt_state) count = lambda x: jnp.prod(jnp.array(x.shape)) num_params = jnp.array( tree_util.tree_map(count, tree_util.tree_flatten(params)[0])).sum() print('number of parameters: {}'.format(num_params)) # Compute comparison statistics. xsph, log_approx = manifold_ode_log_prob(params, rng_sample, 10000) xobs = rejection_sampling(rng_xobs, len(xsph), 3, embedded_sphere_density) mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0))) cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T))) approx = jnp.exp(log_approx) target = embedded_sphere_density(xsph) w = target / approx Z = jnp.nanmean(w) log_approx = jnp.log(approx) log_target = jnp.log(target) klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z) ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w)) ress = 100 * ess / len(w) del w, Z, log_approx, approx, log_target, target log_approx = manifold_reverse_ode_log_prob(params, rng_kl, xobs) approx = jnp.exp(log_approx) target = embedded_sphere_density(xobs) w = approx / target Z = jnp.nanmean(w) log_target = jnp.log(target) klpq = jnp.nanmean(log_target - log_approx) + jnp.log(Z) del w, Z, log_approx, approx, log_target, target print( 'manode - Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - KL$(p\Vert q)$ = {:.5f} - Rel. ESS: {:.2f}%' .format(mean_mse, cov_mse, klqp, klpq, ress))
def loss(rng: jnp.ndarray, thetax: jnp.ndarray, thetay: jnp.ndarray, thetad: jnp.ndarray, paramsm: Sequence[jnp.ndarray], netm: Callable, num_samples: int): """KL(q || p) loss function to minimize. This is computable up to a constant when the target density is known up to proportionality. Args: rng: Pseudo-random number generator seed. thetax: Unconstrained x-coordinates of the spline intervals for radial coordinate. thetay: Unconstrained y-coordinates of the spline intervals for radial coordinate. thetad: Unconstrained derivatives at internal points for radial coordinate. paramsm: Parameters of the neural network giving the conditional distribution of the angular parameter. netm: Neural network to compute the angular parameter. num_samples: Number of samples to draw. Returns: out: A Monte Carlo estimate of KL(q || p). """ xk, yk, delta = spline_unconstrained_transform(thetax, thetay, thetad) (ra, ang), (raunif, angunif), w = mobius_spline_sample(rng, num_samples, xk, yk, delta, paramsm, netm) xsph = torus2sphere(ra, ang) mslp = mobius_spline_log_prob(ra, raunif, ang, angunif, w, xk, yk, delta) t = embedded_sphere_density(xsph) lt = jnp.log(t) return jnp.mean(mslp - lt)
def loss(rng: jnp.ndarray, bij_params: Sequence[jnp.ndarray], bij_fns: Sequence[Callable], deq_params: Sequence[jnp.ndarray], deq_fn: Callable, num_samples: int) -> float: """Loss function composed of the evidence lower bound and score matching loss. Args: rng: Pseudo-random number generator seed. bij_params: List of arrays parameterizing the RealNVP bijectors. bij_fns: List of functions that compute the shift and scale of the RealNVP affine transformation. deq_params: Parameters of the mean and scale functions used in the log-normal dequantizer. deq_fn: Function that computes the mean and scale of the dequantization distribution. num_samples: Number of samples to draw using rejection sampling. Returns: nelbo: The negative evidence lower bound. """ rng, rng_rej, rng_loss = random.split(rng, 3) xsph = rejection_sampling(rng_rej, num_samples, num_dims, embedded_sphere_density) if args.elbo_loss: nelbo = negative_elbo(rng_loss, bij_params, bij_fns, deq_params, deq_fn, xsph).mean() return nelbo else: log_is = importance_log_density(rng_loss, bij_params, bij_fns, deq_params, deq_fn, args.num_importance, xsph) log_target = jnp.log(embedded_sphere_density(xsph)) return jnp.mean(log_target - log_is)
def loss(net_params: List[jnp.ndarray], deq_params: List[jnp.ndarray], rng: random.PRNGKey, num_samples: int) -> float: rng, rng_rej, rng_loss = random.split(rng, 3) xsph = rejection_sampling(rng_rej, num_samples, 4, embedded_sphere_density) if args.elbo_loss: nelbo = negative_elbo(rng_loss, net_params, deq_params, xsph).mean() return nelbo else: log_is = importance_log_density(rng_loss, net_params, deq_params, args.num_importance, xsph) log_target = jnp.log(embedded_sphere_density(xsph)) return jnp.mean(log_target - log_is)
def kl_divergence(params: List, rng: random.PRNGKey, num_samples: int) -> float: """Computes the KL divergence between the target density and the neural manifold ODE's distribution on the sphere. Note that the target density is unnormalized. Args: params: Parameters of the neural manifold ODE. rng: Pseudo-random number generator key. num_samples: Number of samples use to estimate the KL divergence. Returns: div: The estimated KL divergence. """ s, log_prob = manifold_ode_log_prob(params, rng, num_samples) log_prob_target = jnp.log(embedded_sphere_density(s)) div = jnp.mean(log_prob - log_prob_target) return div
def kl_divergence(net_params: List[jnp.ndarray], deq_params: List[jnp.ndarray], rng: random.PRNGKey, num_samples: int) -> float: """Computes the KL divergence between the target density and the neural manifold ODE's distribution on the sphere. Note that the target density is unnormalized. Args: params: Parameters of the neural manifold ODE. rng: Pseudo-random number generator key. num_samples: Number of samples use to estimate the KL divergence. Returns: div: The estimated KL divergence. """ rng, rng_fwd, rng_is = random.split(rng, 3) _, xsph, _ = ode_forward(rng_fwd, net_params, num_samples, 4) log_prob = importance_log_density(rng_is, net_params, deq_params, 10, xsph) log_prob_target = jnp.log(embedded_sphere_density(xsph)) div = jnp.mean(log_prob - log_prob_target) return div
num_params = num_theta + num_paramsm print('number of parameters: {}'.format(num_params)) # Train normalizing flow on the sphere. (thetax, thetay, thetad, paramsm), trace = train(rng_train, thetax, thetay, thetad, paramsm, netm, args.num_samples, args.num_steps, args.lr) num_samples = 100000 xk, yk, delta = spline_unconstrained_transform(thetax, thetay, thetad) # Compute comparison statistics. (ra, ang), (raunif, angunif), w = mobius_spline_sample(rng_ms, num_samples, xk, yk, delta, paramsm, netm) xsph = torus2sphere(ra, ang) log_approx = mobius_spline_log_prob(ra, raunif, ang, angunif, w, xk, yk, delta) approx = jnp.exp(log_approx) target = embedded_sphere_density(xsph) log_target = jnp.log(target) w = target / approx Z = jnp.mean(w) kl = jnp.mean(log_approx - log_target) + jnp.log(Z) ess = jnp.square(jnp.sum(w)) / jnp.sum(jnp.square(w)) ress = 100 * ess / len(w) xobs = rejection_sampling(rng_xobs, len(xsph), 3, embedded_sphere_density) mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0))) cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T))) print( 'normalizing - Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - Rel. ESS: {:.2f}%' .format(mean_mse, cov_mse, kl, ress))
# Estimate parameters of the dequantizer and ambient flow. (bij_params, deq_params), trace = train(rng_train, bij_params, deq_params, args.num_steps, args.lr, args.num_batch) # Sample using dequantization and rejection sampling. xamb, xsph = sample_ambient(rng_xamb, 100000, bij_params, bij_fns, num_dims) xobs = rejection_sampling(rng_xobs, len(xsph), num_dims, embedded_sphere_density) # Compute comparison statistics. mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0))) cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T))) approx = importance_density(rng_kl, bij_params, deq_params, 1000, xsph) target = embedded_sphere_density(xsph) w = target / approx Z = jnp.nanmean(w) log_approx = jnp.log(approx) log_target = jnp.log(target) klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z) ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w)) ress = 100 * ess / len(w) del w, Z, log_approx, approx, log_target, target approx = importance_density(rng_kl, bij_params, deq_params, 1000, xobs) target = embedded_sphere_density(xobs) w = approx / target Z = jnp.nanmean(w) log_approx = jnp.log(approx) log_target = jnp.log(target) klpq = jnp.nanmean(log_target - log_approx) + jnp.log(Z)