Beispiel #1
0
    def gibbsSamp(
            self,
            smpls_fn_incl_trls=False):  ###########################  GIBBSSAMPH
        global interrupted
        oo = self

        signal.signal(signal.SIGINT, signal_handler)

        print("****!!!!!!!!!!!!!!!!  dohist  %s" % str(oo.dohist))

        ooTR = oo.TR
        ook = oo.k

        ooN = oo.N
        _kfar.init(oo.N, oo.k, oo.TR)
        oo.x00 = _N.array(oo.smpx[:, 2])
        oo.V00 = _N.zeros((ooTR, ook, ook))
        if oo.dohist:
            oo.loghist = _N.zeros(oo.Hbf.shape[0])
        else:
            print("fixed hist is")
            print(oo.loghist)

        print("oo.mcmcRunDir    %s" % oo.mcmcRunDir)
        if oo.mcmcRunDir is None:
            oo.mcmcRunDir = ""
        elif (len(oo.mcmcRunDir) > 0) and (oo.mcmcRunDir[-1] != "/"):
            oo.mcmcRunDir += "/"

        ARo = _N.zeros((ooTR, ooN + 1))

        kpOws = _N.empty((ooTR, ooN + 1))
        lv_f = _N.zeros((ooN + 1, ooN + 1))
        lv_u = _N.zeros((ooTR, ooTR))
        Bii = _N.zeros((ooN + 1, ooN + 1))

        #alpC.reverse()
        #  F_alfa_rep = alpR + alpC  already in right order, no?

        Wims = _N.empty((ooTR, ooN + 1, ooN + 1))
        Oms = _N.empty((ooTR, ooN + 1))
        smWimOm = _N.zeros(ooN + 1)
        smWinOn = _N.zeros(ooTR)
        bConstPSTH = False

        D_f = _N.diag(_N.ones(oo.B.shape[0]) * oo.s2_a)  #  spline
        iD_f = _N.linalg.inv(D_f)
        D_u = _N.diag(_N.ones(oo.TR) * oo.s2_u)  #  This should
        iD_u = _N.linalg.inv(D_u)
        iD_u_u_u = _N.dot(iD_u, _N.ones(oo.TR) * oo.u_u)

        if oo.bpsth:
            BDB = _N.dot(oo.B.T, _N.dot(D_f, oo.B))
            DB = _N.dot(D_f, oo.B)
            BTua = _N.dot(oo.B.T, oo.u_a)

        it = -1

        oous_rs = oo.us.reshape((ooTR, 1))
        #runTO = ooNMC + oo.burn - 1 if (burns is None) else (burns - 1)
        runTO = oo.ITERS - 1
        oo.allocateSmp(runTO + 1, Bsmpx=oo.doBsmpx)
        if cython_arc:
            _arcfs.init(ooN + 1 - oo.ignr,
                        oo.k,
                        oo.TR,
                        oo.R,
                        oo.Cs,
                        oo.Cn,
                        aro=_cd.__NF__)
            alpR = _N.array(oo.F_alfa_rep[0:oo.R])
            alpC = _N.array(oo.F_alfa_rep[oo.R:])
        else:
            alpR = oo.F_alfa_rep[0:oo.R]
            alpC = oo.F_alfa_rep[oo.R:]

        BaS = _N.zeros(oo.N + 1)  #_N.empty(oo.N+1)

        #  H shape    100 x 9
        Hbf = oo.Hbf

        RHS = _N.empty((oo.histknots, 1))

        print("-----------    histknots %d" % oo.histknots)

        cInds = _N.arange(oo.iHistKnotBeginFixed, oo.histknots)
        vInds = _N.arange(0, oo.iHistKnotBeginFixed)
        #cInds = _N.array([4, 12, 13])
        #vInds = _N.array([0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, ])
        #vInds = _N.arange(0, oo.iHistKnotBeginFixed)

        RHS[cInds, 0] = 0

        Msts = []
        for m in range(ooTR):
            Msts.append(_N.where(oo.y[m] == 1)[0])
        HcM = _N.ones((len(vInds), len(vInds)))

        HbfExpd = _N.zeros((oo.histknots, ooTR, oo.N + 1))

        #HbfExpd = _N.zeros((oo.histknots, ooTR, oo.Hbf.shape[0]))
        #  HbfExpd is 11 x M x 1200
        #  find the mean.  For the HISTORY TERM
        for i in range(oo.histknots):
            for m in range(oo.TR):
                sts = Msts[m]
                HbfExpd[i, m, 0:sts[0]] = 0
                for iss in range(len(sts) - 1):
                    t0 = sts[iss]
                    t1 = sts[iss + 1]
                    #HbfExpd[i, m, t0+1:t1+1] = Hbf[1:t1-t0+1, i]#Hbf[0:t1-t0, i]
                    HbfExpd[i, m, t0 + 1:t1 + 1] = Hbf[0:t1 - t0, i]
                HbfExpd[i, m, sts[-1] + 1:] = 0

        _N.dot(oo.B.T, oo.aS, out=BaS)
        if oo.hS is None:
            oo.hS = _N.zeros(oo.histknots)

        if oo.dohist:
            _N.dot(Hbf, oo.hS, out=oo.loghist)
        oo.stitch_Hist(ARo, oo.loghist, Msts)

        ##  ORDER OF SAMPLING
        ##  f_xx, f_V
        ##  DA:  PG, kpOws
        ##  history, build ARo
        ##  psth
        ##  offset
        ##  DA:  latent state
        ##  AR coefficients
        ##  q2

        K = _N.empty((oo.TR, oo.N + 1, oo.k))  #  kalman gain

        iterBLOCKS = oo.ITERS // oo.peek
        smpx_C_cont = _N.empty((oo.TR, oo.N + 1, oo.k))  #  need C contiguous

        #  oo.smpx[:, 1+oo.ignr:, 0:ook], oo.smpx[:, oo.ignr:, 0:ook-1]
        smpx_contiguous1 = _N.zeros((oo.TR, oo.N + 2, oo.k))
        smpx_contiguous2 = _N.zeros((oo.TR, (oo.N + 1) + 2, oo.k - 1))
        if (cython_inv_v == 3) or (cython_inv_v == 5):
            oo.if_V = _N.array(oo.f_V)
            oo.chol_L_fV = _N.array(oo.f_V)
        ######  Gibbs sampling procedure
        ttts = _N.zeros((oo.ITERS, 9))
        for itrB in range(iterBLOCKS):
            it = itrB * oo.peek
            if it > 0:
                #  0.5*oo.fs  because (dt*2)  ->  1 corresponds to Fs/2

                print("---------it: %(it)d    mnStd  %(mnstd).3f" % {
                    "it": itrB * oo.peek,
                    "mnstd": oo.mnStds[it - 1]
                })
                if not oo.noAR:
                    print(prt)
                mnttt = _N.mean(ttts[0:it - 1], axis=0)
                for ti in range(9):
                    print("t%(2)d-t%(1)d  %(ttt).4f" % {
                        "1": ti + 1,
                        "2": ti + 2,
                        "ttt": mnttt[ti]
                    })

            if interrupted:
                break
            for it in range(itrB * oo.peek, (itrB + 1) * oo.peek):
                ttt1 = _tm.time()

                itstore = it // oo.BsmpxSkp

                #  generate latent AR state
                oo.f_x[:, 0] = oo.x00
                if it == 0:
                    for m in range(ooTR):
                        oo.f_V[m, 0] = oo.s2_x00
                else:
                    oo.f_V[:, 0] = _N.mean(oo.f_V[:, 1:], axis=1)

                ###  PG latent variable sample
                ttt2 = _tm.time()

                for m in range(ooTR):
                    lw.rpg_devroye(oo.rn,
                                   oo.smpx[m, 2:, 0] + oo.us[m] + BaS +
                                   ARo[m] + oo.knownSig[m],
                                   out=oo.ws[m])  ######  devryoe
                ttt3 = _tm.time()

                if ooTR == 1:
                    oo.ws = oo.ws.reshape(1, ooN + 1)
                _N.divide(oo.kp, oo.ws, out=kpOws)

                if oo.dohist:
                    O = kpOws - oo.smpx[..., 2:, 0] - oo.us.reshape(
                        (ooTR, 1)) - BaS - oo.knownSig

                    #print(oo.ws)

                    # for i in vInds:
                    #     #print("i   %d" % i)
                    #     #print(_N.sum(HbfExpd[i]))
                    #     for j in vInds:
                    #         #print("j   %d" % j)
                    #         #print(_N.sum(HbfExpd[j]))
                    #         HcM[i-iOf, j-iOf] = _N.sum(oo.ws*HbfExpd[i]*HbfExpd[j])

                    #     RHS[i, 0] = _N.sum(oo.ws*HbfExpd[i]*O)
                    #     for cj in cInds:
                    #         RHS[i, 0] -= _N.sum(oo.ws*HbfExpd[i]*HbfExpd[cj])*RHS[cj, 0]
                    for ii in range(len(vInds)):
                        #print("i   %d" % i)
                        #print(_N.sum(HbfExpd[i]))
                        i = vInds[ii]
                        for jj in range(len(vInds)):
                            j = vInds[jj]
                            #print("j   %d" % j)
                            #print(_N.sum(HbfExpd[j]))
                            HcM[ii,
                                jj] = _N.sum(oo.ws * HbfExpd[i] * HbfExpd[j])

                        RHS[ii, 0] = _N.sum(oo.ws * HbfExpd[i] * O)
                        for cj in cInds:
                            RHS[ii, 0] -= _N.sum(
                                oo.ws * HbfExpd[i] * HbfExpd[cj]) * RHS[cj, 0]

                    # print("HbfExpd..............................")
                    # for i in range(oo.histknots):
                    #     print(_N.sum(HbfExpd[i]))
                    # print("HcM..................................")
                    # print(HcM)
                    # print("RHS..................................")
                    # print(RHS[vInds])
                    vm = _N.linalg.solve(HcM, RHS[vInds])
                    Cov = _N.linalg.inv(HcM)
                    #print vm
                    #print(Cov)
                    #print(vm[:, 0])
                    cfs = _N.random.multivariate_normal(vm[:, 0], Cov, size=1)

                    RHS[vInds, 0] = cfs[0]
                    oo.smp_hS[it] = RHS[:, 0]

                    #RHS[2:6, 0] = vm[:, 0]
                    #vv = _N.dot(Hbf, RHS)
                    #print vv.shape
                    #print oo.loghist.shape
                    _N.dot(Hbf, RHS[:, 0], out=oo.loghist)
                    oo.smp_hist[it] = oo.loghist
                    oo.stitch_Hist(ARo, oo.loghist, Msts)
                else:
                    oo.smp_hist[it] = oo.loghist
                    oo.stitch_Hist(ARo, oo.loghist, Msts)

                #  Now that we have PG variables, construct Gaussian timeseries
                #  ws(it+1)    using u(it), F0(it), smpx(it)

                #  cov matrix, prior of aS

                # oo.gau_obs = kpOws - BaS - ARo - oous_rs - oo.knownSig
                # oo.gau_var =1 / oo.ws   #  time dependent noise
                ttt4 = _tm.time()
                if oo.bpsth:
                    Oms = kpOws - oo.smpx[..., 2:,
                                          0] - ARo - oous_rs - oo.knownSig
                    _N.einsum("mn,mn->n", oo.ws, Oms, out=smWimOm)  #  sum over
                    ilv_f = _N.diag(_N.sum(oo.ws, axis=0))
                    #  diag(_N.linalg.inv(Bi)) == diag(1./Bi).  Bii = inv(Bi)
                    _N.fill_diagonal(lv_f, 1. / _N.diagonal(ilv_f))
                    lm_f = _N.dot(lv_f, smWimOm)  #  nondiag of 1./Bi are inf
                    #  now sample
                    iVAR = _N.dot(oo.B, _N.dot(ilv_f, oo.B.T)) + iD_f
                    ttt4a = _tm.time()
                    VAR = _N.linalg.inv(iVAR)  #  knots x knots
                    ttt4b = _tm.time()
                    #iBDBW = _N.linalg.inv(BDB + lv_f)   # BDB not diag
                    #Mn    = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua))

                    #  BDB + lv_f     (N+1 x N+1)
                    #  lm_f - BTua    (N+1)
                    Mn = oo.u_a + _N.dot(
                        DB, _N.linalg.solve(BDB + lv_f, lm_f - BTua))

                    #t4c = _tm.time()

                    oo.aS = _N.random.multivariate_normal(Mn, VAR,
                                                          size=1)[0, :]
                    oo.smp_aS[it] = oo.aS
                    _N.dot(oo.B.T, oo.aS, out=BaS)

                ttt5 = _tm.time()
                ########     per trial offset sample  burns==None, only psth fit
                Ons = kpOws - oo.smpx[..., 2:, 0] - ARo - BaS - oo.knownSig

                #  solve for the mean of the distribution

                if not oo.bpsth:  # if not doing PSTH, don't constrain offset, as there are no confounds controlling offset
                    _N.einsum("mn,mn->m", oo.ws, Ons,
                              out=smWinOn)  #  sum over trials
                    ilv_u = _N.diag(_N.sum(oo.ws, axis=1))  #  var  of LL
                    #  diag(_N.linalg.inv(Bi)) == diag(1./Bi).  Bii = inv(Bi)
                    _N.fill_diagonal(lv_u, 1. / _N.diagonal(ilv_u))
                    lm_u = _N.dot(
                        lv_u, smWinOn)  #  nondiag of 1./Bi are inf, mean LL
                    #  now sample
                    iVAR = ilv_u + iD_u
                    VAR = _N.linalg.inv(iVAR)  #
                    Mn = _N.dot(VAR, _N.dot(ilv_u, lm_u) + iD_u_u_u)
                    oo.us[:] = _N.random.multivariate_normal(Mn, VAR,
                                                             size=1)[0, :]
                    if not oo.bIndOffset:
                        oo.us[:] = _N.mean(oo.us)
                    oo.smp_u[it] = oo.us
                else:
                    H = _N.ones((oo.TR - 1, oo.TR - 1)) * _N.sum(oo.ws[0])
                    uRHS = _N.empty(oo.TR - 1)
                    for dd in range(1, oo.TR):
                        H[dd - 1, dd - 1] += _N.sum(oo.ws[dd])
                        uRHS[dd - 1] = _N.sum(oo.ws[dd] * Ons[dd] -
                                              oo.ws[0] * Ons[0])

                    MM = _N.linalg.solve(H, uRHS)
                    Cov = _N.linalg.inv(H)

                    oo.us[1:] = _N.random.multivariate_normal(MM, Cov, size=1)
                    oo.us[0] = -_N.sum(oo.us[1:])
                    if not oo.bIndOffset:
                        oo.us[:] = _N.mean(oo.us)
                    oo.smp_u[it] = oo.us

                # Ons  = kpOws - ARo
                # _N.einsum("mn,mn->m", oo.ws, Ons, out=smWinOn)  #  sum over trials
                # ilv_u  = _N.diag(_N.sum(oo.ws, axis=1))  #  var  of LL
                # #  diag(_N.linalg.inv(Bi)) == diag(1./Bi).  Bii = inv(Bi)
                # _N.fill_diagonal(lv_u, 1./_N.diagonal(ilv_u))
                # lm_u  = _N.dot(lv_u, smWinOn)  #  nondiag of 1./Bi are inf, mean LL
                # #  now sample
                # iVAR = ilv_u + iD_u
                # VAR  = _N.linalg.inv(iVAR)  #
                # Mn    = _N.dot(VAR, _N.dot(ilv_u, lm_u) + iD_u_u_u)
                # oo.us[:]  = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :]
                # if not oo.bIndOffset:
                #     oo.us[:] = _N.mean(oo.us)
                # oo.smp_u[:, it] = oo.us

                ttt6 = _tm.time()
                if not oo.noAR:
                    #  _d.F, _d.N, _d.ks,
                    #_kfar.armdl_FFBS_1itrMP(oo.gau_obs, oo.gau_var, oo.Fs, _N.linalg.inv(oo.Fs), oo.q2, oo.Ns, oo.ks, oo.f_x, oo.f_V, oo.p_x, oo.p_V, oo.smpx, K)

                    oo.gau_obs = kpOws - BaS - ARo - oous_rs - oo.knownSig
                    oo.gau_var = 1 / oo.ws  #  time dependent noise

                    #print(oo.Fs)
                    #print(_N.linalg.inv(oo.Fs))
                    if (cython_inv_v == 2):
                        _kfar.armdl_FFBS_1itrMP(oo.gau_obs, oo.gau_var, oo.Fs,
                                                _N.linalg.inv(oo.Fs), oo.q2,
                                                oo.Ns, oo.ks, oo.f_x, oo.f_V,
                                                oo.p_x, oo.p_V, smpx_C_cont, K)
                    else:
                        _kfar.armdl_FFBS_1itrMP(oo.gau_obs, oo.gau_var, oo.Fs,
                                                _N.linalg.inv(oo.Fs), oo.q2,
                                                oo.Ns, oo.ks, oo.f_x, oo.f_V,
                                                oo.chol_L_fV, oo.if_V, oo.p_x,
                                                oo.p_V, smpx_C_cont, K)

                    oo.smpx[:, 2:] = smpx_C_cont
                    oo.smpx[:, 1, 0:ook - 1] = oo.smpx[:, 2, 1:]
                    oo.smpx[:, 0, 0:ook - 2] = oo.smpx[:, 2, 2:]

                    if oo.doBsmpx and (it % oo.BsmpxSkp == 0):
                        oo.Bsmpx[it // oo.BsmpxSkp, :, 2:] = oo.smpx[:, 2:, 0]
                        #oo.Bsmpx[it // oo.BsmpxSkp, :, 2:]    = oo.smpx[:, 2:, 0]
                    stds = _N.std(oo.smpx[:, 2 + oo.ignr:, 0], axis=1)
                    oo.mnStds[it] = _N.mean(stds, axis=0)

                    ttt7 = _tm.time()
                    #print("..................................")
                    #print(alpR)
                    #print(alpC)

                    #print(alpR)
                    #print(alpC)

                    # print(oo.smpx[0, 0:20, 0])
                    # print(oo.q2)

                    if cython_arc:
                        _N.copyto(smpx_contiguous1, oo.smpx[:, 1 + oo.ignr:])
                        _N.copyto(smpx_contiguous2, oo.smpx[:, oo.ignr:,
                                                            0:ook - 1])

                        #ARcfSmpl(int N, int k, int TR, AR2lims_nmpy, smpxU, smpxW, double[::1] q2, int R, int Cs, int Cn, complex[::1] valpR, complex[::1] valpC, double sig_ph0L, double sig_ph0H, double prR_s2)

                        oo.uts[itstore], oo.wts[itstore] = _arcfs.ARcfSmpl(
                            ooN + 1 - oo.ignr, ook, oo.TR, oo.AR2lims,
                            smpx_contiguous1, smpx_contiguous2, oo.q2, oo.R,
                            oo.Cs, oo.Cn, alpR, alpC, oo.sig_ph0L, oo.sig_ph0H,
                            0.2 * 0.2)
                    else:
                        oo.uts[itstore], oo.wts[itstore] = _arcfs.ARcfSmpl(
                            ooN + 1 - oo.ignr,
                            ook,
                            oo.AR2lims,
                            oo.smpx[:, 1 + oo.ignr:, 0:ook],
                            oo.smpx[:, oo.ignr:, 0:ook - 1],
                            oo.q2,
                            oo.R,
                            oo.Cs,
                            oo.Cn,
                            alpR,
                            alpC,
                            oo.TR,
                            aro=oo.ARord,
                            sig_ph0L=oo.sig_ph0L,
                            sig_ph0H=oo.sig_ph0H)
                    #oo.F_alfa_rep = alpR + alpC   #  new constructed
                    oo.F_alfa_rep[0:oo.R] = alpR
                    oo.F_alfa_rep[oo.R:] = alpC

                    prt, rank, f, amp = ampAngRep(oo.F_alfa_rep,
                                                  oo.dt,
                                                  f_order=True)
                    #print(f)
                    #print(amp)
                    ttt8 = _tm.time()
                    #print prt
                    #ut, wt = FilteredTimeseries(ooN+1, ook, oo.smpx[:, 1:, 0:ook], oo.smpx[:, :, 0:ook-1], oo.q2, oo.R, oo.Cs, oo.Cn, alpR, alpC, oo.TR)
                    #ranks[it]    = rank
                    oo.allalfas[it] = oo.F_alfa_rep

                    for m in range(ooTR):
                        #oo.wts[m, it, :, :]   = wt[m, :, :, 0]
                        #oo.uts[m, it, :, :]   = ut[m, :, :, 0]
                        if not oo.bFixF:
                            oo.amps[it, :] = amp
                            oo.fs[it, :] = f

                    ttt9 = _tm.time()
                    oo.F0 = (-1 *
                             _Npp.polyfromroots(oo.F_alfa_rep)[::-1].real)[1:]
                    for tr in range(oo.TR):
                        oo.Fs[tr, 0] = oo.F0[:]

                    #  sample u     WE USED TO Do this after smpx
                    #  u(it+1)    using ws(it+1), F0(it), smpx(it+1), ws(it+1)

                    oo.a2 = oo.a_q2 + 0.5 * (ooTR * ooN + 2)  #  N + 1 - 1
                    #oo.a2 = 0.5*(ooTR*(ooN-oo.ignr) + 2)  #  N + 1 - 1
                    BB2 = oo.B_q2
                    #BB2 = 0
                    for m in range(ooTR):
                        #   set x00
                        oo.x00[m] = oo.smpx[m, 2] * 0.1

                        #####################    sample q2
                        rsd_stp = oo.smpx[m, 3 + oo.ignr:, 0] - _N.dot(
                            oo.smpx[m, 2 + oo.ignr:-1], oo.F0).T
                        #oo.rsds[it, m] = _N.dot(rsd_stp, rsd_stp.T)
                        BB2 += 0.5 * _N.dot(rsd_stp, rsd_stp.T)

                    oo.q2[:] = _ss.invgamma.rvs(oo.a2, scale=BB2)
                    oo.smp_q2[it] = oo.q2
                    ttt10 = _tm.time()
                else:
                    ttt7 = ttt8 = ttt9 = ttt10 = ttt6

                ttt10 = _tm.time()
                ttts[it, 0] = ttt2 - ttt1
                ttts[it, 1] = ttt3 - ttt2
                ttts[it, 2] = ttt4 - ttt3
                ttts[it, 3] = ttt5 - ttt4
                ttts[it, 4] = ttt6 - ttt5
                ttts[it, 5] = ttt7 - ttt6
                ttts[it, 6] = ttt8 - ttt7
                ttts[it, 7] = ttt9 - ttt8
                ttts[it, 8] = ttt10 - ttt9

            oo.last_iter = it
            if it > oo.minITERS:
                smps = _N.empty((3, it + 1))
                smps[0, :it + 1] = oo.amps[:it + 1, 0]

                smps[1, :it + 1] = oo.fs[:it + 1, 0]
                smps[2, :it + 1] = oo.mnStds[:it + 1]

                #frms = _mg.stationary_from_Z_bckwd(smps, blksz=oo.peek)
                if _mg.stationary_test(oo.amps[:it + 1, 0],
                                       oo.fs[:it + 1, 0],
                                       oo.mnStds[:it + 1],
                                       it + 1,
                                       blocksize=oo.mg_blocksize,
                                       points=oo.mg_points):
                    break
                """
                fig = _plt.figure(figsize=(8, 8))
                fig.add_subplot(3, 1, 1)
                _plt.plot(range(1, it), oo.amps[1:it, 0], color="grey", lw=1.5)
                _plt.plot(range(0, it), oo.amps[0:it, 0], color="black", lw=3)
                _plt.ylabel("amp")
                fig.add_subplot(3, 1, 2)
                _plt.plot(range(1, it), oo.fs[1:it, 0]/(2*oo.dt), color="grey", lw=1.5)
                _plt.plot(range(0, it), oo.fs[0:it, 0]/(2*oo.dt), color="black", lw=3)
                _plt.ylabel("f")
                fig.add_subplot(3, 1, 3)
                _plt.plot(range(1, it), oo.mnStds[1:it], color="grey", lw=1.5)
                _plt.plot(range(0, it), oo.mnStds[0:it], color="black", lw=3)
                _plt.ylabel("amp")
                _plt.xlabel("iter")
                _plt.savefig("%(dir)stmp-fsamps%(it)d" % {"dir" : oo.mcmcRunDir, "it" : it+1})
                fig.subplots_adjust(left=0.15, bottom=0.15, right=0.95, top=0.95)
                _plt.close()
                """
                #if it - frms > oo.stationaryDuration:
                #   break

        oo.getComponents()
        oo.dump_smps(0,
                     toiter=(it + 1),
                     dir=oo.mcmcRunDir,
                     smpls_fn_incl_trls=smpls_fn_incl_trls)
Beispiel #2
0
    def dirichletAllocate(self):  ###########################  GIBBSSAMP
        oo = self
        ooTR = oo.TR
        print ooTR
        ook = oo.k
        ooNMC = oo.NMC
        ooN = oo.N

        oo.allocateSmp(oo.burn + oo.NMC)
        oo.x00 = _N.array(oo.smpx[:, 2])
        oo.V00 = _N.zeros((ooTR, ook, ook))

        ARo = _N.empty((ooTR, oo._d.N + 1))
        ARo01 = _N.empty((oo.nStates, ooTR, oo._d.N + 1))

        kpOws = _N.empty((ooTR, ooN + 1))
        lv_f = _N.zeros((ooN + 1, ooN + 1))
        lv_u = _N.zeros((ooTR, ooTR))
        Bii = _N.zeros((ooN + 1, ooN + 1))

        #alpC.reverse()
        #  F_alfa_rep = alpR + alpC  already in right order, no?

        Wims = _N.empty((ooTR, ooN + 1, ooN + 1))
        Oms = _N.empty((ooTR, ooN + 1))
        smWimOm = _N.zeros(ooN + 1)
        smWinOn = _N.zeros(ooTR)
        bConstPSTH = False
        D_f = _N.diag(_N.ones(oo.B.shape[0]) * oo.s2_a)  #  spline
        iD_f = _N.linalg.inv(D_f)
        D_u = _N.diag(_N.ones(oo.TR) * oo.s2_u)  #  This should
        iD_u = _N.linalg.inv(D_u)
        iD_u_u_u = _N.dot(iD_u, _N.ones(oo.TR) * oo.u_u)
        BDB = _N.dot(oo.B.T, _N.dot(D_f, oo.B))
        DB = _N.dot(D_f, oo.B)
        BTua = _N.dot(oo.B.T, oo.u_a)

        it = 0

        oo.lrn = _N.empty((ooTR, ooN + 1))
        oo.s_lrn = _N.empty((ooTR, ooN + 1))
        oo.sprb = _N.empty((ooTR, ooN + 1))
        oo.lrn_scr1 = _N.empty(ooN + 1)
        oo.lrn_iscr1 = _N.empty(ooN + 1)
        oo.lrn_scr2 = _N.empty(ooN + 1)
        oo.lrn_scr3 = _N.empty(ooN + 1)
        oo.lrn_scld = _N.empty(ooN + 1)

        oo.lrn = _N.empty((ooTR, ooN + 1))
        if oo.l2 is None:
            oo.lrn[:] = 1
        else:
            for tr in xrange(ooTR):
                oo.lrn[tr] = oo.build_lrnLambda2(tr)

        ###############################  MCMC LOOP  ########################
        ###  need pointer to oo.us, but reshaped for broadcasting to work
        ###############################  MCMC LOOP  ########################
        oous_rs = oo.us.reshape((ooTR, 1))  #  done for broadcasting rules
        lrnBadLoc = _N.empty((oo.TR, oo.N + 1), dtype=_N.bool)

        sd01 = _N.zeros((oo.nStates, oo.TR, oo.TR))
        _N.fill_diagonal(sd01[0], oo.s[0])
        _N.fill_diagonal(sd01[1], oo.s[1])

        smpx01 = _N.zeros((oo.nStates, oo.TR, oo.N + 1))
        ARo01 = _N.empty((oo.nStates, oo.TR, oo.N + 1))
        zsmpx = _N.empty((oo.TR, oo.N + 1))

        #  zsmpx created
        #  PG

        zd = _N.zeros((oo.TR, oo.TR))
        izd = _N.zeros((oo.TR, oo.TR))
        ll = _N.zeros(oo.nStates)
        Bp = _N.empty((oo.nStates, oo.N + 1))

        for m in xrange(ooTR):
            oo._d.f_V[m, 0] = oo.s2_x00
            oo._d.f_V[m, 1] = oo.s2_x00

        THR = _N.empty(oo.TR)
        dirArgs = _N.empty(oo.nStates)  #  dirichlet distribution args
        expT = _N.empty(ooN + 1)
        BaS = _N.dot(oo.B.T, oo.aS)

        oo.nSMP_smpxC = 0
        if oo.processes > 1:
            print oo.processes
            pool = Pool(processes=oo.processes)

        while (it < ooNMC + oo.burn - 1):
            lowsts = _N.where(oo.Z[:, 0] == 1)
            #print "lowsts   %s" % str(lowsts)
            t1 = _tm.time()
            it += 1
            print "****------------  %d" % it

            #  generate latent AR state

            ######  Z
            #print "!!!!!!!!!!!!!!!  1"

            for tryZ in xrange(oo.nStates):
                _N.dot(sd01[tryZ], oo.smpx[..., 2:, 0], out=smpx01[tryZ])
                #oo.build_addHistory(ARo01[tryZ], smpx01[tryZ, m], BaS, oo.us, lrnBadLoc)
                oo.build_addHistory(ARo01[tryZ], smpx01[tryZ], BaS, oo.us,
                                    oo.knownSig)
                """
                for m in xrange(oo.TR):
                    locs = _N.where(lrnBadLoc[m] == True)
                    if locs[0].shape[0] > 0:
                        print "found a bad loc"
                        fig = _plt.figure(figsize=(8, 5))
                        _plt.suptitle("%d" % m)
                        _plt.subplot(2, 1, 1)
                        _plt.plot(smpx01[tryZ, m])
                        _plt.subplot(2, 1, 2)
                        _plt.plot(ARo01[tryZ, m])
                """
                #print "!!!!!!!!!!!!!!!  2"
            for m in oo.varz:
                for tryZ in xrange(
                        oo.nStates):  #  only allow certain trials to change

                    #  calculate p0, p1  p0 = m_0 x PROD_n Ber(y_n | Z_j)
                    #                       = m_0 x _N.exp(_N.log(  ))
                    #  p0, p1 not normalized
                    ll[tryZ] = 0
                    #  Ber(0 | ) and Ber(1 | )
                    _N.exp(smpx01[tryZ, m] + BaS + ARo01[tryZ, m] + oo.us[m],
                           out=expT)
                    Bp[0] = 1 / (1 + expT)
                    Bp[1] = expT / (1 + expT)

                    #   z[:, 1]   is state label

                    for n in xrange(oo.N + 1):
                        ll[tryZ] += _N.log(Bp[oo.y[m, n], n])

                ofs = _N.min(ll)
                ll -= ofs
                nc = oo.m[0] * _N.exp(ll[0]) + oo.m[1] * _N.exp(ll[1])

                iARo = 1
                oo.Z[m, 0] = 0
                oo.Z[m, 1] = 1
                THR[m] = (oo.m[0] * _N.exp(ll[0]) / nc)
                if _N.random.rand() < THR[m]:
                    oo.Z[m, 0] = 1
                    oo.Z[m, 1] = 0
                    iARo = 0
                oo.smp_zs[m, it] = oo.Z[m]
                ####  did we forget to do this?
                ARo[m] = ARo01[iARo, m]
            for m in oo.fxdz:  #####  outside BM loop
                oo.smp_zs[m, it] = oo.Z[m]
            t2 = _tm.time()

            #  Z  set
            _N.fill_diagonal(zd, oo.s[oo.Z[:, 1]])
            _N.fill_diagonal(izd, 1. / oo.s[oo.Z[:, 1]])
            #for kkk in xrange(oo.TR):
            #    print zd[kkk, kkk]
            _N.dot(zd, oo.smpx[..., 2:, 0], out=zsmpx)
            ######  sample m's
            _N.add(oo.alp, _N.sum(oo.Z[oo.varz], axis=0), out=dirArgs)
            oo.m[:] = _N.random.dirichlet(dirArgs)
            oo.smp_ms[it] = oo.m
            print oo.m

            oo.build_addHistory(ARo, zsmpx, BaS, oo.us, oo.knownSig)
            t3 = _tm.time()

            ######  PG generate
            nanLoc = _N.where(_N.isnan(BaS))

            for m in xrange(ooTR):
                lw.rpg_devroye(oo.rn,
                               zsmpx[m] + oo.us[m] + BaS + ARo[m],
                               out=oo.ws[m])  ######  devryoe  ####TRD change
                nanLoc = _N.where(_N.isnan(oo.ws[m]))
                if len(nanLoc[0]) > 0:
                    loc = nanLoc[0][0]
            _N.divide(oo.kp, oo.ws, out=kpOws)

            ########     per trial offset sample
            Ons = kpOws - zsmpx - ARo - BaS
            _N.einsum("mn,mn->m", oo.ws, Ons, out=smWinOn)  #  sum over trials
            ilv_u = _N.diag(_N.sum(oo.ws, axis=1))  #  var  of LL
            _N.fill_diagonal(lv_u, 1. / _N.diagonal(ilv_u))
            lm_u = _N.dot(lv_u, smWinOn)  #  nondiag of 1./Bi are inf, mean LL
            #  now sample
            iVAR = ilv_u + iD_u
            VAR = _N.linalg.inv(iVAR)  #
            Mn = _N.dot(VAR, _N.dot(ilv_u, lm_u) + iD_u_u_u)
            oo.us[:] = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :]

            oo.smp_u[:, it] = oo.us

            ########     PSTH sample  Do PSTH after we generate zs
            if oo.bpsth:
                Oms = kpOws - zsmpx - ARo - oous_rs
                _N.einsum("mn,mn->n", oo.ws, Oms, out=smWimOm)  #  sum over
                ilv_f = _N.diag(_N.sum(oo.ws, axis=0))
                _N.fill_diagonal(lv_f, 1. / _N.diagonal(ilv_f))
                lm_f = _N.dot(lv_f, smWimOm)  #  nondiag of 1./Bi are inf
                #  now sample
                iVAR = _N.dot(oo.B, _N.dot(ilv_f, oo.B.T)) + iD_f
                VAR = _N.linalg.inv(iVAR)  #  knots x knots
                #iBDBW = _N.linalg.inv(BDB + lv_f)   # BDB not diag
                #Mn    = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua))

                Mn = oo.u_a + _N.dot(DB,
                                     _N.linalg.solve(BDB + lv_f, lm_f - BTua))
                oo.aS = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :]
                oo.smp_aS[it, :] = oo.aS

                #iBDBW = _N.linalg.inv(BDB + lv_f)   # BDB not diag
                #Mn    = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua))
                #oo.aS   = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :]
                #oo.smp_aS[it, :] = oo.aS
            else:
                oo.aS[:] = 0
            BaS = _N.dot(oo.B.T, oo.aS)

            t4 = _tm.time()
            ####  Sample latent state
            oo._d.y = _N.dot(izd, kpOws - BaS - ARo - oous_rs)
            oo._d.copyParams(oo.F0, oo.q2)
            #  (MxM)  (MxN) = (MxN)  (Rv is MxN)
            _N.dot(_N.dot(izd, izd), 1. / oo.ws, out=oo._d.Rv)

            oo._d.f_x[:, 0, :, 0] = oo.x00
            #if it == 1:
            for m in xrange(ooTR):
                oo._d.f_V[m, 0] = oo.s2_x00
            else:
                oo._d.f_V[:, 0] = _N.mean(oo._d.f_V[:, 1:], axis=1)

            tpl_args = zip(oo._d.y, oo._d.Rv, oo._d.Fs, oo.q2, oo._d.Ns,
                           oo._d.ks, oo._d.f_x[:, 0], oo._d.f_V[:, 0])

            t5 = _tm.time()
            if oo.processes == 1:
                for m in xrange(ooTR):
                    oo.smpx[m, 2:], oo._d.f_x[m], oo._d.f_V[
                        m] = _kfar.armdl_FFBS_1itrMP(tpl_args[m])
                    oo.smpx[m, 1, 0:ook - 1] = oo.smpx[m, 2, 1:]
                    oo.smpx[m, 0, 0:ook - 2] = oo.smpx[m, 2, 2:]
                    oo.smp_q2[m, it] = oo.q2[m]

            else:
                sxv = pool.map(_kfar.armdl_FFBS_1itrMP, tpl_args)
                for m in xrange(ooTR):
                    oo.smpx[m, 2:] = sxv[m][0]
                    oo._d.f_x[m] = sxv[m][1]
                    oo._d.f_V[m] = sxv[m][2]
                    oo.smpx[m, 1, 0:ook - 1] = oo.smpx[m, 2, 1:]
                    oo.smpx[m, 0, 0:ook - 2] = oo.smpx[m, 2, 2:]
                    #oo.Bsmpx[m, it, 2:]    = oo.smpx[m, 2:, 0]

            stds = _N.std(oo.smpx[:, 2:, 0], axis=1)
            oo.mnStds[it] = _N.mean(stds, axis=0)
            print "mnStd  %.3f" % oo.mnStds[it]
            ###

            lwsts = _N.where(oo.Z[:, 0] == 1)[0]
            hists = _N.where(oo.Z[:, 1] == 1)[0]

            sts2chg = hists
            if (it > oo.startZ) and (len(sts2chg) > 0):
                AL = 0.5 * _N.sum(oo.smpx[sts2chg, 2:, 0] *
                                  oo.smpx[sts2chg, 2:, 0] * oo.ws[sts2chg])
                BRL = kpOws[sts2chg] - BaS - oous_rs[sts2chg] - ARo[sts2chg]
                BL = _N.sum(oo.ws[sts2chg] * BRL * oo.smpx[sts2chg, 2:, 0])
                UL = BL / (2 * AL)
                sgL = 1 / _N.sqrt(2 * AL)
                U = UL
                sg = sgL

                print "U  %(U).3f    s  %(s).3f" % {"U": U, "s": sg}

                oo.s[1] = U + sg * _N.random.randn()

                _N.fill_diagonal(sd01[0], oo.s[0])
                _N.fill_diagonal(sd01[1], oo.s[1])
                print oo.s[1]
                oo.smp_ss[it] = oo.s[1]

            oo.a2 = oo.a_q2 + 0.5 * (ooTR * ooN + 2)  #  N + 1 - 1
            BB2 = oo.B_q2
            for m in xrange(ooTR):
                #   set x00
                #oo.x00[m]      = oo.smpx[m, 2]*0.1
                oo.x00[m] = oo.smpx[m, 2] * 0.001

                #####################    sample q2
                rsd_stp = oo.smpx[m, 3:, 0] - _N.dot(oo.smpx[m, 2:-1], oo.F0).T
                BB2 += 0.5 * _N.dot(rsd_stp, rsd_stp.T)
            oo.q2[:] = _ss.invgamma.rvs(oo.a2, scale=BB2)

            oo.smp_q2[:, it] = oo.q2
            t7 = _tm.time()
            print "gibbs iter %.3f" % (t7 - t1)
