def deep_nelbo(deep_ps, our_ps, batch, lhs):
     ws_batch = pushforward(deep_ps, static_deep_cs_collection[0], batch)
     ws_batch = uniform_to_cell(ws_batch, lhs[0], lhs[1])
     log_det_J_layer_1 = log_det_J_u(deep_ps, static_deep_cs_collection[0],
                                     batch)
     log_det_J_layer_2 = log_det_J_u(our_ps, cs, ws_batch)
     p_densities = input_fun(pushforward(our_ps, cs, ws_batch))
     return -np.mean(log_det_J_layer_1 + log_det_J_layer_2 + p_densities)
def compute_uniform_q(ps, cs, log_p, dim, n_samples, tol, rng, num_boxes=1):
    num_cells = 2**dim
    deep_ps, deep_cs, lhs = init_double_deep_collection(
        dim, 4, num_boxes, num_cells, rng)
    elbos = []
    for (l, h) in lhs:
        batch = sample_our_uniform(n_samples, d=dim, rng=rng, tol=1e-4)
        ws_batch = uniform_to_cell(batch, l, h)
        zs = pushforward(ps, cs, ws_batch)
        elbo_samples = log_p(zs) + log_det_J_u(ps, cs, ws_batch)
        elbo = elbo_samples.mean()
        elbos.append(elbo)
    return logsumexp(elbos) - np.log(len(lhs))
def compute_importance_sampling(ps, cs, log_p, dim, n_samples, tol, rng):
    ws = sample_our_uniform(n_samples, d=dim, rng=rng, tol=1e-4)
    zs = pushforward(ps, cs, ws)
    importance_sampling_samples = np.exp(log_p(zs) - log_det_J_u(ps, cs, ws))
    return importance_sampling_samples.mean(), importance_sampling_samples.std(
    )
def compute_base_elbo(ps, cs, log_p, dim, n_samples, tol, rng):
    ws = sample_our_uniform(n_samples, d=dim, rng=rng, tol=1e-4)
    zs = pushforward(ps, cs, ws)
    elbo_samples = log_p(zs) + log_det_J_u(ps, cs, ws)
    return elbo_samples.mean(), elbo_samples.std()
def elbo_loss(params, cs, ws, p_density):
    zs = pushforward(params, cs, ws)
    elbo_samples = p_density(zs) + log_det_J_u(params, cs, ws)
    return -np.sum(elbo_samples)
 def shallow_nelbo(our_ps, batch):
     batch = batch.reshape((-1, dim))
     log_det_J_layer_2 = log_det_J_u(our_ps, cs, batch)
     p_densities = input_fun(pushforward(our_ps, cs, batch))
     return -np.mean(log_det_J_layer_2 + p_densities)