Exemplo n.º 1
0
def stochasticAssignment(oo, epc, it, Msc, M, K, l0, f, q2, u, Sg, _f_u, _u_u, _f_q2, _u_Sg, Asts, t0, mASr, xASr, rat, econt, gz, qdrMKS, freeClstr, hashthresh, cmp2Existing, nthrds=1):
    #  Msc   Msc signal clusters
    #  M     all clusters, including nz clstr.  M == Msc when not using nzclstr
    #  Gibbs sampling
    #  parameters l0, f, q2
    #  mASr, xASr   just the mark, position of spikes btwn t0 and t1
    #qdrMKS2 = _N.empty(qdrMKS.shape)
    t1 = _tm.time()
    nSpks = len(Asts)
    twpi = 2*_N.pi

    Kp1      = K+1
    #rat      = _N.zeros(M+1)
    pc       = _N.zeros(M)

    ur         = u.reshape((M, 1, K))
    fr         = f.reshape((M, 1))    # centers
    #print q2
    iq2        = 1./q2
    iSg        = _N.linalg.inv(Sg)
    iq2r       = iq2.reshape((M, 1))  
    try:
        ##  warnings because l0 is 0
        isN = _N.where(q2 <= 0)[0]
        if len(isN) > 0:
            q2[isN] = 0.3

        is0 = _N.where(l0 <= 0)[0]
        if len(is0) > 0:
            l0[is0] = 0.001

        pkFR       = _N.log(l0) - 0.5*_N.log(twpi*q2)   #  M
    except RuntimeWarning:
        print "WARNING"
        print l0
        print q2

    mkNrms = _N.log(1/_N.sqrt(twpi*_N.linalg.det(Sg)))
    mkNrms = mkNrms.reshape((M, 1))   #  M x 1

    rnds       = _N.random.rand(nSpks)

    pkFRr      = pkFR.reshape((M, 1))
    dmu        = (mASr - ur)     # mASr 1 x N x K,     ur  is M x 1 x K
    N          = mASr.shape[1]
    #t2 = _tm.time()
    #_N.einsum("mnj,mjk,mnk->mn", dmu, iSg, dmu, out=qdrMKS)
    #t3 = _tm.time()
    _fm.multi_qdrtcs_par_func(dmu, iSg, qdrMKS, M, N, K, nthrds=nthrds)

    #  fr is    M x 1, xASr is 1 x N, iq2r is M x 1
    #qdrSPC     = (fr - xASr)*(fr - xASr)*iq2r  #  M x nSpks   # 0.01s
    qdrSPC     = _N.empty((M, N))
    _hcb.hc_bcast1(fr, xASr, iq2r, qdrSPC, M, N)

    ###  how far is closest cluster to each newly observed mark

    #  mAS = mks[Asts+t0] 
    #  xAS = x[Asts + t0]   #  position @ spikes

    if cmp2Existing:   #  compare only non-hash spikes and non-hash clusters
        # realCl = _N.where(freeClstr == False)[0]
        # print freeClstr.shape
        # print realCl.shape

        abvthrEachCh = mASr[0] > hashthresh    #  should be NxK of
        abvthrAtLeast1Ch = _N.sum(abvthrEachCh, axis=1) > 0   # N x K
        newNonHashSpks   = _N.where(abvthrAtLeast1Ch)[0]

        newNonHashSpksMemClstr = _N.ones(len(newNonHashSpks), dtype=_N.int) * (M-1)   #  initially, assign all of them to noise cluster

        #print "spikes not hash"
         #print abvthrInds
        abvthrEachCh = u[0:Msc] > hashthresh  #  M x K  (M includes noise)
        abvthrAtLeast1Ch = _N.sum(abvthrEachCh, axis=1) > 0
        
        knownNonHclstrs  = _N.where(abvthrAtLeast1Ch & (freeClstr == False) & (q2[0:Msc] < wdSpc))[0]
        

        #print "clusters not hash"

        #  Place prior for freeClstr near new non-hash spikes that are far 
        #  from known clusters that are not hash clusters 


        nNrstMKS_d = _N.sqrt(_N.min(qdrMKS[knownNonHclstrs], axis=0)/K)  #  dim len(sts)
        nNrstSPC_d = _N.sqrt(_N.min(qdrSPC[knownNonHclstrs], axis=0))
        #  for each spike, distance to nearest non-hash cluster
        # print nNrstMKS_d
        # print nNrstSPC_d
        # print "=============="
        s = _N.empty((len(newNonHashSpks), 3))
        #  for each spike, distance to nearest cluster
        s[:, 0] = newNonHashSpks
        s[:, 1] = nNrstMKS_d[newNonHashSpks]
        s[:, 2] = nNrstSPC_d[newNonHashSpks]
        _N.savetxt(resFN("qdrMKSSPC%d" % epc, dir=oo.outdir), s, fmt="%d %.3e %.3e")

        dMK     = nNrstMKS_d[newNonHashSpks]
        dSP     = nNrstSPC_d[newNonHashSpks]

        ###  assignment into 

        farMKinds = _N.where(dMK > 4)[0]    # 
        #  mean of prior for center - mean of farMKinds
        #  cov  of prior for center - how certain am I of mean?  
        farSPinds = _N.where(dSP > 4)[0]  #  4 std. deviations away

        farMKSPinds = _N.union1d(farMKinds, farSPinds)
        print farMKinds
        print newNonHashSpks
        
        ##  points in newNonHashSpks but not in farMKinds
        notFarMKSPinds = _N.setdiff1d(_N.arange(newNonHashSpks.shape[0]), farMKSPinds)

        farMKSP = _N.empty((len(farMKSPinds), K+1))
        farMKSP[:, 0]  = xASr[0, newNonHashSpks[farMKSPinds]]
        farMKSP[:, 1:] = mASr[0, newNonHashSpks[farMKSPinds]]
        notFarMKSP = _N.empty((len(notFarMKSPinds), K+1))
        notFarMKSP[:, 0]  = xASr[0, newNonHashSpks[notFarMKSPinds]]
        notFarMKSP[:, 1:] = mASr[0, newNonHashSpks[notFarMKSPinds]]

        # farSP = _N.empty((len(farSPinds), K+1))
        # farMK = _N.empty((len(farMKinds), K+1))
        # farSP[:, 0]  = xASr[0, farSPinds]
        # farSP[:, 1:] = mASr[0, farSPinds]
        # farMK[:, 0]  = xASr[0, farMKinds]
        # farMK[:, 1:] = mASr[0, farMKinds]

        minK = 1
        maxK = farMKSPinds.shape[0] / K
        maxK = maxK if (maxK < 6) else 6

        freeClstrs = _N.where(freeClstr == True)[0]
        if maxK >= 2:
            print "coming in here"
            #labs, bics, bestLab, nClstrs = _oT.EMBICs(farMKSP, minK=minK, maxK=maxK, TR=1)
            labs, labsH, clstrs = emMKPOS_sep1B(farMKSP, None, TR=1, wfNClstrs=[[1, 4], [1, 4]], spNClstrs=[[1, 4], [1, 3]])
            nClstrs = clstrs[0]
            bestLab    = labs

            cls = clrs.get_colors(nClstrs)

            _U.savetxtWCom(resFN("newSpksMKSP%d" % epc, dir=oo.outdir), farMKSP, fmt="%.3e %.3e %.3e %.3e %.3e", com=("# number of clusters %d" % nClstrs))
            _U.savetxtWCom(resFN("newSpksMKSP_nf%d" % epc, dir=oo.outdir), notFarMKSP, fmt="%.3e %.3e %.3e %.3e %.3e", com=("# number of clusters %d" % nClstrs))

            L = len(freeClstrs)
            
            unqLabs = _N.unique(bestLab)

            upto    = nClstrs if nClstrs < L else L  #  this should just count large clusters
            ii  = -1
            fig = _plt.figure()
            
            for fid in unqLabs[0:upto]:
                iths = farMKSPinds[_N.where(bestLab == fid)[0]]
                ths = newNonHashSpks[iths]

                for w in xrange(K):
                    fig.add_subplot(2, 2, w+1)
                    _plt.scatter(xASr[0, ths], mASr[0, ths, w], color=cls[ii])

                if len(ths) > K:
                    ii += 1
                    im = freeClstrs[ii]   # Asts + t0 gives absolute time
                    newNonHashSpksMemClstr[iths] = im

                    _u_u[im]  = _N.mean(mASr[0, ths], axis=0)
                    u[im]     = _u_u[im]
                    _f_u[im]  = _N.mean(xASr[0, ths], axis=0)
                    f[im]     = _f_u[im]
                    q2[im]    = _N.std(xASr[0, ths], axis=0)**2 * 9
                    #  l0 = Hz * sqrt(2*_N.pi*q2)
                    l0[im]    =   10*_N.sqrt(q2[im])
                    _f_q2[im] = 1
                    _u_Sg[im] = _N.cov(mASr[0, ths], rowvar=0)*25
                    print "ep %(ep)d  new   cluster #  %(m)d" % {"ep" : epc, "m" : im}
                    print _u_u[im]
                    print _f_u[im]
                    print _f_q2[im]
                else:
                    print "too small    this prob. doesn't represent a cluster"

            _plt.savefig("newspks%d" % epc)


            # #######  known clusters
            # for fid in unqLabs[0:upto]:
            #     iths = farMKSPinds[_N.where(bestLab == fid)[0]]
            #     ths = newNonHashSpks[iths]

            #     for w in xrange(K):
            #         fig.add_subplot(2, 2, w+1)
            #         _plt.scatter(xASr[0, ths], mASr[0, ths, w], color=cls[ii])

            #     if len(ths) > K:
            #         ii += 1
            #         im = freeClstrs[ii]   # Asts + t0 gives absolute time
            #         newNonHashSpksMemClstr[iths] = im

            #         _u_u[im]  = _N.mean(mASr[0, ths], axis=0)
            #         u[im]     = _u_u[im]
            #         _f_u[im]  = _N.mean(xASr[0, ths], axis=0)
            #         f[im]     = _f_u[im]
            #         q2[im]    = _N.std(xASr[0, ths], axis=0)**2 * 9
            #         #  l0 = Hz * sqrt(2*_N.pi*q2)
            #         l0[im]    =   10*_N.sqrt(q2[im])
            #         _f_q2[im] = 1
            #         _u_Sg[im] = _N.cov(mASr[0, ths], rowvar=0)*25
            #         print "ep %(ep)d  new   cluster #  %(m)d" % {"ep" : epc, "m" : im}
            #         print _u_u[im]
            #         print _f_u[im]
            #         print _f_q2[im]
            #     else:
            #         print "too small    this prob. doesn't represent a cluster"

            # _plt.savefig("newspks%d" % epc)


        else:  #  just one cluster
            im = freeClstrs[0]   # Asts + t0 gives absolute time

            _u_u[im]  = _N.mean(mASr[0, newNonHashSpks[farMKSPinds]], axis=0)
            _f_u[im]  = _N.mean(xASr[0, newNonHashSpks[farMKSPinds]], axis=0)
            _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farMKSPinds]], rowvar=0)*16
            _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farMKSPinds]], axis=0)**2 * 16

        # ##  kernel density estimate
        # xs  = _N.linspace(-6, 6, 101)
        # xsr = xs.reshape(101, 1)
        # isg2= 1/(0.1**2)   #  spatial kernel bandwidth

        # # fig = _plt.figure(figsize=(6, 9))
        # # fig.add_subplot(1, 2, 1)
        # # _plt.scatter(xASr[0, newNonHashSpks[farMKinds]], mASr[0, newNonHashSpks[farMKinds], 0])
        # # fig.add_subplot(1, 2, 2)
        # # _plt.scatter(xASr[0, newNonHashSpks[farSPinds]], mASr[0, newNonHashSpks[farSPinds], 0])

        # freeClstrs = _N.where(freeClstr == True)[0]
        # L = len(freeClstrs)

        # jjj = 0
        # if (len(farSPinds) >= Kp1) and (len(farMKinds) >= Kp1):
        #     jjj = 1
        #     l1 = L/2

        #     for l in xrange(l1):  # mASr  is 1 x N x K
        #         im = freeClstrs[l]   # Asts + t0 gives absolute time
        #         _u_u[im]  = _N.mean(mASr[0, newNonHashSpks[farMKinds]], axis=0)
        #         y   = _N.exp(-0.5*(xsr - xASr[0, newNonHashSpks[farMKinds]])**2 * isg2)
        #         yc  = _N.sum(y, axis=1)
        #         ix  = _N.where(yc == _N.max(yc))[0][0]
        #         _f_u[im]  = xs[ix]
        #         _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farMKinds]], rowvar=0)*30
        #         _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farMKinds]], axis=0)**2 * 30
        #     # _plt.figure()
        #     # _plt.plot(xs, yc)

        #     for l in xrange(l1, L):
        #         im = freeClstrs[l]   # Asts + t0 gives absolute time
        #         _u_u[im]  = _N.mean(mASr[0, newNonHashSpks[farSPinds]], axis=0)
        #         y   = _N.exp(-0.5*(xsr - xASr[0, newNonHashSpks[farSPinds]])**2 * isg2)
        #         yc  = _N.sum(y, axis=1)
        #         ix  = _N.where(yc == _N.max(yc))[0][0]
        #         _f_u[im]  = xs[ix]
        #         _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farSPinds]], rowvar=0)*30
        #         _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farSPinds]], axis=0)**2 * 30
        #     # _plt.figure()
        #     # _plt.plot(xs, yc)

        # elif (len(farSPinds) >= Kp1) and (len(farMKinds) < Kp1):
        #     jjj = 2
        #     for l in xrange(L):
        #         im = freeClstrs[l]   # Asts + t0 gives absolute time
        #         _u_u[im]  = _N.mean(mASr[0, newNonHashSpks[farSPinds]], axis=0)
        #         y   = _N.exp(-0.5*(xsr - xASr[0, newNonHashSpks[farSPinds]])**2 * isg2)
        #         yc  = _N.sum(y, axis=1)
        #         ix  = _N.where(yc == _N.max(yc))[0][0]
        #         _f_u[im]  = xs[ix]
        #         _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farSPinds]], rowvar=0)*30
        #         _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farSPinds]], axis=0)**2 * 30
        #     # _plt.figure()
        #     # _plt.plot(xs, yc)

        # elif (len(farSPinds) < Kp1) and (len(farMKinds) >= Kp1):
        #     jjj = 3
        #     for l in xrange(L):
        #         im = freeClstrs[l]   # Asts + t0 gives absolute time
        #         _u_u[im]  = _N.mean(mASr[0, newNonHashSpks[farMKinds]], axis=0)
        #         y   = _N.exp(-0.5*(xsr - xASr[0, newNonHashSpks[farMKinds]])**2 * isg2)
        #         yc  = _N.sum(y, axis=1)
        #         ix  = _N.where(yc == _N.max(yc))[0][0]
        #         _f_u[im]  = xs[ix]
        #         _u_Sg[im] = _N.cov(mASr[0, newNonHashSpks[farMKinds]], rowvar=0)*30
        #         _f_q2[im] = _N.std(xASr[0, newNonHashSpks[farMKinds]], axis=0)**2 * 30
        #     # _plt.figure()
        #     # _plt.plot(xs, yc)

        """
        print "^^^^^^^^"
        print freeClstrs
        print "set priors for freeClstrs   %d" % jjj
        #print _u_u[freeClstrs]
        #print _u_Sg[freeClstrs]
        print _f_u[freeClstrs]
        print _f_q2[freeClstrs]
        """

        #if len(farSPinds) > 10:


        #  set the priors of the freeClusters to be near the far spikes


    ####  outside cmp2Existing here
    #   (Mx1) + (Mx1) - (MxN + MxN)
    #cont       = pkFRr + mkNrms - 0.5*(qdrSPC + qdrMKS)
    cont = _N.empty((M, N))
    _hcb.hc_qdr_sum(pkFRr, mkNrms, qdrSPC, qdrMKS, cont, M, N)

    mcontr     = _N.max(cont, axis=0).reshape((1, nSpks))  
    cont       -= mcontr
    _N.exp(cont, out=econt)

    for m in xrange(M):
        rat[m+1] = rat[m] + econt[m]

    rat /= rat[M]
    """
    # print f
    # print u
    # print q2
    # print Sg
    # print l0
    """

    # print rat

    M1 = rat[1:] >= rnds
    M2 = rat[0:-1] <= rnds

    gz[it] = (M1&M2).T

    if cmp2Existing:
        #  gz   is ITERS x N x Mwowonz   (N # of spikes in epoch)
        gz[it, newNonHashSpks] = False   #  not a member of any of them
        gz[it, newNonHashSpks, newNonHashSpksMemClstr] = True