Beispiel #3
0
    def gibbsSamp(self):  ###########################  GIBBSSAMPH
        oo = self

        print("****!!!!!!!!!!!!!!!!  dohist  %s" % str(oo.dohist))

        ooTR = oo.TR
        ook = oo.k

        ooN = oo.N
        _kfar.init(oo.N, oo.k, oo.TR)
        oo.x00 = _N.array(oo.smpx[:, 2])
        oo.V00 = _N.zeros((ooTR, ook, ook))
        if oo.dohist:
            oo.loghist = _N.zeros(oo.N + 1)
        else:
            print("fixed hist is")
            print(oo.loghist)

        print("oo.mcmcRunDir    %s" % oo.mcmcRunDir)
        if oo.mcmcRunDir is None:
            oo.mcmcRunDir = ""
        elif (len(oo.mcmcRunDir) > 0) and (oo.mcmcRunDir[-1] != "/"):
            oo.mcmcRunDir += "/"

        ARo = _N.zeros((ooTR, ooN + 1))

        kpOws = _N.empty((ooTR, ooN + 1))
        lv_f = _N.zeros((ooN + 1, ooN + 1))
        lv_u = _N.zeros((ooTR, ooTR))
        Bii = _N.zeros((ooN + 1, ooN + 1))

        #alpC.reverse()
        #  F_alfa_rep = alpR + alpC  already in right order, no?

        Wims = _N.empty((ooTR, ooN + 1, ooN + 1))
        Oms = _N.empty((ooTR, ooN + 1))
        smWimOm = _N.zeros(ooN + 1)
        smWinOn = _N.zeros(ooTR)
        bConstPSTH = False

        D_f = _N.diag(_N.ones(oo.B.shape[0]) * oo.s2_a)  #  spline
        iD_f = _N.linalg.inv(D_f)
        D_u = _N.diag(_N.ones(oo.TR) * oo.s2_u)  #  This should
        iD_u = _N.linalg.inv(D_u)
        iD_u_u_u = _N.dot(iD_u, _N.ones(oo.TR) * oo.u_u)

        if oo.bpsth:
            BDB = _N.dot(oo.B.T, _N.dot(D_f, oo.B))
            DB = _N.dot(D_f, oo.B)
            BTua = _N.dot(oo.B.T, oo.u_a)

        it = -1

        oous_rs = oo.us.reshape((ooTR, 1))
        #runTO = ooNMC + oo.burn - 1 if (burns is None) else (burns - 1)
        runTO = oo.ITERS - 1
        oo.allocateSmp(runTO + 1, Bsmpx=oo.doBsmpx)
        alpR = oo.F_alfa_rep[0:oo.R]
        alpC = oo.F_alfa_rep[oo.R:]

        BaS = _N.zeros(oo.N + 1)  #_N.empty(oo.N+1)

        #  H shape    100 x 9
        Hbf = oo.Hbf

        RHS = _N.empty((oo.histknots, 1))

        print("-----------    histknots %d" % oo.histknots)
        if oo.h0_1 > 1:  #  no spikes in first few time bins
            print("!!!!!!!   hist scenario 1")
            #cInds = _N.array([0, 1, 5, 6, 7, 8, 9, 10])
            #cInds = _N.array([0, 4, 5, 6, 7, 8, 9])
            cInds = _N.array([0, 5, 6, 7, 8, 9])
            #vInds = _N.array([2, 3, 4])
            vInds = _N.array([1, 2, 3, 4])
            RHS[cInds, 0] = 0
            RHS[0, 0] = -5
        elif oo.hist_max_at_0:  #  no refractory period
            print("!!!!!!!   hist scenario 2")
            #cInds = _N.array([5, 6, 7, 8, 9, 10])
            cInds = _N.array([
                0,
                4,
                5,
                6,
                7,
                8,
            ])
            vInds = _N.array([1, 2, 3])
            #vInds = _N.array([0, 1, 2, 3, 4])
            RHS[cInds, 0] = 0
            RHS[0, 0] = 0
        else:
            print("!!!!!!!   hist scenario 3")
            #cInds = _N.array([5, 6, 7, 8, 9, 10])
            cInds = _N.array([
                4,
                5,
                6,
                7,
                8,
                9,
            ])
            vInds = _N.array([
                0,
                1,
                2,
                3,
            ])
            #vInds = _N.array([0, 1, 2, 3, 4])
            RHS[cInds, 0] = 0

        Msts = []
        for m in range(ooTR):
            Msts.append(_N.where(oo.y[m] == 1)[0])
        HcM = _N.empty((len(vInds), len(vInds)))

        HbfExpd = _N.zeros((oo.histknots, ooTR, oo.N + 1))
        #  HbfExpd is 11 x M x 1200
        #  find the mean.  For the HISTORY TERM
        for i in range(oo.histknots):
            for m in range(oo.TR):
                sts = Msts[m]
                HbfExpd[i, m, 0:sts[0]] = 0
                for iss in range(len(sts) - 1):
                    t0 = sts[iss]
                    t1 = sts[iss + 1]
                    #HbfExpd[i, m, t0+1:t1+1] = Hbf[1:t1-t0+1, i]#Hbf[0:t1-t0, i]
                    HbfExpd[i, m, t0 + 1:t1 + 1] = Hbf[0:t1 - t0, i]
                HbfExpd[i, m, sts[-1] + 1:] = 0

        _N.dot(oo.B.T, oo.aS, out=BaS)
        if oo.hS is None:
            oo.hS = _N.zeros(oo.histknots)

        if oo.dohist:
            _N.dot(Hbf, oo.hS, out=oo.loghist)
        oo.stitch_Hist(ARo, oo.loghist, Msts)

        ##  ORDER OF SAMPLING
        ##  f_xx, f_V
        ##  DA:  PG, kpOws
        ##  history, build ARo
        ##  psth
        ##  offset
        ##  DA:  latent state
        ##  AR coefficients
        ##  q2

        K = _N.empty((oo.TR, oo.N + 1, oo.k))  #  kalman gain

        iterBLOCKS = oo.ITERS // oo.peek
        smpx_tmp = _N.empty((oo.TR, oo.N + 1, oo.k))

        ######  Gibbs sampling procedure
        for itrB in range(iterBLOCKS):
            it = itrB * oo.peek
            if it > 0:
                print("it: %(it)d    mnStd  %(mnstd).3f" % {
                    "it": itrB * oo.peek,
                    "mnstd": oo.mnStds[it - 1]
                })

            #tttA = _tm.time()
            for it in range(itrB * oo.peek, (itrB + 1) * oo.peek):
                #ttt1 = _tm.time()

                #  generate latent AR state
                oo.f_x[:, 0] = oo.x00
                if it == 0:
                    for m in range(ooTR):
                        oo.f_V[m, 0] = oo.s2_x00
                else:
                    oo.f_V[:, 0] = _N.mean(oo.f_V[:, 1:], axis=1)

                ###  PG latent variable sample
                #ttt2 = _tm.time()

                for m in range(ooTR):
                    lw.rpg_devroye(oo.rn,
                                   oo.smpx[m, 2:, 0] + oo.us[m] + BaS +
                                   ARo[m] + oo.knownSig[m],
                                   out=oo.ws[m])  ######  devryoe
                #ttt3 = _tm.time()

                if ooTR == 1:
                    oo.ws = oo.ws.reshape(1, ooN + 1)
                _N.divide(oo.kp, oo.ws, out=kpOws)

                if oo.dohist:
                    O = kpOws - oo.smpx[..., 2:, 0] - oo.us.reshape(
                        (ooTR, 1)) - BaS - oo.knownSig
                    if it == 2000:
                        _N.savetxt("it2000.dat", O)

                    iOf = vInds[0]  #  offset HcM index with RHS index.
                    #print(oo.ws)

                    for i in vInds:
                        #print("i   %d" % i)
                        #print(_N.sum(HbfExpd[i]))
                        for j in vInds:
                            #print("j   %d" % j)
                            #print(_N.sum(HbfExpd[j]))
                            HcM[i - iOf, j - iOf] = _N.sum(oo.ws * HbfExpd[i] *
                                                           HbfExpd[j])

                        RHS[i, 0] = _N.sum(oo.ws * HbfExpd[i] * O)
                        for cj in cInds:
                            RHS[i, 0] -= _N.sum(
                                oo.ws * HbfExpd[i] * HbfExpd[cj]) * RHS[cj, 0]

                    # print("HbfExpd..............................")
                    # print(HbfExpd)
                    # print("HcM..................................")
                    # print(HcM)
                    # print("RHS..................................")
                    # print(RHS[vInds])
                    vm = _N.linalg.solve(HcM, RHS[vInds])
                    Cov = _N.linalg.inv(HcM)
                    #print vm
                    #print(Cov)
                    #print(vm[:, 0])
                    cfs = _N.random.multivariate_normal(vm[:, 0], Cov, size=1)

                    RHS[vInds, 0] = cfs[0]
                    oo.smp_hS[:, it] = RHS[:, 0]

                    #RHS[2:6, 0] = vm[:, 0]
                    #vv = _N.dot(Hbf, RHS)
                    #print vv.shape
                    #print oo.loghist.shape
                    _N.dot(Hbf, RHS[:, 0], out=oo.loghist)
                    oo.smp_hist[:, it] = oo.loghist
                    oo.stitch_Hist(ARo, oo.loghist, Msts)
                else:
                    oo.smp_hist[:, it] = oo.loghist
                    oo.stitch_Hist(ARo, oo.loghist, Msts)

                #  Now that we have PG variables, construct Gaussian timeseries
                #  ws(it+1)    using u(it), F0(it), smpx(it)

                #  cov matrix, prior of aS

                # oo.gau_obs = kpOws - BaS - ARo - oous_rs - oo.knownSig
                # oo.gau_var =1 / oo.ws   #  time dependent noise
                #ttt4 = _tm.time()
                if oo.bpsth:
                    Oms = kpOws - oo.smpx[..., 2:,
                                          0] - ARo - oous_rs - oo.knownSig
                    _N.einsum("mn,mn->n", oo.ws, Oms, out=smWimOm)  #  sum over
                    ilv_f = _N.diag(_N.sum(oo.ws, axis=0))
                    #  diag(_N.linalg.inv(Bi)) == diag(1./Bi).  Bii = inv(Bi)
                    _N.fill_diagonal(lv_f, 1. / _N.diagonal(ilv_f))
                    lm_f = _N.dot(lv_f, smWimOm)  #  nondiag of 1./Bi are inf
                    #  now sample
                    iVAR = _N.dot(oo.B, _N.dot(ilv_f, oo.B.T)) + iD_f
                    #ttt4a = _tm.time()
                    VAR = _N.linalg.inv(iVAR)  #  knots x knots
                    #ttt4b = _tm.time()
                    #iBDBW = _N.linalg.inv(BDB + lv_f)   # BDB not diag
                    #Mn    = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua))

                    #  BDB + lv_f     (N+1 x N+1)
                    #  lm_f - BTua    (N+1)
                    Mn = oo.u_a + _N.dot(
                        DB, _N.linalg.solve(BDB + lv_f, lm_f - BTua))

                    #t4c = _tm.time()

                    oo.aS = _N.random.multivariate_normal(Mn, VAR,
                                                          size=1)[0, :]
                    oo.smp_aS[it, :] = oo.aS
                    _N.dot(oo.B.T, oo.aS, out=BaS)

                #ttt5 = _tm.time()
                ########     per trial offset sample  burns==None, only psth fit
                Ons = kpOws - oo.smpx[..., 2:, 0] - ARo - BaS - oo.knownSig

                #  solve for the mean of the distribution

                if not oo.bpsth:  # if not doing PSTH, don't constrain offset, as there are no confounds controlling offset
                    _N.einsum("mn,mn->m", oo.ws, Ons,
                              out=smWinOn)  #  sum over trials
                    ilv_u = _N.diag(_N.sum(oo.ws, axis=1))  #  var  of LL
                    #  diag(_N.linalg.inv(Bi)) == diag(1./Bi).  Bii = inv(Bi)
                    _N.fill_diagonal(lv_u, 1. / _N.diagonal(ilv_u))
                    lm_u = _N.dot(
                        lv_u, smWinOn)  #  nondiag of 1./Bi are inf, mean LL
                    #  now sample
                    iVAR = ilv_u + iD_u
                    VAR = _N.linalg.inv(iVAR)  #
                    Mn = _N.dot(VAR, _N.dot(ilv_u, lm_u) + iD_u_u_u)
                    oo.us[:] = _N.random.multivariate_normal(Mn, VAR,
                                                             size=1)[0, :]
                    if not oo.bIndOffset:
                        oo.us[:] = _N.mean(oo.us)
                    oo.smp_u[:, it] = oo.us
                else:
                    H = _N.ones((oo.TR - 1, oo.TR - 1)) * _N.sum(oo.ws[0])
                    uRHS = _N.empty(oo.TR - 1)
                    for dd in range(1, oo.TR):
                        H[dd - 1, dd - 1] += _N.sum(oo.ws[dd])
                        uRHS[dd - 1] = _N.sum(oo.ws[dd] * Ons[dd] -
                                              oo.ws[0] * Ons[0])

                    MM = _N.linalg.solve(H, uRHS)
                    Cov = _N.linalg.inv(H)

                    oo.us[1:] = _N.random.multivariate_normal(MM, Cov, size=1)
                    oo.us[0] = -_N.sum(oo.us[1:])
                    if not oo.bIndOffset:
                        oo.us[:] = _N.mean(oo.us)
                    oo.smp_u[:, it] = oo.us

                # Ons  = kpOws - ARo
                # _N.einsum("mn,mn->m", oo.ws, Ons, out=smWinOn)  #  sum over trials
                # ilv_u  = _N.diag(_N.sum(oo.ws, axis=1))  #  var  of LL
                # #  diag(_N.linalg.inv(Bi)) == diag(1./Bi).  Bii = inv(Bi)
                # _N.fill_diagonal(lv_u, 1./_N.diagonal(ilv_u))
                # lm_u  = _N.dot(lv_u, smWinOn)  #  nondiag of 1./Bi are inf, mean LL
                # #  now sample
                # iVAR = ilv_u + iD_u
                # VAR  = _N.linalg.inv(iVAR)  #
                # Mn    = _N.dot(VAR, _N.dot(ilv_u, lm_u) + iD_u_u_u)
                # oo.us[:]  = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :]
                # if not oo.bIndOffset:
                #     oo.us[:] = _N.mean(oo.us)
                # oo.smp_u[:, it] = oo.us

                #ttt6 = _tm.time()
                if not oo.noAR:
                    #  _d.F, _d.N, _d.ks,
                    #_kfar.armdl_FFBS_1itrMP(oo.gau_obs, oo.gau_var, oo.Fs, _N.linalg.inv(oo.Fs), oo.q2, oo.Ns, oo.ks, oo.f_x, oo.f_V, oo.p_x, oo.p_V, oo.smpx, K)

                    oo.gau_obs = kpOws - BaS - ARo - oous_rs - oo.knownSig
                    oo.gau_var = 1 / oo.ws  #  time dependent noise

                    _kfar.armdl_FFBS_1itrMP(oo.gau_obs, oo.gau_var, oo.Fs,
                                            _N.linalg.inv(oo.Fs), oo.q2, oo.Ns,
                                            oo.ks, oo.f_x, oo.f_V, oo.p_x,
                                            oo.p_V, smpx_tmp, K)

                    oo.smpx[:, 2:] = smpx_tmp
                    oo.smpx[:, 1, 0:ook - 1] = oo.smpx[:, 2, 1:]
                    oo.smpx[:, 0, 0:ook - 2] = oo.smpx[:, 2, 2:]

                    if oo.doBsmpx and (it % oo.BsmpxSkp == 0):
                        oo.Bsmpx[:, it // oo.BsmpxSkp, 2:] = oo.smpx[:, 2:, 0]
                        #oo.Bsmpx[it // oo.BsmpxSkp, :, 2:]    = oo.smpx[:, 2:, 0]
                    stds = _N.std(oo.smpx[:, 2 + oo.ignr:, 0], axis=1)
                    oo.mnStds[it] = _N.mean(stds, axis=0)

                    #ttt7 = _tm.time()
                    if not oo.bFixF:
                        #ARcfSmpl(oo.lfc, ooN+1-oo.ignr, ook, oo.AR2lims, oo.smpx[:, 1+oo.ignr:, 0:ook], oo.smpx[:, oo.ignr:, 0:ook-1], oo.q2, oo.R, oo.Cs, oo.Cn, alpR, alpC, oo.TR, prior=oo.use_prior, accepts=8, aro=oo.ARord, sig_ph0L=oo.sig_ph0L, sig_ph0H=oo.sig_ph0H)
                        ARcfSmpl(ooN + 1 - oo.ignr,
                                 ook,
                                 oo.AR2lims,
                                 oo.smpx[:, 1 + oo.ignr:, 0:ook],
                                 oo.smpx[:, oo.ignr:, 0:ook - 1],
                                 oo.q2,
                                 oo.R,
                                 oo.Cs,
                                 oo.Cn,
                                 alpR,
                                 alpC,
                                 oo.TR,
                                 prior=oo.use_prior,
                                 accepts=8,
                                 aro=oo.ARord,
                                 sig_ph0L=oo.sig_ph0L,
                                 sig_ph0H=oo.sig_ph0H)
                        oo.F_alfa_rep = alpR + alpC  #  new constructed
                        prt, rank, f, amp = ampAngRep(oo.F_alfa_rep,
                                                      f_order=True)
                        #print prt
                    #ut, wt = FilteredTimeseries(ooN+1, ook, oo.smpx[:, 1:, 0:ook], oo.smpx[:, :, 0:ook-1], oo.q2, oo.R, oo.Cs, oo.Cn, alpR, alpC, oo.TR)
                    #ranks[it]    = rank
                    oo.allalfas[it] = oo.F_alfa_rep

                    for m in range(ooTR):
                        #oo.wts[m, it, :, :]   = wt[m, :, :, 0]
                        #oo.uts[m, it, :, :]   = ut[m, :, :, 0]
                        if not oo.bFixF:
                            oo.amps[it, :] = amp
                            oo.fs[it, :] = f

                    oo.F0 = (-1 *
                             _Npp.polyfromroots(oo.F_alfa_rep)[::-1].real)[1:]
                    for tr in range(oo.TR):
                        oo.Fs[tr, 0] = oo.F0[:]

                    #  sample u     WE USED TO Do this after smpx
                    #  u(it+1)    using ws(it+1), F0(it), smpx(it+1), ws(it+1)

                    oo.a2 = oo.a_q2 + 0.5 * (ooTR * ooN + 2)  #  N + 1 - 1
                    #oo.a2 = 0.5*(ooTR*(ooN-oo.ignr) + 2)  #  N + 1 - 1
                    BB2 = oo.B_q2
                    #BB2 = 0
                    for m in range(ooTR):
                        #   set x00
                        oo.x00[m] = oo.smpx[m, 2] * 0.1

                        #####################    sample q2
                        rsd_stp = oo.smpx[m, 3 + oo.ignr:, 0] - _N.dot(
                            oo.smpx[m, 2 + oo.ignr:-1], oo.F0).T
                        #oo.rsds[it, m] = _N.dot(rsd_stp, rsd_stp.T)
                        BB2 += 0.5 * _N.dot(rsd_stp, rsd_stp.T)

                    oo.q2[:] = _ss.invgamma.rvs(oo.a2, scale=BB2)
                    oo.smp_q2[:, it] = oo.q2

                #ttt8 = _tm.time()

    #             print("--------------------------------")
    #             print ("t2-t1  %.4f" % (#ttt2-#ttt1))
    #             print ("t3-t2  %.4f" % (#ttt3-#ttt2))
    #             print ("t4-t3  %.4f" % (#ttt4-#ttt3))
    # #            print ("t4b-t4a  %.4f" % (t4b-t4a))
    # #            print ("t4c-t4b  %.4f" % (t4c-t4b))
    # #            print ("t4-t4c  %.4f" % (t4-t4c))
    #             print ("t5-t4  %.4f" % (#ttt5-#ttt4))
    #             print ("t6-t5  %.4f" % (#ttt6-#ttt5))
    #             print ("t7-t6  %.4f" % (#ttt7-#ttt6))
    #             print ("t8-t7  %.4f" % (#ttt8-#ttt7))
    #tttB = _tm.time()
    #print("#tttB - #tttA  %.4f" % (#tttB - #tttA))

            oo.last_iter = it
            if it > oo.minITERS:
                smps = _N.empty((3, it + 1))
                smps[0, :it + 1] = oo.amps[:it + 1, 0]
                smps[1, :it + 1] = oo.fs[:it + 1, 0]
                smps[2, :it + 1] = oo.mnStds[:it + 1]

                #frms = _mg.stationary_from_Z_bckwd(smps, blksz=oo.peek)
                if _mg.stationary_test(oo.amps[:it + 1, 0],
                                       oo.fs[:it + 1, 0],
                                       oo.mnStds[:it + 1],
                                       it + 1,
                                       blocksize=oo.mg_blocksize,
                                       points=oo.mg_points):
                    break
                """
                fig = _plt.figure(figsize=(8, 8))
                fig.add_subplot(3, 1, 1)
                _plt.plot(range(1, it), oo.amps[1:it, 0], color="grey", lw=1.5)
                _plt.plot(range(0, it), oo.amps[0:it, 0], color="black", lw=3)
                _plt.ylabel("amp")
                fig.add_subplot(3, 1, 2)
                _plt.plot(range(1, it), oo.fs[1:it, 0]/(2*oo.dt), color="grey", lw=1.5)
                _plt.plot(range(0, it), oo.fs[0:it, 0]/(2*oo.dt), color="black", lw=3)
                _plt.ylabel("f")
                fig.add_subplot(3, 1, 3)
                _plt.plot(range(1, it), oo.mnStds[1:it], color="grey", lw=1.5)
                _plt.plot(range(0, it), oo.mnStds[0:it], color="black", lw=3)
                _plt.ylabel("amp")
                _plt.xlabel("iter")
                _plt.savefig("%(dir)stmp-fsamps%(it)d" % {"dir" : oo.mcmcRunDir, "it" : it+1})
                fig.subplots_adjust(left=0.15, bottom=0.15, right=0.95, top=0.95)
                _plt.close()
                """
                #if it - frms > oo.stationaryDuration:
                #   break

        oo.dump_smps(0, toiter=(it + 1), dir=oo.mcmcRunDir)
        oo.VIS = ARo  #  to examine this from outside
    def sample_posterior(self,
                         ITER,
                         a_F0,
                         b_F0,
                         a_q2,
                         B_q2,
                         smp_Bns,
                         smp_offsets,
                         smp_F0s,
                         smp_q2s,
                         off_mu=0,
                         off_sig2=0.4,
                         random_walk=False):
        w1_px = _N.random.randn(Tm1)
        w1_pV = _N.ones(Tm1) * 0.2
        w1_fx = _N.zeros(Tm1)
        w1_fV = _N.ones(Tm1) * 0.1
        w2_px = _N.random.randn(Tm1)
        w2_pV = _N.random.rand(Tm1)
        w2_fx = _N.zeros(Tm1)
        w2_fV = _N.ones(Tm1) * 0.1

        t1_px = _N.random.randn(Tm1)
        t1_pV = _N.ones(Tm1) * 0.2
        t1_fx = _N.zeros(Tm1)
        t1_fV = _N.ones(Tm1) * 0.1
        t2_px = _N.random.randn(Tm1)
        t2_pV = _N.random.rand(Tm1)
        t2_fx = _N.zeros(Tm1)
        t2_fV = _N.ones(Tm1) * 0.1

        l1_px = _N.random.randn(Tm1)
        l1_pV = _N.ones(Tm1) * 0.2
        l1_fx = _N.zeros(Tm1)
        l1_fV = _N.ones(Tm1) * 0.1
        l2_px = _N.random.randn(Tm1)
        l2_pV = _N.random.rand(Tm1)
        l2_fx = _N.zeros(Tm1)
        l2_fV = _N.ones(Tm1) * 0.1

        w1_K = _N.empty(Tm1)
        t1_K = _N.empty(Tm1)
        l1_K = _N.empty(Tm1)
        w2_K = _N.empty(Tm1)
        t2_K = _N.empty(Tm1)
        l2_K = _N.empty(Tm1)

        o_w1 = _N.random.randn(Tm1)  #  start at 0 + u
        o_t1 = _N.random.randn(Tm1)  #  start at 0 + u
        o_l1 = _N.random.randn(Tm1)  #  start at 0 + u
        o_w2 = _N.random.randn(Tm1)  #  start at 0 + u
        o_t2 = _N.random.randn(Tm1)  #  start at 0 + u
        o_l2 = _N.random.randn(Tm1)  #  start at 0 + u

        B1wn = _N.random.randn(Tm1)  #  start at 0 + u
        B1tn = _N.random.randn(Tm1)  #  start at 0 + u
        B1ln = _N.random.randn(Tm1)  #  start at 0 + u
        B2wn = _N.random.randn(Tm1)  #  start at 0 + u
        B2tn = _N.random.randn(Tm1)  #  start at 0 + u
        B2ln = _N.random.randn(Tm1)  #  start at 0 + u

        q2_Bw1 = 1.
        q2_Bt1 = 1.
        q2_Bl1 = 1.
        q2_Bw2 = 1.
        q2_Bt2 = 1.
        q2_Bl2 = 1.
        F0_Bw1 = 0
        F0_Bt1 = 0
        F0_Bl1 = 0
        F0_Bw2 = 0
        F0_Bt2 = 0
        F0_Bl2 = 0

        ws1 = _N.random.rand(Tm1)
        ws2 = _N.random.rand(Tm1)

        zr2 = _N.where(
            N_vec[:, 1] == 0)[0]  #  dat where N_2 == 0  (only 1 PG var)
        nzr2 = _N.where(N_vec[:, 1] == 1)[0]

        W_n = _N.zeros(Tm1, dtype=_N.int)
        T_n = _N.zeros(Tm1, dtype=_N.int)
        L_n = _N.zeros(Tm1, dtype=_N.int)

        if covariates == _WTL:
            win = _N.where(hnd_dat[0:Tm1, 2] == 1)[0]
            tie = _N.where(hnd_dat[0:Tm1, 2] == 0)[0]
            los = _N.where(hnd_dat[0:Tm1, 2] == -1)[0]
        elif covariates == _RPS:
            win = _N.where(hnd_dat[0:Tm1, 0] == 1)[0]
            tie = _N.where(hnd_dat[0:Tm1, 0] == 2)[0]
            los = _N.where(hnd_dat[0:Tm1, 0] == 3)[0]
        W_n[win] = 1
        T_n[win] = -1
        L_n[win] = -1
        #
        W_n[tie] = -1
        T_n[tie] = 1
        L_n[tie] = -1
        #
        W_n[los] = -1
        T_n[los] = -1
        L_n[los] = -1

        K = 3
        N_vec = _N.zeros((Tm1, K), dtype=_N.int)  #  The N vector
        N = 1
        kappa = _N.empty((Tm1, K))

        #  hand n, n-1    --->  obs of
        #  1 < 2   R < P    1%3 < 2%3   1 < 2
        #  2 < 3   P < S    2%3 < 3%3   2 < 3
        #  3 < 1   S < R    3%3 < 1%3   0 < 1
        if signal == _RELATIVE_LAST_ME:
            col_n0 = 0  #  current
            col_n1 = 0  #  previous
        elif signal == _RELATIVE_LAST_AI:
            col_n0 = 0  #  did player copy AI's last move
            col_n1 = 1  #  or did player go to move that beat (loses) to the last AI
        elif signal == _RELATIVE_LAST_OU:
            col_n0 = 0  #  did player copy AI's last move
            col_n1 = 2  #  or did player go to move that beat (loses) to the last AI

        #  RELATIVE LAST ME  -  stay or switch (2 types of switch)
        #  RELATIVE LAST AI  -  copy AI or
        y_vec = _N.zeros((Tm1, 3), dtype=_N.int)
        y = _N.zeros(Tm1, dtype=_N.int)  #  indices of the random var

        for n in range(1, Tobs):
            if signal != _RELATIVE_LAST_OU:
                if (hnd_dat[n, col_n0] == hnd_dat[n - 1, col_n1]):
                    y[n - 1] = 0  #  Goo, choki, paa   goo->choki
                    #   choki->paa
                    y_vec[n - 1, 0] = 1  #  [1, 0, 0]    stay
                elif ((hnd_dat[n, col_n0] == 1) and (hnd_dat[n-1, col_n1] == 3)) or \
                     ((hnd_dat[n, col_n0] == 2) and (hnd_dat[n-1, col_n1] == 1)) or \
                     ((hnd_dat[n, col_n0] == 3) and (hnd_dat[n-1, col_n1] == 2)):
                    y[n - 1] = -1
                    y_vec[n - 1, 1] = 1  #  [0, 1, 0]    choose weaker
                elif ((hnd_dat[n, col_n0] == 1) and (hnd_dat[n-1, col_n1] == 2)) or \
                     ((hnd_dat[n, col_n0] == 2) and (hnd_dat[n-1, col_n1] == 3)) or \
                     ((hnd_dat[n, col_n0] == 3) and (hnd_dat[n-1, col_n1] == 1)):
                    y[n - 1] = 1
                    y_vec[n - 1, 2] = 1  #  [0, 0, 1]    choose stronger
            else:
                if (hnd_dat[n, col_n1] == 1):  # win
                    y[n - 1] = 1
                    y_vec[n - 1, 0] = 1  #  [1, 0, 0]    stay
                elif (hnd_dat[n, col_n1] == 0):  # tie
                    y[n - 1] = 0
                    y_vec[n - 1, 1] = 1  #  [0, 1, 0]    stay
                elif (hnd_dat[n, col_n1] == -1):  # los
                    y[n - 1] = -1
                    y_vec[n - 1, 2] = 1  #  [0, 0, 1]    stay

        for n in range(Tm1):
            N_vec[n, 0] = 1
            for k in range(1, K):
                N_vec[n, k] = N - _N.sum(y_vec[n, 0:k])
            for k in range(K):
                kappa[n, k] = y_vec[n, k] - 0.5 * N_vec[n, k]

        print("3")
        smp_offsets = _N.empty((6, ITER, Tm1))
        smp_Bns = _N.empty((6, ITER, Tm1))
        smp_q2s = _N.empty((ITER, 6))
        smp_F0s = _N.empty((ITER, 6))

        #_d.copyData(_N.empty(N), _N.empty(N), onetrial=True)   #  dummy data copied

        off_sig2 = 0.4
        off_mu = 0

        o_w1[:] = 0
        o_t1[:] = 0
        o_l1[:] = 0
        o_w2[:] = 0
        o_t2[:] = 0
        o_l2[:] = 0

        do_order = _N.arange(6)
        for it in range(ITER):
            if it % 1000 == 0:
                print("%(it)d   capped %(cp)d" % {"it": it, "cp": capped})

            vrncL1 = 1 / _N.sum(ws1)
            vrnc1 = (off_sig2 * vrncL1) / (off_sig2 + vrncL1)
            vrncL2 = 1 / _N.sum(ws2)
            vrnc2 = (off_sig2 * vrncL2) / (off_sig2 + vrncL2)

            _N.random.shuffle(do_order)

            for di in do_order:
                #################
                if di == 0:
                    o_w1, F0_Bw1, q2_Bw1 = sampleAR_and_offset(
                        it, Tm1, vrnc1, vrncL1, B1wn, W_n, o_w1, B1tn, T_n,
                        o_t1, B1ln, L_n, o_l1, kappa[:, 0], ws1, q2_Bw1, a_F0,
                        b_F0, a_q2, B_q2, w1_px, w1_pV, w1_fx, w1_fV, w1_K,
                        random_walk)
                    smp_offsets[0, it] = o_w1[0]
                elif di == 1:
                    #################
                    o_t1, F0_Bt1, q2_Bt1 = sampleAR_and_offset(
                        it, Tm1, vrnc1, vrncL1, B1tn, T_n, o_t1, B1ln, L_n,
                        o_l1, B1wn, W_n, o_w1, kappa[:, 0], ws1, q2_Bt1, a_F0,
                        b_F0, a_q2, B_q2, t1_px, t1_pV, t1_fx, t1_fV, t1_K,
                        random_walk)
                    smp_offsets[1, it] = o_t1[0]
                elif di == 2:
                    #################
                    o_l1, F0_Bl1, q2_Bl1 = sampleAR_and_offset(
                        it, Tm1, vrnc1, vrncL1, B1ln, L_n, o_l1, B1wn, W_n,
                        o_w1, B1tn, T_n, o_t1, kappa[:, 0], ws1, q2_Bl1, a_F0,
                        b_F0, a_q2, B_q2, l1_px, l1_pV, l1_fx, l1_fV, l1_K,
                        random_walk)
                    smp_offsets[2, it] = o_l1[0]
                elif di == 3:
                    #################
                    o_w2, F0_Bw2, q2_Bw2 = sampleAR_and_offset(
                        it, Tm1, vrnc2, vrncL2, B2wn, W_n, o_w2, B2tn, T_n,
                        o_t2, B2ln, L_n, o_l2, kappa[:, 1], ws2, q2_Bw2, a_F0,
                        b_F0, a_q2, B_q2, w2_px, w2_pV, w2_fx, w2_fV, w2_K,
                        random_walk)
                    smp_offsets[3, it] = o_w2[0]
                elif di == 4:
                    #################
                    o_t2, F0_Bt2, q2_Bt2 = sampleAR_and_offset(
                        it, Tm1, vrnc2, vrncL2, B2tn, T_n, o_t2, B2ln, L_n,
                        o_l2, B2wn, W_n, o_w2, kappa[:, 1], ws2, q2_Bt2, a_F0,
                        b_F0, a_q2, B_q2, t2_px, t2_pV, t2_fx, t2_fV, t2_K,
                        random_walk)
                    smp_offsets[4, it] = o_t2[0]
                elif di == 5:
                    #################
                    o_l2, F0_Bl2, q2_Bl2 = sampleAR_and_offset(
                        it, Tm1, vrnc2, vrncL2, B2ln, L_n, o_l2, B2wn, W_n,
                        o_w2, B2tn, T_n, o_t2, kappa[:, 1], ws2, q2_Bl2, a_F0,
                        b_F0, a_q2, B_q2, l2_px, l2_pV, l2_fx, l2_fV, l2_K,
                        random_walk)
                    smp_offsets[5, it] = o_l2[0]

            smp_Bns[0, it] = B1wn
            smp_Bns[1, it] = B1tn
            smp_Bns[2, it] = B1ln
            smp_Bns[3, it] = B2wn
            smp_Bns[4, it] = B2tn
            smp_Bns[5, it] = B2ln

            smp_q2s[it] = q2_Bw1, q2_Bt1, q2_Bl1, q2_Bw2, q2_Bt2, q2_Bl2
            #if random_walk:
            #    F0_Bw1 = F0_Bt1 = F0_Bl1 = F0_Bw2 = F0_Bt2 = F0_Bl2 = 1
            smp_F0s[it] = F0_Bw1, F0_Bt1, F0_Bl1, F0_Bw2, F0_Bt2, F0_Bl2

            lw.rpg_devroye(N_vec[:, 0],
                           o_w1 * W_n + o_t1 * T_n + o_l1 * L_n,
                           out=ws1)
            lw.rpg_devroye(N_vec[:, 1],
                           o_w2 * W_n + o_t2 * T_n + o_l2 * L_n,
                           out=ws2)

            ws2[zr2] = 1e-20  #1e-20

        pklme = {}
        smp_every = 50
        pklme["smp_Bns"] = smp_Bns[:, ::smp_every]
        pklme["smp_q2s"] = smp_q2s[::smp_every]
        pklme["smp_F0s"] = smp_F0s[::smp_every]
        pklme["smp_offsets"] = smp_offsets[:, ::smp_every]
        pklme["smp_every"] = smp_every
        pklme["Wn"] = W_n
        pklme["Tn"] = T_n
        pklme["Ln"] = L_n
        pklme["hnd_dat"] = hnd_dat
        pklme["y_vec"] = y_vec
        pklme["N_vec"] = N_vec
        pklme["a_q2"] = a_q2
        pklme["B_q2"] = B_q2
        pklme["l_capped"] = l_capped
        dmp = open(
            "%(dir)s/%(rel)s,%(cov)s%(ran)s2.dmp" % {
                "rel": ssig,
                "cov": scov,
                "ran": sran,
                "dir": out_dir
            }, "wb")
        pickle.dump(pklme, dmp, -1)
        dmp.close()
        print("capped:  %d" % capped)
    def sample_posterior(self,
                         ITER,
                         a_F0,
                         b_F0,
                         a_q2,
                         B_q2,
                         smp_Bns,
                         smp_offsets,
                         smp_F0s,
                         smp_q2s,
                         off_mu=0,
                         off_sig2=0.4,
                         random_walk=False):
        oo = self

        w1_px = _N.random.randn(oo.Tm1)
        w1_pV = _N.ones(oo.Tm1) * 0.2
        w1_fx = _N.zeros(oo.Tm1)
        w1_fV = _N.ones(oo.Tm1) * 0.1
        w2_px = _N.random.randn(oo.Tm1)
        w2_pV = _N.random.rand(oo.Tm1)
        w2_fx = _N.zeros(oo.Tm1)
        w2_fV = _N.ones(oo.Tm1) * 0.1

        w1_K = _N.empty(oo.Tm1)
        w2_K = _N.empty(oo.Tm1)

        o_w1 = _N.random.randn(oo.Tm1)  #  start at 0 + u
        o_w2 = _N.random.randn(oo.Tm1)  #  start at 0 + u
        B1wn = _N.random.randn(oo.Tm1)  #  start at 0 + u
        B2wn = _N.random.randn(oo.Tm1)  #  start at 0 + u
        q2_Bw1 = 1.
        q2_Bw2 = 1.
        F0_Bw1 = 0
        F0_Bw2 = 0
        ws1 = _N.random.rand(oo.Tm1)
        ws2 = _N.random.rand(oo.Tm1)

        zr2 = _N.where(
            oo.N_vec[:, 1] == 0)[0]  #  dat where N_2 == 0  (only 1 PG var)
        nzr2 = _N.where(oo.N_vec[:, 1] == 1)[0]

        do_order = _N.arange(2)

        for it in range(ITER):
            if it % 1000 == 0:
                print("%(it)d   capped %(cp)d" % {"it": it, "cp": capped})

            vrncL1 = 1 / _N.sum(ws1)
            vrnc1 = (off_sig2 * vrncL1) / (off_sig2 + vrncL1)
            vrncL2 = 1 / _N.sum(ws2)
            vrnc2 = (off_sig2 * vrncL2) / (off_sig2 + vrncL2)

            _N.random.shuffle(do_order)

            for di in do_order:
                #################
                if di == 0:
                    o_w1, F0_Bw1, q2_Bw1 \
                        = sampleAR1_and_offset(it, oo.Tm1, off_mu, off_sig2,
                                               vrnc1, vrncL1,B1wn, o_w1,
                                               oo.kappa[:, 0], ws1, q2_Bw1,
                                               a_F0, b_F0, a_q2, B_q2,
                                               w1_px, w1_pV, w1_fx, w1_fV,
                                               w1_K, random_walk)
                    smp_offsets[0, it] = o_w1
                elif di == 1:
                    #################
                    o_w2, F0_Bw2, q2_Bw2 \
                        = sampleAR1_and_offset(it, oo.Tm1, off_mu, off_sig2,
                                               vrnc2, vrncL2, B2wn, o_w2,
                                               oo.kappa[:, 1], ws2, q2_Bw2,
                                               a_F0, b_F0, a_q2, B_q2,
                                               w2_px, w2_pV, w2_fx, w2_fV,
                                               w2_K, random_walk)
                    smp_offsets[1, it] = o_w2

            smp_Bns[0, it] = B1wn
            smp_Bns[1, it] = B2wn

            #smp_q2s[it, 2*cond:2*cond+2]  = q2_Bw1, q2_Bw2y
            smp_q2s[it] = q2_Bw1, q2_Bw2
            #if random_walk:
            #    F0_Bw1 = F0_Bt1 = F0_Bl1 = F0_Bw2 = F0_Bt2 = F0_Bl2 = 1
            smp_F0s[it] = F0_Bw1, F0_Bw2

            lw.rpg_devroye(oo.N_vec[:, 0], B1wn + o_w1, out=ws1)
            lw.rpg_devroye(oo.N_vec[:, 1], B2wn + o_w2, out=ws2)

            ws2[zr2] = 1e-20  #1e-20

        print(smp_F0s[:, 0])
        print(smp_F0s[:, 1])
Beispiel #6
0
    def dirichletAllocate(self):  ###########################  GIBBSSAMP
        oo = self
        ooTR = oo.TR
        print ooTR
        ook = oo.k
        ooNMC = oo.NMC
        ooN = oo.N

        oo.allocateSmp(oo.burn + oo.NMC)
        oo.x00 = _N.array(oo.smpx[:, 2])
        oo.V00 = _N.zeros((ooTR, ook, ook))

        oo.loghist = _N.zeros(oo.N + 1)

        ARo = _N.empty((ooTR, oo._d.N + 1))

        kpOws = _N.empty((ooTR, ooN + 1))
        lv_f = _N.zeros((ooN + 1, ooN + 1))
        lv_u = _N.zeros((ooTR, ooTR))
        Bii = _N.zeros((ooN + 1, ooN + 1))

        #alpC.reverse()
        #  F_alfa_rep = alpR + alpC  already in right order, no?

        Wims = _N.empty((ooTR, ooN + 1, ooN + 1))
        Oms = _N.empty((ooTR, ooN + 1))
        smWimOm = _N.zeros(ooN + 1)
        smWinOn = _N.zeros(ooTR)
        bConstPSTH = False
        D_f = _N.diag(_N.ones(oo.B.shape[0]) * oo.s2_a)  #  spline
        iD_f = _N.linalg.inv(D_f)
        D_u = _N.diag(_N.ones(oo.TR) * oo.s2_u)  #  This should
        iD_u = _N.linalg.inv(D_u)
        iD_u_u_u = _N.dot(iD_u, _N.ones(oo.TR) * oo.u_u)
        BDB = _N.dot(oo.B.T, _N.dot(D_f, oo.B))
        DB = _N.dot(D_f, oo.B)
        BTua = _N.dot(oo.B.T, oo.u_a)

        it = 0

        ###############################  MCMC LOOP  ########################
        ###  need pointer to oo.us, but reshaped for broadcasting to work
        ###############################  MCMC LOOP  ########################
        oous_rs = oo.us.reshape((ooTR, 1))  #  done for broadcasting rules

        sd01 = _N.zeros((oo.nStates, oo.TR, oo.TR))
        _N.fill_diagonal(sd01[0], oo.s[0])
        _N.fill_diagonal(sd01[1], oo.s[1])

        smpx01 = _N.zeros((oo.nStates, oo.TR, oo.N + 1))
        zsmpx = _N.empty((oo.TR, oo.N + 1))

        #  zsmpx created
        #  PG

        zd = _N.zeros((oo.TR, oo.TR))
        izd = _N.zeros((oo.TR, oo.TR))
        ll = _N.zeros(oo.nStates)
        Bp = _N.empty((oo.nStates, oo.N + 1))

        for m in xrange(ooTR):
            oo._d.f_V[m, 0] = oo.s2_x00
            oo._d.f_V[m, 1] = oo.s2_x00

        THR = _N.empty(oo.TR)
        dirArgs = _N.empty(oo.nStates)  #  dirichlet distribution args
        expT = _N.empty(ooN + 1)
        BaS = _N.dot(oo.B.T, oo.aS)

        alpR = oo.F_alfa_rep[0:oo.R]
        alpC = oo.F_alfa_rep[oo.R:]

        oo.nSMP_smpxC = 0
        if oo.processes > 1:
            print oo.processes
            pool = Pool(processes=oo.processes)
        print "oo.mcmcRunDir    %s" % oo.mcmcRunDir
        if oo.mcmcRunDir is None:
            oo.mcmcRunDir = ""
        elif (len(oo.mcmcRunDir) > 0) and (oo.mcmcRunDir[-1] != "/"):
            oo.mcmcRunDir += "/"

        #  H shape    100 x 9
        Hbf = oo.Hbf

        RHS = _N.empty((oo.histknots, 1))

        if oo.h0_1 > 1:  #  first few are 0s
            #cInds = _N.array([0, 1, 5, 6, 7, 8, 9, 10])
            cInds = _N.array([0, 4, 5, 6, 7, 8, 9])
            #vInds = _N.array([2, 3, 4])
            vInds = _N.array([
                1,
                2,
                3,
            ])
            RHS[cInds, 0] = 0
            RHS[0, 0] = -5
        else:
            #cInds = _N.array([5, 6, 7, 8, 9, 10])
            cInds = _N.array([
                4,
                5,
                6,
                7,
                8,
                9,
            ])
            vInds = _N.array([
                0,
                1,
                2,
                3,
            ])
            #vInds = _N.array([0, 1, 2, 3, 4])
            RHS[cInds, 0] = 0

        Msts = []
        for m in xrange(ooTR):
            Msts.append(_N.where(oo.y[m] == 1)[0])
        HcM = _N.empty((len(vInds), len(vInds)))

        HbfExpd = _N.empty((oo.histknots, ooTR, oo.N + 1))
        #  HbfExpd is 11 x M x 1200
        #  find the mean.  For the HISTORY TERM
        for i in xrange(oo.histknots):
            for m in xrange(oo.TR):
                sts = Msts[m]
                HbfExpd[i, m, 0:sts[0]] = 0
                for iss in xrange(len(sts) - 1):
                    t0 = sts[iss]
                    t1 = sts[iss + 1]
                    HbfExpd[i, m, t0 + 1:t1 + 1] = Hbf[0:t1 - t0, i]
                HbfExpd[i, m, sts[-1] + 1:] = 0

        _N.dot(oo.B.T, oo.aS, out=BaS)
        if oo.hS is None:
            oo.hS = _N.zeros(oo.histknots)

        _N.dot(Hbf, oo.hS, out=oo.loghist)
        oo.stitch_Hist(ARo, oo.loghist, Msts)

        ##  ORDER OF SAMPLING
        ##  f_xx, f_V
        ##  BINARY state
        ##  DA:  PG, kpOws
        ##  history, build ARo
        ##  psth
        ##  offset
        ##  DA:  latent state
        ##  AR coefficients
        ##  q2

        while (it < ooNMC + oo.burn - 1):
            lowsts = _N.where(oo.Z[:, 0] == 1)
            #print "lowsts   %s" % str(lowsts)
            t1 = _tm.time()
            it += 1
            print "****------------  %d" % it
            oo._d.f_x[:, 0, :, 0] = oo.x00
            if it == 0:
                for m in xrange(ooTR):
                    oo._d.f_V[m, 0] = oo.s2_x00
            else:
                oo._d.f_V[:, 0] = _N.mean(oo._d.f_V[:, 1:], axis=1)

            #  generate latent AR state

            if it > oo.startZ:
                for tryZ in xrange(oo.nStates):
                    _N.dot(sd01[tryZ], oo.smpx[..., 2:, 0], out=smpx01[tryZ])

                for m in oo.varz:
                    for tryZ in xrange(
                            oo.nStates
                    ):  #  only allow certain trials to change

                        #  calculate p0, p1  p0 = m_0 x PROD_n Ber(y_n | Z_j)
                        #                       = m_0 x _N.exp(_N.log(  ))
                        #  p0, p1 not normalized
                        ll[tryZ] = 0
                        #  Ber(0 | ) and Ber(1 | )
                        _N.exp(smpx01[tryZ, m] + BaS + ARo[m] + oo.us[m] +
                               oo.knownSig[m],
                               out=expT)
                        Bp[0] = 1 / (1 + expT)
                        Bp[1] = expT / (1 + expT)

                        #   z[:, 1]   is state label

                        for n in xrange(oo.N + 1):
                            ll[tryZ] += _N.log(Bp[oo.y[m, n], n])

                    ofs = _N.min(ll)
                    ll -= ofs
                    nc = oo.m[0] * _N.exp(ll[0]) + oo.m[1] * _N.exp(ll[1])

                    oo.Z[m, 0] = 0
                    oo.Z[m, 1] = 1
                    THR[m] = (oo.m[0] * _N.exp(ll[0]) / nc)
                    if _N.random.rand() < THR[m]:
                        oo.Z[m, 0] = 1
                        oo.Z[m, 1] = 0
                    oo.smp_zs[m, it] = oo.Z[m]

                for m in oo.fxdz:  #####  outside BM loop
                    oo.smp_zs[m, it] = oo.Z[m]
                t2 = _tm.time()

                #  Z  set
                _N.fill_diagonal(zd, oo.s[oo.Z[:, 1]])
                _N.fill_diagonal(izd, 1. / oo.s[oo.Z[:, 1]])
                #for kkk in xrange(oo.TR):
                #    print zd[kkk, kkk]
                _N.dot(zd, oo.smpx[..., 2:, 0], out=zsmpx)
                ######  sample m's
                _N.add(oo.alp, _N.sum(oo.Z[oo.varz], axis=0), out=dirArgs)
                oo.m[:] = _N.random.dirichlet(dirArgs)
                oo.smp_ms[it] = oo.m
            else:
                _N.fill_diagonal(zd, oo.s[oo.Z[:, 1]])
                _N.fill_diagonal(izd, 1. / oo.s[oo.Z[:, 1]])

                _N.dot(zd, oo.smpx[..., 2:, 0], out=zsmpx)
                ######  sample m's
                oo.smp_ms[it] = oo.m
                oo.smp_zs[:, it, 1] = 1
                oo.smp_zs[:, it, 0] = 0
            print oo.m

            lwsts = _N.where(oo.Z[:, 0] == 1)[0]
            hists = _N.where(oo.Z[:, 1] == 1)[0]

            t3 = _tm.time()

            ######  PG generate
            for m in xrange(ooTR):
                lw.rpg_devroye(oo.rn,
                               zsmpx[m] + oo.us[m] + BaS + ARo[m] +
                               oo.knownSig[m],
                               out=oo.ws[m])  ######  devryoe  ####TRD change

            _N.divide(oo.kp, oo.ws, out=kpOws)

            if not oo.bFixH:
                O = kpOws - zsmpx - oo.us.reshape(
                    (ooTR, 1)) - BaS - oo.knownSig

                iOf = vInds[0]  #  offset HcM index with RHS index.
                for i in vInds:
                    for j in vInds:
                        HcM[i - iOf,
                            j - iOf] = _N.sum(oo.ws * HbfExpd[i] * HbfExpd[j])

                    RHS[i, 0] = _N.sum(oo.ws * HbfExpd[i] * O)
                    for cj in cInds:
                        RHS[i, 0] -= _N.sum(
                            oo.ws * HbfExpd[i] * HbfExpd[cj]) * RHS[cj, 0]

                # print HbfExpd
                # print HcM
                # print RHS[vInds]
                vm = _N.linalg.solve(HcM, RHS[vInds])
                Cov = _N.linalg.inv(HcM)
                print vm
                cfs = _N.random.multivariate_normal(vm[:, 0], Cov, size=1)

                RHS[vInds, 0] = cfs[0]
                oo.smp_hS[:, it] = RHS[:, 0]

                #RHS[2:6, 0] = vm[:, 0]
                #print HcM
                #vv = _N.dot(Hbf, RHS)
                #print vv.shape
                #print oo.loghist.shape
                _N.dot(Hbf, RHS[:, 0], out=oo.loghist)
                oo.smp_hist[:, it] = oo.loghist
                oo.stitch_Hist(ARo, oo.loghist, Msts)

            ########     PSTH sample  Do PSTH after we generate zs
            if oo.bpsth:
                Oms = kpOws - zsmpx - ARo - oous_rs - oo.knownSig
                _N.einsum("mn,mn->n", oo.ws, Oms, out=smWimOm)  #  sum over
                ilv_f = _N.diag(_N.sum(oo.ws, axis=0))
                _N.fill_diagonal(lv_f, 1. / _N.diagonal(ilv_f))
                lm_f = _N.dot(lv_f, smWimOm)  #  nondiag of 1./Bi are inf
                #  now sample
                iVAR = _N.dot(oo.B, _N.dot(ilv_f, oo.B.T)) + iD_f
                VAR = _N.linalg.inv(iVAR)  #  knots x knots
                #iBDBW = _N.linalg.inv(BDB + lv_f)   # BDB not diag
                #Mn    = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua))

                Mn = oo.u_a + _N.dot(DB,
                                     _N.linalg.solve(BDB + lv_f, lm_f - BTua))
                oo.aS = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :]
                oo.smp_aS[it, :] = oo.aS

                #iBDBW = _N.linalg.inv(BDB + lv_f)   # BDB not diag
                #Mn    = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua))
                #oo.aS   = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :]
                #oo.smp_aS[it, :] = oo.aS
            else:
                oo.aS[:] = 0
            BaS = _N.dot(oo.B.T, oo.aS)

            ########     per trial offset sample
            Ons = kpOws - zsmpx - ARo - BaS - oo.knownSig

            #  solve for the mean of the distribution
            H = _N.ones((oo.TR - 1, oo.TR - 1)) * _N.sum(oo.ws[0])
            uRHS = _N.empty(oo.TR - 1)
            for dd in xrange(1, oo.TR):
                H[dd - 1, dd - 1] += _N.sum(oo.ws[dd])
                uRHS[dd - 1] = _N.sum(oo.ws[dd] * Ons[dd] - oo.ws[0] * Ons[0])

            MM = _N.linalg.solve(H, uRHS)
            Cov = _N.linalg.inv(H)

            oo.us[1:] = _N.random.multivariate_normal(MM, Cov, size=1)
            oo.us[0] = -_N.sum(oo.us[1:])
            oo.smp_u[:, it] = oo.us

            t4 = _tm.time()
            ####  Sample latent state
            oo._d.y = _N.dot(izd, kpOws - BaS - ARo - oous_rs - oo.knownSig)
            oo._d.copyParams(oo.F0, oo.q2)
            #  (MxM)  (MxN) = (MxN)  (Rv is MxN)
            _N.dot(_N.dot(izd, izd), 1. / oo.ws, out=oo._d.Rv)

            oo._d.f_x[:, 0, :, 0] = oo.x00
            #if it == 1:
            for m in xrange(ooTR):
                oo._d.f_V[m, 0] = oo.s2_x00
            else:
                oo._d.f_V[:, 0] = _N.mean(oo._d.f_V[:, 1:], axis=1)

            tpl_args = zip(oo._d.y, oo._d.Rv, oo._d.Fs,
                           _N.linalg.inv(oo._d.Fs, ), oo.q2, oo._d.Ns,
                           oo._d.ks, oo._d.f_x[:, 0], oo._d.f_V[:, 0])

            t5 = _tm.time()
            if oo.processes == 1:
                for m in xrange(ooTR):
                    oo.smpx[m, 2:], oo._d.f_x[m], oo._d.f_V[
                        m] = _kfar.armdl_FFBS_1itrMP(tpl_args[m])
                    oo.smpx[m, 1, 0:ook - 1] = oo.smpx[m, 2, 1:]
                    oo.smpx[m, 0, 0:ook - 2] = oo.smpx[m, 2, 2:]
                    oo.smp_q2[m, it] = oo.q2[m]
            else:
                sxv = pool.map(_kfar.armdl_FFBS_1itrMP, tpl_args)
                for m in xrange(ooTR):
                    oo.smpx[m, 2:] = sxv[m][0]
                    oo._d.f_x[m] = sxv[m][1]
                    oo._d.f_V[m] = sxv[m][2]
                    oo.smpx[m, 1, 0:ook - 1] = oo.smpx[m, 2, 1:]
                    oo.smpx[m, 0, 0:ook - 2] = oo.smpx[m, 2, 2:]
                    #oo.Bsmpx[m, it, 2:]    = oo.smpx[m, 2:, 0]

            stds = _N.std(oo.smpx[:, 2:, 0], axis=1)
            oo.mnStds[it] = _N.mean(stds, axis=0)
            print "mnStd  %.3f" % oo.mnStds[it]
            ###
            if not oo.bFixF:
                ARcfSmpl(oo.lfc,
                         ooN + 1,
                         ook,
                         oo.AR2lims,
                         oo.smpx[:, 1:, 0:ook],
                         oo.smpx[:, :, 0:ook - 1],
                         oo.q2,
                         oo.R,
                         oo.Cs,
                         oo.Cn,
                         alpR,
                         alpC,
                         oo.TR,
                         prior=oo.use_prior,
                         accepts=80,
                         aro=oo.ARord,
                         sig_ph0L=oo.sig_ph0L,
                         sig_ph0H=oo.sig_ph0H)
                oo.F_alfa_rep = alpR + alpC  #  new constructed
                prt, rank, f, amp = ampAngRep(oo.F_alfa_rep, f_order=True)
                print prt
            #ut, wt = FilteredTimeseries(ooN+1, ook, oo.smpx[:, 1:, 0:ook], oo.smpx[:, :, 0:ook-1], oo.q2, oo.R, oo.Cs, oo.Cn, alpR, alpC, oo.TR)
            #ranks[it]    = rank
            oo.allalfas[it] = oo.F_alfa_rep

            for m in xrange(ooTR):
                #oo.wts[m, it, :, :]   = wt[m, :, :, 0]
                #oo.uts[m, it, :, :]   = ut[m, :, :, 0]
                if not oo.bFixF:
                    oo.amps[it, :] = amp
                    oo.fs[it, :] = f

            oo.F0 = (-1 * _Npp.polyfromroots(oo.F_alfa_rep)[::-1].real)[1:]

            print "len(lwsts) %(l)d   len(hists) %(h)d" % {
                "l": len(lwsts),
                "h": len(hists)
            }
            sts2chg = hists
            if (it > oo.startZ) and oo.doS and len(sts2chg) > 0:
                AL = 0.5 * _N.sum(oo.smpx[sts2chg, 2:, 0] *
                                  oo.smpx[sts2chg, 2:, 0] * oo.ws[sts2chg])
                BRL = kpOws[sts2chg] - BaS - oous_rs[sts2chg] - ARo[
                    sts2chg] - oo.knownSig[sts2chg]
                BL = _N.sum(oo.ws[sts2chg] * BRL * oo.smpx[sts2chg, 2:, 0])
                UL = BL / (2 * AL)
                #sgL= 1/_N.sqrt(2*AL)
                sg2 = 1. / (2 * AL)

                q2_pr = 0.0025  # 0.05**2
                u_pr = 1.
                U = (u_pr * sg2 + UL * q2_pr) / (sg2 + q2_pr)
                sg = _N.sqrt((sg2 * q2_pr) / (sg2 + q2_pr))

                print "U  %(U).4f    UL %(UL).4f s  %(s).3f" % {
                    "U": U,
                    "s": sg,
                    "UL": UL
                }
                if _N.isnan(U):
                    print "U is nan  UL %.4f" % UL
                    print "U is nan  AL %.4f" % AL
                    print "U is nan  BL %.4f" % BL
                    print "U is nan  BaS "
                    print "hists"
                    print hists
                    print "lwsts"
                    print lwsts

                oo.s[1] = U + sg * _N.random.randn()

                _N.fill_diagonal(sd01[0], oo.s[0])
                _N.fill_diagonal(sd01[1], oo.s[1])
                print oo.s[1]
                oo.smp_ss[it] = oo.s[1]

            oo.a2 = oo.a_q2 + 0.5 * (ooTR * ooN + 2)  #  N + 1 - 1
            BB2 = oo.B_q2
            for m in xrange(ooTR):
                #   set x00
                #oo.x00[m]      = oo.smpx[m, 2]*0.1
                oo.x00[m] = oo.smpx[m, 2] * 0.001

                #####################    sample q2
                rsd_stp = oo.smpx[m, 3:, 0] - _N.dot(oo.smpx[m, 2:-1], oo.F0).T
                BB2 += 0.5 * _N.dot(rsd_stp, rsd_stp.T)
            oo.q2[:] = _ss.invgamma.rvs(oo.a2, scale=BB2)

            oo.smp_q2[:, it] = oo.q2
            t7 = _tm.time()
            print "gibbs iter %.3f" % (t7 - t1)
            if (it > 1) and (it % oo.peek == 0):
                fig = _plt.figure(figsize=(12, 8))
                fig.add_subplot(4, 1, 1)
                _plt.plot(oo.amps[1:it, 0])
                fig.add_subplot(4, 1, 2)
                _plt.plot(oo.fs[1:it, 0])
                fig.add_subplot(4, 1, 3)
                _plt.plot(oo.mnStds[1:it])
                fig.add_subplot(4, 1, 4)
                _plt.plot(oo.smp_ms[1:it, 0])

                _plt.savefig("%(dir)s/tmp-fsamps%(it)d" % {
                    "dir": oo.mcmcRunDir,
                    "it": it
                })
                _plt.close()

                oo.dump_smpsS(toiter=it, dir=oo.mcmcRunDir)
        oo.dump_smpsS(dir=oo.mcmcRunDir)
