forked from aflaxman/bednet_stock_and_flow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
explore.py
executable file
·309 lines (252 loc) · 10.1 KB
/
explore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
""" Script for exploratory analysis of the bednet model estimates"""
import re
import pylab as pl
import pymc
def load_pickles(path='./'):
""" Load all of the files with name bednet_model.*pickle in the
specified directory
Example
-------
>>> db = explore.load_pickles('/home/j/Project/Models/bednets/2010_07_09/')
"""
import os, sys
file_list = os.listdir(path)
db = {}
for f in file_list:
match = re.match('^bednet_model.*pickle$', f)
if match:
print 'loading', f, '...',
sys.stdout.flush()
k=match.group()
db[k] = pymc.database.pickle.load(path + f)
print 'finished.'
return db
def plot_net_survival(db, country_list):
import pylab as pl
import settings
pl.clf()
ii = 0.
for k, p in sorted(db.items()):
country = k.split('_')[2] # TODO: refactor k.split into function
if country not in country_list:
continue
pr = pl.sort(p.__getattribute__('Pr[net is lost]').gettrace())
pr0 = pr[.025*len(pr)]
pr1 = pr[.975*len(pr)]
t = pl.arange(0,5,.1)
pct0 = 100. * pl.where(t<3, (1-pr0)**t, 0.)
pct1 = 100. * pl.where(t<3, (1-pr1)**t, 0.)
pl.fill(pl.concatenate((t, t[::-1])),
pl.concatenate((pct0, pct1[::-1])),
alpha=.9, linewidth=3,
facecolor='none',
edgecolor=pl.cm.spectral(ii/len(country_list)),
label=country)
pl.fill(pl.concatenate((t, t[::-1])),
pl.concatenate((pct0, pct1[::-1])),
alpha=.5, linewidth=0,
facecolor=pl.cm.spectral(ii/len(country_list)),
zorder=-ii)
ii += 1.
pl.legend()
pl.ylabel('Nets Remaining (%)')
pl.xlabel('Time in household (years)')
pl.title('LLIN Survival Curve Posteriors')
pl.savefig(settings.PATH + 'net_survival.png')
def summary_table(db, table_start=2007, table_end=2010, parameter='itn coverage', midyear=True):
""" Output a table of midyear coverage estimates by country
Example
-------
>>> db = explore.load_pickles('/home/j/Project/Models/bednets/2010_07_09/')
>>> tab = explore.midyear_coverage_table(db)
>>> f = open('/home/j/Project/Models/bednets/2010_08_05/best_case.csv', 'w')
>>> import csv
>>> cf = csv.writer(f)
>>> cf.writerows(tab)
>>> f.close()
"""
import settings
from pylab import mean, std, sort
headers = [ 'Country' ]
for y in range(table_start, table_end+1):
headers += [y, 'ui']
tab = [ headers ]
for k, p in sorted(db.items()):
row = [k.split('_')[2]] # TODO: refactor k.split into function
cov = p.__getattribute__(parameter).gettrace()
for y in range(table_start, table_end+1):
i = y-settings.year_start
if midyear:
c_y = sort(.5 * (cov[:, i] + cov[:, i+1])) # compute mid-year estimate from posterior draws
else:
c_y = sort(cov[:, i]) # compute jan 1 / whole-year estimate posterior draws
n = len(c_y)
row += ['%f' % c_y[.5*n], '(%f, %f)' % (c_y[.025*n], c_y[.975*n])]
tab.append(row)
return tab
def detailed_summary_table(db):
""" Output a table that duplicates the summary information generated by individual runs
Example
-------
>>> db = explore.load_pickles('/home/j/Project/Models/bednets/2010_09_23/')
>>> tab = explore.summary_table(db)
>>> f = open('/home/j/Project/Models/bednets/2010_09_23/summary.csv', 'w')
>>> import csv
>>> cf = csv.writer(f)
>>> cf.writerows(tab)
>>> f.close()
"""
import settings
from pylab import mean, std, sort
# save results in output file
headers = [
'Country', 'Year', 'Population',
'LLINs Shipped (Thousands)', 'LLINs Shipped Lower CI', 'LLINs Shipped Upper CI',
'LLINs Distributed (Thousands)', 'LLINs Distributed Lower CI', 'LLINs Distributed Upper CI',
'LLINs Not Owned Warehouse (Thousands)', 'LLINs Not Owned Lower CI', 'LLINs Not Owned Upper CI',
'LLINs Owned (Thousands)', 'LLINs Owned Lower CI', 'LLINs Owned Upper CI',
'non-LLIN ITNs Owned (Thousands)', 'non-LLIN ITNs Owned Lower CI', 'non-LLIN ITNs Owned Upper CI',
'ITNs Owned (Thousands)', 'ITNs Owned Lower CI', 'ITNs Owned Upper CI',
'LLIN Coverage (Percent)', 'LLIN Coverage Lower CI', 'LLIN Coverage Upper CI',
'ITN Coverage (Percent)', 'ITN Coverage Lower CI', 'ITN Coverage Upper CI',
]
tab = [ headers ]
year_start = settings.year_start
year_end = settings.year_end
from data import Data
data = Data()
from pymc.utils import hpd
def my_summary(stoch, i, li, ui, factor=.001):
row = []
row += [mean(trace[stoch][:, i])*factor]
row += list(hpd(trace[stoch][[li,ui], i], .05)*factor)
return row
for k, p in sorted(db.items()):
trace = {}
for stoch in ['llins shipped', 'llins distributed', 'llin warehouse net stock', 'household llin stock', 'non-llin household net stock', 'household itn stock', 'llin coverage', 'itn coverage']:
trace[stoch] = sort(p.__getattribute__(stoch).gettrace(), axis=0)
c = k.split('_')[2] # TODO: refactor k.split into function
pop = data.population_for(c, year_start, year_end)
for i in range(year_end - year_start):
row = [c, year_start + i, pop[i]]
li = .025 * len(trace['llins shipped'][:,0])
ui = .975 * len(trace['llins shipped'][:,0])
if i == year_end - year_start - 1:
row += [-99, -99, -99]
row += [-99, -99, -99]
else:
row += my_summary('llins shipped', i, li, ui)
row += my_summary('llins distributed', i, li, ui)
row += my_summary('llin warehouse net stock', i, li, ui)
row += my_summary('household llin stock', i, li, ui)
row += my_summary('non-llin household net stock', i, li, ui)
row += my_summary('household itn stock', i, li, ui)
row += my_summary('llin coverage', i, li, ui, 100)
row += my_summary('itn coverage', i, li, ui, 100)
tab.append(row)
return tab
def summarize_fits(path=''):
""" Generate summary tables for all models in a given dir
Parameters
----------
path : str, optional
if path is blank, use settings.PATH
Example
-------
>>> explore.summarize_fits('./') # use pickle files in current directory
"""
if not path:
import settings
path = settings.PATH
db = load_pickles(path)
import csv
for p in ['itn coverage', 'llins distributed', 'non-llin household net stock']:
if p == 'llins distributed':
midyear = False
else:
midyear = True
rows = summary_table(db, parameter=p, midyear=midyear)
p = p.replace(' ', '_')
f = open(path + 'summary_%s.csv'%p, 'w')
cf = csv.writer(f)
cf.writerows(rows)
f.close()
# TODO: notify that model is complete
# e.g. http://www.al1us.net/?p=79 to notify via skype msg
def scatter_stats(db, s1, s2, f1=None, f2=None, **kwargs):
if f1 == None:
f1 = lambda x: x # constant function
if f2 == None:
f2 = f1
x = []
xerr = []
y = []
yerr = []
for k in db:
x_k = [f1(x_ki) for x_ki in db[k].__getattribute__(s1).gettrace()]
y_k = [f2(y_ki) for y_ki in db[k].__getattribute__(s2).gettrace()]
x.append(pl.mean(x_k))
xerr.append(pl.std(x_k))
y.append(pl.mean(y_k))
yerr.append(pl.std(y_k))
pl.text(x[-1], y[-1], ' %s' % k, fontsize=8, alpha=.4, zorder=-1)
default_args = {'fmt': 'o', 'ms': 10}
default_args.update(kwargs)
pl.errorbar(x, y, xerr=xerr, yerr=yerr, **default_args)
pl.xlabel(s1)
pl.ylabel(s2)
def compare_models(db, stoch='itn coverage', stat_func=None, plot_type='', **kwargs):
if stat_func == None:
stat_func = lambda x: x
X = {}
for k in sorted(db.keys()):
c = k.split('_')[2]
X[c] = []
for k in sorted(db.keys()):
c = k.split('_')[2]
X[c].append(
[stat_func(x_ki) for x_ki in
db[k].__getattribute__(stoch).gettrace()]
)
x = pl.array([pl.mean(xc[0]) for xc in X.values()])
xerr = pl.array([pl.std(xc[0]) for xc in X.values()])
y = pl.array([pl.mean(xc[1]) for xc in X.values()])
yerr = pl.array([pl.std(xc[1]) for xc in X.values()])
if plot_type == 'scatter':
default_args = {'fmt': 'o', 'ms': 10}
default_args.update(kwargs)
for c in X.keys():
pl.text(pl.mean(X[c][0]),
pl.mean(X[c][1]),
' %s' % c, fontsize=8, alpha=.4, zorder=-1)
pl.errorbar(x, y, xerr=xerr, yerr=yerr, **default_args)
pl.xlabel('First Model')
pl.ylabel('Second Model')
pl.plot([0,1], [0,1], alpha=.5, linestyle='--', color='k', linewidth=2)
elif plot_type == 'rel_diff':
d1 = sorted(100*(x-y)/x)
d2 = sorted(100*(xerr-yerr)/xerr)
pl.subplot(2,1,1)
pl.title('Percent Model 2 deviates from Model 1')
pl.plot(d1, 'o')
pl.xlabel('Countries sorted by deviation in mean')
pl.ylabel('deviation in mean (%)')
pl.subplot(2,1,2)
pl.plot(d2 ,'o')
pl.xlabel('Countries sorted by deviation in std err')
pl.ylabel('deviation in std err (%)')
elif plot_type == 'abs_diff':
d1 = sorted(x-y)
d2 = sorted(xerr-yerr)
pl.subplot(2,1,1)
pl.title('Percent Model 2 deviates from Model 1')
pl.plot(d1, 'o')
pl.xlabel('Countries sorted by deviation in mean')
pl.ylabel('deviation in mean')
pl.subplot(2,1,2)
pl.plot(d2 ,'o')
pl.xlabel('Countries sorted by deviation in std err')
pl.ylabel('deviation in std err')
else:
assert 0, 'plot_type must be abs_diff, rel_diff, or scatter'
return pl.array([x,y,xerr,yerr])