示例#1
0
def test_claim_3():
    """
    test we can apply rule 3 forwards to rewrite pr(y|do(z),do(x)) as pr(y|do(z))
    """
    graph = make_toy_graph()
    bindings = {
        'z' : set(['z']),
        'x' : set(['x']),
        'y' : set(['y']),
    }
    bind = bindings.get
    root_expr = E.prob([E.v('y')], [E.do(E.v('z')), E.do(E.v('x'))])

    rule = get_rule('ignore_intervention_entirely_forward')

    sites = list(rule['site_gen'](root_expr))
    assert len(sites) == 2
    site = sites[1]

    prepped_args = prepare_rule_arguments(rule['unpack_target'], site)
    bound_args = bind_arguments(bind, prepped_args)
    assert rule['assumption_test'](g = graph, **bound_args)

    root_expr_prime = rule['apply'](site)
    assert root_expr_prime == E.prob([E.v('y')], [E.do(E.v('z'))])
示例#2
0
def test_full_problem():
    graph = make_toy_graph()

    banned_values = set([frozenset(['h'])])
    heuristic = make_heuristic(banned_values, greed=10)

    initial_bindings = {
        'x': set(['x']),
        'y': set(['y']),
    }

    initial_expr = E.prob([E.v('y')], [E.do(E.v('x'))])

    initial_proof_state = ProofState(
        length=0,  # length of proof
        heuristic_length=0,
        bindings=initial_bindings,
        root_expr=initial_expr,
    ).normalise()

    initial_proof_state = initial_proof_state.copy(
        heuristic_length=heuristic(initial_proof_state))

    def goal_check(proof_state):
        return proof_state.heuristic_length == 0

    result = proof_search(initial_proof_state,
                          graph,
                          goal_check,
                          heuristic,
                          max_proof_length=7)
    assert result['reached_goal']
示例#3
0
def test_full_problem():
    graph = make_toy_graph()

    banned_values = set([frozenset(['h'])])
    heuristic = make_heuristic(banned_values, greed=10)

    initial_bindings = {
        'x' : set(['x']),
        'y' : set(['y']),
    }

    initial_expr = E.prob([E.v('y')], [E.do(E.v('x'))])

    initial_proof_state = ProofState(
        length = 0, # length of proof
        heuristic_length = 0,
        bindings = initial_bindings,
        root_expr = initial_expr,
    ).normalise()

    initial_proof_state = initial_proof_state.copy(heuristic_length=heuristic(initial_proof_state))

    def goal_check(proof_state):
        return proof_state.heuristic_length == 0

    result = proof_search(initial_proof_state, graph, goal_check, heuristic, max_proof_length=7)
    assert result['reached_goal']
示例#4
0
def main():

    if len(sys.argv) != 2:
        sys.stderr.write('usage: greediness (positive float...)\n')
        sys.exit(1)

    greed = float(sys.argv[1])

    graph = make_toy_graph()

    banned_values = set([frozenset(['h'])])
    # dial the greed parameter up high.
    # this makes the search very optimistic.
    # in general this may not find the shortest proof
    heuristic = make_heuristic(banned_values, greed)

    initial_bindings = {
        'x': frozenset(['x']),
        'y': frozenset(['y']),
    }

    initial_expr = E.prob([E.v('y')], [E.do(E.v('x'))])

    initial_proof_state = ProofState(
        length=0,  # length of proof
        heuristic_length=0,
        bindings=initial_bindings,
        root_expr=initial_expr,
        parent=None,
        comment='initial state',
    ).normalise()

    # this is a little silly
    initial_proof_state = initial_proof_state.copy(
        heuristic_length=heuristic(initial_proof_state))

    def goal_check(proof_state):
        return proof_state.heuristic_length == 0

    result = proof_search(initial_proof_state,
                          graph,
                          goal_check,
                          heuristic,
                          max_proof_length=7)
    assert result['reached_goal']
    print 'success!'

    display_proof_as_listing(result['path'])

    out_file_name = 'proof_tree.dot'
    write_proof_tree(result['path'], result['closed'], out_file_name)
