Ejemplo n.º 1
0
    def _prepare_object(self, Yr, T, expected_comps, new_dims=None, idx_components=None,
                        g=None, lam=None, s_min=None, bl=None, use_dense=True):

        self.expected_comps = expected_comps

        if idx_components is None:
            idx_components = range(self.A.shape[-1])

        self.A2 = self.A.tocsc()[:, idx_components]
        self.C2 = self.C[idx_components]
        self.b2 = self.b
        self.f2 = self.f
        self.S2 = self.S[idx_components]
        self.YrA2 = self.YrA[idx_components]
        self.g2 = self.g[idx_components]
        self.bl2 = self.bl[idx_components]
        self.c12 = self.c1[idx_components]
        self.neurons_sn2 = self.neurons_sn[idx_components]
        self.lam2 = self.lam[idx_components]
        self.dims2 = self.dims

        self.N = self.A2.shape[-1]
        self.M = self.gnb + self.N

        if Yr.shape[-1] != self.initbatch:
            raise Exception(
                'The movie size used for initialization does not match with the minibatch size')

        if new_dims is not None:

            new_Yr = np.zeros([np.prod(new_dims), T])
            for ffrr in range(T):
                tmp = cv2.resize(Yr[:, ffrr].reshape(self.dims2, order='F'), new_dims[::-1])
                print(tmp.shape)
                new_Yr[:, ffrr] = tmp.reshape([np.prod(new_dims)], order='F')
            Yr = new_Yr
            A_new = scipy.sparse.csc_matrix(
                (np.prod(new_dims), self.A2.shape[-1]), dtype=np.float32)
            for neur in range(N):
                a = self.A2.tocsc()[:, neur].toarray()
                a = a.reshape(self.dims2, order='F')
                a = cv2.resize(a, new_dims[::-1]).reshape([-1, 1], order='F')

                A_new[:, neur] = scipy.sparse.csc_matrix(a)

            self.A2 = A_new
            self.b2 = self.b2.reshape(self.dims2, order='F')
            self.b2 = cv2.resize(self.b2, new_dims[::-1]).reshape([-1, 1], order='F')

            self.dims2 = new_dims

        nA = np.ravel(np.sqrt(self.A2.power(2).sum(0)))
        self.A2 /= nA
        self.C2 *= nA[:, None]
        self.YrA2 *= nA[:, None]
#        self.S2 *= nA[:, None]
        self.neurons_sn2 *= nA
        self.lam2 *= nA
        z = np.sqrt([b.T.dot(b) for b in self.b2.T])
        self.f2 *= z[:, None]
        self.b2 /= z

        self.noisyC = np.zeros((self.gnb + expected_comps, T), dtype=np.float32)
        self.C_on = np.zeros((expected_comps, T), dtype=np.float32)

#        self.noisyC[:, :self.initbatch] = np.vstack(
#            [self.C[:, :self.initbatch] + self.YrA, self.f])
#        Ab_ = scipy.sparse.csc_matrix(np.c_[self.A2, self.b2])
#        AtA_ = Ab_.T.dot(Ab_)
#        self.noisyC[:,:self.initbatch] = hals_full(Yr[:, :self.initbatch], Ab_, np.r_[self.C,self.f], iters=3)
#
#        for t in xrange(self.initbatch):
#            if t % 100 == 0:
#                print(t)
#            self.noisyC[:, t] = HALS4activity(Yr[:, t], Ab_, np.ones(N + 1) if t == 0 else
# self.noisyC[:, t - 1].copy(), AtA_, iters=30 if t == 0 else 5)
        self.noisyC[self.gnb:self.M, :self.initbatch] = self.C2 + self.YrA2
        self.noisyC[:self.gnb, :self.initbatch] = self.f2

        # if no parameter for calculating the spike size threshold is given, then use L1 penalty
        if s_min is None and self.s_min is None and self.thresh_s_min is None:
            use_L1 = True
        else:
            use_L1 = False
            
        self.OASISinstances = [oasis.OASIS(            
            g = np.ravel(0.01) if self.p == 0 else (np.ravel(g)[0] if g is not None else gam[0]),
            lam=0 if not use_L1 else (l if lam is None else lam),
            # if no explicit value for s_min,  use thresh_s_min * noise estimate * sqrt(1-gamma)
            s_min=0 if use_L1 else (s_min if s_min is not None else
                                    (self.s_min if self.s_min is not None else
                                     (self.thresh_s_min * sn * np.sqrt(1 - np.sum(gam))))),
            b=b if bl is None else bl,
            g2=0 if self.p < 2 else (np.ravel(g)[1] if g is not None else gam[1]))
            for gam, l, b, sn in zip(self.g2, self.lam2, self.bl2, self.neurons_sn2)]

        for i, o in enumerate(self.OASISinstances):
            o.fit(self.noisyC[i + self.gnb, :self.initbatch])
            self.C_on[i, :self.initbatch] = o.c

        self.Ab, self.ind_A, self.CY, self.CC = init_shapes_and_sufficient_stats(
            Yr[:, :self.initbatch].reshape(self.dims2 + (-1,), order='F'), self.A2,
            self.C_on[:self.N, :self.initbatch], self.b2, self.noisyC[:self.gnb, :self.initbatch])

        self.CY, self.CC = self.CY * 1. / self.initbatch, 1 * self.CC / self.initbatch

        self.A2 = scipy.sparse.csc_matrix(self.A2.astype(np.float32), dtype=np.float32)
        self.C2 = self.C2.astype(np.float32)
        self.f2 = self.f2.astype(np.float32)
        self.b2 = self.b2.astype(np.float32)
        self.Ab = scipy.sparse.csc_matrix(self.Ab.astype(np.float32), dtype=np.float32)
        self.noisyC = self.noisyC.astype(np.float32)
        self.CY = self.CY.astype(np.float32)
        self.CC = self.CC.astype(np.float32)
        print('Expecting ' + str(self.expected_comps) + ' components')
        self.CY.resize([self.expected_comps + 1, self.CY.shape[-1]])
        if use_dense:
            self.Ab_dense = np.zeros((self.CY.shape[-1], self.expected_comps + 1),
                                     dtype=np.float32)
            self.Ab_dense[:, :self.Ab.shape[1]] = self.Ab.toarray()
        self.C_on = np.vstack([self.noisyC[:self.gnb, :], self.C_on.astype(np.float32)])

        self.gSiz = np.add(np.multiply(self.gSig, 2), 1)

        self.Yr_buf = RingBuffer(Yr[:, self.initbatch - self.minibatch_shape:
                                    self.initbatch].T.copy(), self.minibatch_shape)
        self.Yres_buf = RingBuffer(self.Yr_buf - self.Ab.dot(
            self.C_on[:self.M, self.initbatch - self.minibatch_shape:self.initbatch]).T, self.minibatch_shape)
        self.rho_buf = imblur(self.Yres_buf.T.reshape(
            self.dims2 + (-1,), order='F'), sig=self.gSig, siz=self.gSiz, nDimBlur=2)**2
        self.rho_buf = np.reshape(self.rho_buf, (self.dims2[0] * self.dims2[1], -1)).T
        self.rho_buf = RingBuffer(self.rho_buf, self.minibatch_shape)
        self.AtA = (self.Ab.T.dot(self.Ab)).toarray()
        self.AtY_buf = self.Ab.T.dot(self.Yr_buf.T)
        self.sv = np.sum(self.rho_buf.get_last_frames(min(self.initbatch, self.minibatch_shape) - 1), 0)
        self.groups = list(map(list, update_order(self.Ab)[0]))
        # self.update_counter = np.zeros(self.N)
        self.update_counter = .5**(-np.linspace(0, 1, self.N, dtype=np.float32))
        self.time_neuron_added = []
        for nneeuu in range(self.N):
            self.time_neuron_added.append((nneeuu, self.initbatch))
        self.time_spend = 0
        return self
