Пример #1
0
def saveset(name, noparam=False):
    #  u, B, singleFreqAR, dt, stNz, x, dN, prbs
    xprbsdN = _N.empty((N + 1, 3))
    xprbsdN[:, 0] = x[:]
    xprbsdN[:, 1] = prbs[:]
    xprbsdN[:, 2] = dN[:]

    _N.savetxt(resFN("xprbsdN.dat", dir=name, create=True),
               xprbsdN,
               fmt="%.5e")

    if not noparam:
        fp = open(resFN("params.py", dir=name, create=True), "w")
        fp.write("u=%.3f\n" % u)
        fp.write("beta=%s\n" % arrstr(beta))
        fp.write("ARcoeff=_N.array(%s)\n" % str(ARcoeff))
        fp.write("alfa=_N.array(%s)\n" % str(alfa))
        fp.write("#  ampAngRep=%s\n" % ampAngRep(alfa))
        fp.write("dt=%.2e\n" % dt)
        fp.write("stNz=%.3e\n" % stNz)
        fp.write("absrefr=%d\n" % absrefr)
        fp.close()
Пример #2
0
    def initGibbs(self):  ################################ INITGIBBS
        oo = self
        if oo.bpsth:
            oo.B = patsy.bs(_N.linspace(0, (oo.t1 - oo.t0) * oo.dt,
                                        (oo.t1 - oo.t0)),
                            df=oo.dfPSTH,
                            knots=oo.kntsPSTH,
                            include_intercept=True)  #  spline basis
            if oo.dfPSTH is None:
                oo.dfPSTH = oo.B.shape[1]
            oo.B = oo.B.T  #  My convention for beta
            oo.aS = _N.linalg.solve(
                _N.dot(oo.B, oo.B.T),
                _N.dot(oo.B,
                       _N.ones(oo.t1 - oo.t0) * _N.mean(oo.u)))

        # #generate initial values of parameters
        oo._d = _kfardat.KFARGauObsDat(oo.TR, oo.N, oo.k)
        oo._d.copyData(oo.y)

        sPR = "cmpref"
        if oo.use_prior == _cd.__FREQ_REF__:
            sPR = "frqref"
        elif oo.use_prior == _cd.__ONOF_REF__:
            sPR = "onfref"
        sAO = "sf" if (oo.ARord == _cd.__SF__) else "nf"

        ts = "[%(1)d-%(2)d]" % {"1": oo.t0, "2": oo.t1}
        baseFN = "rs=%(rs)d" % {"pr": sPR, "rs": oo.restarts}
        setdir = "%(sd)s/AR%(k)d_%(ts)s_%(pr)s_%(ao)s" % {
            "sd": oo.setname,
            "k": oo.k,
            "ts": ts,
            "pr": sPR,
            "ao": sAO
        }

        #  baseFN_inter   baseFN_comps   baseFN_comps

        ###############

        oo.Bsmpx = _N.zeros((oo.TR, oo.NMC + oo.burn, (oo.N + 1) + 2))
        oo.smp_u = _N.zeros((oo.TR, oo.burn + oo.NMC))
        oo.smp_q2 = _N.zeros((oo.TR, oo.burn + oo.NMC))
        oo.smp_x00 = _N.empty((oo.TR, oo.burn + oo.NMC - 1, oo.k))
        #  store samples of
        oo.allalfas = _N.empty((oo.burn + oo.NMC, oo.k), dtype=_N.complex)
        oo.uts = _N.empty((oo.TR, oo.burn + oo.NMC, oo.R, oo.N + 2))
        oo.wts = _N.empty((oo.TR, oo.burn + oo.NMC, oo.C, oo.N + 3))
        oo.ranks = _N.empty((oo.burn + oo.NMC, oo.C), dtype=_N.int)
        oo.pgs = _N.empty((oo.TR, oo.burn + oo.NMC, oo.N + 1))
        oo.fs = _N.empty((oo.burn + oo.NMC, oo.C))
        oo.amps = _N.empty((oo.burn + oo.NMC, oo.C))
        if oo.bpsth:
            oo.smp_aS = _N.zeros((oo.burn + oo.NMC, oo.dfPSTH))

        radians = buildLims(oo.Cn, oo.freq_lims, nzLimL=1.)
        oo.AR2lims = 2 * _N.cos(radians)

        if (oo.rs < 0):
            oo.smpx = _N.zeros(
                (oo.TR, (oo.N + 1) + 2, oo.k))  #  start at 0 + u
            oo.ws = _N.empty((oo.TR, oo._d.N + 1), dtype=_N.float)

            oo.F_alfa_rep = initF(oo.R, oo.Cs, oo.Cn,
                                  ifs=oo.ifs).tolist()  #  init F_alfa_rep

            print "begin---"
            print ampAngRep(oo.F_alfa_rep)
            print "begin^^^"
            q20 = 1e-3
            oo.q2 = _N.ones(oo.TR) * q20

            oo.F0 = (-1 * _Npp.polyfromroots(oo.F_alfa_rep)[::-1].real)[1:]
            ########  Limit the amplitude to something reasonable
            xE, nul = createDataAR(oo.N, oo.F0, q20, 0.1)
            mlt = _N.std(xE) / 0.5  #  we want amplitude around 0.5
            oo.q2 /= mlt * mlt
            xE, nul = createDataAR(oo.N, oo.F0, oo.q2[0], 0.1)

            if oo.model == "Bernoulli":
                oo.initBernoulli()
            #smpx[0, 2:, 0] = x[0]    ##########  DEBUG

            ####  initialize ws if starting for first time
            if oo.TR == 1:
                oo.ws = oo.ws.reshape(1, oo._d.N + 1)
            for m in xrange(oo._d.TR):
                lw.rpg_devroye(oo.rn,
                               oo.smpx[m, 2:, 0] + oo.u[m],
                               num=(oo.N + 1),
                               out=oo.ws[m, :])

        oo.smp_u[:, 0] = oo.u
        oo.smp_q2[:, 0] = oo.q2

        if oo.bpsth:
            oo.u_a = _N.ones(oo.dfPSTH) * _N.mean(oo.u)
