Exemplo n.º 1
0
def test_simple_model_eq():
    command_switch = Sequence(
        Sample(Y, randint(low=0, high=4)),
        Switch(Y, range(0, 4), lambda i: Sample(X, bernoulli(p=1 / (i + 1)))))
    model_switch = command_switch.interpret()

    command_ifelse = Sequence(
        Sample(Y, randint(low=0, high=4)),
        IfElse(
            Y << {0},
            Sample(X, bernoulli(p=1 / (0 + 1))),
            Y << {1},
            Sample(X, bernoulli(p=1 / (1 + 1))),
            Y << {2},
            Sample(X, bernoulli(p=1 / (2 + 1))),
            Y << {3},
            Sample(X, bernoulli(p=1 / (3 + 1))),
        ))
    model_ifelse = command_ifelse.interpret()

    for model in [model_switch, model_ifelse]:
        symbols = model.get_symbols()
        assert symbols == {X, Y}
        assert allclose(model.logprob(X << {1}),
                        logsumexp([-log(4) - log(i + 1) for i in range(4)]))
Exemplo n.º 2
0
def get_command_randint():
    return Sequence(
        Sample(simAll, randint(low=0, high=ns)),
        For(
            0, 5, lambda k: Switch(
                simAll, range(0, ns), lambda i: Sequence(
                    Sample(sim[k], bernoulli(p=i / nd)),
                    Sample(p1[k], randint(low=0, high=ns)),
                    IfElse(
                        sim[k] << {1},
                        Sequence(
                            Transform(p2[k], p1[k]),
                            Switch(
                                p1[k], range(ns), lambda j: Sequence(
                                    Sample(clickA[k], bernoulli(p=i / nd)),
                                    Sample(clickB[k], bernoulli(p=i / nd))))),
                        True,
                        Sequence(
                            Sample(p2[k], randint(low=0, high=ns)),
                            Switch(
                                p1[k], range(ns), lambda j: Sample(
                                    clickA[k], bernoulli(p=j / nd))),
                            Switch(
                                p2[k], range(ns), lambda j: Sample(
                                    clickB[k], bernoulli(p=j / nd)))))))))
Exemplo n.º 3
0
def test_error_range():
    with pytest.raises(AssertionError):
        # Switch cases do not sum to one.
        command = Sequence(
            Sample(Y, randint(low=0, high=4)),
            Switch(Y, range(0, 3), lambda i: Sample(X, bernoulli(p=1 /
                                                                 (i + 1)))))
        command.interpret()
Exemplo n.º 4
0
def test_simple_model_enumerate():
    command_switch = Sequence(
        Sample(Y, randint(low=0, high=4)),
        Switch(Y, enumerate(range(0, 4)),
               lambda i, j: Sample(X, bernoulli(p=1 / (i + j + 1)))))
    model = command_switch.interpret()
    assert allclose(model.prob(Y << {0} & (X << {1})), .25 * 1 / (0 + 0 + 1))
    assert allclose(model.prob(Y << {1} & (X << {1})), .25 * 1 / (1 + 1 + 1))
    assert allclose(model.prob(Y << {2} & (X << {1})), .25 * 1 / (2 + 2 + 1))
    assert allclose(model.prob(Y << {3} & (X << {1})), .25 * 1 / (3 + 3 + 1))
Exemplo n.º 5
0
def test_randint():
    X = Id('X')
    spe = X >> randint(low=0, high=5)
    assert spe.xl == 0
    assert spe.xu == 4
    assert spe.logprob(X < 5) == spe.logprob(X <= 4) == 0
    # i.e., X is not in [0, 3]
    spe_condition = spe.condition(~((X + 1) << {1, 4}))
    assert isinstance(spe_condition, SumSPE)
    xl = spe_condition.children[0].xl
    idx0 = 0 if xl == 1 else 1
    idx1 = 1 if xl == 1 else 0
    assert spe_condition.children[idx0].xl == 1
    assert spe_condition.children[idx0].xu == 2
    assert spe_condition.children[idx1].xl == 4
    assert spe_condition.children[idx1].xu == 4
    assert allclose(spe_condition.children[idx0].logprob(X << {1, 2}), 0)
    assert allclose(spe_condition.children[idx1].logprob(X << {4}), 0)
Exemplo n.º 6
0
def test_ifelse_zero_conditions():
    command = Sequence(
        Sample(Y, randint(low=0, high=3)),
        IfElse(
            Y << {-1},
            Transform(X, Y**(-1)),
            Y << {0},
            Sample(X, bernoulli(p=1)),
            Y << {1},
            Transform(X, Y),
            Y << {2},
            Transform(X, Y**2),
            Y << {3},
            Transform(X, Y**3),
        ))
    model = command.interpret()
    assert len(model.children) == 3
    assert len(model.weights) == 3
    assert allclose(model.weights[0], model.logprob(Y << {0}))
    assert allclose(model.weights[1], model.logprob(Y << {1}))
    assert allclose(model.weights[2], model.logprob(Y << {2}))
Exemplo n.º 7
0
def test_condition_real_discrete_no_range():
    command = Sequence(Sample(Y, randint(low=0, high=4)),
                       Condition(Y << {0, 2}))
    model = command.interpret()
    assert allclose(model.prob(Y << {0}), .5)
    assert allclose(model.prob(Y << {1}), .5)