Exemple #1
0
    def gen_expansions(value, proof_state):
        for (prob_expr, inject) in E.gen_matches(E.is_prob, proof_state.root_expr):
            prob_vars = get_variable_order(prob_expr)
            prob_values = [proof_state.bindings[v] for v in prob_vars]
            if value in prob_values:
                continue
            # prob([x],[w]) -> sigma(y, product([prob([x], [y, w]), p([y], [w])]))
            i = new_variable_name(proof_state.bindings)
            v_i = E.v(i)
            alpha_left = tuple(prob_expr[1])
            alpha_right = (v_i, ) + tuple(prob_expr[2])
            alpha = E.prob(alpha_left, alpha_right)
            beta_left = (v_i, )
            beta_right = tuple(prob_expr[2])
            beta = E.prob(beta_left, beta_right)
            expr_prime = E.sigma(v_i, E.product([alpha, beta]))

            succ_length = proof_state.length + 1
            succ_heuristic = 0
            succ_bindings = dict(proof_state.bindings)
            succ_bindings[i] = value
            succ_root_expr = inject(expr_prime)

            succ_comment = 'conditioned %s on %s' % (
                pleasantly_fmt(proof_state.bindings, prob_expr),
                make_canonical_variable_name(value))

            succ_proof_state = ProofState(succ_length, succ_heuristic, succ_bindings, succ_root_expr,
                parent=proof_state, comment=succ_comment)
            yield succ_proof_state
Exemple #2
0
def test_gen_matches():
    root_expr = prob([v('z')], [do(v('x'))])

    matches = list(gen_matches(is_v, root_expr))
    assert len(matches) == 2
    match_a, match_b = matches
    assert match_a[0] == v('z')
    assert match_b[0] == v('x')

    # test substitution machinery
    assert match_a[1]('banana') == prob(['banana'], [do(v('x'))])
    assert match_b[1]('rambutan') == prob([v('z')], [do('rambutan')])
Exemple #3
0
def test_gen_matches():
    root_expr = prob([v("z")], [do(v("x"))])

    matches = list(gen_matches(is_v, root_expr))
    assert len(matches) == 2
    match_a, match_b = matches
    assert match_a[0] == v("z")
    assert match_b[0] == v("x")

    # test substitution machinery
    assert match_a[1]("banana") == prob(["banana"], [do(v("x"))])
    assert match_b[1]("rambutan") == prob([v("z")], [do("rambutan")])
Exemple #4
0
def test_gen_matches_deep():
    # sigma_y { p(x|y,do(z)) * p(y|do(z)) }
    root_expr = sigma(v("y"), product([prob([v("x")], [v("y"), do(v("z"))]), prob([v("y")], [do(v("z"))])]))

    matches = list(gen_matches(is_v, root_expr))
    assert len(matches) == 6

    expr, inject = matches[3]
    assert expr == v("z")

    root_expr_prime = inject("walrus")
    assert root_expr_prime == sigma(
        v("y"), product([prob([v("x")], [v("y"), do("walrus")]), prob([v("y")], [do(v("z"))])])
    )
Exemple #5
0
def test_gen_matches_deep():
    # sigma_y { p(x|y,do(z)) * p(y|do(z)) }
    root_expr = sigma(
        v('y'),
        product([
            prob([v('x')], [v('y'), do(v('z'))]),
            prob([v('y')], [do(v('z'))])
        ]))

    matches = list(gen_matches(is_v, root_expr))
    assert len(matches) == 6

    expr, inject = matches[3]
    assert expr == v('z')

    root_expr_prime = inject('walrus')
    assert root_expr_prime == sigma(
        v('y'),
        product([
            prob([v('x')], [v('y'), do('walrus')]),
            prob([v('y')], [do(v('z'))])
        ]))
Exemple #6
0
 def gen_moves(root_expr):
     for (expr, expr_inject) in gen_matches(expr_predicate, root_expr):
         for site in gen_target_sites(expr):
             target, site_inject, left, vs, dos = site
             inject = compose(expr_inject, site_inject)
             yield (target, inject, left, vs, dos, expr)
Exemple #7
0
 def gen_moves(root_expr):
     for (expr, expr_inject) in gen_matches(expr_predicate, root_expr):
         for site in gen_target_sites(expr):
             target, site_inject, left, vs, dos = site
             inject = compose(expr_inject, site_inject)
             yield (target, inject, left, vs, dos, expr)
Exemple #8
0
def count_sigmas(expr):
    return len(list(E.gen_matches(E.is_sigma, expr)))