示例#1
0
def colorclusters(smkpos, labs, MS, name="", xLo=0, xHi=3):
    fig = _plt.figure(figsize=(12, 8))

    myclrs = _clrs.get_colors(MS)
    for m in xrange(MS):
        inds = _N.where(labs == m)[0]
        for k in xrange(4):
            fig.add_subplot(2, 2, k+1)
            _plt.scatter(smkpos[inds, 0], smkpos[inds, k+1], color=myclrs[m], s=9)
            _plt.xlim(xLo-(xHi-xLo)*0.1, xHi+(xHi-xLo)*0.1)
    _plt.savefig("cc%s-all" % name)
    _plt.close()

    L = 0
    for m in xrange(MS):
        fig = _plt.figure(figsize=(12, 8))
        inds = _N.where(labs == m)[0]
        L += len(inds)
        for k in xrange(4):
            fig.add_subplot(2, 2, k+1)
            _plt.scatter(smkpos[inds, 0], smkpos[inds, k+1], color=myclrs[m], s=9)
            _plt.xlim(xLo-(xHi-xLo)*0.1, xHi+(xHi-xLo)*0.1)
        _plt.savefig("cc%(n)s-%(m)d" % {"n" : name, "m" : m})
        _plt.close()
    print L
示例#2
0
def stochasticAssignment(oo, epc, it, Msc, M, K, l0, f, q2, u, Sg, _f_u, _u_u, _f_q2, _u_Sg, Asts, t0, mASr, xASr, rat, econt, gz, qdrMKS, freeClstr, hashthresh, cmp2Existing, nthrds=1):
    #  Msc   Msc signal clusters
    #  M     all clusters, including nz clstr.  M == Msc when not using nzclstr
    #  Gibbs sampling
    #  parameters l0, f, q2
    #  mASr, xASr   just the mark, position of spikes btwn t0 and t1
    #qdrMKS2 = _N.empty(qdrMKS.shape)
    t1 = _tm.time()
    nSpks = len(Asts)
    twpi = 2*_N.pi

    Kp1      = K+1
    #rat      = _N.zeros(M+1)
    pc       = _N.zeros(M)

    ur         = u.reshape((M, 1, K))
    fr         = f.reshape((M, 1))    # centers
    #print q2
    iq2        = 1./q2
    iSg        = _N.linalg.inv(Sg)
    iq2r       = iq2.reshape((M, 1))  
    try:
        ##  warnings because l0 is 0
        isN = _N.where(q2 <= 0)[0]
        if len(isN) > 0:
            q2[isN] = 0.3

        is0 = _N.where(l0 <= 0)[0]
        if len(is0) > 0:
            l0[is0] = 0.001

        pkFR       = _N.log(l0) - 0.5*_N.log(twpi*q2)   #  M
    except RuntimeWarning:
        print "WARNING"
        print l0
        print q2

    mkNrms = _N.log(1/_N.sqrt(twpi*_N.linalg.det(Sg)))
    mkNrms = mkNrms.reshape((M, 1))   #  M x 1

    rnds       = _N.random.rand(nSpks)

    pkFRr      = pkFR.reshape((M, 1))
    dmu        = (mASr - ur)     # mASr 1 x N x K,     ur  is M x 1 x K
    N          = mASr.shape[1]
    #t2 = _tm.time()
    #_N.einsum("mnj,mjk,mnk->mn", dmu, iSg, dmu, out=qdrMKS)
    #t3 = _tm.time()
    _fm.multi_qdrtcs_par_func(dmu, iSg, qdrMKS, M, N, K, nthrds=nthrds)

    #  fr is    M x 1, xASr is 1 x N, iq2r is M x 1
    #qdrSPC     = (fr - xASr)*(fr - xASr)*iq2r  #  M x nSpks   # 0.01s
    qdrSPC     = _N.empty((M, N))
    _hcb.hc_bcast1(fr, xASr, iq2r, qdrSPC, M, N)

    ###  how far is closest cluster to each newly observed mark

    #  mAS = mks[Asts+t0] 
    #  xAS = x[Asts + t0]   #  position @ spikes

    if cmp2Existing:   #  compare only non-hash spikes and non-hash clusters
        # realCl = _N.where(freeClstr == False)[0]
        # print freeClstr.shape
        # print realCl.shape

        abvthrEachCh = mASr[0] > hashthresh    #  should be NxK of
        abvthrAtLeast1Ch = _N.sum(abvthrEachCh, axis=1) > 0   # N x K
        newNonHashSpks   = _N.where(abvthrAtLeast1Ch)[0]

        newNonHashSpksMemClstr = _N.ones(len(newNonHashSpks), dtype=_N.int) * (M-1)   #  initially, assign all of them to noise cluster

        #print "spikes not hash"
         #print abvthrInds
        abvthrEachCh = u[0:Msc] > hashthresh  #  M x K  (M includes noise)
        abvthrAtLeast1Ch = _N.sum(abvthrEachCh, axis=1) > 0
        
        knownNonHclstrs  = _N.where(abvthrAtLeast1Ch & (freeClstr == False) & (q2[0:Msc] < wdSpc))[0]
        

        #print "clusters not hash"

        #  Place prior for freeClstr near new non-hash spikes that are far 
        #  from known clusters that are not hash clusters 


        nNrstMKS_d = _N.sqrt(_N.min(qdrMKS[knownNonHclstrs], axis=0)/K)  #  dim len(sts)
        nNrstSPC_d = _N.sqrt(_N.min(qdrSPC[knownNonHclstrs], axis=0))
        #  for each spike, distance to nearest non-hash cluster
        # print nNrstMKS_d
        # print nNrstSPC_d
        # print "=============="
        s = _N.empty((len(newNonHashSpks), 3))
        #  for each spike, distance to nearest cluster
        s[:, 0] = newNonHashSpks
        s[:, 1] = nNrstMKS_d[newNonHashSpks]
        s[:, 2] = nNrstSPC_d[newNonHashSpks]
        _N.savetxt(resFN("qdrMKSSPC%d" % epc, dir=oo.outdir), s, fmt="%d %.3e %.3e")

        dMK     = nNrstMKS_d[newNonHashSpks]
        dSP     = nNrstSPC_d[newNonHashSpks]

        ###  assignment into 

        farMKinds = _N.where(dMK > 4)[0]    # 
        #  mean of prior for center - mean of farMKinds
        #  cov  of prior for center - how certain am I of mean?  
        farSPinds = _N.where(dSP > 4)[0]  #  4 std. deviations away

        farMKSPinds = _N.union1d(farMKinds, farSPinds)
        print farMKinds
        print newNonHashSpks
        
        ##  points in newNonHashSpks but not in farMKinds
        notFarMKSPinds = _N.setdiff1d(_N.arange(newNonHashSpks.shape[0]), farMKSPinds)

        farMKSP = _N.empty((len(farMKSPinds), K+1))
        farMKSP[:, 0]  = xASr[0, newNonHashSpks[farMKSPinds]]
        farMKSP[:, 1:] = mASr[0, newNonHashSpks[farMKSPinds]]
        notFarMKSP = _N.empty((len(notFarMKSPinds), K+1))
        notFarMKSP[:, 0]  = xASr[0, newNonHashSpks[notFarMKSPinds]]
        notFarMKSP[:, 1:] = mASr[0, newNonHashSpks[notFarMKSPinds]]

        # farSP = _N.empty((len(farSPinds), K+1))
        # farMK = _N.empty((len(farMKinds), K+1))
        # farSP[:, 0]  = xASr[0, farSPinds]
        # farSP[:, 1:] = mASr[0, farSPinds]
        # farMK[:, 0]  = xASr[0, farMKinds]
        # farMK[:, 1:] = mASr[0, farMKinds]

        minK = 1
        maxK = farMKSPinds.shape[0] / K
        maxK = maxK if (maxK < 6) else 6

        freeClstrs = _N.where(freeClstr == True)[0]
        if maxK >= 2:
            print "coming in here"
            #labs, bics, bestLab, nClstrs = _oT.EMBICs(farMKSP, minK=minK, maxK=maxK, TR=1)
            labs, labsH, clstrs = emMKPOS_sep1B(farMKSP, None, TR=1, wfNClstrs=[[1, 4], [1, 4]], spNClstrs=[[1, 4], [1, 3]])
            nClstrs = clstrs[0]
            bestLab    = labs

            cls = clrs.get_colors(nClstrs)

            _U.savetxtWCom(resFN("newSpksMKSP%d" % epc, dir=oo.outdir), farMKSP, fmt="%.3e %.3e %.3e %.3e %.3e", com=("# number of clusters %d" % nClstrs))
            _U.savetxtWCom(resFN("newSpksMKSP_nf%d" % epc, dir=oo.outdir), notFarMKSP, fmt="%.3e %.3e %.3e %.3e %.3e", com=("# number of clusters %d" % nClstrs))

            L = len(freeClstrs)
            
            unqLabs = _N.unique(bestLab)

            upto    = nClstrs if nClstrs < L else L  #  this should just count large clusters
            ii  = -1
            fig = _plt.figure()
            
            for fid in unqLabs[0:upto]:
                iths = farMKSPinds[_N.where(bestLab == fid)[0]]
                ths = newNonHashSpks[iths]

                for w in xrange(K):
                    fig.add_subplot(2, 2, w+1)
                    _plt.scatter(xASr[0, ths], mASr[0, ths, w], color=cls[ii])

                if len(ths) > K:
                    ii += 1
                    im = freeClstrs[ii]   # Asts + t0 gives absolute time
                    newNonHashSpksMemClstr[iths] = im

                    _u_u[im]  = _N.mean(mASr[0, ths], axis=0)
                    u[im]     = _u_u[im]
                    _f_u[im]  = _N.mean(xASr[0, ths], axis=0)
                    f[im]     = _f_u[im]
                    q2[im]    = _N.std(xASr[0, ths], axis=0)**2 * 9
                    #  l0 = Hz * sqrt(2*_N.pi*q2)
                    l0[im]    =   10*_N.sqrt(q2[im])
                    _f_q2[im] = 1
                    _u_Sg[im] = _N.cov(mASr[0, ths], rowvar=0)*25
                    print "ep %(ep)d  new   cluster #  %(m)d" % {"ep" : epc, "m" : im}
                    print _u_u[im]
                    print _f_u[im]
                    print _f_q2[im]
                else:
                    print "too small    this prob. doesn't represent a cluster"

            _plt.savefig("newspks%d" % epc)


            # #######  known clusters
            # for fid in unqLabs[0:upto]:
            #     iths = farMKSPinds[_N.where(bestLab == fid)[0]]
            #     ths = newNonHashSpks[iths]

            #     for w in xrange(K):
            #         fig.add_subplot(2, 2, w+1)
            #         _plt.scatter(xASr[0, ths], mASr[0, ths, w], color=cls[ii])

            #     if len(ths) > K:
            #         ii += 1
            #         im = freeClstrs[ii]   # Asts + t0 gives absolute time
            #         newNonHashSpksMemClstr[iths] = im

            #         _u_u[im]  = _N.mean(mASr[0, ths], axis=0)
            #         u[im]     = _u_u[im]
            #         _f_u[im]  = _N.mean(xASr[0, ths], axis=0)
            #         f[im]     = _f_u[im]
            #         q2[im]    = _N.std(xASr[0, ths], axis=0)**2 * 9
            #         #  l0 = Hz * sqrt(2*_N.pi*q2)
            #         l0[im]    =   10*_N.sqrt(q2[im])
            #         _f_q2[im] = 1
            #         _u_Sg[im] = _N.cov(mASr[0, ths], rowvar=0)*25
            #         print "ep %(ep)d  new   cluster #  %(m)d" % {"ep" : epc, "m" : im}
            #         print _u_u[im]
            #         print _f_u[im]
            #         print _f_q2[im]
            #     else:
            #         print "too small    this prob. doesn't represent a cluster"

            # _plt.savefig("newspks%d" % epc)


        else:  #  just one cluster
            im = freeClstrs[0]   # Asts + t0 gives absolute time

            _u_u[im]  = _N.mean(mASr[0, newNonHashSpks[farMKSPinds]], axis=0)
            _f_u[im]  = _N.mean(xASr[0, newNonHashSpks[farMKSPinds]], axis=0)
            _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farMKSPinds]], rowvar=0)*16
            _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farMKSPinds]], axis=0)**2 * 16

        # ##  kernel density estimate
        # xs  = _N.linspace(-6, 6, 101)
        # xsr = xs.reshape(101, 1)
        # isg2= 1/(0.1**2)   #  spatial kernel bandwidth

        # # fig = _plt.figure(figsize=(6, 9))
        # # fig.add_subplot(1, 2, 1)
        # # _plt.scatter(xASr[0, newNonHashSpks[farMKinds]], mASr[0, newNonHashSpks[farMKinds], 0])
        # # fig.add_subplot(1, 2, 2)
        # # _plt.scatter(xASr[0, newNonHashSpks[farSPinds]], mASr[0, newNonHashSpks[farSPinds], 0])

        # freeClstrs = _N.where(freeClstr == True)[0]
        # L = len(freeClstrs)

        # jjj = 0
        # if (len(farSPinds) >= Kp1) and (len(farMKinds) >= Kp1):
        #     jjj = 1
        #     l1 = L/2

        #     for l in xrange(l1):  # mASr  is 1 x N x K
        #         im = freeClstrs[l]   # Asts + t0 gives absolute time
        #         _u_u[im]  = _N.mean(mASr[0, newNonHashSpks[farMKinds]], axis=0)
        #         y   = _N.exp(-0.5*(xsr - xASr[0, newNonHashSpks[farMKinds]])**2 * isg2)
        #         yc  = _N.sum(y, axis=1)
        #         ix  = _N.where(yc == _N.max(yc))[0][0]
        #         _f_u[im]  = xs[ix]
        #         _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farMKinds]], rowvar=0)*30
        #         _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farMKinds]], axis=0)**2 * 30
        #     # _plt.figure()
        #     # _plt.plot(xs, yc)

        #     for l in xrange(l1, L):
        #         im = freeClstrs[l]   # Asts + t0 gives absolute time
        #         _u_u[im]  = _N.mean(mASr[0, newNonHashSpks[farSPinds]], axis=0)
        #         y   = _N.exp(-0.5*(xsr - xASr[0, newNonHashSpks[farSPinds]])**2 * isg2)
        #         yc  = _N.sum(y, axis=1)
        #         ix  = _N.where(yc == _N.max(yc))[0][0]
        #         _f_u[im]  = xs[ix]
        #         _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farSPinds]], rowvar=0)*30
        #         _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farSPinds]], axis=0)**2 * 30
        #     # _plt.figure()
        #     # _plt.plot(xs, yc)

        # elif (len(farSPinds) >= Kp1) and (len(farMKinds) < Kp1):
        #     jjj = 2
        #     for l in xrange(L):
        #         im = freeClstrs[l]   # Asts + t0 gives absolute time
        #         _u_u[im]  = _N.mean(mASr[0, newNonHashSpks[farSPinds]], axis=0)
        #         y   = _N.exp(-0.5*(xsr - xASr[0, newNonHashSpks[farSPinds]])**2 * isg2)
        #         yc  = _N.sum(y, axis=1)
        #         ix  = _N.where(yc == _N.max(yc))[0][0]
        #         _f_u[im]  = xs[ix]
        #         _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farSPinds]], rowvar=0)*30
        #         _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farSPinds]], axis=0)**2 * 30
        #     # _plt.figure()
        #     # _plt.plot(xs, yc)

        # elif (len(farSPinds) < Kp1) and (len(farMKinds) >= Kp1):
        #     jjj = 3
        #     for l in xrange(L):
        #         im = freeClstrs[l]   # Asts + t0 gives absolute time
        #         _u_u[im]  = _N.mean(mASr[0, newNonHashSpks[farMKinds]], axis=0)
        #         y   = _N.exp(-0.5*(xsr - xASr[0, newNonHashSpks[farMKinds]])**2 * isg2)
        #         yc  = _N.sum(y, axis=1)
        #         ix  = _N.where(yc == _N.max(yc))[0][0]
        #         _f_u[im]  = xs[ix]
        #         _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farMKinds]], rowvar=0)*30
        #         _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farMKinds]], axis=0)**2 * 30
        #     # _plt.figure()
        #     # _plt.plot(xs, yc)

        """
        print "^^^^^^^^"
        print freeClstrs
        print "set priors for freeClstrs   %d" % jjj
        #print _u_u[freeClstrs]
        #print _u_Sg[freeClstrs]
        print _f_u[freeClstrs]
        print _f_q2[freeClstrs]
        """

        #if len(farSPinds) > 10:


        #  set the priors of the freeClusters to be near the far spikes


    ####  outside cmp2Existing here
    #   (Mx1) + (Mx1) - (MxN + MxN)
    #cont       = pkFRr + mkNrms - 0.5*(qdrSPC + qdrMKS)
    cont = _N.empty((M, N))
    _hcb.hc_qdr_sum(pkFRr, mkNrms, qdrSPC, qdrMKS, cont, M, N)

    mcontr     = _N.max(cont, axis=0).reshape((1, nSpks))  
    cont       -= mcontr
    _N.exp(cont, out=econt)

    for m in xrange(M):
        rat[m+1] = rat[m] + econt[m]

    rat /= rat[M]
    """
    # print f
    # print u
    # print q2
    # print Sg
    # print l0
    """

    # print rat

    M1 = rat[1:] >= rnds
    M2 = rat[0:-1] <= rnds

    gz[it] = (M1&M2).T

    if cmp2Existing:
        #  gz   is ITERS x N x Mwowonz   (N # of spikes in epoch)
        gz[it, newNonHashSpks] = False   #  not a member of any of them
        gz[it, newNonHashSpks, newNonHashSpksMemClstr] = True