Beispiel #7
0
                l2_px, l2_pV, l2_fx, l2_fV, l2_K, random_walk)
            smp_offsets[5, it] = o_l2[0]

    smp_Bns[0, it] = B1wn
    smp_Bns[1, it] = B1tn
    smp_Bns[2, it] = B1ln
    smp_Bns[3, it] = B2wn
    smp_Bns[4, it] = B2tn
    smp_Bns[5, it] = B2ln

    smp_q2s[it] = q2_Bw1, q2_Bt1, q2_Bl1, q2_Bw2, q2_Bt2, q2_Bl2
    #if random_walk:
    #    F0_Bw1 = F0_Bt1 = F0_Bl1 = F0_Bw2 = F0_Bt2 = F0_Bl2 = 1
    smp_F0s[it] = F0_Bw1, F0_Bt1, F0_Bl1, F0_Bw2, F0_Bt2, F0_Bl2

    lw.rpg_devroye(N_vec[:, 0], o_w1 * W_n + o_t1 * T_n + o_l1 * L_n, out=ws1)
    lw.rpg_devroye(N_vec[:, 1], o_w2 * W_n + o_t2 * T_n + o_l2 * L_n, out=ws2)

    ws2[zr2] = 1e-20  #1e-20

pklme = {}
smp_every = 50
pklme["smp_Bns"] = smp_Bns[:, ::smp_every]
pklme["smp_q2s"] = smp_q2s[::smp_every]
pklme["smp_F0s"] = smp_F0s[::smp_every]
pklme["smp_offsets"] = smp_offsets[:, ::smp_every]
pklme["smp_every"] = smp_every
pklme["Wn"] = W_n
pklme["Tn"] = T_n
pklme["Ln"] = L_n
pklme["hnd_dat"] = hnd_dat
Beispiel #8
0
    def dirichletAllocate(self):  ###########################  GIBBSSAMP
        oo = self

        signal.signal(signal.SIGINT, signal_handler)

        ooTR = oo.TR
        ook = oo.k
        ooN = oo.N

        runTO = oo.ITERS - 1
        oo.allocateSmp(runTO + 1, Bsmpx=oo.doBsmpx)
        #oo.allocateSmp(oo.burn + oo.NMC)
        oo.x00 = _N.array(oo.smpx[:, 2])
        oo.V00 = _N.zeros((ooTR, ook, ook))

        _kfar.init(oo.N, oo.k, oo.TR)

        if oo.dohist:
            oo.loghist = _N.zeros(oo.Hbf.shape[0])
        else:
            print("fixed hist is")
            print(oo.loghist)

        ARo = _N.zeros((ooTR, ooN + 1))

        kpOws = _N.empty((ooTR, ooN + 1))
        lv_f = _N.zeros((ooN + 1, ooN + 1))
        lv_u = _N.zeros((ooTR, ooTR))
        Bii = _N.zeros((ooN + 1, ooN + 1))

        #alpC.reverse()
        #  F_alfa_rep = alpR + alpC  already in right order, no?

        Wims = _N.empty((ooTR, ooN + 1, ooN + 1))
        Oms = _N.empty((ooTR, ooN + 1))
        smWimOm = _N.zeros(ooN + 1)
        smWinOn = _N.zeros(ooTR)
        bConstPSTH = False
        D_f = _N.diag(_N.ones(oo.B.shape[0]) * oo.s2_a)  #  spline
        iD_f = _N.linalg.inv(D_f)
        D_u = _N.diag(_N.ones(oo.TR) * oo.s2_u)  #  This should
        iD_u = _N.linalg.inv(D_u)
        iD_u_u_u = _N.dot(iD_u, _N.ones(oo.TR) * oo.u_u)
        BDB = _N.dot(oo.B.T, _N.dot(D_f, oo.B))
        DB = _N.dot(D_f, oo.B)
        BTua = _N.dot(oo.B.T, oo.u_a)

        it = 0

        ###############################  MCMC LOOP  ########################
        ###  need pointer to oo.us, but reshaped for broadcasting to work
        ###############################  MCMC LOOP  ########################
        oous_rs = oo.us.reshape((ooTR, 1))  #  done for broadcasting rules

        sd01 = _N.zeros((oo.nStates, oo.TR, oo.TR))
        _N.fill_diagonal(sd01[0], oo.s[0])
        _N.fill_diagonal(sd01[1], oo.s[1])

        smpx01 = _N.zeros((oo.nStates, oo.TR, oo.N + 1))
        zsmpx = _N.empty((oo.TR, oo.N + 1))

        #  zsmpx created
        #  PG

        zd = _N.zeros((oo.TR, oo.TR))
        izd = _N.zeros((oo.TR, oo.TR))
        ll = _N.zeros(oo.nStates)
        Bp = _N.empty((oo.nStates, oo.N + 1))

        for m in range(ooTR):
            oo.f_V[m, 0] = oo.s2_x00
            oo.f_V[m, 1] = oo.s2_x00

        THR = _N.empty(oo.TR)
        dirArgs = _N.empty(oo.nStates)  #  dirichlet distribution args
        expT = _N.empty(ooN + 1)
        BaS = _N.dot(oo.B.T, oo.aS)

        alpR = oo.F_alfa_rep[0:oo.R]
        alpC = oo.F_alfa_rep[oo.R:]

        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        print(oo.F_alfa_rep)
        print("*****************************")
        print(alpR)
        print(alpC)

        oo.nSMP_smpxC = 0

        if oo.mcmcRunDir is None:
            oo.mcmcRunDir = ""
        elif (len(oo.mcmcRunDir) > 0) and (oo.mcmcRunDir[-1] != "/"):
            oo.mcmcRunDir += "/"

        #  H shape    100 x 9
        Hbf = oo.Hbf

        RHS = _N.empty((oo.histknots, 1))

        cInds = _N.arange(oo.iHistKnotBeginFixed, oo.histknots)
        vInds = _N.arange(0, oo.iHistKnotBeginFixed)
        RHS[cInds, 0] = 0

        Msts = []
        for m in range(ooTR):
            Msts.append(_N.where(oo.y[m] == 1)[0])
        HcM = _N.empty((len(vInds), len(vInds)))

        HbfExpd = _N.empty((oo.histknots, ooTR, oo.N + 1))
        #  HbfExpd is 11 x M x 1200
        #  find the mean.  For the HISTORY TERM
        for i in range(oo.histknots):
            for m in range(oo.TR):
                sts = Msts[m]
                HbfExpd[i, m, 0:sts[0]] = 0
                for iss in range(len(sts) - 1):
                    t0 = sts[iss]
                    t1 = sts[iss + 1]
                    HbfExpd[i, m, t0 + 1:t1 + 1] = Hbf[0:t1 - t0, i]
                HbfExpd[i, m, sts[-1] + 1:] = 0

        _N.dot(oo.B.T, oo.aS, out=BaS)
        if oo.hS is None:
            oo.hS = _N.zeros(oo.histknots)

        _N.dot(Hbf, oo.hS, out=oo.loghist)
        oo.stitch_Hist(ARo, oo.loghist, Msts)

        K = _N.empty((oo.TR, oo.N + 1, oo.k))  #  kalman gain

        iterBLOCKS = oo.ITERS // oo.peek
        smpx_tmp = _N.empty((oo.TR, oo.N + 1, oo.k))

        ##  ORDER OF SAMPLING
        ##  f_xx, f_V
        ##  BINARY state
        ##  DA:  PG, kpOws
        ##  history, build ARo
        ##  psth
        ##  offset
        ##  DA:  latent state
        ##  AR coefficients
        ##  q2
        oo.gau_var = _N.array(oo.ws)

        #iterBLOCKS = 1
        #oo.peek = 1

        arangeNp1 = _N.arange(oo.N + 1)
        for itrB in range(iterBLOCKS):
            it = itrB * oo.peek
            if it > 0:
                print(
                    "it: %(it)d    mnStd  %(mnstd).3f   fs  %(fs).3f    m %(m).3f    [%(0).2f,%(1).2f]"
                    % {
                        "it": itrB * oo.peek,
                        "mnstd": oo.mnStds[it - 1],
                        "fs": oo.fs[it - 1, 0],
                        "m": oo.m[0],
                        "0": oo.s[0],
                        "1": oo.s[1]
                    })

            #tttA = _tm.time()
            if interrupted:
                break
            for it in range(itrB * oo.peek, (itrB + 1) * oo.peek):

                lowsts = _N.where(oo.Z[:, 0] == 1)
                #print "lowsts   %s" % str(lowsts)
                t1 = _tm.time()
                oo.f_x[:, 0] = oo.x00
                if it == 0:
                    for m in range(ooTR):
                        oo.f_V[m, 0] = oo.s2_x00
                else:
                    oo.f_V[:, 0] = _N.mean(oo.f_V[:, 1:], axis=1)

                #  generate latent AR state

                if it > oo.startZ:
                    for tryZ in range(oo.nStates):
                        _N.dot(sd01[tryZ], oo.smpx[:, 2:, 0], out=smpx01[tryZ])

                    for m in range(oo.TR):
                        for tryZ in range(
                                oo.nStates
                        ):  #  only allow certain trials to change

                            #  calculate p0, p1  p0 = m_0 x PROD_n Ber(y_n | Z_j)
                            #                       = m_0 x _N.exp(_N.log(  ))
                            #  p0, p1 not normalized
                            #  Ber(0 | ) and Ber(1 | )
                            _N.exp(smpx01[tryZ, m] + BaS + ARo[m] + oo.us[m] +
                                   oo.knownSig[m],
                                   out=expT)
                            Bp[0] = 1 / (1 + expT)
                            Bp[1] = expT / (1 + expT)

                            #   z[:, 1]   is state label
                            #ll[tryZ] = 0
                            ll[tryZ] = _N.sum(
                                _N.log(Bp[oo.y[m, arangeNp1], arangeNp1]))

                        ofs = _N.min(ll)
                        ll -= ofs
                        #nc = oo.m[0]*_N.exp(ll[0]) + oo.m[1]*_N.exp(ll[1])
                        nc = oo.m[0] + oo.m[1] * _N.exp(ll[1] - ll[0])

                        oo.Z[m, 0] = 0
                        oo.Z[m, 1] = 1
                        #THR[m] = (oo.m[0]*_N.exp(ll[0]) / nc)
                        THR[m] = (oo.m[0] / nc)
                        if _N.random.rand() < THR[m]:
                            oo.Z[m, 0] = 1
                            oo.Z[m, 1] = 0
                        oo.smp_zs[m, it] = oo.Z[m]
                    for m in oo.fxdz:  #####  outside BM loop
                        oo.smp_zs[m, it] = oo.Z[m]
                    #  Z  set
                    _N.fill_diagonal(zd, oo.s[oo.Z[:, 1]])
                    _N.fill_diagonal(izd, 1. / oo.s[oo.Z[:, 1]])

                    _N.dot(zd, oo.smpx[..., 2:, 0], out=zsmpx)
                    ######  sample m's
                    _N.add(oo.alp, _N.sum(oo.Z[oo.varz], axis=0), out=dirArgs)
                    oo.m[:] = _N.random.dirichlet(dirArgs)
                    oo.smp_ms[it] = oo.m

                else:  #  turned off dirichlet, always allocate to low state
                    _N.fill_diagonal(zd, oo.s[oo.Z[:, 1]])
                    _N.fill_diagonal(izd, 1. / oo.s[oo.Z[:, 1]])

                    _N.dot(zd, oo.smpx[:, 2:, 0], out=zsmpx)
                    ######  sample m's
                    oo.smp_ms[it] = oo.m
                    oo.smp_zs[:, it, 1] = 1
                    oo.smp_zs[:, it, 0] = 0

                lwsts = _N.where(oo.Z[:, 0] == 1)[0]
                hists = _N.where(oo.Z[:, 1] == 1)[0]

                #print(zsmpx[0, 0:20])
                #print(oo.smpx[0, 2:22, 0])
                t3 = _tm.time()

                ######  PG generate
                for m in range(ooTR):
                    ###  CHANGE 1
                    #lw.rpg_devroye(oo.rn, oo.smpx[m, 2:, 0] + oo.us[m] + BaS + ARo[m] + oo.knownSig[m], out=oo.ws[m])  ######  devryoe
                    lw.rpg_devroye(
                        oo.rn,
                        zsmpx[m] + oo.us[m] + BaS + ARo[m] + oo.knownSig[m],
                        out=oo.ws[m])  ######  devryoe  ####TRD change

                _N.divide(oo.kp, oo.ws, out=kpOws)

                if oo.dohist:
                    #O = kpOws - oo.smpx[..., 2:, 0] - oo.us.reshape((ooTR, 1)) - BaS -  oo.knownSig
                    O = kpOws - zsmpx - oo.us.reshape(
                        (ooTR, 1)) - BaS - oo.knownSig

                    for ii in range(len(vInds)):
                        #print("i   %d" % i)
                        #print(_N.sum(HbfExpd[i]))
                        i = vInds[ii]
                        for jj in range(ii, len(vInds)):
                            j = vInds[jj]
                            #print("j   %d" % j)
                            #print(_N.sum(HbfExpd[j]))
                            HcM[ii,
                                jj] = _N.sum(oo.ws * HbfExpd[i] * HbfExpd[j])
                            HcM[jj, ii] = HcM[ii, jj]

                        RHS[ii, 0] = _N.sum(oo.ws * HbfExpd[i] * O)
                        for cj in cInds:
                            RHS[ii, 0] -= _N.sum(
                                oo.ws * HbfExpd[i] * HbfExpd[cj]) * RHS[cj, 0]

                    vm = _N.linalg.solve(HcM, RHS[vInds])
                    Cov = _N.linalg.inv(HcM)
                    cfs = _N.random.multivariate_normal(vm[:, 0], Cov, size=1)

                    RHS[vInds, 0] = cfs[0]
                    oo.smp_hS[:, it] = RHS[:, 0]

                    #RHS[2:6, 0] = vm[:, 0]
                    #print HcM
                    #vv = _N.dot(Hbf, RHS)
                    #print vv.shape
                    #print oo.loghist.shape
                    _N.dot(Hbf, RHS[:, 0], out=oo.loghist)
                    oo.smp_hist[:, it] = oo.loghist
                    oo.stitch_Hist(ARo, oo.loghist, Msts)

                ########     PSTH sample  Do PSTH after we generate zs
                if oo.bpsth:
                    #Oms  = kpOws - oo.smpx[..., 2:, 0] - ARo - oous_rs - oo.knownSig
                    Oms = kpOws - zsmpx - ARo - oous_rs - oo.knownSig
                    _N.einsum("mn,mn->n", oo.ws, Oms, out=smWimOm)  #  sum over
                    ilv_f = _N.diag(_N.sum(oo.ws, axis=0))
                    _N.fill_diagonal(lv_f, 1. / _N.diagonal(ilv_f))
                    lm_f = _N.dot(lv_f, smWimOm)  #  nondiag of 1./Bi are inf
                    #  now sample
                    iVAR = _N.dot(oo.B, _N.dot(ilv_f, oo.B.T)) + iD_f
                    VAR = _N.linalg.inv(iVAR)  #  knots x knots
                    #iBDBW = _N.linalg.inv(BDB + lv_f)   # BDB not diag
                    #Mn    = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua))

                    Mn = oo.u_a + _N.dot(
                        DB, _N.linalg.solve(BDB + lv_f, lm_f - BTua))
                    oo.aS = _N.random.multivariate_normal(Mn, VAR,
                                                          size=1)[0, :]
                    oo.smp_aS[it, :] = oo.aS

                    #iBDBW = _N.linalg.inv(BDB + lv_f)   # BDB not diag
                    #Mn    = oo.u_a + _N.dot(DB, _N.dot(iBDBW, lm_f - BTua))
                    #oo.aS   = _N.random.multivariate_normal(Mn, VAR, size=1)[0, :]
                    #oo.smp_aS[it, :] = oo.aS
                else:
                    oo.aS[:] = 0
                _N.dot(oo.B.T, oo.aS, out=BaS)

                ########     per trial offset sample
                #Ons  = kpOws - zsmpx - ARo - BaS - oo.knownSig
                Ons = kpOws - oo.smpx[..., 2:, 0] - ARo - BaS - oo.knownSig

                #  solve for the mean of the distribution
                H = _N.ones((oo.TR - 1, oo.TR - 1)) * _N.sum(oo.ws[0])
                uRHS = _N.empty(oo.TR - 1)
                for dd in range(1, oo.TR):
                    H[dd - 1, dd - 1] += _N.sum(oo.ws[dd])
                    uRHS[dd - 1] = _N.sum(oo.ws[dd] * Ons[dd] -
                                          oo.ws[0] * Ons[0])

                MM = _N.linalg.solve(H, uRHS)
                Cov = _N.linalg.inv(H)

                oo.us[1:] = _N.random.multivariate_normal(MM, Cov, size=1)
                oo.us[0] = -_N.sum(oo.us[1:])
                oo.smp_u[:, it] = oo.us

                t4 = _tm.time()
                ####  Sample latent state
                #oo.gau_obs = kpOws - BaS - ARo - oous_rs - oo.knownSig
                oo.gau_obs = _N.dot(izd,
                                    kpOws - BaS - ARo - oous_rs - oo.knownSig)
                #oo.copyParams(oo.F0, oo.q2)
                #  (MxM)  (MxN) = (MxN)  (Rv is MxN)
                _N.dot(_N.dot(izd, izd), 1. / oo.ws, out=oo.gau_var)
                #oo.gau_var =1 / oo.ws

                t5 = _tm.time()

                _kfar.armdl_FFBS_1itrMP(oo.gau_obs, oo.gau_var, oo.Fs,
                                        _N.linalg.inv(oo.Fs), oo.q2, oo.Ns,
                                        oo.ks, oo.f_x, oo.f_V, oo.p_x, oo.p_V,
                                        smpx_tmp, K)

                oo.smpx[:, 2:] = smpx_tmp
                oo.smpx[:, 1, 0:ook - 1] = oo.smpx[:, 2, 1:]
                oo.smpx[:, 0, 0:ook - 2] = oo.smpx[:, 2, 2:]

                if oo.doBsmpx and (it % oo.BsmpxSkp == 0):
                    oo.Bsmpx[:, it // oo.BsmpxSkp, 2:] = oo.smpx[:, 2:, 0]

                stds = _N.std(oo.smpx[:, 2:, 0], axis=1)
                oo.mnStds[it] = _N.mean(stds, axis=0)
                if len(hists) == 0:
                    print("!!!!!!  length hists is 0 before ARcfSmpl")
                ###
                #_arcfs.ARcfSmpl(ooN+1, ook, oo.AR2lims, oo.smpx[:, 1:, 0:ook], oo.smpx[:, 0:, 0:ook-1], oo.q2, oo.R, oo.Cs, oo.Cn, alpR, alpC, oo.TR, prior=oo.use_prior, accepts=8, aro=oo.ARord, sig_ph0L=oo.sig_ph0L, sig_ph0H=oo.sig_ph0H)
                _arcfs.ARcfSmpl(ooN + 1,
                                ook,
                                oo.AR2lims,
                                oo.smpx[hists, 1:, 0:ook],
                                oo.smpx[hists, 0:, 0:ook - 1],
                                oo.q2,
                                oo.R,
                                oo.Cs,
                                oo.Cn,
                                alpR,
                                alpC,
                                len(hists),
                                prior=oo.use_prior,
                                accepts=8,
                                aro=oo.ARord,
                                sig_ph0L=oo.sig_ph0L,
                                sig_ph0H=oo.sig_ph0H)
                oo.F_alfa_rep = alpR + alpC  #  new constructed
                prt, rank, f, amp = ampAngRep(oo.F_alfa_rep, f_order=True)
                #ut, wt = FilteredTimeseries(ooN+1, ook, oo.smpx[:, 1:, 0:ook], oo.smpx[:, :, 0:ook-1], oo.q2, oo.R, oo.Cs, oo.Cn, alpR, alpC, oo.TR)
                #ranks[it]    = rank
                oo.allalfas[it] = oo.F_alfa_rep

                for m in range(ooTR):
                    #oo.wts[m, it, :, :]   = wt[m, :, :, 0]
                    #oo.uts[m, it, :, :]   = ut[m, :, :, 0]
                    if not oo.bFixF:
                        oo.amps[it, :] = amp
                        oo.fs[it, :] = f

                oo.F0 = (-1 * _Npp.polyfromroots(oo.F_alfa_rep)[::-1].real)[1:]
                for tr in range(oo.TR):
                    oo.Fs[tr, 0] = oo.F0[:]

                #print "len(lwsts) %(l)d   len(hists) %(h)d" % {"l" : len(lwsts), "h" : len(hists)}
                # sts2chg = hists
                # #sts2chg = lwsts
                # #if (it > oo.startZ) and oo.doS and len(sts2chg) > 0:
                # if oo.doS and len(sts2chg) > 0:
                #     AL = 0.5*_N.sum(oo.smpx[sts2chg, 2:, 0]*oo.smpx[sts2chg, 2:, 0]*oo.ws[sts2chg])
                #     #AL = 0.5*_N.sum(oo.smpx[sts2chg, 2:, 0]*oo.smpx[sts2chg, 2:, 0])
                #     BRL = kpOws[sts2chg] - BaS - oous_rs[sts2chg] - ARo[sts2chg] - oo.knownSig[sts2chg]
                #     BL = 0.5*_N.sum(oo.ws[sts2chg]*BRL*oo.smpx[sts2chg, 2:, 0])
                #     UL = BL / (2*AL)
                #     #sgL= 1/_N.sqrt(2*AL)
                #     sg2= 1./(2*AL)
                #     if it % 50 == 0:
                #         print("u  %(u).3f  %(s).3f" % {"u" : UL, "s" : _N.sqrt(sg2)})

                #     q2_pr = 0.25  # 0.05**2
                #     u_pr  = 1.
                #     #u_pr  = 0
                #     U = (u_pr * sg2 + UL * q2_pr) / (sg2 + q2_pr)
                #     sg= _N.sqrt((sg2*q2_pr) / (sg2 + q2_pr))

                #     #print "U  %(U).4f    UL %(UL).4f s  %(s).3f" % {"U" : U, "s" : sg, "UL" : UL}
                #     if _N.isnan(U):
                #         print("U is nan  UL %.4f" % UL)
                #         print("U is nan  AL %.4f" % AL)
                #         print("U is nan  BL %.4f" % BL)
                #         print("U is nan  BaS ")
                #         print("hists")
                #         print(hists)
                #         print("lwsts")
                #         print(lwsts)

                #     oo.s[1] = U + sg*_N.random.randn()
                #     #oo.s[0] = U + sg*_N.random.randn()

                #     _N.fill_diagonal(sd01[0], oo.s[0])
                #     _N.fill_diagonal(sd01[1], oo.s[1])
                #     #print oo.s[1]
                #     oo.smp_ss[it] = oo.s[1]
                #     #oo.smp_ss[it] = oo.s[0]

                #oo.a2 = oo.a_q2 + 0.5*(ooTR*ooN + 2)  #  N + 1 - 1
                oo.a2 = oo.a_q2 + 0.5 * (len(hists) * ooN + 2)  #  N + 1 - 1
                BB2 = oo.B_q2
                #for m in range(ooTR):
                for m in hists:
                    #   set x00
                    #oo.x00[m]      = oo.smpx[m, 2]*0.1
                    oo.x00[m] = oo.smpx[m, 2] * 0.1

                    #####################    sample q2
                    rsd_stp = oo.smpx[m, 3:, 0] - _N.dot(
                        oo.smpx[m, 2:-1], oo.F0).T
                    BB2 += 0.5 * _N.dot(rsd_stp, rsd_stp.T)
                oo.q2[:] = _ss.invgamma.rvs(oo.a2, scale=BB2)

                oo.smp_q2[:, it] = oo.q2
                t7 = _tm.time()