コード例 #1
0
ファイル: test_event_evaluate.py プロジェクト: probcomp/sppl
def test_event_solve_multi():
    event = (Exp(abs(3 * X**2)) > 1) | (Log(Y) < 0.5)
    with pytest.raises(ValueError):
        event.solve()
    event = (Exp(abs(3 * X**2)) > 1) & (Log(Y) < 0.5)
    with pytest.raises(ValueError):
        event.solve()
コード例 #2
0
def test_parse_2_closed():
    # (log(x) <= 2) & (x >= exp(2))
    expr = (Log(X) >= 2) & (X <= sympy.exp(2))
    event = EventAnd([
        EventInterval(Log(Y), Interval(2, oo)),
        EventInterval(Y, Interval(-oo, sympy.exp(2)))
    ])
    assert expr == event
コード例 #3
0
def test_parse_2_open():
    # log(x) < 2 & (x < exp(2))
    expr = (Log(X) > 2) & (X < sympy.exp(2))
    event = EventAnd([
        EventInterval(Log(Y), Interval.open(2, oo)),
        EventInterval(Y, Interval.open(-oo, sympy.exp(2)))
    ])
    assert expr == event
コード例 #4
0
def test_solver_20():
    # log(x**2 - 3) < 5
    solution = Union(
        Interval.open(-sympy.sqrt(3 + sympy.exp(5)), -sympy.sqrt(3)),
        Interval.open(sympy.sqrt(3), sympy.sqrt(3 + sympy.exp(5))))
    event = Log(Y**2 - 3) < 5
    answer = event.solve()
    assert answer == solution
コード例 #5
0
def test_parse_21__ci_():
    # 1 <= log(x**3 - 3*x + 3) < 5
    # Can only be solved by numerical approximation of roots.
    # https://www.wolframalpha.com/input/?i=1+%3C%3D+log%28x**3+-+3x+%2B+3%29+%3C+5
    expr = Log(X**3 - 3*X + 3)
    expr_prime = Log(Poly(Y, [3, -3, 0, 1]))
    assert expr == expr_prime
    assert ((1 <= expr) & (expr < 5)) \
        == EventInterval(expr, Interval.Ropen(1, 5))
コード例 #6
0
def test_solver_9_closed():
    # 2(log(x))**3 - log(x) -5 >= 0
    solution = Interval(
        sympy.exp(1/(6*(sympy.sqrt(2019)/36 + Rat(5,4))**(Rat(1, 3)))
            + (sympy.sqrt(2019)/36 + Rat(5,4))**(Rat(1,3))),
        oo)
    event = 2*(Log(Y))**3 - Log(Y) - 5 >= 0
    answer = event.solve()
    assert answer == solution
コード例 #7
0
def test_parse_9_open():
    # 2(log(x))**3 - log(x) -5 > 0
    expr = 2*(Log(X))**3 - Log(X) - 5
    expr_prime = Poly(Log(Y), [-5, -1, 0, 2])
    assert expr == expr_prime

    event = EventInterval(expr, Interval.open(0, oo))
    assert (expr > 0) == event

    # Cannot add polynomials with different subexpressions.
    with pytest.raises(ValueError):
        (2*Log(X))**3 - Log(X) - 5
