def linear_model_ground_truth(model, design, observation_labels, target_labels, eig=True): if isinstance(target_labels, str): target_labels = [target_labels] w_sd = torch.cat(list(model.w_sds.values()), dim=-1) prior_cov = torch.diag(w_sd**2) design_shape = design.shape posterior_covs = [ analytic_posterior_cov(prior_cov, x, model.obs_sd) for x in torch.unbind( design.reshape(-1, design_shape[-2], design_shape[-1])) ] target_indices = get_indices(target_labels, tensors=model.w_sds) target_posterior_covs = [ S[target_indices, :][:, target_indices] for S in posterior_covs ] output = torch.tensor([ 0.5 * torch.logdet(2 * math.pi * math.e * C) for C in target_posterior_covs ]) if eig: prior_entropy = mean_field_entropy(model, [design], whitelist=target_labels) output = prior_entropy - output return output.reshape(design.shape[:-2])
def loss_fn(design, num_particles, evaluation=False, **kwargs): expanded_design = lexpand(design, num_particles) # Sample from p(y, theta | d) trace = poutine.trace(model).get_trace(expanded_design) y_dict = {l: trace.nodes[l]["value"] for l in observation_labels} theta_dict = {l: trace.nodes[l]["value"] for l in target_labels} # Run through q(theta | y, d) conditional_guide = pyro.condition(guide, data=theta_dict) cond_trace = poutine.trace(conditional_guide).get_trace( y_dict, expanded_design, observation_labels, target_labels) cond_trace.compute_log_prob() if evaluation and analytic_entropy: loss = mean_field_entropy( guide, [y_dict, expanded_design, observation_labels, target_labels], whitelist=target_labels).sum(0) / num_particles agg_loss = loss.sum() else: terms = -sum(cond_trace.nodes[l]["log_prob"] for l in target_labels) agg_loss, loss = _safe_mean_terms(terms) return agg_loss, loss
def _eig_from_ape(model, design, target_labels, ape, eig, prior_entropy_kwargs): mean_field = prior_entropy_kwargs.get("mean_field", True) if eig: if mean_field: try: prior_entropy = mean_field_entropy(model, [design], whitelist=target_labels) except NotImplemented: prior_entropy = monte_carlo_entropy(model, design, target_labels, **prior_entropy_kwargs) else: prior_entropy = monte_carlo_entropy(model, design, target_labels, **prior_entropy_kwargs) return prior_entropy - ape else: return ape
def posterior_entropy(y_dist, design): # Important that y_dist is sampled *within* the function y = pyro.sample("conditioning_y", y_dist) y_dict = {label: y[i, ...] for i, label in enumerate(observation_labels)} conditioned_model = pyro.condition(model, data=y_dict) svi = SVI(conditioned_model, **vi_parameters) with poutine.block(): for _ in range(svi_num_steps): svi.step(design) # Recover the entropy with poutine.block(): guide = vi_parameters["guide"] entropy = mean_field_entropy(guide, [design], whitelist=target_labels) return entropy
def posterior_entropy(y_dist, design): # Important that y_dist is sampled *within* the function y = pyro.sample("conditioning_y", y_dist) y_dict = {label: y[i, ...] for i, label in enumerate(observation_labels)} conditioned_model = pyro.condition(model, data=y_dict) # Here just using SVI to run the MAP optimization guide.train() svi = SVI(conditioned_model, guide=guide, loss=loss, optim=optim) with poutine.block(): for _ in range(num_steps): svi.step(design) # Recover the entropy with poutine.block(): final_loss = loss(conditioned_model, guide, design) guide.finalize(final_loss, target_labels) entropy = mean_field_entropy(guide, [design], whitelist=target_labels) return entropy
def test_guide_entropy(guide, args, expected_entropy): assert_equal(mean_field_entropy(guide, args), expected_entropy)