Пример #3
0
    def dirichletAllocate(self):  ###########################  GIBBSSAMP
        oo = self
        ooTR = oo.TR
        print ooTR
        ook = oo.k
        ooNMC = oo.NMC
        ooN = oo.N

        oo.allocateSmp(oo.burn + oo.NMC)
        oo.x00 = _N.array(oo.smpx[:, 2])
        oo.V00 = _N.zeros((ooTR, ook, ook))

        oo.loghist = _N.zeros(oo.N + 1)

        ARo = _N.empty((ooTR, oo._d.N + 1))

        kpOws = _N.empty((ooTR, ooN + 1))
        lv_f = _N.zeros((ooN + 1, ooN + 1))
        lv_u = _N.zeros((ooTR, ooTR))
        Bii = _N.zeros((ooN + 1, ooN + 1))

        #alpC.reverse()
        #  F_alfa_rep = alpR + alpC  already in right order, no?

        Wims = _N.empty((ooTR, ooN + 1, ooN + 1))
        Oms = _N.empty((ooTR, ooN + 1))
        smWimOm = _N.zeros(ooN + 1)
        smWinOn = _N.zeros(ooTR)
        bConstPSTH = False
        D_f = _N.diag(_N.ones(oo.B.shape[0]) * oo.s2_a)  #  spline
        iD_f = _N.linalg.inv(D_f)
        D_u = _N.diag(_N.ones(oo.TR) * oo.s2_u)  #  This should
        iD_u = _N.linalg.inv(D_u)
        iD_u_u_u = _N.dot(iD_u, _N.ones(oo.TR) * oo.u_u)
        BDB = _N.dot(oo.B.T, _N.dot(D_f, oo.B))
        DB = _N.dot(D_f, oo.B)
        BTua = _N.dot(oo.B.T, oo.u_a)

        it = 0

        ###############################  MCMC LOOP  ########################
        ###  need pointer to oo.us, but reshaped for broadcasting to work
        ###############################  MCMC LOOP  ########################
        oous_rs = oo.us.reshape((ooTR, 1))  #  done for broadcasting rules

        sd01 = _N.zeros((oo.nStates, oo.TR, oo.TR))
        _N.fill_diagonal(sd01[0], oo.s[0])
        _N.fill_diagonal(sd01[1], oo.s[1])

        smpx01 = _N.zeros((oo.nStates, oo.TR, oo.N + 1))
        zsmpx = _N.empty((oo.TR, oo.N + 1))

        #  zsmpx created
        #  PG

        zd = _N.zeros((oo.TR, oo.TR))
        izd = _N.zeros((oo.TR, oo.TR))
        ll = _N.zeros(oo.nStates)
        Bp = _N.empty((oo.nStates, oo.N + 1))

        for m in xrange(ooTR):
            oo._d.f_V[m, 0] = oo.s2_x00
            oo._d.f_V[m, 1] = oo.s2_x00

        THR = _N.empty(oo.TR)
        dirArgs = _N.empty(oo.nStates)  #  dirichlet distribution args
        expT = _N.empty(ooN + 1)
        BaS = _N.dot(oo.B.T, oo.aS)

        alpR = oo.F_alfa_rep[0:oo.R]
        alpC = oo.F_alfa_rep[oo.R:]

        oo.nSMP_smpxC = 0
        if oo.processes > 1:
            print oo.processes
            pool = Pool(processes=oo.processes)
        print "oo.mcmcRunDir    %s" % oo.mcmcRunDir
        if oo.mcmcRunDir is None:
            oo.mcmcRunDir = ""
        elif (len(oo.mcmcRunDir) > 0) and (oo.mcmcRunDir[-1] != "/"):
            oo.mcmcRunDir += "/"

        #  H shape    100 x 9
        Hbf = oo.Hbf

        RHS = _N.empty((oo.histknots, 1))

        if oo.h0_1 > 1:  #  first few are 0s
            #cInds = _N.array([0, 1, 5, 6, 7, 8, 9, 10])
            cInds = _N.array([0, 4, 5, 6, 7, 8, 9])
            #vInds = _N.array([2, 3, 4])
            vInds = _N.array([
                1,
                2,
                3,
            ])
            RHS[cInds, 0] = 0
            RHS[0, 0] = -5
        else:
            #cInds = _N.array([5, 6, 7, 8, 9, 10])
            cInds = _N.array([
                4,
                5,
                6,
                7,
                8,
                9,
            ])
            vInds = _N.array([
                0,
                1,
                2,
                3,
            ])
            #vInds = _N.array([0, 1, 2, 3, 4])
            RHS[cInds, 0] = 0

        Msts = []
        for m in xrange(ooTR):
            Msts.append(_N.where(oo.y[m] == 1)[0])
        HcM = _N.empty((len(vInds), len(vInds)))

        HbfExpd = _N.empty((oo.histknots, ooTR, oo.N + 1))
        #  HbfExpd is 11 x M x 1200
        #  find the mean.  For the HISTORY TERM
        for i in xrange(oo.histknots):
            for m in xrange(oo.TR):
                sts = Msts[m]
                HbfExpd[i, m, 0:sts[0]] = 0
                for iss in xrange(len(sts) - 1):
                    t0 = sts[iss]
                    t1 = sts[iss + 1]
                    HbfExpd[i, m, t0 + 1:t1 + 1] = Hbf[0:t1 - t0, i]
                HbfExpd[i, m, sts[-1] + 1:] = 0

        _N.dot(oo.B.T, oo.aS, out=BaS)
        if oo.hS is None:
            oo.hS = _N.zeros(oo.histknots)

        _N.dot(Hbf, oo.hS, out=oo.loghist)
        oo.stitch_Hist(ARo, oo.loghist, Msts)

        ##  ORDER OF SAMPLING
        ##  f_xx, f_V
        ##  BINARY state
        ##  DA:  PG, kpOws
        ##  history, build ARo
        ##  psth
        ##  offset
        ##  DA:  latent state
        ##  AR coefficients
        ##  q2

        while (it < ooNMC + oo.burn - 1):
            lowsts = _N.where(oo.Z[:, 0] == 1)
            #print "lowsts   %s" % str(lowsts)
            t1 = _tm.time()
            it += 1
            print "****------------  %d" % it
            oo._d.f_x[:, 0, :, 0] = oo.x00
            if it == 0:
                for m in xrange(ooTR):
                    oo._d.f_V[m, 0] = oo.s2_x00
            else:
                oo._d.f_V[:, 0] = _N.mean(oo._d.f_V[:, 1:], axis=1)

            #  generate latent AR state

            if it > oo.startZ:
                for tryZ in xrange(oo.nStates):
                    _N.dot(sd01[tryZ], oo.smpx[..., 2:, 0], out=smpx01[tryZ])

                for m in oo.varz:
                    for tryZ in xrange(
                            oo.nStates
                    ):  #  only allow certain trials to change

                        #  calculate p0, p1  p0 = m_0 x PROD_n Ber(y_n | Z_j)
                        #                       = m_0 x _N.exp(_N.log(  ))
                        #  p0, p1 not normalized
                        ll[tryZ] = 0
                        #  Ber(0 | ) and Ber(1 | )
                        _N.exp(smpx01[tryZ, m] + BaS + ARo[m] + oo.us[m] +
                               oo.knownSig[m],
                               out=expT)
                        Bp[0] = 1 / (1 + expT)
                        Bp[1] = expT / (1 + expT)

                        #   z[:, 1]   is state label

                        for n in xrange(oo.N + 1):
                            ll[tryZ] += _N.log(Bp[oo.y[m, n], n])

                    ofs = _N.min(ll)
                    ll -= ofs
                    nc = oo.m[0] * _N.exp(ll[0]) + oo.m[1] * _N.exp(ll[1])

                    oo.Z[m, 0] = 0
                    oo.Z[m, 1] = 1
                    THR[m] = (oo.m[0] * _N.exp(ll[0]) / nc)
                    if _N.random.rand() < THR[m]:
                        oo.Z[m, 0] = 1
                        oo.Z[m, 1] = 0
                    oo.smp_zs[m, it] = oo.Z[m]

                for m in oo.fxdz:  #####  outside BM loop
                    oo.smp_zs[m, it] = oo.Z[m]
                t2 = _tm.time()

                #  Z  set
                _N.fill_diagonal(zd, oo.s[oo.Z[:, 1]])
                _N.fill_diagonal(izd, 1. / oo.s[oo.Z[:, 1]])
                #for kkk in xrange(oo.TR):
                #    print zd[kkk, kkk]
                _N.dot(zd, oo.smpx[..., 2:, 0], out=zsmpx)
                ######  sample m's
                _N.add(oo.alp, _N.sum(oo.Z[oo.varz], axis=0), out=dirArgs)
                oo.m[:] = _N.random.dirichlet(dirArgs)
                oo.smp_ms[it] = oo.m
            else:
                _N.fill_diagonal(zd, oo.s[oo.Z[:, 1]])
                _N.fill_diagonal(izd, 1. / oo.s[oo.Z[:, 1]])

                _N.dot(zd, oo.smpx[..., 2:, 0], out=zsmpx)
                ######  sample m's
                oo.smp_ms[it] = oo.m
                oo.smp_zs[:, it, 1] = 1
                oo.smp_zs[:, it, 0] = 0
            print oo.m

            lwsts = _N.where(oo.Z[:, 0] == 1)[0]
            hists = _N.where(oo.Z[:, 1] == 1)[0]

            t3 = _tm.time()

            ######  PG generate
            for m in xrange(ooTR):
                lw.rpg_devroye(oo.rn,
                               zsmpx[m] + oo.us[m] + BaS + ARo[m] +
                               oo.knownSig[m],
                               out=oo.ws[m])  ######  devryoe  ####TRD change

            _N.divide(oo.kp, oo.ws, out=kpOws)

            if not oo.bFixH:
                O = kpOws - zsmpx - oo.us.reshape(
                    (ooTR, 1)) - BaS - oo.knownSig

                iOf = vInds[0]  #  offset HcM index with RHS index.
                for i in vInds:
                    for j in vInds:
                        HcM[i - iOf,
                            j - iOf] = _N.sum(oo.ws * HbfExpd[i] * HbfExpd[j])

                    RHS[i, 0] = _N.sum(oo.ws * HbfExpd[i] * O)
                    for cj in cInds:
                        RHS[i, 0] -= _N.sum(
                            oo.ws * HbfExpd[i] * HbfExpd[cj]) * RHS[cj, 0]

                # print HbfExpd
                # print HcM
                # print RHS[vInds]
                vm = _N.linalg.solve(HcM, RHS[vInds])
                Cov = _N.linalg.inv(HcM)
                print vm
                cfs = _N.random.multivariate_normal(vm[:, 0], Cov, size=1)

                RHS[vInds, 0] = cfs[0]
                oo.smp_hS[:, it] = RHS[:, 0]

                #RHS[2:6, 0] = vm[:, 0]
                #print HcM
                #vv = _N.dot(Hbf, RHS)
                #print vv.shape
                #print oo.loghist.shape
                _N.dot(Hbf, RHS[:, 0], out=oo.loghist)
                oo.smp_hist[:, it] = oo.loghist
                oo.stitch_Hist(ARo, oo.loghist, Msts)

            ########     PSTH sample  Do PSTH after we generate zs
            if oo.bpsth:
                Oms = kpOws - zsmpx - ARo - oous_rs - oo.knownSig
                _N.einsum("mn,mn->n", oo.ws, Oms, out=smWimOm)  #  sum over
                ilv_f = _N.diag(_N.sum(oo.ws, axis=0))
                _N.fill_diagonal(lv_f, 1. / _N.diagonal(ilv_f))
                lm_f = _N.dot(lv_f, smWimOm)  #  nondiag of 1./Bi are inf
                #  now sample
                iVAR = _N.dot(oo.B, _N.dot(ilv_f, oo.B.T)) + iD_f
                VAR = _N.linalg.inv(iVAR)  #  knots x knots
                #iBDBW = _N.linalg.inv(BDB + lv_f)   # BDB not diag
                #Mn    = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua))

                Mn = oo.u_a + _N.dot(DB,
                                     _N.linalg.solve(BDB + lv_f, lm_f - BTua))
                oo.aS = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :]
                oo.smp_aS[it, :] = oo.aS

                #iBDBW = _N.linalg.inv(BDB + lv_f)   # BDB not diag
                #Mn    = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua))
                #oo.aS   = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :]
                #oo.smp_aS[it, :] = oo.aS
            else:
                oo.aS[:] = 0
            BaS = _N.dot(oo.B.T, oo.aS)

            ########     per trial offset sample
            Ons = kpOws - zsmpx - ARo - BaS - oo.knownSig

            #  solve for the mean of the distribution
            H = _N.ones((oo.TR - 1, oo.TR - 1)) * _N.sum(oo.ws[0])
            uRHS = _N.empty(oo.TR - 1)
            for dd in xrange(1, oo.TR):
                H[dd - 1, dd - 1] += _N.sum(oo.ws[dd])
                uRHS[dd - 1] = _N.sum(oo.ws[dd] * Ons[dd] - oo.ws[0] * Ons[0])

            MM = _N.linalg.solve(H, uRHS)
            Cov = _N.linalg.inv(H)

            oo.us[1:] = _N.random.multivariate_normal(MM, Cov, size=1)
            oo.us[0] = -_N.sum(oo.us[1:])
            oo.smp_u[:, it] = oo.us

            t4 = _tm.time()
            ####  Sample latent state
            oo._d.y = _N.dot(izd, kpOws - BaS - ARo - oous_rs - oo.knownSig)
            oo._d.copyParams(oo.F0, oo.q2)
            #  (MxM)  (MxN) = (MxN)  (Rv is MxN)
            _N.dot(_N.dot(izd, izd), 1. / oo.ws, out=oo._d.Rv)

            oo._d.f_x[:, 0, :, 0] = oo.x00
            #if it == 1:
            for m in xrange(ooTR):
                oo._d.f_V[m, 0] = oo.s2_x00
            else:
                oo._d.f_V[:, 0] = _N.mean(oo._d.f_V[:, 1:], axis=1)

            tpl_args = zip(oo._d.y, oo._d.Rv, oo._d.Fs,
                           _N.linalg.inv(oo._d.Fs, ), oo.q2, oo._d.Ns,
                           oo._d.ks, oo._d.f_x[:, 0], oo._d.f_V[:, 0])

            t5 = _tm.time()
            if oo.processes == 1:
                for m in xrange(ooTR):
                    oo.smpx[m, 2:], oo._d.f_x[m], oo._d.f_V[
                        m] = _kfar.armdl_FFBS_1itrMP(tpl_args[m])
                    oo.smpx[m, 1, 0:ook - 1] = oo.smpx[m, 2, 1:]
                    oo.smpx[m, 0, 0:ook - 2] = oo.smpx[m, 2, 2:]
                    oo.smp_q2[m, it] = oo.q2[m]
            else:
                sxv = pool.map(_kfar.armdl_FFBS_1itrMP, tpl_args)
                for m in xrange(ooTR):
                    oo.smpx[m, 2:] = sxv[m][0]
                    oo._d.f_x[m] = sxv[m][1]
                    oo._d.f_V[m] = sxv[m][2]
                    oo.smpx[m, 1, 0:ook - 1] = oo.smpx[m, 2, 1:]
                    oo.smpx[m, 0, 0:ook - 2] = oo.smpx[m, 2, 2:]
                    #oo.Bsmpx[m, it, 2:]    = oo.smpx[m, 2:, 0]

            stds = _N.std(oo.smpx[:, 2:, 0], axis=1)
            oo.mnStds[it] = _N.mean(stds, axis=0)
            print "mnStd  %.3f" % oo.mnStds[it]
            ###
            if not oo.bFixF:
                ARcfSmpl(oo.lfc,
                         ooN + 1,
                         ook,
                         oo.AR2lims,
                         oo.smpx[:, 1:, 0:ook],
                         oo.smpx[:, :, 0:ook - 1],
                         oo.q2,
                         oo.R,
                         oo.Cs,
                         oo.Cn,
                         alpR,
                         alpC,
                         oo.TR,
                         prior=oo.use_prior,
                         accepts=80,
                         aro=oo.ARord,
                         sig_ph0L=oo.sig_ph0L,
                         sig_ph0H=oo.sig_ph0H)
                oo.F_alfa_rep = alpR + alpC  #  new constructed
                prt, rank, f, amp = ampAngRep(oo.F_alfa_rep, f_order=True)
                print prt
            #ut, wt = FilteredTimeseries(ooN+1, ook, oo.smpx[:, 1:, 0:ook], oo.smpx[:, :, 0:ook-1], oo.q2, oo.R, oo.Cs, oo.Cn, alpR, alpC, oo.TR)
            #ranks[it]    = rank
            oo.allalfas[it] = oo.F_alfa_rep

            for m in xrange(ooTR):
                #oo.wts[m, it, :, :]   = wt[m, :, :, 0]
                #oo.uts[m, it, :, :]   = ut[m, :, :, 0]
                if not oo.bFixF:
                    oo.amps[it, :] = amp
                    oo.fs[it, :] = f

            oo.F0 = (-1 * _Npp.polyfromroots(oo.F_alfa_rep)[::-1].real)[1:]

            print "len(lwsts) %(l)d   len(hists) %(h)d" % {
                "l": len(lwsts),
                "h": len(hists)
            }
            sts2chg = hists
            if (it > oo.startZ) and oo.doS and len(sts2chg) > 0:
                AL = 0.5 * _N.sum(oo.smpx[sts2chg, 2:, 0] *
                                  oo.smpx[sts2chg, 2:, 0] * oo.ws[sts2chg])
                BRL = kpOws[sts2chg] - BaS - oous_rs[sts2chg] - ARo[
                    sts2chg] - oo.knownSig[sts2chg]
                BL = _N.sum(oo.ws[sts2chg] * BRL * oo.smpx[sts2chg, 2:, 0])
                UL = BL / (2 * AL)
                #sgL= 1/_N.sqrt(2*AL)
                sg2 = 1. / (2 * AL)

                q2_pr = 0.0025  # 0.05**2
                u_pr = 1.
                U = (u_pr * sg2 + UL * q2_pr) / (sg2 + q2_pr)
                sg = _N.sqrt((sg2 * q2_pr) / (sg2 + q2_pr))

                print "U  %(U).4f    UL %(UL).4f s  %(s).3f" % {
                    "U": U,
                    "s": sg,
                    "UL": UL
                }
                if _N.isnan(U):
                    print "U is nan  UL %.4f" % UL
                    print "U is nan  AL %.4f" % AL
                    print "U is nan  BL %.4f" % BL
                    print "U is nan  BaS "
                    print "hists"
                    print hists
                    print "lwsts"
                    print lwsts

                oo.s[1] = U + sg * _N.random.randn()

                _N.fill_diagonal(sd01[0], oo.s[0])
                _N.fill_diagonal(sd01[1], oo.s[1])
                print oo.s[1]
                oo.smp_ss[it] = oo.s[1]

            oo.a2 = oo.a_q2 + 0.5 * (ooTR * ooN + 2)  #  N + 1 - 1
            BB2 = oo.B_q2
            for m in xrange(ooTR):
                #   set x00
                #oo.x00[m]      = oo.smpx[m, 2]*0.1
                oo.x00[m] = oo.smpx[m, 2] * 0.001

                #####################    sample q2
                rsd_stp = oo.smpx[m, 3:, 0] - _N.dot(oo.smpx[m, 2:-1], oo.F0).T
                BB2 += 0.5 * _N.dot(rsd_stp, rsd_stp.T)
            oo.q2[:] = _ss.invgamma.rvs(oo.a2, scale=BB2)

            oo.smp_q2[:, it] = oo.q2
            t7 = _tm.time()
            print "gibbs iter %.3f" % (t7 - t1)
            if (it > 1) and (it % oo.peek == 0):
                fig = _plt.figure(figsize=(12, 8))
                fig.add_subplot(4, 1, 1)
                _plt.plot(oo.amps[1:it, 0])
                fig.add_subplot(4, 1, 2)
                _plt.plot(oo.fs[1:it, 0])
                fig.add_subplot(4, 1, 3)
                _plt.plot(oo.mnStds[1:it])
                fig.add_subplot(4, 1, 4)
                _plt.plot(oo.smp_ms[1:it, 0])

                _plt.savefig("%(dir)s/tmp-fsamps%(it)d" % {
                    "dir": oo.mcmcRunDir,
                    "it": it
                })
                _plt.close()

                oo.dump_smpsS(toiter=it, dir=oo.mcmcRunDir)
        oo.dump_smpsS(dir=oo.mcmcRunDir)