コード例 #8
0
ファイル: test_product.py プロジェクト: probcomp/sppl
def test_product_condition_or_probabilithy_zero():
    X = Id('X')
    Y = Id('Y')
    spe = ProductSPE([X >> norm(loc=0, scale=1), Y >> gamma(a=1)])

    # Condition on event which has probability zero.
    event = (X > 2) & (X < 2)
    with pytest.raises(ValueError):
        spe.condition(event)
    assert spe.logprob(event) == -float('inf')

    # Condition on event which has probability zero.
    event = (Y < 0) | (Y < -1)
    with pytest.raises(ValueError):
        spe.condition(event)
    assert spe.logprob(event) == -float('inf')
    # Condition on an event where one clause has probability
    # zero, yielding a single product.
    spe_condition = spe.condition((Y < 0) | ((Log(X) >= 0) & (1 <= Y)))
    assert isinstance(spe_condition, ProductSPE)
    assert spe_condition.children[0].symbol == X
    assert spe_condition.children[0].conditioned
    assert spe_condition.children[0].support == Interval(1, oo)
    assert spe_condition.children[1].symbol == Y
    assert spe_condition.children[1].conditioned
    assert spe_condition.children[0].support == Interval(1, oo)

    # We have (X < 2) & ~(1 < exp(|3X**2|) is empty.
    # Thus Y remains unconditioned,
    #   and X is partitioned into (-oo, 0) U (0, oo) with equal weight.
    event = (Exp(abs(3 * X**2)) > 1) | ((Log(Y) < 0.5) & (X < 2))
    spe_condition = spe.condition(event)
    #
    # The most concise representation of spe_condition is:
    #   (Product (Sum [.5 .5] X|X<0 X|X>0) Y)
    assert isinstance(spe_condition, ProductSPE)
    assert isinstance(spe_condition.children[0], SumSPE)
    assert spe_condition.children[0].weights == (-log(2), -log(2))
    assert spe_condition.children[0].children[0].conditioned
    assert spe_condition.children[0].children[1].conditioned
    assert spe_condition.children[0].children[0].support \
        in [Interval.Ropen(-oo, 0), Interval.Lopen(0, oo)]
    assert spe_condition.children[0].children[1].support \
        in [Interval.Ropen(-oo, 0), Interval.Lopen(0, oo)]
    assert spe_condition.children[0].children[0].support \
        != spe_condition.children[0].children[1].support
    assert spe_condition.children[1].symbol == Y
    assert not spe_condition.children[1].conditioned
コード例 #9
0
def test_solver_10():
    # Sympy hangs on this test.
    # exp(sqrt(log(x))) > -5
    solution = Interval(1, oo)
    event = Exp(Sqrt(Log(Y))) > -5
    answer = event.solve()
    assert answer == solution
コード例 #10
0
def test_solver_finite_injective():
    sqrt3 = sympy.sqrt(3)
    # Identity.
    solution = FiniteReal(2, 4, -10, sqrt3)
    event = Y << {2, 4, -10, sqrt3}
    assert event.solve() == solution
    # Exp.
    solution = FiniteReal(sympy.log(10), sympy.log(3), sympy.log(sqrt3))
    event = Exp(Y) << {10, 3, sqrt3}
    assert event.solve() == solution
    # Exp2.
    solution = FiniteReal(sympy.log(10, 2), 4, sympy.log(sqrt3, 2))
    event = (2**Y) << {10, 16, sqrt3}
    assert event.solve() == solution
    # Log.
    solution = FiniteReal(sympy.exp(10), sympy.exp(-3), sympy.exp(sqrt3))
    event = Log(Y) << {10, -3, sqrt3}
    assert event.solve() == solution
    # Log2
    solution = FiniteReal(sympy.Pow(2, 10), sympy.Pow(2, -3), sympy.Pow(2, sqrt3))
    event = Logarithm(Y, 2) << {10, -3, sqrt3}
    assert event.solve() == solution
    # Radical.
    solution = FiniteReal(7**4, 12**4, sqrt3**4)
    event = Y**Rat(1, 4) << {7, 12, sqrt3}
    assert event.solve() == solution
コード例 #11
0
def test_solver_9_open():
    # 2(log(x))**3 - log(x) -5 > 0
    solution = Interval.open(
        sympy.exp(1/(6*(sympy.sqrt(2019)/36 + Rat(5,4))**(Rat(1, 3)))
            + (sympy.sqrt(2019)/36 + Rat(5,4))**(Rat(1,3))),
        oo)
    # Our solver handles this case as follows
    # expr' = 2*Z**3 - Z - 5 > 0 [[subst. Z=log(X)]]
    # [Z_low, Z_high] = sympy_solver(expr')
    #       Z_low < Z iff Z_low < log(X) iff exp(Z_low) < X
    #       Z < Z_high iff log(X) < Z_high iff X < exp(Z_high)
    # sympy_solver(expr) = [exp(Z_low), exp(Z_high)]
    # For F invertible, can thus solve Poly(coeffs, F) > 0 using this method.
    event = 2*(Log(Y))**3 - Log(Y) - 5 > 0
    answer = event.solve()
    assert answer == solution
