def test_simple_model_lte(): command_switch = Sequence( Sample(Y, beta(a=2, b=3)), Switch(Y, binspace(0, 1, 5), lambda i: Sample(X, bernoulli(p=i.right)))) model_switch = command_switch.interpret() command_ifelse = Sequence( Sample(Y, beta(a=2, b=3)), IfElse( Y <= 0, Sample(X, bernoulli(p=0)), Y <= 0.25, Sample(X, bernoulli(p=.25)), Y <= 0.50, Sample(X, bernoulli(p=.50)), Y <= 0.75, Sample(X, bernoulli(p=.75)), Y <= 1, Sample(X, bernoulli(p=1)), )) model_ifelse = command_ifelse.interpret() grid = [float(x) for x in linspace(0, 1, 5)] for model in [model_switch, model_ifelse]: symbols = model.get_symbols() assert symbols == {X, Y} assert allclose( model.logprob(X << {1}), logsumexp([ model.logprob((il < Y) <= ih) + log(ih) for il, ih in zip(grid[:-1], grid[1:]) ]))
def test_error_linspace(): with pytest.raises(AssertionError): # Switch cases do not sum to one. command = Sequence( Sample(Y, beta(a=2, b=3)), Switch(Y, linspace(0, .5, 5), lambda i: Sample(X, bernoulli(p=i)))) command.interpret()
def get_command_beta(): return Sequence( Sample(simAll, beta(a=2, b=3)), For( 0, 5, lambda k: Switch( simAll, binspace(0, 1, ns), lambda i: Sequence( Sample(sim[k], bernoulli(p=i.right)), Sample(p1[k], uniform()), IfElse( sim[k] << {1}, Sequence( Transform(p2[k], p1[k]), Switch( p1[k], binspace(0, 1, ns), lambda j: Sequence( Sample(clickA[k], bernoulli(p=i.right)), Sample(clickB[k], bernoulli(p=i.right))))), True, Sequence( Sample(p2[k], uniform()), Switch( p1[k], binspace(0, 1, ns), lambda j: Sample( clickA[k], bernoulli(p=j.right))), Switch( p2[k], binspace(0, 1, ns), lambda j: Sample( clickB[k], bernoulli(p=j.right)))))))))
def test_condition_prob_zero(): with pytest.raises(Exception): Sequence(Sample(Y, { 'a': .1, 'b': .1, 'c': .8 }), Condition(Y << {'d'})).interpret() with pytest.raises(Exception): Sequence(Sample(Y, beta(a=1, b=1)), Condition(Y > 1)).interpret()
def test_condition_real_continuous(): command = Sequence(Sample(Y, beta(a=1, b=1)), Condition(Y < .5)) model = command.interpret() assert allclose(model.prob(Y < .5), 1) assert allclose(model.prob(Y > .5), 0)