def show_posmarks(dec, setname, ylim=None, win=None, singles=False, baseFN=None): MTHR = 0.001 # how much smaller is mixture compared to maximum for nt in xrange(dec.nTets): if not singles: fig = _plt.figure(figsize=(7, 5)) for k in xrange(1, dec.mdim+1): if singles: fig = _plt.figure(figsize=(4, 3)) ax = fig.add_subplot(1, 1, 1) else: ax = fig.add_subplot(2, 2, k) """ for l in xrange(dec.tt0, dec.tt1): if (dec.marks[l, nt] is not None): x.append(dec.pos[l]) y.append(dec.marks[l, nt][0][k-1]) """ if dec.marksObserved[nt] > 0: _plt.scatter(dec.tr_pos[nt], dec.tr_marks[nt][:, k-1], color="black", s=2) #_plt.scatter(dec.mvNrm[nt].us[:, 0], dec.mvNrm[nt].us[:, k], color="red", s=30) mThr = MTHR * _N.max(dec.mvNrm[nt].ms) for m in xrange(dec.M): if dec.mvNrm[nt].ms[m, 0] >= mThr: ux = dec.mvNrm[nt].us[m, 0] # position uy = dec.mvNrm[nt].us[m, k] ex_x = _N.sqrt(dec.mvNrm[nt].covs[m, 0, 0]) ex_y = _N.sqrt(dec.mvNrm[nt].covs[m, k, k]) _plt.plot([ux-ex_x, ux+ex_x], [uy, uy], color="red", lw=2) _plt.plot([ux, ux], [uy-ex_y, uy+ex_y], color="red", lw=2) _plt.scatter(dec.mvNrm[nt].us[m, 0], dec.mvNrm[nt].us[m, k], color="red", s=30) _plt.xlim(-6, 6) if ylim is not None: _plt.ylim(ylim[0], ylim[1]) if singles: _plt.suptitle("k=%(k)d t0=%(2).2fs : t1=%(3).2fs" % {"2" : (dec.tt0/1000.), "3" : (dec.tt1/1000.), "k" : k}) fn= baseFN if (dec.usetets is None) else "%(bf)s_tet%(t)s" % {"bf" : baseFN, "t" : dec.usetets[nt]} mF.arbitraryAxes(ax) mF.setLabelTicks(_plt, xlabel="position", ylabel="mark", xtickFntSz=14, ytickFntSz=14, xlabFntSz=16, ylabFntSz=16) fig.subplots_adjust(left=0.2, bottom=0.2, top=0.85) _plt.savefig(resFN("%(1)s_win=%(w)d.png" % {"1" : fn, "w" : win}, dir=setname), transparent=True) _plt.close() if not singles: _plt.suptitle("t0=%(2)d,t1=%(3)d" % {"2" : dec.tt0, "3" : dec.tt1}) fn= baseFN if (dec.usetets is None) else "%(bf)s_tet%(t)s" % {"bf" : baseFN, "t" : dec.usetets[nt]} _plt.savefig(resFN("%(1)s_win=%(w)d.png" % {"1" : fn, "w" : win}, dir=setname, create=True), transparent=True) _plt.close()
def show_posmarksCNTR(dec, setname, mvNrm, ylim=None, win=None, singles=False, showScatter=True, baseFN="look", scatskip=1): for nt in xrange(dec.nTets): if not singles: fig = _plt.figure(figsize=(7, 5)) for k in xrange(1, dec.mdim+1): if singles: fig = _plt.figure(figsize=(4, 3)) ax = fig.add_subplot(1, 1, 1) else: ax = fig.add_subplot(2, 2, k) """ for l in xrange(dec.tt0, dec.tt1): if (dec.marks[l, nt] is not None): x.append(dec.pos[l]) y.append(dec.marks[l, nt][0][k-1]) """ _plt.xlim(-6, 6) if ylim is not None: _plt.ylim(ylim[0], ylim[1]) else: ylim = _N.empty(2) ylim[0] = _N.min(dec.tr_marks[nt][:, k-1]) ylim[1] = _N.max(dec.tr_marks[nt][:, k-1]) yAMP = ylim[1] - ylim[0] ylim[0] -= 0.1*yAMP ylim[1] += 0.1*yAMP if showScatter and dec.marksObserved[nt] > 0: _plt.scatter(dec.tr_pos[nt][::scatskip], dec.tr_marks[nt][::scatskip, k-1], color="grey", s=1) img = mvNrm.evalAll(1000, k-1, ylim=ylim) _plt.imshow(img, origin="lower", extent=(-6, 6, ylim[0], ylim[1]), cmap=_plt.get_cmap("Reds")) if singles: _plt.suptitle("k=%(k)d t0=%(2).2fs : t1=%(3).2fs" % {"2" : (dec.tt0/1000.), "3" : (dec.tt1/1000.), "k" : k}) fn= baseFN if (dec.usetets is None) else "%(fn)s_tet%(tets)s" % {"fn" : baseFN, "tets" : dec.usetets[nt]} mF.arbitraryAxes(ax) mF.setLabelTicks(_plt, xlabel="position", ylabel="mark", xtickFntSz=14, ytickFntSz=14, xlabFntSz=16, ylabFntSz=16) fig.subplots_adjust(left=0.2, bottom=0.2, top=0.85) _plt.savefig(resFN("%(1)s_win=%(w)d.png" % {"1" : fn, "w" : win}, dir=setname), transparent=True) _plt.close() if not singles: _plt.suptitle("t0=%(2)d,t1=%(3)d" % {"2" : dec.tt0, "3" : dec.tt1}) fn= baseFN if (dec.usetets is None) else "%(fn)s_tet%(tets)s" % {"fn" : baseFN, "tets" : dec.usetets[nt]} _plt.savefig(resFN("%(1)s_win=%(w)d.png" % {"1" : fn, "w" : win}, dir=setname, create=True), transparent=True) _plt.close()
def figs(self, ep1=0, ep2=None): oo = self ep2 = oo.epochs if (ep2 == None) else ep2 fig = _plt.figure(figsize=(8, 9)) mnUs = _N.empty(ep2-ep1) mnL0s = _N.empty(ep2-ep1) mnSq2s = _N.empty(ep2-ep1) for epc in xrange(ep1, ep2): t0 = oo.intvs[epc] t1 = oo.intvs[epc+1] sts = _N.where(oo.dat[t0:t1, 1] == 1)[0] mnUs[epc-ep1] = _N.mean(oo.dat[t0:t1, 2]) mnSq2s[epc-ep1] = _N.mean(oo.dat[t0:t1, 3]) mnL0s[epc-ep1] = _N.mean(oo.dat[t0:t1, 4]) fig.add_subplot(3, 1, 1) _plt.plot(mnUs) _plt.plot(oo.prmPstMd[:, oo.ky_p_f]) fig.add_subplot(3, 1, 2) _plt.plot(mnL0s) _plt.plot(oo.prmPstMd[:, oo.ky_p_l0]) fig.add_subplot(3, 1, 3) _plt.plot(mnSq2s) _plt.plot(oo.prmPstMd[:, oo.ky_p_q2]) _plt.savefig(resFN("cmpModesGT", dir=oo.outdir))
def showMarginalMarkDistributions(dec, setname, mklim=[-6, 8], dk=0.1): for tet in xrange(dec.nTets): ### marginalize tetrode marks mrgidx = _N.array([1, 2, 3, 4]) xp = _N.linspace(-6, 6, 121) fig = _plt.figure(figsize=(13, 12)) fig.add_subplot(3, 2, 1) p = _N.zeros(121) for m in xrange(dec.M): mn, mcov = mvn.marginalPDF(dec.mvNrm[tet].us[m], dec.mvNrm[tet].covs[m], mrgidx) p += dec.mvNrm[tet].ms[m]/_N.sqrt(2*_N.pi*mcov[0,0]) *_N.exp(-0.5*(xp - mn[0])**2 / mcov[0, 0]) x =_plt.hist(dec.tr_pos[tet], bins=_N.linspace(-6, 6, 121), normed=True, color="black") _plt.plot(xp, (p/_N.sum(p))*10, color="red", lw=2) ### marginalize position + 3 tetrode marks allinds = _N.arange(5) bins = _N.linspace(mklim[0], mklim[1], (mklim[1]-mklim[0])*(1./dk)+1) for shk in xrange(1, 5): fig.add_subplot(3, 2, shk+2) mrgidx = _N.setdiff1d(allinds, _N.array([shk])) p = _N.zeros(len(bins)) for m in xrange(dec.M): mn, mcov = mvn.marginalPDF(dec.mvNrm[tet].us[m], dec.mvNrm[tet].covs[m], mrgidx) p += dec.mvNrm[tet].ms[m]/_N.sqrt(2*_N.pi*mcov[0,0]) *_N.exp(-0.5*(bins - mn[0])**2 / mcov[0, 0]) x =_plt.hist(dec.tr_marks[tet][:, shk-1], bins=bins, normed=True, color="black") _plt.plot(bins, (p/_N.sum(p))*(1./dk), color="red", lw=2) fn= "margDists" if (dec.usetets is None) else "margDists%s" % dec.usetets[tet] _plt.savefig(resFN(fn, dir=setname)) _plt.close()
def showTrajectory(dec, t0, t1, ep, setname, dir): fig = _plt.figure(figsize=(14, 7)) ax = fig.add_subplot(1, 1, 1) _plt.imshow(dec.pX_Nm[t0:t1].T, aspect=(0.5*(t1-t0)/50.), cmap=_plt.get_cmap("Reds")) _plt.plot(_N.linspace(t0-t0, t1-t0, t1-t0), (dec.xA+dec.pos[t0:t1])/dec.dxp, color="grey", lw=3, ls="--") #_plt.plot(_N.linspace(float(t0)/1000., float(t1)/1000., t1-t0), (dec.xA+dec.pos[t0:t1])/dec.dxp, color="red", lw=2) #print (float(t0)/1000) #print (float(t1)/1000) _plt.xlim(0, t1-t0) _plt.ylim(-(dec.nTets*4), 50) #_plt.xticks(_N.arange(0, t1-t0, 2000), _N.arange(t0, t1, 2000, dtype=_N.float)/1000) dt = int((((int(t1/1000.)*1000) - (int(t0/1000.)*1000))/4.)/1000.)*1000 stT0 = t0 - int(t0/1000.)*1000 enT1 = t1 - int(t1/1000.)*1000 #_plt.xticks(_N.arange(0, t1-t0, dt), _N.arange(t0, t1, dt, dtype=_N.float)/1000) _plt.xticks(_N.arange(stT0, t1-t0, dt), _N.arange(int(t0/1000.)*1000, int(t1/1000.)*1000, dt, dtype=_N.int)/1000) #_plt.locator_params(nbins=6, axis="x") _plt.yticks(_N.linspace(0, 50, 5), [-6, -3, 0, 3, 6]) mF.arbitaryAxes(ax, axesVis=[False, False, False, False], x_tick_positions="bottom", y_tick_positions="left") mF.setLabelTicks(_plt, xlabel="Time (sec.)", ylabel="Position", xtickFntSz=30, ytickFntSz=30, xlabFntSz=32, ylabFntSz=32) x = [] y = [] for nt in xrange(dec.nTets): x.append([]) y.append([]) for t in xrange(t0, t1): for nt in xrange(dec.nTets): if dec.marks[t, nt] is not None: x[nt].append(t-t0) y[nt].append(-1.5 - 3*nt) for nt in xrange(dec.nTets): _plt.plot(x[nt], y[nt], ls="", marker="|", ms=15, color="black") fig.subplots_adjust(bottom=0.15, left=0.15) _plt.savefig(resFN("decode_%(uts)s_%(mth)s_win=%(e)d.eps" % {"e" : (ep/2), "mth" : dec.decmth, "uts" : dec.utets_str, "dir" : dir}, dir=setname, create=True)) _plt.close()
def gibbs(self, ITERS, M, ep1=0, ep2=None, savePosterior=True, gtdiffusion=False): """ gtdiffusion: use ground truth center of place field in calculating variance of center. Meaning of diffPerMin different """ oo = self # PRIORS # priors prefixed w/ _ #_f_u = _N.zeros(M); _f_q2 = _N.ones(M) # wide _f_u = _N.linspace(oo.xLo+1, oo.xHi-1, M); _f_q2 = _N.ones(M) # wide # inverse gamma _q2_a = _N.ones(M)*2.1; _q2_B = _N.ones(M)*1e-2 #_plt.plot(q2x, q2x**(-_q2_a-1)*_N.exp(-_q2_B / q2x)) _l0_a = _N.ones(M); _l0_B = _N.zeros(M)*(1/30.) ep2 = oo.epochs if (ep2 == None) else ep2 oo.epochs = ep2-ep1 #oo.prmPstMd = _N.zeros((oo.epochs, 3, M)) # mode of the params oo.prmPstMd = _N.zeros((oo.epochs, 3*M)) # mode of the params #oo.hypPstMd = _N.zeros((oo.epochs, (2+2+2), M)) # the hyper params oo.hypPstMd = _N.zeros((oo.epochs, (2+2+2)*M)) # the hyper params twpi = 2*_N.pi pcklme = {} # Gibbs sampling # parameters l0, f, q2 ###################################### GIBBS samples, need for MAP estimate smp_prms = _N.zeros((3, ITERS, M)) # smp_hyps = _N.zeros((6, ITERS, M)) ###################################### INITIAL VALUE OF PARAMS l0 = _N.array([1.,]*M) q2 = _N.array([0.0144]*M) f = _N.array([1.1]*M) oo.f_ = f oo.q2_ = q2 oo.l0_ = l0 ###################################### GRID for calculating #### # points in sum. #### # points in uniform sampling of exp(x)p(x) (non-spike interals) #### # points in sampling of f for conditional posterior distribution #### # points in sampling of q2 for conditional posterior distribution #### NSexp, Nupx, fss, q2ss # numerical grid ux = _N.linspace(oo.xLo, oo.xHi, oo.Nupx, endpoint=False) # uniform x position q2x = _N.exp(_N.linspace(_N.log(0.000001), _N.log(400), oo.q2ss)) # 5 orders of #q2x = _N.exp(_N.linspace(_N.log(0.0001), _N.log(100), oo.q2ss)) # 5 orders of d_q2x = _N.diff(q2x) q2x_m1 = _N.array(q2x[0:-1]) lq2x = _N.log(q2x) iq2x = 1./q2x q2xr = q2x.reshape((oo.q2ss, 1)) iq2xr = 1./q2xr sqrt_2pi_q2x = _N.sqrt(twpi*q2x) l_sqrt_2pi_q2x = _N.log(sqrt_2pi_q2x) x = oo.dat[:, 0] q2rate = oo.diffPerEpoch**2 # unit of minutes ###################################### PRECOMPUTED posbins = _N.linspace(oo.xLo, oo.xHi, oo.Nupx+1) rat = _N.zeros(M+1) pc = _N.zeros(M) tempSlnc = _N.empty((oo.q2ss, 3)) for epc in xrange(ep1, ep2): f = 3*_N.random.rand(M) q2 = 3*_N.random.rand(M)*0.1 l0 = 3*_N.random.rand(M) #print q2 print "epoch %d" % epc t0 = oo.intvs[epc] t1 = oo.intvs[epc+1] Asts = _N.where(oo.dat[t0:t1, 1] == 1)[0] # based at 0 Ants = _N.where(oo.dat[t0:t1, 1] == 0)[0] if gtdiffusion: # tell me how much diffusion to expect per min. # [(EXPECT) x (# of minutes between epochs)]**2 q2rate = (oo.dat[t1-1,2]-oo.dat[t0,2])**2*oo.diffPerMin NSexp = t1-t0 # length of position data # # of no spike positions to sum xt0t1 = _N.array(x[t0:t1]) px, xbns = _N.histogram(xt0t1, bins=posbins, normed=True) nSpks = len(Asts) gz = _N.zeros((ITERS, nSpks, M), dtype=_N.bool) print "spikes %d" % nSpks dSilenceX = (NSexp/float(oo.Nupx))*(oo.xHi-oo.xLo) #dSilenceX = (NSexp/float(oo.Nupx))*3 xAS = x[Asts + t0] # position @ spikes xASr = xAS.reshape((1, nSpks)) econt = _N.empty((M, nSpks)) rat = _N.zeros((M+1, nSpks)) print "-^---------" for iter in xrange(ITERS): if (iter % 100) == 0: print iter fr = f.reshape((M, 1)) iq2 = 1./q2 iq2r = iq2.reshape((M, 1)) try: pkFR = _N.log(l0/_N.sqrt(twpi*q2)) except Warning: print "WARNING" print l0 print q2 pkFRr = pkFR.reshape((M, 1)) rnds = _N.random.rand(nSpks) cont = pkFRr - 0.5*(fr - xASr)*(fr - xASr)*iq2r mcontr = _N.max(cont, axis=0).reshape((1, nSpks)) cont -= mcontr _N.exp(cont, out=econt) for m in xrange(M): rat[m+1] = rat[m] + econt[m] rat /= rat[M] M1 = rat[1:] >= rnds M2 = rat[0:-1] <= rnds gz[iter] = (M1&M2).T for m in xrange(M): iiq2 = 1./q2[m] sts = Asts[_N.where(gz[iter, :, m] == 1)[0]] #print sts nSpksM = len(sts) # prior described by hyper-parameters. # prior described by function # likelihood ############### CONDITIONAL f q2pr = _f_q2[m] if (_f_q2[m] > q2rate) else q2rate if nSpksM > 0: # spiking portion likelihood x prior fs = (1./nSpksM)*_N.sum(xt0t1[sts]) fq2 = q2[m]/nSpksM U = (fs*q2pr + _f_u[m]*fq2) / (q2pr + fq2) FQ2 = (q2pr*fq2) / (q2pr + fq2) else: U = _f_u[m] FQ2 = q2pr FQ = _N.sqrt(FQ2) fx = _N.linspace(U - FQ*60, U + FQ*60, oo.fss) fxr = fx.reshape((oo.fss, 1)) fxrux = -0.5*(fxr-ux)*(fxr-ux) #xI_f = (xt0t1 - fxr)**2*0.5 f_intgrd = _N.exp((fxrux*iiq2)) # integrand f_exp_px = _N.sum(f_intgrd*px, axis=1) * dSilenceX # f_exp_px is a function of f slnc = -(l0[m]*oo.dt/_N.sqrt(twpi*q2[m])) * f_exp_px # a function of x funcf = -0.5*((fx-U)*(fx-U))/FQ2 + slnc funcf -= _N.max(funcf) condPosF= _N.exp(funcf) #print _N.sum(condPosF) norm = 1./_N.sum(condPosF) f_u_ = norm*_N.sum(fx*condPosF) f_q2_ = norm*_N.sum(condPosF*(fx-f_u_)*(fx-f_u_)) f[m] = _N.sqrt(f_q2_)*_N.random.randn() + f_u_ smp_prms[oo.ky_p_f, iter, m] = f[m] smp_hyps[oo.ky_h_f_u, iter, m] = f_u_ smp_hyps[oo.ky_h_f_q2, iter, m] = f_q2_ # ############### CONDITIONAL q2 #xI = (xt0t1-f)*(xt0t1-f)*0.5*iq2xr q2_intgrd = _N.exp(-0.5*(f[m] - ux)*(f[m]-ux) * iq2xr) q2_exp_px = _N.sum(q2_intgrd*px, axis=1) * dSilenceX slnc = -((l0[m]*oo.dt)/sqrt_2pi_q2x)*q2_exp_px # function of q2 #print "s %.3e" % s _Dq2_a = _q2_a[m]# if _q2_a[m] < 200 else 200 _Dq2_B = _q2_B[m]#(_q2_B[m]/(_q2_a[m]+1))*(_Dq2_a+1) if nSpksM > 0: #print _N.sum((xt0t1[sts]-f)*(xt0t1[sts]-f))/(nSpks-1) ## (1/sqrt(sg2))^S ## (1/x)^(S/2) = (1/x)-(a+1) ## -S/2 = -a - 1 -a = -S/2 + 1 a = S/2-1 xI = (xt0t1[sts]-f[m])*(xt0t1[sts]-f[m])*0.5 SL_a = 0.5*nSpksM - 1 # spiking part of likelihood SL_B = _N.sum(xI) # spiking part of likelihood # spiking prior x prior sLLkPr = -(_q2_a[m] + SL_a + 2)*lq2x - iq2x*(_q2_B[m] + SL_B) else: sLLkPr = -(_q2_a[m] + 1)*lq2x - iq2x*_q2_B[m] sat = sLLkPr + slnc sat -= _N.max(sat) condPos = _N.exp(sat) q2_a_, q2_B_ = ig_prmsUV(q2x, condPos, d_q2x, q2x_m1, ITER=1) if q2_a_ > 10000: putit = _N.empty((oo.q2ss, 4)) putit[:, 0] = q2x putit[:, 1] = sLLkPr putit[:, 2] = slnc putit[:, 3] = condPos _N.savetxt("putit", putit, fmt="%.4e %.4e %.4e %.4e") print _q2_a[m] print SL_a print _q2_B[m] print SL_B print "------------ %d" % nSpksM assert q2_a_ < 10000, "q2_a_ too big" print "it %(it)d q2 m: %(m)d nS: %(ns)d q2_a_ %(a) .3e, q2_B_ %(B) .3e" % {"a" : q2_a_, "B" : q2_B_, "m" : m, "ns" : nSpksM, "it" : iter} # domain error when q2_a_ and q2_B_ is nan # if nSpksM == 0: # tempSlnc[:, 0] = condPos # tempSlnc[:, 1] = slnc # tempSlnc[:, 2] = sLLkPr # _N.savetxt(resFN("slnc_nspks0_%(it)d_%(m)d" % {"it" : iter, "m" : m}, dir=oo.outdir), tempSlnc) # if (q2_a_ > 10e10) or (q2_B_ > 10e10): # _N.savetxt("badcondposBIG", condPos) # _N.savetxt("slncBIG", slnc) # _N.savetxt("sLLkPrBIG", sLLkPr) # print "BIG it %(it)d q2_a_ %(1).3e q2_B_ %(2).3e" % {"1" : q2_a_, "2" : q2_B_, "it" : iter} # print "m is %(m)d nSpksM is %(ns)d" % {"m" : m, "ns" : nSpksM} try: q2[m] = _ss.invgamma.rvs(q2_a_ + 1, scale=q2_B_) # check except ValueError: fig = _plt.figure() fig.add_subplot(2, 1, 1) _N.savetxt("badcondpos", condPos) _N.savetxt("slnc", slnc) _N.savetxt("sLLkPr", sLLkPr) _plt.plot(condPos) _plt.ylim(-0.1, 1.1) fig.add_subplot(2, 1, 2) _plt.plot(_N.log(condPos)) print "ValueError it %(it)d q2_a_ %(1).3e q2_B_ %(2).3e" % {"1" : q2_a_, "2" : q2_B_, "it" : iter} print "m is %(m)d nSpksM is %(ns)d" % {"m" : m, "ns" : nSpksM} raise #print ((1./nSpks)*_N.sum((xt0t1[sts]-f)*(xt0t1[sts]-f))) smp_prms[oo.ky_p_q2, iter, m] = q2[m] smp_hyps[oo.ky_h_q2_a, iter, m] = q2_a_ smp_hyps[oo.ky_h_q2_B, iter, m] = q2_B_ #print "l0 1" ############### CONDITIONAL l0 # _ss.gamma.rvs. uses k, theta k is 1/B (B is our thing) iiq2 = 1./q2[m] # xI = (xt0t1-f)*(xt0t1-f)*0.5*iiq2 # BL = (oo.dt/_N.sqrt(twpi*q2))*_N.sum(_N.exp(-xI)) l0_intgrd = _N.exp(-0.5*(f[m] - ux)*(f[m]-ux) * iiq2) l0_exp_px = _N.sum(l0_intgrd*px) * dSilenceX BL = (oo.dt/_N.sqrt(twpi*q2[m]))*l0_exp_px # if iter == 50: # print "BL %(BL).2f BL2 %(BL2).2f" % {"BL" : BL, "BL2" : BL2} #print "l0 2" #_Dl0_a = _l0_a[m] if _l0_a[m] < 400 else 400 _Dl0_a = _l0_a[m] if _l0_a[m] < 25 else 25 _Dl0_B = (_l0_B[m]/_l0_a[m]) * _Dl0_a # a'/B' = a/B # B' = (B/a)a' aL = nSpksM l0_a_ = aL + _Dl0_a l0_B_ = BL + _Dl0_B #print "l0_a_ %(a).3e l0_B_ %(B).3e" % {"a" : l0_a_, "B" : l0_B_} #print "l0 3" if (l0_B_ > 0) and (l0_a_ > 1): l0[m] = _ss.gamma.rvs(l0_a_ - 1, scale=(1/l0_B_)) # check ### l0 / _N.sqrt(twpi*q2) is f*dt used in createData2 smp_prms[oo.ky_p_l0, iter, m] = l0[m] smp_hyps[oo.ky_h_l0_a, iter, m] = l0_a_ smp_hyps[oo.ky_h_l0_B, iter, m] = l0_B_ frm = int(0.6*ITERS) # have to test for stationarity #print _N.sum(_N.mean(gz[frm:], axis=0), axis=0) #print "f[0] %(1).3f f[1] %(2).3f" % {"1" : f[0], "2" : f[1]} #print "here" fig = _plt.figure(figsize=(8, 4)) for m in xrange(M): #print smp_prms[oo.ky_p_f, frm:, m] for ip in xrange(3): # params L = _N.min(smp_prms[ip, frm:, m]); H = _N.max(smp_prms[ip, frm:, m]) cnts, bns = _N.histogram(smp_prms[ip, frm:, m], bins=_N.linspace(L, H, 50)) if oo.polyFit: xfit = 0.5*(bns[0:-1] + bns[1:]) yfit = cnts ac = _N.polyfit(xfit, yfit, 2) #a[0]*x^2 + a[1]*x + a[2] if ac[0] < 0: # found a maximum xMAP = -ac[1] / (2*ac[0]) else: ib = _N.where(cnts == _N.max(cnts))[0][0] xMAP = bns[ib] else: ib = _N.where(cnts == _N.max(cnts))[0][0] xMAP = bns[ib] col = 3*m+ip if ip == oo.ky_p_l0: l0[m] = oo.prmPstMd[epc, col] = xMAP elif ip == oo.ky_p_f: f[m] = oo.prmPstMd[epc, col] = xMAP elif ip == oo.ky_p_q2: q2[m] = oo.prmPstMd[epc, col] = xMAP pcklme["cp%d" % epc] = _N.array(smp_prms) for m in xrange(M): for ip in xrange(6): # hyper params L = _N.min(smp_hyps[ip, frm:, m]); H = _N.max(smp_hyps[ip, frm:, m]) cnts, bns = _N.histogram(smp_hyps[ip, frm:, m], bins=_N.linspace(L, H, 50)) if oo.polyFit: xfit = 0.5*(bns[0:-1] + bns[1:]) yfit = cnts ac = _N.polyfit(xfit, yfit, 2) #a[0]*x^2 + a[1]*x + a[2] #y = a[0]x^2 + a[1]x + a[2] #y' = 2a[0]x + a[1] #y''= 2a[0] if a[0] if ac[0] < 0: # found a maximum xMAP = -ac[1] / (2*ac[0]) else: ib = _N.where(cnts == _N.max(cnts))[0][0] xMAP = bns[ib] else: ib = _N.where(cnts == _N.max(cnts))[0][0] xMAP = bns[ib] col = 6*m+ip if ip == oo.ky_h_l0_a: _l0_a[m] = oo.hypPstMd[epc, col] = xMAP elif ip == oo.ky_h_l0_B: _l0_B[m] = oo.hypPstMd[epc, col] = xMAP elif ip == oo.ky_h_f_u: _f_u[m] = oo.hypPstMd[epc, col] = xMAP elif ip == oo.ky_h_f_q2: _f_q2[m] = oo.hypPstMd[epc, col] = xMAP elif ip == oo.ky_h_q2_a: _q2_a[m] = oo.hypPstMd[epc, col] = xMAP elif ip == oo.ky_h_q2_B: _q2_B[m] = oo.hypPstMd[epc, col] = xMAP ### hack here. If we don't reset the prior for ### what happens when a cluster is unused? ### l0 -> 0, and at the same time, the variance increases. ### the prior then gets pushed to large values, but ### then it becomes difficult to bring it back to small ### values once that cluster becomes used again. So ### we would like unused clusters to have l0->0, but keep the ### variance small. That's why we will reset a cluster occ = _N.mean(gz[ITERS-1], axis=0) print occ for m in xrange(M): if (occ[m] == 0) and (l0[m] / _N.sqrt(twpi*q2[m]) < 1): print "resetting" _q2_a[m] = 1e-4 _q2_B[m] = 1e-3 pcklme["gz"] = gz pcklme["smp_hyps"] = smp_hyps pcklme["smp_prms"] = smp_prms pcklme["prmPstMd"] = oo.prmPstMd pcklme["intvs"] = oo.intvs dmp = open(resFN("posteriors.dmp", dir=oo.outdir), "wb") pickle.dump(pcklme, dmp, -1) dmp.close()
def finish_epoch(oo, nSpks, epc, ITERS, gz, l0, f, q2, u, Sg, _f_u, _f_q2, _q2_a, _q2_B, _l0_a, _l0_B, _u_u, _u_Sg, _Sg_nu, _Sg_PSI, smp_sp_hyps, smp_sp_prms, smp_mk_hyps, smp_mk_prms, freeClstr, M, K): # finish epoch doesn't deal with noise cluster tt2 = _tm.time() gkMAP = gauKer(2) frm = int(0.7*ITERS) # have to test for stationarity if nSpks > 0: # ITERS x nSpks x M occ = _N.mean(_N.mean(gz[frm:ITERS-1], axis=0), axis=0) oo.smp_sp_hyps = smp_sp_hyps oo.smp_sp_prms = smp_sp_prms oo.smp_mk_hyps = smp_mk_hyps oo.smp_mk_prms = smp_mk_prms l_trlsNearMAP = [] MAPvalues2(epc, smp_sp_prms, oo.sp_prmPstMd, frm, ITERS, M, 3, occ, gkMAP, l_trlsNearMAP) l0[0:M] = oo.sp_prmPstMd[epc, oo.ky_p_l0::3] f[0:M] = oo.sp_prmPstMd[epc, oo.ky_p_f::3] q2[0:M] = oo.sp_prmPstMd[epc, oo.ky_p_q2::3] MAPvalues2(epc, smp_sp_hyps, oo.sp_hypPstMd, frm, ITERS, M, 6, occ, gkMAP, None) _f_u[:] = oo.sp_hypPstMd[epc, oo.ky_h_f_u::6] _f_q2[:] = oo.sp_hypPstMd[epc, oo.ky_h_f_q2::6] _q2_a[:] = oo.sp_hypPstMd[epc, oo.ky_h_q2_a::6] _q2_B[:] = oo.sp_hypPstMd[epc, oo.ky_h_q2_B::6] _l0_a[:] = oo.sp_hypPstMd[epc, oo.ky_h_l0_a::6] _l0_B[:] = oo.sp_hypPstMd[epc, oo.ky_h_l0_B::6] #pcklme["cp%d" % epc] = _N.array(smp_sp_prms) #trlsNearMAP = _N.array(list(set(trlsNearMAP_D)))+frm # use these trials to pick out posterior params for MARK part #oo.mk_prmPstMd = [ epochs, M, K # epochs, M, K, K ] #oo.mk_hypPstMd = [ epochs, M, K # epochs, M, K, K # epochs, M, 1 # epochs, M, K, K #smp_mk_prms = [ K, ITERS, M # K, K, ITERS, M #smp_mk_hyps = [ K, ITERS, M # K, K, ITERS, M # 1, ITERS, M # K, K, ITERS, M ## params and hyper parms for mark for m in xrange(M): MAPtrls = l_trlsNearMAP[m] if len(MAPtrls) == 0: # none of them. causes nan in mean MAPtrls = _N.arange(frm, ITERS, 10) #print MAPtrls u[m] = _N.median(smp_mk_prms[0][:, frm:, m], axis=1) Sg[m] = _N.mean(smp_mk_prms[1][:, :, frm:, m], axis=2) oo.mk_prmPstMd[oo.ky_p_u][epc, m] = u[m] oo.mk_prmPstMd[oo.ky_p_Sg][epc, m]= Sg[m] _u_u[m] = _N.mean(smp_mk_hyps[oo.ky_h_u_u][:, frm:, m], axis=1) _u_Sg[m] = _N.mean(smp_mk_hyps[oo.ky_h_u_Sg][:, :, frm:, m], axis=2) _Sg_nu[m] = _N.mean(smp_mk_hyps[oo.ky_h_Sg_nu][0, frm:, m], axis=0) _Sg_PSI[m] = _N.mean(smp_mk_hyps[oo.ky_h_Sg_PSI][:, :, frm:, m], axis=2) oo.mk_hypPstMd[oo.ky_h_u_u][epc, m] = _u_u[m] oo.mk_hypPstMd[oo.ky_h_u_Sg][epc, m] = _u_Sg[m] oo.mk_hypPstMd[oo.ky_h_Sg_nu][epc, m] = _Sg_nu[m] oo.mk_hypPstMd[oo.ky_h_Sg_PSI][epc, m]= _Sg_PSI[m] #print _u_Sg[m] u[0:M] = oo.mk_prmPstMd[oo.ky_p_u][epc] Sg[0:M] = oo.mk_prmPstMd[oo.ky_p_Sg][epc] ### hack here. If we don't reset the prior for ### what happens when a cluster is unused? ### l0 -> 0, and at the same time, the variance increases. ### the prior then gets pushed to large values, but ### then it becomes difficult to bring it back to small ### values once that cluster becomes used again. So ### we would like unused clusters to have l0->0, but keep the ### variance small. That's why we will reset a cluster sq25 = 5*_N.sqrt(q2) if M > 1: occ = _N.mean(_N.sum(gz[frm:], axis=1), axis=0) # avg. # of marks assigned to this cluster socc = _N.sort(occ) minAss = (0.5*(socc[-2]+socc[-1])*0.01) # if we're 100 times smaller than the average of the top 2, let's consider it empty if oo.resetClus and (M > 1): for m in xrange(M): # Sg and q2 are treated differently. Even if no spikes are # observed, q2 is updated, while Sg is not. # This is because NO spikes in physical space AND trajectory # information contains information about the place field. # However, in mark space, not observing any marks tells you # nothing about the mark distribution. That is why f, q2 # are updated when there are no spikes, but u and Sg are not. if q2[m] < 0: print "????????????????" print q2 print "q2[%(m)d] = %(q2).3f" % {"m" : m, "q2" : q2[m]} print smp_sp_prms[0, :, m] print smp_sp_prms[1, :, m] print smp_sp_prms[2, :, m] print smp_sp_hyps[4, :, m] print smp_sp_hyps[5, :, m] if ((occ[m] < minAss) and (l0[m] / _N.sqrt(twpi*q2[m]) < 1)) or \ (f[m] < oo.xLo-sq25[m]) or \ (f[m] > oo.xHi+sq25[m]): print "resetting cluster %(m)d %(l0).3f %(f).3f" % {"m" : m, "l0" : (l0[m] / _N.sqrt(twpi*q2[m])), "f" : f[m]} _q2_a[m] = 1e-4 _q2_B[m] = 1e-3 _f_q2[m] = 4 _u_Sg[m] = _N.identity(K)*9 _l0_a[m] = 1e-4 freeClstr[m] = True else: freeClstr[m] = False rsmp_sp_prms = smp_sp_prms.swapaxes(1, 0).reshape(ITERS, 3*M, order="F") _N.savetxt(resFN("posParams_%d.dat" % epc, dir=oo.outdir), rsmp_sp_prms, fmt=("%.4f %.4f %.4f " * M)) # the params for the non-noise
def stochasticAssignment(oo, epc, it, Msc, M, K, l0, f, q2, u, Sg, _f_u, _u_u, _f_q2, _u_Sg, Asts, t0, mASr, xASr, rat, econt, gz, qdrMKS, freeClstr, hashthresh, cmp2Existing, nthrds=1): # Msc Msc signal clusters # M all clusters, including nz clstr. M == Msc when not using nzclstr # Gibbs sampling # parameters l0, f, q2 # mASr, xASr just the mark, position of spikes btwn t0 and t1 #qdrMKS2 = _N.empty(qdrMKS.shape) t1 = _tm.time() nSpks = len(Asts) twpi = 2*_N.pi Kp1 = K+1 #rat = _N.zeros(M+1) pc = _N.zeros(M) ur = u.reshape((M, 1, K)) fr = f.reshape((M, 1)) # centers #print q2 iq2 = 1./q2 iSg = _N.linalg.inv(Sg) iq2r = iq2.reshape((M, 1)) try: ## warnings because l0 is 0 isN = _N.where(q2 <= 0)[0] if len(isN) > 0: q2[isN] = 0.3 is0 = _N.where(l0 <= 0)[0] if len(is0) > 0: l0[is0] = 0.001 pkFR = _N.log(l0) - 0.5*_N.log(twpi*q2) # M except RuntimeWarning: print "WARNING" print l0 print q2 mkNrms = _N.log(1/_N.sqrt(twpi*_N.linalg.det(Sg))) mkNrms = mkNrms.reshape((M, 1)) # M x 1 rnds = _N.random.rand(nSpks) pkFRr = pkFR.reshape((M, 1)) dmu = (mASr - ur) # mASr 1 x N x K, ur is M x 1 x K N = mASr.shape[1] #t2 = _tm.time() #_N.einsum("mnj,mjk,mnk->mn", dmu, iSg, dmu, out=qdrMKS) #t3 = _tm.time() _fm.multi_qdrtcs_par_func(dmu, iSg, qdrMKS, M, N, K, nthrds=nthrds) # fr is M x 1, xASr is 1 x N, iq2r is M x 1 #qdrSPC = (fr - xASr)*(fr - xASr)*iq2r # M x nSpks # 0.01s qdrSPC = _N.empty((M, N)) _hcb.hc_bcast1(fr, xASr, iq2r, qdrSPC, M, N) ### how far is closest cluster to each newly observed mark # mAS = mks[Asts+t0] # xAS = x[Asts + t0] # position @ spikes if cmp2Existing: # compare only non-hash spikes and non-hash clusters # realCl = _N.where(freeClstr == False)[0] # print freeClstr.shape # print realCl.shape abvthrEachCh = mASr[0] > hashthresh # should be NxK of abvthrAtLeast1Ch = _N.sum(abvthrEachCh, axis=1) > 0 # N x K newNonHashSpks = _N.where(abvthrAtLeast1Ch)[0] newNonHashSpksMemClstr = _N.ones(len(newNonHashSpks), dtype=_N.int) * (M-1) # initially, assign all of them to noise cluster #print "spikes not hash" #print abvthrInds abvthrEachCh = u[0:Msc] > hashthresh # M x K (M includes noise) abvthrAtLeast1Ch = _N.sum(abvthrEachCh, axis=1) > 0 knownNonHclstrs = _N.where(abvthrAtLeast1Ch & (freeClstr == False) & (q2[0:Msc] < wdSpc))[0] #print "clusters not hash" # Place prior for freeClstr near new non-hash spikes that are far # from known clusters that are not hash clusters nNrstMKS_d = _N.sqrt(_N.min(qdrMKS[knownNonHclstrs], axis=0)/K) # dim len(sts) nNrstSPC_d = _N.sqrt(_N.min(qdrSPC[knownNonHclstrs], axis=0)) # for each spike, distance to nearest non-hash cluster # print nNrstMKS_d # print nNrstSPC_d # print "==============" s = _N.empty((len(newNonHashSpks), 3)) # for each spike, distance to nearest cluster s[:, 0] = newNonHashSpks s[:, 1] = nNrstMKS_d[newNonHashSpks] s[:, 2] = nNrstSPC_d[newNonHashSpks] _N.savetxt(resFN("qdrMKSSPC%d" % epc, dir=oo.outdir), s, fmt="%d %.3e %.3e") dMK = nNrstMKS_d[newNonHashSpks] dSP = nNrstSPC_d[newNonHashSpks] ### assignment into farMKinds = _N.where(dMK > 4)[0] # # mean of prior for center - mean of farMKinds # cov of prior for center - how certain am I of mean? farSPinds = _N.where(dSP > 4)[0] # 4 std. deviations away farMKSPinds = _N.union1d(farMKinds, farSPinds) print farMKinds print newNonHashSpks ## points in newNonHashSpks but not in farMKinds notFarMKSPinds = _N.setdiff1d(_N.arange(newNonHashSpks.shape[0]), farMKSPinds) farMKSP = _N.empty((len(farMKSPinds), K+1)) farMKSP[:, 0] = xASr[0, newNonHashSpks[farMKSPinds]] farMKSP[:, 1:] = mASr[0, newNonHashSpks[farMKSPinds]] notFarMKSP = _N.empty((len(notFarMKSPinds), K+1)) notFarMKSP[:, 0] = xASr[0, newNonHashSpks[notFarMKSPinds]] notFarMKSP[:, 1:] = mASr[0, newNonHashSpks[notFarMKSPinds]] # farSP = _N.empty((len(farSPinds), K+1)) # farMK = _N.empty((len(farMKinds), K+1)) # farSP[:, 0] = xASr[0, farSPinds] # farSP[:, 1:] = mASr[0, farSPinds] # farMK[:, 0] = xASr[0, farMKinds] # farMK[:, 1:] = mASr[0, farMKinds] minK = 1 maxK = farMKSPinds.shape[0] / K maxK = maxK if (maxK < 6) else 6 freeClstrs = _N.where(freeClstr == True)[0] if maxK >= 2: print "coming in here" #labs, bics, bestLab, nClstrs = _oT.EMBICs(farMKSP, minK=minK, maxK=maxK, TR=1) labs, labsH, clstrs = emMKPOS_sep1B(farMKSP, None, TR=1, wfNClstrs=[[1, 4], [1, 4]], spNClstrs=[[1, 4], [1, 3]]) nClstrs = clstrs[0] bestLab = labs cls = clrs.get_colors(nClstrs) _U.savetxtWCom(resFN("newSpksMKSP%d" % epc, dir=oo.outdir), farMKSP, fmt="%.3e %.3e %.3e %.3e %.3e", com=("# number of clusters %d" % nClstrs)) _U.savetxtWCom(resFN("newSpksMKSP_nf%d" % epc, dir=oo.outdir), notFarMKSP, fmt="%.3e %.3e %.3e %.3e %.3e", com=("# number of clusters %d" % nClstrs)) L = len(freeClstrs) unqLabs = _N.unique(bestLab) upto = nClstrs if nClstrs < L else L # this should just count large clusters ii = -1 fig = _plt.figure() for fid in unqLabs[0:upto]: iths = farMKSPinds[_N.where(bestLab == fid)[0]] ths = newNonHashSpks[iths] for w in xrange(K): fig.add_subplot(2, 2, w+1) _plt.scatter(xASr[0, ths], mASr[0, ths, w], color=cls[ii]) if len(ths) > K: ii += 1 im = freeClstrs[ii] # Asts + t0 gives absolute time newNonHashSpksMemClstr[iths] = im _u_u[im] = _N.mean(mASr[0, ths], axis=0) u[im] = _u_u[im] _f_u[im] = _N.mean(xASr[0, ths], axis=0) f[im] = _f_u[im] q2[im] = _N.std(xASr[0, ths], axis=0)**2 * 9 # l0 = Hz * sqrt(2*_N.pi*q2) l0[im] = 10*_N.sqrt(q2[im]) _f_q2[im] = 1 _u_Sg[im] = _N.cov(mASr[0, ths], rowvar=0)*25 print "ep %(ep)d new cluster # %(m)d" % {"ep" : epc, "m" : im} print _u_u[im] print _f_u[im] print _f_q2[im] else: print "too small this prob. doesn't represent a cluster" _plt.savefig("newspks%d" % epc) # ####### known clusters # for fid in unqLabs[0:upto]: # iths = farMKSPinds[_N.where(bestLab == fid)[0]] # ths = newNonHashSpks[iths] # for w in xrange(K): # fig.add_subplot(2, 2, w+1) # _plt.scatter(xASr[0, ths], mASr[0, ths, w], color=cls[ii]) # if len(ths) > K: # ii += 1 # im = freeClstrs[ii] # Asts + t0 gives absolute time # newNonHashSpksMemClstr[iths] = im # _u_u[im] = _N.mean(mASr[0, ths], axis=0) # u[im] = _u_u[im] # _f_u[im] = _N.mean(xASr[0, ths], axis=0) # f[im] = _f_u[im] # q2[im] = _N.std(xASr[0, ths], axis=0)**2 * 9 # # l0 = Hz * sqrt(2*_N.pi*q2) # l0[im] = 10*_N.sqrt(q2[im]) # _f_q2[im] = 1 # _u_Sg[im] = _N.cov(mASr[0, ths], rowvar=0)*25 # print "ep %(ep)d new cluster # %(m)d" % {"ep" : epc, "m" : im} # print _u_u[im] # print _f_u[im] # print _f_q2[im] # else: # print "too small this prob. doesn't represent a cluster" # _plt.savefig("newspks%d" % epc) else: # just one cluster im = freeClstrs[0] # Asts + t0 gives absolute time _u_u[im] = _N.mean(mASr[0, newNonHashSpks[farMKSPinds]], axis=0) _f_u[im] = _N.mean(xASr[0, newNonHashSpks[farMKSPinds]], axis=0) _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farMKSPinds]], rowvar=0)*16 _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farMKSPinds]], axis=0)**2 * 16 # ## kernel density estimate # xs = _N.linspace(-6, 6, 101) # xsr = xs.reshape(101, 1) # isg2= 1/(0.1**2) # spatial kernel bandwidth # # fig = _plt.figure(figsize=(6, 9)) # # fig.add_subplot(1, 2, 1) # # _plt.scatter(xASr[0, newNonHashSpks[farMKinds]], mASr[0, newNonHashSpks[farMKinds], 0]) # # fig.add_subplot(1, 2, 2) # # _plt.scatter(xASr[0, newNonHashSpks[farSPinds]], mASr[0, newNonHashSpks[farSPinds], 0]) # freeClstrs = _N.where(freeClstr == True)[0] # L = len(freeClstrs) # jjj = 0 # if (len(farSPinds) >= Kp1) and (len(farMKinds) >= Kp1): # jjj = 1 # l1 = L/2 # for l in xrange(l1): # mASr is 1 x N x K # im = freeClstrs[l] # Asts + t0 gives absolute time # _u_u[im] = _N.mean(mASr[0, newNonHashSpks[farMKinds]], axis=0) # y = _N.exp(-0.5*(xsr - xASr[0, newNonHashSpks[farMKinds]])**2 * isg2) # yc = _N.sum(y, axis=1) # ix = _N.where(yc == _N.max(yc))[0][0] # _f_u[im] = xs[ix] # _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farMKinds]], rowvar=0)*30 # _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farMKinds]], axis=0)**2 * 30 # # _plt.figure() # # _plt.plot(xs, yc) # for l in xrange(l1, L): # im = freeClstrs[l] # Asts + t0 gives absolute time # _u_u[im] = _N.mean(mASr[0, newNonHashSpks[farSPinds]], axis=0) # y = _N.exp(-0.5*(xsr - xASr[0, newNonHashSpks[farSPinds]])**2 * isg2) # yc = _N.sum(y, axis=1) # ix = _N.where(yc == _N.max(yc))[0][0] # _f_u[im] = xs[ix] # _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farSPinds]], rowvar=0)*30 # _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farSPinds]], axis=0)**2 * 30 # # _plt.figure() # # _plt.plot(xs, yc) # elif (len(farSPinds) >= Kp1) and (len(farMKinds) < Kp1): # jjj = 2 # for l in xrange(L): # im = freeClstrs[l] # Asts + t0 gives absolute time # _u_u[im] = _N.mean(mASr[0, newNonHashSpks[farSPinds]], axis=0) # y = _N.exp(-0.5*(xsr - xASr[0, newNonHashSpks[farSPinds]])**2 * isg2) # yc = _N.sum(y, axis=1) # ix = _N.where(yc == _N.max(yc))[0][0] # _f_u[im] = xs[ix] # _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farSPinds]], rowvar=0)*30 # _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farSPinds]], axis=0)**2 * 30 # # _plt.figure() # # _plt.plot(xs, yc) # elif (len(farSPinds) < Kp1) and (len(farMKinds) >= Kp1): # jjj = 3 # for l in xrange(L): # im = freeClstrs[l] # Asts + t0 gives absolute time # _u_u[im] = _N.mean(mASr[0, newNonHashSpks[farMKinds]], axis=0) # y = _N.exp(-0.5*(xsr - xASr[0, newNonHashSpks[farMKinds]])**2 * isg2) # yc = _N.sum(y, axis=1) # ix = _N.where(yc == _N.max(yc))[0][0] # _f_u[im] = xs[ix] # _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farMKinds]], rowvar=0)*30 # _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farMKinds]], axis=0)**2 * 30 # # _plt.figure() # # _plt.plot(xs, yc) """ print "^^^^^^^^" print freeClstrs print "set priors for freeClstrs %d" % jjj #print _u_u[freeClstrs] #print _u_Sg[freeClstrs] print _f_u[freeClstrs] print _f_q2[freeClstrs] """ #if len(farSPinds) > 10: # set the priors of the freeClusters to be near the far spikes #### outside cmp2Existing here # (Mx1) + (Mx1) - (MxN + MxN) #cont = pkFRr + mkNrms - 0.5*(qdrSPC + qdrMKS) cont = _N.empty((M, N)) _hcb.hc_qdr_sum(pkFRr, mkNrms, qdrSPC, qdrMKS, cont, M, N) mcontr = _N.max(cont, axis=0).reshape((1, nSpks)) cont -= mcontr _N.exp(cont, out=econt) for m in xrange(M): rat[m+1] = rat[m] + econt[m] rat /= rat[M] """ # print f # print u # print q2 # print Sg # print l0 """ # print rat M1 = rat[1:] >= rnds M2 = rat[0:-1] <= rnds gz[it] = (M1&M2).T if cmp2Existing: # gz is ITERS x N x Mwowonz (N # of spikes in epoch) gz[it, newNonHashSpks] = False # not a member of any of them gz[it, newNonHashSpks, newNonHashSpksMemClstr] = True
def gibbs(self, ITERS, K, ep1=0, ep2=None, savePosterior=True, gtdiffusion=False, doSepHash=True, use_spc=True, nz_pth=0., smth_pth_ker=100, ignoresilence=False, use_omp=False, nThrds=2): """ gtdiffusion: use ground truth center of place field in calculating variance of center. Meaning of diffPerMin different """ print "gibbs %.5f" % _N.random.rand() oo = self oo.nThrds = nThrds twpi = 2*_N.pi pcklme = {} ep2 = oo.epochs if (ep2 == None) else ep2 oo.epochs = ep2-ep1 ###################################### GRID for calculating #### # points in sum. #### # points in uniform sampling of exp(x)p(x) (non-spike interals) #### # points in sampling of f for conditional posterior distribution #### # points in sampling of q2 for conditional posterior distribution #### NSexp, Nupx, fss, q2ss # numerical grid ux = _N.linspace(oo.xLo, oo.xHi, oo.Nupx, endpoint=False) # uniform x position # grid over uxr = ux.reshape((1, oo.Nupx)) uxrr= ux.reshape((1, 1, oo.Nupx)) #q2x = _N.exp(_N.linspace(_N.log(1e-7), _N.log(100), oo.q2ss)) # 5 orders of q2x = _N.exp(_N.linspace(_N.log(oo.q2x_L), _N.log(oo.q2x_H), oo.q2ss)) # 5 orders of d_q2x = _N.diff(q2x) q2x_m1 = _N.array(q2x[0:-1]) lq2x = _N.log(q2x) iq2x = 1./q2x q2xr = q2x.reshape((oo.q2ss, 1)) iq2xr = 1./q2xr q2xrr = q2x.reshape((1, oo.q2ss, 1)) iq2xrr = 1./q2xrr d_q2xr = d_q2x.reshape((oo.q2ss - 1, 1)) q2x_m1 = _N.array(q2x[0:-1]) q2x_m1r = q2x_m1.reshape((oo.q2ss-1, 1)) sqrt_2pi_q2x = _N.sqrt(twpi*q2x) l_sqrt_2pi_q2x = _N.log(sqrt_2pi_q2x) freeClstr = None if smth_pth_ker > 0: gk = gauKer(smth_pth_ker) # 0.1s smoothing of motion gk /= _N.sum(gk) xf = _N.convolve(oo.dat[:, 0], gk, mode="same") oo.dat[:, 0] = xf + nz_pth*_N.random.randn(len(oo.dat[:, 0])) else: oo.dat[:, 0] += nz_pth*_N.random.randn(len(oo.dat[:, 0])) x = oo.dat[:, 0] mks = oo.dat[:, 2:] if nz_pth > 0: _N.savetxt(resFN("nzyx.txt", dir=oo.outdir), x, fmt="%.4f") f_q2_rate = (oo.diffusePerMin**2)/60000. # unit of minutes ###################################### PRECOMPUTED tau_l0 = oo.t_hlf_l0/_N.log(2) tau_q2 = oo.t_hlf_q2/_N.log(2) for epc in xrange(ep1, ep2): print "^^^^^^^^^^^^^^^^^^^^^^^^ epoch %d" % epc t0 = oo.intvs[epc] t1 = oo.intvs[epc+1] if epc > 0: tm1= oo.intvs[epc-1] # 0 10 30 20 - 5 = 15 0.5*((10+30) - (10+0)) = 15 dt = 0.5*((t1+t0) - (t0+tm1)) dt = (t1-t0)*0.5 xt0t1 = _N.array(x[t0:t1]) posbins = _N.linspace(oo.xLo, oo.xHi, oo.Nupx+1) # _N.sum(px)*(xbns[1]-xbns[0]) = 1 px, xbns = _N.histogram(xt0t1, bins=posbins, normed=True) pxr = px.reshape((1, oo.Nupx)) pxrr = px.reshape((1, 1, oo.Nupx)) Asts = _N.where(oo.dat[t0:t1, 1] == 1)[0] # based at 0 if epc == ep1: ### initialize labS, labH, flatlabels, M, MF, hashthresh, nHSclusters = gAMxMu.initClusters(oo, K, x, mks, t0, t1, Asts, doSepHash=doSepHash, xLo=oo.xLo, xHi=oo.xHi, oneCluster=oo.oneCluster, nzclstr=oo.nzclstr) Mwowonz = M+1 if oo.nzclstr else M #nHSclusters.append(M - nHSclusters[0]-nHSclusters[1]) # last are free clusters that are not the noise cluster u_u_ = _N.empty((M, K)) u_Sg_ = _N.empty((M, K, K)) ####### containers for GIBBS samples iterations smp_sp_prms = _N.zeros((3, ITERS, M)) smp_mk_prms = [_N.zeros((K, ITERS, M)), _N.zeros((K, K, ITERS, M))] smp_sp_hyps = _N.zeros((6, ITERS, M)) smp_mk_hyps = [_N.zeros((K, ITERS, M)), _N.zeros((K, K, ITERS, M)), _N.zeros((1, ITERS, M)), _N.zeros((K, K, ITERS, M))] oo.smp_sp_prms = smp_sp_prms oo.smp_mk_prms = smp_mk_prms oo.smp_sp_hyps = smp_sp_hyps oo.smp_mk_hyps = smp_mk_hyps if oo.nzclstr: smp_nz_l0 = _N.zeros(ITERS) smp_nz_hyps = _N.zeros((2, ITERS)) # list of freeClstrs freeClstr = _N.empty(M, dtype=_N.bool) # Actual cluster freeClstr[:] = False l0, f, q2, u, Sg = gAMxMu.declare_params(M, K, nzclstr=oo.nzclstr) # nzclstr not inited # sized to include noise cluster if needed _l0_a, _l0_B, _f_u, _f_q2, _q2_a, _q2_B, _u_u, _u_Sg, _Sg_nu, \ _Sg_PSI = gAMxMu.declare_prior_hyp_params(M, MF, K, x, mks, Asts, t0) fr = f[0:M].reshape((M, 1)) gAMxMu.init_params_hyps(oo, M, MF, K, l0, f, q2, u, Sg, _l0_a, _l0_B, _f_u, _f_q2, _q2_a, _q2_B, _u_u, _u_Sg, _Sg_nu, \ _Sg_PSI, Asts, t0, x, mks, flatlabels, nHSclusters, nzclstr=oo.nzclstr) U = _N.empty(M) FQ2 = _N.empty(M) _fxs0 = _N.tile(_N.linspace(0, 1, oo.fss), M).reshape(M, oo.fss) f_exp_px = _N.empty((M, oo.fss)) q2_exp_px= _N.empty((M, oo.q2ss)) if oo.nzclstr: nz_l0_intgrd = _N.exp(-0.5*ux*ux / q2[Mwowonz-1]) _nz_l0_a = 0.001 _nz_l0_B = 0.1 ###### the hyperparameters for f, q2, u, Sg, l0 during Gibbs # f_u_, f_q2_, q2_a_, q2_B_, u_u_, u_Sg_, Sg_nu, Sg_PSI_, l0_a_, l0_B_ NSexp = t1-t0 # length of position data # # of no spike positions to sum xt0t1 = _N.array(x[t0:t1]) nSpks = len(Asts) gz = _N.zeros((ITERS, nSpks, Mwowonz), dtype=_N.bool) oo.gz=gz print "spikes %d" % nSpks dSilenceX1 = (NSexp/float(oo.Nupx))*(oo.xHi-oo.xLo) dSilenceX2 = NSexp*(xbns[1]-xbns[0]) # dx of histogram print "-------------------------- %(1).4f %(2).4f" % {"1" : dSilenceX1, "2" : dSilenceX2} dSilenceX = dSilenceX1 xAS = x[Asts + t0] # position @ spikes mAS = mks[Asts + t0] # position @ spikes xASr = xAS.reshape((1, nSpks)) mASr = mAS.reshape((1, nSpks, K)) econt = _N.empty((Mwowonz, nSpks)) rat = _N.zeros((Mwowonz+1, nSpks)) qdrMKS = _N.empty((Mwowonz, nSpks)) ################################ GIBBS ITERS ITERS ITERS clstsz = _N.zeros(M, dtype=_N.int) _iu_Sg = _N.array(_u_Sg) for m in xrange(M): _iu_Sg[m] = _N.linalg.inv(_u_Sg[m]) ttA = _tm.time() for iter in xrange(ITERS): tt1 = _tm.time() iSg = _N.linalg.inv(Sg) if (iter % 100) == 0: #print "-------iter %(i)d %(r).5f" % {"i" : iter, "r" : _N.random.rand()} print "-------iter %(i)d" % {"i" : iter} gAMxMu.stochasticAssignment(oo, epc, iter, M, Mwowonz, K, l0, f, q2, u, Sg, _f_u, _u_u, _f_q2, _u_Sg, Asts, t0, mASr, xASr, rat, econt, gz, qdrMKS, freeClstr, hashthresh, ((epc > 0) and (iter == 0)), nthrds=oo.nThrds) #gAMxMu.stochasticAssignment(oo, iter, M, Mwowonz, K, l0, f, q2, u, Sg, _f_u, _u_u, Asts, t0, mASr, xASr, rat, econt, gz, qdrMKS, freeClstr, hashthresh, iter==0, nthrds=oo.nThrds) ############### FOR EACH CLUSTER l_sts = [] for m in xrange(M): # get the minds minds = _N.where(gz[iter, :, m] == 1)[0] sts = Asts[minds] + t0 # sts is in absolute time clstsz[m] = len(sts) l_sts.append(sts) # for m in xrange(Mwowonz): # get the minds # minds = _N.where(gz[iter, :, m] == 1)[0] # print "cluster %(m)d len %(l)d " % {"m" : m, "l" : len(minds)} # print u[m] # print f[m] #tt2 = _tm.time() ############### ############### CONDITIONAL l0 ############### # _ss.gamma.rvs. uses k, theta k is 1/B (B is our thing) iiq2 = 1./q2[0:M] iiq2r= iiq2.reshape((M, 1)) iiq2rr= iiq2.reshape((M, 1, 1)) fr = f[0:M].reshape((M, 1)) l0_intgrd = _N.exp(-0.5*(fr - ux)*(fr-ux) * iiq2r) sLLkPr = _N.empty((M, oo.q2ss)) l0_exp_px = _N.sum(l0_intgrd*pxr, axis=1) * dSilenceX BL = (oo.dt/_N.sqrt(twpi*q2[0:M]))*l0_exp_px # dim M if (epc > 0) and oo.adapt: _md_nd= _l0_a / _l0_B _Dl0_a = _l0_a * _N.exp(-dt/tau_l0) _Dl0_B = _Dl0_a / _md_nd else: _Dl0_a = _l0_a _Dl0_B = _l0_B aL = clstsz l0_a_ = aL + _Dl0_a l0_B_ = BL + _Dl0_B try: # mean is (l0_a_ / l0_B_) l0[0:M] = _ss.gamma.rvs(l0_a_, scale=(1/l0_B_)) # check except ValueError: """ print l0_B_ print _Dl0_B print BL print l0_exp_px print 1/_N.sqrt(twpi*q2[0:M]) print pxr print l0_intgrd """ _N.savetxt("fxux", (fr - ux)*(fr-ux)) _N.savetxt("fr", fr) _N.savetxt("iiq2", iiq2) _N.savetxt("l0_intgrd", l0_intgrd) raise smp_sp_prms[oo.ky_p_l0, iter] = l0[0:M] smp_sp_hyps[oo.ky_h_l0_a, iter] = l0_a_ smp_sp_hyps[oo.ky_h_l0_B, iter] = l0_B_ mcs = _N.empty((M, K)) # cluster sample means #tt3 = _tm.time() ############### ############### u ############### for m in xrange(M): if clstsz[m] > 0: u_Sg_[m] = _N.linalg.inv(_iu_Sg[m] + clstsz[m]*iSg[m]) clstx = mks[l_sts[m]] mcs[m] = _N.mean(clstx, axis=0) #u_u_[m] = _N.dot(u_Sg_[m], _N.dot(_iu_Sg[m], _u_u[m]) + clstsz[m]*_N.dot(iSg[m], mcs[m])) u_u_[m] = _N.einsum("jk,k->j", u_Sg_[m], _N.dot(_iu_Sg[m], _u_u[m]) + clstsz[m]*_N.dot(iSg[m], mcs[m])) # print "mean of cluster %d" % m # print mcs[m] # print u_u_[m] # hyp ######## POSITION ## mean of posterior distribution of cluster means # sigma^2 and mu are the current Gibbs-sampled values ## mean of posterior distribution of cluster means # print "for cluster %(m)d with size %(sz)d" % {"m" : m, "sz" : clstsz[m]} # print mcs[m] # print u_u_[m] # print _u_u[m] else: u_Sg_[m] = _N.array(_u_Sg[m]) u_u_[m] = _N.array(_u_u[m]) ucmvnrms= _N.random.randn(M, K) C = _N.linalg.cholesky(u_Sg_) u[0:M] = _N.einsum("njk,nk->nj", C, ucmvnrms) + u_u_ smp_mk_prms[oo.ky_p_u][:, iter] = u[0:M].T # dim of u wrong smp_mk_hyps[oo.ky_h_u_u][:, iter] = u_u_.T smp_mk_hyps[oo.ky_h_u_Sg][:, :, iter] = u_Sg_.T #tt4 = _tm.time() ############### ############### Conditional f ############### if (epc > 0) and oo.adapt: q2pr = _f_q2 + f_q2_rate * dt else: q2pr = _f_q2 for m in xrange(M): sts = l_sts[m] if clstsz[m] > 0: fs = (1./clstsz[m])*_N.sum(xt0t1[sts-t0]) fq2 = q2[m]/clstsz[m] U[m] = (fs*q2pr[m] + _f_u[m]*fq2) / (q2pr[m] + fq2) FQ2[m] = (q2pr[m]*fq2) / (q2pr[m] + fq2) else: U[m] = _f_u[m] FQ2[m] = q2pr[m] FQ = _N.sqrt(FQ2) Ur = U.reshape((M, 1)) FQr = FQ.reshape((M, 1)) FQ2r = FQ2.reshape((M, 1)) if use_spc: fxs = _N.copy(_fxs0) fxs *= (FQr*120) fxs -= (FQr*60) fxs += Ur if use_omp: M_times_N_f_intgrls_raw(fxs, ux, iiq2, dSilenceX, px, f_exp_px, M, oo.fss, oo.Nupx, oo.nThrds) else: fxsr = fxs.reshape((M, oo.fss, 1)) fxrux = -0.5*(fxsr-uxrr)*(fxsr-uxrr) # f_intgrd is M x fss x Nupx f_intgrd = _N.exp(fxrux*iiq2rr) # integrand f_exp_px = _N.sum(f_intgrd*pxrr, axis=2) * dSilenceX # f_exp_px is M x fss l0r = l0[0:M].reshape((M, 1)) q2r = q2[0:M].reshape((M, 1)) # s is (M x fss) s = -(l0r*oo.dt/_N.sqrt(twpi*q2r)) * f_exp_px # a function of x #if (iter > ITERS - 40) and (iter % 5 == 0): # print f_exp_px # _plt.plot(fxs[0], _N.s) else: s = _N.zeros(M) # U, FQ2 is dim(M) # fxs is M x fss funcf = -0.5*((fxs-Ur)*(fxs-Ur))/FQ2r + s maxes = _N.max(funcf, axis=1) maxesr = maxes.reshape((M, 1)) funcf -= maxesr condPosF= _N.exp(funcf) # condPosF is M x fss ttB = _tm.time() # fxs M x fss # fxs M x fss # condPosF M x fss norm = 1./_N.sum(condPosF, axis=1) # sz M f_u_ = norm*_N.sum(fxs*condPosF, axis=1) # sz M f_u_r = f_u_.reshape((M, 1)) f_q2_ = norm*_N.sum(condPosF*(fxs-f_u_r)*(fxs-f_u_r), axis=1) f[0:M] = _N.sqrt(f_q2_)*_N.random.randn() + f_u_ smp_sp_prms[oo.ky_p_f, iter] = f[0:M] smp_sp_hyps[oo.ky_h_f_u, iter] = f_u_ smp_sp_hyps[oo.ky_h_f_q2, iter] = f_q2_ #tt5 = _tm.time() ############## ############## VARIANCE, COVARIANCE ############## for m in xrange(M): if clstsz[m] >= K: ## dof of posterior distribution of cluster covariance Sg_nu_ = _Sg_nu[m, 0] + clstsz[m] ## dof of posterior distribution of cluster covariance ur = u[m].reshape((1, K)) clstx = mks[l_sts[m]] Sg_PSI_ = _Sg_PSI[m] + _N.dot((clstx - ur).T, (clstx-ur)) else: Sg_nu_ = _Sg_nu[m, 0] ## dof of posterior distribution of cluster covariance ur = u[m].reshape((1, K)) Sg_PSI_ = _Sg_PSI[m] Sg[m] = s_u.sample_invwishart(Sg_PSI_, Sg_nu_) smp_mk_hyps[oo.ky_h_Sg_nu][0, iter, m] = Sg_nu_ smp_mk_hyps[oo.ky_h_Sg_PSI][:, :, iter, m] = Sg_PSI_ ## dof of posterior distribution of cluster covariance smp_mk_prms[oo.ky_p_Sg][:, :, iter] = Sg[0:M].T #tt6 = _tm.time() ############## ############## SAMPLE SPATIAL VARIANCE ############## if use_spc: # M x q2ss x Nupx # f M x 1 x 1 # iq2xrr 1 x q2ss x 1 # uxrr 1 x 1 x Nupx if use_omp: #ux variable held fixed M_times_N_q2_intgrls_raw(f, ux, iq2x, dSilenceX, px, q2_exp_px, M, oo.q2ss, oo.Nupx, oo.nThrds) else: frr = f.reshape((M, 1, 1)) q2_intgrd = _N.exp(-0.5*(frr - uxrr)*(frr-uxrr) * iq2xrr) q2_exp_px = _N.sum(q2_intgrd*pxrr, axis=2) * dSilenceX # function of q2 s = -((l0r*oo.dt)/sqrt_2pi_q2x)*q2_exp_px else: s = _N.zeros((oo.q2ss, M)) # B' / (a' - 1) = MODE #keep mode the same after discount # B' = MODE * (a' - 1) if (epc > 0) and oo.adapt: _md_nd= _q2_B / (_q2_a + 1) _Dq2_a = _q2_a * _N.exp(-dt/tau_q2) _Dq2_B = _Dq2_a / _md_nd else: _Dq2_a = _q2_a _Dq2_B = _q2_B SL_Bs = _N.empty(M) SL_as = _N.empty(M) for m in xrange(M): if clstsz[m] > 0: sts = l_sts[m] xI = (xt0t1[sts-t0]-f[m])*(xt0t1[sts-t0]-f[m])*0.5 SL_a = 0.5*clstsz[m] - 1 # spiking part of likelihood SL_B = _N.sum(xI) # spiking part of likelihood SL_Bs[m] = SL_B SL_as[m] = SL_a # spiking prior x prior #sLLkPr[m] = -(SL_a + 1)*lq2x - iq2x*SL_B sLLkPr[m] = -(_q2_a[m] + SL_a + 2)*lq2x - iq2x*(_q2_B[m] + SL_B) else: sLLkPr[m] = -(_q2_a[m] + 1)*lq2x - iq2x*_q2_B[m] q2_a_, q2_B_ = mltpl_ig_prmsUV(q2xr, sLLkPr.T, s.T, d_q2xr, q2x_m1r, clstsz, iter, mks, t0, xt0t1, gz, l_sts, SL_as, SL_Bs, _q2_a, _q2_B, oo.q2_min, oo.q2_max) q2[0:M] = _ss.invgamma.rvs(q2_a_ + 1, scale=q2_B_) # check tt7 = _tm.time() smp_sp_prms[oo.ky_p_q2, iter] = q2[0:M] smp_sp_hyps[oo.ky_h_q2_a, iter] = q2_a_ smp_sp_hyps[oo.ky_h_q2_B, iter] = q2_B_ # print "timing start" # print (tt2-tt1) # print (tt3-tt2) # print (tt4-tt3) # print (tt5-tt4) # print (tt6-tt5) #print (tt7-tt1) # print "timing end" # nz clstr. fixed width if oo.nzclstr: nz_l0_exp_px = _N.sum(nz_l0_intgrd*px) * dSilenceX BL = (oo.dt/_N.sqrt(twpi*q2[Mwowonz-1]))*nz_l0_exp_px minds = len(_N.where(gz[iter, :, Mwowonz-1] == 1)[0]) l0_a_ = minds + _nz_l0_a l0_B_ = BL + _nz_l0_B l0[Mwowonz-1] = _ss.gamma.rvs(l0_a_, scale=(1/l0_B_)) smp_nz_l0[iter] = l0[Mwowonz-1] smp_nz_hyps[0, iter] = l0_a_ smp_nz_hyps[1, iter] = l0_B_ ttB = _tm.time() print (ttB-ttA) gAMxMu.finish_epoch(oo, nSpks, epc, ITERS, gz, l0, f, q2, u, Sg, _f_u, _f_q2, _q2_a, _q2_B, _l0_a, _l0_B, _u_u, _u_Sg, _Sg_nu, _Sg_PSI, smp_sp_hyps, smp_sp_prms, smp_mk_hyps, smp_mk_prms, freeClstr, M, K) # MAP of nzclstr if oo.nzclstr: frm = int(0.7*ITERS) _nz_l0_a = _N.median(smp_nz_hyps[0, frm:]) _nz_l0_B = _N.median(smp_nz_hyps[1, frm:]) pcklme["smp_sp_hyps"] = smp_sp_hyps pcklme["smp_mk_hyps"] = smp_mk_hyps pcklme["smp_sp_prms"] = smp_sp_prms pcklme["smp_mk_prms"] = smp_mk_prms pcklme["sp_prmPstMd"] = oo.sp_prmPstMd pcklme["mk_prmPstMd"] = oo.mk_prmPstMd pcklme["intvs"] = oo.intvs pcklme["occ"] = gz pcklme["nz_pth"] = nz_pth pcklme["M"] = M pcklme["Mwowonz"] = Mwowonz if Mwowonz > M: # or oo.nzclstr == True pcklme["nz_fs"] = f[M] pcklme["nz_q2"] = q2[M] pcklme["nz_Sg"] = Sg[M] pcklme["nz_u"] = u[M] pcklme["smp_nz_l0"] = smp_nz_l0 pcklme["smp_nz_hyps"]= smp_nz_hyps dmp = open(resFN("posteriors_%d.dmp" % epc, dir=oo.outdir), "wb") pickle.dump(pcklme, dmp, -1) dmp.close()
def timeline(bfn, datfn, itvfn, outfn="timeline", ch1=0, ch2=1, xL=0, xH=3, yticks=[0, 1, 2, 3], thin=1): d = _N.loadtxt(datFN("%s.dat" % datfn)) # marks itv = _N.loadtxt(datFN("%s.dat" % itvfn)) N = d.shape[0] epochs = itv.shape[0]-1 ch1 += 2 # because this is data col ch2 += 2 _sts = _N.where(d[:, 1] == 1)[0] if thin == 1: sts = _sts else: sts = _sts[::thin] wvfmMin = _N.min(d[:, 2:], axis=0) wvfmMax = _N.max(d[:, 2:], axis=0) fig = _plt.figure(figsize=(10, 12)) ####################### ax =_plt.subplot2grid((4, 3), (0, 0), colspan=3) _plt.scatter(sts/1000., d[sts, 0], s=2, color="black") mF.arbitraryAxes(ax, axesVis=[True, True, False, False], xtpos="bottom", ytpos="left") mF.setTicksAndLims(xlabel="time (s)", ylabel="position", xticks=None, yticks=yticks, xticksD=None, yticksD=None, xlim=[0, N/1000.], ylim=[xL-0.3, xH+0.3], tickFS=15, labelFS=18) for ep in xrange(epochs): _plt.axvline(x=(itv[ep+1]*N/1000.), color="red", ls="--") ####################### ax = _plt.subplot2grid((4, 3), (1, 0), colspan=3) _plt.scatter(sts/1000., d[sts, ch1], s=2, color="black") mF.arbitraryAxes(ax, axesVis=[True, True, False, False], xtpos="bottom", ytpos="left") mF.setTicksAndLims(xlabel="time (s)", ylabel=("mk tet%d" % (ch1-1)), xticks=None, yticks=[0, 3, 6], xticksD=None, yticksD=None, xlim=[0, N/1000.], ylim=[wvfmMin[0], wvfmMax[0]], tickFS=15, labelFS=18) for ep in xrange(epochs): _plt.axvline(x=(itv[ep+1]*N/1000.), color="red", ls="--") ####################### ax = _plt.subplot2grid((4, 3), (2, 0), colspan=3) _plt.scatter(sts/1000., d[sts, ch2], s=2, color="black") mF.arbitraryAxes(ax, axesVis=[True, True, False, False], xtpos="bottom", ytpos="left") mF.setTicksAndLims(xlabel="time (s)", ylabel=("mk tet%d" % (ch2-1)), xticks=None, yticks=[0, 3, 6], xticksD=None, yticksD=None, xlim=[0, N/1000.], ylim=[wvfmMin[1], wvfmMax[1]], tickFS=15, labelFS=18) for ep in xrange(epochs): _plt.axvline(x=(itv[ep+1]*N/1000.), color="red", ls="--") ############## ax = _plt.subplot2grid((4, 3), (3, 0), colspan=1) _plt.scatter(d[sts, ch1], d[sts, ch2], s=2, color="black") mF.arbitraryAxes(ax, axesVis=[True, True, False, False], xtpos="bottom", ytpos="left") mF.setTicksAndLims(xlabel=("mk tet%d" % (ch1-1)), ylabel=("mk tet%d" % (ch2-1)), xticks=[0, 3, 6], yticks=[0, 3, 6], xticksD=None, yticksD=None, xlim=[wvfmMin[0], wvfmMax[0]], ylim=[wvfmMin[1], wvfmMax[1]], tickFS=15, labelFS=18) ############## ax = _plt.subplot2grid((4, 3), (3, 1), colspan=1) _plt.scatter(d[sts, 0], d[sts, ch1], s=2, color="black") mF.arbitraryAxes(ax, axesVis=[True, True, False, False], xtpos="bottom", ytpos="left") mF.setTicksAndLims(xlabel="pos", ylabel=("mk tet%d" % (ch1-1)), xticks=_N.linspace(xL, xH, 3), yticks=[0, 3, 6], xticksD=None, yticksD=None, xlim=[xL, xH], ylim=None, tickFS=15, labelFS=18) ############## ax = _plt.subplot2grid((4, 3), (3, 2), colspan=1) _plt.scatter(d[sts, 0], d[sts, ch2], s=2, color="black") mF.arbitraryAxes(ax, axesVis=[True, True, False, False], xtpos="bottom", ytpos="left") mF.setTicksAndLims(xlabel="pos", ylabel=("mk tet%d" % (ch2-1)), xticks=_N.linspace(xL, xH, 3), yticks=[0, 3, 6], xticksD=None, yticksD=None, xlim=[xL, xH], ylim=None, tickFS=15, labelFS=18) ############## fig.subplots_adjust(left=0.15, bottom=0.15, wspace=0.38, hspace=0.38) epochs = len(itv)-1 choutfn = "%(of)s_%(1)d,%(2)d" % {"of" : outfn, "1" : (ch1-1), "2" : (ch2-1)} _plt.savefig(resFN(choutfn, dir=bfn), transparent=True) _plt.close()
def gibbs(self, ITERS, ep1=0, ep2=None, savePosterior=True, gtdiffusion=False): """ gtdiffusion: use ground truth center of place field in calculating variance of center. Meaning of diffPerMin different """ oo = self # PRIORS # priors prefixed w/ _ _f_u = 0; _f_q2 = 1 # inverse gamma _q2_a = 1e-4; _q2_B = 1e-3 #_plt.plot(q2x, q2x**(-_q2_a-1)*_N.exp(-_q2_B / q2x)) _l0_a = 1.; _l0_B = 1/30. # mean 30Hz peak firing rate ep2 = oo.epochs if (ep2 == None) else ep2 oo.epochs = ep2-ep1 oo.prmPstMd = _N.zeros((oo.epochs, 3)) # mode of the params oo.hypPstMd = _N.zeros((oo.epochs, 2+2+2)) # the hyper params twpi = 2*_N.pi pcklme = {} # Gibbs sampling # parameters l0, f, q2 ###################################### GIBBS samples, need for MAP estimate smp_prms = _N.zeros((3, ITERS, 1)) # smp_hyps = _N.zeros((6, ITERS, 1)) ###################################### INITIAL VALUE OF PARAMS l0 = 50 q2 = 0.0144 f = 1.1 ###################################### GRID for calculating #### # points in sum. #### # points in uniform sampling of exp(x)p(x) (non-spike interals) #### # points in sampling of f for conditional posterior distribution #### # points in sampling of q2 for conditional posterior distribution #### NSexp, Nupx, fss, q2ss # numerical grid ux = _N.linspace(0, 3, oo.Nupx, endpoint=False) # uniform x position q2x = _N.exp(_N.linspace(_N.log(0.00005), _N.log(10), oo.q2ss)) # 5 orders of d_q2x = _N.diff(q2x) q2x_m1 = _N.array(q2x[0:-1]) lq2x = _N.log(q2x) iq2x = 1./q2x q2xr = q2x.reshape((oo.q2ss, 1)) iq2xr = 1./q2xr sqrt_2pi_q2x = _N.sqrt(twpi*q2x) l_sqrt_2pi_q2x = _N.log(sqrt_2pi_q2x) x = oo.dat[:, 0] q2rate = oo.diffPerEpoch**2 # unit of minutes ###################################### PRECOMPUTED posbins = _N.linspace(0, 3, oo.Nupx+1) for epc in xrange(ep1, ep2): # if i > 0: # q2x = _N.linspace(0.001, 4, q2ss) # q2xr = q2x.reshape((q2ss, 1)) # iq2xr = 1./q2xr #print q2 print "epoch %d" % epc t0 = oo.intvs[epc] t1 = oo.intvs[epc+1] sts = _N.where(oo.dat[t0:t1, 1] == 1)[0] nts = _N.where(oo.dat[t0:t1, 1] == 0)[0] if gtdiffusion: q2rate = (oo.dat[t1-1,2]-oo.dat[t0,2])**2*oo.diffPerMin NSexp = t1-t0 # length of position data # # of no spike positions to sum xt0t1 = _N.array(x[t0:t1]) px, xbns = _N.histogram(xt0t1, bins=posbins, normed=True) nSpks = len(sts) print "spikes %d" % nSpks dSilenceX = (NSexp/float(oo.Nupx))*3 for iter in xrange(ITERS): #print "iter %d" % iter iiq2 = 1./q2 # prior described by hyper-parameters. # prior described by function # likelihood ############### CONDITIONAL f #q2pr = _f_q2 + q2rate q2pr = _f_q2 if (_f_q2 > q2rate) else q2rate if nSpks > 0: # spiking portion likelihood x prior fs = (1./nSpks)*_N.sum(xt0t1[sts]) fq2 = q2/nSpks M = (fs*q2pr + + _f_u*fq2) / (q2pr + fq2) Sg2 = (q2pr*fq2) / (q2pr + fq2) else: M = _f_u Sg2 = q2pr Sg = _N.sqrt(Sg2) fx = _N.linspace(M - Sg*50, M + Sg*50, oo.fss) fxr = fx.reshape((oo.fss, 1)) fxrux = -0.5*(fxr-ux)**2 xI_f = (xt0t1 - fxr)**2*0.5 f_intgrd = _N.exp((fxrux*iiq2)) # integrand f_exp_px = _N.sum(f_intgrd*px, axis=1) * dSilenceX # f_exp_px is a function of f s = -(l0*oo.dt/_N.sqrt(twpi*q2)) * f_exp_px # a function of x #print Sg2 #print M funcf = -0.5*((fx-M)*(fx-M))/Sg2 + s funcf -= _N.max(funcf) condPosF= _N.exp(funcf) #print _N.sum(condPosF) """ if iter == 0: fig = _plt.figure() _plt.plot(fx, condPosF) _plt.xlim(0.8, 1.3) _plt.savefig("%(dir)s/condposF%(i)d" % {"dir" : outdir, "i" : i}) _plt.close() """ norm = 1./_N.sum(condPosF) f_u_ = norm*_N.sum(fx*condPosF) f_q2_ = norm*_N.sum(condPosF*(fx-f_u_)*(fx-f_u_)) f = _N.sqrt(f_q2_)*_N.random.randn() + f_u_ smp_prms[oo.ky_p_f, iter, 0] = f smp_hyps[oo.ky_h_f_u, iter, 0] = f_u_ smp_hyps[oo.ky_h_f_q2, iter, 0] = f_q2_ #ax1.plot(fx, L_f, color="black") # ############### CONDITIONAL q2 #xI = (xt0t1-f)*(xt0t1-f)*0.5*iq2xr q2_intgrd = _N.exp(-0.5*(f - ux)*(f-ux) * iq2xr) q2_exp_px = _N.sum(q2_intgrd*px, axis=1) * dSilenceX s = -((l0*oo.dt)/sqrt_2pi_q2x)*q2_exp_px # function of q2 ## adjust the prior to reflect how much we think PF can change _Dq2_a = _q2_a if _q2_a < 200 else 200 _Dq2_B = (_q2_B/(_q2_a+1))*(_Dq2_a+1) if nSpks > 0: #print _N.sum((xt0t1[sts]-f)*(xt0t1[sts]-f))/(nSpks-1) ## (1/sqrt(sg2))^S ## (1/x)^(S/2) = (1/x)-(a+1) ## -S/2 = -a - 1 -a = -S/2 + 1 a = S/2-1 xI = (xt0t1[sts]-f)*(xt0t1[sts]-f)*0.5 SL_a = 0.5*nSpks - 1 # spiking part of likelihood SL_B = _N.sum(xI) # spiking part of likelihood # spiking prior x prior sLLkPr = -(_Dq2_a + SL_a + 2)*lq2x - iq2x*(_Dq2_B + SL_B) else: sLLkPr = -(_Dq2_a + 1)*lq2x - iq2x*(_Dq2_B) sat = sLLkPr + s sat -= _N.max(sat) condPos = _N.exp(sat) """ if iter == 10: fig = _plt.figure() _plt.plot(q2x, condPos) _plt.xlim(0, 0.5) _plt.savefig("%(dir)s/condpos%(i)d" % {"dir" : outdir, "i" : i}) _plt.close() """ q2_a_, q2_B_ = ig_prmsUV(q2x, condPos, d_q2x, q2x_m1, ITER=1) #print condPos _plt.plot(q2x, condPos) q2 = _ss.invgamma.rvs(q2_a_ + 1, scale=q2_B_) # check #print ((1./nSpks)*_N.sum((xt0t1[sts]-f)*(xt0t1[sts]-f))) smp_prms[oo.ky_p_q2, iter, 0] = q2 smp_hyps[oo.ky_h_q2_a, iter, 0] = q2_a_ smp_hyps[oo.ky_h_q2_B, iter, 0] = q2_B_ ############### CONDITIONAL l0 # _ss.gamma.rvs. uses k, theta k is 1/B (B is our thing) iiq2 = 1./q2 # xI = (xt0t1-f)*(xt0t1-f)*0.5*iiq2 # BL = (oo.dt/_N.sqrt(twpi*q2))*_N.sum(_N.exp(-xI)) l0_intgrd = _N.exp(-0.5*(f - ux)*(f-ux) * iiq2) l0_exp_px = _N.sum(l0_intgrd*px) * dSilenceX BL = (oo.dt/_N.sqrt(twpi*q2))*l0_exp_px # if iter == 50: # print "BL %(BL).2f BL2 %(BL2).2f" % {"BL" : BL, "BL2" : BL2} aL = nSpks l0_a_ = aL + _l0_a l0_B_ = BL + _l0_B l0 = _ss.gamma.rvs(l0_a_ - 1, scale=(1/l0_B_)) # check ### l0 / _N.sqrt(twpi*q2) is f*dt used in createData2 smp_prms[oo.ky_p_l0, iter, 0] = l0 smp_hyps[oo.ky_h_l0_a, iter, 0] = l0_a_ smp_hyps[oo.ky_h_l0_B, iter, 0] = l0_B_ frm = 30 for ip in xrange(3): # params L = _N.min(smp_prms[ip, frm:, 0]); H = _N.max(smp_prms[ip, frm:, 0]) cnts, bns = _N.histogram(smp_prms[ip, frm:, 0], bins=_N.linspace(L, H, 50)) ib = _N.where(cnts == _N.max(cnts))[0][0] if ip == oo.ky_p_l0: l0 = oo.prmPstMd[epc, ip] = bns[ib] elif ip == oo.ky_p_f: f = oo.prmPstMd[epc, ip] = bns[ib] elif ip == oo.ky_p_q2: q2 = oo.prmPstMd[epc, ip] = bns[ib] pcklme["cp%d" % epc] = _N.array(smp_prms) for ip in xrange(6): # hyper params L = _N.min(smp_hyps[ip, frm:, 0]); H = _N.max(smp_hyps[ip, frm:, 0]) cnts, bns = _N.histogram(smp_hyps[ip, frm:, 0], bins=_N.linspace(L, H, 50)) ib = _N.where(cnts == _N.max(cnts))[0][0] if ip == oo.ky_h_l0_a: _l0_a = oo.hypPstMd[epc, ip] = bns[ib] elif ip == oo.ky_h_l0_B: _l0_B = oo.hypPstMd[epc, ip] = bns[ib] elif ip == oo.ky_h_f_u: _f_u = oo.hypPstMd[epc, ip] = bns[ib] elif ip == oo.ky_h_f_q2: _f_q2 = oo.hypPstMd[epc, ip] = bns[ib] elif ip == oo.ky_h_q2_a: _q2_a = oo.hypPstMd[epc, ip] = bns[ib] elif ip == oo.ky_h_q2_B: _q2_B = oo.hypPstMd[epc, ip] = bns[ib] if savePosterior: _N.savetxt(resFN("posParams.dat", dir=oo.outdir), smp_prms[:, :, 0].T, fmt="%.4f %.4f %.4f") _N.savetxt(resFN("posHypParams.dat", dir=oo.outdir), smp_hyps[:, :, 0].T, fmt="%.4f %.4f %.4f %.4f %.4f %.4f") pcklme["md"] = _N.array(oo.prmPstMd) dmp = open(resFN("posteriors.dump", dir=oo.outdir), "wb") pickle.dump(pcklme, dmp, -1) dmp.close() _N.savetxt(resFN("posModes.dat", dir=oo.outdir), oo.prmPstMd, fmt="%.4f %.4f %.4f") _N.savetxt(resFN("hypModes.dat", dir=oo.outdir), oo.hypPstMd, fmt="%.4f %.4f %.4f %.4f %.4f %.4f")
def gibbs(self, ITERS, K, ep1=0, ep2=None, savePosterior=True, gtdiffusion=False, Mdbg=None, doSepHash=True, use_spc=True, nz_pth=0., ignoresilence=False, use_omp=False): """ gtdiffusion: use ground truth center of place field in calculating variance of center. Meaning of diffPerMin different """ print "gibbs" oo = self twpi = 2*_N.pi pcklme = {} ep2 = oo.epochs if (ep2 == None) else ep2 oo.epochs = ep2-ep1 ###################################### GRID for calculating #### # points in sum. #### # points in uniform sampling of exp(x)p(x) (non-spike interals) #### # points in sampling of f for conditional posterior distribution #### # points in sampling of q2 for conditional posterior distribution #### NSexp, Nupx, fss, q2ss # numerical grid ux = _N.linspace(oo.xLo, oo.xHi, oo.Nupx, endpoint=False) # uniform x position q2x = _N.exp(_N.linspace(_N.log(1e-7), _N.log(100), oo.q2ss)) # 5 orders of d_q2x = _N.diff(q2x) q2x_m1 = _N.array(q2x[0:-1]) lq2x = _N.log(q2x) iq2x = 1./q2x q2xr = q2x.reshape((oo.q2ss, 1)) iq2xr = 1./q2xr sqrt_2pi_q2x = _N.sqrt(twpi*q2x) l_sqrt_2pi_q2x = _N.log(sqrt_2pi_q2x) freeClstr = None gk = gauKer(100) # 0.1s smoothing of motion gk /= _N.sum(gk) xf = _N.convolve(oo.dat[:, 0], gk, mode="same") oo.dat[:, 0] = xf + nz_pth*_N.random.randn(len(oo.dat[:, 0])) x = oo.dat[:, 0] mks = oo.dat[:, 2:] f_q2_rate = (oo.diffusePerMin**2)/60000. # unit of minutes ###################################### PRECOMPUTED tau_l0 = oo.t_hlf_l0/_N.log(2) tau_q2 = oo.t_hlf_q2/_N.log(2) for epc in xrange(ep1, ep2): t0 = oo.intvs[epc] t1 = oo.intvs[epc+1] if epc > 0: tm1= oo.intvs[epc-1] # 0 10 30 20 - 5 = 15 0.5*((10+30) - (10+0)) = 15 dt = 0.5*((t1+t0) - (t0+tm1)) dt = (t1-t0)*0.5 xt0t1 = _N.array(x[t0:t1]) posbins = _N.linspace(oo.xLo, oo.xHi, oo.Nupx+1) # _N.sum(px)*(xbns[1]-xbns[0]) = 1 px, xbns = _N.histogram(xt0t1, bins=posbins, normed=True) Asts = _N.where(oo.dat[t0:t1, 1] == 1)[0] # based at 0 Ants = _N.where(oo.dat[t0:t1, 1] == 0)[0] if epc == ep1: ### initialize labS, labH, lab, flatlabels, M, MF, hashthresh, nHSclusters = gAMxMu.initClusters(oo, K, x, mks, t0, t1, Asts, doSepHash=doSepHash, xLo=oo.xLo, xHi=oo.xHi, oneCluster=oo.oneCluster) # nHSclusters is # of clusters in hash and signal signalClusters = _N.where(flatlabels < nHSclusters[0])[0] Mwowonz = M if not oo.nzclstr else M + 1 ####### containers for GIBBS samples iterations smp_sp_prms = _N.zeros((3, ITERS, M)) smp_mk_prms = [_N.zeros((K, ITERS, M)), _N.zeros((K, K, ITERS, M))] smp_sp_hyps = _N.zeros((6, ITERS, M)) smp_mk_hyps = [_N.zeros((K, ITERS, M)), _N.zeros((K, K, ITERS, M)), _N.zeros((1, ITERS, M)), _N.zeros((K, K, ITERS, M))] oo.smp_sp_prms = smp_sp_prms oo.smp_mk_prms = smp_mk_prms oo.smp_sp_hyps = smp_sp_hyps oo.smp_mk_hyps = smp_mk_hyps if oo.nzclstr: smp_nz_l0 = _N.zeros(ITERS) smp_nz_hyps = _N.zeros((2, ITERS)) # list of freeClstrs freeClstr = _N.empty(M, dtype=_N.bool) # Actual cluster freeClstr[:] = False l0, f, q2, u, Sg = gAMxMu.declare_params(M, K, nzclstr=oo.nzclstr) # nzclstr not INITED, sized to include noise cluster if needed _l0_a, _l0_B, _f_u, _f_q2, _q2_a, _q2_B, _u_u, _u_Sg, _Sg_nu, \ _Sg_PSI = gAMxMu.declare_prior_hyp_params(M, MF, K, x, mks, Asts, t0) # hyper params don't include noise cluster gAMxMu.init_params_hyps(oo, M, MF, K, l0, f, q2, u, Sg, Asts, t0, x, mks, flatlabels, nzclstr=oo.nzclstr, signalClusters=signalClusters) ###### the hyperparameters for f, q2, u, Sg, l0 during Gibbs # f_u_, f_q2_, q2_a_, q2_B_, u_u_, u_Sg_, Sg_nu, Sg_PSI_, l0_a_, l0_B_ if oo.nzclstr: nz_l0_intgrd = _N.exp(-0.5*ux*ux / q2[Mwowonz-1]) _nz_l0_a = 0.001 _nz_l0_B = 0.1 NSexp = t1-t0 # length of position data # # of no spike positions to sum xt0t1 = _N.array(x[t0:t1]) nSpks = len(Asts) gz = _N.zeros((ITERS, nSpks, Mwowonz), dtype=_N.bool) oo.gz=gz print "spikes %d" % nSpks #dSilenceX = (NSexp/float(oo.Nupx))*(oo.xHi-oo.xLo) dSilenceX = NSexp*(xbns[1]-xbns[0]) # dx of histogram xAS = x[Asts + t0] # position @ spikes mAS = mks[Asts + t0] # position @ spikes xASr = xAS.reshape((1, nSpks)) #mASr = mAS.reshape((nSpks, 1, K)) mASr = mAS.reshape((1, nSpks, K)) econt = _N.empty((Mwowonz, nSpks)) rat = _N.zeros((Mwowonz+1, nSpks)) qdrMKS = _N.empty((Mwowonz, nSpks)) ################################ GIBBS ITERS ITERS ITERS # linalgerror #_iSg_Mu = _N.einsum("mjk,mk->mj", _N.linalg.inv(_u_Sg), _u_u) clusSz = _N.zeros(M, dtype=_N.int) _iu_Sg = _N.array(_u_Sg) for m in xrange(M): _iu_Sg[m] = _N.linalg.inv(_u_Sg[m]) ttA = _tm.time() for iter in xrange(ITERS): iSg = _N.linalg.inv(Sg) if (iter % 5) == 0: print "iter %d" % iter gAMxMu.stochasticAssignment(oo, iter, M, Mwowonz, K, l0, f, q2, u, Sg, _f_u, _u_u, Asts, t0, mASr, xASr, rat, econt, gz, qdrMKS, freeClstr, hashthresh, ((epc > 0) and (iter == 0))) # ############### FOR EACH CLUSTER for m in xrange(M): minds = _N.where(gz[iter, :, m] == 1)[0] sts = Asts[minds] + t0 nSpksM = len(sts) clusSz[m] = nSpksM ############### CONDITIONAL l0 # _ss.gamma.rvs. uses k, theta k is 1/B (B is our thing) iiq2 = 1./q2[m] # xI = (xt0t1-f[m])*(xt0t1-f[m])*0.5*iiq2 # BL = (oo.dt/_N.sqrt(twpi*q2[m]))*_N.sum(_N.exp(-xI)) # l0_intgrd (M x Nupx) l0_intgrd = _N.exp(-0.5*(f[m] - ux)*(f[m]-ux) * iiq2) l0_exp_px = _N.sum(l0_intgrd*px) * dSilenceX BL = (oo.dt/_N.sqrt(twpi*q2[m]))*l0_exp_px # # keep mode same after discount # a' - 1 / B' = MODE # mode is a - 1 / B # B' = (a' - 1) / MODE # discount a #if (epc > 0) and oo.adapt and (_l0_a[m] > 1.1): if (epc > 0) and oo.adapt: _md_nd= _l0_a[m] / _l0_B[m] _Dl0_a = _l0_a[m] * _N.exp(-dt/tau_l0) _Dl0_B = _Dl0_a / _md_nd else: _Dl0_a = _l0_a[m] _Dl0_B = _l0_B[m] # a'/B' = a/B # B' = (B/a)a' aL = nSpksM l0_a_ = aL + _Dl0_a l0_B_ = BL + _Dl0_B # print "------------------" # print "liklhd BL %(B).3f f %(f).3f a %(a)d B/a %(ba).3f" % {"B" : BL, "f" : f[m], "ba" : (aL/ BL), "a" : aL} # print "prior BL %(B).3f f %(f).3f a %(a)d B/a %(ba).3f" % {"B" : l0_B_, "f" : f[m], "ba" : (l0_a_/ l0_B_), "a" : l0_a_} # print (len(xt0t1)*oo.dt) # print "******************" #print "%(1).5f %(2).5f" % {"1" : l0_a_, "2" : l0_B_} try: l0[m] = _ss.gamma.rvs(l0_a_, scale=(1/l0_B_)) # check except ValueError: print "fail" print "M: %d" % M print "_l0_a[m] %.3f" % _l0_a[m] print "_l0_B[m] %.3f" % _l0_B[m] print "l0_a_ %.3f" % l0_a_ print "l0_B_ %.3f" % l0_B_ print "aL %.3f" % aL print "BL %.3f" % BL print "_Dl0_a %.3f" % _Dl0_a print "_Dl0_B %.3f" % _Dl0_B raise ### l0 / _N.sqrt(twpi*q2) is f*dt used in createData2 smp_sp_prms[oo.ky_p_l0, iter, m] = l0[m] smp_sp_hyps[oo.ky_h_l0_a, iter, m] = l0_a_ smp_sp_hyps[oo.ky_h_l0_B, iter, m] = l0_B_ mcs = _N.empty((M, K)) # cluster sample means if nSpksM >= K: u_Sg_ = _N.linalg.inv(_iu_Sg[m] + nSpksM*iSg[m]) clstx = mks[sts] mcs[m] = _N.mean(clstx, axis=0) #u_u_ = _N.einsum("jk,k->j", u_Sg_, _N.dot(_N.linalg.inv(_u_Sg[m]), _u_u[m]) + nSpksM*_N.dot(iSg[m], mcs[m])) #u_u_ = _N.einsum("jk,k->j", u_Sg_, _N.dot(_iu_Sg[m], _u_u[m]) + nSpksM*_N.dot(iSg[m], mcs[m])) # hyp ######## POSITION ## mean of posterior distribution of cluster means # sigma^2 and mu are the current Gibbs-sampled values ## mean of posterior distribution of cluster means else: u_Sg_ = _N.array(_u_Sg[m]) u_u_ = _N.array(_u_u[m]) u[m] = _N.random.multivariate_normal(u_u_, u_Sg_) smp_mk_prms[oo.ky_p_u][:, iter, m] = u[m] smp_mk_hyps[oo.ky_h_u_u][:, iter, m] = u_u_ smp_mk_hyps[oo.ky_h_u_Sg][:, :, iter, m] = u_Sg_ """ ############################################ """ ############### CONDITIONAL f #q2pr = _f_q2[m] if (_f_q2[m] > q2rate) else q2rate if (epc > 0) and oo.adapt: q2pr = _f_q2[m] + f_q2_rate * dt else: q2pr = _f_q2[m] if nSpksM > 0: # spiking portion likelihood x prior fs = (1./nSpksM)*_N.sum(xt0t1[sts-t0]) fq2 = q2[m]/nSpksM U = (fs*q2pr + _f_u[m]*fq2) / (q2pr + fq2) FQ2 = (q2pr*fq2) / (q2pr + fq2) else: U = _f_u[m] FQ2 = q2pr FQ = _N.sqrt(FQ2) fx = _N.linspace(U - FQ*15, U + FQ*15, oo.fss) if use_spc: fxr = fx.reshape((oo.fss, 1)) fxrux = -0.5*(fxr-ux)*(fxr-ux) # f_intgrd = _N.exp((fxrux*iiq2)) # integrand f_exp_px = _N.sum(f_intgrd*px, axis=1) * dSilenceX s = -(l0[m]*oo.dt/_N.sqrt(twpi*q2[m])) * f_exp_px # a function of x else: s = 0 funcf = -0.5*((fx-U)*(fx-U))/FQ2 + s funcf -= _N.max(funcf) condPosF= _N.exp(funcf) norm = 1./_N.sum(condPosF) f_u_ = norm*_N.sum(fx*condPosF) f_q2_ = norm*_N.sum(condPosF*(fx-f_u_)*(fx-f_u_)) f[m] = _N.sqrt(f_q2_)*_N.random.randn() + f_u_ smp_sp_prms[oo.ky_p_f, iter, m] = f[m] smp_sp_hyps[oo.ky_h_f_u, iter, m] = f_u_ smp_sp_hyps[oo.ky_h_f_q2, iter, m] = f_q2_ #ttc1g = _tm.time() ############# VARIANCE, COVARIANCE if nSpksM >= K: ## dof of posterior distribution of cluster covariance Sg_nu_ = _Sg_nu[m, 0] + nSpksM ## dof of posterior distribution of cluster covariance ur = u[m].reshape((1, K)) Sg_PSI_ = _Sg_PSI[m] + _N.dot((clstx - ur).T, (clstx-ur)) Sg[m] = s_u.sample_invwishart(Sg_PSI_, Sg_nu_) else: Sg_nu_ = _Sg_nu[m, 0] ## dof of posterior distribution of cluster covariance ur = u[m].reshape((1, K)) Sg_PSI_ = _Sg_PSI[m] Sg[m] = s_u.sample_invwishart(Sg_PSI_, Sg_nu_) ############## SAMPLE COVARIANCES ## dof of posterior distribution of cluster covariance smp_mk_prms[oo.ky_p_Sg][:, :, iter, m] = Sg[m] smp_mk_hyps[oo.ky_h_Sg_nu][0, iter, m] = Sg_nu_ smp_mk_hyps[oo.ky_h_Sg_PSI][:, :, iter, m] = Sg_PSI_ # ############### CONDITIONAL q2 #xI = (xt0t1-f)*(xt0t1-f)*0.5*iq2xr if use_spc: q2_intgrd = _N.exp(-0.5*(f[m] - ux)*(f[m]-ux) * iq2xr) q2_exp_px = _N.sum(q2_intgrd*px, axis=1) * dSilenceX # function of q2 s = -((l0[m]*oo.dt)/sqrt_2pi_q2x)*q2_exp_px else: s = 0 # B' / (a' - 1) = MODE #keep mode the same after discount # B' = MODE * (a' - 1) if (epc > 0) and oo.adapt: _md_nd= _q2_B[m] / (_q2_a[m] + 1) _Dq2_a = _q2_a[m] * _N.exp(-dt/tau_q2) _Dq2_B = _Dq2_a / _md_nd else: _Dq2_a = _q2_a[m] _Dq2_B = _q2_B[m] if nSpksM > 0: ## (1/sqrt(sg2))^S ## (1/x)^(S/2) = (1/x)-(a+1) ## -S/2 = -a - 1 -a = -S/2 + 1 a = S/2-1 xI = (xt0t1[sts-t0]-f[m])*(xt0t1[sts-t0]-f[m])*0.5 SL_a = 0.5*nSpksM - 1 # spiking part of likelihood SL_B = _N.sum(xI) # spiking part of likelihood # spiking prior x prior sLLkPr = -(_q2_a[m] + SL_a + 2)*lq2x - iq2x*(_q2_B[m] + SL_B) else: sLLkPr = -(_q2_a[m] + 1)*lq2x - iq2x*_q2_B[m] sat = sLLkPr + s sat -= _N.max(sat) condPos = _N.exp(sat) q2_a_, q2_B_ = ig_prmsUV(q2x, sLLkPr, s, d_q2x, q2x_m1, ITER=1, nSpksM=nSpksM, clstr=m, l0=l0[m]) # sat = sLLkPr + s # sat -= _N.max(sat) # condPos = _N.exp(sat) # q2_a_, q2_B_ = ig_prmsUV(q2x, condPos, d_q2x, q2x_m1, ITER=1) q2[m] = _ss.invgamma.rvs(q2_a_ + 1, scale=q2_B_) # check #q2[m] = 1.1**2 #print ((1./nSpks)*_N.sum((xt0t1[sts]-f)*(xt0t1[sts]-f))) if q2[m] < 0: print "******** q2[%(m)d] = %(q2).3f" % {"m" : m, "q2" : q2[m]} smp_sp_prms[oo.ky_p_q2, iter, m] = q2[m] smp_sp_hyps[oo.ky_h_q2_a, iter, m] = q2_a_ smp_sp_hyps[oo.ky_h_q2_B, iter, m] = q2_B_ if q2[m] < 0: print "^^^^^^^^ q2[%(m)d] = %(q2).3f" % {"m" : m, "q2" : q2[m]} print q2[m] print smp_sp_prms[oo.ky_p_q2, 0:iter+1, m] iiq2 = 1./q2[m] #ttc1h = _tm.time() # nz clstr. fixed width if oo.nzclstr: nz_l0_exp_px = _N.sum(nz_l0_intgrd*px) * dSilenceX BL = (oo.dt/_N.sqrt(twpi*q2[Mwowonz-1]))*nz_l0_exp_px minds = len(_N.where(gz[iter, :, Mwowonz-1] == 1)[0]) l0_a_ = minds + _nz_l0_a l0_B_ = BL + _nz_l0_B l0[Mwowonz-1] = _ss.gamma.rvs(l0_a_, scale=(1/l0_B_)) smp_nz_l0[iter] = l0[Mwowonz-1] smp_nz_hyps[0, iter] = l0_a_ smp_nz_hyps[1, iter] = l0_B_ ttB = _tm.time() print (ttB-ttA) ### THIS LEVEL: Finished Gibbs iters for epoch gAMxMu.finish_epoch(oo, nSpks, epc, ITERS, gz, l0, f, q2, u, Sg, _f_u, _f_q2, _q2_a, _q2_B, _l0_a, _l0_B, _u_u, _u_Sg, _Sg_nu, _Sg_PSI, smp_sp_hyps, smp_sp_prms, smp_mk_hyps, smp_mk_prms, freeClstr, M, K) # MAP of nzclstr if oo.nzclstr: frm = int(0.7*ITERS) _nz_l0_a = _N.median(smp_nz_hyps[0, frm:]) _nz_l0_B = _N.median(smp_nz_hyps[1, frm:]) pcklme["smp_sp_hyps"] = smp_sp_hyps pcklme["smp_mk_hyps"] = smp_mk_hyps pcklme["smp_sp_prms"] = smp_sp_prms pcklme["smp_mk_prms"] = smp_mk_prms pcklme["sp_prmPstMd"] = oo.sp_prmPstMd pcklme["mk_prmPstMd"] = oo.mk_prmPstMd pcklme["intvs"] = oo.intvs pcklme["occ"] = gz pcklme["nz_pth"] = nz_pth pcklme["M"] = M pcklme["Mwowonz"] = Mwowonz if Mwowonz > M: # or oo.nzclstr == True pcklme["smp_nz_l0"] = smp_nz_l0 pcklme["smp_nz_hyps"]= smp_nz_hyps dmp = open(resFN("posteriors_%d.dmp" % epc, dir=oo.outdir), "wb") pickle.dump(pcklme, dmp, -1) dmp.close()