コード例 #1
0
ファイル: test_rules.py プロジェクト: silky/d_separation
def test_claim_4():
    """
    test we can apply rule 2 forwards to rewrite pr(y|x,do(z)) as pr(y|x,z)
    """
    graph = make_toy_graph()
    bindings = {
        'z' : set(['z']),
        'x' : set(['x']),
        'y' : set(['y']),
    }
    bind = bindings.get
    root_expr = E.prob([E.v('y')], [E.v('x'), E.do(E.v('z'))])

    rule = get_rule('ignore_intervention_act_forward')

    sites = list(rule['site_gen'](root_expr))
    assert len(sites) == 1
    site, = sites

    prepped_args = prepare_rule_arguments(rule['unpack_target'], site)
    bound_args = bind_arguments(bind, prepped_args)
    assert rule['assumption_test'](g = graph, **bound_args)

    root_expr_prime = rule['apply'](site)
    assert root_expr_prime == E.prob([E.v('y')], [E.v('x'), E.v('z')])
コード例 #2
0
def gen_causal_rule_moves(rules, proof_state, graph):
    def bind(name):
        return proof_state.bindings[name]

    for rule in rules:
        for site in rule['site_gen'](proof_state.root_expr):
            prepped_args = prepare_rule_arguments(rule['unpack_target'], site)
            bound_args = bind_arguments(bind, prepped_args)
            if not rule['assumption_test'](g = graph, **bound_args):
                continue
            succ_length = proof_state.length + 1
            succ_heuristic = 0
            succ_bindings = dict(proof_state.bindings)
            succ_root_expr = rule['apply'](site)

            original_expr = site[-1]
            
            succ_comment = 'applied rule %s to %s' % (
                rule['name'],
                pleasantly_fmt(proof_state.bindings, original_expr))

            succ_proof_state = ProofState(succ_length, succ_heuristic, succ_bindings, succ_root_expr,
                parent=proof_state, comment=succ_comment)
            yield succ_proof_state