示例#1
0
def test_sum_simplify_nested_sum_2():
    X = Id('X')
    W = Id('W')
    children = [
        SumSPE([
            (X >> norm(loc=0, scale=1)) & (W >> norm(loc=0, scale=2)),
            (X >> norm(loc=0, scale=2)) & (W >> norm(loc=0, scale=1))],
            [log(0.9), log(0.1)]),
        (X >> norm(loc=0, scale=4)) & (W >> norm(loc=0, scale=10)),
        SumSPE([
            (X >> norm(loc=0, scale=1)) & (W >> norm(loc=0, scale=2)),
            (X >> norm(loc=0, scale=2)) & (W >> norm(loc=0, scale=1)),
            (X >> norm(loc=0, scale=8)) & (W >> norm(loc=0, scale=3)),],
            [log(0.4), log(0.3), log(0.3)]),
    ]
    spe = SumSPE(children, [log(0.4), log(0.4), log(0.2)])
    assert spe.size() == 19
    assert spe.children == (
        children[0].children[0], # 2 leaves
        children[0].children[1], # 2 leaves
        children[1],             # 2 leaf
        children[2].children[0], # 2 leaves
        children[2].children[1], # 2 leaves
        children[2].children[2], # 2 leaves
    )
    assert allclose(spe.weights[0], log(0.4) + log(0.9))
    assert allclose(spe.weights[1], log(0.4) + log(0.1))
    assert allclose(spe.weights[2], log(0.4))
    assert allclose(spe.weights[3], log(0.2) + log(0.4))
    assert allclose(spe.weights[4], log(0.2) + log(0.3))
    assert allclose(spe.weights[5], log(0.2) + log(0.3))
示例#2
0
def test_logpdf_mixture_nominal():
    spe = SumSPE([X >> norm(), X >> choice({'a':.1, 'b':.9})], [log(.4), log(.6)])
    assert allclose(
        spe.logpdf({X: .5}),
        log(.4) + spe.children[0].logpdf({X: .5}))
    assert allclose(
        spe.logpdf({X: 'a'}),
        log(.6) + spe.children[1].logpdf({X: 'a'}))
示例#3
0
def test_sum_simplify_product_complex():
    A1 = Id('A') >> norm(loc=0, scale=1)
    A0 = Id('A') >> norm(loc=0, scale=2)
    B = Id('B') >> norm(loc=0, scale=1)
    B1 = Id('B') >> norm(loc=0, scale=2)
    B0 = Id('B') >> norm(loc=0, scale=3)
    C = Id('C') >> norm(loc=0, scale=1)
    C1 = Id('C') >> norm(loc=0, scale=2)
    D = Id('D') >> norm(loc=0, scale=1)
    spe = SumSPE([
        ProductSPE([A1, B, C, D]),
        ProductSPE([A0, B1, C, D]),
        ProductSPE([A0, B0, C1, D]),
    ], [log(0.4), log(0.4), log(0.2)])

    spe_simplified = spe_simplify_sum(spe)
    assert isinstance(spe_simplified, ProductSPE)
    assert isinstance(spe_simplified.children[0], SumSPE)
    assert spe_simplified.children[1] == D

    ssc0 = spe_simplified.children[0]
    assert isinstance(ssc0.children[1], ProductSPE)
    assert ssc0.children[1].children == (A0, B0, C1)

    assert isinstance(ssc0.children[0], ProductSPE)
    assert ssc0.children[0].children[1] == C

    ssc0c0 = ssc0.children[0].children[0]
    assert isinstance(ssc0c0, SumSPE)
    assert isinstance(ssc0c0.children[0], ProductSPE)
    assert isinstance(ssc0c0.children[1], ProductSPE)
    assert ssc0c0.children[0].children == (A1, B)
    assert ssc0c0.children[1].children == (A0, B1)
示例#4
0
def test_sum_simplify_nested_sum_1():
    X = Id('X')
    children = [
        SumSPE(
            [X >> norm(loc=0, scale=1), X >> norm(loc=0, scale=2)],
            [log(0.4), log(0.6)]),
        X >> gamma(loc=0, a=1),
    ]
    spe = SumSPE(children, [log(0.7), log(0.3)])
    assert spe.size() == 4
    assert spe.children == (
        children[0].children[0],
        children[0].children[1],
        children[1]
    )
    assert allclose(spe.weights[0], log(0.7) + log(0.4))
    assert allclose(spe.weights[1], log(0.7) + log(0.6))
    assert allclose(spe.weights[2], log(0.3))
