/
gibbsAppr0.py
341 lines (268 loc) · 12.7 KB
/
gibbsAppr0.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
"""
V1.2 use adaptive range for integrating over f
variance 0.0001
"""
import scipy.stats as _ss
import os
import time as _tm
from ig_prmLib import ig_prmsUV
import numpy as _N
import matplotlib.pyplot as _plt
from EnDedirs import resFN, datFN
import pickle
class singleRecptvFld:
ky_p_l0 = 0; ky_p_f = 1; ky_p_q2 = 2
ky_h_l0_a = 0; ky_h_l0_B=1;
ky_h_f_u = 2; ky_h_f_q2=3;
ky_h_q2_a = 4; ky_h_q2_B=5;
dt = 0.001
# position dependent firing rate
###################################### PRIORS
twpi = 2*_N.pi
# sizes of arrays
Nupx = 500 # # points to sample position with (uniform lam(x)p(x))
fss = 100 # sampling at various values of f
q2ss = 200 # sampling at various values of q2
intvs = None #
dat = None
diffPerMin = 1. # diffusion per minute
epochs = None
outdir = None
def __init__(self, outdir, fn, intvfn):
oo = self
###################################### DATA input, define intervals
# bFN = fn[0:-4]
oo.outdir = outdir
# if not os.access(bFN, os.F_OK):
# os.mkdir(bFN)
oo.dat = _N.loadtxt("%s.dat" % datFN(fn, create=False))
oo.datprms= _N.loadtxt("%s_prms.dat" % datFN(fn, create=False))
intvs = _N.loadtxt("%s.dat" % datFN(intvfn, create=False))
oo.intvs = _N.array(intvs*oo.dat.shape[0], dtype=_N.int)
oo.epochs = oo.intvs.shape[0] - 1
NT = oo.dat.shape[0]
def gibbs(self, ITERS, ep1=0, ep2=None, savePosterior=True, gtdiffusion=False):
"""
gtdiffusion: use ground truth center of place field in calculating variance of center. Meaning of diffPerMin different
"""
oo = self
# PRIORS
# priors prefixed w/ _
_f_u = 0; _f_q2 = 1
# inverse gamma
_q2_a = 1e-4; _q2_B = 1e-3
#_plt.plot(q2x, q2x**(-_q2_a-1)*_N.exp(-_q2_B / q2x))
_l0_a = 1.; _l0_B = 1/30. # mean 30Hz peak firing rate
ep2 = oo.epochs if (ep2 == None) else ep2
oo.epochs = ep2-ep1
oo.prmPstMd = _N.zeros((oo.epochs, 3)) # mode of the params
oo.hypPstMd = _N.zeros((oo.epochs, 2+2+2)) # the hyper params
twpi = 2*_N.pi
pcklme = {}
# Gibbs sampling
# parameters l0, f, q2
###################################### GIBBS samples, need for MAP estimate
smp_prms = _N.zeros((3, ITERS, 1))
#
smp_hyps = _N.zeros((6, ITERS, 1))
###################################### INITIAL VALUE OF PARAMS
l0 = 50
q2 = 0.0144
f = 1.1
###################################### GRID for calculating
#### # points in sum.
#### # points in uniform sampling of exp(x)p(x) (non-spike interals)
#### # points in sampling of f for conditional posterior distribution
#### # points in sampling of q2 for conditional posterior distribution
#### NSexp, Nupx, fss, q2ss
# numerical grid
ux = _N.linspace(0, 3, oo.Nupx, endpoint=False) # uniform x position
q2x = _N.exp(_N.linspace(_N.log(0.00005), _N.log(10), oo.q2ss)) # 5 orders of
d_q2x = _N.diff(q2x)
q2x_m1 = _N.array(q2x[0:-1])
lq2x = _N.log(q2x)
iq2x = 1./q2x
q2xr = q2x.reshape((oo.q2ss, 1))
iq2xr = 1./q2xr
sqrt_2pi_q2x = _N.sqrt(twpi*q2x)
l_sqrt_2pi_q2x = _N.log(sqrt_2pi_q2x)
x = oo.dat[:, 0]
q2rate = oo.diffPerEpoch**2 # unit of minutes
###################################### PRECOMPUTED
posbins = _N.linspace(0, 3, oo.Nupx+1)
for epc in xrange(ep1, ep2):
# if i > 0:
# q2x = _N.linspace(0.001, 4, q2ss)
# q2xr = q2x.reshape((q2ss, 1))
# iq2xr = 1./q2xr
#print q2
print "epoch %d" % epc
t0 = oo.intvs[epc]
t1 = oo.intvs[epc+1]
sts = _N.where(oo.dat[t0:t1, 1] == 1)[0]
nts = _N.where(oo.dat[t0:t1, 1] == 0)[0]
if gtdiffusion:
q2rate = (oo.dat[t1-1,2]-oo.dat[t0,2])**2*oo.diffPerMin
NSexp = t1-t0 # length of position data # # of no spike positions to sum
xt0t1 = _N.array(x[t0:t1])
px, xbns = _N.histogram(xt0t1, bins=posbins, normed=True)
nSpks = len(sts)
print "spikes %d" % nSpks
dSilenceX = (NSexp/float(oo.Nupx))*3
for iter in xrange(ITERS):
#print "iter %d" % iter
iiq2 = 1./q2
# prior described by hyper-parameters.
# prior described by function
# likelihood
############### CONDITIONAL f
#q2pr = _f_q2 + q2rate
q2pr = _f_q2 if (_f_q2 > q2rate) else q2rate
if nSpks > 0: # spiking portion likelihood x prior
fs = (1./nSpks)*_N.sum(xt0t1[sts])
fq2 = q2/nSpks
M = (fs*q2pr + + _f_u*fq2) / (q2pr + fq2)
Sg2 = (q2pr*fq2) / (q2pr + fq2)
else:
M = _f_u
Sg2 = q2pr
Sg = _N.sqrt(Sg2)
fx = _N.linspace(M - Sg*50, M + Sg*50, oo.fss)
fxr = fx.reshape((oo.fss, 1))
fxrux = -0.5*(fxr-ux)**2
xI_f = (xt0t1 - fxr)**2*0.5
f_intgrd = _N.exp((fxrux*iiq2)) # integrand
f_exp_px = _N.sum(f_intgrd*px, axis=1) * dSilenceX
# f_exp_px is a function of f
s = -(l0*oo.dt/_N.sqrt(twpi*q2)) * f_exp_px # a function of x
#print Sg2
#print M
funcf = -0.5*((fx-M)*(fx-M))/Sg2 + s
funcf -= _N.max(funcf)
condPosF= _N.exp(funcf)
#print _N.sum(condPosF)
"""
if iter == 0:
fig = _plt.figure()
_plt.plot(fx, condPosF)
_plt.xlim(0.8, 1.3)
_plt.savefig("%(dir)s/condposF%(i)d" % {"dir" : outdir, "i" : i})
_plt.close()
"""
norm = 1./_N.sum(condPosF)
f_u_ = norm*_N.sum(fx*condPosF)
f_q2_ = norm*_N.sum(condPosF*(fx-f_u_)*(fx-f_u_))
f = _N.sqrt(f_q2_)*_N.random.randn() + f_u_
smp_prms[oo.ky_p_f, iter, 0] = f
smp_hyps[oo.ky_h_f_u, iter, 0] = f_u_
smp_hyps[oo.ky_h_f_q2, iter, 0] = f_q2_
#ax1.plot(fx, L_f, color="black")
# ############### CONDITIONAL q2
#xI = (xt0t1-f)*(xt0t1-f)*0.5*iq2xr
q2_intgrd = _N.exp(-0.5*(f - ux)*(f-ux) * iq2xr)
q2_exp_px = _N.sum(q2_intgrd*px, axis=1) * dSilenceX
s = -((l0*oo.dt)/sqrt_2pi_q2x)*q2_exp_px # function of q2
## adjust the prior to reflect how much we think PF can change
_Dq2_a = _q2_a if _q2_a < 200 else 200
_Dq2_B = (_q2_B/(_q2_a+1))*(_Dq2_a+1)
if nSpks > 0:
#print _N.sum((xt0t1[sts]-f)*(xt0t1[sts]-f))/(nSpks-1)
## (1/sqrt(sg2))^S
## (1/x)^(S/2) = (1/x)-(a+1)
## -S/2 = -a - 1 -a = -S/2 + 1 a = S/2-1
xI = (xt0t1[sts]-f)*(xt0t1[sts]-f)*0.5
SL_a = 0.5*nSpks - 1 # spiking part of likelihood
SL_B = _N.sum(xI) # spiking part of likelihood
# spiking prior x prior
sLLkPr = -(_Dq2_a + SL_a + 2)*lq2x - iq2x*(_Dq2_B + SL_B)
else:
sLLkPr = -(_Dq2_a + 1)*lq2x - iq2x*(_Dq2_B)
sat = sLLkPr + s
sat -= _N.max(sat)
condPos = _N.exp(sat)
"""
if iter == 10:
fig = _plt.figure()
_plt.plot(q2x, condPos)
_plt.xlim(0, 0.5)
_plt.savefig("%(dir)s/condpos%(i)d" % {"dir" : outdir, "i" : i})
_plt.close()
"""
q2_a_, q2_B_ = ig_prmsUV(q2x, condPos, d_q2x, q2x_m1, ITER=1)
#print condPos
_plt.plot(q2x, condPos)
q2 = _ss.invgamma.rvs(q2_a_ + 1, scale=q2_B_) # check
#print ((1./nSpks)*_N.sum((xt0t1[sts]-f)*(xt0t1[sts]-f)))
smp_prms[oo.ky_p_q2, iter, 0] = q2
smp_hyps[oo.ky_h_q2_a, iter, 0] = q2_a_
smp_hyps[oo.ky_h_q2_B, iter, 0] = q2_B_
############### CONDITIONAL l0
# _ss.gamma.rvs. uses k, theta k is 1/B (B is our thing)
iiq2 = 1./q2
# xI = (xt0t1-f)*(xt0t1-f)*0.5*iiq2
# BL = (oo.dt/_N.sqrt(twpi*q2))*_N.sum(_N.exp(-xI))
l0_intgrd = _N.exp(-0.5*(f - ux)*(f-ux) * iiq2)
l0_exp_px = _N.sum(l0_intgrd*px) * dSilenceX
BL = (oo.dt/_N.sqrt(twpi*q2))*l0_exp_px
# if iter == 50:
# print "BL %(BL).2f BL2 %(BL2).2f" % {"BL" : BL, "BL2" : BL2}
aL = nSpks
l0_a_ = aL + _l0_a
l0_B_ = BL + _l0_B
l0 = _ss.gamma.rvs(l0_a_ - 1, scale=(1/l0_B_)) # check
### l0 / _N.sqrt(twpi*q2) is f*dt used in createData2
smp_prms[oo.ky_p_l0, iter, 0] = l0
smp_hyps[oo.ky_h_l0_a, iter, 0] = l0_a_
smp_hyps[oo.ky_h_l0_B, iter, 0] = l0_B_
frm = 30
for ip in xrange(3): # params
L = _N.min(smp_prms[ip, frm:, 0]); H = _N.max(smp_prms[ip, frm:, 0])
cnts, bns = _N.histogram(smp_prms[ip, frm:, 0], bins=_N.linspace(L, H, 50))
ib = _N.where(cnts == _N.max(cnts))[0][0]
if ip == oo.ky_p_l0: l0 = oo.prmPstMd[epc, ip] = bns[ib]
elif ip == oo.ky_p_f: f = oo.prmPstMd[epc, ip] = bns[ib]
elif ip == oo.ky_p_q2: q2 = oo.prmPstMd[epc, ip] = bns[ib]
pcklme["cp%d" % epc] = _N.array(smp_prms)
for ip in xrange(6): # hyper params
L = _N.min(smp_hyps[ip, frm:, 0]); H = _N.max(smp_hyps[ip, frm:, 0])
cnts, bns = _N.histogram(smp_hyps[ip, frm:, 0], bins=_N.linspace(L, H, 50))
ib = _N.where(cnts == _N.max(cnts))[0][0]
if ip == oo.ky_h_l0_a: _l0_a = oo.hypPstMd[epc, ip] = bns[ib]
elif ip == oo.ky_h_l0_B: _l0_B = oo.hypPstMd[epc, ip] = bns[ib]
elif ip == oo.ky_h_f_u: _f_u = oo.hypPstMd[epc, ip] = bns[ib]
elif ip == oo.ky_h_f_q2: _f_q2 = oo.hypPstMd[epc, ip] = bns[ib]
elif ip == oo.ky_h_q2_a: _q2_a = oo.hypPstMd[epc, ip] = bns[ib]
elif ip == oo.ky_h_q2_B: _q2_B = oo.hypPstMd[epc, ip] = bns[ib]
if savePosterior:
_N.savetxt(resFN("posParams.dat", dir=oo.outdir), smp_prms[:, :, 0].T, fmt="%.4f %.4f %.4f")
_N.savetxt(resFN("posHypParams.dat", dir=oo.outdir), smp_hyps[:, :, 0].T, fmt="%.4f %.4f %.4f %.4f %.4f %.4f")
pcklme["md"] = _N.array(oo.prmPstMd)
dmp = open(resFN("posteriors.dump", dir=oo.outdir), "wb")
pickle.dump(pcklme, dmp, -1)
dmp.close()
_N.savetxt(resFN("posModes.dat", dir=oo.outdir), oo.prmPstMd, fmt="%.4f %.4f %.4f")
_N.savetxt(resFN("hypModes.dat", dir=oo.outdir), oo.hypPstMd, fmt="%.4f %.4f %.4f %.4f %.4f %.4f")
def figs(self, ep1=0, ep2=None):
oo = self
ep2 = oo.epochs if (ep2 == None) else ep2
fig = _plt.figure(figsize=(8, 9))
mnUs = _N.empty(ep2-ep1)
mnL0s = _N.empty(ep2-ep1)
mnSq2s = _N.empty(ep2-ep1)
for epc in xrange(ep1, ep2):
t0 = oo.intvs[epc]
t1 = oo.intvs[epc+1]
sts = _N.where(oo.dat[t0:t1, 1] == 1)[0]
mnUs[epc-ep1] = _N.mean(oo.datprms[t0:t1, 0])
mnSq2s[epc-ep1] = _N.mean(oo.datprms[t0:t1, 1])
mnL0s[epc-ep1] = _N.mean(oo.datprms[t0:t1, 2])
fig.add_subplot(3, 1, 1)
_plt.plot(mnUs)
_plt.plot(oo.prmPstMd[:, oo.ky_p_f])
fig.add_subplot(3, 1, 2)
_plt.plot(mnL0s)
_plt.plot(oo.prmPstMd[:, oo.ky_p_l0])
fig.add_subplot(3, 1, 3)
_plt.plot(mnSq2s)
_plt.plot(oo.prmPstMd[:, oo.ky_p_q2])
_plt.savefig(resFN("cmpModesGT", dir=oo.outdir))