def _vi_ape(model, design, observation_labels, target_labels, vi_parameters, is_parameters, y_dist=None): svi_num_steps = vi_parameters.pop('num_steps') def posterior_entropy(y_dist, design): # Important that y_dist is sampled *within* the function y = pyro.sample("conditioning_y", y_dist) y_dict = {label: y[i, ...] for i, label in enumerate(observation_labels)} conditioned_model = pyro.condition(model, data=y_dict) svi = SVI(conditioned_model, **vi_parameters) with poutine.block(): for _ in range(svi_num_steps): svi.step(design) # Recover the entropy with poutine.block(): guide = vi_parameters["guide"] entropy = mean_field_entropy(guide, [design], whitelist=target_labels) return entropy if y_dist is None: y_dist = EmpiricalMarginal(Importance(model, **is_parameters).run(design), sites=observation_labels) # Calculate the expected posterior entropy under this distn of y loss_dist = EmpiricalMarginal(Search(posterior_entropy).run(y_dist, design)) loss = loss_dist.mean return loss
def _laplace_vi_ape(model, design, observation_labels, target_labels, guide, loss, optim, num_steps, final_num_samples, y_dist=None): def posterior_entropy(y_dist, design): # Important that y_dist is sampled *within* the function y = pyro.sample("conditioning_y", y_dist) y_dict = {label: y[i, ...] for i, label in enumerate(observation_labels)} conditioned_model = pyro.condition(model, data=y_dict) # Here just using SVI to run the MAP optimization guide.train() svi = SVI(conditioned_model, guide=guide, loss=loss, optim=optim) with poutine.block(): for _ in range(num_steps): svi.step(design) # Recover the entropy with poutine.block(): final_loss = loss(conditioned_model, guide, design) guide.finalize(final_loss, target_labels) entropy = mean_field_entropy(guide, [design], whitelist=target_labels) return entropy if y_dist is None: y_dist = EmpiricalMarginal(Importance(model, num_samples=final_num_samples).run(design), sites=observation_labels) # Calculate the expected posterior entropy under this distn of y loss_dist = EmpiricalMarginal(Search(posterior_entropy).run(y_dist, design)) ape = loss_dist.mean return ape
def vi_ape(model, design, observation_labels, target_labels, vi_parameters, is_parameters, y_dist=None): """Estimates the average posterior entropy (APE) loss function using variational inference (VI). The APE loss function estimated by this method is defined as :math:`APE(d)=E_{Y\\sim p(y|\\theta, d)}[H(p(\\theta|Y, d))]` where :math:`H[p(x)]` is the `differential entropy <https://en.wikipedia.org/wiki/Differential_entropy>`_. The APE is related to expected information gain (EIG) by the equation :math:`EIG(d)=H[p(\\theta)]-APE(d)` in particular, minimising the APE is equivalent to maximising EIG. :param function model: A pyro model accepting `design` as only argument. :param torch.Tensor design: Tensor representation of design :param list observation_labels: A subset of the sample sites present in `model`. These sites are regarded as future observations and other sites are regarded as latent variables over which a posterior is to be inferred. :param list target_labels: A subset of the sample sites over which the posterior entropy is to be measured. :param dict vi_parameters: Variational inference parameters which should include: `optim`: an instance of :class:`pyro.Optim`, `guide`: a guide function compatible with `model`, `num_steps`: the number of VI steps to make, and `loss`: the loss function to use for VI :param dict is_parameters: Importance sampling parameters for the marginal distribution of :math:`Y`. May include `num_samples`: the number of samples to draw from the marginal. :param pyro.distributions.Distribution y_dist: (optional) the distribution assumed for the response variable :math:`Y` :return: Loss function estimate :rtype: `torch.Tensor` """ if isinstance(observation_labels, str): observation_labels = [observation_labels] if target_labels is not None and isinstance(target_labels, str): target_labels = [target_labels] def posterior_entropy(y_dist, design): # Important that y_dist is sampled *within* the function y = pyro.sample("conditioning_y", y_dist) y_dict = { label: y[i, ...] for i, label in enumerate(observation_labels) } conditioned_model = pyro.condition(model, data=y_dict) SVI(conditioned_model, **vi_parameters).run(design) # Recover the entropy return mean_field_guide_entropy(vi_parameters["guide"], [design], whitelist=target_labels) if y_dist is None: y_dist = EmpiricalMarginal(Importance(model, **is_parameters).run(design), sites=observation_labels) # Calculate the expected posterior entropy under this distn of y loss_dist = EmpiricalMarginal( Search(posterior_entropy).run(y_dist, design)) loss = loss_dist.mean return loss