示例#5
0
def test_sum_simplify_product_collapse():
    A1 = Id('A') >> norm(loc=0, scale=1)
    A0 = Id('A') >> norm(loc=0, scale=1)
    B = Id('B') >> norm(loc=0, scale=1)
    B1 = Id('B') >> norm(loc=0, scale=1)
    B0 = Id('B') >> norm(loc=0, scale=1)
    C = Id('C') >> norm(loc=0, scale=1)
    C1 = Id('C') >> norm(loc=0, scale=1)
    D = Id('D') >> norm(loc=0, scale=1)
    spe = SumSPE([
        ProductSPE([A1, B, C, D]),
        ProductSPE([A0, B1, C, D]),
        ProductSPE([A0, B0, C1, D]),
    ], [log(0.4), log(0.4), log(0.2)])
    assert spe_simplify_sum(spe) == ProductSPE([A1, B, C, D])
示例#6
0
def test_sum_simplify_leaf():
    Xd0 = Id('X') >> norm(loc=0, scale=1)
    Xd1 = Id('X') >> norm(loc=0, scale=2)
    Xd2 = Id('X') >> norm(loc=0, scale=3)
    spe = SumSPE([Xd0, Xd1, Xd2], [log(0.5), log(0.1), log(.4)])
    assert spe.size() == 4
    assert spe_simplify_sum(spe) == spe

    Xd0 = Id('X') >> norm(loc=0, scale=1)
    Xd1 = Id('X') >> norm(loc=0, scale=1)
    Xd2 = Id('X') >> norm(loc=0, scale=1)
    spe = SumSPE([Xd0, Xd1, Xd2], [log(0.5), log(0.1), log(.4)])
    assert spe_simplify_sum(spe) == Xd0

    Xd3 = Id('X') >> norm(loc=0, scale=2)
    spe = SumSPE([Xd0, Xd3, Xd1, Xd3], [log(0.5), log(0.1), log(.3), log(.1)])
    spe_simplified = spe_simplify_sum(spe)
    assert len(spe_simplified.children) == 2
    assert spe_simplified.children[0] == Xd0
    assert spe_simplified.children[1] == Xd3
    assert allclose(spe_simplified.weights[0], log(0.8))
    assert allclose(spe_simplified.weights[1], log(0.2))
示例#7
0
def test_sum_normal_gamma():
    X = Id('X')
    weights = [log(Fraction(2, 3)), log(Fraction(1, 3))]
    spe = SumSPE([
        X >> norm(loc=0, scale=1),
        X >> gamma(loc=0, a=1),
    ], weights)

    assert spe.logprob(X > 0) == logsumexp([
        spe.weights[0] + spe.children[0].logprob(X > 0),
        spe.weights[1] + spe.children[1].logprob(X > 0),
    ])
    assert spe.logprob(X < 0) == log(Fraction(2, 3)) + log(Fraction(1, 2))
    samples = spe.sample(100, prng=numpy.random.RandomState(1))
    assert all(s[X] for s in samples)
    spe.sample_func(lambda X: abs(X**3), 100)
    with pytest.raises(ValueError):
        spe.sample_func(lambda Y: abs(X**3), 100)

    spe_condition = spe.condition(X < 0)
    assert isinstance(spe_condition, ContinuousLeaf)
    assert spe_condition.conditioned
    assert spe_condition.logprob(X < 0) == 0
    samples = spe_condition.sample(100)
    assert all(s[X] < 0 for s in samples)

    assert spe.logprob(X < 0) == logsumexp([
        spe.weights[0] + spe.children[0].logprob(X < 0),
        spe.weights[1] + spe.children[1].logprob(X < 0),
    ])
