def test_metropolis_hastings_normal_normal_10_chains(): def model(key): keys = jax.random.split(key, 2) n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)), keys[0]) n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), keys[1]) return n2 conditioned_model = func.condition(model, {'n2': jnp.array(0.5)}) def proposal(key, **current): n1 = func.sample('n1', dist.normal(current['n1'], jnp.array(5.)), key) return {'n1': n1} initial_samples = {'n1': jnp.array(0.)} mcmc_model = func.metropolis_hastings(conditioned_model, proposal, initial_samples, num_chains=10) keys = jax.random.split(jax.random.PRNGKey(123), 300) samples = [] for i in range(300): samples.append(mcmc_model(keys[i])['n1']) samples = jnp.stack(samples)[200:].reshape(-1) tu.check_close(jnp.mean(samples), 0.21243224)
def test_sample_binomial_invalid_condition(): def model(key): keys = jax.random.split(key) n1 = func.sample('n1', dist.bernoulli(jnp.array(0.5)), keys[0]) n2 = func.sample('n2', dist.binomial(jnp.array(1), n1), keys[1]) return n2 conditioned_model = func.condition(model, {'n1': jnp.array(0.5)}) key = jax.random.PRNGKey(123) sample = conditioned_model(key) assert jnp.isnan(sample) conditioned_model = func.condition(model, {'n1': jnp.array(2)}) key = jax.random.PRNGKey(123) sample = conditioned_model(key) assert jnp.isnan(sample)
def test_metropolis_hastings_beta_bernoulli_10chains(): def model(key): keys = jax.random.split(key, 2) n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)), keys[0]) n2 = func.sample('n2', dist.bernoulli(n1), keys[1]) return n2 conditioned_model = func.condition(model, {'n2': jnp.array(1)}) def proposal(key, **current): # use prior as proposal n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)), key) return {'n1': n1} initial_samples = {'n1': jnp.array(0.5)} mcmc_model = func.metropolis_hastings(conditioned_model, proposal, initial_samples, num_chains=10) keys = jax.random.split(jax.random.PRNGKey(123), 300) samples = [] for i in range(300): samples.append(mcmc_model(keys[i])['n1']) samples = jnp.stack(samples)[200:] tu.check_close(jnp.mean(samples), 0.999707)
def test_metropolis_hastings_normal_normal_multidim_10_chains(): def model(key): keys = jax.random.split(key, 2) n1 = func.sample('n1', dist.normal(jnp.zeros((2, 2)), jnp.full((2, 2), 1.)), keys[0]) n2 = func.sample('n2', dist.normal(n1, jnp.ones((2, 2))), keys[1]) return n2 conditioned_model = func.condition(model, {'n2': jnp.full((2, 2), 0.5)}) def proposal(key, **current): n1 = func.sample('n1', dist.normal(current['n1'], jnp.full((2, 2), 5.)), key) return {'n1': n1} initial_samples = {'n1': jnp.zeros((2, 2))} mcmc_model = func.metropolis_hastings(conditioned_model, proposal, initial_samples, num_chains=10) keys = jax.random.split(jax.random.PRNGKey(123), 300) samples = [] for i in range(300): samples.append(mcmc_model(keys[i])['n1']) samples = jnp.stack(samples)[200:] tu.check_close( jnp.mean(samples, (0, 1)), jnp.array([[0.17168559, 0.14462896], [0.15957117, 0.13937134]]))
def test_trace_normal_normal(): def model(key): keys = jax.random.split(key) n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)), keys[0]) n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), keys[1]) return n2 conditioned_model = func.condition(model, { 'n1': jnp.array(0.), 'n2': jnp.array(1.) }) tracer = func.trace(conditioned_model) key = jax.random.PRNGKey(123) tracer(key) tree = tracer.get_tree() assert 'n1' in tree.nodes \ and tree.nodes['n1'].distribution.mu == 0. \ and tree.nodes['n1'].distribution.sigma == 1. \ and tree.nodes['n1'].value == 0. assert 'n2' in tree.nodes \ and tree.nodes['n2'].distribution.mu == 0. \ and tree.nodes['n2'].distribution.sigma == 1. \ and tree.nodes['n2'].value == 1.
def test_sample_dir_conditioned_invalid_value_error(): def model(key): n1 = func.sample('n1', dist.dirichlet(jnp.array([1.0, 0.5])), key) return n1 conditioned_model = func.condition(model, {'n1': jnp.array([0.8, 0.7])}) key = jax.random.PRNGKey(123) sample = conditioned_model(key) assert jnp.all(jnp.isnan(sample))
def test_sample_beta_conditioned_invalid_value_error(): def model(key): n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)), key) return n1 conditioned_model = func.condition(model, {'n1': jnp.array(2.0)}) key = jax.random.PRNGKey(123) sample = conditioned_model(key) assert jnp.isnan(sample)
def test_sample_multinomial_invalid_condition(): def model(key): n1 = func.sample('n1', dist.categorical(jnp.array([0.1, 0.8, 0.1])), key) return n1 conditioned_model = func.condition(model, {'n1': jnp.array([2, 0, 0])}) key = jax.random.PRNGKey(123) sample = conditioned_model(key) assert jnp.all(jnp.isnan(sample)) # more than 1 outcome
def test_sample_conditioned(): def model(key): keys = jax.random.split(key) n1 = func.sample('n1', dist.bernoulli(jnp.array(0.5)), keys[0]) n2 = func.sample('n2', dist.binomial(n1, jnp.array(0.5)), keys[1]) return n2 conditioned_model = func.condition(model, {'n1': jnp.array(1)}) keys = jax.random.split(jax.random.PRNGKey(123), 100) samples = jax.vmap(lambda k: conditioned_model(k))(keys) assert jnp.mean(samples) > 0.4 and jnp.mean(samples) < 0.6
def test_sample_conditioned(): def model(key): keys = jax.random.split(key) weight = func.sample('weight', dist.normal(jnp.array(0.), jnp.array(1.)), keys[0]) measurement = func.sample('measurement', dist.normal(weight, jnp.array(1.)), keys[1]) return measurement conditioned_model = func.condition(model, {'weight': jnp.array(0.)}) keys = jax.random.split(jax.random.PRNGKey(123), 100) samples = jax.vmap(lambda k: conditioned_model(k))(keys) assert abs(jnp.mean(samples)) < 0.2
def test_log_prob_normal_normal(): def model(key): keys = jax.random.split(key) n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)), keys[0]) n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), keys[1]) return n2 conditioned_model = func.condition(model, { 'n1': jnp.array(0.), 'n2': jnp.array(1.) }) tracer = func.trace(conditioned_model) key = jax.random.PRNGKey(123) tracer(key) tree = tracer.get_tree() log_prob = func.log_prob(tree) tu.check_close(log_prob, -2.337877)
def test_metropolis_hastings_wrong_proposal_or_initial_samples(): def model(key): n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)), key) n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), key) return n2 def proposal(key, **current): n4 = func.sample('n4', dist.normal(current['n4'], jnp.array(1.)), key) return {'n4': n4} conditioned_model = func.condition(model, {'n2': jnp.array(0.5)}) initial_samples = {'n4': jnp.array(0.)} with pytest.raises(RuntimeError): # wrong proposal and initial samples key = jax.random.PRNGKey(123) mcmc_model = func.metropolis_hastings(conditioned_model, proposal, initial_samples, num_chains=1) mcmc_model(key)
def test_log_prob_normal(): def model(key): n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)), key) return n1 # Unconditioned tracer = func.trace(model) key = jax.random.PRNGKey(123) tracer(key) tree = tracer.get_tree() log_prob = func.log_prob(tree) tu.check_close(log_prob, -1.2025023) # Conditioned conditioned_model = func.condition(model, {'n1': jnp.array(0.)}) tracer = func.trace(conditioned_model) key = jax.random.PRNGKey(123) tracer(key) tree = tracer.get_tree() log_prob = func.log_prob(tree) tu.check_close(log_prob, -0.9189385)