Example #1
0
    def loss_fn(design, num_particles, evaluation=False, **kwargs):
        N, M = num_particles
        expanded_design = lexpand(design, N)

        # Sample from p(y, theta | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {l: lexpand(trace.nodes[l]["value"], M) for l in observation_labels}

        # Sample M times from q(theta | y, d) for each y
        reexpanded_design = lexpand(expanded_design, M)
        conditional_guide = pyro.condition(guide, data=y_dict)
        guide_trace = poutine.trace(conditional_guide).get_trace(
            y_dict, reexpanded_design, observation_labels, target_labels)
        theta_y_dict = {l: guide_trace.nodes[l]["value"] for l in target_labels}
        theta_y_dict.update(y_dict)
        guide_trace.compute_log_prob()

        # Re-run that through the model to compute the joint
        modelp = pyro.condition(model, data=theta_y_dict)
        model_trace = poutine.trace(modelp).get_trace(reexpanded_design)
        model_trace.compute_log_prob()

        terms = -sum(guide_trace.nodes[l]["log_prob"] for l in target_labels)
        terms += sum(model_trace.nodes[l]["log_prob"] for l in target_labels)
        terms += sum(model_trace.nodes[l]["log_prob"] for l in observation_labels)
        terms = -terms.logsumexp(0) + math.log(M)

        # At eval time, add p(y | theta, d) terms
        if evaluation:
            trace.compute_log_prob()
            terms += sum(trace.nodes[l]["log_prob"] for l in observation_labels)

        return _safe_mean_terms(terms)
Example #2
0
    def __init__(self, d, w_sizes, y_sizes, regressor_init=0., scale_tril_init=3., use_softplus=True, **kwargs):
        """
        Guide for linear models. No amortisation happens over designs.
        Amortisation over data is taken care of by analytic formulae for
        linear models (heavy use of truth).

        :param tuple d: the shape by which to expand the guide parameters, e.g. `(num_batches, num_designs)`.
        :param dict w_sizes: map from variable string names to int, indicating the dimension of each
                             weight vector in the linear model.
        :param float regressor_init: initial value for the regressor matrix used to learn the posterior mean.
        :param float scale_tril_init: initial value for posterior `scale_tril` parameter.
        :param bool use_softplus: whether to transform the regressor by a softplus transform: useful if the
                                  regressor should be nonnegative but close to zero.
        """
        super().__init__()
        # Represent each parameter group as independent Gaussian
        # Making a weak mean-field assumption
        # To avoid this- combine labels
        if not isinstance(d, (tuple, list, torch.Tensor)):
            d = (d,)
        self.regressor = nn.ParameterDict({l: nn.Parameter(
                regressor_init * torch.ones(*(d + (p, sum(y_sizes.values()))))) for l, p in w_sizes.items()})
        self.scale_tril = nn.ParameterDict({l: nn.Parameter(
                scale_tril_init * lexpand(torch.eye(p), *d)) for l, p in w_sizes.items()})
        self.w_sizes = w_sizes
        self.use_softplus = use_softplus
        self.softplus = nn.Softplus()
Example #3
0
    def loss_fn(design, num_particles, evaluation=False, **kwargs):

        try:
            pyro.module("h", h)
        except AssertionError:
            pass

        expanded_design = lexpand(design, num_particles)
        model_conditional_trace = poutine.trace(model_conditional).get_trace(
            expanded_design)

        if not evaluation:
            model_marginal_trace = poutine.trace(model_marginal).get_trace(
                expanded_design)

            h_joint = h(expanded_design, model_conditional_trace,
                        observation_labels, target_labels)
            h_independent = h(expanded_design, model_marginal_trace,
                              observation_labels, target_labels)

            terms = torch.nn.functional.softplus(
                -h_joint) + torch.nn.functional.softplus(h_independent)
            return _safe_mean_terms(terms)

        else:
            h_joint = h(expanded_design, model_conditional_trace,
                        observation_labels, target_labels)
            return _safe_mean_terms(h_joint)
Example #4
0
    def loss_fn(design, num_particles, **kwargs):

        try:
            pyro.module("T", T)
        except AssertionError:
            pass

        expanded_design = lexpand(design, num_particles)

        # Unshuffled data
        unshuffled_trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {l: unshuffled_trace.nodes[l]["value"] for l in observation_labels}

        # Shuffled data
        # Not actually shuffling, resimulate for safety
        conditional_model = pyro.condition(model, data=y_dict)
        shuffled_trace = poutine.trace(conditional_model).get_trace(expanded_design)

        T_joint = T(expanded_design, unshuffled_trace, observation_labels, target_labels)
        T_independent = T(expanded_design, shuffled_trace, observation_labels, target_labels)

        joint_expectation = T_joint.sum(0)/num_particles

        A = T_independent - math.log(num_particles)
        s, _ = torch.max(A, dim=0)
        independent_expectation = s + ewma_log((A - s).exp().sum(dim=0), s)

        loss = joint_expectation - independent_expectation
        # Switch sign, sum over batch dimensions for scalar loss
        agg_loss = -loss.sum()
        return agg_loss, loss