Exemplo n.º 2
0
def initClusters(oo, K, x, mks, t0, t1, Asts, doSepHash=True, xLo=0, xHi=3, oneCluster=False, nzclstr=False):
    n0 = 0
    n1 = len(Asts)

    _x   = _N.empty((n1-n0, K+1))
    _x[:, 0]    = x[Asts+t0]
    _x[:, 1:]   = mks[Asts+t0]

    if oneCluster:
        unonhash = _N.arange(len(Asts))
        hashsp   = _N.array([])
        hashthresh = _N.min(_x[:, 1:], axis=0)   #  no hash spikes

        labS     = _N.zeros(len(Asts), dtype=_N.int)
        labH     = _N.array([], dtype=_N.int)
        clstrs   = _N.array([0, 1])
        lab      = _N.array(labS.tolist() + (labH + clstrs[0]).tolist())
        M        = 1
        MF       = 1
        flatlabels = _N.zeros(len(Asts), dtype=_N.int)
    else:
        if not doSepHash:
            unonhash = _N.arange(len(Asts))
            hashsp   = _N.array([])
            hashthresh = _N.min(_x[:, 1:], axis=0)   #  no hash spikes

            ###   1 cluster
            # labS = _N.zeros(len(Asts), dtype=_N.int)
            # labH = _N.array([], dtype=_N.int)
            # clstrs = _N.array([0, 1])
        else:
            unonhash, hashsp, hashthresh = sepHash(_x, BINS=20, blksz=5, xlo=oo.xLo, xhi=oo.xHi)
            #  hashthresh is dim 2

            # print len(unonhash)
            # print "--------"
            # print len(hashsp)
            # fig = _plt.figure(figsize=(5, 10))
            # fig.add_subplot(3, 1, 1)
            # _plt.scatter(_x[hashsp, 1], _x[hashsp, 2], color="red")
            # _plt.scatter(_x[unonhash, 1], _x[unonhash, 2], color="black")
            # fig.add_subplot(3, 1, 2)
            # _plt.scatter(_x[hashsp, 0], _x[hashsp, 1], color="red")
            # _plt.scatter(_x[unonhash, 0], _x[unonhash, 1], color="black")
            # fig.add_subplot(3, 1, 3)
            # _plt.scatter(_x[hashsp, 0], _x[hashsp, 2], color="red")
            # _plt.scatter(_x[unonhash, 0], _x[unonhash, 2], color="black")


        # len(hashsp)==len(labH)
        # len(unonhash)==len(labS)
        if (len(unonhash) > 0) and (len(hashsp) > 0): 
            labS, labH, clstrs = emMKPOS_sep1B(_x[unonhash], _x[hashsp])
        elif len(unonhash) == 0:
            labS, labH, clstrs = emMKPOS_sep1B(None, _x[hashsp], TR=5)
        else:
            labS, labH, clstrs = emMKPOS_sep1B(_x[unonhash], None, TR=5)
        if doSepHash:
            splitclstrs(_x[unonhash], labS)
            posMkCov0(_x[unonhash], labS)

            #mergesmallclusters(_x[unonhash], _x[hashsp], labS, labH, K+1, clstrs)
            smallClstrID, spksInSmallClstrs = findsmallclusters(_x[unonhash], labS, K+1)

            print smallClstrID
            _N.savetxt("labSb4", labS, fmt="%d")
            for nid in smallClstrID:
                ths = _N.where(labS == nid)[0]
                labS[ths] = -1#clstrs[0]+clstrs[1]-1  # -1 first for easy cpack2
            _N.savetxt("labS", labS, fmt="%d")

            # 0...clstrs[0]-1     clstrs[0]...clstrs[0]+clstrs[1]-1  (no nz)
            # 0...clstrs[0]-2     clstrs[0]-1...clstrs[0]+clstrs[1]-2  (no nz)
            contiguous_pack2(labS, startAt=-1)

            clstrs[0] = len(_N.unique(labS)) 
            clstrs[1] = len(_N.unique(labH))

            print "----------"
            print clstrs
            print "----------"
            # labS [0...#S]   labH [#S...#S+#H]
            
            nzspks = _N.where(labS == -1)[0]
            labS[nzspks] = clstrs[0]+clstrs[1]-1   #  highest ID

            contiguous_pack2(labH, startAt=(clstrs[0]-1))
            _N.savetxt("labH", labH, fmt="%d")
            _N.savetxt("labS", labS, fmt="%d")

            #contiguous_pack2(labH, startAt=(_N.max(labS)+1))

            nonnz = _N.where(labS < clstrs[0]-1)[0]
            nz    = _N.where(labS == clstrs[0]+clstrs[1]-1)[0]
            _plt.scatter(_x[hashsp, 0], _x[hashsp, 1], color="black")
            _plt.scatter(_x[unonhash[nonnz], 0], _x[unonhash[nonnz], 1], color="blue")
            _plt.scatter(_x[unonhash[nz], 0], _x[unonhash[nz], 1], color="red")

            #colorclusters(_x[hashsp], labH, clstrs[1], name="hash", xLo=xLo, xHi=xHi)
            #colorclusters(_x[unonhash], labS, clstrs[0], name="nhash", xLo=xLo, xHi=xHi)


    #     #fig = _plt.figure(figsize=(7, 10))
    #     #fig.add_subplot(2, 1, 1)

        flatlabels = _N.ones(n1-n0, dtype=_N.int)*-1   # 
        #cls = clrs.get_colors(clstrs[0] + clstrs[1])

        for i in labS:
            these = _N.where(labS == i)[0]

            if len(these) > 0:
                flatlabels[unonhash[these]] = i
            #_plt.scatter(_x[unonhash[these], 0], _x[unonhash[these], 1], color=cls[i])
        #for i in xrange(clstrs[1]):
        for i in labH:
            these = _N.where(labH == i)[0]

            if len(these) > 0:
                flatlabels[hashsp[these]] = i 
            #_plt.scatter(_x[hashsp[these], 0], _x[hashsp[these], 1], color=cls[i+clstrs[0]])

        MF     = clstrs[0] + clstrs[1]   #  includes noise
        if nzclstr:
            ths = _N.where(flatlabels == -1)[0]
            flatlabels[ths] = MF - 1
            M = int((clstrs[0]-1) * 1.3 + clstrs[1]) + 2   #  20% more clusters
        else:
            M = int(clstrs[0] * 1.3 + clstrs[1]) + 2   #  20% more clusters
        print "cluters:  %d" % M

    Mwonz     = M if (nzclstr is False) else M-1
    #####  MODES  - find from the sampling
    oo.sp_prmPstMd = _N.zeros((oo.epochs, 3*Mwonz))   # mode of params
    oo.sp_hypPstMd  = _N.zeros((oo.epochs, (2+2+2)*Mwonz)) # hyperparam
    oo.mk_prmPstMd = [_N.zeros((oo.epochs, Mwonz, K)),
                      _N.zeros((oo.epochs, Mwonz, K, K))]
                      # mode of params
    oo.mk_hypPstMd  = [_N.zeros((oo.epochs, Mwonz, K)),
                       _N.zeros((oo.epochs, Mwonz, K, K)), # hyperparam
                       _N.zeros((oo.epochs, Mwonz, 1)), # hyperparam
                       _N.zeros((oo.epochs, Mwonz, K, K))]

    print labS
    print labH
    _N.savetxt("flatlabels", flatlabels, fmt="%d")
    ##################

    # flatlabels + lab = same content, but flatlabels are temporally correct
    return labS, labH, flatlabels, Mwonz, MF, hashthresh, clstrs