Пример #1
0
def test_configure_latent_prior_var(proposal):
    """
    Test the latent prior when using a truncated Gaussian with a variance.
    """
    proposal.latent_prior = 'truncated_gaussian'
    proposal.flow_config = {'model_config': {'kwargs': {'var': 4}}}
    proposal.draw_latent_kwargs = {}
    FlowProposal.configure_latent_prior(proposal)
    assert proposal.draw_latent_prior == \
        getattr(utils, 'draw_truncated_gaussian')
    assert proposal.draw_latent_kwargs.get('var') == 4
Пример #2
0
def test_configure_latent_prior(proposal, latent_prior, prior_func):
    """Test to make sure the correct latent priors are used."""
    proposal.latent_prior = latent_prior
    proposal.flow_config = {'model_config': {}}
    FlowProposal.configure_latent_prior(proposal)
    assert proposal.draw_latent_prior == getattr(utils, prior_func)
Пример #3
0
def test_configure_latent_prior_unknown(proposal):
    """Make sure unknown latent priors raise an error"""
    proposal.latent_prior = 'truncated'
    with pytest.raises(RuntimeError) as excinfo:
        FlowProposal.configure_latent_prior(proposal)
    assert 'Unknown latent prior: truncated, ' in str(excinfo.value)