示例#1
0
def test_metropolis_hastings_normal_normal_10_chains():
    def model(key):
        keys = jax.random.split(key, 2)
        n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)),
                         keys[0])
        n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), keys[1])
        return n2

    conditioned_model = func.condition(model, {'n2': jnp.array(0.5)})

    def proposal(key, **current):
        n1 = func.sample('n1', dist.normal(current['n1'], jnp.array(5.)), key)
        return {'n1': n1}

    initial_samples = {'n1': jnp.array(0.)}

    mcmc_model = func.metropolis_hastings(conditioned_model,
                                          proposal,
                                          initial_samples,
                                          num_chains=10)

    keys = jax.random.split(jax.random.PRNGKey(123), 300)
    samples = []
    for i in range(300):
        samples.append(mcmc_model(keys[i])['n1'])

    samples = jnp.stack(samples)[200:].reshape(-1)
    tu.check_close(jnp.mean(samples), 0.21243224)
示例#2
0
def test_sample_binomial_invalid_condition():
    def model(key):
        keys = jax.random.split(key)
        n1 = func.sample('n1', dist.bernoulli(jnp.array(0.5)), keys[0])
        n2 = func.sample('n2', dist.binomial(jnp.array(1), n1), keys[1])
        return n2

    conditioned_model = func.condition(model, {'n1': jnp.array(0.5)})
    key = jax.random.PRNGKey(123)
    sample = conditioned_model(key)
    assert jnp.isnan(sample)

    conditioned_model = func.condition(model, {'n1': jnp.array(2)})
    key = jax.random.PRNGKey(123)
    sample = conditioned_model(key)
    assert jnp.isnan(sample)
示例#3
0
def test_metropolis_hastings_beta_bernoulli_10chains():
    def model(key):
        keys = jax.random.split(key, 2)
        n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)),
                         keys[0])
        n2 = func.sample('n2', dist.bernoulli(n1), keys[1])
        return n2

    conditioned_model = func.condition(model, {'n2': jnp.array(1)})

    def proposal(key, **current):  # use prior as proposal
        n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)), key)
        return {'n1': n1}

    initial_samples = {'n1': jnp.array(0.5)}

    mcmc_model = func.metropolis_hastings(conditioned_model,
                                          proposal,
                                          initial_samples,
                                          num_chains=10)

    keys = jax.random.split(jax.random.PRNGKey(123), 300)
    samples = []
    for i in range(300):
        samples.append(mcmc_model(keys[i])['n1'])

    samples = jnp.stack(samples)[200:]
    tu.check_close(jnp.mean(samples), 0.999707)
示例#4
0
def test_metropolis_hastings_normal_normal_multidim_10_chains():
    def model(key):
        keys = jax.random.split(key, 2)
        n1 = func.sample('n1',
                         dist.normal(jnp.zeros((2, 2)), jnp.full((2, 2), 1.)),
                         keys[0])
        n2 = func.sample('n2', dist.normal(n1, jnp.ones((2, 2))), keys[1])
        return n2

    conditioned_model = func.condition(model, {'n2': jnp.full((2, 2), 0.5)})

    def proposal(key, **current):
        n1 = func.sample('n1', dist.normal(current['n1'], jnp.full((2, 2),
                                                                   5.)), key)
        return {'n1': n1}

    initial_samples = {'n1': jnp.zeros((2, 2))}

    mcmc_model = func.metropolis_hastings(conditioned_model,
                                          proposal,
                                          initial_samples,
                                          num_chains=10)

    keys = jax.random.split(jax.random.PRNGKey(123), 300)
    samples = []
    for i in range(300):
        samples.append(mcmc_model(keys[i])['n1'])

    samples = jnp.stack(samples)[200:]
    tu.check_close(
        jnp.mean(samples, (0, 1)),
        jnp.array([[0.17168559, 0.14462896], [0.15957117, 0.13937134]]))
示例#5
0
def test_trace_normal_normal():
    def model(key):
        keys = jax.random.split(key)
        n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)),
                         keys[0])
        n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), keys[1])
        return n2

    conditioned_model = func.condition(model, {
        'n1': jnp.array(0.),
        'n2': jnp.array(1.)
    })
    tracer = func.trace(conditioned_model)
    key = jax.random.PRNGKey(123)
    tracer(key)
    tree = tracer.get_tree()

    assert 'n1' in tree.nodes \
        and tree.nodes['n1'].distribution.mu == 0. \
        and tree.nodes['n1'].distribution.sigma == 1. \
        and tree.nodes['n1'].value == 0.

    assert 'n2' in tree.nodes \
        and tree.nodes['n2'].distribution.mu == 0. \
        and tree.nodes['n2'].distribution.sigma == 1. \
        and tree.nodes['n2'].value == 1.
