Esempio n. 1
0
def linear_model_ground_truth(model,
                              design,
                              observation_labels,
                              target_labels,
                              eig=True):
    if isinstance(target_labels, str):
        target_labels = [target_labels]

    w_sd = torch.cat(list(model.w_sds.values()), dim=-1)
    prior_cov = torch.diag(w_sd**2)
    design_shape = design.shape
    posterior_covs = [
        analytic_posterior_cov(prior_cov, x, model.obs_sd)
        for x in torch.unbind(
            design.reshape(-1, design_shape[-2], design_shape[-1]))
    ]
    target_indices = get_indices(target_labels, tensors=model.w_sds)
    target_posterior_covs = [
        S[target_indices, :][:, target_indices] for S in posterior_covs
    ]
    output = torch.tensor([
        0.5 * torch.logdet(2 * math.pi * math.e * C)
        for C in target_posterior_covs
    ])
    if eig:
        prior_entropy = mean_field_entropy(model, [design],
                                           whitelist=target_labels)
        output = prior_entropy - output

    return output.reshape(design.shape[:-2])
Esempio n. 2
0
    def loss_fn(design, num_particles, evaluation=False, **kwargs):

        expanded_design = lexpand(design, num_particles)

        # Sample from p(y, theta | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {l: trace.nodes[l]["value"] for l in observation_labels}
        theta_dict = {l: trace.nodes[l]["value"] for l in target_labels}

        # Run through q(theta | y, d)
        conditional_guide = pyro.condition(guide, data=theta_dict)
        cond_trace = poutine.trace(conditional_guide).get_trace(
            y_dict, expanded_design, observation_labels, target_labels)
        cond_trace.compute_log_prob()
        if evaluation and analytic_entropy:
            loss = mean_field_entropy(
                guide,
                [y_dict, expanded_design, observation_labels, target_labels],
                whitelist=target_labels).sum(0) / num_particles
            agg_loss = loss.sum()
        else:
            terms = -sum(cond_trace.nodes[l]["log_prob"]
                         for l in target_labels)
            agg_loss, loss = _safe_mean_terms(terms)

        return agg_loss, loss
Esempio n. 3
0
def _eig_from_ape(model, design, target_labels, ape, eig, prior_entropy_kwargs):
    mean_field = prior_entropy_kwargs.get("mean_field", True)
    if eig:
        if mean_field:
            try:
                prior_entropy = mean_field_entropy(model, [design], whitelist=target_labels)
            except NotImplemented:
                prior_entropy = monte_carlo_entropy(model, design, target_labels, **prior_entropy_kwargs)
        else:
            prior_entropy = monte_carlo_entropy(model, design, target_labels, **prior_entropy_kwargs)
        return prior_entropy - ape
    else:
        return ape
Esempio n. 4
0
 def posterior_entropy(y_dist, design):
     # Important that y_dist is sampled *within* the function
     y = pyro.sample("conditioning_y", y_dist)
     y_dict = {label: y[i, ...] for i, label in enumerate(observation_labels)}
     conditioned_model = pyro.condition(model, data=y_dict)
     svi = SVI(conditioned_model, **vi_parameters)
     with poutine.block():
         for _ in range(svi_num_steps):
             svi.step(design)
     # Recover the entropy
     with poutine.block():
         guide = vi_parameters["guide"]
         entropy = mean_field_entropy(guide, [design], whitelist=target_labels)
     return entropy
Esempio n. 5
0
 def posterior_entropy(y_dist, design):
     # Important that y_dist is sampled *within* the function
     y = pyro.sample("conditioning_y", y_dist)
     y_dict = {label: y[i, ...] for i, label in enumerate(observation_labels)}
     conditioned_model = pyro.condition(model, data=y_dict)
     # Here just using SVI to run the MAP optimization
     guide.train()
     svi = SVI(conditioned_model, guide=guide, loss=loss, optim=optim)
     with poutine.block():
         for _ in range(num_steps):
             svi.step(design)
     # Recover the entropy
     with poutine.block():
         final_loss = loss(conditioned_model, guide, design)
         guide.finalize(final_loss, target_labels)
         entropy = mean_field_entropy(guide, [design], whitelist=target_labels)
     return entropy
Esempio n. 6
0
def test_guide_entropy(guide, args, expected_entropy):
    assert_equal(mean_field_entropy(guide, args), expected_entropy)