def process_prior( prior, custom_prior_wrapper_kwargs: Dict = {} ) -> Tuple[Distribution, int, bool]: """Return PyTorch distribution-like prior from user-provided prior. Args: prior: Prior object with `.sample()` and `.log_prob()` as provided by the user. custom_prior_wrapper_kwargs: kwargs to be passed to the class that wraps a custom prior into a pytorch Distribution, e.g., for passing bounds for a prior with bounded support (lower_bound, upper_bound), or argument constraints. (arg_constraints), see pytorch.distributions.Distribution for more info. Raises: AttributeError: If prior objects lacks `.sample()` or `.log_prob()`. Returns: prior: Prior that emits samples and evaluates log prob as PyTorch Tensors. theta_numel: Number of parameters - elements in a single sample from the prior. prior_returns_numpy: Whether the return type of the prior was a Numpy array. """ # If prior is a sequence, assume independent components and check as PyTorch prior. if isinstance(prior, Sequence): warnings.warn( f"""Prior was provided as a sequence of {len(prior)} priors. They will be interpreted as independent of each other and matched in order to the components of the parameter.""") return process_pytorch_prior(MultipleIndependent(prior)) if isinstance(prior, Distribution): return process_pytorch_prior(prior) # If prior is given as `scipy.stats` object, wrap as PyTorch. elif isinstance(prior, (rv_frozen, multi_rv_frozen)): event_shape = torch.Size([prior.rvs().size]) # batch_shape is passed as default prior = ScipyPytorchWrapper(prior, batch_shape=torch.Size([]), event_shape=event_shape) return process_pytorch_prior(prior) # Otherwise it is a custom prior - check for `.sample()` and `.log_prob()`. else: return process_custom_prior(prior, custom_prior_wrapper_kwargs)
def test_independent_joint_shapes_and_samples(dists): """Test return shapes and validity of samples by comparing to samples generated from underlying distributions individually.""" # Fix the seed for reseeding within this test. seed = 0 joint = MultipleIndependent(dists) # Check shape of single sample and log prob sample = joint.sample() assert sample.shape == torch.Size([joint.ndims]) assert joint.log_prob(sample).shape == torch.Size([1]) num_samples = 10 # seed sampling for later comparison. torch.manual_seed(seed) samples = joint.sample((num_samples,)) log_probs = joint.log_prob(samples) # Check sample and log_prob return shapes. assert samples.shape == torch.Size( [num_samples, joint.ndims] ) or samples.shape == torch.Size([num_samples]) assert log_probs.shape == torch.Size([num_samples]) # Seed again to get same samples by hand. torch.manual_seed(seed) true_samples = [] true_log_probs = [] # Get samples and log probs by hand. for d in dists: sample = d.sample((num_samples,)) true_samples.append(sample) true_log_probs.append(d.log_prob(sample).reshape(num_samples, -1)) # collect in Tensor. true_samples = torch.cat(true_samples, dim=-1) true_log_probs = torch.cat(true_log_probs, dim=-1).sum(-1) # Check whether independent joint sample equal individual samples. assert (true_samples == samples).all() assert (true_log_probs == log_probs).all() # Check support attribute. within_support = joint.support.check(true_samples) assert within_support.all()
def test_invalid_inputs(): dists = [ Gamma(ones(1), ones(1)), Uniform(zeros(1), ones(1)), Beta(ones(1), 2 * ones(1)), ] joint = MultipleIndependent(dists) # Test too many input dimensions. with pytest.raises(AssertionError): joint.log_prob(ones(10, 4)) # Test nested construction. with pytest.raises(AssertionError): MultipleIndependent([joint]) # Test 3D value. with pytest.raises(AssertionError): joint.log_prob(ones(10, 4, 1))
_ = inference.append_simulations(theta, x).train(max_num_epochs=2) _ = inference.build_posterior() @pytest.mark.parametrize( "dists", [ pytest.param([Beta(ones(1), 2 * ones(1))], marks=pytest.mark.xfail), # single dist. pytest.param([Gamma(ones(2), ones(1))], marks=pytest.mark.xfail), # single batched dist. pytest.param( [ Gamma(ones(2), ones(1)), MultipleIndependent( [Uniform(zeros(1), ones(1)), Uniform(zeros(1), ones(1))]), ], marks=pytest.mark.xfail, ), # nested definition. pytest.param([Uniform(0, 1), Beta(1, 2)], marks=pytest.mark.xfail), # scalar dists. [Uniform(zeros(1), ones(1)), Uniform(zeros(1), ones(1))], ( Gamma(ones(1), ones(1)), Uniform(zeros(1), ones(1)), Beta(ones(1), 2 * ones(1)), ), [MultivariateNormal(zeros(3), eye(3)), Gamma(ones(1), ones(1))],
def test_mnle_accuracy(sampler): def mixed_simulator(theta): # Extract parameters beta, ps = theta[:, :1], theta[:, 1:] # Sample choices and rts independently. choices = Binomial(probs=ps).sample() rts = InverseGamma(concentration=2 * torch.ones_like(beta), rate=beta).sample() return torch.cat((rts, choices), dim=1) prior = MultipleIndependent( [ Gamma(torch.tensor([1.0]), torch.tensor([0.5])), Beta(torch.tensor([2.0]), torch.tensor([2.0])), ], validate_args=False, ) num_simulations = 2000 num_samples = 1000 theta = prior.sample((num_simulations, )) x = mixed_simulator(theta) # MNLE trainer = MNLE(prior) trainer.append_simulations(theta, x).train() posterior = trainer.build_posterior() mcmc_kwargs = dict( num_chains=10, warmup_steps=100, method="slice_np_vectorized", init_strategy="proposal", ) for num_trials in [10]: theta_o = prior.sample((1, )) x_o = mixed_simulator(theta_o.repeat(num_trials, 1)) # True posterior samples transform = mcmc_transform(prior) true_posterior_samples = MCMCPosterior( PotentialFunctionProvider(prior, atleast_2d(x_o)), theta_transform=transform, proposal=prior, **mcmc_kwargs, ).sample((num_samples, ), show_progress_bars=False) posterior = trainer.build_posterior(prior=prior, sample_with=sampler) posterior.set_default_x(x_o) if sampler == "vi": posterior.train() mnle_posterior_samples = posterior.sample( sample_shape=(num_samples, ), show_progress_bars=False, **mcmc_kwargs if sampler == "mcmc" else {}, ) check_c2st( mnle_posterior_samples, true_posterior_samples, alg=f"MNLE with {sampler}", )