def process_x(x: Tensor, x_shape: torch.Size, allow_iid_x: bool = False) -> Tensor: """Return observed data adapted to match sbi's shape and type requirements. Args: x: Observed data as provided by the user. x_shape: Prescribed shape - either directly provided by the user at init or inferred by sbi by running a simulation and checking the output. allow_iid_x: Whether multiple trials in x are allowed. Returns: x: Observed data with shape ready for usage in sbi. """ x = torch.as_tensor(atleast_2d(x), dtype=float32) input_x_shape = x.shape if not allow_iid_x: check_for_possibly_batched_x_shape(input_x_shape) start_idx = 0 else: warn_on_iid_x(num_trials=input_x_shape[0]) start_idx = 1 # Number of trials can change for every new x, but single trial x shape must match. assert input_x_shape[start_idx:] == x_shape[start_idx:], ( f"Observed data shape ({input_x_shape[start_idx:]}) must match " f"the shape of simulated data x ({x_shape[start_idx:]}).") return x
def sample(self, num_samples: int, context: Tensor = None, **kwargs) -> Tensor: """ Return samples from posterior distribution. Args: num_samples: number of samples context: conditioning observation. Will be _true_observation if None **kwargs: Additional parameters passed to MCMC sampler (thin and warmup) Returns: samples from posterior. """ context = self._context if context is None else atleast_2d(context) if self._sample_with_mcmc: return self._sample_posterior_mcmc( context=context, num_samples=num_samples, mcmc_method=self._mcmc_method, **kwargs, ) else: # rejection sampling samples, _ = utils.sample_posterior_within_prior( self.neural_net, self._prior, context, num_samples=num_samples ) return samples
def process_observed_data(observed_data: Union[Tensor, np.ndarray], simulator: Callable, prior) -> Tuple[Tensor, int]: """Check and correct for requirements on the observed data. Args: observed_data: observed data as provided by the user. simulator: simulator function as provided by the user. prior: prior object. Returns: observed data: observed data with shape corrected for usage in SBI. observation_dim: number of elements in a single data point. """ # maybe add batch dimension, cast to tensor observed_data = atleast_2d(observed_data) check_for_possibly_batched_observations(observed_data) # Get unbatched simulated data by sampling from prior and simulator. # cast to tensor for comparison simulated_data = torch.as_tensor(simulator(prior.sample()), dtype=torch.float32).squeeze(0) # Get data shape by ommitting the batch dimension. observed_data_shape = observed_data.shape[1:] assert observed_data_shape == simulated_data.shape, ( f"Observed data shape ({observed_data_shape}) must match " f"simulator output shape ({simulated_data.shape}).") observation_dim = observed_data[0, :].numel() return observed_data, observation_dim
def __call__( self, prior, posterior_nn: nn.Module, x: Tensor, method: str, ) -> Callable: """Return potential function. Switch on numpy or pyro potential function based on `method`. """ self.posterior_nn = posterior_nn self.prior = prior self.device = next(posterior_nn.parameters()).device self.x = atleast_2d(x).to(self.device) if method == "slice": return partial(self.pyro_potential, track_gradients=False) elif method in ("hmc", "nuts"): return partial(self.pyro_potential, track_gradients=True) elif "slice_np" in method: return partial(self.posterior_potential, track_gradients=False) elif method == "rejection": return partial(self.posterior_potential, track_gradients=True) else: NotImplementedError
def __call__(self, theta, track_gradients: bool = True): theta = atleast_2d(theta) with torch.set_grad_enabled(track_gradients): iid_ll = self.iid_likelihood(theta) return iid_ll + self.prior.log_prob(theta)
def test_atleast_2d(): t1 = np.array([0.0, -1.0, 1.0]) t2 = torch.tensor([[1, 2, 3]]) t3, t4 = torchutils.atleast_2d(t1, t2) assert isinstance(t3, torch.Tensor) assert t3.ndim == 2 assert t4.ndim == 2
def log_prob(self, value: Tensor) -> Tensor: """ Return ones as a constant log-prob for each input. Args: value: The parameters at which to evaluate the log-probability. Returns: Tensor of as many ones as there were parameter sets. """ value = atleast_2d(value) return zeros(value.shape[0])
def samples_true_posterior_linear_gaussian_uniform_prior( x_o: Tensor, likelihood_shift: Tensor, likelihood_cov: Tensor, prior: Union[Uniform, Independent], num_samples: int = 1000, ) -> Tensor: """ Returns ground truth posterior samples for Gaussian likelihood and uniform prior. Args: x_o: The observation. likelihood_shift: Mean of the likelihood p(x|theta) is likelihood_shift+theta. likelihood_cov: Covariance matrix of likelihood. prior: Uniform prior distribution. num_samples: Desired number of samples. Returns: Samples from posterior. """ # Let s denote the likelihood_shift: # The likelihood has the term (x-(s+theta))^2 in the exponent of the Gaussian. # In other words, as a function of x, the mean of the likelihood is s+theta. # For computing the posterior we need the likelihood as a function of theta. Hence: # (x-(s+theta))^2 = (theta-(-s+x))^2 # We see that the mean is -s+x = x-s # Take into account iid trials x_o = atleast_2d(x_o) num_trials, *_ = x_o.shape x_o_mean = x_o.mean(0) likelihood_mean = x_o_mean - likelihood_shift posterior = MultivariateNormal(loc=likelihood_mean, covariance_matrix=1 / num_trials * likelihood_cov) # generate samples from ND Gaussian truncated by prior support num_remaining = num_samples samples = [] while num_remaining > 0: candidate_samples = posterior.sample( sample_shape=torch.Size((num_remaining, ))) is_in_prior = within_support(prior, candidate_samples) # accept if in prior if is_in_prior.sum(): samples.append(candidate_samples[is_in_prior, :]) num_remaining -= int(is_in_prior.sum().item()) return torch.cat(samples)
def _log_ratios_over_trials(x: Tensor, theta: Tensor, net: nn.Module, track_gradients: bool = False) -> Tensor: r"""Return log ratios summed over iid trials of `x`. Note: `x` can be a batch with batch size larger 1. Batches in x are assumed to be iid trials, i.e., data generated based on the same paramters / experimental conditions. Repeats `x` and $\theta$ to cover all their combinations of batch entries. Args: x: batch of iid data. theta: batch of parameters net: neural net representing the classifier to approximate the ratio. track_gradients: Whether to track gradients. Returns: log_ratio_trial_sum: log ratio for each parameter, summed over all batch entries (iid trials) in `x`. """ theta_repeated, x_repeated = match_theta_and_x_batch_shapes( theta=atleast_2d(theta), x=atleast_2d(x)) assert (x_repeated.shape[0] == theta_repeated.shape[0] ), "x and theta must match in batch shape." assert ( next(net.parameters()).device == x.device and x.device == theta.device ), f"""device mismatch: net, x, theta: {next(net.parameters()).device}, {x.device}, {theta.device}.""" # Calculate ratios in one batch. with torch.set_grad_enabled(track_gradients): log_ratio_trial_batch = net([theta_repeated, x_repeated]) # Reshape to (x-trials x parameters), sum over trial-log likelihoods. log_ratio_trial_sum = log_ratio_trial_batch.reshape(x.shape[0], -1).sum(0) return log_ratio_trial_sum
def _log_likelihoods_over_trials( x: Tensor, theta: Tensor, net: nn.Module, track_gradients: bool = False, ) -> Tensor: r"""Return log likelihoods summed over iid trials of `x`. Note: `x` can be a batch with batch size larger 1. Batches in `x` are assumed to be iid trials, i.e., data generated based on the same paramters / experimental conditions. Repeats `x` and $\theta$ to cover all their combinations of batch entries. Args: x: batch of iid data. theta: batch of parameters net: neural net with .log_prob() track_gradients: Whether to track gradients. Returns: log_likelihood_trial_sum: log likelihood for each parameter, summed over all batch entries (iid trials) in `x`. """ # Repeat `x` in case of evaluation on multiple `theta`. This is needed below in # when calling nflows in order to have matching shapes of theta and context x # at neural network evaluation time. theta_repeated, x_repeated = NeuralPosterior._match_theta_and_x_batch_shapes( theta=theta, x=atleast_2d(x)) assert (x_repeated.shape[0] == theta_repeated.shape[0] ), "x and theta must match in batch shape." assert ( next(net.parameters()).device == x.device and x.device == theta.device ), f"device mismatch: net, x, theta: {next(net.parameters()).device}, {x.decive}, {theta.device}." # Calculate likelihood in one batch. with torch.set_grad_enabled(track_gradients): log_likelihood_trial_batch = net.log_prob(x_repeated, theta_repeated) # Reshape to (x-trials x parameters), sum over trial-log likelihoods. log_likelihood_trial_sum = log_likelihood_trial_batch.reshape( x.shape[0], -1).sum(0) return log_likelihood_trial_sum
def true_posterior_linear_gaussian_mvn_prior( x_o: Tensor, likelihood_shift: Tensor, likelihood_cov: Tensor, prior_mean: Tensor, prior_cov: Tensor, ) -> MultivariateNormal: """ Returns the posterior when likelihood and prior are Gaussian. We follow the implementation suggested by rhashimoto here: https://math.stackexchange.com/questions/157172 as it requires only one matrix inverse. Args: x_o: The observation. likelihood_shift: Mean of the likelihood p(x|theta) is likelihood_shift+theta. likelihood_cov: Covariance matrix of likelihood. prior_mean: Mean of prior. prior_cov: Covariance matrix of prior. Returns: Posterior distribution. """ # Let s denote the likelihood_shift: # The likelihood has the term (x-(s+theta))^2 in the exponent of the Gaussian. # In other words, as a function of x, the mean of the likelihood is s+theta. # For computing the posterior we need the likelihood as a function of theta. Hence: # (x-(s+theta))^2 = (theta-(-s+x))^2 # We see that the mean is -s+x = x-s # Take into account iid trials x_o = atleast_2d(x_o) num_trials, *_ = x_o.shape x_o_mean = x_o.mean(0) likelihood_mean = x_o_mean - likelihood_shift product_mean, product_cov = multiply_gaussian_pdfs( likelihood_mean, 1 / num_trials * likelihood_cov, prior_mean, prior_cov) posterior_dist = MultivariateNormal(product_mean, product_cov) return posterior_dist
def process_x(x: Tensor, x_shape: torch.Size) -> Tensor: """Return observed data adapted to match sbi's shape and type requirements. Args: x: Observed data as provided by the user. x_shape: Prescribed shape - either directly provided by the user at init or inferred by sbi by running a simulation and checking the output. Returns: x: Observed data with shape ready for usage in sbi. """ x = torch.as_tensor(atleast_2d(x), dtype=float32) check_for_possibly_batched_observations(x) input_x_shape = x.shape assert input_x_shape == x_shape, ( f"Observed data shape ({input_x_shape}) must match " f"the shape of simulated data x ({x_shape}).") return x
def __call__( self, prior, classifier: nn.Module, x: Tensor, method: str, ) -> Callable: r"""Return potential function for posterior $p(\theta|x)$. Switch on numpy or pyro potential function based on `method`. Args: prior: Prior distribution that can be evaluated. classifier: Binary classifier approximating the likelihood up to a constant. x: Conditioning variable for posterior $p(\theta|x)$. method: One of `slice_np`, `slice`, `hmc` or `nuts`, `rejection`. Returns: Potential function for sampler. """ self.classifier = classifier self.prior = prior self.device = next(classifier.parameters()).device self.x = atleast_2d(x).to(self.device) if method == "slice": return partial(self.pyro_potential, track_gradients=False) elif method in ("hmc", "nuts"): return partial(self.pyro_potential, track_gradients=True) elif "slice_np" in method: return partial(self.posterior_potential, track_gradients=False) elif method == "rejection": return partial(self.posterior_potential, track_gradients=True) else: NotImplementedError
def log_prob_iid(self, x: Tensor, theta: Tensor) -> Tensor: """Return log prob given a batch of iid x and a different batch of theta. This is different from `.log_prob()` to enable speed ups in evaluation during inference. The speed up is achieved by exploiting the fact that there are only finite number of possible categories in the discrete part of the dat: one can just calculate the log probs for each possible category (given the current batch of theta) and then copy those log probs into the entire batch of iid categories. For example, for the drift-diffusion model, there are only two choices, but often 100s or 1000 trials. With this method a evaluation over trials then passes a batch of `2 (one per choice) * num_thetas` into the NN, whereas the normal `.log_prob()` would pass `1000 * num_thetas`. Args: x: batch of iid data, data observed given the same underlying parameters or experimental conditions. theta: batch of parameters to be evaluated, i.e., each batch entry will be evaluated for the entire batch of iid x. Returns: Tensor: log probs with shape (num_trials, num_parameters), i.e., the log prob for each theta for each trial. """ theta = atleast_2d(theta) x = atleast_2d(x) batch_size = theta.shape[0] num_trials = x.shape[0] theta_repeated, x_repeated = match_theta_and_x_batch_shapes(theta, x) net_device = next(self.discrete_net.parameters()).device assert ( net_device == x.device and x.device == theta.device ), f"device mismatch: net, x, theta: {net_device}, {x.device}, {theta.device}." x_cont_repeated, x_disc_repeated = _separate_x(x_repeated) x_cont, x_disc = _separate_x(x) log_prob_per_cat = torch.zeros(self.discrete_net.num_categories, batch_size) # repeat categories for parameters repeated_categories = torch.repeat_interleave( torch.arange(self.discrete_net.num_categories - 1), batch_size, dim=0) # repeat parameters for categories repeated_theta = theta.repeat(self.discrete_net.num_categories - 1, 1) log_prob_per_cat[:-1, :] = self.discrete_net.log_prob( repeated_categories, repeated_theta, ).reshape(-1, batch_size) # infer the last category logprob from sum to one. log_prob_per_cat[-1, :] = torch.log( 1 - log_prob_per_cat[:-1, :].exp().sum(0)) # fill in lps for each occurred category log_probs_discrete = log_prob_per_cat[x_disc.type_as( torch.zeros(1, dtype=torch.long)).squeeze()].reshape(-1) # Get repeat discrete data and theta to match in batch shape for flow eval. log_probs_cont = self.continuous_net.log_prob( torch.log(x_cont_repeated) if self.log_transform_x else x_cont_repeated, context=torch.cat((theta_repeated, x_disc_repeated), dim=1), ) # Combine into joint lp with first dim over trials. log_probs_combined = (log_probs_discrete + log_probs_cont).reshape( num_trials, batch_size) # Maybe add log abs det jacobian of RTs: log(1/rt) = - log(rt) if self.log_transform_x: log_probs_combined -= torch.log(x_cont) # Return batch over trials as required by SBI potentials. return log_probs_combined
def test_mnle_accuracy(sampler): def mixed_simulator(theta): # Extract parameters beta, ps = theta[:, :1], theta[:, 1:] # Sample choices and rts independently. choices = Binomial(probs=ps).sample() rts = InverseGamma(concentration=2 * torch.ones_like(beta), rate=beta).sample() return torch.cat((rts, choices), dim=1) prior = MultipleIndependent( [ Gamma(torch.tensor([1.0]), torch.tensor([0.5])), Beta(torch.tensor([2.0]), torch.tensor([2.0])), ], validate_args=False, ) num_simulations = 2000 num_samples = 1000 theta = prior.sample((num_simulations, )) x = mixed_simulator(theta) # MNLE trainer = MNLE(prior) trainer.append_simulations(theta, x).train() posterior = trainer.build_posterior() mcmc_kwargs = dict( num_chains=10, warmup_steps=100, method="slice_np_vectorized", init_strategy="proposal", ) for num_trials in [10]: theta_o = prior.sample((1, )) x_o = mixed_simulator(theta_o.repeat(num_trials, 1)) # True posterior samples transform = mcmc_transform(prior) true_posterior_samples = MCMCPosterior( PotentialFunctionProvider(prior, atleast_2d(x_o)), theta_transform=transform, proposal=prior, **mcmc_kwargs, ).sample((num_samples, ), show_progress_bars=False) posterior = trainer.build_posterior(prior=prior, sample_with=sampler) posterior.set_default_x(x_o) if sampler == "vi": posterior.train() mnle_posterior_samples = posterior.sample( sample_shape=(num_samples, ), show_progress_bars=False, **mcmc_kwargs if sampler == "mcmc" else {}, ) check_c2st( mnle_posterior_samples, true_posterior_samples, alg=f"MNLE with {sampler}", )