Пример #4
0
    def gibbsSamp(self):  ###########################  GIBBSSAMPH
        oo = self

        ooTR = oo.TR
        ook = oo.k

        ooN = oo.N
        _kfar.init(oo.N, oo.k, oo.TR)
        oo.x00 = _N.array(oo.smpx[:, 2])
        oo.V00 = _N.zeros((ooTR, ook, ook))
        if oo.dohist:
            oo.loghist = _N.zeros(oo.N + 1)
        else:
            print "fixed hist is"
            print oo.loghist

        print "oo.mcmcRunDir    %s" % oo.mcmcRunDir
        if oo.mcmcRunDir is None:
            oo.mcmcRunDir = ""
        elif (len(oo.mcmcRunDir) > 0) and (oo.mcmcRunDir[-1] != "/"):
            oo.mcmcRunDir += "/"

        ARo = _N.zeros((ooTR, ooN + 1))

        kpOws = _N.empty((ooTR, ooN + 1))
        lv_f = _N.zeros((ooN + 1, ooN + 1))
        lv_u = _N.zeros((ooTR, ooTR))
        Bii = _N.zeros((ooN + 1, ooN + 1))

        #alpC.reverse()
        #  F_alfa_rep = alpR + alpC  already in right order, no?

        Wims = _N.empty((ooTR, ooN + 1, ooN + 1))
        Oms = _N.empty((ooTR, ooN + 1))
        smWimOm = _N.zeros(ooN + 1)
        smWinOn = _N.zeros(ooTR)
        bConstPSTH = False

        D_f = _N.diag(_N.ones(oo.B.shape[0]) * oo.s2_a)  #  spline
        iD_f = _N.linalg.inv(D_f)
        D_u = _N.diag(_N.ones(oo.TR) * oo.s2_u)  #  This should
        iD_u = _N.linalg.inv(D_u)
        iD_u_u_u = _N.dot(iD_u, _N.ones(oo.TR) * oo.u_u)

        if oo.bpsth:
            BDB = _N.dot(oo.B.T, _N.dot(D_f, oo.B))
            DB = _N.dot(D_f, oo.B)
            BTua = _N.dot(oo.B.T, oo.u_a)

        it = -1

        oous_rs = oo.us.reshape((ooTR, 1))
        #runTO = ooNMC + oo.burn - 1 if (burns is None) else (burns - 1)
        runTO = oo.ITERS - 1
        oo.allocateSmp(runTO + 1, Bsmpx=oo.doBsmpx)
        alpR = oo.F_alfa_rep[0:oo.R]
        alpC = oo.F_alfa_rep[oo.R:]

        BaS = _N.zeros(oo.N + 1)  #_N.empty(oo.N+1)

        #  H shape    100 x 9
        Hbf = oo.Hbf

        RHS = _N.empty((oo.histknots, 1))

        print "-----------    histknots %d" % oo.histknots
        if oo.h0_1 > 1:  #  no spikes in first few time bins
            #cInds = _N.array([0, 1, 5, 6, 7, 8, 9, 10])
            cInds = _N.array([0, 4, 5, 6, 7, 8, 9])
            #vInds = _N.array([2, 3, 4])
            vInds = _N.array([
                1,
                2,
                3,
            ])
            RHS[cInds, 0] = 0
            RHS[0, 0] = -5
        elif oo.hist_max_at_0:  #  no refractory period
            #cInds = _N.array([5, 6, 7, 8, 9, 10])
            cInds = _N.array([
                3,
                4,
                5,
                6,
                7,
                8,
            ])
            vInds = _N.array([
                0,
                1,
                2,
            ])
            #vInds = _N.array([0, 1, 2, 3, 4])
            RHS[cInds, 0] = 0
        else:
            #cInds = _N.array([5, 6, 7, 8, 9, 10])
            cInds = _N.array([
                4,
                5,
                6,
                7,
                8,
                9,
            ])
            vInds = _N.array([
                0,
                1,
                2,
                3,
            ])
            #vInds = _N.array([0, 1, 2, 3, 4])
            RHS[cInds, 0] = 0

        Msts = []
        for m in xrange(ooTR):
            Msts.append(_N.where(oo.y[m] == 1)[0])
        HcM = _N.empty((len(vInds), len(vInds)))

        HbfExpd = _N.empty((oo.histknots, ooTR, oo.N + 1))
        #  HbfExpd is 11 x M x 1200
        #  find the mean.  For the HISTORY TERM
        for i in xrange(oo.histknots):
            for m in xrange(oo.TR):
                sts = Msts[m]
                HbfExpd[i, m, 0:sts[0]] = 0
                for iss in xrange(len(sts) - 1):
                    t0 = sts[iss]
                    t1 = sts[iss + 1]
                    HbfExpd[i, m, t0 + 1:t1 + 1] = Hbf[0:t1 - t0, i]
                HbfExpd[i, m, sts[-1] + 1:] = 0

        _N.dot(oo.B.T, oo.aS, out=BaS)
        if oo.hS is None:
            oo.hS = _N.zeros(oo.histknots)

        if oo.dohist:
            _N.dot(Hbf, oo.hS, out=oo.loghist)
        oo.stitch_Hist(ARo, oo.loghist, Msts)

        ##  ORDER OF SAMPLING
        ##  f_xx, f_V
        ##  DA:  PG, kpOws
        ##  history, build ARo
        ##  psth
        ##  offset
        ##  DA:  latent state
        ##  AR coefficients
        ##  q2

        K = _N.empty((oo.TR, oo.N + 1, oo.k))  #  kalman gain

        iterBLOCKS = oo.ITERS / oo.peek
        smpx_tmp = _N.empty((oo.TR, oo.N + 1, oo.k))
        for itrB in xrange(iterBLOCKS):
            for it in xrange(itrB * oo.peek, (itrB + 1) * oo.peek):
                ttt1 = _tm.time()

                if (it % 10) == 0:
                    print it
                #  generate latent AR state
                oo.f_x[:, 0] = oo.x00
                if it == 0:
                    for m in xrange(ooTR):
                        oo.f_V[m, 0] = oo.s2_x00
                else:
                    oo.f_V[:, 0] = _N.mean(oo.f_V[:, 1:], axis=1)

                ###  PG latent variable sample
                ttt2 = _tm.time()

                for m in xrange(ooTR):
                    lw.rpg_devroye(oo.rn,
                                   oo.smpx[m, 2:, 0] + oo.us[m] + BaS +
                                   ARo[m] + oo.knownSig[m],
                                   out=oo.ws[m])  ######  devryoe
                ttt3 = _tm.time()

                if ooTR == 1:
                    oo.ws = oo.ws.reshape(1, ooN + 1)
                _N.divide(oo.kp, oo.ws, out=kpOws)

                if oo.dohist:
                    O = kpOws - oo.smpx[..., 2:, 0] - oo.us.reshape(
                        (ooTR, 1)) - BaS - oo.knownSig

                    iOf = vInds[0]  #  offset HcM index with RHS index.
                    for i in vInds:
                        for j in vInds:
                            HcM[i - iOf, j - iOf] = _N.sum(oo.ws * HbfExpd[i] *
                                                           HbfExpd[j])

                        RHS[i, 0] = _N.sum(oo.ws * HbfExpd[i] * O)
                        for cj in cInds:
                            RHS[i, 0] -= _N.sum(
                                oo.ws * HbfExpd[i] * HbfExpd[cj]) * RHS[cj, 0]

                    # print HbfExpd
                    # print HcM
                    # print RHS[vInds]
                    vm = _N.linalg.solve(HcM, RHS[vInds])
                    Cov = _N.linalg.inv(HcM)
                    #print vm
                    cfs = _N.random.multivariate_normal(vm[:, 0], Cov, size=1)

                    RHS[vInds, 0] = cfs[0]
                    oo.smp_hS[:, it] = RHS[:, 0]

                    #RHS[2:6, 0] = vm[:, 0]
                    #vv = _N.dot(Hbf, RHS)
                    #print vv.shape
                    #print oo.loghist.shape
                    _N.dot(Hbf, RHS[:, 0], out=oo.loghist)
                    oo.smp_hist[:, it] = oo.loghist
                    oo.stitch_Hist(ARo, oo.loghist, Msts)
                else:
                    oo.smp_hist[:, it] = oo.loghist
                    oo.stitch_Hist(ARo, oo.loghist, Msts)

                #  Now that we have PG variables, construct Gaussian timeseries
                #  ws(it+1)    using u(it), F0(it), smpx(it)

                #  cov matrix, prior of aS

                oo.gau_obs = kpOws - BaS - ARo - oous_rs - oo.knownSig

                oo.gau_var = 1 / oo.ws  #  time dependent noise

                if oo.bpsth:
                    Oms = kpOws - oo.smpx[..., 2:,
                                          0] - ARo - oous_rs - oo.knownSig
                    _N.einsum("mn,mn->n", oo.ws, Oms, out=smWimOm)  #  sum over
                    ilv_f = _N.diag(_N.sum(oo.ws, axis=0))
                    #  diag(_N.linalg.inv(Bi)) == diag(1./Bi).  Bii = inv(Bi)
                    _N.fill_diagonal(lv_f, 1. / _N.diagonal(ilv_f))
                    lm_f = _N.dot(lv_f, smWimOm)  #  nondiag of 1./Bi are inf
                    #  now sample
                    iVAR = _N.dot(oo.B, _N.dot(ilv_f, oo.B.T)) + iD_f
                    t4a = _tm.time()
                    VAR = _N.linalg.inv(iVAR)  #  knots x knots
                    t4b = _tm.time()
                    #iBDBW = _N.linalg.inv(BDB + lv_f)   # BDB not diag
                    #Mn    = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua))

                    #  BDB + lv_f     (N+1 x N+1)
                    #  lm_f - BTua    (N+1)
                    Mn = oo.u_a + _N.dot(
                        DB, _N.linalg.solve(BDB + lv_f, lm_f - BTua))

                    t4c = _tm.time()

                    oo.aS = _N.random.multivariate_normal(Mn, VAR,
                                                          size=1)[0, :]
                    oo.smp_aS[it, :] = oo.aS
                    _N.dot(oo.B.T, oo.aS, out=BaS)

                ttt4 = _tm.time()
                ########     per trial offset sample  burns==None, only psth fit
                Ons = kpOws - oo.smpx[..., 2:, 0] - ARo - BaS - oo.knownSig

                #  solve for the mean of the distribution

                if not oo.bpsth:  # if not doing PSTH, don't constrain offset, as there are no confounds controlling offset
                    _N.einsum("mn,mn->m", oo.ws, Ons,
                              out=smWinOn)  #  sum over trials
                    ilv_u = _N.diag(_N.sum(oo.ws, axis=1))  #  var  of LL
                    #  diag(_N.linalg.inv(Bi)) == diag(1./Bi).  Bii = inv(Bi)
                    _N.fill_diagonal(lv_u, 1. / _N.diagonal(ilv_u))
                    lm_u = _N.dot(
                        lv_u, smWinOn)  #  nondiag of 1./Bi are inf, mean LL
                    #  now sample
                    iVAR = ilv_u + iD_u
                    VAR = _N.linalg.inv(iVAR)  #
                    Mn = _N.dot(VAR, _N.dot(ilv_u, lm_u) + iD_u_u_u)
                    oo.us[:] = _N.random.multivariate_normal(Mn, VAR,
                                                             size=1)[0, :]
                    if not oo.bIndOffset:
                        oo.us[:] = _N.mean(oo.us)
                    oo.smp_u[:, it] = oo.us
                else:
                    H = _N.ones((oo.TR - 1, oo.TR - 1)) * _N.sum(oo.ws[0])
                    uRHS = _N.empty(oo.TR - 1)
                    for dd in xrange(1, oo.TR):
                        H[dd - 1, dd - 1] += _N.sum(oo.ws[dd])
                        uRHS[dd - 1] = _N.sum(oo.ws[dd] * Ons[dd] -
                                              oo.ws[0] * Ons[0])

                    MM = _N.linalg.solve(H, uRHS)
                    Cov = _N.linalg.inv(H)

                    oo.us[1:] = _N.random.multivariate_normal(MM, Cov, size=1)
                    oo.us[0] = -_N.sum(oo.us[1:])
                    if not oo.bIndOffset:
                        oo.us[:] = _N.mean(oo.us)
                    oo.smp_u[:, it] = oo.us

                # Ons  = kpOws - ARo
                # _N.einsum("mn,mn->m", oo.ws, Ons, out=smWinOn)  #  sum over trials
                # ilv_u  = _N.diag(_N.sum(oo.ws, axis=1))  #  var  of LL
                # #  diag(_N.linalg.inv(Bi)) == diag(1./Bi).  Bii = inv(Bi)
                # _N.fill_diagonal(lv_u, 1./_N.diagonal(ilv_u))
                # lm_u  = _N.dot(lv_u, smWinOn)  #  nondiag of 1./Bi are inf, mean LL
                # #  now sample
                # iVAR = ilv_u + iD_u
                # VAR  = _N.linalg.inv(iVAR)  #
                # Mn    = _N.dot(VAR, _N.dot(ilv_u, lm_u) + iD_u_u_u)
                # oo.us[:]  = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :]
                # if not oo.bIndOffset:
                #     oo.us[:] = _N.mean(oo.us)
                # oo.smp_u[:, it] = oo.us

                ttt5 = _tm.time()
                if not oo.noAR:
                    #  _d.F, _d.N, _d.ks,
                    #_kfar.armdl_FFBS_1itrMP(oo.gau_obs, oo.gau_var, oo.Fs, _N.linalg.inv(oo.Fs), oo.q2, oo.Ns, oo.ks, oo.f_x, oo.f_V, oo.p_x, oo.p_V, oo.smpx, K)
                    _kfar.armdl_FFBS_1itrMP(oo.gau_obs, oo.gau_var, oo.Fs,
                                            _N.linalg.inv(oo.Fs), oo.q2, oo.Ns,
                                            oo.ks, oo.f_x, oo.f_V, oo.p_x,
                                            oo.p_V, smpx_tmp, K)

                    oo.smpx[:, 2:] = smpx_tmp
                    oo.smpx[:, 1, 0:ook - 1] = oo.smpx[:, 2, 1:]
                    oo.smpx[:, 0, 0:ook - 2] = oo.smpx[:, 2, 2:]

                    if oo.doBsmpx and (it % oo.BsmpxSkp == 0):
                        oo.Bsmpx[:, it / oo.BsmpxSkp, 2:] = oo.smpx[:, 2:, 0]
                    stds = _N.std(oo.smpx[:, 2 + oo.ignr:, 0], axis=1)
                    oo.mnStds[it] = _N.mean(stds, axis=0)
                    print "mnStd  %.3f" % oo.mnStds[it]

                    ttt6 = _tm.time()
                    if not oo.bFixF:
                        ARcfSmpl(oo.lfc,
                                 ooN + 1 - oo.ignr,
                                 ook,
                                 oo.AR2lims,
                                 oo.smpx[:, 1 + oo.ignr:, 0:ook],
                                 oo.smpx[:, oo.ignr:, 0:ook - 1],
                                 oo.q2,
                                 oo.R,
                                 oo.Cs,
                                 oo.Cn,
                                 alpR,
                                 alpC,
                                 oo.TR,
                                 prior=oo.use_prior,
                                 accepts=50,
                                 aro=oo.ARord,
                                 sig_ph0L=oo.sig_ph0L,
                                 sig_ph0H=oo.sig_ph0H)
                        oo.F_alfa_rep = alpR + alpC  #  new constructed
                        prt, rank, f, amp = ampAngRep(oo.F_alfa_rep,
                                                      f_order=True)
                        #print prt
                    #ut, wt = FilteredTimeseries(ooN+1, ook, oo.smpx[:, 1:, 0:ook], oo.smpx[:, :, 0:ook-1], oo.q2, oo.R, oo.Cs, oo.Cn, alpR, alpC, oo.TR)
                    #ranks[it]    = rank
                    oo.allalfas[it] = oo.F_alfa_rep

                    for m in xrange(ooTR):
                        #oo.wts[m, it, :, :]   = wt[m, :, :, 0]
                        #oo.uts[m, it, :, :]   = ut[m, :, :, 0]
                        if not oo.bFixF:
                            oo.amps[it, :] = amp
                            oo.fs[it, :] = f

                    oo.F0 = (-1 *
                             _Npp.polyfromroots(oo.F_alfa_rep)[::-1].real)[1:]
                    for tr in xrange(oo.TR):
                        oo.Fs[tr, 0] = oo.F0[:]

                    #  sample u     WE USED TO Do this after smpx
                    #  u(it+1)    using ws(it+1), F0(it), smpx(it+1), ws(it+1)

                    if oo.ID_q2:
                        for m in xrange(ooTR):
                            #####################    sample q2
                            a = oo.a_q2 + 0.5 * (ooN + 1)  #  N + 1 - 1
                            rsd_stp = oo.smpx[m, 3 + oo.ignr:, 0] - _N.dot(
                                oo.smpx[m, 2 + oo.ignr:-1], oo.F0).T
                            BB = oo.B_q2 + 0.5 * _N.dot(rsd_stp, rsd_stp.T)
                            oo.q2[m] = _ss.invgamma.rvs(a, scale=BB)
                            oo.x00[m] = oo.smpx[m, 2] * 0.1
                            oo.smp_q2[m, it] = oo.q2[m]
                    else:
                        #oo.a2 = oo.a_q2 + 0.5*(ooTR*ooN + 2)  #  N + 1 - 1
                        oo.a2 = 0.5 * (ooTR *
                                       (ooN - oo.ignr) + 2)  #  N + 1 - 1
                        #BB2 = oo.B_q2
                        BB2 = 0
                        for m in xrange(ooTR):
                            #   set x00
                            oo.x00[m] = oo.smpx[m, 2] * 0.1

                            #####################    sample q2
                            rsd_stp = oo.smpx[m, 3 + oo.ignr:, 0] - _N.dot(
                                oo.smpx[m, 2 + oo.ignr:-1], oo.F0).T
                            #oo.rsds[it, m] = _N.dot(rsd_stp, rsd_stp.T)
                            BB2 += 0.5 * _N.dot(rsd_stp, rsd_stp.T)
                        oo.q2[:] = _ss.invgamma.rvs(oo.a2, scale=BB2)

                    oo.smp_q2[:, it] = oo.q2

                ttt7 = _tm.time()

    #             print ("t2-t1  %.4f" % (t2-t1))
    #             print ("t3-t2  %.4f" % (t3-t2))
    #             print ("t4-t3  %.4f" % (t4-t3))
    # #            print ("t4b-t4a  %.4f" % (t4b-t4a))
    # #            print ("t4c-t4b  %.4f" % (t4c-t4b))
    # #            print ("t4-t4c  %.4f" % (t4-t4c))
    #             print ("t5-t4  %.4f" % (t5-t4))
    #             print ("t6-t5  %.4f" % (t6-t5))
    #             print ("t7-t6  %.4f" % (t7-t6))

            if it > oo.minITERS:
                smps = _N.empty((3, it + 1))
                smps[0, :it + 1] = oo.amps[:it + 1, 0]
                smps[1, :it + 1] = oo.fs[:it + 1, 0]
                smps[2, :it + 1] = oo.mnStds[:it + 1]

                frms = _mg.stationary_from_Z_bckwd(smps, blksz=oo.peek)

                fig = _plt.figure(figsize=(8, 8))
                fig.add_subplot(3, 1, 1)
                _plt.plot(range(1, it), oo.amps[1:it, 0], color="grey", lw=1.5)
                _plt.plot(range(frms, it),
                          oo.amps[frms:it, 0],
                          color="black",
                          lw=3)
                _plt.ylabel("amp")
                fig.add_subplot(3, 1, 2)
                _plt.plot(range(1, it),
                          oo.fs[1:it, 0] / (2 * oo.dt),
                          color="grey",
                          lw=1.5)
                _plt.plot(range(frms, it),
                          oo.fs[frms:it, 0] / (2 * oo.dt),
                          color="black",
                          lw=3)
                _plt.ylabel("f")
                fig.add_subplot(3, 1, 3)
                _plt.plot(range(1, it), oo.mnStds[1:it], color="grey", lw=1.5)
                _plt.plot(range(frms, it),
                          oo.mnStds[frms:it],
                          color="black",
                          lw=3)
                _plt.ylabel("amp")
                _plt.xlabel("iter")
                _plt.savefig("%(dir)stmp-fsamps%(it)d" % {
                    "dir": oo.mcmcRunDir,
                    "it": it + 1
                })
                fig.subplots_adjust(left=0.15,
                                    bottom=0.15,
                                    right=0.95,
                                    top=0.95)
                _plt.close()

                if it - frms > oo.stationaryDuration:
                    break

        oo.dump_smps(frms, toiter=(it + 1), dir=oo.mcmcRunDir)
        oo.VIS = ARo