def gen_expansions(value, proof_state): for (prob_expr, inject) in E.gen_matches(E.is_prob, proof_state.root_expr): prob_vars = get_variable_order(prob_expr) prob_values = [proof_state.bindings[v] for v in prob_vars] if value in prob_values: continue # prob([x],[w]) -> sigma(y, product([prob([x], [y, w]), p([y], [w])])) i = new_variable_name(proof_state.bindings) v_i = E.v(i) alpha_left = tuple(prob_expr[1]) alpha_right = (v_i, ) + tuple(prob_expr[2]) alpha = E.prob(alpha_left, alpha_right) beta_left = (v_i, ) beta_right = tuple(prob_expr[2]) beta = E.prob(beta_left, beta_right) expr_prime = E.sigma(v_i, E.product([alpha, beta])) succ_length = proof_state.length + 1 succ_heuristic = 0 succ_bindings = dict(proof_state.bindings) succ_bindings[i] = value succ_root_expr = inject(expr_prime) succ_comment = 'conditioned %s on %s' % ( pleasantly_fmt(proof_state.bindings, prob_expr), make_canonical_variable_name(value)) succ_proof_state = ProofState(succ_length, succ_heuristic, succ_bindings, succ_root_expr, parent=proof_state, comment=succ_comment) yield succ_proof_state
def test_gen_matches(): root_expr = prob([v('z')], [do(v('x'))]) matches = list(gen_matches(is_v, root_expr)) assert len(matches) == 2 match_a, match_b = matches assert match_a[0] == v('z') assert match_b[0] == v('x') # test substitution machinery assert match_a[1]('banana') == prob(['banana'], [do(v('x'))]) assert match_b[1]('rambutan') == prob([v('z')], [do('rambutan')])
def test_gen_matches(): root_expr = prob([v("z")], [do(v("x"))]) matches = list(gen_matches(is_v, root_expr)) assert len(matches) == 2 match_a, match_b = matches assert match_a[0] == v("z") assert match_b[0] == v("x") # test substitution machinery assert match_a[1]("banana") == prob(["banana"], [do(v("x"))]) assert match_b[1]("rambutan") == prob([v("z")], [do("rambutan")])
def test_gen_matches_deep(): # sigma_y { p(x|y,do(z)) * p(y|do(z)) } root_expr = sigma(v("y"), product([prob([v("x")], [v("y"), do(v("z"))]), prob([v("y")], [do(v("z"))])])) matches = list(gen_matches(is_v, root_expr)) assert len(matches) == 6 expr, inject = matches[3] assert expr == v("z") root_expr_prime = inject("walrus") assert root_expr_prime == sigma( v("y"), product([prob([v("x")], [v("y"), do("walrus")]), prob([v("y")], [do(v("z"))])]) )
def test_gen_matches_deep(): # sigma_y { p(x|y,do(z)) * p(y|do(z)) } root_expr = sigma( v('y'), product([ prob([v('x')], [v('y'), do(v('z'))]), prob([v('y')], [do(v('z'))]) ])) matches = list(gen_matches(is_v, root_expr)) assert len(matches) == 6 expr, inject = matches[3] assert expr == v('z') root_expr_prime = inject('walrus') assert root_expr_prime == sigma( v('y'), product([ prob([v('x')], [v('y'), do('walrus')]), prob([v('y')], [do(v('z'))]) ]))
def gen_moves(root_expr): for (expr, expr_inject) in gen_matches(expr_predicate, root_expr): for site in gen_target_sites(expr): target, site_inject, left, vs, dos = site inject = compose(expr_inject, site_inject) yield (target, inject, left, vs, dos, expr)
def count_sigmas(expr): return len(list(E.gen_matches(E.is_sigma, expr)))