コード例 #12
0
def test_dnf_factor_3():
    A = (Exp(X0) > 0)
    B = X0 < 10
    C = X1 < 10
    D = X4 > 0
    E = (X2**2 - 3 * X2) << (0, 10, 100)
    F = (10 * Log(X5) + 9) > 5
    G = X4 < 4

    event = (A & B & C & ~D) | (E & F & G)
    event_dnf = event.to_dnf()
    event_factor = dnf_factor(event_dnf, {
        X0: 0,
        X1: 0,
        X2: 0,
        X3: 1,
        X4: 1,
        X5: 2
    })
    assert len(event_factor) == 2
    assert event_factor[0][0] == A & B & C
    assert event_factor[0][1] == ~D
    assert event_factor[1][0] == E
    assert event_factor[1][1] == G
    assert event_factor[1][2] == F
コード例 #13
0
def test_solver_21__ci_():
    # 1 <= log(x**3 - 3*x + 3) < 5
    # Can only be solved by numerical approximation of roots.
    # https://www.wolframalpha.com/input/?i=1+%3C%3D+log%28x**3+-+3x+%2B+3%29+%3C+5
    solution = Union(
        Interval(
            -1.777221448430427630375448631016427343692,
            0.09418455242255462832154474245589911789464),
        Interval.Ropen(
            1.683036896007873002053903888560528225797,
            5.448658707897512189124586716091172798465))

    expr = Log(Y**3 - 3*Y + 3)
    event = ((1 <= expr) & (expr < 5))
    answer = event.solve()
    assert isinstance(answer, Union)
    assert len(answer.args) == 2
    first = answer.args[0] if answer.args[0].a < 0 else answer.args[1]
    second = answer.args[0] if answer.args[0].a > 0 else answer.args[1]
    # Check first interval.
    assert not first.left_open
    assert not first.right_open
    assert allclose(float(first.a), float(solution.args[0].a))
    assert allclose(float(first.b), float(solution.args[0].b))
    # Check second interval.
    assert not second.left_open
    assert second.right_open
    assert allclose(float(second.a), float(solution.args[1].a))
    assert allclose(float(second.b), float(solution.args[1].b))
コード例 #14
0
def test_errors():
    with pytest.raises(ValueError):
        1 + Log(X) - Exp(X)
    with pytest.raises(TypeError):
        Log(X) ** Exp(X)
    with pytest.raises(ValueError):
        Abs(X) ** sympy.sqrt(10)
    with pytest.raises(ValueError):
        Log(X) * X
    with pytest.raises(ValueError):
        (2*Log(X)) - Rat(1, 10) * Abs(X)
    with pytest.raises(ValueError):
        X**(1.71)
    with pytest.raises(ValueError):
        Abs(X)**(1.1, 8)
    with pytest.raises(ValueError):
        Abs(X)**(7, 8)
    with pytest.raises(ValueError):
        (-3)**X
    with pytest.raises(ValueError):
        (Identity('Z')**2)*(Y > 0)
    with pytest.raises(ValueError):
        (Y > 0) * (Identity('Z')**2)
    with pytest.raises(ValueError):
        ((Y > 0) | (Identity('Z') < 3)) * (Identity('Z')**2)
    with pytest.raises(ValueError):
        Y**2 + (0 <= Y) * Y
    with pytest.raises(ValueError):
        (Y <= 0)*(Y**2) + (0 <= Y)*Y**((1, 2))
    with pytest.raises(ValueError):
        (Y <= 0)*(Y**2) + (0 <= Identity('Z'))*Y**((1, 2))
    with pytest.raises(ValueError):
        (Y <= 0)*(Y**2) + (0 <= Identity('Z'))*Identity('Z')**((1, 2))

    # TypeErrors from 'return NotImplemented'.
    with pytest.raises(TypeError):
        X + 'a'
    with pytest.raises(TypeError):
        X * 'a'
    with pytest.raises(TypeError):
        X / 'a'
    with pytest.raises(TypeError):
        X**'s'
コード例 #15
0
def test_parse_17():
    # https://www.wolframalpha.com/input/?i=Expand%5B%2810%2F7+%2B+X%29+%28-1%2F%285+Sqrt%5B2%5D%29+%2B+X%29+%28-Sqrt%5B5%5D+%2B+X%29%5D
    for Z in [X, Log(X), Abs(1+X**2)]:
        expr = (Z - Rat(1, 10) * sympy.sqrt(2)) \
            * (Z + Rat(10, 7)) \
            * (Z - sympy.sqrt(5))
        coeffs = [
            sympy.sqrt(10)/7,
            1/sympy.sqrt(10) - (10*sympy.sqrt(5))/7 - sympy.sqrt(2)/7,
            (-sympy.sqrt(5) - 1/(5 * sympy.sqrt(2))) + Rat(10)/7,
            1,
        ]
        expr_prime = Poly(Z, coeffs)
        assert expr == expr_prime
