-
Notifications
You must be signed in to change notification settings - Fork 1
/
expr.py
255 lines (204 loc) · 7.12 KB
/
expr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
the original plan
v := ('v', variable_name)
do := ('do', v)
atom := do | v
prob := ('prob', [atom, ...], [atom, ...])
product := ('product', [expr, ...])
sigma:= ('sigma', v, expr)
expr := pr | product | sigma
bindings := ((variable_name, subset_of_vertices), ...)
state := (bindings, expr)
claim 1:
there is a reduction from p(z|do(x)) to p(z|x) using rule 2
once more with boilerplate:
start
bindings = (('z', set([Z_vertex])), ('x', set([X_vertex])))
expr = prob([v('z')], [do(v('x'))])
moves
rule 2
goal
bindings = (('z', set([Z_vertex])), ('x', set([X_vertex])))
expr = prob([v('z')], [v('x')])
"""
from util import (compose, identity)
# define basic expression tree structure
def v(name):
return ('v', name)
def do(v_):
return ('do', v_)
def prob(left, right):
return ('prob', tuple(left), tuple(right))
def product(exprs):
return ('product', tuple(exprs))
def sigma(v_, expr):
return ('sigma', v_, expr)
# predicates for case matching
def tag_matches(tag):
def predicate(expr):
return expr[0] == tag
return predicate
is_v = tag_matches('v')
is_do = tag_matches('do')
is_prob = tag_matches('prob')
is_product = tag_matches('product')
is_sigma = tag_matches('sigma')
# a few routines for unpacking vs and dos
def unpack_expr(expr, i=None):
if i is None:
i = 1
return expr[i]
def unpack_v(expr):
assert is_v(expr)
return unpack_expr(expr)
def unpack_do(expr):
assert is_do(expr)
return unpack_expr(expr)
# define how to ugly-print expressions
def fmt(expr):
if is_v(expr):
return str(expr[1])
elif is_do(expr):
return 'do(%s)' % fmt(expr[1])
elif is_prob(expr):
return 'pr(%s|%s)' % (fmt_list(expr[1]), fmt_list(expr[2]))
elif is_product(expr):
return 'product(%s)' % fmt_list(expr[1])
elif is_sigma(expr):
return 'sigma(%s, %s)' % (fmt(expr[1]), fmt(expr[2]))
def fmt_list(exprs):
return ','.join(map(fmt, exprs))
# machinery for find/replacing sub-expressions of expression trees
def gen_matches(predicate, expr, inject=None):
if inject is None:
inject = lambda x : x
# do we match?
if predicate(expr):
yield expr, inject
# recursively generate all matches in child expressions
# n.b. there is a lot of redundancy here that could be cleaned up
# by defining these operations for tuples and lists (hey, both of
# those cases are essentially the same...)
if is_v(expr):
return
elif is_do(expr):
next_expr = expr[1]
next_inject = compose(inject, do)
for result in gen_matches(predicate, next_expr, next_inject):
yield result
elif is_prob(expr):
left, right = expr[1], expr[2]
for i, next_expr in enumerate(left):
iota = make_list_inject(i, left)
prob_inject = make_left_inject('prob', iota, right)
next_inject = compose(inject, prob_inject)
for result in gen_matches(predicate, next_expr, next_inject):
yield result
for i, next_expr in enumerate(right):
iota = make_list_inject(i, right)
prob_inject = make_right_inject('prob', left, iota)
next_inject = compose(inject, prob_inject)
for result in gen_matches(predicate, next_expr, next_inject):
yield result
elif is_product(expr):
children = expr[1]
for i, next_expr in enumerate(children):
iota = make_list_inject(i, children)
product_inject = make_unary_inject('product', iota)
next_inject = compose(inject, product_inject)
for result in gen_matches(predicate, next_expr, next_inject):
yield result
elif is_sigma(expr):
left, right = expr[1], expr[2]
# left case (replace index var)
next_expr = left
sigma_inject = make_left_inject('sigma', identity, right)
next_inject = compose(inject, sigma_inject)
for result in gen_matches(predicate, next_expr, next_inject):
yield result
# right case (replace body expr)
next_expr = right
sigma_inject = make_right_inject('sigma', left, identity)
next_inject = compose(inject, sigma_inject)
for result in gen_matches(predicate, next_expr, next_inject):
yield result
def make_list_inject(i, a):
def list_inject(x=None, drop=False):
if drop:
body = ()
else:
body = (x, )
return tuple(a[:i]) + body + tuple(a[i+1:])
return list_inject
def make_left_inject(tag, iota, right):
def inject(*args, **kwargs):
return (tag, iota(*args, **kwargs), right)
return inject
def make_right_inject(tag, left, iota):
def inject(*args, **kwargs):
return (tag, left, iota(*args, **kwargs))
return inject
def make_unary_inject(tag, iota):
def inject(*args, **kwargs):
return (tag, iota(*args, **kwargs))
return inject
# machinery for normalising things
def normalise_expr(expr):
if is_v(expr):
return expr
elif is_do(expr):
return expr
elif is_prob(expr):
return normalise_prob(expr)
elif is_product(expr):
return normalise_product(expr)
elif is_sigma(expr):
return normalise_sigma(expr)
raise ValueError(expr)
def normalise_expr_list(exprs):
"""this is the only normalisation function that actually does anything!"""
return tuple(sorted(map(normalise_expr, exprs)))
def normalise_prob(expr):
return (expr[0], normalise_expr_list(expr[1]), normalise_expr_list(expr[2]))
def normalise_product(expr):
return (expr[0], normalise_expr_list(expr[1]))
def normalise_sigma(expr):
return (expr[0], normalise_expr(expr[1]), normalise_expr(expr[2]))
# a single pass combined filter-map.
# it doesn't actually filter. it acts like identity where predicate is false.
# heh.
def filter_map(predicate, f, root_expr):
def _filter_map(expr):
if predicate(expr):
return f(expr)
elif is_v(expr):
return v(expr)
elif is_do(expr):
return do(_filter_map(expr[1]))
elif is_prob(expr):
return prob(_filter_map_exprs(expr[1]), _filter_map_exprs(expr[2]))
elif is_product(expr):
return product(_filter_map_exprs(expr[1]))
elif is_sigma(expr):
return sigma(_filter_map(expr[1]), _filter_map(expr[2]))
def _filter_map_exprs(exprs):
return tuple(_filter_map(expr) for expr in exprs)
return _filter_map(root_expr)
def filter_walk(predicate, f, root_expr):
def _filter_walk(expr):
if predicate(expr):
f(expr)
elif is_do(expr):
_filter_walk(expr[1])
elif is_prob(expr):
_filter_walk_exprs(expr[1])
_filter_walk_exprs(expr[2])
elif is_product(expr):
_filter_walk_exprs(expr[1])
elif is_sigma(expr):
_filter_walk(expr[1])
_filter_walk(expr[2])
def _filter_walk_exprs(exprs):
for expr in exprs:
_filter_walk(expr)
_filter_walk(root_expr)