Ejemplo n.º 1
0
def log_logistic_volume_flow_marginal_estimate(recon_x_mu, recon_x_logvar, x,
                                               zk, z0, z0_mu, z0_logvar):
    batch_size, n_samples, z_dim = z0.size()
    input_dim = x.size(1)
    x = x.unsqueeze(1).repeat(1, n_samples, 1)

    z0_2d = z0.view(batch_size * n_samples, z_dim)
    zk_2d = zk.view(batch_size * n_samples, z_dim)
    z0_mu_2d = z0_mu.view(batch_size * n_samples, z_dim)
    z0_logvar_2d = z0_logvar.view(batch_size * n_samples, z_dim)
    recon_x_mu_2d = recon_x_mu.view(batch_size * n_samples, input_dim)
    recon_x_logvar_2d = recon_x_logvar.view(batch_size * n_samples, input_dim)
    x_2d = x.view(batch_size * n_samples, input_dim)

    log_p_x_given_zk_2d = logistic_256_log_pdf(x_2d, recon_x_mu_2d,
                                               recon_x_logvar_2d)
    log_q_z0_given_x_2d = gaussian_log_pdf(z0_2d, z0_mu_2d, z0_logvar_2d)
    log_q_zk_given_x_2d = log_q_z0_given_x_2d  # diff
    log_p_zk_2d = unit_gaussian_log_pdf(zk_2d)

    log_weight_2d = log_p_x_given_zk_2d + log_p_zk_2d - log_q_zk_given_x_2d
    log_weight = log_weight_2d.view(batch_size, n_samples)

    log_p_x = log_mean_exp(log_weight, dim=1)
    return -torch.mean(log_p_x)
Ejemplo n.º 2
0
def gaussian_free_energy_bound(recon_x_mu,
                               recon_x_logvar,
                               x,
                               zk,
                               z0,
                               z_mu,
                               z_logvar,
                               log_abs_det_jacobian=None,
                               beta=1.):
    assert z0.size() == zk.size()
    n_samples = recon_x_mu.size(1)

    BOUND = 0
    for i in xrange(n_samples):
        log_p_x_given_z = logistic_256_log_pdf(x, recon_x_mu[:, i],
                                               recon_x_logvar[:, i])
        log_q_z0_given_x = gaussian_log_pdf(z0[:, i], z_mu[:, i], z_logvar[:,
                                                                           i])
        log_p_zk = unit_gaussian_log_pdf(zk[:, i])

        if log_abs_det_jacobian is not None:
            log_q_zk_given_x = log_q_z0_given_x - log_abs_det_jacobian[:, i]
        else:
            log_q_zk_given_x = log_q_z0_given_x

        BOUND_i = log_p_x_given_z + beta * (log_p_zk - log_q_zk_given_x)
        BOUND += BOUND_i

    BOUND = BOUND / float(n_samples)
    BOUND = -BOUND
    BOUND = torch.mean(BOUND)

    return BOUND
Ejemplo n.º 3
0
def gaussian_elbo_loss(recon_x_mu, recon_x_logvar, x, z, z_mu, z_logvar):
    n_samples = recon_x_mu.size(1)

    ELBO = 0
    for i in xrange(n_samples):
        BCE = -logistic_256_log_pdf(x, recon_x_mu[:, i], recon_x_logvar[:, i])
        KLD = -0.5 * (1 + z_logvar[:, i] - z_mu[:, i].pow(2) -
                      z_logvar[:, i].exp())
        KLD = torch.sum(KLD, dim=1)

        ELBO_i = BCE + KLD
        ELBO += ELBO_i

    ELBO = ELBO / float(n_samples)
    ELBO = torch.mean(ELBO)

    return ELBO
Ejemplo n.º 4
0
def weighted_gaussian_elbo_loss(recon_x_mu, recon_x_logvar, x, z, z_mu,
                                z_logvar):
    n_samples = recon_x_mu.size(1)

    log_ws = []
    for i in xrange(n_samples):
        log_p_x_given_z = logistic_256_log_pdf(x, recon_x_mu[:, i],
                                               recon_x_logvar[:, i])
        log_q_z_given_x = gaussian_log_pdf(z[:, i], z_mu[:, i], z_logvar[:, i])
        log_p_z = unit_gaussian_log_pdf(z[:, i])

        log_ws_i = log_p_x_given_z + log_p_z - log_q_z_given_x
        log_ws.append(log_ws_i.unsqueeze(1))

    log_ws = torch.cat(log_ws, dim=1)
    log_ws = log_mean_exp(log_ws, dim=1)
    BOUND = -torch.mean(log_ws)

    return BOUND