Ejemplo n.º 2
0
def update_num_components(t,
                          sv,
                          Ab,
                          Cf,
                          Yres_buf,
                          Y_buf,
                          rho_buf,
                          dims,
                          gSig,
                          gSiz,
                          ind_A,
                          CY,
                          CC,
                          groups,
                          oases,
                          gnb=1,
                          rval_thr=0.875,
                          bSiz=3,
                          robust_std=False,
                          N_samples_exceptionality=5,
                          remove_baseline=True,
                          thresh_fitness_delta=-80,
                          thresh_fitness_raw=-20,
                          thresh_overlap=0.25,
                          batch_update_suff_stat=False,
                          sn=None,
                          g=None,
                          thresh_s_min=None,
                          s_min=None,
                          Ab_dense=None,
                          max_num_added=1,
                          min_num_trial=1,
                          loaded_model=None,
                          thresh_CNN_noisy=0.99,
                          sniper_mode=False,
                          use_peak_max=False,
                          test_both=False):
    """
    Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
    """

    ind_new = []
    gHalf = np.array(gSiz) // 2

    # number of total components (including background)
    M = np.shape(Ab)[-1]
    N = M - gnb  # number of coponents (without background)

    sv -= rho_buf.get_first()
    # update variance of residual buffer
    sv += rho_buf.get_last_frames(1).squeeze()
    sv = np.maximum(sv, 0)

    Ains, Cins, Cins_res, inds, ijsig_all, cnn_pos, local_max = get_candidate_components(
        sv,
        dims,
        Yres_buf=Yres_buf,
        min_num_trial=min_num_trial,
        gSig=gSig,
        gHalf=gHalf,
        sniper_mode=sniper_mode,
        rval_thr=rval_thr,
        patch_size=50,
        loaded_model=loaded_model,
        thresh_CNN_noisy=thresh_CNN_noisy,
        use_peak_max=use_peak_max,
        test_both=test_both)

    ind_new_all = ijsig_all

    num_added = len(inds)
    cnt = 0
    for ind, ain, cin, cin_res in zip(inds, Ains, Cins, Cins_res):
        cnt += 1
        ij = np.unravel_index(ind, dims)

        ijSig = [[max(i - temp_g, 0),
                  min(i + temp_g + 1, d)]
                 for i, temp_g, d in zip(ij, gHalf, dims)]
        dims_ain = (np.abs(np.diff(ijSig[1])[0]), np.abs(np.diff(ijSig[0])[0]))

        indeces = np.ravel_multi_index(
            np.ix_(*[np.arange(ij[0], ij[1]) for ij in ijSig]),
            dims,
            order='F').ravel()

        # use sparse Ain only later iff it is actually added to Ab
        Ain = np.zeros((np.prod(dims), 1), dtype=np.float32)
        Ain[indeces, :] = ain[:, None]

        cin_circ = cin.get_ordered()
        useOASIS = False  # whether to use faster OASIS for cell detection
        accepted = True  # flag indicating new component has not been rejected yet

        if Ab_dense is None:
            ff = np.where((Ab.T.dot(Ain).T > thresh_overlap)[:, gnb:])[1] + gnb
        else:
            ff = np.where(
                Ab_dense[indeces, gnb:].T.dot(ain).T > thresh_overlap)[0] + gnb

        if ff.size > 0:
            #                accepted = False
            cc = [corr(cin_circ.copy(), cins) for cins in Cf[ff, :]]
            if np.any(np.array(cc) > .25) and accepted:
                accepted = False  # reject component as duplicate

        if s_min is None:
            s_min = 0
        # use s_min * noise estimate * sqrt(1-sum(gamma))
        elif s_min < 0:
            # the formula has been obtained by running OASIS with s_min=0 and lambda=0 on Gaussin noise.
            # e.g. 1 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the root mean square (non-zero) spike size, sqrt(<s^2>)
            #      2 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 95% percentile of (non-zero) spike sizes
            #      3 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 99.7% percentile of (non-zero) spike sizes
            s_min = -s_min * sqrt(
                (ain**2).dot(sn[indeces]**2)) * sqrt(1 - np.sum(g))

        cin_res = cin_res.get_ordered()
        if accepted:
            if useOASIS:
                oas = oasis.OASIS(g=g,
                                  s_min=s_min,
                                  num_empty_samples=t + 1 - len(cin_res))
                for yt in cin_res:
                    oas.fit_next(yt)
                accepted = oas.get_l_of_last_pool() <= t
            else:
                fitness_delta, erfc_delta, std_rr, _ = compute_event_exceptionality(
                    np.diff(cin_res)[None, :],
                    robust_std=robust_std,
                    N=N_samples_exceptionality)
                if remove_baseline:
                    num_samps_bl = min(len(cin_res) // 5, 800)
                    bl = scipy.ndimage.percentile_filter(cin_res,
                                                         8,
                                                         size=num_samps_bl)
                else:
                    bl = 0
                fitness_raw, erfc_raw, std_rr, _ = compute_event_exceptionality(
                    (cin_res - bl)[None, :],
                    robust_std=robust_std,
                    N=N_samples_exceptionality)
                accepted = (fitness_delta < thresh_fitness_delta) or (
                    fitness_raw < thresh_fitness_raw)

#        if accepted:
#            dims_ain = (np.abs(np.diff(ijSig[1])[0]), np.abs(np.diff(ijSig[0])[0]))
#            thrcomp = threshold_components(ain[:,None],
#                                 dims_ain, medw=None, thr_method='max', maxthr=0.2,
#                                 nrgthr=0.99, extract_cc=True,
#                                 se=None, ss=None)
#
#            sznr = np.sum(thrcomp>0)
#            accepted = (sznr >= np.pi*(np.prod(gSig)/4))
#            if not accepted:
#                print('Rejected because of size')

        if accepted:
            # print('adding component' + str(N + 1) + ' at timestep ' + str(t))
            num_added += 1
            ind_new.append(ijSig)

            if oases is not None:
                if not useOASIS:
                    # lambda from Selesnick's 3*sigma*|K| rule
                    # use noise estimate from init batch or use std_rr?
                    #                    sn_ = sqrt((ain**2).dot(sn[indeces]**2)) / sqrt(1 - g**2)
                    sn_ = std_rr
                    oas = oasis.OASIS(
                        np.ravel(g)[0],
                        3 * sn_ / (sqrt(1 - g**2) if np.size(g) == 1 else sqrt(
                            (1 + g[1]) * ((1 - g[1])**2 - g[0]**2) /
                            (1 - g[1]))) if s_min == 0 else 0,
                        s_min,
                        num_empty_samples=t + 1 - len(cin_res),
                        g2=0 if np.size(g) == 1 else g[1])
                    for yt in cin_res:
                        oas.fit_next(yt)

                oases.append(oas)

            Ain_csc = scipy.sparse.csc_matrix(
                (ain, (indeces, [0] * len(indeces))), (np.prod(dims), 1),
                dtype=np.float32)
            if Ab_dense is None:
                groups = update_order(Ab, Ain, groups)[0]
            else:
                groups = update_order(Ab_dense[indeces], ain, groups)[0]
                Ab_dense = np.hstack((Ab_dense, Ain))
            # faster version of scipy.sparse.hstack
            csc_append(Ab, Ain_csc)
            ind_A.append(Ab.indices[Ab.indptr[M]:Ab.indptr[M + 1]])

            tt = t * 1.
            Y_buf_ = Y_buf
            cin_ = cin
            Cf_ = Cf
            cin_circ_ = cin_circ

            CY[M, indeces] = cin_.dot(Y_buf_[:, indeces]) / tt

            # preallocate memory for speed up?
            CC1 = np.hstack([CC, Cf_.dot(cin_circ_ / tt)[:, None]])
            CC2 = np.hstack([(Cf_.dot(cin_circ_)).T,
                             cin_circ_.dot(cin_circ_)]) / tt
            CC = np.vstack([CC1, CC2])
            Cf = np.vstack([Cf, cin_circ])

            N = N + 1
            M = M + 1

            Yres_buf[:, indeces] -= np.outer(cin, ain)
            # vb = imblur(np.reshape(Ain, dims, order='F'), sig=gSig,
            #             siz=gSiz, nDimBlur=2).ravel()
            # restrict blurring to region where component is located
            #            vb = np.reshape(Ain, dims, order='F')
            slices = tuple(
                slice(max(0, ijs[0] - 2 * sg), min(d, ijs[1] + 2 * sg))
                for ijs, sg, d in zip(ijSig, gSiz // 2, dims))  # is 2 enough?

            slice_within = tuple(
                slice(ijs[0] - sl.start, ijs[1] - sl.start)
                for ijs, sl in zip(ijSig, slices))

            ind_vb = np.ravel_multi_index(
                np.ix_(*[np.arange(ij[0], ij[1]) for ij in ijSig]),
                dims,
                order='C').ravel()

            vb_buf = [
                imblur(np.maximum(
                    0,
                    vb.reshape(dims, order='F')[slices][slice_within]),
                       sig=gSig,
                       siz=gSiz,
                       nDimBlur=len(dims)) for vb in Yres_buf
            ]

            vb_buf2 = np.stack([vb.ravel() for vb in vb_buf])

            #            ind_vb = np.ravel_multi_index(
            #                    np.ix_(*[np.arange(s.start, s.stop)
            #                           for s in slices_small]), dims).ravel()

            rho_buf[:, ind_vb] = vb_buf2**2

            sv[ind_vb] = np.sum(rho_buf[:, ind_vb], 0)
#            sv = np.sum([imblur(vb.reshape(dims,order='F'), sig=gSig, siz=gSiz, nDimBlur=len(dims))**2 for vb in Yres_buf], 0).reshape(-1)
#            plt.subplot(1,5,4)
#            plt.cla()
#            plt.imshow(sv.reshape(dims), vmax=30)
#            plt.pause(.05)
#            plt.subplot(1,5,5)
#            plt.cla()
#            plt.imshow(Yres_buf.mean(0).reshape(dims,order='F'))
#            plt.imshow(np.sum([imblur(vb.reshape(dims,order='F'),\
#                                       sig=gSig, siz=gSiz, nDimBlur=len(dims))**2\
#                                        for vb in Yres_buf],axis=0), vmax=30)
#            plt.pause(.05)

#print(np.min(sv))
#    plt.subplot(1,3,3)
#    plt.cla()
#    plt.imshow(Yres_buf.mean(0).reshape(dims, order = 'F'))
#    plt.pause(.05)
    return Ab, Cf, Yres_buf, rho_buf, CC, CY, ind_A, sv, groups, ind_new, ind_new_all, sv, cnn_pos
Ejemplo n.º 3
0
def update_num_components(t,
                          sv,
                          Ab,
                          Cf,
                          Yres_buf,
                          Y_buf,
                          rho_buf,
                          dims,
                          gSig,
                          gSiz,
                          ind_A,
                          CY,
                          CC,
                          groups,
                          oases,
                          gnb=1,
                          rval_thr=0.875,
                          bSiz=3,
                          robust_std=False,
                          N_samples_exceptionality=5,
                          remove_baseline=True,
                          thresh_fitness_delta=-20,
                          thresh_fitness_raw=-20,
                          thresh_overlap=0.5,
                          batch_update_suff_stat=False,
                          sn=None,
                          g=None,
                          lam=0,
                          thresh_s_min=None,
                          s_min=None,
                          Ab_dense=None,
                          max_num_added=1):

    gHalf = np.array(gSiz) // 2

    M = np.shape(Ab)[-1]
    N = M - gnb

    #    Yres = np.array(Yres_buf).T
    #    Y = np.array(Y_buf).T
    #    rhos = np.array(rho_buf).T

    first = True

    sv -= rho_buf.get_first()
    sv += rho_buf.get_last_frames(1).squeeze()

    num_added = 0
    while num_added < max_num_added:

        if first:
            sv_ = sv.copy()  # np.sum(rho_buf,0)
            first = False

        ind = np.argmax(sv_)
        ij = np.unravel_index(ind, dims)
        # ijSig = [[np.maximum(ij[c] - gHalf[c], 0), np.minimum(ij[c] + gHalf[c] + 1, dims[c])]
        #          for c in range(len(ij))]
        # better than above expensive call of numpy and loop creation
        ijSig = [[
            max(ij[0] - gHalf[0], 0),
            min(ij[0] + gHalf[0] + 1, dims[0])
        ], [max(ij[1] - gHalf[1], 0),
            min(ij[1] + gHalf[1] + 1, dims[1])]]

        # xySig = np.meshgrid(*[np.arange(s[0], s[1]) for s in ijSig], indexing='xy')
        # arr = np.array([np.reshape(s, (1, np.size(s)), order='F').squeeze()
        #                 for s in xySig], dtype=np.int)
        # indeces = np.ravel_multi_index(arr, dims, order='F')
        indeces = np.ravel_multi_index(np.ix_(
            np.arange(ijSig[0][0], ijSig[0][1]),
            np.arange(ijSig[1][0], ijSig[1][1])),
                                       dims,
                                       order='F').ravel()

        Ypx = Yres_buf.T[indeces, :]

        ain = np.maximum(np.mean(Ypx, 1), 0)
        na = ain.dot(ain)
        if not na:
            break

        ain /= sqrt(na)

        #        new_res = sv_.copy()
        #        new_res[ np.ravel_multi_index(arr, dims, order='C')] = 10000
        #        cv2.imshow('untitled', 0.1*cv2.resize(new_res.reshape(dims,order = 'C'),(512,512))/2000)
        #        cv2.waitKey(1)

        #        for iter_ in range(15):
        #            cin_res = ain.T.dot(Ypx) / ain.dot(ain)
        #            cin = np.maximum(cin_res, 0)
        #            ain = np.maximum(Ypx.dot(cin.T) / cin.dot(cin), 0)

        ain, cin, cin_res = rank1nmf(Ypx,
                                     ain)  # expects and returns normalized ain

        rval = corr(ain.copy(), np.mean(Ypx, -1))
        #        print(rval)
        if rval > rval_thr:
            # na = sqrt(ain.dot(ain))
            # ain /= na
            # cin = na * cin
            # use sparse Ain only later iff it is actually added to Ab
            Ain = np.zeros((np.prod(dims), 1), dtype=np.float32)
            # Ain = scipy.sparse.csc_matrix((np.prod(dims), 1), dtype=np.float32)
            Ain[indeces, :] = ain[:, None]

            cin_circ = cin.get_ordered()

            #        indeces_good = (Ain[indeces]>0.01).nonzero()[0]

            # rval = np.corrcoef(ain, np.mean(Ypx, -1))[0, 1]

            # rval =
            # np.corrcoef(Ain[indeces_good].toarray().squeeze(),np.mean(Yres[indeces_good,:],-1))[0,1]

            # if rval > rval_thr:
            #            pl.cla()
            #            _ = cm.utils.visualization.plot_contours(Ain, sv.reshape(dims), thr=0.95)
            #            pl.pause(0.01)

            useOASIS = False  # whether to use faster OASIS for cell detection
            if Ab_dense is None:
                ff = np.where(
                    (Ab.T.dot(Ain).T > thresh_overlap)[:, gnb:])[1] + gnb
            else:
                ff = np.where(
                    Ab_dense[indeces,
                             gnb:].T.dot(ain).T > thresh_overlap)[0] + gnb
            if ff.size > 0:
                cc = [corr(cin_circ.copy(), cins) for cins in Cf[ff, :]]

                if np.any(np.array(cc) > .8):
                    #                    repeat = False
                    # vb = imblur(np.reshape(Ain, dims, order='F'),
                    #             sig=gSig, siz=gSiz, nDimBlur=2)
                    # restrict blurring to region where component is located
                    vb = np.reshape(Ain, dims, order='F')
                    slices = tuple(
                        slice(max(0, ijs[0] - 2 * sg),
                              min(d, ijs[1] + 2 * sg)) for ijs, sg, d in zip(
                                  ijSig, gSig, dims))  # is 2 enough?
                    vb[slices] = imblur(vb[slices],
                                        sig=gSig,
                                        siz=gSiz,
                                        nDimBlur=2)
                    sv_ -= (vb.ravel()**2) * cin.dot(cin)

                    #                    pl.imshow(np.reshape(sv,dims));pl.pause(0.001)
                    # print('Overlap at step' + str(t) + ' ' + str(cc))
                    break

            if s_min is None:  # use thresh_s_min * noise estimate * sqrt(1-sum(gamma))
                # the formula has been obtained by running OASIS with s_min=0 and lambda=0 on Gaussin noise.
                # e.g. 1 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the root mean square (non-zero) spike size, sqrt(<s^2>)
                #      2 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 95% percentile of (non-zero) spike sizes
                #      3 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 99.7% percentile of (non-zero) spike sizes
                s_min = 0 if thresh_s_min is None else thresh_s_min * \
                    sqrt((ain**2).dot(sn[indeces]**2)) * sqrt(1 - np.sum(g))

            cin_res = cin_res.get_ordered()
            if useOASIS:
                oas = oasis.OASIS(g=g,
                                  s_min=s_min,
                                  num_empty_samples=t + 1 - len(cin_res))
                for yt in cin_res:
                    oas.fit_next(yt)
                foo = oas.get_l_of_last_pool() <= t
                # cc=oas.c
                # print([np.corrcoef(cin_circ,cins)[0,1] for cins in Cf[overlap[0] > 0]])
                # print([np.corrcoef(cc,cins)[0,1] for cins in Cf[overlap[0] > 0, ]])
                # import matplotlib.pyplot as plt
                # plt.plot(cin_res); plt.plot(cc); plt.show()
                # import pdb;pdb.set_trace()
            else:
                fitness_delta, erfc_delta, std_rr, _ = compute_event_exceptionality(
                    np.diff(cin_res)[None, :],
                    robust_std=robust_std,
                    N=N_samples_exceptionality)
                if remove_baseline:
                    num_samps_bl = min(len(cin_res) // 5, 800)
                    bl = scipy.ndimage.percentile_filter(cin_res,
                                                         8,
                                                         size=num_samps_bl)
                else:
                    bl = 0
                fitness_raw, erfc_raw, std_rr, _ = compute_event_exceptionality(
                    (cin_res - bl)[None, :],
                    robust_std=robust_std,
                    N=N_samples_exceptionality)
                foo = (fitness_delta < thresh_fitness_delta) or (
                    fitness_raw < thresh_fitness_raw)

            if foo:
                # print('adding component' + str(N + 1) + ' at timestep ' + str(t))
                num_added += 1
                #                ind_a = uniform_filter(np.reshape(Ain.toarray(), dims, order='F'), size=bSiz)
                #                ind_a = np.reshape(ind_a > 1e-10, (np.prod(dims),), order='F')
                #                indeces_good = np.where(ind_a)[0]#np.where(determine_search_location(Ain,dims))[0]
                if not useOASIS:
                    # TODO: decide on a line to use for setting the parameters
                    # # lambda from init batch (e.g. mean of lambdas)  or  s_min if either s_min or s_min_thresh are given
                    # oas = oasis.OASIS(g=np.ravel(g)[0], lam if s_min==0 else 0, s_min, num_empty_samples=t + 1 - len(cin_res),
                    #                   g2=0 if np.size(g) == 1 else g[1])
                    # or
                    # lambda from Selesnick's 3*sigma*|K| rule
                    # use noise estimate from init batch or use std_rr?
                    #                    sn_ = sqrt((ain**2).dot(sn[indeces]**2)) / sqrt(1 - g**2)
                    sn_ = std_rr
                    oas = oasis.OASIS(
                        np.ravel(g)[0],
                        3 * sn_ / (sqrt(1 - g**2) if np.size(g) == 1 else sqrt(
                            (1 + g[1]) * ((1 - g[1])**2 - g[0]**2) /
                            (1 - g[1]))) if s_min == 0 else 0,
                        s_min,
                        num_empty_samples=t + 1 - len(cin_res),
                        g2=0 if np.size(g) == 1 else g[1])
                    for yt in cin_res:
                        oas.fit_next(yt)

                oases.append(oas)

                Ain_csc = scipy.sparse.csc_matrix(
                    (ain, (indeces, [0] * len(indeces))), (np.prod(dims), 1),
                    dtype=np.float32)

                if Ab_dense is None:
                    groups = update_order(Ab, Ain, groups)[0]
                else:
                    groups = update_order(Ab_dense[indeces], ain, groups)[0]
                csc_append(Ab,
                           Ain_csc)  # faster version of scipy.sparse.hstack
                ind_A.append(Ab.indices[Ab.indptr[M]:Ab.indptr[M + 1]])

                #                ccf = Cf[:,-minibatch_suff_stat:]
                #                CY = ((t*1.-1)/t) * CY + (1./t) * np.dot(ccf, Yr[:, t-minibatch_suff_stat:t].T)/minibatch_suff_stat
                #                CC = ((t*1.-1)/t) * CC + (1./t) * ccf.dot(ccf.T)/minibatch_suff_stat

                tt = t * 1.
                #                if batch_update_suff_stat and Y_buf.cur<len(Y_buf)-1:
                #                   Y_buf_ = Y_buf[Y_buf.cur+1:,:]
                #                   cin_ = cin[Y_buf.cur+1:]
                #                   n_fr_ = len(cin_)
                #                   cin_circ_= cin_circ[-n_fr_:]
                #                   Cf_ = Cf[:,-n_fr_:]
                #                else:
                Y_buf_ = Y_buf
                cin_ = cin
                Cf_ = Cf
                cin_circ_ = cin_circ

                #                CY[M, :] = Y_buf_.T.dot(cin_)[None, :] / tt
                # much faster: exploit that we only access CY[m, ind_pixels],
                # hence update only these
                CY[M, indeces] = cin_.dot(Y_buf_[:, indeces]) / tt
                #                CY = np.vstack([CY[:N,:], Y_buf.T.dot(cin / tt)[None,:], CY[ N:,:]])
                #                YC = CY.T
                #                YC = np.hstack([YC[:, :N], Y_buf.T.dot(cin / tt)[:, None], YC[:, N:]])
                #                CY = YC.T

                # preallocate memory for speed up?
                CC1 = np.hstack([CC, Cf_.dot(cin_circ_ / tt)[:, None]])
                CC2 = np.hstack([(Cf_.dot(cin_circ_)).T,
                                 cin_circ_.dot(cin_circ_)]) / tt
                CC = np.vstack([CC1, CC2])
                Cf = np.vstack([Cf, cin_circ])

                N = N + 1
                M = M + 1

                Yres_buf[:, indeces] -= np.outer(cin, ain)
                # vb = imblur(np.reshape(Ain, dims, order='F'), sig=gSig,
                #             siz=gSiz, nDimBlur=2).ravel()
                # restrict blurring to region where component is located
                vb = np.reshape(Ain, dims, order='F')
                slices = tuple(
                    slice(max(0, ijs[0] - 2 * sg), min(d, ijs[1] + 2 * sg))
                    for ijs, sg, d in zip(ijSig, gSig, dims))  # is 2 enough?
                vb[slices] = imblur(vb[slices], sig=gSig, siz=gSiz, nDimBlur=2)
                vb = vb.ravel()

                # ind_vb = np.where(vb)[0]
                ind_vb = np.ravel_multi_index(
                    np.ix_(*[np.arange(s.start, s.stop) for s in slices]),
                    dims).ravel()

                updt_res = (vb[None, ind_vb].T**2).dot(cin[None, :]**2).T
                rho_buf[:, ind_vb] -= updt_res
                updt_res_sum = np.sum(updt_res, 0)
                sv[ind_vb] -= updt_res_sum
                sv_[ind_vb] -= updt_res_sum

            else:

                num_added = max_num_added

        else:

            num_added = max_num_added

    return Ab, Cf, Yres_buf, rho_buf, CC, CY, ind_A, sv, groups
Ejemplo n.º 4
0
def update_num_components(t,
                          sv,
                          Ab,
                          Cf,
                          Yres_buf,
                          Y_buf,
                          rho_buf,
                          dims,
                          gSig,
                          gSiz,
                          ind_A,
                          CY,
                          CC,
                          groups,
                          oases,
                          gnb=1,
                          rval_thr=0.875,
                          bSiz=3,
                          robust_std=False,
                          N_samples_exceptionality=5,
                          remove_baseline=True,
                          thresh_fitness_delta=-80,
                          thresh_fitness_raw=-20,
                          thresh_overlap=0.25,
                          batch_update_suff_stat=False,
                          sn=None,
                          g=None,
                          thresh_s_min=None,
                          s_min=None,
                          Ab_dense=None,
                          max_num_added=1,
                          min_num_trial=1):
    """
    Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests    
    """

    order_rvl = 'C'
    gHalf = np.array(gSiz) // 2

    # number of total components (including background)
    M = np.shape(Ab)[-1]
    N = M - gnb  # number of coponents (without background)

    first = True

    sv -= rho_buf.get_first()
    # update variance of residual buffer
    sv += rho_buf.get_last_frames(1).squeeze()

    num_added = 0
    cnt = 0
    while num_added < max_num_added:
        cnt += 1
        if first:
            sv_ = sv.copy()  # np.sum(rho_buf,0)
            first = False

        ind = np.argmax(sv_)
        ij = np.unravel_index(ind, dims, order=order_rvl)
        # ijSig = [[np.maximum(ij[c] - gHalf[c], 0), np.minimum(ij[c] + gHalf[c] + 1, dims[c])]
        #          for c in range(len(ij))]
        # better than above expensive call of numpy and loop creation

        #        ijSig = [[max(ij[0] - gHalf[0], 0), min(ij[0] + gHalf[0] + 1, dims[0])],
        #                 [max(ij[1] - gHalf[1], 0), min(ij[1] + gHalf[1] + 1, dims[1])]]

        ijSig = [[max(i - g, 0), min(i + g + 1, d)]
                 for i, g, d in zip(ij, gHalf, dims)]

        # xySig = np.meshgrid(*[np.arange(s[0], s[1]) for s in ijSig], indexing='xy')
        # arr = np.array([np.reshape(s, (1, np.size(s)), order='F').squeeze()
        #                 for s in xySig], dtype=np.int)
        # indeces = np.ravel_multi_index(arr, dims, order='F')

        #        indeces = np.ravel_multi_index(np.ix_(np.arange(ijSig[0][0], ijSig[0][1]),
        #                                              np.arange(ijSig[1][0], ijSig[1][1])),
        #                                       dims, order='F').ravel(order=order_rvl)

        indeces = np.ravel_multi_index(
            np.ix_(*[np.arange(ij[0], ij[1]) for ij in ijSig]),
            dims,
            order='F').ravel(order=order_rvl)

        indeces_ = np.ravel_multi_index(
            np.ix_(*[np.arange(ij[0], ij[1]) for ij in ijSig]),
            dims,
            order='C').ravel(order=order_rvl)

        #        indeces_ = np.ravel_multi_index(np.ix_(np.arange(ijSig[0][0], ijSig[0][1]),
        #                                               np.arange(ijSig[1][0], ijSig[1][1])),
        #                                        dims, order='C').ravel(order=order_rvl)

        Ypx = Yres_buf.T[indeces, :]

        ain = np.maximum(np.mean(Ypx, 1), 0)
        na = ain.dot(ain)
        if not na:
            break

        ain /= sqrt(na)

        #        new_res = sv_.copy()
        #        new_res[ np.ravel_multi_index(arr, dims, order='C')] = 10000

        # expects and returns normalized ain
        ain, cin, cin_res = rank1nmf(Ypx, ain)
        # correlation coefficient
        rval = corr(ain.copy(), np.mean(Ypx, -1))

        if rval > rval_thr:
            # use sparse Ain only later iff it is actually added to Ab
            Ain = np.zeros((np.prod(dims), 1), dtype=np.float32)
            Ain[indeces, :] = ain[:, None]

            cin_circ = cin.get_ordered()

            #        indeces_good = (Ain[indeces]>0.01).nonzero()[0]

            useOASIS = False  # whether to use faster OASIS for cell detection
            foo = True  # flag indicating new component has not been rejected yet

            if Ab_dense is None:
                ff = np.where(
                    (Ab.T.dot(Ain).T > thresh_overlap)[:, gnb:])[1] + gnb
            else:
                ff = np.where(
                    Ab_dense[indeces,
                             gnb:].T.dot(ain).T > thresh_overlap)[0] + gnb
            if ff.size > 0:
                foo = False
                cc = [corr(cin_circ.copy(), cins) for cins in Cf[ff, :]]
                if np.any(np.array(cc) > .25) and foo:
                    #                    repeat = False
                    # vb = imblur(np.reshape(Ain, dims, order='F'),
                    #             sig=gSig, siz=gSiz, nDimBlur=2)
                    # restrict blurring to region where component is located
                    vb = np.reshape(Ain, dims, order='C')
                    slices = tuple(
                        slice(max(0, ijs[0] - 2 * sg),
                              min(d, ijs[1] + 2 * sg)) for ijs, sg, d in zip(
                                  ijSig, gSig, dims))  # is 2 enough?
                    vb[slices] = imblur(vb[slices],
                                        sig=gSig,
                                        siz=gSiz,
                                        nDimBlur=len(dims))
                    sv_ -= (vb.ravel(order=order_rvl)**2) * cin.dot(cin)
                    foo = False  # reject component as duplicate
#                    pl.imshow(np.reshape(sv,dims));pl.pause(0.001)
#  print('Overlap at step' + str(t) + ' ' + str(cc))
# break

# use thresh_s_min * noise estimate * sqrt(1-sum(gamma))
            if s_min is None:
                # the formula has been obtained by running OASIS with s_min=0 and lambda=0 on Gaussin noise.
                # e.g. 1 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the root mean square (non-zero) spike size, sqrt(<s^2>)
                #      2 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 95% percentile of (non-zero) spike sizes
                #      3 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 99.7% percentile of (non-zero) spike sizes
                s_min = 0 if thresh_s_min is None else thresh_s_min * \
                    sqrt((ain**2).dot(sn[indeces]**2)) * sqrt(1 - np.sum(g))

            cin_res = cin_res.get_ordered()
            if foo:
                if useOASIS:
                    oas = oasis.OASIS(g=g,
                                      s_min=s_min,
                                      num_empty_samples=t + 1 - len(cin_res))
                    for yt in cin_res:
                        oas.fit_next(yt)
                    foo = oas.get_l_of_last_pool() <= t
                else:
                    fitness_delta, erfc_delta, std_rr, _ = compute_event_exceptionality(
                        np.diff(cin_res)[None, :],
                        robust_std=robust_std,
                        N=N_samples_exceptionality)
                    if remove_baseline:
                        num_samps_bl = min(len(cin_res) // 5, 800)
                        bl = scipy.ndimage.percentile_filter(cin_res,
                                                             8,
                                                             size=num_samps_bl)
                    else:
                        bl = 0
                    fitness_raw, erfc_raw, std_rr, _ = compute_event_exceptionality(
                        (cin_res - bl)[None, :],
                        robust_std=robust_std,
                        N=N_samples_exceptionality)
                    foo = (fitness_delta < thresh_fitness_delta) or (
                        fitness_raw < thresh_fitness_raw)

            if foo:
                # print('adding component' + str(N + 1) + ' at timestep ' + str(t))
                num_added += 1
                #                ind_a = uniform_filter(np.reshape(Ain.toarray(), dims, order='F'), size=bSiz)
                #                ind_a = np.reshape(ind_a > 1e-10, (np.prod(dims),), order='F')
                #                indeces_good = np.where(ind_a)[0]#np.where(determine_search_location(Ain,dims))[0]
                if oases is not None:
                    if not useOASIS:
                        # lambda from Selesnick's 3*sigma*|K| rule
                        # use noise estimate from init batch or use std_rr?
                        #                    sn_ = sqrt((ain**2).dot(sn[indeces]**2)) / sqrt(1 - g**2)
                        sn_ = std_rr
                        oas = oasis.OASIS(
                            np.ravel(g)[0],
                            3 * sn_ /
                            (sqrt(1 - g**2) if np.size(g) == 1 else sqrt(
                                (1 + g[1]) * ((1 - g[1])**2 - g[0]**2) /
                                (1 - g[1]))) if s_min == 0 else 0,
                            s_min,
                            num_empty_samples=t + 1 - len(cin_res),
                            g2=0 if np.size(g) == 1 else g[1])
                        for yt in cin_res:
                            oas.fit_next(yt)

                    oases.append(oas)

                Ain_csc = scipy.sparse.csc_matrix(
                    (ain, (indeces, [0] * len(indeces))), (np.prod(dims), 1),
                    dtype=np.float32)

                if Ab_dense is None:
                    groups = update_order(Ab, Ain, groups)[0]
                else:
                    groups = update_order(Ab_dense[indeces], ain, groups)[0]
                # faster version of scipy.sparse.hstack
                csc_append(Ab, Ain_csc)
                ind_A.append(Ab.indices[Ab.indptr[M]:Ab.indptr[M + 1]])

                tt = t * 1.
                #                if batch_update_suff_stat and Y_buf.cur<len(Y_buf)-1:
                #                   Y_buf_ = Y_buf[Y_buf.cur+1:,:]
                #                   cin_ = cin[Y_buf.cur+1:]
                #                   n_fr_ = len(cin_)
                #                   cin_circ_= cin_circ[-n_fr_:]
                #                   Cf_ = Cf[:,-n_fr_:]
                #                else:
                Y_buf_ = Y_buf
                cin_ = cin
                Cf_ = Cf
                cin_circ_ = cin_circ

                #                CY[M, :] = Y_buf_.T.dot(cin_)[None, :] / tt
                # much faster: exploit that we only access CY[m, ind_pixels],
                # hence update only these
                CY[M, indeces] = cin_.dot(Y_buf_[:, indeces]) / tt

                # preallocate memory for speed up?
                CC1 = np.hstack([CC, Cf_.dot(cin_circ_ / tt)[:, None]])
                CC2 = np.hstack([(Cf_.dot(cin_circ_)).T,
                                 cin_circ_.dot(cin_circ_)]) / tt
                CC = np.vstack([CC1, CC2])
                Cf = np.vstack([Cf, cin_circ])

                N = N + 1
                M = M + 1

                Yres_buf[:, indeces] -= np.outer(cin, ain)
                # vb = imblur(np.reshape(Ain, dims, order='F'), sig=gSig,
                #             siz=gSiz, nDimBlur=2).ravel()
                # restrict blurring to region where component is located
                vb = np.reshape(Ain, dims, order='F')
                slices = tuple(
                    slice(max(0, ijs[0] - 2 * sg), min(d, ijs[1] + 2 * sg))
                    for ijs, sg, d in zip(ijSig, gSig, dims))  # is 2 enough?
                vb[slices] = imblur(vb[slices],
                                    sig=gSig,
                                    siz=gSiz,
                                    nDimBlur=len(dims))
                vb = vb.ravel(order=order_rvl)

                # ind_vb = np.where(vb)[0]
                ind_vb = np.ravel_multi_index(
                    np.ix_(*[np.arange(s.start, s.stop) for s in slices]),
                    dims,
                    order=order_rvl).ravel(order=order_rvl)

                updt_res = (vb[None, ind_vb].T**2).dot(cin[None, :]**2).T
                rho_buf[:, ind_vb] -= updt_res
                updt_res_sum = np.sum(updt_res, 0)
                sv[ind_vb] -= updt_res_sum
                sv_[ind_vb] -= updt_res_sum

            else:
                if cnt >= min_num_trial:
                    num_added = max_num_added
                else:
                    first = False
                    sv_[indeces_] = 0

        else:
            if cnt >= min_num_trial:
                num_added = max_num_added
            else:
                first = False
                sv_[indeces_] = 0

    return Ab, Cf, Yres_buf, rho_buf, CC, CY, ind_A, sv, groups