def test_simple_model_eq(): command_switch = Sequence( Sample(Y, randint(low=0, high=4)), Switch(Y, range(0, 4), lambda i: Sample(X, bernoulli(p=1 / (i + 1))))) model_switch = command_switch.interpret() command_ifelse = Sequence( Sample(Y, randint(low=0, high=4)), IfElse( Y << {0}, Sample(X, bernoulli(p=1 / (0 + 1))), Y << {1}, Sample(X, bernoulli(p=1 / (1 + 1))), Y << {2}, Sample(X, bernoulli(p=1 / (2 + 1))), Y << {3}, Sample(X, bernoulli(p=1 / (3 + 1))), )) model_ifelse = command_ifelse.interpret() for model in [model_switch, model_ifelse]: symbols = model.get_symbols() assert symbols == {X, Y} assert allclose(model.logprob(X << {1}), logsumexp([-log(4) - log(i + 1) for i in range(4)]))
def get_command_randint(): return Sequence( Sample(simAll, randint(low=0, high=ns)), For( 0, 5, lambda k: Switch( simAll, range(0, ns), lambda i: Sequence( Sample(sim[k], bernoulli(p=i / nd)), Sample(p1[k], randint(low=0, high=ns)), IfElse( sim[k] << {1}, Sequence( Transform(p2[k], p1[k]), Switch( p1[k], range(ns), lambda j: Sequence( Sample(clickA[k], bernoulli(p=i / nd)), Sample(clickB[k], bernoulli(p=i / nd))))), True, Sequence( Sample(p2[k], randint(low=0, high=ns)), Switch( p1[k], range(ns), lambda j: Sample( clickA[k], bernoulli(p=j / nd))), Switch( p2[k], range(ns), lambda j: Sample( clickB[k], bernoulli(p=j / nd)))))))))
def test_error_range(): with pytest.raises(AssertionError): # Switch cases do not sum to one. command = Sequence( Sample(Y, randint(low=0, high=4)), Switch(Y, range(0, 3), lambda i: Sample(X, bernoulli(p=1 / (i + 1))))) command.interpret()
def test_simple_model_enumerate(): command_switch = Sequence( Sample(Y, randint(low=0, high=4)), Switch(Y, enumerate(range(0, 4)), lambda i, j: Sample(X, bernoulli(p=1 / (i + j + 1))))) model = command_switch.interpret() assert allclose(model.prob(Y << {0} & (X << {1})), .25 * 1 / (0 + 0 + 1)) assert allclose(model.prob(Y << {1} & (X << {1})), .25 * 1 / (1 + 1 + 1)) assert allclose(model.prob(Y << {2} & (X << {1})), .25 * 1 / (2 + 2 + 1)) assert allclose(model.prob(Y << {3} & (X << {1})), .25 * 1 / (3 + 3 + 1))
def test_randint(): X = Id('X') spe = X >> randint(low=0, high=5) assert spe.xl == 0 assert spe.xu == 4 assert spe.logprob(X < 5) == spe.logprob(X <= 4) == 0 # i.e., X is not in [0, 3] spe_condition = spe.condition(~((X + 1) << {1, 4})) assert isinstance(spe_condition, SumSPE) xl = spe_condition.children[0].xl idx0 = 0 if xl == 1 else 1 idx1 = 1 if xl == 1 else 0 assert spe_condition.children[idx0].xl == 1 assert spe_condition.children[idx0].xu == 2 assert spe_condition.children[idx1].xl == 4 assert spe_condition.children[idx1].xu == 4 assert allclose(spe_condition.children[idx0].logprob(X << {1, 2}), 0) assert allclose(spe_condition.children[idx1].logprob(X << {4}), 0)
def test_ifelse_zero_conditions(): command = Sequence( Sample(Y, randint(low=0, high=3)), IfElse( Y << {-1}, Transform(X, Y**(-1)), Y << {0}, Sample(X, bernoulli(p=1)), Y << {1}, Transform(X, Y), Y << {2}, Transform(X, Y**2), Y << {3}, Transform(X, Y**3), )) model = command.interpret() assert len(model.children) == 3 assert len(model.weights) == 3 assert allclose(model.weights[0], model.logprob(Y << {0})) assert allclose(model.weights[1], model.logprob(Y << {1})) assert allclose(model.weights[2], model.logprob(Y << {2}))
def test_condition_real_discrete_no_range(): command = Sequence(Sample(Y, randint(low=0, high=4)), Condition(Y << {0, 2})) model = command.interpret() assert allclose(model.prob(Y << {0}), .5) assert allclose(model.prob(Y << {1}), .5)