Exemple #1
0
def statistics(net_params: List[jnp.ndarray], deq_params: List[jnp.ndarray],
               rng: random.PRNGKey):
    # Split pseudo-random number key.
    rng, rng_sample, rng_xobs, rng_kl = random.split(rng, 4)
    # Compute comparison statistics.
    _, xsph, _ = ode_forward(rng_sample, net_params, 10000, 4)
    xobs = rejection_sampling(rng_xobs, len(xsph), 4, embedded_sphere_density)
    mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0)))
    cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T)))
    approx = importance_density(rng_kl, net_params, deq_params, 10000, xsph)
    log_approx = jnp.log(approx)
    target = embedded_sphere_density(xsph)
    w = target / approx
    Z = jnp.nanmean(w)
    log_approx = jnp.log(approx)
    log_target = jnp.log(target)
    klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z)
    ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w))
    ress = 100 * ess / len(w)
    del w, Z, log_approx, approx, log_target, target, xsph
    approx = importance_density(rng_kl, net_params, deq_params, 10000, xobs)
    log_approx = jnp.log(approx)
    target = embedded_sphere_density(xobs)
    w = approx / target
    Z = jnp.nanmean(w)
    log_target = jnp.log(target)
    klpq = jnp.nanmean(log_target - log_approx) + jnp.log(Z)
    del w, Z, log_approx, approx, log_target, target
    method = 'deqode ({})'.format('ELBO' if args.elbo_loss else 'KL')
    print(
        '{} - Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - KL$(p\Vert q)$ = {:.5f} - Rel. ESS: {:.2f}%'
        .format(method, mean_mse, cov_mse, klqp, klpq, ress))
Exemple #2
0
def main():
    # Set pseudo-random number generator keys.
    rng = random.PRNGKey(args.seed)
    rng, rng_net = random.split(rng, 2)
    rng, rng_sample, rng_xobs, rng_basis = random.split(rng, 4)
    rng, rng_fwd, rng_rev = random.split(rng, 3)
    rng, rng_kl = random.split(rng, 2)

    # Initialize the parameters of the ambient vector field network.
    _, params = net_init(rng_net, (-1, 4))
    opt_state = opt_init(params)

    for it in range(args.num_steps):
        opt_state, kl = step(opt_state, it, args.num_samples)
        print('iter.: {} - kl: {:.4f}'.format(it, kl))

    params = get_params(opt_state)
    count = lambda x: jnp.prod(jnp.array(x.shape))
    num_params = jnp.array(
        tree_util.tree_map(count,
                           tree_util.tree_flatten(params)[0])).sum()
    print('number of parameters: {}'.format(num_params))

    # Compute comparison statistics.
    xsph, log_approx = manifold_ode_log_prob(params, rng_sample, 10000)
    xobs = rejection_sampling(rng_xobs, len(xsph), 3, embedded_sphere_density)
    mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0)))
    cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T)))
    approx = jnp.exp(log_approx)
    target = embedded_sphere_density(xsph)
    w = target / approx
    Z = jnp.nanmean(w)
    log_approx = jnp.log(approx)
    log_target = jnp.log(target)
    klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z)
    ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w))
    ress = 100 * ess / len(w)
    del w, Z, log_approx, approx, log_target, target
    log_approx = manifold_reverse_ode_log_prob(params, rng_kl, xobs)
    approx = jnp.exp(log_approx)
    target = embedded_sphere_density(xobs)
    w = approx / target
    Z = jnp.nanmean(w)
    log_target = jnp.log(target)
    klpq = jnp.nanmean(log_target - log_approx) + jnp.log(Z)
    del w, Z, log_approx, approx, log_target, target
    print(
        'manode - Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - KL$(p\Vert q)$ = {:.5f} - Rel. ESS: {:.2f}%'
        .format(mean_mse, cov_mse, klqp, klpq, ress))
