示例#1
0
文件: eig.py 项目: youisbaby/pyro
def _vi_ape(model, design, observation_labels, target_labels, vi_parameters, is_parameters, y_dist=None):
    svi_num_steps = vi_parameters.pop('num_steps')

    def posterior_entropy(y_dist, design):
        # Important that y_dist is sampled *within* the function
        y = pyro.sample("conditioning_y", y_dist)
        y_dict = {label: y[i, ...] for i, label in enumerate(observation_labels)}
        conditioned_model = pyro.condition(model, data=y_dict)
        svi = SVI(conditioned_model, **vi_parameters)
        with poutine.block():
            for _ in range(svi_num_steps):
                svi.step(design)
        # Recover the entropy
        with poutine.block():
            guide = vi_parameters["guide"]
            entropy = mean_field_entropy(guide, [design], whitelist=target_labels)
        return entropy

    if y_dist is None:
        y_dist = EmpiricalMarginal(Importance(model, **is_parameters).run(design),
                                   sites=observation_labels)

    # Calculate the expected posterior entropy under this distn of y
    loss_dist = EmpiricalMarginal(Search(posterior_entropy).run(y_dist, design))
    loss = loss_dist.mean

    return loss
示例#2
0
文件: eig.py 项目: youisbaby/pyro
def _laplace_vi_ape(model, design, observation_labels, target_labels, guide, loss, optim, num_steps,
                    final_num_samples, y_dist=None):

    def posterior_entropy(y_dist, design):
        # Important that y_dist is sampled *within* the function
        y = pyro.sample("conditioning_y", y_dist)
        y_dict = {label: y[i, ...] for i, label in enumerate(observation_labels)}
        conditioned_model = pyro.condition(model, data=y_dict)
        # Here just using SVI to run the MAP optimization
        guide.train()
        svi = SVI(conditioned_model, guide=guide, loss=loss, optim=optim)
        with poutine.block():
            for _ in range(num_steps):
                svi.step(design)
        # Recover the entropy
        with poutine.block():
            final_loss = loss(conditioned_model, guide, design)
            guide.finalize(final_loss, target_labels)
            entropy = mean_field_entropy(guide, [design], whitelist=target_labels)
        return entropy

    if y_dist is None:
        y_dist = EmpiricalMarginal(Importance(model, num_samples=final_num_samples).run(design),
                                   sites=observation_labels)

    # Calculate the expected posterior entropy under this distn of y
    loss_dist = EmpiricalMarginal(Search(posterior_entropy).run(y_dist, design))
    ape = loss_dist.mean

    return ape
示例#3
0
文件: eig.py 项目: zyxue/pyro
def vi_ape(model,
           design,
           observation_labels,
           target_labels,
           vi_parameters,
           is_parameters,
           y_dist=None):
    """Estimates the average posterior entropy (APE) loss function using
    variational inference (VI).

    The APE loss function estimated by this method is defined as

        :math:`APE(d)=E_{Y\\sim p(y|\\theta, d)}[H(p(\\theta|Y, d))]`

    where :math:`H[p(x)]` is the `differential entropy
    <https://en.wikipedia.org/wiki/Differential_entropy>`_.
    The APE is related to expected information gain (EIG) by the equation

        :math:`EIG(d)=H[p(\\theta)]-APE(d)`

    in particular, minimising the APE is equivalent to maximising EIG.

    :param function model: A pyro model accepting `design` as only argument.
    :param torch.Tensor design: Tensor representation of design
    :param list observation_labels: A subset of the sample sites
        present in `model`. These sites are regarded as future observations
        and other sites are regarded as latent variables over which a
        posterior is to be inferred.
    :param list target_labels: A subset of the sample sites over which the posterior
        entropy is to be measured.
    :param dict vi_parameters: Variational inference parameters which should include:
        `optim`: an instance of :class:`pyro.Optim`, `guide`: a guide function
        compatible with `model`, `num_steps`: the number of VI steps to make,
        and `loss`: the loss function to use for VI
    :param dict is_parameters: Importance sampling parameters for the
        marginal distribution of :math:`Y`. May include `num_samples`: the number
        of samples to draw from the marginal.
    :param pyro.distributions.Distribution y_dist: (optional) the distribution
        assumed for the response variable :math:`Y`
    :return: Loss function estimate
    :rtype: `torch.Tensor`

    """

    if isinstance(observation_labels, str):
        observation_labels = [observation_labels]
    if target_labels is not None and isinstance(target_labels, str):
        target_labels = [target_labels]

    def posterior_entropy(y_dist, design):
        # Important that y_dist is sampled *within* the function
        y = pyro.sample("conditioning_y", y_dist)
        y_dict = {
            label: y[i, ...]
            for i, label in enumerate(observation_labels)
        }
        conditioned_model = pyro.condition(model, data=y_dict)
        SVI(conditioned_model, **vi_parameters).run(design)
        # Recover the entropy
        return mean_field_guide_entropy(vi_parameters["guide"], [design],
                                        whitelist=target_labels)

    if y_dist is None:
        y_dist = EmpiricalMarginal(Importance(model,
                                              **is_parameters).run(design),
                                   sites=observation_labels)

    # Calculate the expected posterior entropy under this distn of y
    loss_dist = EmpiricalMarginal(
        Search(posterior_entropy).run(y_dist, design))
    loss = loss_dist.mean

    return loss