示例#5
0
    def gen_expansions(value, proof_state):
        for (prob_expr, inject) in E.gen_matches(E.is_prob, proof_state.root_expr):
            prob_vars = get_variable_order(prob_expr)
            prob_values = [proof_state.bindings[v] for v in prob_vars]
            if value in prob_values:
                continue
            # prob([x],[w]) -> sigma(y, product([prob([x], [y, w]), p([y], [w])]))
            i = new_variable_name(proof_state.bindings)
            v_i = E.v(i)
            alpha_left = tuple(prob_expr[1])
            alpha_right = (v_i, ) + tuple(prob_expr[2])
            alpha = E.prob(alpha_left, alpha_right)
            beta_left = (v_i, )
            beta_right = tuple(prob_expr[2])
            beta = E.prob(beta_left, beta_right)
            expr_prime = E.sigma(v_i, E.product([alpha, beta]))

            succ_length = proof_state.length + 1
            succ_heuristic = 0
            succ_bindings = dict(proof_state.bindings)
            succ_bindings[i] = value
            succ_root_expr = inject(expr_prime)

            succ_comment = 'conditioned %s on %s' % (
                pleasantly_fmt(proof_state.bindings, prob_expr),
                make_canonical_variable_name(value))

            succ_proof_state = ProofState(succ_length, succ_heuristic, succ_bindings, succ_root_expr,
                parent=proof_state, comment=succ_comment)
            yield succ_proof_state
示例#6
0
def main():

    if len(sys.argv) != 2:
        sys.stderr.write('usage: greediness (positive float...)\n')
        sys.exit(1)

    greed = float(sys.argv[1])

    graph = make_toy_graph()

    banned_values = set([frozenset(['h'])])
    # dial the greed parameter up high.
    # this makes the search very optimistic.
    # in general this may not find the shortest proof
    heuristic = make_heuristic(banned_values, greed)

    initial_bindings = {
        'x' : frozenset(['x']),
        'y' : frozenset(['y']),
    }

    initial_expr = E.prob([E.v('y')], [E.do(E.v('x'))])

    initial_proof_state = ProofState(
        length = 0, # length of proof
        heuristic_length = 0,
        bindings = initial_bindings,
        root_expr = initial_expr,
        parent = None,
        comment = 'initial state',
    ).normalise()

    # this is a little silly
    initial_proof_state = initial_proof_state.copy(heuristic_length=heuristic(initial_proof_state))
    
    def goal_check(proof_state):
        return proof_state.heuristic_length == 0

    result = proof_search(initial_proof_state, graph, goal_check, heuristic, max_proof_length=7)
    assert result['reached_goal']
    print 'success!'

    display_proof_as_listing(result['path'])

    out_file_name = 'proof_tree.dot'
    write_proof_tree(result['path'], result['closed'], out_file_name)
示例#7
0
def test_gen_matches():
    root_expr = prob([v('z')], [do(v('x'))])

    matches = list(gen_matches(is_v, root_expr))
    assert len(matches) == 2
    match_a, match_b = matches
    assert match_a[0] == v('z')
    assert match_b[0] == v('x')

    # test substitution machinery
    assert match_a[1]('banana') == prob(['banana'], [do(v('x'))])
    assert match_b[1]('rambutan') == prob([v('z')], [do('rambutan')])
示例#8
0
def test_gen_matches():
    root_expr = prob([v("z")], [do(v("x"))])

    matches = list(gen_matches(is_v, root_expr))
    assert len(matches) == 2
    match_a, match_b = matches
    assert match_a[0] == v("z")
    assert match_b[0] == v("x")

    # test substitution machinery
    assert match_a[1]("banana") == prob(["banana"], [do(v("x"))])
    assert match_b[1]("rambutan") == prob([v("z")], [do("rambutan")])