Exemple #3
0
def loss(rng: jnp.ndarray, thetax: jnp.ndarray, thetay: jnp.ndarray,
         thetad: jnp.ndarray, paramsm: Sequence[jnp.ndarray], netm: Callable,
         num_samples: int):
    """KL(q || p) loss function to minimize. This is computable up to a constant
    when the target density is known up to proportionality.

    Args:
        rng: Pseudo-random number generator seed.
        thetax: Unconstrained x-coordinates of the spline intervals for radial coordinate.
        thetay: Unconstrained y-coordinates of the spline intervals for radial coordinate.
        thetad: Unconstrained derivatives at internal points for radial coordinate.
        paramsm: Parameters of the neural network giving the conditional
            distribution of the angular parameter.
        netm: Neural network to compute the angular parameter.
        num_samples: Number of samples to draw.

    Returns:
        out: A Monte Carlo estimate of KL(q || p).

    """
    xk, yk, delta = spline_unconstrained_transform(thetax, thetay, thetad)
    (ra, ang), (raunif,
                angunif), w = mobius_spline_sample(rng, num_samples, xk, yk,
                                                   delta, paramsm, netm)
    xsph = torus2sphere(ra, ang)
    mslp = mobius_spline_log_prob(ra, raunif, ang, angunif, w, xk, yk, delta)
    t = embedded_sphere_density(xsph)
    lt = jnp.log(t)
    return jnp.mean(mslp - lt)
def loss(rng: jnp.ndarray, bij_params: Sequence[jnp.ndarray],
         bij_fns: Sequence[Callable], deq_params: Sequence[jnp.ndarray],
         deq_fn: Callable, num_samples: int) -> float:
    """Loss function composed of the evidence lower bound and score matching
    loss.

    Args:
        rng: Pseudo-random number generator seed.
        bij_params: List of arrays parameterizing the RealNVP bijectors.
        bij_fns: List of functions that compute the shift and scale of the RealNVP
            affine transformation.
        deq_params: Parameters of the mean and scale functions used in
            the log-normal dequantizer.
        deq_fn: Function that computes the mean and scale of the dequantization
            distribution.
        num_samples: Number of samples to draw using rejection sampling.

    Returns:
        nelbo: The negative evidence lower bound.

    """
    rng, rng_rej, rng_loss = random.split(rng, 3)
    xsph = rejection_sampling(rng_rej, num_samples, num_dims,
                              embedded_sphere_density)
    if args.elbo_loss:
        nelbo = negative_elbo(rng_loss, bij_params, bij_fns, deq_params,
                              deq_fn, xsph).mean()
        return nelbo
    else:
        log_is = importance_log_density(rng_loss, bij_params, bij_fns,
                                        deq_params, deq_fn,
                                        args.num_importance, xsph)
        log_target = jnp.log(embedded_sphere_density(xsph))
        return jnp.mean(log_target - log_is)
Exemple #5
0
def loss(net_params: List[jnp.ndarray], deq_params: List[jnp.ndarray],
         rng: random.PRNGKey, num_samples: int) -> float:
    rng, rng_rej, rng_loss = random.split(rng, 3)
    xsph = rejection_sampling(rng_rej, num_samples, 4, embedded_sphere_density)
    if args.elbo_loss:
        nelbo = negative_elbo(rng_loss, net_params, deq_params, xsph).mean()
        return nelbo
    else:
        log_is = importance_log_density(rng_loss, net_params, deq_params,
                                        args.num_importance, xsph)
        log_target = jnp.log(embedded_sphere_density(xsph))
        return jnp.mean(log_target - log_is)
def kl_divergence(params: List, rng: random.PRNGKey, num_samples: int) -> float:
    """Computes the KL divergence between the target density and the neural
    manifold ODE's distribution on the sphere. Note that the target density is
    unnormalized.

    Args:
        params: Parameters of the neural manifold ODE.
        rng: Pseudo-random number generator key.
        num_samples: Number of samples use to estimate the KL divergence.

    Returns:
        div: The estimated KL divergence.

    """
    s, log_prob = manifold_ode_log_prob(params, rng, num_samples)
    log_prob_target = jnp.log(embedded_sphere_density(s))
    div = jnp.mean(log_prob - log_prob_target)
    return div