Example #5
0
    def loss_fn(design, num_particles, evaluation=False, **kwargs):

        expanded_design = lexpand(design, num_particles)

        # Sample from p(y | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {l: trace.nodes[l]["value"] for l in observation_labels}
        theta_dict = {l: trace.nodes[l]["value"] for l in target_labels}

        # Run through q(y | d)
        qyd = pyro.condition(marginal_guide, data=y_dict)
        marginal_trace = poutine.trace(qyd).get_trace(
             expanded_design, observation_labels, target_labels)
        marginal_trace.compute_log_prob()

        # Run through q(y | theta, d)
        qythetad = pyro.condition(likelihood_guide, data=y_dict)
        cond_trace = poutine.trace(qythetad).get_trace(
                theta_dict, expanded_design, observation_labels, target_labels)
        cond_trace.compute_log_prob()
        terms = -sum(marginal_trace.nodes[l]["log_prob"] for l in observation_labels)

        # At evaluation time, use the right estimator, q(y | theta, d) - q(y | d)
        # At training time, use -q(y | theta, d) - q(y | d) so gradients go the same way
        if evaluation:
            terms += sum(cond_trace.nodes[l]["log_prob"] for l in observation_labels)
        else:
            terms -= sum(cond_trace.nodes[l]["log_prob"] for l in observation_labels)

        return _safe_mean_terms(terms)
Example #6
0
    def loss_fn(design, num_particles, evaluation=False, **kwargs):

        expanded_design = lexpand(design, num_particles)

        # Sample from p(y, theta | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {l: trace.nodes[l]["value"] for l in observation_labels}
        theta_dict = {l: trace.nodes[l]["value"] for l in target_labels}

        # Run through q(theta | y, d)
        conditional_guide = pyro.condition(guide, data=theta_dict)
        cond_trace = poutine.trace(conditional_guide).get_trace(
            y_dict, expanded_design, observation_labels, target_labels)
        cond_trace.compute_log_prob()
        if evaluation and analytic_entropy:
            loss = mean_field_entropy(
                guide,
                [y_dict, expanded_design, observation_labels, target_labels],
                whitelist=target_labels).sum(0) / num_particles
            agg_loss = loss.sum()
        else:
            terms = -sum(cond_trace.nodes[l]["log_prob"]
                         for l in target_labels)
            agg_loss, loss = _safe_mean_terms(terms)

        return agg_loss, loss
Example #7
0
def lfire_eig(model, design, observation_labels, target_labels,
              num_y_samples, num_theta_samples, num_steps, classifier, optim, return_history=False,
              final_design=None, final_num_samples=None):
    """Estimates the EIG using the method of Likelihood-Free Inference by Ratio Estimation (LFIRE) as in [1].
    LFIRE is run separately for several samples of :math:`\\theta`.

    [1] Kleinegesse, Steven, and Michael Gutmann. "Efficient Bayesian Experimental Design for Implicit Models."
    arXiv preprint arXiv:1810.09912 (2018).

    :param function model: A pyro model accepting `design` as only argument.
    :param torch.Tensor design: Tensor representation of design
    :param list observation_labels: A subset of the sample sites
        present in `model`. These sites are regarded as future observations
        and other sites are regarded as latent variables over which a
        posterior is to be inferred.
    :param list target_labels: A subset of the sample sites over which the posterior
        entropy is to be measured.
    :param int num_y_samples: Number of samples to take in :math:`y` for each :math:`\\theta`.
    :param: int num_theta_samples: Number of initial samples in :math:`\\theta` to take. The likelihood ratio
                                   is estimated by LFIRE for each sample.
    :param int num_steps: Number of optimization steps.
    :param function classifier: a Pytorch or Pyro classifier used to distinguish between samples of :math:`y` under
                                :math:`p(y|d)` and samples under :math:`p(y|\\theta,d)` for some :math:`\\theta`.
    :param pyro.optim.Optim optim: Optimiser to use.
    :param bool return_history: If `True`, also returns a tensor giving the loss function
        at each step of the optimization.
    :param torch.Tensor final_design: The final design tensor to evaluate at. If `None`, uses
        `design`.
    :param int final_num_samples: The number of samples to use at the final evaluation, If `None,
        uses `num_samples`.
    :return: EIG estimate, optionally includes full optimization history
    :rtype: torch.Tensor or tuple
    """
    if isinstance(observation_labels, str):
        observation_labels = [observation_labels]
    if isinstance(target_labels, str):
        target_labels = [target_labels]

    # Take N samples of the model
    expanded_design = lexpand(design, num_theta_samples)
    trace = poutine.trace(model).get_trace(expanded_design)

    theta_dict = {l: trace.nodes[l]["value"] for l in target_labels}
    cond_model = pyro.condition(model, data=theta_dict)

    loss = _lfire_loss(model, cond_model, classifier, observation_labels, target_labels)
    out = opt_eig_ape_loss(expanded_design, loss, num_y_samples, num_steps, optim, return_history,
                           final_design, final_num_samples)
    if return_history:
        return out[0], out[1].sum(0) / num_theta_samples
    else:
        return out.sum(0) / num_theta_samples