示例#8
0
def test_sum_normal_nominal():
    X = Id('X')
    children = [
        X >> norm(loc=0, scale=1),
        X >> choice({
            'low': Fraction(3, 10),
            'high': Fraction(7, 10)
        }),
    ]
    weights = [log(Fraction(4, 7)), log(Fraction(3, 7))]

    spe = SumSPE(children, weights)

    assert allclose(spe.logprob(X < 0),
                    log(Fraction(4, 7)) + log(Fraction(1, 2)))

    assert allclose(spe.logprob(X << {'low'}),
                    log(Fraction(3, 7)) + log(Fraction(3, 10)))

    # The semantics of ~(X<<{'low'}) are (X << String and X != 'low')
    assert allclose(spe.logprob(~(X << {'low'})), spe.logprob((X << {'high'})))
    assert allclose(
        spe.logprob((X << FiniteNominal(b=True)) & ~(X << {'low'})),
        spe.logprob((X << FiniteNominal(b=True)) & (X << {'high'})))

    assert isinf_neg(spe.logprob((X < 0) & (X << {'low'})))

    assert allclose(spe.logprob((X < 0) | (X << {'low'})),
                    logsumexp([spe.logprob(X < 0),
                               spe.logprob(X << {'low'})]))

    assert isinf_neg(spe.logprob(X << {'a'}))
    assert allclose(spe.logprob(~(X << {'a'})),
                    spe.logprob(X << {'low', 'high'}))

    assert allclose(spe.logprob(X**2 < 9),
                    log(Fraction(4, 7)) + spe.children[0].logprob(X**2 < 9))

    spe_condition = spe.condition(X**2 < 9)
    assert isinstance(spe_condition, ContinuousLeaf)
    assert spe_condition.support == Interval.open(-3, 3)

    spe_condition = spe.condition((X**2 < 9) | X << {'low'})
    assert isinstance(spe_condition, SumSPE)
    assert spe_condition.children[0].support == Interval.open(-3, 3)
    assert spe_condition.children[1].support == FiniteNominal('low', 'high')
    assert isinf_neg(spe_condition.children[1].logprob(X << {'high'}))

    assert spe_condition == spe.condition((X**2 < 9) | ~(X << {'high'}))
    assert allclose(spe.logprob((X < oo) | ~(X << {'1'})), 0)
def test_cache_complex_sum_of_product():
    # Test case adapted from the SPE generated by
    # test_repeat.make_model_repeat(n=2)
    duplicate_subtrees = [None, None]
    for i in range(2):
        duplicate_subtrees[i] = SumSPE([
            ProductSPE([(X[0] >> bernoulli(p=.1)),
                        SumSPE([(Z[0] >> bernoulli(p=.5))
                                & (Y >> choice({
                                    '0': .1,
                                    '1': .9
                                })), (Z[0] >> bernoulli(p=.1))
                                & (Y >> choice({
                                    '0': .9,
                                    '1': .1
                                }))],
                               weights=[log(.730), log(.270)])]),
            ProductSPE([
                Z[0] >> bernoulli(p=.1),
                Y >> choice({
                    '0': .9,
                    '1': .1
                }),
                X[0] >> bernoulli(p=.5),
            ]),
        ],
                                       weights=[log(.925),
                                                log(.075)])

    assert duplicate_subtrees[0] == duplicate_subtrees[1]
    assert duplicate_subtrees[0] is not duplicate_subtrees[1]

    left_subtree = ProductSPE([
        X[1] >> bernoulli(p=.5),
        SumSPE([
            ProductSPE([
                duplicate_subtrees[0],
                Z[1] >> bernoulli(p=.5),
            ]),
            ProductSPE([
                Z[1] >> bernoulli(p=.7),
                SumSPE([
                    Y >> choice({
                        '0': .3,
                        '1': .7
                    })
                    & X[0] >> bernoulli(p=.1)
                    & Z[0] >> bernoulli(p=.1),
                    Y >> choice({
                        '0': .7,
                        '1': .3
                    })
                    & X[0] >> bernoulli(p=.5)
                    & Z[0] >> bernoulli(p=.5),
                ],
                       weights=[log(.9), log(.1)])
            ])
        ],
               weights=[log(.783), log(.217)])
    ])

    right_subtree = ProductSPE([
        Z[1] >> bernoulli(p=.8), X[1] >> bernoulli(p=.1), duplicate_subtrees[1]
    ])

    spe = .92 * left_subtree | .08 * right_subtree

    spe_cached = spe_cache_duplicate_subtrees(spe, {})
    assert spe_cached.children[0].children[1].children[0].children[
        0] is duplicate_subtrees[0]
    assert spe_cached.children[1].children[2] is duplicate_subtrees[0]