コード例 #16
0
ファイル: test_product.py プロジェクト: probcomp/sppl
def test_product_disjoint_union_numerical():
    X = Id('X')
    Y = Id('Y')
    Z = Id('Z')
    spe = ProductSPE([
        X >> norm(loc=0, scale=1),
        Y >> norm(loc=0, scale=2),
        Z >> norm(loc=0, scale=2),
    ])

    for event in [
        (1 / X < 4) | (X > 7),
        (2 * X - 3 > 0) | (Log(Y) < 3),
        ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | ~(X << {1, 2}),
        ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0),
        ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0) | ~(X << {1, 3}),
    ]:
        clauses = dnf_to_disjoint_union(event)
        logps = [spe.logprob(s) for s in clauses.subexprs]
        assert allclose(logsumexp(logps), spe.logprob(event))
コード例 #17
0
def test_solver_finite_non_injective():
    sqrt2 = sympy.sqrt(2)
    # Abs.
    solution = FiniteReal(-10, -3, 3, 10)
    event = abs(Y) << {10, 3}
    assert event.solve() == solution
    # Abs(Poly).
    solution = FiniteReal(-5, -Rat(3,2), Rat(3,2), 5)
    event = abs(2*Y) << {10, 3}
    assert event.solve() == solution
    # Poly order 2.
    solution = FiniteReal(-sqrt2, sqrt2)
    event = (Y**2) << {2}
    assert event.solve() == solution
    # Poly order 3.
    solution = FiniteReal(1, 3)
    event = Y**3 << {1, 27}
    assert event.solve() == solution
    # Poly Abs.
    solution = FiniteReal(-3, -1, 1, 3)
    event = (abs(Y))**3 << {1, 27}
    assert event.solve() == solution
    # Abs Not.
    solution = Union(
        Interval.open(-oo, -1),
        Interval.open(-1, 1),
        Interval.open(1, oo))
    event = ~(abs(Y) << {1})
    assert event.solve() == solution
    # Abs in EmptySet.
    solution = EmptySet
    event = (abs(Y))**3 << set()
    assert event.solve() == solution
    # Abs Not in EmptySet (yields all reals).
    solution = Interval(-oo, oo)
    event = ~(((abs(Y))**3) << set())
    assert event.solve() == solution
    # Log in Reals (yields positive reals).
    solution = Interval.open(0, oo)
    event = ~((Log(Y))**3 << set())
    assert event.solve() == solution
コード例 #18
0
    ((X >> norm(loc=0, scale=1)) & (Y >> gamma(a=1))).constrain({Y:1}),
]
@pytest.mark.parametrize('spe', spes)
def test_serialize_equal(spe):
    metadata = spe_to_dict(spe)
    spe_json_encoded = json.dumps(metadata)
    spe_json_decoded = json.loads(spe_json_encoded)
    spe2 = spe_from_dict(spe_json_decoded)
    assert spe2 == spe

transforms = [
    X,
    X**(1,3),
    Exponential(X, base=3),
    Logarithm(X, base=2),
    2**Log(X),
    1/Exp(X),
    abs(X),
    1/X,
    2*X + X**3,
    (X/2)*(X<0) + (X**(1,2))*(0<=X),
    X < sqrt(3),
    X << [],
    ~(X << []),
    EventFiniteNominal(1/X**(1,10), EmptySet),
    X << {1, 2},
    X << {'a', 'x'},
    ~(X << {'a', '1'}),
    (X < 3) | (X << {1,2}),
    (X < 3) & (X << {1,2}),
]
コード例 #19
0
def test_solver_2_closed():
    # (log(x) <= 2) & (x >= exp(2))
    solution = FiniteReal(sympy.exp(2))
    event = (Log(Y) >= 2) & (Y <= sympy.exp(2))
    answer = event.solve()
    assert answer == solution
コード例 #20
0
def test_parse_20():
    # log(x**2 - 3) < 5
    expr = Log(X**2 - 3)
    expr_prime = Log(Poly(Y, [-3, 0, 1]))
    event = EventInterval(expr_prime, Interval.open(-oo, 5))
    assert (expr < 5) == event
コード例 #21
0
def test_solver_2_open():
    # log(x) < 2 & (x < exp(2))
    solution = EmptySet
    event = (Log(Y) > 2) & (Y < sympy.exp(2))
    answer = event.solve()
    assert answer == solution