Example #8
0
File: util.py Project: zyxue/pyro
def mc_H_prior(model,
               design,
               observation_labels,
               target_labels,
               num_samples=1000):
    if isinstance(target_labels, str):
        target_labels = [target_labels]

    expanded_design = lexpand(design, num_samples)
    trace = pyro.poutine.trace(model).get_trace(expanded_design)
    trace.compute_log_prob()
    lp = sum(trace.nodes[l]["log_prob"] for l in target_labels)
    return -lp.sum(0) / num_samples
Example #9
0
def monte_carlo_entropy(model, design, target_labels, num_prior_samples=1000):
    """Computes a Monte Carlo estimate of the entropy of `model` assuming that each of sites in `target_labels` is
    independent and the entropy is to be computed for that subset of sites only.
    """

    if isinstance(target_labels, str):
        target_labels = [target_labels]

    expanded_design = lexpand(design, num_prior_samples)
    trace = pyro.poutine.trace(model).get_trace(expanded_design)
    trace.compute_log_prob()
    lp = sum(trace.nodes[l]["log_prob"] for l in target_labels)
    return -lp.sum(0) / num_prior_samples
Example #10
0
File: eig.py Project: zyxue/pyro
    def loss_fn(design, num_particles):

        expanded_design = lexpand(design, num_particles)

        # Sample from p(y, theta | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {l: trace.nodes[l]["value"] for l in observation_labels}
        theta_dict = {l: trace.nodes[l]["value"] for l in target_labels}

        # Run through q(theta | y, d)
        conditional_guide = pyro.condition(guide, data=theta_dict)
        cond_trace = poutine.trace(conditional_guide).get_trace(
            y_dict, expanded_design, observation_labels, target_labels)
        cond_trace.compute_log_prob()

        loss = -sum(cond_trace.nodes[l]["log_prob"]
                    for l in target_labels).sum(0) / num_particles
        agg_loss = loss.sum()
        return agg_loss, loss
Example #11
0
    def loss_fn(design, num_particles, evaluation=False, **kwargs):

        expanded_design = lexpand(design, num_particles)

        # Sample from p(y | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {l: trace.nodes[l]["value"] for l in observation_labels}

        # Run through q(y | d)
        conditional_guide = pyro.condition(guide, data=y_dict)
        cond_trace = poutine.trace(conditional_guide).get_trace(
             expanded_design, observation_labels, target_labels)
        cond_trace.compute_log_prob()

        terms = -sum(cond_trace.nodes[l]["log_prob"] for l in observation_labels)

        # At eval time, add p(y | theta, d) terms
        if evaluation:
            trace.compute_log_prob()
            terms += sum(trace.nodes[l]["log_prob"] for l in observation_labels)

        return _safe_mean_terms(terms)
Example #12
0
def test_rvv(a, b):
    assert_equal(rvv(a, b), torch.dot(a, b), prec=1e-8)
    batched_a = lexpand(a, 5, 4)
    batched_b = lexpand(b, 5, 4)
    expected_ab = lexpand(torch.dot(a, b), 5, 4)
    assert_equal(rvv(batched_a, batched_b), expected_ab, prec=1e-8)
Example #13
0
def test_lexpand():
    A = torch.tensor([[1., 2.], [-2., 0]])
    assert_equal(lexpand(A), A, prec=1e-8)
    assert_equal(lexpand(A, 4), A.expand(4, 2, 2), prec=1e-8)
    assert_equal(lexpand(A, 4, 2), A.expand(4, 2, 2, 2), prec=1e-8)
