def loss_fn(design, num_particles, evaluation=False, **kwargs): N, M = num_particles expanded_design = lexpand(design, N) # Sample from p(y, theta | d) trace = poutine.trace(model).get_trace(expanded_design) y_dict = {l: lexpand(trace.nodes[l]["value"], M) for l in observation_labels} # Sample M times from q(theta | y, d) for each y reexpanded_design = lexpand(expanded_design, M) conditional_guide = pyro.condition(guide, data=y_dict) guide_trace = poutine.trace(conditional_guide).get_trace( y_dict, reexpanded_design, observation_labels, target_labels) theta_y_dict = {l: guide_trace.nodes[l]["value"] for l in target_labels} theta_y_dict.update(y_dict) guide_trace.compute_log_prob() # Re-run that through the model to compute the joint modelp = pyro.condition(model, data=theta_y_dict) model_trace = poutine.trace(modelp).get_trace(reexpanded_design) model_trace.compute_log_prob() terms = -sum(guide_trace.nodes[l]["log_prob"] for l in target_labels) terms += sum(model_trace.nodes[l]["log_prob"] for l in target_labels) terms += sum(model_trace.nodes[l]["log_prob"] for l in observation_labels) terms = -terms.logsumexp(0) + math.log(M) # At eval time, add p(y | theta, d) terms if evaluation: trace.compute_log_prob() terms += sum(trace.nodes[l]["log_prob"] for l in observation_labels) return _safe_mean_terms(terms)
def __init__(self, d, w_sizes, y_sizes, regressor_init=0., scale_tril_init=3., use_softplus=True, **kwargs): """ Guide for linear models. No amortisation happens over designs. Amortisation over data is taken care of by analytic formulae for linear models (heavy use of truth). :param tuple d: the shape by which to expand the guide parameters, e.g. `(num_batches, num_designs)`. :param dict w_sizes: map from variable string names to int, indicating the dimension of each weight vector in the linear model. :param float regressor_init: initial value for the regressor matrix used to learn the posterior mean. :param float scale_tril_init: initial value for posterior `scale_tril` parameter. :param bool use_softplus: whether to transform the regressor by a softplus transform: useful if the regressor should be nonnegative but close to zero. """ super().__init__() # Represent each parameter group as independent Gaussian # Making a weak mean-field assumption # To avoid this- combine labels if not isinstance(d, (tuple, list, torch.Tensor)): d = (d,) self.regressor = nn.ParameterDict({l: nn.Parameter( regressor_init * torch.ones(*(d + (p, sum(y_sizes.values()))))) for l, p in w_sizes.items()}) self.scale_tril = nn.ParameterDict({l: nn.Parameter( scale_tril_init * lexpand(torch.eye(p), *d)) for l, p in w_sizes.items()}) self.w_sizes = w_sizes self.use_softplus = use_softplus self.softplus = nn.Softplus()
def loss_fn(design, num_particles, evaluation=False, **kwargs): try: pyro.module("h", h) except AssertionError: pass expanded_design = lexpand(design, num_particles) model_conditional_trace = poutine.trace(model_conditional).get_trace( expanded_design) if not evaluation: model_marginal_trace = poutine.trace(model_marginal).get_trace( expanded_design) h_joint = h(expanded_design, model_conditional_trace, observation_labels, target_labels) h_independent = h(expanded_design, model_marginal_trace, observation_labels, target_labels) terms = torch.nn.functional.softplus( -h_joint) + torch.nn.functional.softplus(h_independent) return _safe_mean_terms(terms) else: h_joint = h(expanded_design, model_conditional_trace, observation_labels, target_labels) return _safe_mean_terms(h_joint)
def loss_fn(design, num_particles, **kwargs): try: pyro.module("T", T) except AssertionError: pass expanded_design = lexpand(design, num_particles) # Unshuffled data unshuffled_trace = poutine.trace(model).get_trace(expanded_design) y_dict = {l: unshuffled_trace.nodes[l]["value"] for l in observation_labels} # Shuffled data # Not actually shuffling, resimulate for safety conditional_model = pyro.condition(model, data=y_dict) shuffled_trace = poutine.trace(conditional_model).get_trace(expanded_design) T_joint = T(expanded_design, unshuffled_trace, observation_labels, target_labels) T_independent = T(expanded_design, shuffled_trace, observation_labels, target_labels) joint_expectation = T_joint.sum(0)/num_particles A = T_independent - math.log(num_particles) s, _ = torch.max(A, dim=0) independent_expectation = s + ewma_log((A - s).exp().sum(dim=0), s) loss = joint_expectation - independent_expectation # Switch sign, sum over batch dimensions for scalar loss agg_loss = -loss.sum() return agg_loss, loss
def loss_fn(design, num_particles, evaluation=False, **kwargs): expanded_design = lexpand(design, num_particles) # Sample from p(y | d) trace = poutine.trace(model).get_trace(expanded_design) y_dict = {l: trace.nodes[l]["value"] for l in observation_labels} theta_dict = {l: trace.nodes[l]["value"] for l in target_labels} # Run through q(y | d) qyd = pyro.condition(marginal_guide, data=y_dict) marginal_trace = poutine.trace(qyd).get_trace( expanded_design, observation_labels, target_labels) marginal_trace.compute_log_prob() # Run through q(y | theta, d) qythetad = pyro.condition(likelihood_guide, data=y_dict) cond_trace = poutine.trace(qythetad).get_trace( theta_dict, expanded_design, observation_labels, target_labels) cond_trace.compute_log_prob() terms = -sum(marginal_trace.nodes[l]["log_prob"] for l in observation_labels) # At evaluation time, use the right estimator, q(y | theta, d) - q(y | d) # At training time, use -q(y | theta, d) - q(y | d) so gradients go the same way if evaluation: terms += sum(cond_trace.nodes[l]["log_prob"] for l in observation_labels) else: terms -= sum(cond_trace.nodes[l]["log_prob"] for l in observation_labels) return _safe_mean_terms(terms)
def loss_fn(design, num_particles, evaluation=False, **kwargs): expanded_design = lexpand(design, num_particles) # Sample from p(y, theta | d) trace = poutine.trace(model).get_trace(expanded_design) y_dict = {l: trace.nodes[l]["value"] for l in observation_labels} theta_dict = {l: trace.nodes[l]["value"] for l in target_labels} # Run through q(theta | y, d) conditional_guide = pyro.condition(guide, data=theta_dict) cond_trace = poutine.trace(conditional_guide).get_trace( y_dict, expanded_design, observation_labels, target_labels) cond_trace.compute_log_prob() if evaluation and analytic_entropy: loss = mean_field_entropy( guide, [y_dict, expanded_design, observation_labels, target_labels], whitelist=target_labels).sum(0) / num_particles agg_loss = loss.sum() else: terms = -sum(cond_trace.nodes[l]["log_prob"] for l in target_labels) agg_loss, loss = _safe_mean_terms(terms) return agg_loss, loss
def lfire_eig(model, design, observation_labels, target_labels, num_y_samples, num_theta_samples, num_steps, classifier, optim, return_history=False, final_design=None, final_num_samples=None): """Estimates the EIG using the method of Likelihood-Free Inference by Ratio Estimation (LFIRE) as in [1]. LFIRE is run separately for several samples of :math:`\\theta`. [1] Kleinegesse, Steven, and Michael Gutmann. "Efficient Bayesian Experimental Design for Implicit Models." arXiv preprint arXiv:1810.09912 (2018). :param function model: A pyro model accepting `design` as only argument. :param torch.Tensor design: Tensor representation of design :param list observation_labels: A subset of the sample sites present in `model`. These sites are regarded as future observations and other sites are regarded as latent variables over which a posterior is to be inferred. :param list target_labels: A subset of the sample sites over which the posterior entropy is to be measured. :param int num_y_samples: Number of samples to take in :math:`y` for each :math:`\\theta`. :param: int num_theta_samples: Number of initial samples in :math:`\\theta` to take. The likelihood ratio is estimated by LFIRE for each sample. :param int num_steps: Number of optimization steps. :param function classifier: a Pytorch or Pyro classifier used to distinguish between samples of :math:`y` under :math:`p(y|d)` and samples under :math:`p(y|\\theta,d)` for some :math:`\\theta`. :param pyro.optim.Optim optim: Optimiser to use. :param bool return_history: If `True`, also returns a tensor giving the loss function at each step of the optimization. :param torch.Tensor final_design: The final design tensor to evaluate at. If `None`, uses `design`. :param int final_num_samples: The number of samples to use at the final evaluation, If `None, uses `num_samples`. :return: EIG estimate, optionally includes full optimization history :rtype: torch.Tensor or tuple """ if isinstance(observation_labels, str): observation_labels = [observation_labels] if isinstance(target_labels, str): target_labels = [target_labels] # Take N samples of the model expanded_design = lexpand(design, num_theta_samples) trace = poutine.trace(model).get_trace(expanded_design) theta_dict = {l: trace.nodes[l]["value"] for l in target_labels} cond_model = pyro.condition(model, data=theta_dict) loss = _lfire_loss(model, cond_model, classifier, observation_labels, target_labels) out = opt_eig_ape_loss(expanded_design, loss, num_y_samples, num_steps, optim, return_history, final_design, final_num_samples) if return_history: return out[0], out[1].sum(0) / num_theta_samples else: return out.sum(0) / num_theta_samples
def mc_H_prior(model, design, observation_labels, target_labels, num_samples=1000): if isinstance(target_labels, str): target_labels = [target_labels] expanded_design = lexpand(design, num_samples) trace = pyro.poutine.trace(model).get_trace(expanded_design) trace.compute_log_prob() lp = sum(trace.nodes[l]["log_prob"] for l in target_labels) return -lp.sum(0) / num_samples
def monte_carlo_entropy(model, design, target_labels, num_prior_samples=1000): """Computes a Monte Carlo estimate of the entropy of `model` assuming that each of sites in `target_labels` is independent and the entropy is to be computed for that subset of sites only. """ if isinstance(target_labels, str): target_labels = [target_labels] expanded_design = lexpand(design, num_prior_samples) trace = pyro.poutine.trace(model).get_trace(expanded_design) trace.compute_log_prob() lp = sum(trace.nodes[l]["log_prob"] for l in target_labels) return -lp.sum(0) / num_prior_samples
def loss_fn(design, num_particles): expanded_design = lexpand(design, num_particles) # Sample from p(y, theta | d) trace = poutine.trace(model).get_trace(expanded_design) y_dict = {l: trace.nodes[l]["value"] for l in observation_labels} theta_dict = {l: trace.nodes[l]["value"] for l in target_labels} # Run through q(theta | y, d) conditional_guide = pyro.condition(guide, data=theta_dict) cond_trace = poutine.trace(conditional_guide).get_trace( y_dict, expanded_design, observation_labels, target_labels) cond_trace.compute_log_prob() loss = -sum(cond_trace.nodes[l]["log_prob"] for l in target_labels).sum(0) / num_particles agg_loss = loss.sum() return agg_loss, loss
def loss_fn(design, num_particles, evaluation=False, **kwargs): expanded_design = lexpand(design, num_particles) # Sample from p(y | d) trace = poutine.trace(model).get_trace(expanded_design) y_dict = {l: trace.nodes[l]["value"] for l in observation_labels} # Run through q(y | d) conditional_guide = pyro.condition(guide, data=y_dict) cond_trace = poutine.trace(conditional_guide).get_trace( expanded_design, observation_labels, target_labels) cond_trace.compute_log_prob() terms = -sum(cond_trace.nodes[l]["log_prob"] for l in observation_labels) # At eval time, add p(y | theta, d) terms if evaluation: trace.compute_log_prob() terms += sum(trace.nodes[l]["log_prob"] for l in observation_labels) return _safe_mean_terms(terms)
def test_rvv(a, b): assert_equal(rvv(a, b), torch.dot(a, b), prec=1e-8) batched_a = lexpand(a, 5, 4) batched_b = lexpand(b, 5, 4) expected_ab = lexpand(torch.dot(a, b), 5, 4) assert_equal(rvv(batched_a, batched_b), expected_ab, prec=1e-8)
def test_lexpand(): A = torch.tensor([[1., 2.], [-2., 0]]) assert_equal(lexpand(A), A, prec=1e-8) assert_equal(lexpand(A, 4), A.expand(4, 2, 2), prec=1e-8) assert_equal(lexpand(A, 4, 2), A.expand(4, 2, 2, 2), prec=1e-8)
def test_rtril(): A = torch.tensor([[1., 2.], [-2., 0]]) assert_equal(rtril(A), torch.tril(A), prec=1e-8) expanded = lexpand(A, 5, 4) expected = lexpand(torch.tril(A), 5, 4) assert_equal(rtril(expanded), expected, prec=1e-8)
def nmc_eig(model, design, observation_labels, target_labels=None, N=100, M=10, M_prime=None, independent_priors=False): """Nested Monte Carlo estimate of the expected information gain (EIG). The estimate is, when there are not any random effects, .. math:: \\frac{1}{N}\\sum_{n=1}^N \\log p(y_n | \\theta_n, d) - \\frac{1}{N}\\sum_{n=1}^N \\log \\left(\\frac{1}{M}\\sum_{m=1}^M p(y_n | \\theta_m, d)\\right) where :math:`\\theta_n, y_n \\sim p(\\theta, y | d)` and :math:`\\theta_m \\sim p(\\theta)`. The estimate in the presence of random effects is .. math:: \\frac{1}{N}\\sum_{n=1}^N \\log \\left(\\frac{1}{M'}\\sum_{m=1}^{M'} p(y_n | \\theta_n, \\widetilde{\\theta}_{nm}, d)\\right)- \\frac{1}{N}\\sum_{n=1}^N \\log \\left(\\frac{1}{M}\\sum_{m=1}^{M} p(y_n | \\theta_m, \\widetilde{\\theta}_{m}, d)\\right) where :math:`\\widetilde{\\theta}` are the random effects with :math:`\\widetilde{\\theta}_{nm} \\sim p(\\widetilde{\\theta}|\\theta=\\theta_n)` and :math:`\\theta_m,\\widetilde{\\theta}_m \\sim p(\\theta,\\widetilde{\\theta})`. The latter form is used when `M_prime != None`. :param function model: A pyro model accepting `design` as only argument. :param torch.Tensor design: Tensor representation of design :param list observation_labels: A subset of the sample sites present in `model`. These sites are regarded as future observations and other sites are regarded as latent variables over which a posterior is to be inferred. :param list target_labels: A subset of the sample sites over which the posterior entropy is to be measured. :param int N: Number of outer expectation samples. :param int M: Number of inner expectation samples for `p(y|d)`. :param int M_prime: Number of samples for `p(y | theta, d)` if required. :param bool independent_priors: Only used when `M_prime` is not `None`. Indicates whether the prior distributions for the target variables and the nuisance variables are independent. In this case, it is not necessary to sample the targets conditional on the nuisance variables. :return: EIG estimate, optionally includes full optimization history :rtype: torch.Tensor """ if isinstance(observation_labels, str): # list of strings instead of strings observation_labels = [observation_labels] if isinstance(target_labels, str): target_labels = [target_labels] # Take N samples of the model expanded_design = lexpand(design, N) # N copies of the model trace = poutine.trace(model).get_trace(expanded_design) trace.compute_log_prob() if M_prime is not None: y_dict = {l: lexpand(trace.nodes[l]["value"], M_prime) for l in observation_labels} theta_dict = {l: lexpand(trace.nodes[l]["value"], M_prime) for l in target_labels} theta_dict.update(y_dict) # Resample M values of u and compute conditional probabilities # WARNING: currently the use of condition does not actually sample # the conditional distribution! # We need to use some importance weighting conditional_model = pyro.condition(model, data=theta_dict) if independent_priors: reexpanded_design = lexpand(design, M_prime, 1) else: # Not acceptable to use (M_prime, 1) here - other variables may occur after # theta, so need to be sampled conditional upon it reexpanded_design = lexpand(design, M_prime, N) retrace = poutine.trace(conditional_model).get_trace(reexpanded_design) retrace.compute_log_prob() conditional_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \ - math.log(M_prime) else: # This assumes that y are independent conditional on theta # Furthermore assume that there are no other variables besides theta conditional_lp = sum(trace.nodes[l]["log_prob"] for l in observation_labels) y_dict = {l: lexpand(trace.nodes[l]["value"], M) for l in observation_labels} # Resample M values of theta and compute conditional probabilities conditional_model = pyro.condition(model, data=y_dict) # Using (M, 1) instead of (M, N) - acceptable to re-use thetas between ys because # theta comes before y in graphical model reexpanded_design = lexpand(design, M, 1) # sample M theta retrace = poutine.trace(conditional_model).get_trace(reexpanded_design) retrace.compute_log_prob() marginal_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \ - math.log(M) terms = conditional_lp - marginal_lp nonnan = (~torch.isnan(terms)).sum(0).type_as(terms) terms[torch.isnan(terms)] = 0. return terms.sum(0)/nonnan
def naive_rainforth_eig(model, design, observation_labels, target_labels=None, N=100, M=10, M_prime=None): """ Naive Rainforth (i.e. Nested Monte Carlo) estimate of the expected information gain (EIG). The estimate is .. math:: \\frac{1}{N}\\sum_{n=1}^N \\log p(y_n | \\theta_n, d) - \\log \\left(\\frac{1}{M}\\sum_{m=1}^M p(y_n | \\theta_m, d)\\right) Monte Carlo estimation is attempted for the :math:`\\log p(y | \\theta, d)` term if the parameter `M_prime` is passed. Otherwise, it is assumed that that :math:`\\log p(y | \\theta, d)` can safely be read from the model itself. :param function model: A pyro model accepting `design` as only argument. :param torch.Tensor design: Tensor representation of design :param list observation_labels: A subset of the sample sites present in `model`. These sites are regarded as future observations and other sites are regarded as latent variables over which a posterior is to be inferred. :param list target_labels: A subset of the sample sites over which the posterior entropy is to be measured. :param int N: Number of outer expectation samples. :param int M: Number of inner expectation samples for `p(y|d)`. :param int M_prime: Number of samples for `p(y | theta, d)` if required. :return: EIG estimate :rtype: `torch.Tensor` """ if isinstance(observation_labels, str): observation_labels = [observation_labels] if isinstance(target_labels, str): target_labels = [target_labels] # Take N samples of the model expanded_design = lexpand(design, N) trace = poutine.trace(model).get_trace(expanded_design) trace.compute_log_prob() if M_prime is not None: y_dict = { l: lexpand(trace.nodes[l]["value"], M_prime) for l in observation_labels } theta_dict = { l: lexpand(trace.nodes[l]["value"], M_prime) for l in target_labels } theta_dict.update(y_dict) # Resample M values of u and compute conditional probabilities conditional_model = pyro.condition(model, data=theta_dict) # Not acceptable to use (M_prime, 1) here - other variables may occur after # theta, so need to be sampled conditional upon it reexpanded_design = lexpand(design, M_prime, N) retrace = poutine.trace(conditional_model).get_trace(reexpanded_design) retrace.compute_log_prob() conditional_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \ - math.log(M_prime) else: # This assumes that y are independent conditional on theta # Furthermore assume that there are no other variables besides theta conditional_lp = sum(trace.nodes[l]["log_prob"] for l in observation_labels) y_dict = { l: lexpand(trace.nodes[l]["value"], M) for l in observation_labels } # Resample M values of theta and compute conditional probabilities conditional_model = pyro.condition(model, data=y_dict) # Using (M, 1) instead of (M, N) - acceptable to re-use thetas between ys because # theta comes before y in graphical model reexpanded_design = lexpand(design, M, 1) retrace = poutine.trace(conditional_model).get_trace(reexpanded_design) retrace.compute_log_prob() marginal_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \ - math.log(M) return (conditional_lp - marginal_lp).sum(0) / N
def test_rmv(A, b): assert_equal(rmv(A, b), A.mv(b), prec=1e-8) batched_A = lexpand(A, 5, 4) batched_b = lexpand(b, 5, 4) expected_Ab = lexpand(A.mv(b), 5, 4) assert_equal(rmv(batched_A, batched_b), expected_Ab, prec=1e-8)
def test_rdiag(): v = torch.tensor([1., 2., -1.]) assert_equal(rdiag(v), torch.diag(v), prec=1e-8) expanded = lexpand(v, 5, 4) expeceted = lexpand(torch.diag(v), 5, 4) assert_equal(rdiag(expanded), expeceted, prec=1e-8)