def saveset(name, noparam=False): # u, B, singleFreqAR, dt, stNz, x, dN, prbs xprbsdN = _N.empty((N + 1, 3)) xprbsdN[:, 0] = x[:] xprbsdN[:, 1] = prbs[:] xprbsdN[:, 2] = dN[:] _N.savetxt(resFN("xprbsdN.dat", dir=name, create=True), xprbsdN, fmt="%.5e") if not noparam: fp = open(resFN("params.py", dir=name, create=True), "w") fp.write("u=%.3f\n" % u) fp.write("beta=%s\n" % arrstr(beta)) fp.write("ARcoeff=_N.array(%s)\n" % str(ARcoeff)) fp.write("alfa=_N.array(%s)\n" % str(alfa)) fp.write("# ampAngRep=%s\n" % ampAngRep(alfa)) fp.write("dt=%.2e\n" % dt) fp.write("stNz=%.3e\n" % stNz) fp.write("absrefr=%d\n" % absrefr) fp.close()
def initGibbs(self): ################################ INITGIBBS oo = self if oo.bpsth: oo.B = patsy.bs(_N.linspace(0, (oo.t1 - oo.t0) * oo.dt, (oo.t1 - oo.t0)), df=oo.dfPSTH, knots=oo.kntsPSTH, include_intercept=True) # spline basis if oo.dfPSTH is None: oo.dfPSTH = oo.B.shape[1] oo.B = oo.B.T # My convention for beta oo.aS = _N.linalg.solve( _N.dot(oo.B, oo.B.T), _N.dot(oo.B, _N.ones(oo.t1 - oo.t0) * _N.mean(oo.u))) # #generate initial values of parameters oo._d = _kfardat.KFARGauObsDat(oo.TR, oo.N, oo.k) oo._d.copyData(oo.y) sPR = "cmpref" if oo.use_prior == _cd.__FREQ_REF__: sPR = "frqref" elif oo.use_prior == _cd.__ONOF_REF__: sPR = "onfref" sAO = "sf" if (oo.ARord == _cd.__SF__) else "nf" ts = "[%(1)d-%(2)d]" % {"1": oo.t0, "2": oo.t1} baseFN = "rs=%(rs)d" % {"pr": sPR, "rs": oo.restarts} setdir = "%(sd)s/AR%(k)d_%(ts)s_%(pr)s_%(ao)s" % { "sd": oo.setname, "k": oo.k, "ts": ts, "pr": sPR, "ao": sAO } # baseFN_inter baseFN_comps baseFN_comps ############### oo.Bsmpx = _N.zeros((oo.TR, oo.NMC + oo.burn, (oo.N + 1) + 2)) oo.smp_u = _N.zeros((oo.TR, oo.burn + oo.NMC)) oo.smp_q2 = _N.zeros((oo.TR, oo.burn + oo.NMC)) oo.smp_x00 = _N.empty((oo.TR, oo.burn + oo.NMC - 1, oo.k)) # store samples of oo.allalfas = _N.empty((oo.burn + oo.NMC, oo.k), dtype=_N.complex) oo.uts = _N.empty((oo.TR, oo.burn + oo.NMC, oo.R, oo.N + 2)) oo.wts = _N.empty((oo.TR, oo.burn + oo.NMC, oo.C, oo.N + 3)) oo.ranks = _N.empty((oo.burn + oo.NMC, oo.C), dtype=_N.int) oo.pgs = _N.empty((oo.TR, oo.burn + oo.NMC, oo.N + 1)) oo.fs = _N.empty((oo.burn + oo.NMC, oo.C)) oo.amps = _N.empty((oo.burn + oo.NMC, oo.C)) if oo.bpsth: oo.smp_aS = _N.zeros((oo.burn + oo.NMC, oo.dfPSTH)) radians = buildLims(oo.Cn, oo.freq_lims, nzLimL=1.) oo.AR2lims = 2 * _N.cos(radians) if (oo.rs < 0): oo.smpx = _N.zeros( (oo.TR, (oo.N + 1) + 2, oo.k)) # start at 0 + u oo.ws = _N.empty((oo.TR, oo._d.N + 1), dtype=_N.float) oo.F_alfa_rep = initF(oo.R, oo.Cs, oo.Cn, ifs=oo.ifs).tolist() # init F_alfa_rep print "begin---" print ampAngRep(oo.F_alfa_rep) print "begin^^^" q20 = 1e-3 oo.q2 = _N.ones(oo.TR) * q20 oo.F0 = (-1 * _Npp.polyfromroots(oo.F_alfa_rep)[::-1].real)[1:] ######## Limit the amplitude to something reasonable xE, nul = createDataAR(oo.N, oo.F0, q20, 0.1) mlt = _N.std(xE) / 0.5 # we want amplitude around 0.5 oo.q2 /= mlt * mlt xE, nul = createDataAR(oo.N, oo.F0, oo.q2[0], 0.1) if oo.model == "Bernoulli": oo.initBernoulli() #smpx[0, 2:, 0] = x[0] ########## DEBUG #### initialize ws if starting for first time if oo.TR == 1: oo.ws = oo.ws.reshape(1, oo._d.N + 1) for m in xrange(oo._d.TR): lw.rpg_devroye(oo.rn, oo.smpx[m, 2:, 0] + oo.u[m], num=(oo.N + 1), out=oo.ws[m, :]) oo.smp_u[:, 0] = oo.u oo.smp_q2[:, 0] = oo.q2 if oo.bpsth: oo.u_a = _N.ones(oo.dfPSTH) * _N.mean(oo.u)
def dirichletAllocate(self): ########################### GIBBSSAMP oo = self ooTR = oo.TR print ooTR ook = oo.k ooNMC = oo.NMC ooN = oo.N oo.allocateSmp(oo.burn + oo.NMC) oo.x00 = _N.array(oo.smpx[:, 2]) oo.V00 = _N.zeros((ooTR, ook, ook)) oo.loghist = _N.zeros(oo.N + 1) ARo = _N.empty((ooTR, oo._d.N + 1)) kpOws = _N.empty((ooTR, ooN + 1)) lv_f = _N.zeros((ooN + 1, ooN + 1)) lv_u = _N.zeros((ooTR, ooTR)) Bii = _N.zeros((ooN + 1, ooN + 1)) #alpC.reverse() # F_alfa_rep = alpR + alpC already in right order, no? Wims = _N.empty((ooTR, ooN + 1, ooN + 1)) Oms = _N.empty((ooTR, ooN + 1)) smWimOm = _N.zeros(ooN + 1) smWinOn = _N.zeros(ooTR) bConstPSTH = False D_f = _N.diag(_N.ones(oo.B.shape[0]) * oo.s2_a) # spline iD_f = _N.linalg.inv(D_f) D_u = _N.diag(_N.ones(oo.TR) * oo.s2_u) # This should iD_u = _N.linalg.inv(D_u) iD_u_u_u = _N.dot(iD_u, _N.ones(oo.TR) * oo.u_u) BDB = _N.dot(oo.B.T, _N.dot(D_f, oo.B)) DB = _N.dot(D_f, oo.B) BTua = _N.dot(oo.B.T, oo.u_a) it = 0 ############################### MCMC LOOP ######################## ### need pointer to oo.us, but reshaped for broadcasting to work ############################### MCMC LOOP ######################## oous_rs = oo.us.reshape((ooTR, 1)) # done for broadcasting rules sd01 = _N.zeros((oo.nStates, oo.TR, oo.TR)) _N.fill_diagonal(sd01[0], oo.s[0]) _N.fill_diagonal(sd01[1], oo.s[1]) smpx01 = _N.zeros((oo.nStates, oo.TR, oo.N + 1)) zsmpx = _N.empty((oo.TR, oo.N + 1)) # zsmpx created # PG zd = _N.zeros((oo.TR, oo.TR)) izd = _N.zeros((oo.TR, oo.TR)) ll = _N.zeros(oo.nStates) Bp = _N.empty((oo.nStates, oo.N + 1)) for m in xrange(ooTR): oo._d.f_V[m, 0] = oo.s2_x00 oo._d.f_V[m, 1] = oo.s2_x00 THR = _N.empty(oo.TR) dirArgs = _N.empty(oo.nStates) # dirichlet distribution args expT = _N.empty(ooN + 1) BaS = _N.dot(oo.B.T, oo.aS) alpR = oo.F_alfa_rep[0:oo.R] alpC = oo.F_alfa_rep[oo.R:] oo.nSMP_smpxC = 0 if oo.processes > 1: print oo.processes pool = Pool(processes=oo.processes) print "oo.mcmcRunDir %s" % oo.mcmcRunDir if oo.mcmcRunDir is None: oo.mcmcRunDir = "" elif (len(oo.mcmcRunDir) > 0) and (oo.mcmcRunDir[-1] != "/"): oo.mcmcRunDir += "/" # H shape 100 x 9 Hbf = oo.Hbf RHS = _N.empty((oo.histknots, 1)) if oo.h0_1 > 1: # first few are 0s #cInds = _N.array([0, 1, 5, 6, 7, 8, 9, 10]) cInds = _N.array([0, 4, 5, 6, 7, 8, 9]) #vInds = _N.array([2, 3, 4]) vInds = _N.array([ 1, 2, 3, ]) RHS[cInds, 0] = 0 RHS[0, 0] = -5 else: #cInds = _N.array([5, 6, 7, 8, 9, 10]) cInds = _N.array([ 4, 5, 6, 7, 8, 9, ]) vInds = _N.array([ 0, 1, 2, 3, ]) #vInds = _N.array([0, 1, 2, 3, 4]) RHS[cInds, 0] = 0 Msts = [] for m in xrange(ooTR): Msts.append(_N.where(oo.y[m] == 1)[0]) HcM = _N.empty((len(vInds), len(vInds))) HbfExpd = _N.empty((oo.histknots, ooTR, oo.N + 1)) # HbfExpd is 11 x M x 1200 # find the mean. For the HISTORY TERM for i in xrange(oo.histknots): for m in xrange(oo.TR): sts = Msts[m] HbfExpd[i, m, 0:sts[0]] = 0 for iss in xrange(len(sts) - 1): t0 = sts[iss] t1 = sts[iss + 1] HbfExpd[i, m, t0 + 1:t1 + 1] = Hbf[0:t1 - t0, i] HbfExpd[i, m, sts[-1] + 1:] = 0 _N.dot(oo.B.T, oo.aS, out=BaS) if oo.hS is None: oo.hS = _N.zeros(oo.histknots) _N.dot(Hbf, oo.hS, out=oo.loghist) oo.stitch_Hist(ARo, oo.loghist, Msts) ## ORDER OF SAMPLING ## f_xx, f_V ## BINARY state ## DA: PG, kpOws ## history, build ARo ## psth ## offset ## DA: latent state ## AR coefficients ## q2 while (it < ooNMC + oo.burn - 1): lowsts = _N.where(oo.Z[:, 0] == 1) #print "lowsts %s" % str(lowsts) t1 = _tm.time() it += 1 print "****------------ %d" % it oo._d.f_x[:, 0, :, 0] = oo.x00 if it == 0: for m in xrange(ooTR): oo._d.f_V[m, 0] = oo.s2_x00 else: oo._d.f_V[:, 0] = _N.mean(oo._d.f_V[:, 1:], axis=1) # generate latent AR state if it > oo.startZ: for tryZ in xrange(oo.nStates): _N.dot(sd01[tryZ], oo.smpx[..., 2:, 0], out=smpx01[tryZ]) for m in oo.varz: for tryZ in xrange( oo.nStates ): # only allow certain trials to change # calculate p0, p1 p0 = m_0 x PROD_n Ber(y_n | Z_j) # = m_0 x _N.exp(_N.log( )) # p0, p1 not normalized ll[tryZ] = 0 # Ber(0 | ) and Ber(1 | ) _N.exp(smpx01[tryZ, m] + BaS + ARo[m] + oo.us[m] + oo.knownSig[m], out=expT) Bp[0] = 1 / (1 + expT) Bp[1] = expT / (1 + expT) # z[:, 1] is state label for n in xrange(oo.N + 1): ll[tryZ] += _N.log(Bp[oo.y[m, n], n]) ofs = _N.min(ll) ll -= ofs nc = oo.m[0] * _N.exp(ll[0]) + oo.m[1] * _N.exp(ll[1]) oo.Z[m, 0] = 0 oo.Z[m, 1] = 1 THR[m] = (oo.m[0] * _N.exp(ll[0]) / nc) if _N.random.rand() < THR[m]: oo.Z[m, 0] = 1 oo.Z[m, 1] = 0 oo.smp_zs[m, it] = oo.Z[m] for m in oo.fxdz: ##### outside BM loop oo.smp_zs[m, it] = oo.Z[m] t2 = _tm.time() # Z set _N.fill_diagonal(zd, oo.s[oo.Z[:, 1]]) _N.fill_diagonal(izd, 1. / oo.s[oo.Z[:, 1]]) #for kkk in xrange(oo.TR): # print zd[kkk, kkk] _N.dot(zd, oo.smpx[..., 2:, 0], out=zsmpx) ###### sample m's _N.add(oo.alp, _N.sum(oo.Z[oo.varz], axis=0), out=dirArgs) oo.m[:] = _N.random.dirichlet(dirArgs) oo.smp_ms[it] = oo.m else: _N.fill_diagonal(zd, oo.s[oo.Z[:, 1]]) _N.fill_diagonal(izd, 1. / oo.s[oo.Z[:, 1]]) _N.dot(zd, oo.smpx[..., 2:, 0], out=zsmpx) ###### sample m's oo.smp_ms[it] = oo.m oo.smp_zs[:, it, 1] = 1 oo.smp_zs[:, it, 0] = 0 print oo.m lwsts = _N.where(oo.Z[:, 0] == 1)[0] hists = _N.where(oo.Z[:, 1] == 1)[0] t3 = _tm.time() ###### PG generate for m in xrange(ooTR): lw.rpg_devroye(oo.rn, zsmpx[m] + oo.us[m] + BaS + ARo[m] + oo.knownSig[m], out=oo.ws[m]) ###### devryoe ####TRD change _N.divide(oo.kp, oo.ws, out=kpOws) if not oo.bFixH: O = kpOws - zsmpx - oo.us.reshape( (ooTR, 1)) - BaS - oo.knownSig iOf = vInds[0] # offset HcM index with RHS index. for i in vInds: for j in vInds: HcM[i - iOf, j - iOf] = _N.sum(oo.ws * HbfExpd[i] * HbfExpd[j]) RHS[i, 0] = _N.sum(oo.ws * HbfExpd[i] * O) for cj in cInds: RHS[i, 0] -= _N.sum( oo.ws * HbfExpd[i] * HbfExpd[cj]) * RHS[cj, 0] # print HbfExpd # print HcM # print RHS[vInds] vm = _N.linalg.solve(HcM, RHS[vInds]) Cov = _N.linalg.inv(HcM) print vm cfs = _N.random.multivariate_normal(vm[:, 0], Cov, size=1) RHS[vInds, 0] = cfs[0] oo.smp_hS[:, it] = RHS[:, 0] #RHS[2:6, 0] = vm[:, 0] #print HcM #vv = _N.dot(Hbf, RHS) #print vv.shape #print oo.loghist.shape _N.dot(Hbf, RHS[:, 0], out=oo.loghist) oo.smp_hist[:, it] = oo.loghist oo.stitch_Hist(ARo, oo.loghist, Msts) ######## PSTH sample Do PSTH after we generate zs if oo.bpsth: Oms = kpOws - zsmpx - ARo - oous_rs - oo.knownSig _N.einsum("mn,mn->n", oo.ws, Oms, out=smWimOm) # sum over ilv_f = _N.diag(_N.sum(oo.ws, axis=0)) _N.fill_diagonal(lv_f, 1. / _N.diagonal(ilv_f)) lm_f = _N.dot(lv_f, smWimOm) # nondiag of 1./Bi are inf # now sample iVAR = _N.dot(oo.B, _N.dot(ilv_f, oo.B.T)) + iD_f VAR = _N.linalg.inv(iVAR) # knots x knots #iBDBW = _N.linalg.inv(BDB + lv_f) # BDB not diag #Mn = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua)) Mn = oo.u_a + _N.dot(DB, _N.linalg.solve(BDB + lv_f, lm_f - BTua)) oo.aS = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :] oo.smp_aS[it, :] = oo.aS #iBDBW = _N.linalg.inv(BDB + lv_f) # BDB not diag #Mn = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua)) #oo.aS = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :] #oo.smp_aS[it, :] = oo.aS else: oo.aS[:] = 0 BaS = _N.dot(oo.B.T, oo.aS) ######## per trial offset sample Ons = kpOws - zsmpx - ARo - BaS - oo.knownSig # solve for the mean of the distribution H = _N.ones((oo.TR - 1, oo.TR - 1)) * _N.sum(oo.ws[0]) uRHS = _N.empty(oo.TR - 1) for dd in xrange(1, oo.TR): H[dd - 1, dd - 1] += _N.sum(oo.ws[dd]) uRHS[dd - 1] = _N.sum(oo.ws[dd] * Ons[dd] - oo.ws[0] * Ons[0]) MM = _N.linalg.solve(H, uRHS) Cov = _N.linalg.inv(H) oo.us[1:] = _N.random.multivariate_normal(MM, Cov, size=1) oo.us[0] = -_N.sum(oo.us[1:]) oo.smp_u[:, it] = oo.us t4 = _tm.time() #### Sample latent state oo._d.y = _N.dot(izd, kpOws - BaS - ARo - oous_rs - oo.knownSig) oo._d.copyParams(oo.F0, oo.q2) # (MxM) (MxN) = (MxN) (Rv is MxN) _N.dot(_N.dot(izd, izd), 1. / oo.ws, out=oo._d.Rv) oo._d.f_x[:, 0, :, 0] = oo.x00 #if it == 1: for m in xrange(ooTR): oo._d.f_V[m, 0] = oo.s2_x00 else: oo._d.f_V[:, 0] = _N.mean(oo._d.f_V[:, 1:], axis=1) tpl_args = zip(oo._d.y, oo._d.Rv, oo._d.Fs, _N.linalg.inv(oo._d.Fs, ), oo.q2, oo._d.Ns, oo._d.ks, oo._d.f_x[:, 0], oo._d.f_V[:, 0]) t5 = _tm.time() if oo.processes == 1: for m in xrange(ooTR): oo.smpx[m, 2:], oo._d.f_x[m], oo._d.f_V[ m] = _kfar.armdl_FFBS_1itrMP(tpl_args[m]) oo.smpx[m, 1, 0:ook - 1] = oo.smpx[m, 2, 1:] oo.smpx[m, 0, 0:ook - 2] = oo.smpx[m, 2, 2:] oo.smp_q2[m, it] = oo.q2[m] else: sxv = pool.map(_kfar.armdl_FFBS_1itrMP, tpl_args) for m in xrange(ooTR): oo.smpx[m, 2:] = sxv[m][0] oo._d.f_x[m] = sxv[m][1] oo._d.f_V[m] = sxv[m][2] oo.smpx[m, 1, 0:ook - 1] = oo.smpx[m, 2, 1:] oo.smpx[m, 0, 0:ook - 2] = oo.smpx[m, 2, 2:] #oo.Bsmpx[m, it, 2:] = oo.smpx[m, 2:, 0] stds = _N.std(oo.smpx[:, 2:, 0], axis=1) oo.mnStds[it] = _N.mean(stds, axis=0) print "mnStd %.3f" % oo.mnStds[it] ### if not oo.bFixF: ARcfSmpl(oo.lfc, ooN + 1, ook, oo.AR2lims, oo.smpx[:, 1:, 0:ook], oo.smpx[:, :, 0:ook - 1], oo.q2, oo.R, oo.Cs, oo.Cn, alpR, alpC, oo.TR, prior=oo.use_prior, accepts=80, aro=oo.ARord, sig_ph0L=oo.sig_ph0L, sig_ph0H=oo.sig_ph0H) oo.F_alfa_rep = alpR + alpC # new constructed prt, rank, f, amp = ampAngRep(oo.F_alfa_rep, f_order=True) print prt #ut, wt = FilteredTimeseries(ooN+1, ook, oo.smpx[:, 1:, 0:ook], oo.smpx[:, :, 0:ook-1], oo.q2, oo.R, oo.Cs, oo.Cn, alpR, alpC, oo.TR) #ranks[it] = rank oo.allalfas[it] = oo.F_alfa_rep for m in xrange(ooTR): #oo.wts[m, it, :, :] = wt[m, :, :, 0] #oo.uts[m, it, :, :] = ut[m, :, :, 0] if not oo.bFixF: oo.amps[it, :] = amp oo.fs[it, :] = f oo.F0 = (-1 * _Npp.polyfromroots(oo.F_alfa_rep)[::-1].real)[1:] print "len(lwsts) %(l)d len(hists) %(h)d" % { "l": len(lwsts), "h": len(hists) } sts2chg = hists if (it > oo.startZ) and oo.doS and len(sts2chg) > 0: AL = 0.5 * _N.sum(oo.smpx[sts2chg, 2:, 0] * oo.smpx[sts2chg, 2:, 0] * oo.ws[sts2chg]) BRL = kpOws[sts2chg] - BaS - oous_rs[sts2chg] - ARo[ sts2chg] - oo.knownSig[sts2chg] BL = _N.sum(oo.ws[sts2chg] * BRL * oo.smpx[sts2chg, 2:, 0]) UL = BL / (2 * AL) #sgL= 1/_N.sqrt(2*AL) sg2 = 1. / (2 * AL) q2_pr = 0.0025 # 0.05**2 u_pr = 1. U = (u_pr * sg2 + UL * q2_pr) / (sg2 + q2_pr) sg = _N.sqrt((sg2 * q2_pr) / (sg2 + q2_pr)) print "U %(U).4f UL %(UL).4f s %(s).3f" % { "U": U, "s": sg, "UL": UL } if _N.isnan(U): print "U is nan UL %.4f" % UL print "U is nan AL %.4f" % AL print "U is nan BL %.4f" % BL print "U is nan BaS " print "hists" print hists print "lwsts" print lwsts oo.s[1] = U + sg * _N.random.randn() _N.fill_diagonal(sd01[0], oo.s[0]) _N.fill_diagonal(sd01[1], oo.s[1]) print oo.s[1] oo.smp_ss[it] = oo.s[1] oo.a2 = oo.a_q2 + 0.5 * (ooTR * ooN + 2) # N + 1 - 1 BB2 = oo.B_q2 for m in xrange(ooTR): # set x00 #oo.x00[m] = oo.smpx[m, 2]*0.1 oo.x00[m] = oo.smpx[m, 2] * 0.001 ##################### sample q2 rsd_stp = oo.smpx[m, 3:, 0] - _N.dot(oo.smpx[m, 2:-1], oo.F0).T BB2 += 0.5 * _N.dot(rsd_stp, rsd_stp.T) oo.q2[:] = _ss.invgamma.rvs(oo.a2, scale=BB2) oo.smp_q2[:, it] = oo.q2 t7 = _tm.time() print "gibbs iter %.3f" % (t7 - t1) if (it > 1) and (it % oo.peek == 0): fig = _plt.figure(figsize=(12, 8)) fig.add_subplot(4, 1, 1) _plt.plot(oo.amps[1:it, 0]) fig.add_subplot(4, 1, 2) _plt.plot(oo.fs[1:it, 0]) fig.add_subplot(4, 1, 3) _plt.plot(oo.mnStds[1:it]) fig.add_subplot(4, 1, 4) _plt.plot(oo.smp_ms[1:it, 0]) _plt.savefig("%(dir)s/tmp-fsamps%(it)d" % { "dir": oo.mcmcRunDir, "it": it }) _plt.close() oo.dump_smpsS(toiter=it, dir=oo.mcmcRunDir) oo.dump_smpsS(dir=oo.mcmcRunDir)
def gibbsSamp(self): ########################### GIBBSSAMPH oo = self ooTR = oo.TR ook = oo.k ooN = oo.N _kfar.init(oo.N, oo.k, oo.TR) oo.x00 = _N.array(oo.smpx[:, 2]) oo.V00 = _N.zeros((ooTR, ook, ook)) if oo.dohist: oo.loghist = _N.zeros(oo.N + 1) else: print "fixed hist is" print oo.loghist print "oo.mcmcRunDir %s" % oo.mcmcRunDir if oo.mcmcRunDir is None: oo.mcmcRunDir = "" elif (len(oo.mcmcRunDir) > 0) and (oo.mcmcRunDir[-1] != "/"): oo.mcmcRunDir += "/" ARo = _N.zeros((ooTR, ooN + 1)) kpOws = _N.empty((ooTR, ooN + 1)) lv_f = _N.zeros((ooN + 1, ooN + 1)) lv_u = _N.zeros((ooTR, ooTR)) Bii = _N.zeros((ooN + 1, ooN + 1)) #alpC.reverse() # F_alfa_rep = alpR + alpC already in right order, no? Wims = _N.empty((ooTR, ooN + 1, ooN + 1)) Oms = _N.empty((ooTR, ooN + 1)) smWimOm = _N.zeros(ooN + 1) smWinOn = _N.zeros(ooTR) bConstPSTH = False D_f = _N.diag(_N.ones(oo.B.shape[0]) * oo.s2_a) # spline iD_f = _N.linalg.inv(D_f) D_u = _N.diag(_N.ones(oo.TR) * oo.s2_u) # This should iD_u = _N.linalg.inv(D_u) iD_u_u_u = _N.dot(iD_u, _N.ones(oo.TR) * oo.u_u) if oo.bpsth: BDB = _N.dot(oo.B.T, _N.dot(D_f, oo.B)) DB = _N.dot(D_f, oo.B) BTua = _N.dot(oo.B.T, oo.u_a) it = -1 oous_rs = oo.us.reshape((ooTR, 1)) #runTO = ooNMC + oo.burn - 1 if (burns is None) else (burns - 1) runTO = oo.ITERS - 1 oo.allocateSmp(runTO + 1, Bsmpx=oo.doBsmpx) alpR = oo.F_alfa_rep[0:oo.R] alpC = oo.F_alfa_rep[oo.R:] BaS = _N.zeros(oo.N + 1) #_N.empty(oo.N+1) # H shape 100 x 9 Hbf = oo.Hbf RHS = _N.empty((oo.histknots, 1)) print "----------- histknots %d" % oo.histknots if oo.h0_1 > 1: # no spikes in first few time bins #cInds = _N.array([0, 1, 5, 6, 7, 8, 9, 10]) cInds = _N.array([0, 4, 5, 6, 7, 8, 9]) #vInds = _N.array([2, 3, 4]) vInds = _N.array([ 1, 2, 3, ]) RHS[cInds, 0] = 0 RHS[0, 0] = -5 elif oo.hist_max_at_0: # no refractory period #cInds = _N.array([5, 6, 7, 8, 9, 10]) cInds = _N.array([ 3, 4, 5, 6, 7, 8, ]) vInds = _N.array([ 0, 1, 2, ]) #vInds = _N.array([0, 1, 2, 3, 4]) RHS[cInds, 0] = 0 else: #cInds = _N.array([5, 6, 7, 8, 9, 10]) cInds = _N.array([ 4, 5, 6, 7, 8, 9, ]) vInds = _N.array([ 0, 1, 2, 3, ]) #vInds = _N.array([0, 1, 2, 3, 4]) RHS[cInds, 0] = 0 Msts = [] for m in xrange(ooTR): Msts.append(_N.where(oo.y[m] == 1)[0]) HcM = _N.empty((len(vInds), len(vInds))) HbfExpd = _N.empty((oo.histknots, ooTR, oo.N + 1)) # HbfExpd is 11 x M x 1200 # find the mean. For the HISTORY TERM for i in xrange(oo.histknots): for m in xrange(oo.TR): sts = Msts[m] HbfExpd[i, m, 0:sts[0]] = 0 for iss in xrange(len(sts) - 1): t0 = sts[iss] t1 = sts[iss + 1] HbfExpd[i, m, t0 + 1:t1 + 1] = Hbf[0:t1 - t0, i] HbfExpd[i, m, sts[-1] + 1:] = 0 _N.dot(oo.B.T, oo.aS, out=BaS) if oo.hS is None: oo.hS = _N.zeros(oo.histknots) if oo.dohist: _N.dot(Hbf, oo.hS, out=oo.loghist) oo.stitch_Hist(ARo, oo.loghist, Msts) ## ORDER OF SAMPLING ## f_xx, f_V ## DA: PG, kpOws ## history, build ARo ## psth ## offset ## DA: latent state ## AR coefficients ## q2 K = _N.empty((oo.TR, oo.N + 1, oo.k)) # kalman gain iterBLOCKS = oo.ITERS / oo.peek smpx_tmp = _N.empty((oo.TR, oo.N + 1, oo.k)) for itrB in xrange(iterBLOCKS): for it in xrange(itrB * oo.peek, (itrB + 1) * oo.peek): ttt1 = _tm.time() if (it % 10) == 0: print it # generate latent AR state oo.f_x[:, 0] = oo.x00 if it == 0: for m in xrange(ooTR): oo.f_V[m, 0] = oo.s2_x00 else: oo.f_V[:, 0] = _N.mean(oo.f_V[:, 1:], axis=1) ### PG latent variable sample ttt2 = _tm.time() for m in xrange(ooTR): lw.rpg_devroye(oo.rn, oo.smpx[m, 2:, 0] + oo.us[m] + BaS + ARo[m] + oo.knownSig[m], out=oo.ws[m]) ###### devryoe ttt3 = _tm.time() if ooTR == 1: oo.ws = oo.ws.reshape(1, ooN + 1) _N.divide(oo.kp, oo.ws, out=kpOws) if oo.dohist: O = kpOws - oo.smpx[..., 2:, 0] - oo.us.reshape( (ooTR, 1)) - BaS - oo.knownSig iOf = vInds[0] # offset HcM index with RHS index. for i in vInds: for j in vInds: HcM[i - iOf, j - iOf] = _N.sum(oo.ws * HbfExpd[i] * HbfExpd[j]) RHS[i, 0] = _N.sum(oo.ws * HbfExpd[i] * O) for cj in cInds: RHS[i, 0] -= _N.sum( oo.ws * HbfExpd[i] * HbfExpd[cj]) * RHS[cj, 0] # print HbfExpd # print HcM # print RHS[vInds] vm = _N.linalg.solve(HcM, RHS[vInds]) Cov = _N.linalg.inv(HcM) #print vm cfs = _N.random.multivariate_normal(vm[:, 0], Cov, size=1) RHS[vInds, 0] = cfs[0] oo.smp_hS[:, it] = RHS[:, 0] #RHS[2:6, 0] = vm[:, 0] #vv = _N.dot(Hbf, RHS) #print vv.shape #print oo.loghist.shape _N.dot(Hbf, RHS[:, 0], out=oo.loghist) oo.smp_hist[:, it] = oo.loghist oo.stitch_Hist(ARo, oo.loghist, Msts) else: oo.smp_hist[:, it] = oo.loghist oo.stitch_Hist(ARo, oo.loghist, Msts) # Now that we have PG variables, construct Gaussian timeseries # ws(it+1) using u(it), F0(it), smpx(it) # cov matrix, prior of aS oo.gau_obs = kpOws - BaS - ARo - oous_rs - oo.knownSig oo.gau_var = 1 / oo.ws # time dependent noise if oo.bpsth: Oms = kpOws - oo.smpx[..., 2:, 0] - ARo - oous_rs - oo.knownSig _N.einsum("mn,mn->n", oo.ws, Oms, out=smWimOm) # sum over ilv_f = _N.diag(_N.sum(oo.ws, axis=0)) # diag(_N.linalg.inv(Bi)) == diag(1./Bi). Bii = inv(Bi) _N.fill_diagonal(lv_f, 1. / _N.diagonal(ilv_f)) lm_f = _N.dot(lv_f, smWimOm) # nondiag of 1./Bi are inf # now sample iVAR = _N.dot(oo.B, _N.dot(ilv_f, oo.B.T)) + iD_f t4a = _tm.time() VAR = _N.linalg.inv(iVAR) # knots x knots t4b = _tm.time() #iBDBW = _N.linalg.inv(BDB + lv_f) # BDB not diag #Mn = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua)) # BDB + lv_f (N+1 x N+1) # lm_f - BTua (N+1) Mn = oo.u_a + _N.dot( DB, _N.linalg.solve(BDB + lv_f, lm_f - BTua)) t4c = _tm.time() oo.aS = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :] oo.smp_aS[it, :] = oo.aS _N.dot(oo.B.T, oo.aS, out=BaS) ttt4 = _tm.time() ######## per trial offset sample burns==None, only psth fit Ons = kpOws - oo.smpx[..., 2:, 0] - ARo - BaS - oo.knownSig # solve for the mean of the distribution if not oo.bpsth: # if not doing PSTH, don't constrain offset, as there are no confounds controlling offset _N.einsum("mn,mn->m", oo.ws, Ons, out=smWinOn) # sum over trials ilv_u = _N.diag(_N.sum(oo.ws, axis=1)) # var of LL # diag(_N.linalg.inv(Bi)) == diag(1./Bi). Bii = inv(Bi) _N.fill_diagonal(lv_u, 1. / _N.diagonal(ilv_u)) lm_u = _N.dot( lv_u, smWinOn) # nondiag of 1./Bi are inf, mean LL # now sample iVAR = ilv_u + iD_u VAR = _N.linalg.inv(iVAR) # Mn = _N.dot(VAR, _N.dot(ilv_u, lm_u) + iD_u_u_u) oo.us[:] = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :] if not oo.bIndOffset: oo.us[:] = _N.mean(oo.us) oo.smp_u[:, it] = oo.us else: H = _N.ones((oo.TR - 1, oo.TR - 1)) * _N.sum(oo.ws[0]) uRHS = _N.empty(oo.TR - 1) for dd in xrange(1, oo.TR): H[dd - 1, dd - 1] += _N.sum(oo.ws[dd]) uRHS[dd - 1] = _N.sum(oo.ws[dd] * Ons[dd] - oo.ws[0] * Ons[0]) MM = _N.linalg.solve(H, uRHS) Cov = _N.linalg.inv(H) oo.us[1:] = _N.random.multivariate_normal(MM, Cov, size=1) oo.us[0] = -_N.sum(oo.us[1:]) if not oo.bIndOffset: oo.us[:] = _N.mean(oo.us) oo.smp_u[:, it] = oo.us # Ons = kpOws - ARo # _N.einsum("mn,mn->m", oo.ws, Ons, out=smWinOn) # sum over trials # ilv_u = _N.diag(_N.sum(oo.ws, axis=1)) # var of LL # # diag(_N.linalg.inv(Bi)) == diag(1./Bi). Bii = inv(Bi) # _N.fill_diagonal(lv_u, 1./_N.diagonal(ilv_u)) # lm_u = _N.dot(lv_u, smWinOn) # nondiag of 1./Bi are inf, mean LL # # now sample # iVAR = ilv_u + iD_u # VAR = _N.linalg.inv(iVAR) # # Mn = _N.dot(VAR, _N.dot(ilv_u, lm_u) + iD_u_u_u) # oo.us[:] = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :] # if not oo.bIndOffset: # oo.us[:] = _N.mean(oo.us) # oo.smp_u[:, it] = oo.us ttt5 = _tm.time() if not oo.noAR: # _d.F, _d.N, _d.ks, #_kfar.armdl_FFBS_1itrMP(oo.gau_obs, oo.gau_var, oo.Fs, _N.linalg.inv(oo.Fs), oo.q2, oo.Ns, oo.ks, oo.f_x, oo.f_V, oo.p_x, oo.p_V, oo.smpx, K) _kfar.armdl_FFBS_1itrMP(oo.gau_obs, oo.gau_var, oo.Fs, _N.linalg.inv(oo.Fs), oo.q2, oo.Ns, oo.ks, oo.f_x, oo.f_V, oo.p_x, oo.p_V, smpx_tmp, K) oo.smpx[:, 2:] = smpx_tmp oo.smpx[:, 1, 0:ook - 1] = oo.smpx[:, 2, 1:] oo.smpx[:, 0, 0:ook - 2] = oo.smpx[:, 2, 2:] if oo.doBsmpx and (it % oo.BsmpxSkp == 0): oo.Bsmpx[:, it / oo.BsmpxSkp, 2:] = oo.smpx[:, 2:, 0] stds = _N.std(oo.smpx[:, 2 + oo.ignr:, 0], axis=1) oo.mnStds[it] = _N.mean(stds, axis=0) print "mnStd %.3f" % oo.mnStds[it] ttt6 = _tm.time() if not oo.bFixF: ARcfSmpl(oo.lfc, ooN + 1 - oo.ignr, ook, oo.AR2lims, oo.smpx[:, 1 + oo.ignr:, 0:ook], oo.smpx[:, oo.ignr:, 0:ook - 1], oo.q2, oo.R, oo.Cs, oo.Cn, alpR, alpC, oo.TR, prior=oo.use_prior, accepts=50, aro=oo.ARord, sig_ph0L=oo.sig_ph0L, sig_ph0H=oo.sig_ph0H) oo.F_alfa_rep = alpR + alpC # new constructed prt, rank, f, amp = ampAngRep(oo.F_alfa_rep, f_order=True) #print prt #ut, wt = FilteredTimeseries(ooN+1, ook, oo.smpx[:, 1:, 0:ook], oo.smpx[:, :, 0:ook-1], oo.q2, oo.R, oo.Cs, oo.Cn, alpR, alpC, oo.TR) #ranks[it] = rank oo.allalfas[it] = oo.F_alfa_rep for m in xrange(ooTR): #oo.wts[m, it, :, :] = wt[m, :, :, 0] #oo.uts[m, it, :, :] = ut[m, :, :, 0] if not oo.bFixF: oo.amps[it, :] = amp oo.fs[it, :] = f oo.F0 = (-1 * _Npp.polyfromroots(oo.F_alfa_rep)[::-1].real)[1:] for tr in xrange(oo.TR): oo.Fs[tr, 0] = oo.F0[:] # sample u WE USED TO Do this after smpx # u(it+1) using ws(it+1), F0(it), smpx(it+1), ws(it+1) if oo.ID_q2: for m in xrange(ooTR): ##################### sample q2 a = oo.a_q2 + 0.5 * (ooN + 1) # N + 1 - 1 rsd_stp = oo.smpx[m, 3 + oo.ignr:, 0] - _N.dot( oo.smpx[m, 2 + oo.ignr:-1], oo.F0).T BB = oo.B_q2 + 0.5 * _N.dot(rsd_stp, rsd_stp.T) oo.q2[m] = _ss.invgamma.rvs(a, scale=BB) oo.x00[m] = oo.smpx[m, 2] * 0.1 oo.smp_q2[m, it] = oo.q2[m] else: #oo.a2 = oo.a_q2 + 0.5*(ooTR*ooN + 2) # N + 1 - 1 oo.a2 = 0.5 * (ooTR * (ooN - oo.ignr) + 2) # N + 1 - 1 #BB2 = oo.B_q2 BB2 = 0 for m in xrange(ooTR): # set x00 oo.x00[m] = oo.smpx[m, 2] * 0.1 ##################### sample q2 rsd_stp = oo.smpx[m, 3 + oo.ignr:, 0] - _N.dot( oo.smpx[m, 2 + oo.ignr:-1], oo.F0).T #oo.rsds[it, m] = _N.dot(rsd_stp, rsd_stp.T) BB2 += 0.5 * _N.dot(rsd_stp, rsd_stp.T) oo.q2[:] = _ss.invgamma.rvs(oo.a2, scale=BB2) oo.smp_q2[:, it] = oo.q2 ttt7 = _tm.time() # print ("t2-t1 %.4f" % (t2-t1)) # print ("t3-t2 %.4f" % (t3-t2)) # print ("t4-t3 %.4f" % (t4-t3)) # # print ("t4b-t4a %.4f" % (t4b-t4a)) # # print ("t4c-t4b %.4f" % (t4c-t4b)) # # print ("t4-t4c %.4f" % (t4-t4c)) # print ("t5-t4 %.4f" % (t5-t4)) # print ("t6-t5 %.4f" % (t6-t5)) # print ("t7-t6 %.4f" % (t7-t6)) if it > oo.minITERS: smps = _N.empty((3, it + 1)) smps[0, :it + 1] = oo.amps[:it + 1, 0] smps[1, :it + 1] = oo.fs[:it + 1, 0] smps[2, :it + 1] = oo.mnStds[:it + 1] frms = _mg.stationary_from_Z_bckwd(smps, blksz=oo.peek) fig = _plt.figure(figsize=(8, 8)) fig.add_subplot(3, 1, 1) _plt.plot(range(1, it), oo.amps[1:it, 0], color="grey", lw=1.5) _plt.plot(range(frms, it), oo.amps[frms:it, 0], color="black", lw=3) _plt.ylabel("amp") fig.add_subplot(3, 1, 2) _plt.plot(range(1, it), oo.fs[1:it, 0] / (2 * oo.dt), color="grey", lw=1.5) _plt.plot(range(frms, it), oo.fs[frms:it, 0] / (2 * oo.dt), color="black", lw=3) _plt.ylabel("f") fig.add_subplot(3, 1, 3) _plt.plot(range(1, it), oo.mnStds[1:it], color="grey", lw=1.5) _plt.plot(range(frms, it), oo.mnStds[frms:it], color="black", lw=3) _plt.ylabel("amp") _plt.xlabel("iter") _plt.savefig("%(dir)stmp-fsamps%(it)d" % { "dir": oo.mcmcRunDir, "it": it + 1 }) fig.subplots_adjust(left=0.15, bottom=0.15, right=0.95, top=0.95) _plt.close() if it - frms > oo.stationaryDuration: break oo.dump_smps(frms, toiter=(it + 1), dir=oo.mcmcRunDir) oo.VIS = ARo