Example #14
0
def test_rtril():
    A = torch.tensor([[1., 2.], [-2., 0]])
    assert_equal(rtril(A), torch.tril(A), prec=1e-8)
    expanded = lexpand(A, 5, 4)
    expected = lexpand(torch.tril(A), 5, 4)
    assert_equal(rtril(expanded), expected, prec=1e-8)
Example #15
0
def nmc_eig(model, design, observation_labels, target_labels=None,
            N=100, M=10, M_prime=None, independent_priors=False):
    """Nested Monte Carlo estimate of the expected information
    gain (EIG). The estimate is, when there are not any random effects,

    .. math::

        \\frac{1}{N}\\sum_{n=1}^N \\log p(y_n | \\theta_n, d) -
        \\frac{1}{N}\\sum_{n=1}^N \\log \\left(\\frac{1}{M}\\sum_{m=1}^M p(y_n | \\theta_m, d)\\right)

    where :math:`\\theta_n, y_n \\sim p(\\theta, y | d)` and :math:`\\theta_m \\sim p(\\theta)`.
    The estimate in the presence of random effects is

    .. math::

        \\frac{1}{N}\\sum_{n=1}^N  \\log \\left(\\frac{1}{M'}\\sum_{m=1}^{M'}
        p(y_n | \\theta_n, \\widetilde{\\theta}_{nm}, d)\\right)-
        \\frac{1}{N}\\sum_{n=1}^N \\log \\left(\\frac{1}{M}\\sum_{m=1}^{M}
        p(y_n | \\theta_m, \\widetilde{\\theta}_{m}, d)\\right)

    where :math:`\\widetilde{\\theta}` are the random effects with
    :math:`\\widetilde{\\theta}_{nm} \\sim p(\\widetilde{\\theta}|\\theta=\\theta_n)` and
    :math:`\\theta_m,\\widetilde{\\theta}_m \\sim p(\\theta,\\widetilde{\\theta})`.
    The latter form is used when `M_prime != None`.

    :param function model: A pyro model accepting `design` as only argument.
    :param torch.Tensor design: Tensor representation of design
    :param list observation_labels: A subset of the sample sites
        present in `model`. These sites are regarded as future observations
        and other sites are regarded as latent variables over which a
        posterior is to be inferred.
    :param list target_labels: A subset of the sample sites over which the posterior
        entropy is to be measured.
    :param int N: Number of outer expectation samples.
    :param int M: Number of inner expectation samples for `p(y|d)`.
    :param int M_prime: Number of samples for `p(y | theta, d)` if required.
    :param bool independent_priors: Only used when `M_prime` is not `None`. Indicates whether the prior distributions
        for the target variables and the nuisance variables are independent. In this case, it is not necessary to
        sample the targets conditional on the nuisance variables.
    :return: EIG estimate, optionally includes full optimization history
    :rtype: torch.Tensor
    """

    if isinstance(observation_labels, str):  # list of strings instead of strings
        observation_labels = [observation_labels]
    if isinstance(target_labels, str):
        target_labels = [target_labels]

    # Take N samples of the model
    expanded_design = lexpand(design, N)  # N copies of the model
    trace = poutine.trace(model).get_trace(expanded_design)
    trace.compute_log_prob()

    if M_prime is not None:
        y_dict = {l: lexpand(trace.nodes[l]["value"], M_prime) for l in observation_labels}
        theta_dict = {l: lexpand(trace.nodes[l]["value"], M_prime) for l in target_labels}
        theta_dict.update(y_dict)
        # Resample M values of u and compute conditional probabilities
        # WARNING: currently the use of condition does not actually sample
        # the conditional distribution!
        # We need to use some importance weighting
        conditional_model = pyro.condition(model, data=theta_dict)
        if independent_priors:
            reexpanded_design = lexpand(design, M_prime, 1)
        else:
            # Not acceptable to use (M_prime, 1) here - other variables may occur after
            # theta, so need to be sampled conditional upon it
            reexpanded_design = lexpand(design, M_prime, N)
        retrace = poutine.trace(conditional_model).get_trace(reexpanded_design)
        retrace.compute_log_prob()
        conditional_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \
            - math.log(M_prime)
    else:
        # This assumes that y are independent conditional on theta
        # Furthermore assume that there are no other variables besides theta
        conditional_lp = sum(trace.nodes[l]["log_prob"] for l in observation_labels)

    y_dict = {l: lexpand(trace.nodes[l]["value"], M) for l in observation_labels}
    # Resample M values of theta and compute conditional probabilities
    conditional_model = pyro.condition(model, data=y_dict)
    # Using (M, 1) instead of (M, N) - acceptable to re-use thetas between ys because
    # theta comes before y in graphical model
    reexpanded_design = lexpand(design, M, 1)  # sample M theta
    retrace = poutine.trace(conditional_model).get_trace(reexpanded_design)
    retrace.compute_log_prob()
    marginal_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \
        - math.log(M)

    terms = conditional_lp - marginal_lp
    nonnan = (~torch.isnan(terms)).sum(0).type_as(terms)
    terms[torch.isnan(terms)] = 0.
    return terms.sum(0)/nonnan
