예제 #1
0
def test_can_get_acceptance_rates(net: BayesNet) -> None:
    acceptance_rate_tracker = AcceptanceRateTracker()
    latents = list(net.get_latent_vertices())

    algo = MetropolisHastingsSampler(proposal_distribution='prior', proposal_listeners=[acceptance_rate_tracker])
    samples = sample(net=net, sample_from=latents, sampling_algorithm=algo, drop=3)

    for latent in latents:
        rate = acceptance_rate_tracker.get_acceptance_rate(latent)
        assert 0 <= rate <= 1
예제 #2
0
def test_can_track_acceptance_rate_when_iterating(net: BayesNet) -> None:
    acceptance_rate_tracker = AcceptanceRateTracker()
    latents = list(net.get_latent_vertices())

    samples = generate_samples(net=net,
                               sample_from=latents,
                               proposal_distribution='prior',
                               proposal_listeners=[acceptance_rate_tracker],
                               drop=3)

    draws = 100
    for _ in islice(samples, draws):
        for latent in latents:
            rate = acceptance_rate_tracker.get_acceptance_rate(latent)
            assert 0 <= rate <= 1
예제 #3
0
def test_it_throws_if_you_pass_in_a_proposal_listener_but_you_didnt_specify_the_proposal_type(
        net: BayesNet) -> None:
    with pytest.raises(TypeError) as excinfo:
        sample(net=net,
               sample_from=net.get_latent_vertices(),
               proposal_listeners=[AcceptanceRateTracker()],
               drop=3)
    assert str(
        excinfo.value
    ) == "If you pass in proposal_listeners you must also specify proposal_distribution"
예제 #4
0
def test_it_throws_if_you_pass_in_a_proposal_listener_but_the_algo_isnt_metropolis(
        net: BayesNet) -> None:
    with pytest.raises(TypeError) as excinfo:
        sample(net=net,
               sample_from=net.get_latent_vertices(),
               algo="hamiltonian",
               proposal_listeners=[AcceptanceRateTracker()],
               drop=3)
    assert str(
        excinfo.value
    ) == "Only Metropolis Hastings supports the proposal_listeners parameter"
예제 #5
0
def test_it_throws_if_you_pass_in_a_proposal_listener_but_you_didnt_specify_the_proposal_type(net: BayesNet) -> None:
    with pytest.raises(TypeError) as excinfo:
        algo = MetropolisHastingsSampler(proposal_listeners=[AcceptanceRateTracker()])

    assert str(excinfo.value) == "If you pass in proposal_listeners you must also specify proposal_distribution"