Exemplo n.º 1
0
def marginal(guide, num_samples=100):
    posterior_predictive = Predictive(guide, num_samples=num_samples)
    posterior_samples = posterior_predictive.forward(data)

    mu_mean = posterior_samples['mu'].detach().mean(dim=0)

    beta_mean = posterior_samples['beta'].detach().mean(dim=0)
    weights_mean = mix_weights(beta_mean)

    centers, _ = truncate(alpha, mu_mean, weights_mean)

    return centers
Exemplo n.º 2
0
def marginal(guide, num_samples=25):
    posterior_predictive = Predictive(guide, num_samples=num_samples)
    posterior_samples = posterior_predictive.forward(data)

    mu_mean = posterior_samples['mu'].detach().mean(dim=0).squeeze()
    prec_mean = posterior_samples['prec'].detach().mean(dim=0).squeeze()
    corr_chol_mean = posterior_samples['corr_chol'].detach().mean(dim=0).squeeze()

    _std_mean = torch.sqrt(1. / prec_mean)
    _sigma_chol_mean = torch.mm(torch.diag(_std_mean), corr_chol_mean)
    sigma_mean = torch.mm(_sigma_chol_mean, _sigma_chol_mean.T)

    return mu_mean, sigma_mean
Exemplo n.º 3
0
def marginal(guide, num_samples=25):
    posterior_predictive = Predictive(guide, num_samples=num_samples)
    posterior_samples = posterior_predictive.forward(data)

    mu_mean = posterior_samples['mu'].detach().mean(dim=0)
    prec_mean = posterior_samples['prec'].detach().mean(dim=0)

    corr_mean = torch.zeros(T, D, D)
    for t in range(T):
        corr_mean[t, ...] = posterior_samples['corr_chol_{}'.format(
            t)].detach().mean(dim=0)

    beta_mean = posterior_samples['beta'].detach().mean(dim=0)
    weights_mean = mix_weights(beta_mean)

    centers, sigmas, _ = truncate(alpha, mu_mean, prec_mean, corr_mean,
                                  weights_mean)

    return centers, sigmas