def test_configure_latent_prior_var(proposal): """ Test the latent prior when using a truncated Gaussian with a variance. """ proposal.latent_prior = 'truncated_gaussian' proposal.flow_config = {'model_config': {'kwargs': {'var': 4}}} proposal.draw_latent_kwargs = {} FlowProposal.configure_latent_prior(proposal) assert proposal.draw_latent_prior == \ getattr(utils, 'draw_truncated_gaussian') assert proposal.draw_latent_kwargs.get('var') == 4
def test_configure_latent_prior(proposal, latent_prior, prior_func): """Test to make sure the correct latent priors are used.""" proposal.latent_prior = latent_prior proposal.flow_config = {'model_config': {}} FlowProposal.configure_latent_prior(proposal) assert proposal.draw_latent_prior == getattr(utils, prior_func)
def test_configure_latent_prior_unknown(proposal): """Make sure unknown latent priors raise an error""" proposal.latent_prior = 'truncated' with pytest.raises(RuntimeError) as excinfo: FlowProposal.configure_latent_prior(proposal) assert 'Unknown latent prior: truncated, ' in str(excinfo.value)