def check_sample_shape_and_support(q: Distribution, prior: Distribution) -> None: """Checks the samples shape and support between variational distribution and the prior. Especially it checks if the shapes match and that the support between q and the prior matches (a property which holds for the true posterior in any case). Args: q: Variational distribution which is checked prior: Prior to check certain attributes which should be satisfied. """ assert (q.event_shape == prior.event_shape ), "The event shape of q must match that of the prior" assert (q.batch_shape == prior.batch_shape ), "The batch sahpe of q must match that of the prior" sample_shape = torch.Size((1000, )) samples = q.sample(sample_shape) samples_prior = prior.sample(sample_shape).to(samples.device) try: _ = prior.support has_support = True except (NotImplementedError, AttributeError): has_support = False if has_support: assert all(prior.support.check(samples) # type: ignore ), "The support of q must match that of the prior" assert ( samples.shape == samples_prior.shape ), "Something is wrong with sample shape and event_shape or batch_shape attributes." assert torch.isfinite(q.log_prob( samples_prior)).all(), "Invalid values in logprob on prior samples." assert torch.isfinite(prior.log_prob( samples)).all(), "Invalid values in logprob on q samples."
def process_pytorch_prior( prior: Distribution) -> Tuple[Distribution, int, bool]: """Return PyTorch prior adapted to the requirements for sbi. Args: prior: PyTorch distribution prior provided by the user. Raises: ValueError: If prior is defined over an unwrapped scalar variable. Returns: prior: PyTorch distribution prior. theta_numel: Number of parameters - elements in a single sample from the prior. prior_returns_numpy: False. """ # Reject unwrapped scalar priors. if prior.sample().ndim == 0: raise ValueError( "Detected scalar prior. Please make sure to pass a PyTorch prior with " "`batch_shape=torch.Size([1])` or `event_shape=torch.Size([1])`.") check_prior_batch_behavior(prior) check_prior_batch_dims(prior) if not prior.sample().dtype == float32: prior = PytorchReturnTypeWrapper(prior, return_type=float32) # This will fail for float64 priors. check_prior_return_type(prior) theta_numel = prior.sample().numel() return prior, theta_numel, False
def plot_comparison(n_display: int, x_og: torch.Tensor, p_x: D.Distribution, input_dims: Tuple[int]) -> Figure: fig = plt.figure(figsize=(6, 6)) gs = fig.add_gridspec(4, n_display, width_ratios=[1] * n_display, height_ratios=[1, 1, 1, 1]) gs.update(wspace=0, hspace=0) x_hat = batch_reshape(p_x.sample(), input_dims).clip(0, 1) x_mu = batch_reshape(p_x.mean, input_dims).clip(0, 1) x_var = batch_reshape(p_x.variance, input_dims).clip(0, 1) for n in range(n_display): for k in range(4): ax = plt.subplot(gs[k, n]) ax = disable_ticks(ax) # Original # imshow accepts (W, H, C) if k == 0: ax.imshow(x_og[n, :].permute(1, 2, 0), vmin=0, vmax=1) # Mean elif k == 1: ax.imshow(x_mu[n, :].permute(1, 2, 0), vmin=0, vmax=1) # Variance elif k == 2: ax.imshow(x_var[n, :].permute(1, 2, 0)) # Sample elif k == 3: ax.imshow(x_hat[n, :].permute(1, 2, 0), vmin=0, vmax=1) return fig
def get_normalization_uniform_prior( posterior: DirectPosterior, prior: Distribution, true_observation: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: """ Return the unnormalized posterior likelihood, the normalized posterior likelihood, and the estimated acceptance probability. Args: posterior: estimated posterior prior: prior distribution true_observation: observation where we evaluate the posterior """ # Test normalization. prior_sample = prior.sample() # Compute unnormalized density, i.e. just the output of the density estimator. posterior_likelihood_unnorm = torch.exp( posterior.log_prob(prior_sample, norm_posterior=False)) # Compute the normalized density, scale up output of the density # estimator by the ratio of posterior samples within the prior bounds. posterior_likelihood_norm = torch.exp( posterior.log_prob(prior_sample, norm_posterior=True)) # Estimate acceptance ratio through rejection sampling. acceptance_prob = posterior.leakage_correction(x=true_observation) return posterior_likelihood_unnorm, posterior_likelihood_norm, acceptance_prob
def dist_sample(distribution: dist.Distribution, context: SamplingContext = None) -> torch.Tensor: """ Sample n samples from a given distribution. Args: repetition_indices: Indices into the repetition axis. distribution (dists.Distribution): Base distribution to sample from. parent_indices (torch.Tensor): Tensor of indexes that point to specific representations of single features/scopes. """ # Sample from the specified distribution if context.is_mpe: samples = _mode(distribution, context) else: samples = distribution.sample(sample_shape=(context.n,)) assert ( samples.shape[1] == 1 ), "Something went wrong. First sample size dimension should be size 1 due to the distribution parameter dimensions. Please report this issue." samples.squeeze_(1) n, d, c, r = samples.shape # Filter each sample by its specific repetition tmp = torch.zeros(n, d, c, device=context.repetition_indices.device) for i in range(n): tmp[i, :, :] = samples[i, :, :, context.repetition_indices[i]] samples = tmp # If parent index into out_channels are given if context.parent_indices is not None: # Choose only specific samples for each feature/scope samples = torch.gather(samples, dim=2, index=context.parent_indices.unsqueeze(-1)).squeeze(-1) return samples
def _sample(self, distribution: dist.Distribution, sample_shape: Union[torch.Size, tuple] = torch.Size()): if self.training: return distribution.rsample(sample_shape=sample_shape) else: return distribution.sample(sample_shape=sample_shape)
def run_pmmh( filter_: BaseFilter, state: FilterResult, prop_kernel: Distribution, prop_filt, y: torch.Tensor, size=torch.Size([]), **kwargs, ) -> Tuple[torch.Tensor, FilterResult, BaseFilter]: """ Runs one iteration of a vectorized Particle Marginal Metropolis hastings. """ rvs = prop_kernel.sample(size) prop_filt.ssm.parameters_from_array(rvs, constrained=False) new_res = prop_filt.longfilter(y, bar=False, **kwargs) diff_logl = new_res.loglikelihood - state.loglikelihood diff_prior = (prop_filt.ssm.eval_prior_log_prob(False) - filter_.ssm.eval_prior_log_prob(False)).squeeze() if isinstance(prop_kernel, Independent) and size == torch.Size([]): diff_prop = 0.0 else: diff_prop = prop_kernel.log_prob( filter_.ssm.parameters_to_array( constrained=False)) - prop_kernel.log_prob(rvs) log_acc_prob = diff_prop + diff_prior + diff_logl res: torch.Tensor = torch.empty_like( log_acc_prob).uniform_().log() < log_acc_prob return res, new_res, prop_filt
def process_pytorch_prior( prior: Distribution) -> Tuple[Distribution, int, bool]: """Return PyTorch prior adapted to the requirements for sbi. Args: prior: PyTorch distribution prior provided by the user. Raises: ValueError: If prior is defined over an unwrapped scalar variable. Returns: prior: PyTorch distribution prior. theta_numel: Number of parameters - elements in a single sample from the prior. prior_returns_numpy: False. """ # Turn off validation of input arguments to allow `log_prob()` on samples outside # of the support. prior.set_default_validate_args(False) # Reject unwrapped scalar priors. # This will reject Uniform priors with dimension larger than 1. if prior.sample().ndim == 0: raise ValueError( "Detected scalar prior. Please make sure to pass a PyTorch prior with " "`batch_shape=torch.Size([1])` or `event_shape=torch.Size([1])`.") # Cast 1D Uniform to BoxUniform to avoid shape error in mdn log prob. elif isinstance(prior, Uniform) and prior.batch_shape.numel() == 1: prior = BoxUniform(low=prior.low, high=prior.high) warnings.warn( "Casting 1D Uniform prior to BoxUniform to match sbi batch requirements." ) check_prior_batch_behavior(prior) check_prior_batch_dims(prior) if not prior.sample().dtype == float32: prior = PytorchReturnTypeWrapper(prior, return_type=float32, validate_args=False) # This will fail for float64 priors. check_prior_return_type(prior) theta_numel = prior.sample().numel() return prior, theta_numel, False
def mc_sample( self, distr: Distribution, parents: [storch.Tensor], plates: [Plate], amt_samples: int, ) -> torch.Tensor: return distr.sample((self.n_samples,))
def mc_sample( self, distr: Distribution, parents: [storch.Tensor], plates: [Plate], amt_samples: int, ) -> torch.Tensor: # TODO: Why does this ignore amt_samples? return distr.sample((self.n_samples,))
def test_process_simulator(simulator: Callable, prior: Distribution): prior, theta_dim, prior_returns_numpy = process_prior(prior) simulator = process_simulator(simulator, prior, prior_returns_numpy) n_batch = 2 x = simulator(prior.sample((n_batch, ))) assert isinstance(x, Tensor), "Processed simulator must return Tensor." assert ( x.shape[0] == n_batch ), "Processed simulator must return as many data points as parameters in batch."
def check_transform(prior: Distribution, transform: TorchTransform) -> None: """Check validity of transformed and re-transformed samples.""" theta = prior.sample(torch.Size((2, ))) theta_unconstrained = transform.inv(theta) assert ( theta_unconstrained.shape == theta.shape # type: ignore ), """Mismatch between transformed and untransformed space. Note that you cannot use a transforms when using a MultipleIndependent prior with a Dirichlet prior.""" assert torch.allclose( theta, transform(theta_unconstrained) # type: ignore ), "Mismatch between true and re-transformed parameters."
def process_pytorch_prior( prior: Distribution, ) -> Tuple[Distribution, int, bool]: """Return corrected prior after checking requirements for SBI. Args: prior: PyTorch distribution prior provided by the user. Raises: ValueError: If prior is defined over a scalar variable. Returns: prior: PyTorch distribution prior. parameter_dim: event shape of the prior, number of parameters. prior_returns_numpy: False. """ # reject scalar priors if prior.sample().ndim == 0: raise ValueError( "Detected scalar prior. Please make sure to pass a PyTorch prior with " "batch_shape=torch.Size([1]), or event_shape=torch.Size([1])") assert prior.batch_shape in ( torch.Size([1]), torch.Size([]), ), f"""The prior must have batch shape torch.Size([]) or torch.Size([1]), but has {prior.batch_shape}. """ check_prior_batch_behavior(prior) check_for_batch_reinterpretation_extra_d_uniform(prior) parameter_dim = prior.sample().numel() return prior, parameter_dim, False
def naive_sampling(num_samples: int, potential_fn: Callable, proposal: Distribution, **kwargs) -> Tensor: """Basic sampling method, which just samples from the proposal i.e. the variational posterior. Args: num_samples: Number of samples to draw. potential_fn: Potential function, this may be used to debias the proposal. proposal: Proposal distribution to propose samples. Returns: Tensor: Samples of shape (num_samples, event_shape) """ return proposal.sample(torch.Size((num_samples, )))
def psis_diagnostics( potential_function: Callable, q: Distribution, proposal: Optional[Distribution] = None, N: int = int(5e4), ) -> float: r"""This will evaluate the posteriors quality by investingating its importance weights. If q is a perfect posterior approximation then $q(\theta) \propto p(\theta, x_o)$ thus $\log w(\theta) = \log \frac{p(\theta, x_o)}{\log q(\theta)} = \log p(x_o)$ is constant. This function will fit a Generalized Paretto distribution to the tails of w. The shape parameter k serves as metric as detailed in [1]. In short it is related to the variance of a importance sampling estimate, especially for k > 1 the variance will be infinite. NOTE: In our experience this metric does distinguish "very bad" from "ok", but becomes less sensitive to distinguish "ok" from "good". Args: potential_function: Potential function of target. q: Variational distribution, should be proportional to the potential_function proposal: Proposal for samples. Typically this is q. N: Number of samples involved in the test. Returns: float: Quality metric Reference: [1] _Yes, but Did It Work?: Evaluating Variational Inference_, Yuling Yao, Aki Vehtari, Daniel Simpson, Andrew Gelman, 2018, https://arxiv.org/abs/1802.02538 """ M = int(min(N / 5, 3 * np.sqrt(N))) with torch.no_grad(): if proposal is None: samples = q.sample(Size((N, ))) else: samples = proposal.sample(Size((N, ))) log_q = q.log_prob(samples) log_potential = potential_function(samples) logweights = log_potential - log_q logweights = logweights[torch.isfinite(logweights)] logweights_max = logweights.max() weights = torch.exp(logweights - logweights_max) # Thus will only affect scale vals, _ = weights.sort() largest_weigths = vals[-M:] k, _ = gpdfit(largest_weigths) return k
def _mcmc_move(params: Iterable[Parameter], dist: Distribution, stacked: StackedObject, shape: int): """ Performs an MCMC move to rejuvenate parameters. :param params: The parameters to use for defining the distribution :param dist: The distribution to use for sampling :param stacked: The mask to apply for parameters :param shape: The shape to sample :return: Samples from a multivariate normal distribution """ rvs = dist.sample((shape,)) for p, msk, ps in zip(params, stacked.mask, stacked.prev_shape): p.t_values = unflattify(rvs[:, msk], ps) return True
def proportional_to_joint_diagnostics( potential_function: Callable, q: Distribution, proposal: Optional[Distribution] = None, N: int = int(5e4), ) -> float: r"""This will evaluate the posteriors quality by investingating its importance weights. If q is a perfect posterior approximation then $q(\theta) \propto p(\theta, x_o)$. Thus we should be able to fit a line to $(q(\theta), p(\theta, x_o))$, whereas the slope will be proportional to the normalizing constant. The quality of a linear fit is hence a direct metric for the quality of q. We use R2 statistic. NOTE: In our experience this metric does distinguish "good" from "ok", but becomes less sensitive to distinguish "very bad" from "ok". Args: potential_function: Potential function of target. q: Variational distribution, should be proportional to the potential_function proposal: Proposal for samples. Typically this is q. N: Number of samples involved in the test. Returns: float: Quality metric """ with torch.no_grad(): if proposal is None: samples = q.sample(Size((N, ))) else: samples = proposal.sample(Size((N, ))) log_q = q.log_prob(samples) log_potential = potential_function(samples) X = log_q.exp().unsqueeze(-1) Y = log_potential.exp().unsqueeze(-1) w = torch.linalg.solve(X.T @ X, X.T @ Y) # Linear regression residuals = Y - w * X var_res = torch.sum(residuals**2) var_tot = torch.sum((Y - Y.mean())**2) r2 = 1 - var_res / var_tot # R2 statistic to evaluate fit return r2.item()
def importance_resampling( num_samples: int, potential_fn: Callable, proposal: Distribution, K: int = 32, num_samples_batch: int = 10000, **kwargs, ) -> Tensor: """Perform sequential importance resampling (SIR). Args: num_samples: Number of samples to draw. potential_fn: Potential function, this may be used to debias the proposal. proposal: Proposal distribution to propose samples. K: Number of proposed samples form which only one is selected based on its importance weight. num_samples_batch: Number of samples processed in parallel. For large K you may want to reduce this, depending on your memory capabilities. Returns: Tensor: Samples of shape (num_samples, event_shape) """ final_samples = [] num_samples_batch = min(num_samples, num_samples_batch) iters = int(num_samples / num_samples_batch) for _ in range(iters): batch_size = min(num_samples_batch, num_samples - len(final_samples)) with torch.no_grad(): thetas = proposal.sample(torch.Size((batch_size * K, ))) logp = potential_fn(thetas) logq = proposal.log_prob(thetas) weights = (logp - logq).reshape(batch_size, K).softmax(-1).cumsum(-1) u = torch.rand(batch_size, 1, device=thetas.device) mask = torch.cumsum(weights >= u, -1) == 1 samples = thetas.reshape(batch_size, K, -1)[mask] final_samples.append(samples) return torch.vstack(final_samples)
def check_sbi_inputs(simulator: Callable, prior: Distribution) -> None: """Assert requirements for simulator, prior and observation for usage in sbi. Args: simulator: simulator function prior: prior (Distribution like) x_shape: Shape of single simulation output $x$. """ num_prior_samples = 1 theta = prior.sample((num_prior_samples, )) theta_batch_shape, *_ = theta.shape simulation = simulator(theta) sim_batch_shape, *sim_event_shape = simulation.shape assert isinstance(theta, Tensor), "Parameters theta must be a `Tensor`." assert isinstance(simulation, Tensor), "Simulator output must be a `Tensor`." assert (theta_batch_shape == num_prior_samples ), f"""Theta batch shape {theta_batch_shape} must match num_samples={num_prior_samples}.""" assert (sim_batch_shape == num_prior_samples ), f"""Simulation batch shape {sim_batch_shape} must match
def mcmc_transform( prior: Distribution, num_prior_samples_for_zscoring: int = 1000, enable_transform: bool = True, device: str = "cpu", **kwargs, ) -> TorchTransform: """ Builds a transform that is applied to parameters during MCMC. The resulting transform is defined such that the forward mapping maps from constrained to unconstrained space. It does two things: 1) When the prior support is bounded, it transforms the parameters into unbounded space. 2) It z-scores the parameters such that MCMC is performed in a z-scored space. Args: prior: The prior distribution. num_prior_samples_for_zscoring: The number of samples drawn from the prior to infer the `mean` and `stddev` of the prior used for z-scoring. Unused if the prior has bounded support or when the prior has `mean` and `stddev` attributes. enable_transform: Whether or not to use a transformation during MCMC. Returns: A transformation that transforms whose `forward()` maps from unconstrained (or z-scored) to constrained (or non-z-scored) space. """ if enable_transform: # Some distributions have a support argument but it raises a # NotImplementedError. We catch this case here. try: _ = prior.support has_support = True except (NotImplementedError, AttributeError): # NotImplementedError -> Distribution that inherits from torch dist but # does not implement support. # AttributeError -> Custom distribution that has no support attribute. warnings.warn( """The passed prior has no support property, transform will be constructed from mean and std. If the passed prior is supposed to be bounded consider implementing the prior.support property.""") has_support = False # If the distribution has a `support`, check if the support is bounded. # If it is not bounded, we want to z-score the space. This is not done # by `biject_to()`, so we have to deal with this case separately. if has_support: if hasattr(prior.support, "base_constraint"): constraint = prior.support.base_constraint # type: ignore else: constraint = prior.support if isinstance(constraint, constraints._Real): support_is_bounded = False else: support_is_bounded = True else: support_is_bounded = False # Prior with bounded support, e.g., uniform priors. if has_support and support_is_bounded: transform = biject_to(prior.support) # For all other cases build affine transform with mean and std. else: if hasattr(prior, "mean") and hasattr(prior, "stddev"): prior_mean = prior.mean.to(device) prior_std = prior.stddev.to(device) else: theta = prior.sample( torch.Size((num_prior_samples_for_zscoring, ))) prior_mean = theta.mean(dim=0).to(device) prior_std = theta.std(dim=0).to(device) transform = torch_tf.AffineTransform(loc=prior_mean, scale=prior_std) else: transform = torch_tf.identity_transform # Pytorch `transforms` do not sum the determinant over the parameters. However, if # the `transform` explicitly is an `IndependentTransform`, it does. Since our # `BoxUniform` is a `Independent` distribution, it will also automatically get a # `IndependentTransform` wrapper in `biject_to`. Our solution here is to wrap all # transforms as `IndependentTransform`. if not isinstance(transform, torch_tf.IndependentTransform): transform = torch_tf.IndependentTransform(transform, reinterpreted_batch_ndims=1) check_transform(prior, transform) # type: ignore return transform.inv # type: ignore
def run_pmmh( context: ParameterContext, state: FilterAlgorithmState, proposal: BaseProposal, proposal_kernel: Distribution, proposal_filter: BaseFilter, proposal_context: ParameterContext, y: torch.Tensor, size=torch.Size([]), mutate_kernel=False, ) -> torch.BoolTensor: r""" Runs one iteration of the PMMH update step in which we sample a candidate :math:`\theta^*` from the proposal kernel, run a filter for the considered dataset with :math:`\theta^*`, and accept based on the acceptance probability given by the article. Args: context: the parameter context of the main algorithm. state: the latest algorithm state. proposal: the proposal to use when generating the candidate sample :math:`\theta^*`. proposal_kernel: the kernel from which to draw the candidate sample :math:`\theta^*`. To clarify, ``proposal`` corresponds to the ``BaseProposal`` class that was used when generating ``prop_kernel``. proposal_filter: the proposal filter to use. proposal_context: the parameter context of the proposal filter. y: see ``pyfilter.inference.base.BaseAlgorithm``. size: optional parameter specifying the number of num_samples to draw from ``proposal_kernel``. Should be empty if we draw from an independent kernel. mutate_kernel: optional parameter specifying whether to update ``proposal_kernel`` with the newly accepted candidate sample. Returns: Returns the candidate sample(s) that were accepted. """ constrained = False rvs = proposal_kernel.sample(size) proposal_context.unstack_parameters(rvs, constrained=constrained) new_res = proposal_filter.batch_filter(y, bar=False) diff_logl = new_res.loglikelihood - state.filter_state.loglikelihood diff_prior = proposal_context.eval_priors( constrained=constrained) - context.eval_priors(constrained=constrained) new_prop_kernel = proposal.build(proposal_context, state.replicate(new_res), proposal_filter, y) params_as_tensor = context.stack_parameters(constrained=constrained) diff_prop = new_prop_kernel.log_prob( params_as_tensor) - proposal_kernel.log_prob(rvs) log_acc_prob = diff_prop + diff_prior.squeeze(-1) + diff_logl accepted: torch.BoolTensor = torch.empty_like( log_acc_prob).uniform_().log() < log_acc_prob state.filter_state.exchange(new_res, accepted) context.exchange(proposal_context, accepted) if mutate_kernel: proposal.exchange(proposal_kernel, new_prop_kernel, accepted) return accepted
def sample(self, actions: Distribution) -> Tensor: return actions.sample().clone().long()