コード例 #22
0
def test_parse_1_open():
    # log(x) > 2
    expr = Log(X) > 2
    event = EventInterval(Log(Y), Interval(2, oo, left_open=True))
    assert expr == event
コード例 #23
0
def test_solver_11_closed():
    # exp(sqrt(log(x))) >= 6
    solution = Interval(sympy.exp(sympy.log(6)**2), oo)
    event = Exp(Sqrt(Log(Y))) >= 6
    answer = event.solve()
    assert answer == solution
コード例 #24
0
def test_parse_10():
    # exp(sqrt(log(x))) > -5
    expr = Exp(Sqrt(Log(Y)))
    event = EventInterval(expr, Interval.open(-5, oo))
    assert (expr > -5) == event
コード例 #25
0
def test_dnf_factor():
    E00 = Exp(X0) > 0
    E01 = X0 < 10
    E10 = X1 < 10
    E20 = (X2**2 - X2 * 3) < 0
    E30 = X3 > 10
    E31 = (Sqrt(2 * X3)) < 0
    E40 = X4 > 0
    E41 = X4 << [1, 5]
    E50 = 10 * Log(X5) + 9 > 5

    event = (E00)
    event_dnf = event.to_dnf()
    dnf = dnf_factor(event_dnf)
    assert len(dnf) == 1
    assert dnf[0][X0] == E00

    event = E00 & E01
    event_dnf = event.to_dnf()
    dnf = dnf_factor(event_dnf)
    assert len(dnf) == 1
    assert dnf[0][X0] == E00 & E01

    event = E00 | E01
    event_dnf = event.to_dnf()
    dnf = dnf_factor(event_dnf)
    assert len(dnf) == 2
    assert dnf[0][X0] == E00
    assert dnf[1][X0] == E01

    event = E00 | (E01 & E10)
    event_dnf = event.to_dnf()
    dnf = dnf_factor(event_dnf, {X0: 0, X1: 0})
    assert len(dnf) == 2
    assert dnf[0][0] == E00
    assert dnf[1][0] == E01 & E10

    event = (E00 & E01 & E10 & E30 & E40) | (E20 & E50 & E31 & ~E41)
    # For the second clause we have:
    #   ~E41 = (-oo, 1) U (1, 5) U (5, oo)
    # so the second clause becomes
    # = (E20 & E50 & E31 & ((-oo, 1) U (1, 5) U (5, oo)))
    # = (E20 & E50 & E31 & (-oo, 1))
    #   or (E20 & E50 & E31 & (1, 5))
    #   or (E20 & E50 & E31 & (5, oo))
    event_dnf = event.to_dnf()
    event_factor = dnf_factor(event_dnf)
    assert len(event_factor) == 4
    # clause 0
    assert len(event_factor[0]) == 4
    assert event_factor[0][X0] == E00 & E01
    assert event_factor[0][X1] == E10
    assert event_factor[0][X3] == E30
    assert event_factor[0][X4] == E40
    # clause 1
    assert len(event_factor[1]) == 4
    assert event_factor[1][X3] == E31
    assert event_factor[1][X2] == E20
    assert event_factor[1][X4] == (X4 < 1)
    assert event_factor[1][X5] == E50
    # clause 2
    assert len(event_factor[2]) == 4
    assert event_factor[2][X3] == E31
    assert event_factor[2][X2] == E20
    assert event_factor[2][X4] == (1 < (X4 < 5))
    assert event_factor[2][X5] == E50
    # clause 3
    assert len(event_factor[3]) == 4
    assert event_factor[3][X3] == E31
    assert event_factor[3][X2] == E20
    assert event_factor[3][X4] == (5 < X4)
    assert event_factor[3][X5] == E50
コード例 #26
0
def test_solver_1_open():
    # log(x) > 2
    solution = Interval.open(sympy.exp(2), oo)
    event = Log(Y) > 2
    answer = event.solve()
    assert answer == solution
コード例 #27
0
def test_solver_1_closed():
    # log(x) >= 2
    solution = Interval(sympy.exp(2), oo)
    event = Log(Y) >= 2
    answer = event.solve()
    assert answer == solution
コード例 #28
0
def test_parse_1_closed():
    # log(x) >= 2
    expr = Log(X) >= 2
    event = EventInterval(Log(Y), Interval(2, oo))
    assert expr == event