コード例 #1
0
ファイル: test_ast_switch.py プロジェクト: probcomp/sppl
def test_simple_model_lte():
    command_switch = Sequence(
        Sample(Y, beta(a=2, b=3)),
        Switch(Y, binspace(0, 1, 5),
               lambda i: Sample(X, bernoulli(p=i.right))))
    model_switch = command_switch.interpret()

    command_ifelse = Sequence(
        Sample(Y, beta(a=2, b=3)),
        IfElse(
            Y <= 0,
            Sample(X, bernoulli(p=0)),
            Y <= 0.25,
            Sample(X, bernoulli(p=.25)),
            Y <= 0.50,
            Sample(X, bernoulli(p=.50)),
            Y <= 0.75,
            Sample(X, bernoulli(p=.75)),
            Y <= 1,
            Sample(X, bernoulli(p=1)),
        ))
    model_ifelse = command_ifelse.interpret()

    grid = [float(x) for x in linspace(0, 1, 5)]
    for model in [model_switch, model_ifelse]:
        symbols = model.get_symbols()
        assert symbols == {X, Y}
        assert allclose(
            model.logprob(X << {1}),
            logsumexp([
                model.logprob((il < Y) <= ih) + log(ih)
                for il, ih in zip(grid[:-1], grid[1:])
            ]))
コード例 #2
0
ファイル: test_ast_switch.py プロジェクト: probcomp/sppl
def test_error_linspace():
    with pytest.raises(AssertionError):
        # Switch cases do not sum to one.
        command = Sequence(
            Sample(Y, beta(a=2, b=3)),
            Switch(Y, linspace(0, .5, 5), lambda i: Sample(X, bernoulli(p=i))))
        command.interpret()
コード例 #3
0
def get_command_beta():
    return Sequence(
        Sample(simAll, beta(a=2, b=3)),
        For(
            0, 5, lambda k: Switch(
                simAll, binspace(0, 1, ns), lambda i: Sequence(
                    Sample(sim[k], bernoulli(p=i.right)),
                    Sample(p1[k], uniform()),
                    IfElse(
                        sim[k] << {1},
                        Sequence(
                            Transform(p2[k], p1[k]),
                            Switch(
                                p1[k], binspace(0, 1, ns), lambda j: Sequence(
                                    Sample(clickA[k], bernoulli(p=i.right)),
                                    Sample(clickB[k], bernoulli(p=i.right))))),
                        True,
                        Sequence(
                            Sample(p2[k], uniform()),
                            Switch(
                                p1[k], binspace(0, 1, ns), lambda j: Sample(
                                    clickA[k], bernoulli(p=j.right))),
                            Switch(
                                p2[k], binspace(0, 1, ns), lambda j: Sample(
                                    clickB[k], bernoulli(p=j.right)))))))))
コード例 #4
0
ファイル: test_ast_condition.py プロジェクト: probcomp/sppl
def test_condition_prob_zero():
    with pytest.raises(Exception):
        Sequence(Sample(Y, {
            'a': .1,
            'b': .1,
            'c': .8
        }), Condition(Y << {'d'})).interpret()
    with pytest.raises(Exception):
        Sequence(Sample(Y, beta(a=1, b=1)), Condition(Y > 1)).interpret()
コード例 #5
0
ファイル: test_ast_condition.py プロジェクト: probcomp/sppl
def test_condition_real_continuous():
    command = Sequence(Sample(Y, beta(a=1, b=1)), Condition(Y < .5))
    model = command.interpret()
    assert allclose(model.prob(Y < .5), 1)
    assert allclose(model.prob(Y > .5), 0)