示例#6
0
def test_sample_dir_conditioned_invalid_value_error():
    def model(key):
        n1 = func.sample('n1', dist.dirichlet(jnp.array([1.0, 0.5])), key)
        return n1

    conditioned_model = func.condition(model, {'n1': jnp.array([0.8, 0.7])})
    key = jax.random.PRNGKey(123)
    sample = conditioned_model(key)
    assert jnp.all(jnp.isnan(sample))
示例#7
0
def test_sample_beta_conditioned_invalid_value_error():
    def model(key):
        n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)), key)
        return n1

    conditioned_model = func.condition(model, {'n1': jnp.array(2.0)})
    key = jax.random.PRNGKey(123)
    sample = conditioned_model(key)
    assert jnp.isnan(sample)
示例#8
0
def test_sample_multinomial_invalid_condition():
    def model(key):
        n1 = func.sample('n1', dist.categorical(jnp.array([0.1, 0.8, 0.1])),
                         key)
        return n1

    conditioned_model = func.condition(model, {'n1': jnp.array([2, 0, 0])})
    key = jax.random.PRNGKey(123)
    sample = conditioned_model(key)
    assert jnp.all(jnp.isnan(sample))  # more than 1 outcome
示例#9
0
def test_sample_conditioned():
    def model(key):
        keys = jax.random.split(key)
        n1 = func.sample('n1', dist.bernoulli(jnp.array(0.5)), keys[0])
        n2 = func.sample('n2', dist.binomial(n1, jnp.array(0.5)), keys[1])
        return n2

    conditioned_model = func.condition(model, {'n1': jnp.array(1)})
    keys = jax.random.split(jax.random.PRNGKey(123), 100)
    samples = jax.vmap(lambda k: conditioned_model(k))(keys)

    assert jnp.mean(samples) > 0.4 and jnp.mean(samples) < 0.6
示例#10
0
def test_sample_conditioned():
    def model(key):
        keys = jax.random.split(key)
        weight = func.sample('weight', dist.normal(jnp.array(0.),
                                                   jnp.array(1.)), keys[0])
        measurement = func.sample('measurement',
                                  dist.normal(weight, jnp.array(1.)), keys[1])
        return measurement

    conditioned_model = func.condition(model, {'weight': jnp.array(0.)})
    keys = jax.random.split(jax.random.PRNGKey(123), 100)
    samples = jax.vmap(lambda k: conditioned_model(k))(keys)

    assert abs(jnp.mean(samples)) < 0.2
示例#11
0
def test_log_prob_normal_normal():
    def model(key):
        keys = jax.random.split(key)
        n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)),
                         keys[0])
        n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), keys[1])
        return n2

    conditioned_model = func.condition(model, {
        'n1': jnp.array(0.),
        'n2': jnp.array(1.)
    })
    tracer = func.trace(conditioned_model)
    key = jax.random.PRNGKey(123)
    tracer(key)
    tree = tracer.get_tree()
    log_prob = func.log_prob(tree)

    tu.check_close(log_prob, -2.337877)
示例#12
0
def test_metropolis_hastings_wrong_proposal_or_initial_samples():
    def model(key):
        n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)), key)
        n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), key)
        return n2

    def proposal(key, **current):
        n4 = func.sample('n4', dist.normal(current['n4'], jnp.array(1.)), key)
        return {'n4': n4}

    conditioned_model = func.condition(model, {'n2': jnp.array(0.5)})
    initial_samples = {'n4': jnp.array(0.)}

    with pytest.raises(RuntimeError):  # wrong proposal and initial samples
        key = jax.random.PRNGKey(123)
        mcmc_model = func.metropolis_hastings(conditioned_model,
                                              proposal,
                                              initial_samples,
                                              num_chains=1)
        mcmc_model(key)
示例#13
0
def test_log_prob_normal():
    def model(key):
        n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)), key)
        return n1

    # Unconditioned
    tracer = func.trace(model)
    key = jax.random.PRNGKey(123)
    tracer(key)
    tree = tracer.get_tree()
    log_prob = func.log_prob(tree)

    tu.check_close(log_prob, -1.2025023)

    # Conditioned
    conditioned_model = func.condition(model, {'n1': jnp.array(0.)})
    tracer = func.trace(conditioned_model)
    key = jax.random.PRNGKey(123)
    tracer(key)
    tree = tracer.get_tree()
    log_prob = func.log_prob(tree)

    tu.check_close(log_prob, -0.9189385)