Example #16
0
File: eig.py Project: zyxue/pyro
def naive_rainforth_eig(model,
                        design,
                        observation_labels,
                        target_labels=None,
                        N=100,
                        M=10,
                        M_prime=None):
    """
    Naive Rainforth (i.e. Nested Monte Carlo) estimate of the expected information
    gain (EIG). The estimate is

    .. math::

        \\frac{1}{N}\\sum_{n=1}^N \\log p(y_n | \\theta_n, d) -
        \\log \\left(\\frac{1}{M}\\sum_{m=1}^M p(y_n | \\theta_m, d)\\right)

    Monte Carlo estimation is attempted for the :math:`\\log p(y | \\theta, d)` term if
    the parameter `M_prime` is passed. Otherwise, it is assumed that that :math:`\\log p(y | \\theta, d)`
    can safely be read from the model itself.

    :param function model: A pyro model accepting `design` as only argument.
    :param torch.Tensor design: Tensor representation of design
    :param list observation_labels: A subset of the sample sites
        present in `model`. These sites are regarded as future observations
        and other sites are regarded as latent variables over which a
        posterior is to be inferred.
    :param list target_labels: A subset of the sample sites over which the posterior
        entropy is to be measured.
    :param int N: Number of outer expectation samples.
    :param int M: Number of inner expectation samples for `p(y|d)`.
    :param int M_prime: Number of samples for `p(y | theta, d)` if required.
    :return: EIG estimate
    :rtype: `torch.Tensor`
    """

    if isinstance(observation_labels, str):
        observation_labels = [observation_labels]
    if isinstance(target_labels, str):
        target_labels = [target_labels]

    # Take N samples of the model
    expanded_design = lexpand(design, N)
    trace = poutine.trace(model).get_trace(expanded_design)
    trace.compute_log_prob()

    if M_prime is not None:
        y_dict = {
            l: lexpand(trace.nodes[l]["value"], M_prime)
            for l in observation_labels
        }
        theta_dict = {
            l: lexpand(trace.nodes[l]["value"], M_prime)
            for l in target_labels
        }
        theta_dict.update(y_dict)
        # Resample M values of u and compute conditional probabilities
        conditional_model = pyro.condition(model, data=theta_dict)
        # Not acceptable to use (M_prime, 1) here - other variables may occur after
        # theta, so need to be sampled conditional upon it
        reexpanded_design = lexpand(design, M_prime, N)
        retrace = poutine.trace(conditional_model).get_trace(reexpanded_design)
        retrace.compute_log_prob()
        conditional_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \
            - math.log(M_prime)
    else:
        # This assumes that y are independent conditional on theta
        # Furthermore assume that there are no other variables besides theta
        conditional_lp = sum(trace.nodes[l]["log_prob"]
                             for l in observation_labels)

    y_dict = {
        l: lexpand(trace.nodes[l]["value"], M)
        for l in observation_labels
    }
    # Resample M values of theta and compute conditional probabilities
    conditional_model = pyro.condition(model, data=y_dict)
    # Using (M, 1) instead of (M, N) - acceptable to re-use thetas between ys because
    # theta comes before y in graphical model
    reexpanded_design = lexpand(design, M, 1)
    retrace = poutine.trace(conditional_model).get_trace(reexpanded_design)
    retrace.compute_log_prob()
    marginal_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \
        - math.log(M)

    return (conditional_lp - marginal_lp).sum(0) / N
Example #17
0
def test_rmv(A, b):
    assert_equal(rmv(A, b), A.mv(b), prec=1e-8)
    batched_A = lexpand(A, 5, 4)
    batched_b = lexpand(b, 5, 4)
    expected_Ab = lexpand(A.mv(b), 5, 4)
    assert_equal(rmv(batched_A, batched_b), expected_Ab, prec=1e-8)
Example #18
0
def test_rdiag():
    v = torch.tensor([1., 2., -1.])
    assert_equal(rdiag(v), torch.diag(v), prec=1e-8)
    expanded = lexpand(v, 5, 4)
    expeceted = lexpand(torch.diag(v), 5, 4)
    assert_equal(rdiag(expanded), expeceted, prec=1e-8)