示例#9
0
def test_normalise_single_iter():
    root_expr = sigma(v('x'), product([prob([v('z'), v('y')], [v('b'), v('x'), do(v('a'))]),
        prob([v('z'), v('y'), v('x')], [do(v('a'))])]))

    bindings = {'x' : 'xxx', 'z' : 'zzz', 'y' : 'yyy', 'a' : 'aaa'}

    state = ProofState(0, 0, bindings, root_expr)

    normalised_state = state.normalise(max_iters=1)

    # first up: expression ordering (nb do(v()) comes before v() in sorted lists)
    # sigma(x, product([prob([x y z],[do(a)]), prob([y z], [(do a) b x])]))
    # so, variable order should be:
    # x y z a b
    # so, new variable names should be
    # 0 1 2 3 4
    # so, normalised state should be
    # sigma(0, product([prob([0 1 2],[do(3)]), prob([1 2], [(do 3) 4 0])]))

    expected_result = sigma(v(0), product((prob((v(0), v(1), v(2)),(do(v(3)), )),
        prob((v(1), v(2)), (do(v(3)), v(4), v(0))))))

    assert normalised_state.root_expr == expected_result
示例#10
0
def test_normalise_fixed_point():
    root_expr = sigma(v('x'), product([prob([v('z'), v('y')], [v('b'), v('x'), do(v('a'))]),
        prob([v('z'), v('y'), v('x')], [do(v('a'))])]))

    bindings = {'x' : 'xxx', 'z' : 'zzz', 'y' : 'yyy', 'a' : 'aaa'}

    state = ProofState(0, 0, bindings, root_expr)

    normalised_state = state.normalise()

    expected_result = sigma(v(0), product((prob((v(0), v(1), v(2)),(do(v(3)), )),
        prob((v(1), v(2)), (do(v(3)), v(0), v(4))))))

    assert normalised_state.root_expr == expected_result
示例#11
0
def test_fmt():
    root_expr = prob([v('z')], [do(v('x'))])
    assert fmt(root_expr) == 'pr(z|do(x))'
示例#12
0
def test_gen_matches_deep():
    # sigma_y { p(x|y,do(z)) * p(y|do(z)) }
    root_expr = sigma(
        v('y'),
        product([
            prob([v('x')], [v('y'), do(v('z'))]),
            prob([v('y')], [do(v('z'))])
        ]))

    matches = list(gen_matches(is_v, root_expr))
    assert len(matches) == 6

    expr, inject = matches[3]
    assert expr == v('z')

    root_expr_prime = inject('walrus')
    assert root_expr_prime == sigma(
        v('y'),
        product([
            prob([v('x')], [v('y'), do('walrus')]),
            prob([v('y')], [do(v('z'))])
        ]))
示例#13
0
def test_gen_v_sites():
    root_expr = prob([v('z')], [v('w'), do(v('x')), do(v('y'))])

    sites = list(gen_v_sites(root_expr))
    assert len(sites) == 1
    target, inject, left, vs, dos, _ = sites[0]
    atom, = target
    assert atom == v('w')
    assert inject('banana') == prob(
        [v('z')], ['banana', do(v('x')), do(v('y'))])
    assert left == (v('z'), )
    assert vs == []
    assert dos == [do(v('x')), do(v('y'))]
示例#14
0
def test_fmt():
    root_expr = prob([v("z")], [do(v("x"))])
    assert fmt(root_expr) == "pr(z|do(x))"
示例#15
0
def test_gen_matches_deep():
    # sigma_y { p(x|y,do(z)) * p(y|do(z)) }
    root_expr = sigma(v("y"), product([prob([v("x")], [v("y"), do(v("z"))]), prob([v("y")], [do(v("z"))])]))

    matches = list(gen_matches(is_v, root_expr))
    assert len(matches) == 6

    expr, inject = matches[3]
    assert expr == v("z")

    root_expr_prime = inject("walrus")
    assert root_expr_prime == sigma(
        v("y"), product([prob([v("x")], [v("y"), do("walrus")]), prob([v("y")], [do(v("z"))])])
    )
示例#16
0
def test_gen_v_sites():
    root_expr = prob([v('z')], [v('w'), do(v('x')), do(v('y'))])

    sites = list(gen_v_sites(root_expr))
    assert len(sites) == 1
    target, inject, left, vs, dos, _ = sites[0]
    atom, = target
    assert atom == v('w')
    assert inject('banana') == prob([v('z')], ['banana', do(v('x')), do(v('y'))])
    assert left == (v('z'), )
    assert vs == []
    assert dos == [do(v('x')), do(v('y'))]