Exemple #7
0
def kl_divergence(net_params: List[jnp.ndarray], deq_params: List[jnp.ndarray],
                  rng: random.PRNGKey, num_samples: int) -> float:
    """Computes the KL divergence between the target density and the neural
    manifold ODE's distribution on the sphere. Note that the target density is
    unnormalized.

    Args:
        params: Parameters of the neural manifold ODE.
        rng: Pseudo-random number generator key.
        num_samples: Number of samples use to estimate the KL divergence.

    Returns:
        div: The estimated KL divergence.

    """
    rng, rng_fwd, rng_is = random.split(rng, 3)
    _, xsph, _ = ode_forward(rng_fwd, net_params, num_samples, 4)
    log_prob = importance_log_density(rng_is, net_params, deq_params, 10, xsph)
    log_prob_target = jnp.log(embedded_sphere_density(xsph))
    div = jnp.mean(log_prob - log_prob_target)
    return div
Exemple #8
0
num_params = num_theta + num_paramsm
print('number of parameters: {}'.format(num_params))

# Train normalizing flow on the sphere.
(thetax, thetay, thetad,
 paramsm), trace = train(rng_train, thetax, thetay, thetad, paramsm, netm,
                         args.num_samples, args.num_steps, args.lr)
num_samples = 100000
xk, yk, delta = spline_unconstrained_transform(thetax, thetay, thetad)

# Compute comparison statistics.
(ra, ang), (raunif, angunif), w = mobius_spline_sample(rng_ms, num_samples, xk,
                                                       yk, delta, paramsm,
                                                       netm)
xsph = torus2sphere(ra, ang)
log_approx = mobius_spline_log_prob(ra, raunif, ang, angunif, w, xk, yk, delta)
approx = jnp.exp(log_approx)
target = embedded_sphere_density(xsph)
log_target = jnp.log(target)
w = target / approx
Z = jnp.mean(w)
kl = jnp.mean(log_approx - log_target) + jnp.log(Z)
ess = jnp.square(jnp.sum(w)) / jnp.sum(jnp.square(w))
ress = 100 * ess / len(w)
xobs = rejection_sampling(rng_xobs, len(xsph), 3, embedded_sphere_density)
mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0)))
cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T)))
print(
    'normalizing - Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - Rel. ESS: {:.2f}%'
    .format(mean_mse, cov_mse, kl, ress))
# Estimate parameters of the dequantizer and ambient flow.
(bij_params, deq_params), trace = train(rng_train, bij_params, deq_params,
                                        args.num_steps, args.lr,
                                        args.num_batch)

# Sample using dequantization and rejection sampling.
xamb, xsph = sample_ambient(rng_xamb, 100000, bij_params, bij_fns, num_dims)
xobs = rejection_sampling(rng_xobs, len(xsph), num_dims,
                          embedded_sphere_density)

# Compute comparison statistics.
mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0)))
cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T)))
approx = importance_density(rng_kl, bij_params, deq_params, 1000, xsph)
target = embedded_sphere_density(xsph)
w = target / approx
Z = jnp.nanmean(w)
log_approx = jnp.log(approx)
log_target = jnp.log(target)
klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z)
ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w))
ress = 100 * ess / len(w)
del w, Z, log_approx, approx, log_target, target
approx = importance_density(rng_kl, bij_params, deq_params, 1000, xobs)
target = embedded_sphere_density(xobs)
w = approx / target
Z = jnp.nanmean(w)
log_approx = jnp.log(approx)
log_target = jnp.log(target)
klpq = jnp.nanmean(log_target - log_approx) + jnp.log(Z)